1use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9use std::thread;
10
11use serde::Serialize;
12use tracing::debug;
13
14use crate::error::Error;
15use crate::shutdown::ShutdownSignal;
16use crate::transports::{Stdio, Transport};
17use crate::types::{Message, Notification, Request, RequestId, Response};
18
19trait HandlerFn: Send + Sync {
20 fn call(&self, params: serde_json::Value) -> Result<serde_json::Value, Error>;
21}
22
23struct HandlerWrapper<F, P, R>
24where
25 F: Fn(P) -> Result<R, Error> + Send + Sync + 'static,
26 P: serde::de::DeserializeOwned + Send + Sync + 'static,
27 R: Serialize + Send + Sync + 'static,
28{
29 f: Arc<F>,
30 _phantom: std::marker::PhantomData<(P, R)>,
31}
32
33impl<F, P, R> HandlerFn for HandlerWrapper<F, P, R>
34where
35 F: Fn(P) -> Result<R, Error> + Send + Sync + 'static,
36 P: serde::de::DeserializeOwned + Send + Sync + 'static,
37 R: Serialize + Send + Sync + 'static,
38{
39 fn call(&self, params: serde_json::Value) -> Result<serde_json::Value, Error> {
40 let parsed: P = serde_json::from_value(params)?;
41 let result = (self.f)(parsed)?;
42 Ok(serde_json::to_value(result)?)
43 }
44}
45
46type Job = Box<dyn FnOnce() + Send + 'static>;
47
48struct Worker {
49 _handle: thread::JoinHandle<()>,
50}
51
52impl Worker {
53 fn spawn(_id: usize, receiver: Arc<Mutex<std::sync::mpsc::Receiver<Job>>>) -> Self {
54 let handle = thread::spawn(move || {
55 loop {
56 let job = {
57 let rx = match receiver.lock() {
58 Ok(guard) => guard,
59 Err(_) => break,
60 };
61 rx.recv()
62 };
63
64 match job {
65 Ok(job) => job(),
66 Err(_) => break,
67 }
68 }
69 });
70
71 Self { _handle: handle }
72 }
73}
74
75struct ThreadPool {
76 workers: Vec<Worker>,
77 sender: Option<std::sync::mpsc::Sender<Job>>,
78}
79
80impl ThreadPool {
81 fn new(size: usize) -> Self {
82 assert!(size > 0, "Thread pool size must be greater than 0");
83
84 let (sender, receiver) = std::sync::mpsc::channel();
85 let receiver = Arc::new(Mutex::new(receiver));
86
87 let mut workers = Vec::with_capacity(size);
88
89 for id in 0..size {
90 workers.push(Worker::spawn(id, Arc::clone(&receiver)));
91 }
92
93 Self {
94 workers,
95 sender: Some(sender),
96 }
97 }
98
99 fn execute<F>(&self, job: F) -> Result<(), Error>
100 where
101 F: FnOnce() + Send + 'static,
102 {
103 let job = Box::new(job);
104 let sender = self.sender.as_ref().ok_or_else(|| {
105 Error::TransportError(std::io::Error::new(
106 std::io::ErrorKind::NotConnected,
107 "Thread pool is not available",
108 ))
109 })?;
110
111 sender.send(job).map_err(|_| {
112 Error::TransportError(std::io::Error::new(
113 std::io::ErrorKind::BrokenPipe,
114 "Failed to send job to thread pool",
115 ))
116 })
117 }
118}
119
120impl Drop for ThreadPool {
121 fn drop(&mut self) {
122 drop(self.sender.take());
123 for _worker in &mut self.workers {}
124 }
125}
126
127struct ResponseData {
128 response: Response,
129 batch_id: Option<usize>,
130 batch_index: Option<usize>,
131}
132
133struct BatchContext {
134 responses: Vec<Option<Response>>,
135 expected_count: usize,
136}
137
138pub struct Server {
139 handlers: HashMap<String, Box<dyn HandlerFn>>,
140 thread_pool_size: usize,
141 shutdown_signal: Option<ShutdownSignal>,
142 transport: Option<Box<dyn Transport>>,
143}
144
145impl Server {
146 pub fn new() -> Self {
147 Self {
148 handlers: HashMap::new(),
149 thread_pool_size: num_cpus::get(),
150 shutdown_signal: None,
151 transport: None,
152 }
153 }
154
155 pub fn with_thread_pool_size(mut self, size: usize) -> Self {
156 assert!(size > 0, "Thread pool size must be greater than 0");
157 self.thread_pool_size = size;
158 self
159 }
160
161 pub fn with_shutdown_signal(mut self, signal: ShutdownSignal) -> Self {
162 self.shutdown_signal = Some(signal);
163 self
164 }
165
166 pub fn with_transport<T>(mut self, transport: T) -> Self
167 where
168 T: Transport + 'static,
169 {
170 self.transport = Some(Box::new(transport));
171 self
172 }
173
174 pub fn register<F, P, R>(&mut self, method: &str, handler: F) -> Result<(), Error>
175 where
176 F: Fn(P) -> Result<R, Error> + Send + Sync + 'static,
177 P: serde::de::DeserializeOwned + Send + Sync + 'static,
178 R: Serialize + Send + Sync + 'static,
179 {
180 let wrapper = HandlerWrapper {
181 f: Arc::new(handler),
182 _phantom: std::marker::PhantomData,
183 };
184 self.handlers.insert(method.to_string(), Box::new(wrapper));
185 Ok(())
186 }
187
188 pub fn run(&mut self) -> Result<(), Error> {
189 let mut transport = self
190 .transport
191 .take()
192 .unwrap_or_else(|| Box::new(Stdio::default()) as Box<dyn Transport>);
193 let thread_pool = ThreadPool::new(self.thread_pool_size);
194 let handlers = Arc::new(std::sync::Mutex::new(std::mem::take(&mut self.handlers)));
195 let shutdown_signal = self.shutdown_signal.clone();
196 let (response_sender, response_receiver) = std::sync::mpsc::channel::<ResponseData>();
197 let mut batches: HashMap<usize, BatchContext> = HashMap::new();
198 let mut next_batch_id: usize = 0;
199
200 loop {
201 if let Some(ref signal) = shutdown_signal
202 && signal.is_shutdown_requested()
203 {
204 break;
205 }
206
207 let json_str = match transport.receive_message() {
208 Ok(msg) => {
209 debug!("Received message from transport: {}", msg);
210 msg
211 }
212 Err(Error::TransportError(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
213 debug!("EOF received, breaking loop");
214 break;
215 }
216 Err(e) => {
217 debug!("Transport error: {}", e);
218 let error = crate::types::Error::internal_error("Internal error");
219 let response = Response::error(RequestId::Null, error);
220 let json = match serde_json::to_string(&response) {
221 Ok(json) => json,
222 Err(e) => {
223 eprintln!("Failed to serialize internal error response: {}", e);
224 continue;
225 }
226 };
227 debug!("Sending internal error response: {}", json);
228 let _ = transport.send_message(&json);
229 continue;
230 }
231 };
232
233 let value: serde_json::Value = match serde_json::from_str(&json_str) {
234 Ok(v) => {
235 debug!("JSON parsed successfully");
236 v
237 }
238 Err(_e) => {
239 debug!("Failed to parse JSON string: {}", json_str);
240 let error = crate::types::Error::parse_error("Parse error");
241 let response = Response::error(RequestId::Null, error);
242 let json = match serde_json::to_string(&response) {
243 Ok(json) => json,
244 Err(e) => {
245 eprintln!("Failed to serialize parse error response: {}", e);
246 continue;
247 }
248 };
249 debug!("Sending parse error response: {}", json);
250 let _ = transport.send_message(&json);
251 continue;
252 }
253 };
254
255 let request_id = value.get("id").and_then(|id_value| match id_value {
256 serde_json::Value::Null => Some(RequestId::Null),
257 serde_json::Value::Number(n) => n.as_u64().map(RequestId::Number),
258 serde_json::Value::String(s) => Some(RequestId::String(s.clone())),
259 _ => None,
260 });
261 debug!("Extracted request_id: {:?}", request_id);
262
263 let message = match Message::from_json(value) {
264 Ok(msg) => {
265 debug!("Message parsed successfully");
266 msg
267 }
268 Err(Error::InvalidRequest(e)) => {
269 debug!("Invalid Request error caught: {}", e);
270 let error = crate::types::Error::invalid_request("Invalid Request");
271 let id_to_use = request_id.unwrap_or(RequestId::Null);
272 debug!("Using request_id in error response: {:?}", id_to_use);
273 let response = Response::error(id_to_use, error);
274 let json = match serde_json::to_string(&response) {
275 Ok(json) => json,
276 Err(e) => {
277 eprintln!("Failed to serialize invalid request error response: {}", e);
278 continue;
279 }
280 };
281 debug!("Sending Invalid Request error response: {}", json);
282 let _ = transport.send_message(&json);
283 continue;
284 }
285 Err(e) => {
286 debug!("Error parsing message: {}", e);
287 eprintln!("Error parsing message: {}", e);
288 let error = crate::types::Error::internal_error("Internal error");
289 let response = Response::error(request_id.unwrap_or(RequestId::Null), error);
290 let json = match serde_json::to_string(&response) {
291 Ok(json) => json,
292 Err(e) => {
293 eprintln!("Failed to serialize internal error response: {}", e);
294 continue;
295 }
296 };
297 debug!("Sending internal error response: {}", json);
298 let _ = transport.send_message(&json);
299 continue;
300 }
301 };
302
303 let handlers_clone = Arc::clone(&handlers);
304
305 match message {
306 Message::Request(request) => {
307 let sender_clone = response_sender.clone();
308 thread_pool.execute(move || {
309 if let Err(e) = Self::process_request(handlers_clone, sender_clone, request)
310 {
311 eprintln!("Error processing request: {}", e);
312 }
313 })?;
314 }
315 Message::Notification(notification) => {
316 if let Err(e) = Self::process_notification(handlers_clone, notification) {
317 eprintln!("Error processing notification: {}", e);
318 }
319 }
320 Message::Batch(messages) => {
321 let batch_id = next_batch_id;
322 next_batch_id = next_batch_id.wrapping_add(1);
323
324 let request_count = messages
325 .iter()
326 .filter(|m| matches!(m, Message::Request(_) | Message::Response(_)))
327 .count();
328
329 if request_count > 0 {
330 batches.insert(
331 batch_id,
332 BatchContext {
333 responses: vec![None; request_count],
334 expected_count: request_count,
335 },
336 );
337
338 if let Err(e) = Self::process_batch(
339 &thread_pool,
340 handlers_clone,
341 response_sender.clone(),
342 batch_id,
343 messages,
344 ) {
345 eprintln!("Error processing batch: {}", e);
346 batches.remove(&batch_id);
347 }
348 } else {
349 eprintln!("Batch contains only notifications - no response sent");
350 }
351 }
352 Message::Response(_response) => {}
353 }
354
355 while let Ok(response_data) =
356 response_receiver.recv_timeout(std::time::Duration::from_millis(100))
357 {
358 if let Some(batch_id) = response_data.batch_id
359 && let Some(batch_index) = response_data.batch_index
360 && let Some(batch) = batches.get_mut(&batch_id)
361 && batch_index < batch.responses.len()
362 {
363 batch.responses[batch_index] = Some(response_data.response);
364
365 let completed = batch.responses.iter().filter(|r| r.is_some()).count();
366 if completed == batch.expected_count {
367 let responses: Vec<Response> =
368 batch.responses.drain(..).flatten().collect();
369
370 if !responses.is_empty() {
371 let batch_json = serde_json::to_string(&responses)?;
372 transport.send_message(&batch_json)?;
373 }
374
375 batches.remove(&batch_id);
376 }
377 } else {
378 let json = serde_json::to_string(&response_data.response)?;
379 transport.send_message(&json)?;
380 }
381 }
382 }
383
384 while let Ok(response_data) =
385 response_receiver.recv_timeout(std::time::Duration::from_millis(100))
386 {
387 let json = serde_json::to_string(&response_data.response)?;
388 transport.send_message(&json)?;
389 }
390
391 Ok(())
392 }
393
394 fn process_request(
395 handlers: Arc<std::sync::Mutex<HashMap<String, Box<dyn HandlerFn>>>>,
396 sender: std::sync::mpsc::Sender<ResponseData>,
397 request: Request,
398 ) -> Result<(), Error> {
399 Self::process_request_with_batch(handlers, sender, request, None, None)
400 }
401
402 fn process_request_with_batch(
403 handlers: Arc<std::sync::Mutex<HashMap<String, Box<dyn HandlerFn>>>>,
404 sender: std::sync::mpsc::Sender<ResponseData>,
405 request: Request,
406 batch_id: Option<usize>,
407 batch_index: Option<usize>,
408 ) -> Result<(), Error> {
409 let id = request.id.clone();
410 let method_name = request.method.clone();
411 let params = request.params.unwrap_or(serde_json::Value::Null);
412
413 let response = match handlers.lock() {
414 Ok(handlers_lock) => match handlers_lock.get(&method_name) {
415 Some(handler) => match handler.call(params) {
416 Ok(result) => Response::success(id, result),
417 Err(Error::RpcError { code, message }) => {
418 let error = crate::types::Error::new(code, message, None);
419 Response::error(id, error)
420 }
421 Err(e) => {
422 let error = crate::types::Error::new(-32603, e.to_string(), None);
423 Response::error(id, error)
424 }
425 },
426 None => {
427 let error = crate::types::Error::method_not_found(format!(
428 "Unknown method: {}",
429 method_name
430 ));
431 Response::error(id, error)
432 }
433 },
434 Err(_) => {
435 let error = crate::types::Error::internal_error("Internal server error");
436 Response::error(id, error)
437 }
438 };
439
440 sender
441 .send(ResponseData {
442 response,
443 batch_id,
444 batch_index,
445 })
446 .map_err(|e| {
447 Error::TransportError(std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))
448 })?;
449
450 Ok(())
451 }
452
453 fn process_notification(
454 handlers: Arc<std::sync::Mutex<HashMap<String, Box<dyn HandlerFn>>>>,
455 notification: Notification,
456 ) -> Result<(), Error> {
457 eprintln!("Processing notification: {}", notification.method);
458 let method_name = notification.method.clone();
459 let params = notification.params.unwrap_or(serde_json::Value::Null);
460
461 match handlers.lock() {
462 Ok(handlers_lock) => match handlers_lock.get(&method_name) {
463 Some(handler) => {
464 let _ = handler.call(params);
465 Ok(())
466 }
467 None => Ok(()),
468 },
469 Err(_) => Ok(()),
470 }
471 }
472
473 fn process_batch(
474 thread_pool: &ThreadPool,
475 handlers: Arc<std::sync::Mutex<HashMap<String, Box<dyn HandlerFn>>>>,
476 sender: std::sync::mpsc::Sender<ResponseData>,
477 batch_id: usize,
478 messages: Vec<Message>,
479 ) -> Result<(), Error> {
480 let mut request_index = 0;
481
482 for message in messages {
483 match message {
484 Message::Request(request) => {
485 let handlers_clone = Arc::clone(&handlers);
486 let sender_clone = sender.clone();
487 let index = request_index;
488 request_index += 1;
489
490 thread_pool.execute(move || {
491 if let Err(e) = Self::process_request_with_batch(
492 handlers_clone,
493 sender_clone,
494 request,
495 Some(batch_id),
496 Some(index),
497 ) {
498 eprintln!("Error processing request in batch: {}", e);
499 }
500 })?;
501 }
502 Message::Notification(notification) => {
503 if let Err(e) = Self::process_notification(handlers.clone(), notification) {
504 eprintln!("Error processing notification in batch: {}", e);
505 }
506 }
507 Message::Response(response) => {
508 let sender_clone = sender.clone();
509 let index = request_index;
510 request_index += 1;
511
512 sender_clone
513 .send(ResponseData {
514 response,
515 batch_id: Some(batch_id),
516 batch_index: Some(index),
517 })
518 .map_err(|e| {
519 Error::TransportError(std::io::Error::new(
520 std::io::ErrorKind::BrokenPipe,
521 e,
522 ))
523 })?;
524 }
525 _ => {
526 debug!("Unexpected message type in batch: {:?}", message);
527 }
528 }
529 }
530
531 Ok(())
532 }
533}
534
535impl Default for Server {
536 fn default() -> Self {
537 Self::new()
538 }
539}