lampo_jsonrpc/
lib.rs

1//! Full feature async JSON RPC 2.0 Server/client with a
2//! minimal dependencies footprint.
3use std::cell::{Cell, RefCell};
4use std::collections::{HashMap, VecDeque};
5use std::io::{self, ErrorKind};
6use std::io::{Read, Write};
7use std::os::unix::net::UnixListener;
8use std::os::unix::net::{SocketAddr, UnixStream};
9use std::sync::{Arc, Mutex};
10use std::thread::JoinHandle;
11
12// FIXME: use mio for a better platform support.
13use popol::{Sources, Timeout};
14use serde_json::Value;
15
16pub mod command;
17pub mod errors;
18pub mod json_rpc2;
19
20use command::Context;
21
22use crate::errors::Error;
23use crate::json_rpc2::{Request, Response};
24
25#[derive(Debug, Clone, PartialEq)]
26pub enum RPCEvent {
27    Listening,
28    Connect(String),
29}
30
31pub struct JSONRPCv2<T: Send + Sync + 'static> {
32    socket_path: String,
33    sources: Sources<RPCEvent>,
34    socket: UnixListener,
35    handler: Arc<Handler<T>>,
36    // FIXME: should be not the name but the fd int as key?
37    pub(crate) conn: HashMap<String, UnixStream>,
38    conn_queue: Mutex<Cell<HashMap<String, VecDeque<Response<Value>>>>>,
39}
40
41pub struct Handler<T: Send + Sync + 'static> {
42    stop: Cell<bool>,
43    rpc_method:
44        RefCell<HashMap<String, Arc<dyn Fn(&T, &Value) -> Result<Value, errors::Error> + 'static>>>,
45    ctx: Arc<dyn Context<Ctx = T>>,
46}
47
48unsafe impl<T: Send + Sync> Sync for Handler<T> {}
49unsafe impl<T: Send + Sync> Send for Handler<T> {}
50
51impl<T: Send + Sync + 'static> Handler<T> {
52    pub fn new(ctx: Arc<dyn Context<Ctx = T>>) -> Self {
53        Handler::<T> {
54            stop: Cell::new(false),
55            rpc_method: RefCell::new(HashMap::new()),
56            ctx,
57        }
58    }
59
60    pub fn add_method<F>(&self, method: &str, callback: F)
61    where
62        F: Fn(&T, &Value) -> Result<Value, errors::Error> + 'static,
63    {
64        self.rpc_method
65            .borrow_mut()
66            .insert(method.to_owned(), Arc::new(callback));
67    }
68
69    pub fn run_callback(&self, req: &Request<Value>) -> Option<Result<Value, errors::Error>> {
70        let binding = self.rpc_method.borrow();
71        let Some(callback) = binding.get(&req.method) else {
72            return Some(Err(errors::RpcError {
73                message: format!("method `{}` not found", req.method),
74                code: -1,
75                data: None,
76            }
77            .into()));
78        };
79        let resp = callback(self.ctx(), &req.params);
80        Some(resp)
81    }
82
83    pub fn has_rpc(&self, method: &str) -> bool {
84        self.rpc_method.borrow().contains_key(method)
85    }
86
87    fn ctx(&self) -> &T {
88        self.ctx.ctx()
89    }
90
91    pub fn stop(&self) {
92        self.stop.set(true);
93    }
94}
95
96impl<T: Send + Sync + 'static> JSONRPCv2<T> {
97    pub fn new(ctx: Arc<dyn Context<Ctx = T>>, path: &str) -> Result<Self, Error> {
98        let listnet = UnixListener::bind(path)?;
99        let sources = Sources::<RPCEvent>::new();
100        Ok(Self {
101            sources,
102            socket: listnet,
103            handler: Arc::new(Handler::new(ctx)),
104            socket_path: path.to_owned(),
105            conn: HashMap::new(),
106            conn_queue: Mutex::new(Cell::new(HashMap::new())),
107        })
108    }
109
110    pub fn add_rpc<F>(&self, name: &str, callback: F) -> Result<(), ()>
111    where
112        F: Fn(&T, &Value) -> Result<Value, errors::Error> + 'static,
113    {
114        if self.handler.has_rpc(name) {
115            return Err(());
116        }
117        self.handler.add_method(name, callback);
118        Ok(())
119    }
120
121    pub fn add_connection(&mut self, key: &SocketAddr, stream: UnixStream) {
122        let path = if let Some(path) = key.as_pathname() {
123            path.to_str().unwrap()
124        } else {
125            "unnamed"
126        };
127        let res = stream.set_nonblocking(true);
128        debug_assert!(res.is_ok());
129        let event = RPCEvent::Connect(path.to_string());
130        self.sources.register(event, &stream, popol::interest::ALL);
131        self.conn.insert(path.to_owned(), stream);
132    }
133
134    pub fn send_resp(&self, key: String, resp: Response<Value>) {
135        let queue = self.conn_queue.lock().unwrap();
136
137        let mut conns = queue.take();
138        log::debug!(target: "jsonrpc", "{:?}", conns);
139        if conns.contains_key(&key) {
140            let Some(queue) = conns.get_mut(&key) else {
141                panic!("queue not found");
142            };
143            queue.push_back(resp);
144        } else {
145            let mut q = VecDeque::new();
146            q.push_back(resp);
147            conns.insert(key, q);
148        }
149        log::debug!(target: "jsonrpc", "{:?}", conns);
150        queue.set(conns);
151    }
152
153    pub fn pop_resp(&self, key: String) -> Option<Response<Value>> {
154        let queue = self.conn_queue.lock().unwrap();
155
156        let mut conns = queue.take();
157        if !conns.contains_key(&key) {
158            return None;
159        }
160        let Some(q) = conns.get_mut(&key) else {
161            return None;
162        };
163        let resp = q.pop_front();
164        queue.set(conns);
165        resp
166    }
167
168    #[allow(dead_code)]
169    fn ctx(&self) -> &T {
170        self.handler.ctx()
171    }
172
173    pub fn listen(mut self) -> io::Result<()> {
174        self.socket.set_nonblocking(true)?;
175        self.sources
176            .register(RPCEvent::Listening, &self.socket, popol::interest::READ);
177
178        log::info!(target: "jsonrpc", "starting server on {}", self.socket_path);
179        let mut events = vec![];
180        while !self.handler.stop.get() {
181            // Blocking while we are waiting new events!
182            self.sources.poll(&mut events, Timeout::Never)?;
183
184            for mut event in events.drain(..) {
185                match &event.key {
186                    RPCEvent::Listening => {
187                        let conn = self.socket.accept();
188                        let Ok((stream, addr)) = conn else {
189                            if let Err(err) = &conn {
190                                if err.kind() == ErrorKind::WouldBlock {
191                                    break;
192                                }
193                            }
194                            log::error!(target: "jsonrpc", "fail to accept the connection: {:?}", conn);
195                            continue;
196                        };
197                        log::trace!(target: "jsonrpc", "new connection to unix rpc socket");
198                        self.add_connection(&addr, stream);
199                    }
200                    RPCEvent::Connect(addr) => {
201                        if event.is_hangup() {
202                            break;
203                        }
204                        if event.is_error() {
205                            log::error!(target: "jsonrpc", "an error occurs");
206                            continue;
207                        }
208
209                        if event.is_invalid() {
210                            log::info!(target: "jsonrpc", "event invalid, unregister event from the tracking one");
211                            self.sources.unregister(&event.key);
212                            break;
213                        }
214
215                        if event.is_readable() {
216                            let Some(mut stream) = self.conn.get(addr) else {
217                                log::error!(target: "jsonrpc", "connection not found `{addr}`");
218                                continue;
219                            };
220                            let mut buff = String::new();
221                            if let Err(err) = stream.read_to_string(&mut buff) {
222                                if err.kind() != ErrorKind::WouldBlock {
223                                    return Err(err);
224                                }
225                                log::info!(target: "jsonrpc", "blocking with err {:?}!", err);
226                            }
227                            if buff.is_empty() {
228                                log::warn!(target: "jsonrpc", "buffer is empty");
229                                break;
230                            }
231                            let buff = buff.trim();
232                            log::info!(target: "jsonrpc", "buffer read {buff}");
233                            let requ: Request<Value> =
234                                serde_json::from_str(&buff).map_err(|err| {
235                                    io::Error::new(io::ErrorKind::Other, format!("{err}"))
236                                })?;
237                            log::trace!(target: "jsonrpc", "request {:?}", requ);
238                            let Some(resp) = self.handler.run_callback(&requ) else {
239                                log::error!(target: "jsonrpc", "`{}` not found!", requ.method);
240                                break;
241                            };
242                            // FIXME; the id in the JSON RPC can be null!
243                            let response = match resp {
244                                Ok(result) => Response {
245                                    id: requ.id.clone().unwrap(),
246                                    jsonrpc: requ.jsonrpc.to_owned(),
247                                    result: Some(result),
248                                    error: None,
249                                },
250                                Err(err) => Response {
251                                    result: None,
252                                    error: Some(err.into()),
253                                    id: requ.id.unwrap().clone(),
254                                    jsonrpc: requ.jsonrpc.clone(),
255                                },
256                            };
257                            log::trace!(target: "jsonrpc", "send response: `{:?}`", response);
258                            self.send_resp(addr.to_string(), response);
259                        }
260
261                        if event.is_writable() {
262                            let stream = self.conn.get(addr);
263                            if stream.is_none() {
264                                log::error!(target: "jsonrpc", "connection not found `{addr}`");
265                                continue;
266                            };
267
268                            let mut stream = stream.unwrap();
269                            let Some(resp) = self.pop_resp(addr.to_string()) else {
270                                break;
271                            };
272                            let buff = serde_json::to_string(&resp).unwrap();
273                            if let Err(err) = stream.write_all(buff.as_bytes()) {
274                                if err.kind() != ErrorKind::WouldBlock {
275                                    return Err(err);
276                                }
277                            }
278                            match stream.flush() {
279                                // In this case, we've written all the data, we
280                                // are no longer interested in writing to this
281                                // socket.
282                                Ok(()) => {
283                                    event.source.unset(popol::interest::WRITE);
284                                }
285                                // In this case, the write couldn't complete. Set
286                                // our interest to `WRITE` to be notified when the
287                                // socket is ready to write again.
288                                Err(err)
289                                    if [io::ErrorKind::WouldBlock, io::ErrorKind::WriteZero]
290                                        .contains(&err.kind()) =>
291                                {
292                                    event.source.set(popol::interest::WRITE);
293                                }
294                                Err(err) => {
295                                    log::error!(target: "jsonrpc", "{}: Write error: {}", addr, err.to_string());
296                                }
297                            }
298                            stream.shutdown(std::net::Shutdown::Both)?;
299                        }
300                    }
301                }
302            }
303        }
304        Ok(())
305    }
306
307    pub fn handler(&self) -> Arc<Handler<T>> {
308        self.handler.clone()
309    }
310
311    pub fn spawn(self) -> JoinHandle<io::Result<()>> {
312        std::thread::spawn(move || self.listen())
313    }
314}
315
316impl<T: Send + Sync + 'static> Drop for JSONRPCv2<T> {
317    fn drop(&mut self) {
318        let _ = std::fs::remove_file(&self.socket_path).unwrap();
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use std::{
325        io::Write, os::unix::net::UnixStream, path::Path, str::FromStr, sync::Arc, time::Duration,
326    };
327
328    use lampo_common::logger;
329    use ntest::timeout;
330    use serde_json::Value;
331
332    use crate::{
333        command::Context,
334        json_rpc2::{Id, Request, Response},
335        JSONRPCv2,
336    };
337
338    struct DummyCtx;
339
340    impl Context for DummyCtx {
341        type Ctx = DummyCtx;
342
343        fn ctx(&self) -> &Self::Ctx {
344            self
345        }
346    }
347
348    #[test]
349    #[timeout(9000)]
350    fn register_rpc() {
351        logger::init(log::Level::Debug).unwrap();
352        let path = "/tmp/tmp.sock";
353        let _ = std::fs::remove_file(path);
354        let server = JSONRPCv2::new(Arc::new(DummyCtx), path).unwrap();
355        let _ = server.add_rpc("foo", |_: &DummyCtx, request| {
356            Ok(serde_json::json!(request))
357        });
358        let res = server.add_rpc("secon", |_: &DummyCtx, request| {
359            Ok(serde_json::json!(request))
360        });
361        assert!(res.is_ok(), "{:?}", res);
362
363        let handler = server.handler();
364        let worker = server.spawn();
365        let request = Request::<Value> {
366            id: Some(0.into()),
367            jsonrpc: String::from_str("2.0").unwrap(),
368            method: "foo".to_owned(),
369            params: serde_json::Value::Array([].to_vec()),
370        };
371        let client_worker = std::thread::spawn(move || {
372            let buff = serde_json::to_string(&request).unwrap();
373            //connect to the socket
374            let mut stream = match UnixStream::connect(Path::new("/tmp/tmp.sock")) {
375                Err(_) => panic!("server is not running"),
376                Ok(stream) => stream,
377            };
378            log::info!(target: "client", "sending {buff}");
379            let _ = stream.write_all(buff.as_bytes()).unwrap();
380            let _ = stream.flush().unwrap();
381            log::info!(target: "client", "waiting for server response");
382            log::info!(target: "client", "read answer from server");
383            let resp: Response<Value> = serde_json::from_reader(stream).unwrap();
384            log::info!(target: "client", "msg received: {:?}", resp);
385            assert_eq!(resp.id, request.id.unwrap());
386            resp
387        });
388
389        let client_worker2 = std::thread::spawn(move || {
390            std::thread::sleep(Duration::from_secs(3));
391            let request = Request::<Value> {
392                id: Some(1.into()),
393                jsonrpc: String::from_str("2.0").unwrap(),
394                method: "secon".to_owned(),
395                params: serde_json::Value::Array([].to_vec()),
396            };
397
398            let buff = serde_json::to_string(&request).unwrap();
399            let mut stream = match UnixStream::connect(Path::new("/tmp/tmp.sock")) {
400                Err(_) => panic!("server is not running"),
401                Ok(stream) => stream,
402            };
403            log::info!(target: "client", "sending {buff}");
404            let _ = stream.write_all(buff.as_bytes()).unwrap();
405            let _ = stream.flush().unwrap();
406            log::info!(target: "client", "waiting for server response");
407            log::info!(target: "client", "read answer from server");
408            let resp: Response<Value> = serde_json::from_reader(stream).unwrap();
409            log::info!(target: "client", "msg received: {:?}", resp);
410            resp
411        });
412
413        let resp = client_worker.join().unwrap();
414        assert_eq!(Id::Str("0".to_owned()), resp.id);
415        let resp = client_worker2.join().unwrap();
416        assert_eq!(Id::Str("1".to_owned()), resp.id);
417        handler.stop();
418
419        let _ = worker.join();
420    }
421}