1use std::sync::Arc;
2
3use derive_more::{Display, From};
4use schemars::JsonSchema;
5use serde::{Deserialize, Serialize, de::DeserializeOwned};
6use serde_json::value::RawValue;
7
8use crate::{
9 AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES,
10 ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result,
11};
12
13#[derive(
23 Debug,
24 PartialEq,
25 Clone,
26 Hash,
27 Eq,
28 Deserialize,
29 Serialize,
30 PartialOrd,
31 Ord,
32 Display,
33 JsonSchema,
34 From,
35)]
36#[serde(untagged)]
37#[allow(
38 clippy::exhaustive_enums,
39 reason = "This comes from the JSON-RPC specification itself"
40)]
41#[from(String, i64)]
42pub enum RequestId {
43 #[display("null")]
44 Null,
45 Number(i64),
46 Str(String),
47}
48
49#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
50#[allow(
51 clippy::exhaustive_structs,
52 reason = "This comes from the JSON-RPC specification itself"
53)]
54#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
55pub struct Request<Params> {
56 pub id: RequestId,
57 pub method: Arc<str>,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub params: Option<Params>,
60}
61
62#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
63#[allow(
64 clippy::exhaustive_enums,
65 reason = "This comes from the JSON-RPC specification itself"
66)]
67#[serde(untagged)]
68#[schemars(rename = "{Result}", extend("x-docs-ignore" = true))]
69pub enum Response<Result> {
70 Result { id: RequestId, result: Result },
71 Error { id: RequestId, error: Error },
72}
73
74impl<R> Response<R> {
75 #[must_use]
76 pub fn new(id: impl Into<RequestId>, result: Result<R>) -> Self {
77 match result {
78 Ok(result) => Self::Result {
79 id: id.into(),
80 result,
81 },
82 Err(error) => Self::Error {
83 id: id.into(),
84 error,
85 },
86 }
87 }
88}
89
90#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
91#[allow(
92 clippy::exhaustive_structs,
93 reason = "This comes from the JSON-RPC specification itself"
94)]
95#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
96pub struct Notification<Params> {
97 pub method: Arc<str>,
98 #[serde(skip_serializing_if = "Option::is_none")]
99 pub params: Option<Params>,
100}
101
102#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
103#[serde(untagged)]
104#[schemars(inline)]
105#[allow(
106 clippy::exhaustive_enums,
107 reason = "This comes from the JSON-RPC specification itself"
108)]
109pub enum OutgoingMessage<Local: Side, Remote: Side> {
110 Request(Request<Remote::InRequest>),
111 Response(Response<Local::OutResponse>),
112 Notification(Notification<Remote::InNotification>),
113}
114
115#[derive(Debug, Serialize, Deserialize, JsonSchema)]
116#[schemars(inline)]
117enum JsonRpcVersion {
118 #[serde(rename = "2.0")]
119 V2,
120}
121
122#[derive(Debug, Serialize, Deserialize, JsonSchema)]
127#[schemars(inline)]
128pub struct JsonRpcMessage<M> {
129 jsonrpc: JsonRpcVersion,
130 #[serde(flatten)]
131 message: M,
132}
133
134impl<M> JsonRpcMessage<M> {
135 #[must_use]
138 pub fn wrap(message: M) -> Self {
139 Self {
140 jsonrpc: JsonRpcVersion::V2,
141 message,
142 }
143 }
144}
145
146pub trait Side: Clone {
147 type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
148 type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
149 type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
150
151 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
159
160 fn decode_notification(method: &str, params: Option<&RawValue>)
168 -> Result<Self::InNotification>;
169}
170
171#[derive(Clone, Default, Debug, JsonSchema)]
178#[non_exhaustive]
179pub struct ClientSide;
180
181impl Side for ClientSide {
182 type InRequest = AgentRequest;
183 type InNotification = AgentNotification;
184 type OutResponse = ClientResponse;
185
186 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
187 let params = params.ok_or_else(Error::invalid_params)?;
188
189 match method {
190 m if m == CLIENT_METHOD_NAMES.session_request_permission => {
191 serde_json::from_str(params.get())
192 .map(AgentRequest::RequestPermissionRequest)
193 .map_err(Into::into)
194 }
195 m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
196 .map(AgentRequest::WriteTextFileRequest)
197 .map_err(Into::into),
198 m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
199 .map(AgentRequest::ReadTextFileRequest)
200 .map_err(Into::into),
201 m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
202 .map(AgentRequest::CreateTerminalRequest)
203 .map_err(Into::into),
204 m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
205 .map(AgentRequest::TerminalOutputRequest)
206 .map_err(Into::into),
207 m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
208 .map(AgentRequest::KillTerminalRequest)
209 .map_err(Into::into),
210 m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
211 .map(AgentRequest::ReleaseTerminalRequest)
212 .map_err(Into::into),
213 m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
214 serde_json::from_str(params.get())
215 .map(AgentRequest::WaitForTerminalExitRequest)
216 .map_err(Into::into)
217 }
218 #[cfg(feature = "unstable_elicitation")]
219 m if m == CLIENT_METHOD_NAMES.session_elicitation => serde_json::from_str(params.get())
220 .map(AgentRequest::ElicitationRequest)
221 .map_err(Into::into),
222 _ => {
223 if let Some(custom_method) = method.strip_prefix('_') {
224 Ok(AgentRequest::ExtMethodRequest(ExtRequest {
225 method: custom_method.into(),
226 params: params.to_owned().into(),
227 }))
228 } else {
229 Err(Error::method_not_found())
230 }
231 }
232 }
233 }
234
235 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
236 let params = params.ok_or_else(Error::invalid_params)?;
237
238 match method {
239 m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
240 .map(AgentNotification::SessionNotification)
241 .map_err(Into::into),
242 #[cfg(feature = "unstable_elicitation")]
243 m if m == CLIENT_METHOD_NAMES.session_elicitation_complete => {
244 serde_json::from_str(params.get())
245 .map(AgentNotification::ElicitationCompleteNotification)
246 .map_err(Into::into)
247 }
248 _ => {
249 if let Some(custom_method) = method.strip_prefix('_') {
250 Ok(AgentNotification::ExtNotification(ExtNotification {
251 method: custom_method.into(),
252 params: params.to_owned().into(),
253 }))
254 } else {
255 Err(Error::method_not_found())
256 }
257 }
258 }
259 }
260}
261
262#[derive(Clone, Default, Debug, JsonSchema)]
269#[non_exhaustive]
270pub struct AgentSide;
271
272impl Side for AgentSide {
273 type InRequest = ClientRequest;
274 type InNotification = ClientNotification;
275 type OutResponse = AgentResponse;
276
277 fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
278 let params = params.ok_or_else(Error::invalid_params)?;
279
280 match method {
281 m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
282 .map(ClientRequest::InitializeRequest)
283 .map_err(Into::into),
284 m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
285 .map(ClientRequest::AuthenticateRequest)
286 .map_err(Into::into),
287 #[cfg(feature = "unstable_logout")]
288 m if m == AGENT_METHOD_NAMES.logout => serde_json::from_str(params.get())
289 .map(ClientRequest::LogoutRequest)
290 .map_err(Into::into),
291 m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
292 .map(ClientRequest::NewSessionRequest)
293 .map_err(Into::into),
294 m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
295 .map(ClientRequest::LoadSessionRequest)
296 .map_err(Into::into),
297 m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
298 .map(ClientRequest::ListSessionsRequest)
299 .map_err(Into::into),
300 #[cfg(feature = "unstable_session_fork")]
301 m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
302 .map(ClientRequest::ForkSessionRequest)
303 .map_err(Into::into),
304 #[cfg(feature = "unstable_session_resume")]
305 m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get())
306 .map(ClientRequest::ResumeSessionRequest)
307 .map_err(Into::into),
308 #[cfg(feature = "unstable_session_close")]
309 m if m == AGENT_METHOD_NAMES.session_close => serde_json::from_str(params.get())
310 .map(ClientRequest::CloseSessionRequest)
311 .map_err(Into::into),
312 m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
313 .map(ClientRequest::SetSessionModeRequest)
314 .map_err(Into::into),
315 m if m == AGENT_METHOD_NAMES.session_set_config_option => {
316 serde_json::from_str(params.get())
317 .map(ClientRequest::SetSessionConfigOptionRequest)
318 .map_err(Into::into)
319 }
320 #[cfg(feature = "unstable_session_model")]
321 m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
322 .map(ClientRequest::SetSessionModelRequest)
323 .map_err(Into::into),
324 m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
325 .map(ClientRequest::PromptRequest)
326 .map_err(Into::into),
327 #[cfg(feature = "unstable_nes")]
328 m if m == AGENT_METHOD_NAMES.nes_start => serde_json::from_str(params.get())
329 .map(ClientRequest::StartNesRequest)
330 .map_err(Into::into),
331 #[cfg(feature = "unstable_nes")]
332 m if m == AGENT_METHOD_NAMES.nes_suggest => serde_json::from_str(params.get())
333 .map(ClientRequest::SuggestNesRequest)
334 .map_err(Into::into),
335 #[cfg(feature = "unstable_nes")]
336 m if m == AGENT_METHOD_NAMES.nes_close => serde_json::from_str(params.get())
337 .map(ClientRequest::CloseNesRequest)
338 .map_err(Into::into),
339 _ => {
340 if let Some(custom_method) = method.strip_prefix('_') {
341 Ok(ClientRequest::ExtMethodRequest(ExtRequest {
342 method: custom_method.into(),
343 params: params.to_owned().into(),
344 }))
345 } else {
346 Err(Error::method_not_found())
347 }
348 }
349 }
350 }
351
352 fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
353 let params = params.ok_or_else(Error::invalid_params)?;
354
355 match method {
356 m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
357 .map(ClientNotification::CancelNotification)
358 .map_err(Into::into),
359 #[cfg(feature = "unstable_nes")]
360 m if m == AGENT_METHOD_NAMES.document_did_open => serde_json::from_str(params.get())
361 .map(ClientNotification::DidOpenDocumentNotification)
362 .map_err(Into::into),
363 #[cfg(feature = "unstable_nes")]
364 m if m == AGENT_METHOD_NAMES.document_did_change => serde_json::from_str(params.get())
365 .map(ClientNotification::DidChangeDocumentNotification)
366 .map_err(Into::into),
367 #[cfg(feature = "unstable_nes")]
368 m if m == AGENT_METHOD_NAMES.document_did_close => serde_json::from_str(params.get())
369 .map(ClientNotification::DidCloseDocumentNotification)
370 .map_err(Into::into),
371 #[cfg(feature = "unstable_nes")]
372 m if m == AGENT_METHOD_NAMES.document_did_save => serde_json::from_str(params.get())
373 .map(ClientNotification::DidSaveDocumentNotification)
374 .map_err(Into::into),
375 #[cfg(feature = "unstable_nes")]
376 m if m == AGENT_METHOD_NAMES.document_did_focus => serde_json::from_str(params.get())
377 .map(ClientNotification::DidFocusDocumentNotification)
378 .map_err(Into::into),
379 #[cfg(feature = "unstable_nes")]
380 m if m == AGENT_METHOD_NAMES.nes_accept => serde_json::from_str(params.get())
381 .map(ClientNotification::AcceptNesNotification)
382 .map_err(Into::into),
383 #[cfg(feature = "unstable_nes")]
384 m if m == AGENT_METHOD_NAMES.nes_reject => serde_json::from_str(params.get())
385 .map(ClientNotification::RejectNesNotification)
386 .map_err(Into::into),
387 _ => {
388 if let Some(custom_method) = method.strip_prefix('_') {
389 Ok(ClientNotification::ExtNotification(ExtNotification {
390 method: custom_method.into(),
391 params: params.to_owned().into(),
392 }))
393 } else {
394 Err(Error::method_not_found())
395 }
396 }
397 }
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404
405 use serde_json::{Number, Value};
406
407 #[test]
408 fn id_deserialization() {
409 let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
410 assert_eq!(id, RequestId::Null);
411
412 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
413 .unwrap();
414 assert_eq!(id, RequestId::Number(1));
415
416 let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
417 .unwrap();
418 assert_eq!(id, RequestId::Number(-1));
419
420 let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
421 assert_eq!(id, RequestId::Str("id".to_owned()));
422 }
423
424 #[test]
425 fn id_serialization() {
426 let id = serde_json::to_value(RequestId::Null).unwrap();
427 assert_eq!(id, Value::Null);
428
429 let id = serde_json::to_value(RequestId::Number(1)).unwrap();
430 assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
431
432 let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
433 assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
434
435 let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
436 assert_eq!(id, Value::String("id".to_owned()));
437 }
438
439 #[test]
440 fn id_display() {
441 let id = RequestId::Null;
442 assert_eq!(id.to_string(), "null");
443
444 let id = RequestId::Number(1);
445 assert_eq!(id.to_string(), "1");
446
447 let id = RequestId::Number(-1);
448 assert_eq!(id.to_string(), "-1");
449
450 let id = RequestId::Str("id".to_owned());
451 assert_eq!(id.to_string(), "id");
452 }
453}
454
455#[cfg(feature = "unstable_nes")]
456#[cfg(test)]
457mod nes_rpc_tests {
458 use super::*;
459 use serde_json::json;
460
461 #[test]
462 fn test_decode_nes_start_request() {
463 let params = serde_json::to_string(&json!({
464 "workspaceUri": "file:///Users/alice/projects/my-app",
465 "workspaceFolders": [
466 { "uri": "file:///Users/alice/projects/my-app", "name": "my-app" }
467 ]
468 }))
469 .unwrap();
470 let raw = serde_json::value::RawValue::from_string(params).unwrap();
471 let request = AgentSide::decode_request("nes/start", Some(&raw)).unwrap();
472 assert!(matches!(request, ClientRequest::StartNesRequest(_)));
473 }
474
475 #[test]
476 fn test_decode_nes_suggest_request() {
477 let params = serde_json::to_string(&json!({
478 "sessionId": "session_123",
479 "uri": "file:///path/to/file.rs",
480 "version": 2,
481 "position": { "line": 5, "character": 12 },
482 "triggerKind": "automatic"
483 }))
484 .unwrap();
485 let raw = serde_json::value::RawValue::from_string(params).unwrap();
486 let request = AgentSide::decode_request("nes/suggest", Some(&raw)).unwrap();
487 assert!(matches!(request, ClientRequest::SuggestNesRequest(_)));
488 }
489
490 #[test]
491 fn test_decode_nes_close_request() {
492 let params = serde_json::to_string(&json!({
493 "sessionId": "session_123"
494 }))
495 .unwrap();
496 let raw = serde_json::value::RawValue::from_string(params).unwrap();
497 let request = AgentSide::decode_request("nes/close", Some(&raw)).unwrap();
498 assert!(matches!(request, ClientRequest::CloseNesRequest(_)));
499 }
500
501 #[test]
502 fn test_decode_document_did_open_notification() {
503 let params = serde_json::to_string(&json!({
504 "sessionId": "session_123",
505 "uri": "file:///path/to/file.rs",
506 "languageId": "rust",
507 "version": 1,
508 "text": "fn main() {}"
509 }))
510 .unwrap();
511 let raw = serde_json::value::RawValue::from_string(params).unwrap();
512 let notification = AgentSide::decode_notification("document/didOpen", Some(&raw)).unwrap();
513 assert!(matches!(
514 notification,
515 ClientNotification::DidOpenDocumentNotification(_)
516 ));
517 }
518
519 #[test]
520 fn test_decode_document_did_change_notification() {
521 let params = serde_json::to_string(&json!({
522 "sessionId": "session_123",
523 "uri": "file:///path/to/file.rs",
524 "version": 2,
525 "contentChanges": [{ "text": "fn main() { let x = 1; }" }]
526 }))
527 .unwrap();
528 let raw = serde_json::value::RawValue::from_string(params).unwrap();
529 let notification =
530 AgentSide::decode_notification("document/didChange", Some(&raw)).unwrap();
531 assert!(matches!(
532 notification,
533 ClientNotification::DidChangeDocumentNotification(_)
534 ));
535 }
536
537 #[test]
538 fn test_decode_document_did_close_notification() {
539 let params = serde_json::to_string(&json!({
540 "sessionId": "session_123",
541 "uri": "file:///path/to/file.rs"
542 }))
543 .unwrap();
544 let raw = serde_json::value::RawValue::from_string(params).unwrap();
545 let notification = AgentSide::decode_notification("document/didClose", Some(&raw)).unwrap();
546 assert!(matches!(
547 notification,
548 ClientNotification::DidCloseDocumentNotification(_)
549 ));
550 }
551
552 #[test]
553 fn test_decode_document_did_save_notification() {
554 let params = serde_json::to_string(&json!({
555 "sessionId": "session_123",
556 "uri": "file:///path/to/file.rs"
557 }))
558 .unwrap();
559 let raw = serde_json::value::RawValue::from_string(params).unwrap();
560 let notification = AgentSide::decode_notification("document/didSave", Some(&raw)).unwrap();
561 assert!(matches!(
562 notification,
563 ClientNotification::DidSaveDocumentNotification(_)
564 ));
565 }
566
567 #[test]
568 fn test_decode_document_did_focus_notification() {
569 let params = serde_json::to_string(&json!({
570 "sessionId": "session_123",
571 "uri": "file:///path/to/file.rs",
572 "version": 2,
573 "position": { "line": 5, "character": 12 },
574 "visibleRange": {
575 "start": { "line": 0, "character": 0 },
576 "end": { "line": 45, "character": 0 }
577 }
578 }))
579 .unwrap();
580 let raw = serde_json::value::RawValue::from_string(params).unwrap();
581 let notification = AgentSide::decode_notification("document/didFocus", Some(&raw)).unwrap();
582 assert!(matches!(
583 notification,
584 ClientNotification::DidFocusDocumentNotification(_)
585 ));
586 }
587
588 #[test]
589 fn test_decode_nes_accept_notification() {
590 let params = serde_json::to_string(&json!({
591 "sessionId": "session_123",
592 "id": "sugg_001"
593 }))
594 .unwrap();
595 let raw = serde_json::value::RawValue::from_string(params).unwrap();
596 let notification = AgentSide::decode_notification("nes/accept", Some(&raw)).unwrap();
597 assert!(matches!(
598 notification,
599 ClientNotification::AcceptNesNotification(_)
600 ));
601 }
602
603 #[test]
604 fn test_decode_nes_reject_notification() {
605 let params = serde_json::to_string(&json!({
606 "sessionId": "session_123",
607 "id": "sugg_001",
608 "reason": "rejected"
609 }))
610 .unwrap();
611 let raw = serde_json::value::RawValue::from_string(params).unwrap();
612 let notification = AgentSide::decode_notification("nes/reject", Some(&raw)).unwrap();
613 assert!(matches!(
614 notification,
615 ClientNotification::RejectNesNotification(_)
616 ));
617 }
618}
619
620#[test]
621fn test_notification_wire_format() {
622 use super::*;
623
624 use serde_json::{Value, json};
625
626 let outgoing_msg = JsonRpcMessage::wrap(
628 OutgoingMessage::<ClientSide, AgentSide>::Notification(Notification {
629 method: "cancel".into(),
630 params: Some(ClientNotification::CancelNotification(CancelNotification {
631 session_id: SessionId("test-123".into()),
632 meta: None,
633 })),
634 }),
635 );
636
637 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
638 assert_eq!(
639 serialized,
640 json!({
641 "jsonrpc": "2.0",
642 "method": "cancel",
643 "params": {
644 "sessionId": "test-123"
645 },
646 })
647 );
648
649 let outgoing_msg = JsonRpcMessage::wrap(
651 OutgoingMessage::<AgentSide, ClientSide>::Notification(Notification {
652 method: "sessionUpdate".into(),
653 params: Some(AgentNotification::SessionNotification(
654 SessionNotification {
655 session_id: SessionId("test-456".into()),
656 update: SessionUpdate::AgentMessageChunk(ContentChunk {
657 content: ContentBlock::Text(TextContent {
658 annotations: None,
659 text: "Hello".to_string(),
660 meta: None,
661 }),
662 #[cfg(feature = "unstable_message_id")]
663 message_id: None,
664 meta: None,
665 }),
666 meta: None,
667 },
668 )),
669 }),
670 );
671
672 let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
673 assert_eq!(
674 serialized,
675 json!({
676 "jsonrpc": "2.0",
677 "method": "sessionUpdate",
678 "params": {
679 "sessionId": "test-456",
680 "update": {
681 "sessionUpdate": "agent_message_chunk",
682 "content": {
683 "type": "text",
684 "text": "Hello"
685 }
686 }
687 }
688 })
689 );
690}