1use crate::error::{CopilotError, Result};
9use crate::transport::{MessageFramer, MessageReader, MessageWriter, StdioTransport, Transport};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::collections::HashMap;
13use std::pin::Pin;
14use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
18use tokio::net::TcpStream;
19use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
27#[serde(untagged)]
28pub enum JsonRpcId {
29 Num(i64),
30 Str(String),
31}
32
33impl From<i64> for JsonRpcId {
34 fn from(n: i64) -> Self {
35 Self::Num(n)
36 }
37}
38
39impl From<String> for JsonRpcId {
40 fn from(s: String) -> Self {
41 Self::Str(s)
42 }
43}
44
45impl From<&str> for JsonRpcId {
46 fn from(s: &str) -> Self {
47 Self::Str(s.to_string())
48 }
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct JsonRpcRequest {
54 pub jsonrpc: String,
55 pub method: String,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 pub params: Option<Value>,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub id: Option<JsonRpcId>,
60}
61
62impl JsonRpcRequest {
63 pub fn new(method: impl Into<String>, params: Option<Value>, id: Option<JsonRpcId>) -> Self {
65 Self {
66 jsonrpc: "2.0".to_string(),
67 method: method.into(),
68 params,
69 id,
70 }
71 }
72
73 pub fn notification(method: impl Into<String>, params: Option<Value>) -> Self {
75 Self::new(method, params, None)
76 }
77
78 pub fn is_notification(&self) -> bool {
80 self.id.is_none()
81 }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct JsonRpcError {
87 pub code: i32,
88 pub message: String,
89 #[serde(skip_serializing_if = "Option::is_none")]
90 pub data: Option<Value>,
91}
92
93impl JsonRpcError {
94 pub fn new(code: i32, message: impl Into<String>) -> Self {
96 Self {
97 code,
98 message: message.into(),
99 data: None,
100 }
101 }
102
103 pub fn with_data(code: i32, message: impl Into<String>, data: Value) -> Self {
105 Self {
106 code,
107 message: message.into(),
108 data: Some(data),
109 }
110 }
111
112 pub const PARSE_ERROR: i32 = -32700;
114 pub const INVALID_REQUEST: i32 = -32600;
115 pub const METHOD_NOT_FOUND: i32 = -32601;
116 pub const INVALID_PARAMS: i32 = -32602;
117 pub const INTERNAL_ERROR: i32 = -32603;
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct JsonRpcResponse {
123 pub jsonrpc: String,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 pub id: Option<JsonRpcId>,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 pub result: Option<Value>,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 pub error: Option<JsonRpcError>,
130}
131
132impl JsonRpcResponse {
133 pub fn success(id: JsonRpcId, result: Value) -> Self {
135 Self {
136 jsonrpc: "2.0".to_string(),
137 id: Some(id),
138 result: Some(result),
139 error: None,
140 }
141 }
142
143 pub fn error(id: JsonRpcId, error: JsonRpcError) -> Self {
145 Self {
146 jsonrpc: "2.0".to_string(),
147 id: Some(id),
148 result: None,
149 error: Some(error),
150 }
151 }
152
153 pub fn is_error(&self) -> bool {
155 self.error.is_some()
156 }
157}
158
159pub type NotificationHandler = Arc<dyn Fn(&str, &Value) + Send + Sync>;
165
166pub type RequestHandlerFuture =
168 Pin<Box<dyn std::future::Future<Output = std::result::Result<Value, JsonRpcError>> + Send>>;
169
170pub type RequestHandler = Arc<dyn Fn(&str, &Value) -> RequestHandlerFuture + Send + Sync>;
172
173struct PendingRequest {
178 sender: oneshot::Sender<std::result::Result<Value, JsonRpcError>>,
179}
180
181struct SharedState<T: Transport> {
186 framer: Mutex<MessageFramer<T>>,
187 running: AtomicBool,
188 pending_requests: RwLock<HashMap<i64, PendingRequest>>,
189 notification_handler: RwLock<Option<NotificationHandler>>,
190 request_handler: RwLock<Option<RequestHandler>>,
191}
192
193pub struct JsonRpcClient<T: Transport> {
206 state: Arc<SharedState<T>>,
207 next_id: AtomicI64,
208 shutdown_tx: Mutex<Option<mpsc::Sender<()>>>,
209}
210
211impl<T: Transport + 'static> JsonRpcClient<T> {
212 pub fn new(transport: T) -> Self {
214 Self {
215 state: Arc::new(SharedState {
216 framer: Mutex::new(MessageFramer::new(transport)),
217 running: AtomicBool::new(false),
218 pending_requests: RwLock::new(HashMap::new()),
219 notification_handler: RwLock::new(None),
220 request_handler: RwLock::new(None),
221 }),
222 next_id: AtomicI64::new(1),
223 shutdown_tx: Mutex::new(None),
224 }
225 }
226
227 pub async fn start(&self) -> Result<()> {
229 if self.state.running.swap(true, Ordering::SeqCst) {
230 return Ok(()); }
232
233 let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
234 *self.shutdown_tx.lock().await = Some(shutdown_tx);
235
236 let state = Arc::clone(&self.state);
238
239 tokio::spawn(async move {
240 loop {
241 tokio::select! {
242 _ = shutdown_rx.recv() => {
243 break;
244 }
245 result = async {
246 let mut framer = state.framer.lock().await;
247 framer.read_message().await
248 } => {
249 match result {
250 Ok(message_str) => {
251 if let Ok(message) = serde_json::from_str::<Value>(&message_str) {
252 Self::dispatch_message(&state, message).await;
253 }
254 }
255 Err(CopilotError::ConnectionClosed) => {
256 state.running.store(false, Ordering::SeqCst);
257 let mut pending = state.pending_requests.write().await;
259 for (_, req) in pending.drain() {
260 let _ = req.sender.send(Err(JsonRpcError::new(
261 -32801,
262 "Connection closed",
263 )));
264 }
265 break;
266 }
267 Err(_) => {
268 if !state.running.load(Ordering::SeqCst) {
270 break;
271 }
272 }
273 }
274 }
275 }
276 }
277 });
278
279 Ok(())
280 }
281
282 pub async fn stop(&self) {
284 self.state.running.store(false, Ordering::SeqCst);
285
286 if let Some(tx) = self.shutdown_tx.lock().await.take() {
288 let _ = tx.send(()).await;
289 }
290
291 let mut pending = self.state.pending_requests.write().await;
293 for (_, req) in pending.drain() {
294 let _ = req
295 .sender
296 .send(Err(JsonRpcError::new(-32801, "Connection closed")));
297 }
298 }
299
300 pub fn is_running(&self) -> bool {
302 self.state.running.load(Ordering::SeqCst)
303 }
304
305 pub async fn set_notification_handler<F>(&self, handler: F)
307 where
308 F: Fn(&str, &Value) + Send + Sync + 'static,
309 {
310 *self.state.notification_handler.write().await = Some(Arc::new(handler));
311 }
312
313 pub async fn set_request_handler<F>(&self, handler: F)
315 where
316 F: Fn(&str, &Value) -> RequestHandlerFuture + Send + Sync + 'static,
317 {
318 *self.state.request_handler.write().await = Some(Arc::new(handler));
319 }
320
321 pub async fn invoke(&self, method: &str, params: Option<Value>) -> Result<Value> {
323 self.invoke_with_timeout(method, params, Duration::from_secs(30))
324 .await
325 }
326
327 pub async fn invoke_with_timeout(
329 &self,
330 method: &str,
331 params: Option<Value>,
332 timeout: Duration,
333 ) -> Result<Value> {
334 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
335
336 let (tx, rx) = oneshot::channel();
338
339 {
341 let mut pending = self.state.pending_requests.write().await;
342 pending.insert(id, PendingRequest { sender: tx });
343 }
344
345 let request = JsonRpcRequest::new(method, params, Some(JsonRpcId::Num(id)));
347 let request_json = serde_json::to_string(&request)?;
348
349 if let Err(e) = self.send_raw(&request_json).await {
350 self.state.pending_requests.write().await.remove(&id);
352 return Err(e);
353 }
354
355 match tokio::time::timeout(timeout, rx).await {
357 Ok(Ok(Ok(result))) => Ok(result),
358 Ok(Ok(Err(rpc_error))) => Err(CopilotError::JsonRpc {
359 code: rpc_error.code,
360 message: rpc_error.message,
361 data: rpc_error.data,
362 }),
363 Ok(Err(_)) => {
364 self.state.pending_requests.write().await.remove(&id);
366 Err(CopilotError::ConnectionClosed)
367 }
368 Err(_) => {
369 self.state.pending_requests.write().await.remove(&id);
371 Err(CopilotError::Timeout(timeout))
372 }
373 }
374 }
375
376 pub async fn notify(&self, method: &str, params: Option<Value>) -> Result<()> {
378 let request = JsonRpcRequest::notification(method, params);
379 let request_json = serde_json::to_string(&request)?;
380 self.send_raw(&request_json).await
381 }
382
383 pub async fn send_response(&self, id: JsonRpcId, result: Value) -> Result<()> {
385 let response = JsonRpcResponse::success(id, result);
386 let response_json = serde_json::to_string(&response)?;
387 self.send_raw(&response_json).await
388 }
389
390 pub async fn send_error_response(&self, id: JsonRpcId, error: JsonRpcError) -> Result<()> {
392 let response = JsonRpcResponse::error(id, error);
393 let response_json = serde_json::to_string(&response)?;
394 self.send_raw(&response_json).await
395 }
396
397 async fn send_raw(&self, message: &str) -> Result<()> {
399 let mut framer = self.state.framer.lock().await;
400 framer.write_message(message).await
401 }
402
403 async fn dispatch_message(state: &SharedState<T>, message: Value) {
405 if message.get("id").is_some()
407 && !message.get("id").map(|v| v.is_null()).unwrap_or(true)
408 && (message.get("result").is_some() || message.get("error").is_some())
409 && message.get("method").is_none()
410 {
411 Self::handle_response(state, message).await;
412 return;
413 }
414
415 if message.get("method").is_some() {
417 if let Ok(request) = serde_json::from_value::<JsonRpcRequest>(message) {
418 if request.is_notification() {
419 Self::handle_notification(state, &request).await;
420 } else {
421 Self::handle_request(state, &request).await;
422 }
423 }
424 }
425 }
426
427 async fn handle_response(state: &SharedState<T>, message: Value) {
429 let response: JsonRpcResponse = match serde_json::from_value(message) {
431 Ok(r) => r,
432 Err(_) => return,
433 };
434
435 let id = match &response.id {
437 Some(JsonRpcId::Num(n)) => *n,
438 _ => return, };
440
441 let pending_req = {
443 let mut pending = state.pending_requests.write().await;
444 pending.remove(&id)
445 };
446
447 if let Some(req) = pending_req {
448 let result = if let Some(error) = response.error {
449 Err(error)
450 } else {
451 Ok(response.result.unwrap_or(Value::Null))
452 };
453 let _ = req.sender.send(result);
454 }
455 }
456
457 async fn handle_notification(state: &SharedState<T>, request: &JsonRpcRequest) {
459 let handler = state.notification_handler.read().await;
460 if let Some(handler) = handler.as_ref() {
461 let params = request.params.as_ref().unwrap_or(&Value::Null);
462 handler(&request.method, params);
463 }
464 }
465
466 async fn handle_request(state: &SharedState<T>, request: &JsonRpcRequest) {
468 let id = match &request.id {
469 Some(id) => id.clone(),
470 None => return, };
472
473 let handler = state.request_handler.read().await;
474 let params = request.params.as_ref().unwrap_or(&Value::Null);
475
476 let response = if let Some(handler) = handler.as_ref() {
477 match handler(&request.method, params).await {
479 Ok(result) => JsonRpcResponse::success(id, result),
480 Err(error) => JsonRpcResponse::error(id, error),
481 }
482 } else {
483 JsonRpcResponse::error(
485 id,
486 JsonRpcError::new(
487 JsonRpcError::METHOD_NOT_FOUND,
488 format!("Method not found: {}", request.method),
489 ),
490 )
491 };
492
493 if let Ok(response_json) = serde_json::to_string(&response) {
495 let mut framer = state.framer.lock().await;
496 let _ = framer.write_message(&response_json).await;
497 }
498 }
499}
500
501struct StdioSharedState {
507 writer: Mutex<MessageWriter<tokio::process::ChildStdin>>,
508 running: AtomicBool,
509 pending_requests: RwLock<HashMap<i64, PendingRequest>>,
510 notification_handler: RwLock<Option<NotificationHandler>>,
511 request_handler: RwLock<Option<RequestHandler>>,
512}
513
514pub struct StdioJsonRpcClient {
519 state: Arc<StdioSharedState>,
520 reader: Mutex<Option<MessageReader<tokio::process::ChildStdout>>>,
521 next_id: AtomicI64,
522 shutdown_tx: Mutex<Option<mpsc::Sender<()>>>,
523}
524
525impl StdioJsonRpcClient {
526 pub fn new(transport: StdioTransport) -> Self {
528 let (writer, reader) = transport.split();
529 Self {
530 state: Arc::new(StdioSharedState {
531 writer: Mutex::new(MessageWriter::new(writer)),
532 running: AtomicBool::new(false),
533 pending_requests: RwLock::new(HashMap::new()),
534 notification_handler: RwLock::new(None),
535 request_handler: RwLock::new(None),
536 }),
537 reader: Mutex::new(Some(MessageReader::new(reader))),
538 next_id: AtomicI64::new(1),
539 shutdown_tx: Mutex::new(None),
540 }
541 }
542
543 pub async fn start(&self) -> Result<()> {
545 let reader = self.reader.lock().await.take().ok_or_else(|| {
546 CopilotError::InvalidConfig("Reader already taken or client already started".into())
547 })?;
548 self.start_with_reader(reader).await
549 }
550
551 async fn start_with_reader(
553 &self,
554 mut reader: MessageReader<tokio::process::ChildStdout>,
555 ) -> Result<()> {
556 if self.state.running.swap(true, Ordering::SeqCst) {
557 return Ok(()); }
559
560 let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
561 *self.shutdown_tx.lock().await = Some(shutdown_tx);
562
563 let state = Arc::clone(&self.state);
565
566 tokio::spawn(async move {
567 loop {
568 tokio::select! {
569 _ = shutdown_rx.recv() => {
570 break;
571 }
572 result = reader.read_message() => {
573 match result {
574 Ok(message_str) => {
575 if let Ok(message) = serde_json::from_str::<Value>(&message_str) {
576 Self::dispatch_message(&state, message).await;
577 }
578 }
579 Err(CopilotError::ConnectionClosed) => {
580 state.running.store(false, Ordering::SeqCst);
581 let mut pending = state.pending_requests.write().await;
583 for (_, req) in pending.drain() {
584 let _ = req.sender.send(Err(JsonRpcError::new(
585 -32801,
586 "Connection closed",
587 )));
588 }
589 break;
590 }
591 Err(_) => {
592 if !state.running.load(Ordering::SeqCst) {
594 break;
595 }
596 }
597 }
598 }
599 }
600 }
601 });
602
603 Ok(())
604 }
605
606 pub async fn stop(&self) {
608 self.state.running.store(false, Ordering::SeqCst);
609
610 if let Some(tx) = self.shutdown_tx.lock().await.take() {
612 let _ = tx.send(()).await;
613 }
614
615 let mut pending = self.state.pending_requests.write().await;
617 for (_, req) in pending.drain() {
618 let _ = req
619 .sender
620 .send(Err(JsonRpcError::new(-32801, "Connection closed")));
621 }
622 }
623
624 pub fn is_running(&self) -> bool {
626 self.state.running.load(Ordering::SeqCst)
627 }
628
629 pub async fn set_notification_handler<F>(&self, handler: F)
631 where
632 F: Fn(&str, &Value) + Send + Sync + 'static,
633 {
634 *self.state.notification_handler.write().await = Some(Arc::new(handler));
635 }
636
637 pub async fn set_request_handler<F>(&self, handler: F)
639 where
640 F: Fn(&str, &Value) -> RequestHandlerFuture + Send + Sync + 'static,
641 {
642 *self.state.request_handler.write().await = Some(Arc::new(handler));
643 }
644
645 pub async fn invoke(&self, method: &str, params: Option<Value>) -> Result<Value> {
647 self.invoke_with_timeout(method, params, Duration::from_secs(30))
648 .await
649 }
650
651 pub async fn invoke_with_timeout(
653 &self,
654 method: &str,
655 params: Option<Value>,
656 timeout: Duration,
657 ) -> Result<Value> {
658 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
659
660 let (tx, rx) = oneshot::channel();
662
663 {
665 let mut pending = self.state.pending_requests.write().await;
666 pending.insert(id, PendingRequest { sender: tx });
667 }
668
669 let request = JsonRpcRequest::new(method, params, Some(JsonRpcId::Num(id)));
671 let request_json = serde_json::to_string(&request)?;
672
673 if let Err(e) = self.send_raw(&request_json).await {
674 self.state.pending_requests.write().await.remove(&id);
676 return Err(e);
677 }
678
679 match tokio::time::timeout(timeout, rx).await {
681 Ok(Ok(Ok(result))) => Ok(result),
682 Ok(Ok(Err(rpc_error))) => Err(CopilotError::JsonRpc {
683 code: rpc_error.code,
684 message: rpc_error.message,
685 data: rpc_error.data,
686 }),
687 Ok(Err(_)) => {
688 self.state.pending_requests.write().await.remove(&id);
690 Err(CopilotError::ConnectionClosed)
691 }
692 Err(_) => {
693 self.state.pending_requests.write().await.remove(&id);
695 Err(CopilotError::Timeout(timeout))
696 }
697 }
698 }
699
700 pub async fn notify(&self, method: &str, params: Option<Value>) -> Result<()> {
702 let request = JsonRpcRequest::notification(method, params);
703 let request_json = serde_json::to_string(&request)?;
704 self.send_raw(&request_json).await
705 }
706
707 async fn send_raw(&self, message: &str) -> Result<()> {
709 let mut writer = self.state.writer.lock().await;
710 writer.write_message(message).await
711 }
712
713 async fn dispatch_message(state: &StdioSharedState, message: Value) {
715 if message.get("id").is_some()
717 && !message.get("id").map(|v| v.is_null()).unwrap_or(true)
718 && (message.get("result").is_some() || message.get("error").is_some())
719 && message.get("method").is_none()
720 {
721 Self::handle_response(state, message).await;
722 return;
723 }
724
725 if message.get("method").is_some() {
727 if let Ok(request) = serde_json::from_value::<JsonRpcRequest>(message) {
728 if request.is_notification() {
729 Self::handle_notification(state, &request).await;
730 } else {
731 Self::handle_request(state, &request).await;
732 }
733 }
734 }
735 }
736
737 async fn handle_response(state: &StdioSharedState, message: Value) {
739 let response: JsonRpcResponse = match serde_json::from_value(message) {
741 Ok(r) => r,
742 Err(_) => return,
743 };
744
745 let id = match &response.id {
747 Some(JsonRpcId::Num(n)) => *n,
748 _ => return,
749 };
750
751 let pending_req = {
753 let mut pending = state.pending_requests.write().await;
754 pending.remove(&id)
755 };
756
757 if let Some(req) = pending_req {
758 let result = if let Some(error) = response.error {
759 Err(error)
760 } else {
761 Ok(response.result.unwrap_or(Value::Null))
762 };
763 let _ = req.sender.send(result);
764 }
765 }
766
767 async fn handle_notification(state: &StdioSharedState, request: &JsonRpcRequest) {
769 let handler = state.notification_handler.read().await;
770 if let Some(handler) = handler.as_ref() {
771 let params = request.params.as_ref().unwrap_or(&Value::Null);
772 handler(&request.method, params);
773 }
774 }
775
776 async fn handle_request(state: &StdioSharedState, request: &JsonRpcRequest) {
778 let id = match &request.id {
779 Some(id) => id.clone(),
780 None => return,
781 };
782
783 let handler = state.request_handler.read().await;
784 let params = request.params.as_ref().unwrap_or(&Value::Null);
785
786 let response = if let Some(handler) = handler.as_ref() {
787 match handler(&request.method, params).await {
789 Ok(result) => JsonRpcResponse::success(id.clone(), result),
790 Err(error) => JsonRpcResponse::error(id.clone(), error),
791 }
792 } else {
793 JsonRpcResponse::error(
794 id.clone(),
795 JsonRpcError::new(
796 JsonRpcError::METHOD_NOT_FOUND,
797 format!("Method not found: {}", request.method),
798 ),
799 )
800 };
801
802 if let Ok(response_json) = serde_json::to_string(&response) {
804 let mut writer = state.writer.lock().await;
805 let _ = writer.write_message(&response_json).await;
806 }
807 }
808}
809
810struct TcpSharedState {
816 writer: Mutex<MessageWriter<OwnedWriteHalf>>,
817 running: AtomicBool,
818 pending_requests: RwLock<HashMap<i64, PendingRequest>>,
819 notification_handler: RwLock<Option<NotificationHandler>>,
820 request_handler: RwLock<Option<RequestHandler>>,
821}
822
823pub struct TcpJsonRpcClient {
825 state: Arc<TcpSharedState>,
826 reader: Mutex<Option<MessageReader<OwnedReadHalf>>>,
827 next_id: AtomicI64,
828 shutdown_tx: Mutex<Option<mpsc::Sender<()>>>,
829}
830
831impl TcpJsonRpcClient {
832 pub async fn connect(addr: impl AsRef<str>) -> Result<Self> {
834 let stream = TcpStream::connect(addr.as_ref())
835 .await
836 .map_err(CopilotError::Transport)?;
837 Ok(Self::new(stream))
838 }
839
840 pub fn new(stream: TcpStream) -> Self {
842 let (reader, writer) = stream.into_split();
843 Self {
844 state: Arc::new(TcpSharedState {
845 writer: Mutex::new(MessageWriter::new(writer)),
846 running: AtomicBool::new(false),
847 pending_requests: RwLock::new(HashMap::new()),
848 notification_handler: RwLock::new(None),
849 request_handler: RwLock::new(None),
850 }),
851 reader: Mutex::new(Some(MessageReader::new(reader))),
852 next_id: AtomicI64::new(1),
853 shutdown_tx: Mutex::new(None),
854 }
855 }
856
857 pub async fn start(&self) -> Result<()> {
859 let reader = self.reader.lock().await.take().ok_or_else(|| {
860 CopilotError::InvalidConfig("Reader already taken or client already started".into())
861 })?;
862 self.start_with_reader(reader).await
863 }
864
865 async fn start_with_reader(&self, mut reader: MessageReader<OwnedReadHalf>) -> Result<()> {
866 if self.state.running.swap(true, Ordering::SeqCst) {
867 return Ok(()); }
869
870 let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
871 *self.shutdown_tx.lock().await = Some(shutdown_tx);
872
873 let state = Arc::clone(&self.state);
874
875 tokio::spawn(async move {
876 loop {
877 tokio::select! {
878 _ = shutdown_rx.recv() => {
879 break;
880 }
881 result = reader.read_message() => {
882 match result {
883 Ok(message_str) => {
884 if let Ok(message) = serde_json::from_str::<Value>(&message_str) {
885 Self::dispatch_message(&state, message).await;
886 }
887 }
888 Err(CopilotError::ConnectionClosed) => {
889 state.running.store(false, Ordering::SeqCst);
890 let mut pending = state.pending_requests.write().await;
891 for (_, req) in pending.drain() {
892 let _ = req.sender.send(Err(JsonRpcError::new(
893 -32801,
894 "Connection closed",
895 )));
896 }
897 break;
898 }
899 Err(_) => {
900 if !state.running.load(Ordering::SeqCst) {
901 break;
902 }
903 }
904 }
905 }
906 }
907 }
908 });
909
910 Ok(())
911 }
912
913 pub async fn stop(&self) {
915 self.state.running.store(false, Ordering::SeqCst);
916
917 if let Some(tx) = self.shutdown_tx.lock().await.take() {
918 let _ = tx.send(()).await;
919 }
920
921 let mut pending = self.state.pending_requests.write().await;
922 for (_, req) in pending.drain() {
923 let _ = req
924 .sender
925 .send(Err(JsonRpcError::new(-32801, "Connection closed")));
926 }
927 }
928
929 pub fn is_running(&self) -> bool {
931 self.state.running.load(Ordering::SeqCst)
932 }
933
934 pub async fn set_notification_handler<F>(&self, handler: F)
936 where
937 F: Fn(&str, &Value) + Send + Sync + 'static,
938 {
939 *self.state.notification_handler.write().await = Some(Arc::new(handler));
940 }
941
942 pub async fn set_request_handler<F>(&self, handler: F)
944 where
945 F: Fn(&str, &Value) -> RequestHandlerFuture + Send + Sync + 'static,
946 {
947 *self.state.request_handler.write().await = Some(Arc::new(handler));
948 }
949
950 pub async fn invoke(&self, method: &str, params: Option<Value>) -> Result<Value> {
952 self.invoke_with_timeout(method, params, Duration::from_secs(30))
953 .await
954 }
955
956 pub async fn invoke_with_timeout(
958 &self,
959 method: &str,
960 params: Option<Value>,
961 timeout: Duration,
962 ) -> Result<Value> {
963 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
964
965 let (tx, rx) = oneshot::channel();
966 {
967 let mut pending = self.state.pending_requests.write().await;
968 pending.insert(id, PendingRequest { sender: tx });
969 }
970
971 let request = JsonRpcRequest::new(method, params, Some(JsonRpcId::Num(id)));
972 let request_json = serde_json::to_string(&request)?;
973
974 if let Err(e) = self.send_raw(&request_json).await {
975 self.state.pending_requests.write().await.remove(&id);
976 return Err(e);
977 }
978
979 match tokio::time::timeout(timeout, rx).await {
980 Ok(Ok(Ok(result))) => Ok(result),
981 Ok(Ok(Err(rpc_error))) => Err(CopilotError::JsonRpc {
982 code: rpc_error.code,
983 message: rpc_error.message,
984 data: rpc_error.data,
985 }),
986 Ok(Err(_)) => {
987 self.state.pending_requests.write().await.remove(&id);
988 Err(CopilotError::ConnectionClosed)
989 }
990 Err(_) => {
991 self.state.pending_requests.write().await.remove(&id);
992 Err(CopilotError::Timeout(timeout))
993 }
994 }
995 }
996
997 pub async fn notify(&self, method: &str, params: Option<Value>) -> Result<()> {
999 let request = JsonRpcRequest::notification(method, params);
1000 let request_json = serde_json::to_string(&request)?;
1001 self.send_raw(&request_json).await
1002 }
1003
1004 async fn send_raw(&self, message: &str) -> Result<()> {
1005 let mut writer = self.state.writer.lock().await;
1006 writer.write_message(message).await
1007 }
1008
1009 async fn dispatch_message(state: &TcpSharedState, message: Value) {
1010 if message.get("id").is_some()
1011 && !message.get("id").map(|v| v.is_null()).unwrap_or(true)
1012 && (message.get("result").is_some() || message.get("error").is_some())
1013 && message.get("method").is_none()
1014 {
1015 Self::handle_response(state, message).await;
1016 return;
1017 }
1018
1019 if message.get("method").is_some() {
1020 if let Ok(request) = serde_json::from_value::<JsonRpcRequest>(message) {
1021 if request.is_notification() {
1022 Self::handle_notification(state, &request).await;
1023 } else {
1024 Self::handle_request(state, &request).await;
1025 }
1026 }
1027 }
1028 }
1029
1030 async fn handle_response(state: &TcpSharedState, message: Value) {
1031 let response: JsonRpcResponse = match serde_json::from_value(message) {
1032 Ok(r) => r,
1033 Err(_) => return,
1034 };
1035
1036 let id = match &response.id {
1037 Some(JsonRpcId::Num(n)) => *n,
1038 _ => return,
1039 };
1040
1041 let pending_req = {
1042 let mut pending = state.pending_requests.write().await;
1043 pending.remove(&id)
1044 };
1045
1046 if let Some(req) = pending_req {
1047 let result = if let Some(error) = response.error {
1048 Err(error)
1049 } else {
1050 Ok(response.result.unwrap_or(Value::Null))
1051 };
1052 let _ = req.sender.send(result);
1053 }
1054 }
1055
1056 async fn handle_notification(state: &TcpSharedState, request: &JsonRpcRequest) {
1057 let handler = state.notification_handler.read().await;
1058 if let Some(handler) = handler.as_ref() {
1059 let params = request.params.as_ref().unwrap_or(&Value::Null);
1060 handler(&request.method, params);
1061 }
1062 }
1063
1064 async fn handle_request(state: &TcpSharedState, request: &JsonRpcRequest) {
1065 let id = match &request.id {
1066 Some(id) => id.clone(),
1067 None => return,
1068 };
1069
1070 let handler = state.request_handler.read().await;
1071 let params = request.params.as_ref().unwrap_or(&Value::Null);
1072
1073 let response = if let Some(handler) = handler.as_ref() {
1074 match handler(&request.method, params).await {
1075 Ok(result) => JsonRpcResponse::success(id.clone(), result),
1076 Err(error) => JsonRpcResponse::error(id.clone(), error),
1077 }
1078 } else {
1079 JsonRpcResponse::error(
1080 id.clone(),
1081 JsonRpcError::new(
1082 JsonRpcError::METHOD_NOT_FOUND,
1083 format!("Method not found: {}", request.method),
1084 ),
1085 )
1086 };
1087
1088 if let Ok(response_json) = serde_json::to_string(&response) {
1089 let mut writer = state.writer.lock().await;
1090 let _ = writer.write_message(&response_json).await;
1091 }
1092 }
1093}
1094
1095#[cfg(test)]
1096mod tests {
1097 use super::*;
1098 use crate::transport::MemoryTransport;
1099 use serde_json::json;
1100
1101 #[test]
1102 fn test_json_rpc_request_serialization() {
1103 let request = JsonRpcRequest::new(
1104 "test_method",
1105 Some(json!({"key": "value"})),
1106 Some(JsonRpcId::Num(1)),
1107 );
1108
1109 let json = serde_json::to_value(&request).unwrap();
1110 assert_eq!(json["jsonrpc"], "2.0");
1111 assert_eq!(json["method"], "test_method");
1112 assert_eq!(json["params"]["key"], "value");
1113 assert_eq!(json["id"], 1);
1114 }
1115
1116 #[test]
1117 fn test_json_rpc_notification_serialization() {
1118 let request = JsonRpcRequest::notification("notify_method", Some(json!([1, 2, 3])));
1119
1120 let json = serde_json::to_value(&request).unwrap();
1121 assert_eq!(json["jsonrpc"], "2.0");
1122 assert_eq!(json["method"], "notify_method");
1123 assert!(json.get("id").is_none());
1124 }
1125
1126 #[test]
1127 fn test_json_rpc_response_success() {
1128 let response = JsonRpcResponse::success(JsonRpcId::Num(1), json!({"result": "ok"}));
1129
1130 let json = serde_json::to_value(&response).unwrap();
1131 assert_eq!(json["jsonrpc"], "2.0");
1132 assert_eq!(json["id"], 1);
1133 assert_eq!(json["result"]["result"], "ok");
1134 assert!(json.get("error").is_none());
1135 }
1136
1137 #[test]
1138 fn test_json_rpc_response_error() {
1139 let response = JsonRpcResponse::error(
1140 JsonRpcId::Num(1),
1141 JsonRpcError::new(-32600, "Invalid Request"),
1142 );
1143
1144 let json = serde_json::to_value(&response).unwrap();
1145 assert_eq!(json["jsonrpc"], "2.0");
1146 assert_eq!(json["id"], 1);
1147 assert_eq!(json["error"]["code"], -32600);
1148 assert_eq!(json["error"]["message"], "Invalid Request");
1149 }
1150
1151 #[test]
1152 fn test_json_rpc_id_from_i64() {
1153 let id: JsonRpcId = 42i64.into();
1154 assert_eq!(id, JsonRpcId::Num(42));
1155 }
1156
1157 #[test]
1158 fn test_json_rpc_id_from_string() {
1159 let id: JsonRpcId = "test-id".into();
1160 assert_eq!(id, JsonRpcId::Str("test-id".to_string()));
1161 }
1162
1163 #[test]
1164 fn test_json_rpc_error_constants() {
1165 assert_eq!(JsonRpcError::PARSE_ERROR, -32700);
1166 assert_eq!(JsonRpcError::INVALID_REQUEST, -32600);
1167 assert_eq!(JsonRpcError::METHOD_NOT_FOUND, -32601);
1168 assert_eq!(JsonRpcError::INVALID_PARAMS, -32602);
1169 assert_eq!(JsonRpcError::INTERNAL_ERROR, -32603);
1170 }
1171
1172 #[test]
1173 fn test_request_is_notification() {
1174 let request = JsonRpcRequest::notification("method", None);
1175 assert!(request.is_notification());
1176
1177 let request = JsonRpcRequest::new("method", None, Some(JsonRpcId::Num(1)));
1178 assert!(!request.is_notification());
1179 }
1180
1181 #[tokio::test]
1182 async fn test_large_payload_64kb_boundary() {
1183 let large_data = "x".repeat(65536 - 50); let msg =
1186 serde_json::json!({"jsonrpc": "2.0", "method": "test", "params": {"data": large_data}});
1187 let msg_str = serde_json::to_string(&msg).unwrap();
1188
1189 let transport = MemoryTransport::new(Vec::new());
1191 let mut framer = MessageFramer::new(transport);
1192 framer.write_message(&msg_str).await.unwrap();
1193
1194 let written = framer.transport().written_data().to_vec();
1196 let transport2 = MemoryTransport::new(written);
1197 let mut framer2 = MessageFramer::new(transport2);
1198 let read_back = framer2.read_message().await.unwrap();
1199 assert_eq!(msg_str, read_back);
1200 }
1201
1202 #[tokio::test]
1203 async fn test_large_payload_100kb() {
1204 let large_data = "y".repeat(100_000);
1205 let msg =
1206 serde_json::json!({"jsonrpc": "2.0", "method": "test", "params": {"data": large_data}});
1207 let msg_str = serde_json::to_string(&msg).unwrap();
1208
1209 let transport = MemoryTransport::new(Vec::new());
1210 let mut framer = MessageFramer::new(transport);
1211 framer.write_message(&msg_str).await.unwrap();
1212
1213 let written = framer.transport().written_data().to_vec();
1214 let transport2 = MemoryTransport::new(written);
1215 let mut framer2 = MessageFramer::new(transport2);
1216 let read_back = framer2.read_message().await.unwrap();
1217 assert_eq!(msg_str, read_back);
1218 }
1219
1220 #[tokio::test]
1221 async fn test_multiple_large_messages_sequential() {
1222 let msg1_data = "a".repeat(50_000);
1223 let msg2_data = "b".repeat(80_000);
1224 let msg1 = serde_json::json!({"jsonrpc": "2.0", "id": 1, "method": "test1", "params": {"data": msg1_data}});
1225 let msg2 = serde_json::json!({"jsonrpc": "2.0", "id": 2, "method": "test2", "params": {"data": msg2_data}});
1226 let msg1_str = serde_json::to_string(&msg1).unwrap();
1227 let msg2_str = serde_json::to_string(&msg2).unwrap();
1228
1229 let transport = MemoryTransport::new(Vec::new());
1231 let mut framer = MessageFramer::new(transport);
1232 framer.write_message(&msg1_str).await.unwrap();
1233 framer.write_message(&msg2_str).await.unwrap();
1234
1235 let written = framer.transport().written_data().to_vec();
1237 let transport2 = MemoryTransport::new(written);
1238 let mut framer2 = MessageFramer::new(transport2);
1239 let read1 = framer2.read_message().await.unwrap();
1240 let read2 = framer2.read_message().await.unwrap();
1241 assert_eq!(msg1_str, read1);
1242 assert_eq!(msg2_str, read2);
1243 }
1244}