Skip to main content

json_rpc/
server.rs

1//! JSON-RPC server with builder pattern.
2//!
3//! This module provides a `Server` that uses a builder pattern for
4//! method registration and includes a thread pool for concurrent
5//! request handling.
6
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9use std::thread;
10
11use serde::Serialize;
12use tracing::debug;
13
14use crate::error::Error;
15use crate::shutdown::ShutdownSignal;
16use crate::transports::{Stdio, Transport};
17use crate::types::{Message, Notification, Request, RequestId, Response};
18
19trait HandlerFn: Send + Sync {
20    fn call(&self, params: serde_json::Value) -> Result<serde_json::Value, Error>;
21}
22
23struct HandlerWrapper<F, P, R>
24where
25    F: Fn(P) -> Result<R, Error> + Send + Sync + 'static,
26    P: serde::de::DeserializeOwned + Send + Sync + 'static,
27    R: Serialize + Send + Sync + 'static,
28{
29    f: Arc<F>,
30    _phantom: std::marker::PhantomData<(P, R)>,
31}
32
33impl<F, P, R> HandlerFn for HandlerWrapper<F, P, R>
34where
35    F: Fn(P) -> Result<R, Error> + Send + Sync + 'static,
36    P: serde::de::DeserializeOwned + Send + Sync + 'static,
37    R: Serialize + Send + Sync + 'static,
38{
39    fn call(&self, params: serde_json::Value) -> Result<serde_json::Value, Error> {
40        let parsed: P = serde_json::from_value(params)?;
41        let result = (self.f)(parsed)?;
42        Ok(serde_json::to_value(result)?)
43    }
44}
45
46type Job = Box<dyn FnOnce() + Send + 'static>;
47
48struct Worker {
49    _handle: thread::JoinHandle<()>,
50}
51
52impl Worker {
53    fn spawn(_id: usize, receiver: Arc<Mutex<std::sync::mpsc::Receiver<Job>>>) -> Self {
54        let handle = thread::spawn(move || {
55            loop {
56                let job = {
57                    let rx = match receiver.lock() {
58                        Ok(guard) => guard,
59                        Err(_) => break,
60                    };
61                    rx.recv()
62                };
63
64                match job {
65                    Ok(job) => job(),
66                    Err(_) => break,
67                }
68            }
69        });
70
71        Self { _handle: handle }
72    }
73}
74
75struct ThreadPool {
76    workers: Vec<Worker>,
77    sender: Option<std::sync::mpsc::Sender<Job>>,
78}
79
80impl ThreadPool {
81    fn new(size: usize) -> Self {
82        assert!(size > 0, "Thread pool size must be greater than 0");
83
84        let (sender, receiver) = std::sync::mpsc::channel();
85        let receiver = Arc::new(Mutex::new(receiver));
86
87        let mut workers = Vec::with_capacity(size);
88
89        for id in 0..size {
90            workers.push(Worker::spawn(id, Arc::clone(&receiver)));
91        }
92
93        Self {
94            workers,
95            sender: Some(sender),
96        }
97    }
98
99    fn execute<F>(&self, job: F) -> Result<(), Error>
100    where
101        F: FnOnce() + Send + 'static,
102    {
103        let job = Box::new(job);
104        let sender = self.sender.as_ref().ok_or_else(|| {
105            Error::TransportError(std::io::Error::new(
106                std::io::ErrorKind::NotConnected,
107                "Thread pool is not available",
108            ))
109        })?;
110
111        sender.send(job).map_err(|_| {
112            Error::TransportError(std::io::Error::new(
113                std::io::ErrorKind::BrokenPipe,
114                "Failed to send job to thread pool",
115            ))
116        })
117    }
118}
119
120impl Drop for ThreadPool {
121    fn drop(&mut self) {
122        drop(self.sender.take());
123        for _worker in &mut self.workers {}
124    }
125}
126
127struct ResponseData {
128    response: Response,
129    batch_id: Option<usize>,
130    batch_index: Option<usize>,
131}
132
133struct BatchContext {
134    responses: Vec<Option<Response>>,
135    expected_count: usize,
136}
137
138pub struct Server {
139    handlers: HashMap<String, Box<dyn HandlerFn>>,
140    thread_pool_size: usize,
141    shutdown_signal: Option<ShutdownSignal>,
142    transport: Option<Box<dyn Transport>>,
143}
144
145impl Server {
146    pub fn new() -> Self {
147        Self {
148            handlers: HashMap::new(),
149            thread_pool_size: num_cpus::get(),
150            shutdown_signal: None,
151            transport: None,
152        }
153    }
154
155    pub fn with_thread_pool_size(mut self, size: usize) -> Self {
156        assert!(size > 0, "Thread pool size must be greater than 0");
157        self.thread_pool_size = size;
158        self
159    }
160
161    pub fn with_shutdown_signal(mut self, signal: ShutdownSignal) -> Self {
162        self.shutdown_signal = Some(signal);
163        self
164    }
165
166    pub fn with_transport<T>(mut self, transport: T) -> Self
167    where
168        T: Transport + 'static,
169    {
170        self.transport = Some(Box::new(transport));
171        self
172    }
173
174    pub fn register<F, P, R>(&mut self, method: &str, handler: F) -> Result<(), Error>
175    where
176        F: Fn(P) -> Result<R, Error> + Send + Sync + 'static,
177        P: serde::de::DeserializeOwned + Send + Sync + 'static,
178        R: Serialize + Send + Sync + 'static,
179    {
180        let wrapper = HandlerWrapper {
181            f: Arc::new(handler),
182            _phantom: std::marker::PhantomData,
183        };
184        self.handlers.insert(method.to_string(), Box::new(wrapper));
185        Ok(())
186    }
187
188    pub fn run(&mut self) -> Result<(), Error> {
189        let mut transport = self
190            .transport
191            .take()
192            .unwrap_or_else(|| Box::new(Stdio::default()) as Box<dyn Transport>);
193        let thread_pool = ThreadPool::new(self.thread_pool_size);
194        let handlers = Arc::new(std::sync::Mutex::new(std::mem::take(&mut self.handlers)));
195        let shutdown_signal = self.shutdown_signal.clone();
196        let (response_sender, response_receiver) = std::sync::mpsc::channel::<ResponseData>();
197        let mut batches: HashMap<usize, BatchContext> = HashMap::new();
198        let mut next_batch_id: usize = 0;
199
200        loop {
201            if let Some(ref signal) = shutdown_signal
202                && signal.is_shutdown_requested()
203            {
204                break;
205            }
206
207            let json_str = match transport.receive_message() {
208                Ok(msg) => {
209                    debug!("Received message from transport: {}", msg);
210                    msg
211                }
212                Err(Error::TransportError(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
213                    debug!("EOF received, breaking loop");
214                    break;
215                }
216                Err(e) => {
217                    debug!("Transport error: {}", e);
218                    let error = crate::types::Error::internal_error("Internal error");
219                    let response = Response::error(RequestId::Null, error);
220                    let json = match serde_json::to_string(&response) {
221                        Ok(json) => json,
222                        Err(e) => {
223                            eprintln!("Failed to serialize internal error response: {}", e);
224                            continue;
225                        }
226                    };
227                    debug!("Sending internal error response: {}", json);
228                    let _ = transport.send_message(&json);
229                    continue;
230                }
231            };
232
233            let value: serde_json::Value = match serde_json::from_str(&json_str) {
234                Ok(v) => {
235                    debug!("JSON parsed successfully");
236                    v
237                }
238                Err(_e) => {
239                    debug!("Failed to parse JSON string: {}", json_str);
240                    let error = crate::types::Error::parse_error("Parse error");
241                    let response = Response::error(RequestId::Null, error);
242                    let json = match serde_json::to_string(&response) {
243                        Ok(json) => json,
244                        Err(e) => {
245                            eprintln!("Failed to serialize parse error response: {}", e);
246                            continue;
247                        }
248                    };
249                    debug!("Sending parse error response: {}", json);
250                    let _ = transport.send_message(&json);
251                    continue;
252                }
253            };
254
255            let request_id = value.get("id").and_then(|id_value| match id_value {
256                serde_json::Value::Null => Some(RequestId::Null),
257                serde_json::Value::Number(n) => n.as_u64().map(RequestId::Number),
258                serde_json::Value::String(s) => Some(RequestId::String(s.clone())),
259                _ => None,
260            });
261            debug!("Extracted request_id: {:?}", request_id);
262
263            let message = match Message::from_json(value) {
264                Ok(msg) => {
265                    debug!("Message parsed successfully");
266                    msg
267                }
268                Err(Error::InvalidRequest(e)) => {
269                    debug!("Invalid Request error caught: {}", e);
270                    let error = crate::types::Error::invalid_request("Invalid Request");
271                    let id_to_use = request_id.unwrap_or(RequestId::Null);
272                    debug!("Using request_id in error response: {:?}", id_to_use);
273                    let response = Response::error(id_to_use, error);
274                    let json = match serde_json::to_string(&response) {
275                        Ok(json) => json,
276                        Err(e) => {
277                            eprintln!("Failed to serialize invalid request error response: {}", e);
278                            continue;
279                        }
280                    };
281                    debug!("Sending Invalid Request error response: {}", json);
282                    let _ = transport.send_message(&json);
283                    continue;
284                }
285                Err(e) => {
286                    debug!("Error parsing message: {}", e);
287                    eprintln!("Error parsing message: {}", e);
288                    let error = crate::types::Error::internal_error("Internal error");
289                    let response = Response::error(request_id.unwrap_or(RequestId::Null), error);
290                    let json = match serde_json::to_string(&response) {
291                        Ok(json) => json,
292                        Err(e) => {
293                            eprintln!("Failed to serialize internal error response: {}", e);
294                            continue;
295                        }
296                    };
297                    debug!("Sending internal error response: {}", json);
298                    let _ = transport.send_message(&json);
299                    continue;
300                }
301            };
302
303            let handlers_clone = Arc::clone(&handlers);
304
305            match message {
306                Message::Request(request) => {
307                    let sender_clone = response_sender.clone();
308                    thread_pool.execute(move || {
309                        if let Err(e) = Self::process_request(handlers_clone, sender_clone, request)
310                        {
311                            eprintln!("Error processing request: {}", e);
312                        }
313                    })?;
314                }
315                Message::Notification(notification) => {
316                    if let Err(e) = Self::process_notification(handlers_clone, notification) {
317                        eprintln!("Error processing notification: {}", e);
318                    }
319                }
320                Message::Batch(messages) => {
321                    let batch_id = next_batch_id;
322                    next_batch_id = next_batch_id.wrapping_add(1);
323
324                    let request_count = messages
325                        .iter()
326                        .filter(|m| matches!(m, Message::Request(_) | Message::Response(_)))
327                        .count();
328
329                    if request_count > 0 {
330                        batches.insert(
331                            batch_id,
332                            BatchContext {
333                                responses: vec![None; request_count],
334                                expected_count: request_count,
335                            },
336                        );
337
338                        if let Err(e) = Self::process_batch(
339                            &thread_pool,
340                            handlers_clone,
341                            response_sender.clone(),
342                            batch_id,
343                            messages,
344                        ) {
345                            eprintln!("Error processing batch: {}", e);
346                            batches.remove(&batch_id);
347                        }
348                    } else {
349                        eprintln!("Batch contains only notifications - no response sent");
350                    }
351                }
352                Message::Response(_response) => {}
353            }
354
355            while let Ok(response_data) =
356                response_receiver.recv_timeout(std::time::Duration::from_millis(100))
357            {
358                if let Some(batch_id) = response_data.batch_id
359                    && let Some(batch_index) = response_data.batch_index
360                    && let Some(batch) = batches.get_mut(&batch_id)
361                    && batch_index < batch.responses.len()
362                {
363                    batch.responses[batch_index] = Some(response_data.response);
364
365                    let completed = batch.responses.iter().filter(|r| r.is_some()).count();
366                    if completed == batch.expected_count {
367                        let responses: Vec<Response> =
368                            batch.responses.drain(..).flatten().collect();
369
370                        if !responses.is_empty() {
371                            let batch_json = serde_json::to_string(&responses)?;
372                            transport.send_message(&batch_json)?;
373                        }
374
375                        batches.remove(&batch_id);
376                    }
377                } else {
378                    let json = serde_json::to_string(&response_data.response)?;
379                    transport.send_message(&json)?;
380                }
381            }
382        }
383
384        while let Ok(response_data) =
385            response_receiver.recv_timeout(std::time::Duration::from_millis(100))
386        {
387            let json = serde_json::to_string(&response_data.response)?;
388            transport.send_message(&json)?;
389        }
390
391        Ok(())
392    }
393
394    fn process_request(
395        handlers: Arc<std::sync::Mutex<HashMap<String, Box<dyn HandlerFn>>>>,
396        sender: std::sync::mpsc::Sender<ResponseData>,
397        request: Request,
398    ) -> Result<(), Error> {
399        Self::process_request_with_batch(handlers, sender, request, None, None)
400    }
401
402    fn process_request_with_batch(
403        handlers: Arc<std::sync::Mutex<HashMap<String, Box<dyn HandlerFn>>>>,
404        sender: std::sync::mpsc::Sender<ResponseData>,
405        request: Request,
406        batch_id: Option<usize>,
407        batch_index: Option<usize>,
408    ) -> Result<(), Error> {
409        let id = request.id.clone();
410        let method_name = request.method.clone();
411        let params = request.params.unwrap_or(serde_json::Value::Null);
412
413        let response = match handlers.lock() {
414            Ok(handlers_lock) => match handlers_lock.get(&method_name) {
415                Some(handler) => match handler.call(params) {
416                    Ok(result) => Response::success(id, result),
417                    Err(Error::RpcError { code, message }) => {
418                        let error = crate::types::Error::new(code, message, None);
419                        Response::error(id, error)
420                    }
421                    Err(e) => {
422                        let error = crate::types::Error::new(-32603, e.to_string(), None);
423                        Response::error(id, error)
424                    }
425                },
426                None => {
427                    let error = crate::types::Error::method_not_found(format!(
428                        "Unknown method: {}",
429                        method_name
430                    ));
431                    Response::error(id, error)
432                }
433            },
434            Err(_) => {
435                let error = crate::types::Error::internal_error("Internal server error");
436                Response::error(id, error)
437            }
438        };
439
440        sender
441            .send(ResponseData {
442                response,
443                batch_id,
444                batch_index,
445            })
446            .map_err(|e| {
447                Error::TransportError(std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))
448            })?;
449
450        Ok(())
451    }
452
453    fn process_notification(
454        handlers: Arc<std::sync::Mutex<HashMap<String, Box<dyn HandlerFn>>>>,
455        notification: Notification,
456    ) -> Result<(), Error> {
457        eprintln!("Processing notification: {}", notification.method);
458        let method_name = notification.method.clone();
459        let params = notification.params.unwrap_or(serde_json::Value::Null);
460
461        match handlers.lock() {
462            Ok(handlers_lock) => match handlers_lock.get(&method_name) {
463                Some(handler) => {
464                    let _ = handler.call(params);
465                    Ok(())
466                }
467                None => Ok(()),
468            },
469            Err(_) => Ok(()),
470        }
471    }
472
473    fn process_batch(
474        thread_pool: &ThreadPool,
475        handlers: Arc<std::sync::Mutex<HashMap<String, Box<dyn HandlerFn>>>>,
476        sender: std::sync::mpsc::Sender<ResponseData>,
477        batch_id: usize,
478        messages: Vec<Message>,
479    ) -> Result<(), Error> {
480        let mut request_index = 0;
481
482        for message in messages {
483            match message {
484                Message::Request(request) => {
485                    let handlers_clone = Arc::clone(&handlers);
486                    let sender_clone = sender.clone();
487                    let index = request_index;
488                    request_index += 1;
489
490                    thread_pool.execute(move || {
491                        if let Err(e) = Self::process_request_with_batch(
492                            handlers_clone,
493                            sender_clone,
494                            request,
495                            Some(batch_id),
496                            Some(index),
497                        ) {
498                            eprintln!("Error processing request in batch: {}", e);
499                        }
500                    })?;
501                }
502                Message::Notification(notification) => {
503                    if let Err(e) = Self::process_notification(handlers.clone(), notification) {
504                        eprintln!("Error processing notification in batch: {}", e);
505                    }
506                }
507                Message::Response(response) => {
508                    let sender_clone = sender.clone();
509                    let index = request_index;
510                    request_index += 1;
511
512                    sender_clone
513                        .send(ResponseData {
514                            response,
515                            batch_id: Some(batch_id),
516                            batch_index: Some(index),
517                        })
518                        .map_err(|e| {
519                            Error::TransportError(std::io::Error::new(
520                                std::io::ErrorKind::BrokenPipe,
521                                e,
522                            ))
523                        })?;
524                }
525                _ => {
526                    debug!("Unexpected message type in batch: {:?}", message);
527                }
528            }
529        }
530
531        Ok(())
532    }
533}
534
535impl Default for Server {
536    fn default() -> Self {
537        Self::new()
538    }
539}