rpcx_server/
lib.rs

1use std::{
2    boxed::Box,
3    collections::HashMap,
4    sync::{Arc, RwLock},
5};
6
7use std::net::SocketAddr;
8
9use rpcx_protocol::*;
10use std::{
11    io::{BufReader, BufWriter, Write},
12    net::{Shutdown, TcpListener, TcpStream},
13};
14
15use std::{
16    os::unix::io::{AsRawFd, RawFd},
17    thread,
18};
19
20use scoped_threadpool::Pool;
21
22pub mod plugin;
23pub use plugin::*;
24
25pub type RpcxFn = fn(&[u8], SerializeType) -> Result<Vec<u8>>;
26pub struct Server {
27    pub addr: String,
28    raw_fd: Option<RawFd>,
29    pub services: Arc<RwLock<HashMap<String, Box<RpcxFn>>>>,
30    thread_number: u32,
31    register_plugins: Arc<RwLock<Vec<Box<dyn RegisterPlugin + Send + Sync>>>>,
32    connect_plugins: Arc<RwLock<Vec<Box<dyn ConnectPlugin + Send + Sync>>>>,
33}
34
35impl Server {
36    pub fn new(s: String, n: u32) -> Self {
37        let mut thread_number = n;
38        if n == 0 {
39            thread_number = num_cpus::get() as u32;
40            thread_number *= 2;
41        }
42        Server {
43            addr: s,
44            services: Arc::new(RwLock::new(HashMap::new())),
45            thread_number,
46            register_plugins: Arc::new(RwLock::new(Vec::new())),
47            connect_plugins: Arc::new(RwLock::new(Vec::new())),
48            raw_fd: None,
49        }
50    }
51
52    pub fn register_fn(
53        &mut self,
54        service_path: String,
55        service_method: String,
56        meta: String,
57        f: RpcxFn,
58    ) {
59        // invoke register plugins
60        let mut plugins = self.register_plugins.write().unwrap();
61        for p in plugins.iter_mut() {
62            let pp = &mut **p;
63            match pp.register_fn(
64                service_path.as_str(),
65                service_method.as_str(),
66                meta.clone(),
67                f,
68            ) {
69                Ok(_) => {}
70                Err(err) => eprintln!("{}", err),
71            }
72        }
73
74        // invoke service
75        let key = format!("{}.{}", service_path, service_method);
76        let services = self.services.clone();
77        let mut map = services.write().unwrap();
78        map.insert(key, Box::new(f));
79    }
80
81    pub fn get_fn(&self, service_path: String, service_method: String) -> Option<RpcxFn> {
82        let key = format!("{}.{}", service_path, service_method);
83        let map = self.services.read().unwrap();
84        let box_fn = map.get(&key)?;
85        Some(**box_fn)
86    }
87
88    pub fn start_with_listener(&self, listener: TcpListener) -> Result<()> {
89        let thread_number = self.thread_number;
90
91        'accept_loop: for stream in listener.incoming() {
92            match stream {
93                Ok(stream) => {
94                    let services_cloned = self.services.clone();
95                    thread::spawn(move || {
96                        Server::process(thread_number, services_cloned, stream);
97                    });
98                }
99                Err(e) => {
100                    //println!("Unable to accept: {}", e);
101                    return Err(Error::new(ErrorKind::Network, e));
102                }
103            }
104        }
105
106        Ok(())
107    }
108    pub fn start(&mut self) -> Result<()> {
109        let addr = self
110            .addr
111            .parse::<SocketAddr>()
112            .map_err(|err| Error::new(ErrorKind::Other, err))?;
113
114        let listener = TcpListener::bind(&addr)?;
115        println!("Listening on: {}", addr);
116
117        self.raw_fd = Some(listener.as_raw_fd());
118
119        self.start_with_listener(listener)
120    }
121
122    pub fn close(&self) {
123        if let Some(raw_fd) = self.raw_fd {
124            unsafe {
125                libc::close(raw_fd);
126            }
127        }
128    }
129    fn process(
130        thread_number: u32,
131        service: Arc<RwLock<HashMap<String, Box<RpcxFn>>>>,
132        stream: TcpStream,
133    ) {
134        let services_cloned = service;
135        let local_stream = stream.try_clone().unwrap();
136
137        let mut pool = Pool::new(thread_number);
138        pool.scoped(|scoped| {
139            let mut reader = BufReader::new(stream.try_clone().unwrap());
140            loop {
141                let mut msg = Message::new();
142                match msg.decode(&mut reader) {
143                    Ok(()) => {
144                        let service_path = &msg.service_path;
145                        let service_method = &msg.service_method;
146                        let key = format!("{}.{}", service_path, service_method);
147                        let map = &services_cloned.read().unwrap();
148                        match map.get(&key) {
149                            Some(box_fn) => {
150                                let f = **box_fn;
151                                let local_stream_in_child = local_stream.try_clone().unwrap();
152
153                                scoped.execute(move || {
154                                    invoke_fn(local_stream_in_child.try_clone().unwrap(), msg, f)
155                                });
156                            }
157                            None => {
158                                let err = format!("service {} not found", key);
159                                let reply_msg = msg.get_reply().unwrap();
160                                let mut metadata = reply_msg.metadata.borrow_mut();
161                                (*metadata).insert(SERVICE_ERROR.to_string(), err);
162                                drop(metadata);
163                                let data = reply_msg.encode();
164                                let mut writer = BufWriter::new(local_stream.try_clone().unwrap());
165                                writer.write_all(&data).unwrap();
166                                writer.flush().unwrap();
167                            }
168                        }
169                    }
170                    Err(err) => {
171                        eprintln!("failed to read: {}", err.to_string());
172                        match local_stream.shutdown(Shutdown::Both) {
173                            Ok(()) => {
174                                if let Ok(sa) = local_stream.peer_addr() {
175                                    println!("client {} is closed", sa)
176                                }
177                            }
178                            Err(e) => {
179                                if let Ok(sa) = local_stream.peer_addr() {
180                                    println!("client {} is closed. err: {}", sa, e)
181                                }
182                            }
183                        }
184                        return;
185                    }
186                }
187            }
188        });
189    }
190}
191
192fn invoke_fn(stream: TcpStream, msg: Message, f: RpcxFn) {
193    let mut reply_msg = msg.get_reply().unwrap();
194    let reply = f(&msg.payload, msg.get_serialize_type().unwrap()).unwrap();
195    reply_msg.payload = reply;
196    let data = reply_msg.encode();
197
198    let mut writer = BufWriter::new(stream.try_clone().unwrap());
199    match writer.write_all(&data) {
200        Ok(()) => {}
201        Err(_err) => {}
202    }
203    match writer.flush() {
204        Ok(()) => {}
205        Err(_err) => {}
206    }
207}
208
209#[macro_export]
210macro_rules! register_func {
211    ($rpc_server:expr, $service_path:expr, $service_method:expr, $service_fn:expr, $meta:expr, $arg_type:ty, $reply_type:ty) => {{
212        let f: RpcxFn = |x, st| {
213            // TODO change ProtoArgs to $arg_typ
214            let mut args: $arg_type = Default::default();
215            args.from_slice(st, x)?;
216            let reply: $reply_type = $service_fn(args);
217            reply.into_bytes(st)
218        };
219        $rpc_server.register_fn(
220            $service_path.to_string(),
221            $service_method.to_string(),
222            $meta,
223            f,
224        );
225    }};
226}