1use std::{
2 boxed::Box,
3 collections::HashMap,
4 sync::{Arc, RwLock},
5};
6
7use std::net::SocketAddr;
8
9use rpcx_protocol::*;
10use std::{
11 io::{BufReader, BufWriter, Write},
12 net::{Shutdown, TcpListener, TcpStream},
13};
14
15use std::{
16 os::unix::io::{AsRawFd, RawFd},
17 thread,
18};
19
20use scoped_threadpool::Pool;
21
22pub mod plugin;
23pub use plugin::*;
24
25pub type RpcxFn = fn(&[u8], SerializeType) -> Result<Vec<u8>>;
26pub struct Server {
27 pub addr: String,
28 raw_fd: Option<RawFd>,
29 pub services: Arc<RwLock<HashMap<String, Box<RpcxFn>>>>,
30 thread_number: u32,
31 register_plugins: Arc<RwLock<Vec<Box<dyn RegisterPlugin + Send + Sync>>>>,
32 connect_plugins: Arc<RwLock<Vec<Box<dyn ConnectPlugin + Send + Sync>>>>,
33}
34
35impl Server {
36 pub fn new(s: String, n: u32) -> Self {
37 let mut thread_number = n;
38 if n == 0 {
39 thread_number = num_cpus::get() as u32;
40 thread_number *= 2;
41 }
42 Server {
43 addr: s,
44 services: Arc::new(RwLock::new(HashMap::new())),
45 thread_number,
46 register_plugins: Arc::new(RwLock::new(Vec::new())),
47 connect_plugins: Arc::new(RwLock::new(Vec::new())),
48 raw_fd: None,
49 }
50 }
51
52 pub fn register_fn(
53 &mut self,
54 service_path: String,
55 service_method: String,
56 meta: String,
57 f: RpcxFn,
58 ) {
59 let mut plugins = self.register_plugins.write().unwrap();
61 for p in plugins.iter_mut() {
62 let pp = &mut **p;
63 match pp.register_fn(
64 service_path.as_str(),
65 service_method.as_str(),
66 meta.clone(),
67 f,
68 ) {
69 Ok(_) => {}
70 Err(err) => eprintln!("{}", err),
71 }
72 }
73
74 let key = format!("{}.{}", service_path, service_method);
76 let services = self.services.clone();
77 let mut map = services.write().unwrap();
78 map.insert(key, Box::new(f));
79 }
80
81 pub fn get_fn(&self, service_path: String, service_method: String) -> Option<RpcxFn> {
82 let key = format!("{}.{}", service_path, service_method);
83 let map = self.services.read().unwrap();
84 let box_fn = map.get(&key)?;
85 Some(**box_fn)
86 }
87
88 pub fn start_with_listener(&self, listener: TcpListener) -> Result<()> {
89 let thread_number = self.thread_number;
90
91 'accept_loop: for stream in listener.incoming() {
92 match stream {
93 Ok(stream) => {
94 let services_cloned = self.services.clone();
95 thread::spawn(move || {
96 Server::process(thread_number, services_cloned, stream);
97 });
98 }
99 Err(e) => {
100 return Err(Error::new(ErrorKind::Network, e));
102 }
103 }
104 }
105
106 Ok(())
107 }
108 pub fn start(&mut self) -> Result<()> {
109 let addr = self
110 .addr
111 .parse::<SocketAddr>()
112 .map_err(|err| Error::new(ErrorKind::Other, err))?;
113
114 let listener = TcpListener::bind(&addr)?;
115 println!("Listening on: {}", addr);
116
117 self.raw_fd = Some(listener.as_raw_fd());
118
119 self.start_with_listener(listener)
120 }
121
122 pub fn close(&self) {
123 if let Some(raw_fd) = self.raw_fd {
124 unsafe {
125 libc::close(raw_fd);
126 }
127 }
128 }
129 fn process(
130 thread_number: u32,
131 service: Arc<RwLock<HashMap<String, Box<RpcxFn>>>>,
132 stream: TcpStream,
133 ) {
134 let services_cloned = service;
135 let local_stream = stream.try_clone().unwrap();
136
137 let mut pool = Pool::new(thread_number);
138 pool.scoped(|scoped| {
139 let mut reader = BufReader::new(stream.try_clone().unwrap());
140 loop {
141 let mut msg = Message::new();
142 match msg.decode(&mut reader) {
143 Ok(()) => {
144 let service_path = &msg.service_path;
145 let service_method = &msg.service_method;
146 let key = format!("{}.{}", service_path, service_method);
147 let map = &services_cloned.read().unwrap();
148 match map.get(&key) {
149 Some(box_fn) => {
150 let f = **box_fn;
151 let local_stream_in_child = local_stream.try_clone().unwrap();
152
153 scoped.execute(move || {
154 invoke_fn(local_stream_in_child.try_clone().unwrap(), msg, f)
155 });
156 }
157 None => {
158 let err = format!("service {} not found", key);
159 let reply_msg = msg.get_reply().unwrap();
160 let mut metadata = reply_msg.metadata.borrow_mut();
161 (*metadata).insert(SERVICE_ERROR.to_string(), err);
162 drop(metadata);
163 let data = reply_msg.encode();
164 let mut writer = BufWriter::new(local_stream.try_clone().unwrap());
165 writer.write_all(&data).unwrap();
166 writer.flush().unwrap();
167 }
168 }
169 }
170 Err(err) => {
171 eprintln!("failed to read: {}", err.to_string());
172 match local_stream.shutdown(Shutdown::Both) {
173 Ok(()) => {
174 if let Ok(sa) = local_stream.peer_addr() {
175 println!("client {} is closed", sa)
176 }
177 }
178 Err(e) => {
179 if let Ok(sa) = local_stream.peer_addr() {
180 println!("client {} is closed. err: {}", sa, e)
181 }
182 }
183 }
184 return;
185 }
186 }
187 }
188 });
189 }
190}
191
192fn invoke_fn(stream: TcpStream, msg: Message, f: RpcxFn) {
193 let mut reply_msg = msg.get_reply().unwrap();
194 let reply = f(&msg.payload, msg.get_serialize_type().unwrap()).unwrap();
195 reply_msg.payload = reply;
196 let data = reply_msg.encode();
197
198 let mut writer = BufWriter::new(stream.try_clone().unwrap());
199 match writer.write_all(&data) {
200 Ok(()) => {}
201 Err(_err) => {}
202 }
203 match writer.flush() {
204 Ok(()) => {}
205 Err(_err) => {}
206 }
207}
208
209#[macro_export]
210macro_rules! register_func {
211 ($rpc_server:expr, $service_path:expr, $service_method:expr, $service_fn:expr, $meta:expr, $arg_type:ty, $reply_type:ty) => {{
212 let f: RpcxFn = |x, st| {
213 let mut args: $arg_type = Default::default();
215 args.from_slice(st, x)?;
216 let reply: $reply_type = $service_fn(args);
217 reply.into_bytes(st)
218 };
219 $rpc_server.register_fn(
220 $service_path.to_string(),
221 $service_method.to_string(),
222 $meta,
223 f,
224 );
225 }};
226}