1use std::cell::{Cell, RefCell};
4use std::collections::{HashMap, VecDeque};
5use std::io::{self, ErrorKind};
6use std::io::{Read, Write};
7use std::os::unix::net::UnixListener;
8use std::os::unix::net::{SocketAddr, UnixStream};
9use std::sync::{Arc, Mutex};
10use std::thread::JoinHandle;
11
12use popol::{Sources, Timeout};
14use serde_json::Value;
15
16pub mod command;
17pub mod errors;
18pub mod json_rpc2;
19
20use command::Context;
21
22use crate::errors::Error;
23use crate::json_rpc2::{Request, Response};
24
25#[derive(Debug, Clone, PartialEq)]
26pub enum RPCEvent {
27 Listening,
28 Connect(String),
29}
30
31pub struct JSONRPCv2<T: Send + Sync + 'static> {
32 socket_path: String,
33 sources: Sources<RPCEvent>,
34 socket: UnixListener,
35 handler: Arc<Handler<T>>,
36 pub(crate) conn: HashMap<String, UnixStream>,
38 conn_queue: Mutex<Cell<HashMap<String, VecDeque<Response<Value>>>>>,
39}
40
41pub struct Handler<T: Send + Sync + 'static> {
42 stop: Cell<bool>,
43 rpc_method:
44 RefCell<HashMap<String, Arc<dyn Fn(&T, &Value) -> Result<Value, errors::Error> + 'static>>>,
45 ctx: Arc<dyn Context<Ctx = T>>,
46}
47
48unsafe impl<T: Send + Sync> Sync for Handler<T> {}
49unsafe impl<T: Send + Sync> Send for Handler<T> {}
50
51impl<T: Send + Sync + 'static> Handler<T> {
52 pub fn new(ctx: Arc<dyn Context<Ctx = T>>) -> Self {
53 Handler::<T> {
54 stop: Cell::new(false),
55 rpc_method: RefCell::new(HashMap::new()),
56 ctx,
57 }
58 }
59
60 pub fn add_method<F>(&self, method: &str, callback: F)
61 where
62 F: Fn(&T, &Value) -> Result<Value, errors::Error> + 'static,
63 {
64 self.rpc_method
65 .borrow_mut()
66 .insert(method.to_owned(), Arc::new(callback));
67 }
68
69 pub fn run_callback(&self, req: &Request<Value>) -> Option<Result<Value, errors::Error>> {
70 let binding = self.rpc_method.borrow();
71 let Some(callback) = binding.get(&req.method) else {
72 return Some(Err(errors::RpcError {
73 message: format!("method `{}` not found", req.method),
74 code: -1,
75 data: None,
76 }
77 .into()));
78 };
79 let resp = callback(self.ctx(), &req.params);
80 Some(resp)
81 }
82
83 pub fn has_rpc(&self, method: &str) -> bool {
84 self.rpc_method.borrow().contains_key(method)
85 }
86
87 fn ctx(&self) -> &T {
88 self.ctx.ctx()
89 }
90
91 pub fn stop(&self) {
92 self.stop.set(true);
93 }
94}
95
96impl<T: Send + Sync + 'static> JSONRPCv2<T> {
97 pub fn new(ctx: Arc<dyn Context<Ctx = T>>, path: &str) -> Result<Self, Error> {
98 let listnet = UnixListener::bind(path)?;
99 let sources = Sources::<RPCEvent>::new();
100 Ok(Self {
101 sources,
102 socket: listnet,
103 handler: Arc::new(Handler::new(ctx)),
104 socket_path: path.to_owned(),
105 conn: HashMap::new(),
106 conn_queue: Mutex::new(Cell::new(HashMap::new())),
107 })
108 }
109
110 pub fn add_rpc<F>(&self, name: &str, callback: F) -> Result<(), ()>
111 where
112 F: Fn(&T, &Value) -> Result<Value, errors::Error> + 'static,
113 {
114 if self.handler.has_rpc(name) {
115 return Err(());
116 }
117 self.handler.add_method(name, callback);
118 Ok(())
119 }
120
121 pub fn add_connection(&mut self, key: &SocketAddr, stream: UnixStream) {
122 let path = if let Some(path) = key.as_pathname() {
123 path.to_str().unwrap()
124 } else {
125 "unnamed"
126 };
127 let res = stream.set_nonblocking(true);
128 debug_assert!(res.is_ok());
129 let event = RPCEvent::Connect(path.to_string());
130 self.sources.register(event, &stream, popol::interest::ALL);
131 self.conn.insert(path.to_owned(), stream);
132 }
133
134 pub fn send_resp(&self, key: String, resp: Response<Value>) {
135 let queue = self.conn_queue.lock().unwrap();
136
137 let mut conns = queue.take();
138 log::debug!(target: "jsonrpc", "{:?}", conns);
139 if conns.contains_key(&key) {
140 let Some(queue) = conns.get_mut(&key) else {
141 panic!("queue not found");
142 };
143 queue.push_back(resp);
144 } else {
145 let mut q = VecDeque::new();
146 q.push_back(resp);
147 conns.insert(key, q);
148 }
149 log::debug!(target: "jsonrpc", "{:?}", conns);
150 queue.set(conns);
151 }
152
153 pub fn pop_resp(&self, key: String) -> Option<Response<Value>> {
154 let queue = self.conn_queue.lock().unwrap();
155
156 let mut conns = queue.take();
157 if !conns.contains_key(&key) {
158 return None;
159 }
160 let Some(q) = conns.get_mut(&key) else {
161 return None;
162 };
163 let resp = q.pop_front();
164 queue.set(conns);
165 resp
166 }
167
168 #[allow(dead_code)]
169 fn ctx(&self) -> &T {
170 self.handler.ctx()
171 }
172
173 pub fn listen(mut self) -> io::Result<()> {
174 self.socket.set_nonblocking(true)?;
175 self.sources
176 .register(RPCEvent::Listening, &self.socket, popol::interest::READ);
177
178 log::info!(target: "jsonrpc", "starting server on {}", self.socket_path);
179 let mut events = vec![];
180 while !self.handler.stop.get() {
181 self.sources.poll(&mut events, Timeout::Never)?;
183
184 for mut event in events.drain(..) {
185 match &event.key {
186 RPCEvent::Listening => {
187 let conn = self.socket.accept();
188 let Ok((stream, addr)) = conn else {
189 if let Err(err) = &conn {
190 if err.kind() == ErrorKind::WouldBlock {
191 break;
192 }
193 }
194 log::error!(target: "jsonrpc", "fail to accept the connection: {:?}", conn);
195 continue;
196 };
197 log::trace!(target: "jsonrpc", "new connection to unix rpc socket");
198 self.add_connection(&addr, stream);
199 }
200 RPCEvent::Connect(addr) => {
201 if event.is_hangup() {
202 break;
203 }
204 if event.is_error() {
205 log::error!(target: "jsonrpc", "an error occurs");
206 continue;
207 }
208
209 if event.is_invalid() {
210 log::info!(target: "jsonrpc", "event invalid, unregister event from the tracking one");
211 self.sources.unregister(&event.key);
212 break;
213 }
214
215 if event.is_readable() {
216 let Some(mut stream) = self.conn.get(addr) else {
217 log::error!(target: "jsonrpc", "connection not found `{addr}`");
218 continue;
219 };
220 let mut buff = String::new();
221 if let Err(err) = stream.read_to_string(&mut buff) {
222 if err.kind() != ErrorKind::WouldBlock {
223 return Err(err);
224 }
225 log::info!(target: "jsonrpc", "blocking with err {:?}!", err);
226 }
227 if buff.is_empty() {
228 log::warn!(target: "jsonrpc", "buffer is empty");
229 break;
230 }
231 let buff = buff.trim();
232 log::info!(target: "jsonrpc", "buffer read {buff}");
233 let requ: Request<Value> =
234 serde_json::from_str(&buff).map_err(|err| {
235 io::Error::new(io::ErrorKind::Other, format!("{err}"))
236 })?;
237 log::trace!(target: "jsonrpc", "request {:?}", requ);
238 let Some(resp) = self.handler.run_callback(&requ) else {
239 log::error!(target: "jsonrpc", "`{}` not found!", requ.method);
240 break;
241 };
242 let response = match resp {
244 Ok(result) => Response {
245 id: requ.id.clone().unwrap(),
246 jsonrpc: requ.jsonrpc.to_owned(),
247 result: Some(result),
248 error: None,
249 },
250 Err(err) => Response {
251 result: None,
252 error: Some(err.into()),
253 id: requ.id.unwrap().clone(),
254 jsonrpc: requ.jsonrpc.clone(),
255 },
256 };
257 log::trace!(target: "jsonrpc", "send response: `{:?}`", response);
258 self.send_resp(addr.to_string(), response);
259 }
260
261 if event.is_writable() {
262 let stream = self.conn.get(addr);
263 if stream.is_none() {
264 log::error!(target: "jsonrpc", "connection not found `{addr}`");
265 continue;
266 };
267
268 let mut stream = stream.unwrap();
269 let Some(resp) = self.pop_resp(addr.to_string()) else {
270 break;
271 };
272 let buff = serde_json::to_string(&resp).unwrap();
273 if let Err(err) = stream.write_all(buff.as_bytes()) {
274 if err.kind() != ErrorKind::WouldBlock {
275 return Err(err);
276 }
277 }
278 match stream.flush() {
279 Ok(()) => {
283 event.source.unset(popol::interest::WRITE);
284 }
285 Err(err)
289 if [io::ErrorKind::WouldBlock, io::ErrorKind::WriteZero]
290 .contains(&err.kind()) =>
291 {
292 event.source.set(popol::interest::WRITE);
293 }
294 Err(err) => {
295 log::error!(target: "jsonrpc", "{}: Write error: {}", addr, err.to_string());
296 }
297 }
298 stream.shutdown(std::net::Shutdown::Both)?;
299 }
300 }
301 }
302 }
303 }
304 Ok(())
305 }
306
307 pub fn handler(&self) -> Arc<Handler<T>> {
308 self.handler.clone()
309 }
310
311 pub fn spawn(self) -> JoinHandle<io::Result<()>> {
312 std::thread::spawn(move || self.listen())
313 }
314}
315
316impl<T: Send + Sync + 'static> Drop for JSONRPCv2<T> {
317 fn drop(&mut self) {
318 let _ = std::fs::remove_file(&self.socket_path).unwrap();
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use std::{
325 io::Write, os::unix::net::UnixStream, path::Path, str::FromStr, sync::Arc, time::Duration,
326 };
327
328 use lampo_common::logger;
329 use ntest::timeout;
330 use serde_json::Value;
331
332 use crate::{
333 command::Context,
334 json_rpc2::{Id, Request, Response},
335 JSONRPCv2,
336 };
337
338 struct DummyCtx;
339
340 impl Context for DummyCtx {
341 type Ctx = DummyCtx;
342
343 fn ctx(&self) -> &Self::Ctx {
344 self
345 }
346 }
347
348 #[test]
349 #[timeout(9000)]
350 fn register_rpc() {
351 logger::init(log::Level::Debug).unwrap();
352 let path = "/tmp/tmp.sock";
353 let _ = std::fs::remove_file(path);
354 let server = JSONRPCv2::new(Arc::new(DummyCtx), path).unwrap();
355 let _ = server.add_rpc("foo", |_: &DummyCtx, request| {
356 Ok(serde_json::json!(request))
357 });
358 let res = server.add_rpc("secon", |_: &DummyCtx, request| {
359 Ok(serde_json::json!(request))
360 });
361 assert!(res.is_ok(), "{:?}", res);
362
363 let handler = server.handler();
364 let worker = server.spawn();
365 let request = Request::<Value> {
366 id: Some(0.into()),
367 jsonrpc: String::from_str("2.0").unwrap(),
368 method: "foo".to_owned(),
369 params: serde_json::Value::Array([].to_vec()),
370 };
371 let client_worker = std::thread::spawn(move || {
372 let buff = serde_json::to_string(&request).unwrap();
373 let mut stream = match UnixStream::connect(Path::new("/tmp/tmp.sock")) {
375 Err(_) => panic!("server is not running"),
376 Ok(stream) => stream,
377 };
378 log::info!(target: "client", "sending {buff}");
379 let _ = stream.write_all(buff.as_bytes()).unwrap();
380 let _ = stream.flush().unwrap();
381 log::info!(target: "client", "waiting for server response");
382 log::info!(target: "client", "read answer from server");
383 let resp: Response<Value> = serde_json::from_reader(stream).unwrap();
384 log::info!(target: "client", "msg received: {:?}", resp);
385 assert_eq!(resp.id, request.id.unwrap());
386 resp
387 });
388
389 let client_worker2 = std::thread::spawn(move || {
390 std::thread::sleep(Duration::from_secs(3));
391 let request = Request::<Value> {
392 id: Some(1.into()),
393 jsonrpc: String::from_str("2.0").unwrap(),
394 method: "secon".to_owned(),
395 params: serde_json::Value::Array([].to_vec()),
396 };
397
398 let buff = serde_json::to_string(&request).unwrap();
399 let mut stream = match UnixStream::connect(Path::new("/tmp/tmp.sock")) {
400 Err(_) => panic!("server is not running"),
401 Ok(stream) => stream,
402 };
403 log::info!(target: "client", "sending {buff}");
404 let _ = stream.write_all(buff.as_bytes()).unwrap();
405 let _ = stream.flush().unwrap();
406 log::info!(target: "client", "waiting for server response");
407 log::info!(target: "client", "read answer from server");
408 let resp: Response<Value> = serde_json::from_reader(stream).unwrap();
409 log::info!(target: "client", "msg received: {:?}", resp);
410 resp
411 });
412
413 let resp = client_worker.join().unwrap();
414 assert_eq!(Id::Str("0".to_owned()), resp.id);
415 let resp = client_worker2.join().unwrap();
416 assert_eq!(Id::Str("1".to_owned()), resp.id);
417 handler.stop();
418
419 let _ = worker.join();
420 }
421}