use std::sync::Arc;
use derive_more::{Display, From};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use serde_json::value::RawValue;
use crate::{
AGENT_METHOD_NAMES, AgentNotification, AgentRequest, AgentResponse, CLIENT_METHOD_NAMES,
ClientNotification, ClientRequest, ClientResponse, Error, ExtNotification, ExtRequest, Result,
};
#[derive(
Debug,
PartialEq,
Clone,
Hash,
Eq,
Deserialize,
Serialize,
PartialOrd,
Ord,
Display,
JsonSchema,
From,
)]
#[serde(untagged)]
#[allow(
clippy::exhaustive_enums,
reason = "This comes from the JSON-RPC specification itself"
)]
#[from(String, i64)]
pub enum RequestId {
#[display("null")]
Null,
Number(i64),
Str(String),
}
#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
#[allow(
clippy::exhaustive_structs,
reason = "This comes from the JSON-RPC specification itself"
)]
#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
pub struct Request<Params> {
pub id: RequestId,
pub method: Arc<str>,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<Params>,
}
#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
#[allow(
clippy::exhaustive_enums,
reason = "This comes from the JSON-RPC specification itself"
)]
#[serde(untagged)]
#[schemars(rename = "{Result}", extend("x-docs-ignore" = true))]
pub enum Response<Result> {
Result { id: RequestId, result: Result },
Error { id: RequestId, error: Error },
}
impl<R> Response<R> {
#[must_use]
pub fn new(id: impl Into<RequestId>, result: Result<R>) -> Self {
match result {
Ok(result) => Self::Result {
id: id.into(),
result,
},
Err(error) => Self::Error {
id: id.into(),
error,
},
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
#[allow(
clippy::exhaustive_structs,
reason = "This comes from the JSON-RPC specification itself"
)]
#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
pub struct Notification<Params> {
pub method: Arc<str>,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<Params>,
}
#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
#[serde(untagged)]
#[schemars(inline)]
#[allow(
clippy::exhaustive_enums,
reason = "This comes from the JSON-RPC specification itself"
)]
pub enum OutgoingMessage<Local: Side, Remote: Side> {
Request(Request<Remote::InRequest>),
Response(Response<Local::OutResponse>),
Notification(Notification<Remote::InNotification>),
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[schemars(inline)]
enum JsonRpcVersion {
#[serde(rename = "2.0")]
V2,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[schemars(inline)]
pub struct JsonRpcMessage<M> {
jsonrpc: JsonRpcVersion,
#[serde(flatten)]
message: M,
}
impl<M> JsonRpcMessage<M> {
#[must_use]
pub fn wrap(message: M) -> Self {
Self {
jsonrpc: JsonRpcVersion::V2,
message,
}
}
}
pub trait Side: Clone {
type InRequest: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
type InNotification: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
type OutResponse: Clone + Serialize + DeserializeOwned + JsonSchema + 'static;
fn decode_request(method: &str, params: Option<&RawValue>) -> Result<Self::InRequest>;
fn decode_notification(method: &str, params: Option<&RawValue>)
-> Result<Self::InNotification>;
}
#[derive(Clone, Default, Debug, JsonSchema)]
#[non_exhaustive]
pub struct ClientSide;
impl Side for ClientSide {
type InRequest = AgentRequest;
type InNotification = AgentNotification;
type OutResponse = ClientResponse;
fn decode_request(method: &str, params: Option<&RawValue>) -> Result<AgentRequest> {
let params = params.ok_or_else(Error::invalid_params)?;
match method {
m if m == CLIENT_METHOD_NAMES.session_request_permission => {
serde_json::from_str(params.get())
.map(AgentRequest::RequestPermissionRequest)
.map_err(Into::into)
}
m if m == CLIENT_METHOD_NAMES.fs_write_text_file => serde_json::from_str(params.get())
.map(AgentRequest::WriteTextFileRequest)
.map_err(Into::into),
m if m == CLIENT_METHOD_NAMES.fs_read_text_file => serde_json::from_str(params.get())
.map(AgentRequest::ReadTextFileRequest)
.map_err(Into::into),
m if m == CLIENT_METHOD_NAMES.terminal_create => serde_json::from_str(params.get())
.map(AgentRequest::CreateTerminalRequest)
.map_err(Into::into),
m if m == CLIENT_METHOD_NAMES.terminal_output => serde_json::from_str(params.get())
.map(AgentRequest::TerminalOutputRequest)
.map_err(Into::into),
m if m == CLIENT_METHOD_NAMES.terminal_kill => serde_json::from_str(params.get())
.map(AgentRequest::KillTerminalRequest)
.map_err(Into::into),
m if m == CLIENT_METHOD_NAMES.terminal_release => serde_json::from_str(params.get())
.map(AgentRequest::ReleaseTerminalRequest)
.map_err(Into::into),
m if m == CLIENT_METHOD_NAMES.terminal_wait_for_exit => {
serde_json::from_str(params.get())
.map(AgentRequest::WaitForTerminalExitRequest)
.map_err(Into::into)
}
#[cfg(feature = "unstable_elicitation")]
m if m == CLIENT_METHOD_NAMES.elicitation_create => serde_json::from_str(params.get())
.map(AgentRequest::CreateElicitationRequest)
.map_err(Into::into),
_ => {
if let Some(custom_method) = method.strip_prefix('_') {
Ok(AgentRequest::ExtMethodRequest(ExtRequest {
method: custom_method.into(),
params: params.to_owned().into(),
}))
} else {
Err(Error::method_not_found())
}
}
}
}
fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<AgentNotification> {
let params = params.ok_or_else(Error::invalid_params)?;
match method {
m if m == CLIENT_METHOD_NAMES.session_update => serde_json::from_str(params.get())
.map(AgentNotification::SessionNotification)
.map_err(Into::into),
#[cfg(feature = "unstable_elicitation")]
m if m == CLIENT_METHOD_NAMES.elicitation_complete => {
serde_json::from_str(params.get())
.map(AgentNotification::CompleteElicitationNotification)
.map_err(Into::into)
}
_ => {
if let Some(custom_method) = method.strip_prefix('_') {
Ok(AgentNotification::ExtNotification(ExtNotification {
method: custom_method.into(),
params: params.to_owned().into(),
}))
} else {
Err(Error::method_not_found())
}
}
}
}
}
#[derive(Clone, Default, Debug, JsonSchema)]
#[non_exhaustive]
pub struct AgentSide;
impl Side for AgentSide {
type InRequest = ClientRequest;
type InNotification = ClientNotification;
type OutResponse = AgentResponse;
fn decode_request(method: &str, params: Option<&RawValue>) -> Result<ClientRequest> {
let params = params.ok_or_else(Error::invalid_params)?;
match method {
m if m == AGENT_METHOD_NAMES.initialize => serde_json::from_str(params.get())
.map(ClientRequest::InitializeRequest)
.map_err(Into::into),
m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get())
.map(ClientRequest::AuthenticateRequest)
.map_err(Into::into),
#[cfg(feature = "unstable_llm_providers")]
m if m == AGENT_METHOD_NAMES.providers_list => serde_json::from_str(params.get())
.map(ClientRequest::ListProvidersRequest)
.map_err(Into::into),
#[cfg(feature = "unstable_llm_providers")]
m if m == AGENT_METHOD_NAMES.providers_set => serde_json::from_str(params.get())
.map(ClientRequest::SetProvidersRequest)
.map_err(Into::into),
#[cfg(feature = "unstable_llm_providers")]
m if m == AGENT_METHOD_NAMES.providers_disable => serde_json::from_str(params.get())
.map(ClientRequest::DisableProvidersRequest)
.map_err(Into::into),
#[cfg(feature = "unstable_logout")]
m if m == AGENT_METHOD_NAMES.logout => serde_json::from_str(params.get())
.map(ClientRequest::LogoutRequest)
.map_err(Into::into),
m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get())
.map(ClientRequest::NewSessionRequest)
.map_err(Into::into),
m if m == AGENT_METHOD_NAMES.session_load => serde_json::from_str(params.get())
.map(ClientRequest::LoadSessionRequest)
.map_err(Into::into),
m if m == AGENT_METHOD_NAMES.session_list => serde_json::from_str(params.get())
.map(ClientRequest::ListSessionsRequest)
.map_err(Into::into),
#[cfg(feature = "unstable_session_fork")]
m if m == AGENT_METHOD_NAMES.session_fork => serde_json::from_str(params.get())
.map(ClientRequest::ForkSessionRequest)
.map_err(Into::into),
#[cfg(feature = "unstable_session_resume")]
m if m == AGENT_METHOD_NAMES.session_resume => serde_json::from_str(params.get())
.map(ClientRequest::ResumeSessionRequest)
.map_err(Into::into),
#[cfg(feature = "unstable_session_close")]
m if m == AGENT_METHOD_NAMES.session_close => serde_json::from_str(params.get())
.map(ClientRequest::CloseSessionRequest)
.map_err(Into::into),
m if m == AGENT_METHOD_NAMES.session_set_mode => serde_json::from_str(params.get())
.map(ClientRequest::SetSessionModeRequest)
.map_err(Into::into),
m if m == AGENT_METHOD_NAMES.session_set_config_option => {
serde_json::from_str(params.get())
.map(ClientRequest::SetSessionConfigOptionRequest)
.map_err(Into::into)
}
#[cfg(feature = "unstable_session_model")]
m if m == AGENT_METHOD_NAMES.session_set_model => serde_json::from_str(params.get())
.map(ClientRequest::SetSessionModelRequest)
.map_err(Into::into),
m if m == AGENT_METHOD_NAMES.session_prompt => serde_json::from_str(params.get())
.map(ClientRequest::PromptRequest)
.map_err(Into::into),
#[cfg(feature = "unstable_nes")]
m if m == AGENT_METHOD_NAMES.nes_start => serde_json::from_str(params.get())
.map(ClientRequest::StartNesRequest)
.map_err(Into::into),
#[cfg(feature = "unstable_nes")]
m if m == AGENT_METHOD_NAMES.nes_suggest => serde_json::from_str(params.get())
.map(ClientRequest::SuggestNesRequest)
.map_err(Into::into),
#[cfg(feature = "unstable_nes")]
m if m == AGENT_METHOD_NAMES.nes_close => serde_json::from_str(params.get())
.map(ClientRequest::CloseNesRequest)
.map_err(Into::into),
_ => {
if let Some(custom_method) = method.strip_prefix('_') {
Ok(ClientRequest::ExtMethodRequest(ExtRequest {
method: custom_method.into(),
params: params.to_owned().into(),
}))
} else {
Err(Error::method_not_found())
}
}
}
}
fn decode_notification(method: &str, params: Option<&RawValue>) -> Result<ClientNotification> {
let params = params.ok_or_else(Error::invalid_params)?;
match method {
m if m == AGENT_METHOD_NAMES.session_cancel => serde_json::from_str(params.get())
.map(ClientNotification::CancelNotification)
.map_err(Into::into),
#[cfg(feature = "unstable_nes")]
m if m == AGENT_METHOD_NAMES.document_did_open => serde_json::from_str(params.get())
.map(ClientNotification::DidOpenDocumentNotification)
.map_err(Into::into),
#[cfg(feature = "unstable_nes")]
m if m == AGENT_METHOD_NAMES.document_did_change => serde_json::from_str(params.get())
.map(ClientNotification::DidChangeDocumentNotification)
.map_err(Into::into),
#[cfg(feature = "unstable_nes")]
m if m == AGENT_METHOD_NAMES.document_did_close => serde_json::from_str(params.get())
.map(ClientNotification::DidCloseDocumentNotification)
.map_err(Into::into),
#[cfg(feature = "unstable_nes")]
m if m == AGENT_METHOD_NAMES.document_did_save => serde_json::from_str(params.get())
.map(ClientNotification::DidSaveDocumentNotification)
.map_err(Into::into),
#[cfg(feature = "unstable_nes")]
m if m == AGENT_METHOD_NAMES.document_did_focus => serde_json::from_str(params.get())
.map(ClientNotification::DidFocusDocumentNotification)
.map_err(Into::into),
#[cfg(feature = "unstable_nes")]
m if m == AGENT_METHOD_NAMES.nes_accept => serde_json::from_str(params.get())
.map(ClientNotification::AcceptNesNotification)
.map_err(Into::into),
#[cfg(feature = "unstable_nes")]
m if m == AGENT_METHOD_NAMES.nes_reject => serde_json::from_str(params.get())
.map(ClientNotification::RejectNesNotification)
.map_err(Into::into),
_ => {
if let Some(custom_method) = method.strip_prefix('_') {
Ok(ClientNotification::ExtNotification(ExtNotification {
method: custom_method.into(),
params: params.to_owned().into(),
}))
} else {
Err(Error::method_not_found())
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::{Number, Value};
#[test]
fn id_deserialization() {
let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
assert_eq!(id, RequestId::Null);
let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
.unwrap();
assert_eq!(id, RequestId::Number(1));
let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
.unwrap();
assert_eq!(id, RequestId::Number(-1));
let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
assert_eq!(id, RequestId::Str("id".to_owned()));
}
#[test]
fn id_serialization() {
let id = serde_json::to_value(RequestId::Null).unwrap();
assert_eq!(id, Value::Null);
let id = serde_json::to_value(RequestId::Number(1)).unwrap();
assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
assert_eq!(id, Value::String("id".to_owned()));
}
#[test]
fn id_display() {
let id = RequestId::Null;
assert_eq!(id.to_string(), "null");
let id = RequestId::Number(1);
assert_eq!(id.to_string(), "1");
let id = RequestId::Number(-1);
assert_eq!(id.to_string(), "-1");
let id = RequestId::Str("id".to_owned());
assert_eq!(id.to_string(), "id");
}
}
#[cfg(feature = "unstable_nes")]
#[cfg(test)]
mod nes_rpc_tests {
use super::*;
use serde_json::json;
#[test]
fn test_decode_nes_start_request() {
let params = serde_json::to_string(&json!({
"workspaceUri": "file:///Users/alice/projects/my-app",
"workspaceFolders": [
{ "uri": "file:///Users/alice/projects/my-app", "name": "my-app" }
]
}))
.unwrap();
let raw = serde_json::value::RawValue::from_string(params).unwrap();
let request = AgentSide::decode_request("nes/start", Some(&raw)).unwrap();
assert!(matches!(request, ClientRequest::StartNesRequest(_)));
}
#[test]
fn test_decode_nes_suggest_request() {
let params = serde_json::to_string(&json!({
"sessionId": "session_123",
"uri": "file:///path/to/file.rs",
"version": 2,
"position": { "line": 5, "character": 12 },
"triggerKind": "automatic"
}))
.unwrap();
let raw = serde_json::value::RawValue::from_string(params).unwrap();
let request = AgentSide::decode_request("nes/suggest", Some(&raw)).unwrap();
assert!(matches!(request, ClientRequest::SuggestNesRequest(_)));
}
#[test]
fn test_decode_nes_close_request() {
let params = serde_json::to_string(&json!({
"sessionId": "session_123"
}))
.unwrap();
let raw = serde_json::value::RawValue::from_string(params).unwrap();
let request = AgentSide::decode_request("nes/close", Some(&raw)).unwrap();
assert!(matches!(request, ClientRequest::CloseNesRequest(_)));
}
#[test]
fn test_decode_document_did_open_notification() {
let params = serde_json::to_string(&json!({
"sessionId": "session_123",
"uri": "file:///path/to/file.rs",
"languageId": "rust",
"version": 1,
"text": "fn main() {}"
}))
.unwrap();
let raw = serde_json::value::RawValue::from_string(params).unwrap();
let notification = AgentSide::decode_notification("document/didOpen", Some(&raw)).unwrap();
assert!(matches!(
notification,
ClientNotification::DidOpenDocumentNotification(_)
));
}
#[test]
fn test_decode_document_did_change_notification() {
let params = serde_json::to_string(&json!({
"sessionId": "session_123",
"uri": "file:///path/to/file.rs",
"version": 2,
"contentChanges": [{ "text": "fn main() { let x = 1; }" }]
}))
.unwrap();
let raw = serde_json::value::RawValue::from_string(params).unwrap();
let notification =
AgentSide::decode_notification("document/didChange", Some(&raw)).unwrap();
assert!(matches!(
notification,
ClientNotification::DidChangeDocumentNotification(_)
));
}
#[test]
fn test_decode_document_did_close_notification() {
let params = serde_json::to_string(&json!({
"sessionId": "session_123",
"uri": "file:///path/to/file.rs"
}))
.unwrap();
let raw = serde_json::value::RawValue::from_string(params).unwrap();
let notification = AgentSide::decode_notification("document/didClose", Some(&raw)).unwrap();
assert!(matches!(
notification,
ClientNotification::DidCloseDocumentNotification(_)
));
}
#[test]
fn test_decode_document_did_save_notification() {
let params = serde_json::to_string(&json!({
"sessionId": "session_123",
"uri": "file:///path/to/file.rs"
}))
.unwrap();
let raw = serde_json::value::RawValue::from_string(params).unwrap();
let notification = AgentSide::decode_notification("document/didSave", Some(&raw)).unwrap();
assert!(matches!(
notification,
ClientNotification::DidSaveDocumentNotification(_)
));
}
#[test]
fn test_decode_document_did_focus_notification() {
let params = serde_json::to_string(&json!({
"sessionId": "session_123",
"uri": "file:///path/to/file.rs",
"version": 2,
"position": { "line": 5, "character": 12 },
"visibleRange": {
"start": { "line": 0, "character": 0 },
"end": { "line": 45, "character": 0 }
}
}))
.unwrap();
let raw = serde_json::value::RawValue::from_string(params).unwrap();
let notification = AgentSide::decode_notification("document/didFocus", Some(&raw)).unwrap();
assert!(matches!(
notification,
ClientNotification::DidFocusDocumentNotification(_)
));
}
#[test]
fn test_decode_nes_accept_notification() {
let params = serde_json::to_string(&json!({
"sessionId": "session_123",
"id": "sugg_001"
}))
.unwrap();
let raw = serde_json::value::RawValue::from_string(params).unwrap();
let notification = AgentSide::decode_notification("nes/accept", Some(&raw)).unwrap();
assert!(matches!(
notification,
ClientNotification::AcceptNesNotification(_)
));
}
#[test]
fn test_decode_nes_reject_notification() {
let params = serde_json::to_string(&json!({
"sessionId": "session_123",
"id": "sugg_001",
"reason": "rejected"
}))
.unwrap();
let raw = serde_json::value::RawValue::from_string(params).unwrap();
let notification = AgentSide::decode_notification("nes/reject", Some(&raw)).unwrap();
assert!(matches!(
notification,
ClientNotification::RejectNesNotification(_)
));
}
}
#[cfg(feature = "unstable_llm_providers")]
#[cfg(test)]
mod providers_rpc_tests {
use super::*;
use serde_json::json;
#[test]
fn test_decode_providers_list_request() {
let params = serde_json::to_string(&json!({})).unwrap();
let raw = serde_json::value::RawValue::from_string(params).unwrap();
let request = AgentSide::decode_request("providers/list", Some(&raw)).unwrap();
assert!(matches!(request, ClientRequest::ListProvidersRequest(_)));
}
#[test]
fn test_decode_providers_set_request() {
let params = serde_json::to_string(&json!({
"id": "main",
"apiType": "anthropic",
"baseUrl": "https://api.anthropic.com"
}))
.unwrap();
let raw = serde_json::value::RawValue::from_string(params).unwrap();
let request = AgentSide::decode_request("providers/set", Some(&raw)).unwrap();
assert!(matches!(request, ClientRequest::SetProvidersRequest(_)));
}
#[test]
fn test_decode_providers_set_request_with_headers() {
let params = serde_json::to_string(&json!({
"id": "main",
"apiType": "openai",
"baseUrl": "https://api.openai.com/v1",
"headers": {
"Authorization": "Bearer sk-test"
}
}))
.unwrap();
let raw = serde_json::value::RawValue::from_string(params).unwrap();
let request = AgentSide::decode_request("providers/set", Some(&raw)).unwrap();
assert!(matches!(request, ClientRequest::SetProvidersRequest(_)));
}
#[test]
fn test_decode_providers_disable_request() {
let params = serde_json::to_string(&json!({
"id": "secondary"
}))
.unwrap();
let raw = serde_json::value::RawValue::from_string(params).unwrap();
let request = AgentSide::decode_request("providers/disable", Some(&raw)).unwrap();
assert!(matches!(request, ClientRequest::DisableProvidersRequest(_)));
}
}
#[test]
fn test_notification_wire_format() {
use super::*;
use serde_json::{Value, json};
let outgoing_msg = JsonRpcMessage::wrap(
OutgoingMessage::<ClientSide, AgentSide>::Notification(Notification {
method: "cancel".into(),
params: Some(ClientNotification::CancelNotification(CancelNotification {
session_id: SessionId("test-123".into()),
meta: None,
})),
}),
);
let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
assert_eq!(
serialized,
json!({
"jsonrpc": "2.0",
"method": "cancel",
"params": {
"sessionId": "test-123"
},
})
);
let outgoing_msg = JsonRpcMessage::wrap(
OutgoingMessage::<AgentSide, ClientSide>::Notification(Notification {
method: "sessionUpdate".into(),
params: Some(AgentNotification::SessionNotification(
SessionNotification {
session_id: SessionId("test-456".into()),
update: SessionUpdate::AgentMessageChunk(ContentChunk {
content: ContentBlock::Text(TextContent {
annotations: None,
text: "Hello".to_string(),
meta: None,
}),
#[cfg(feature = "unstable_message_id")]
message_id: None,
meta: None,
}),
meta: None,
},
)),
}),
);
let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
assert_eq!(
serialized,
json!({
"jsonrpc": "2.0",
"method": "sessionUpdate",
"params": {
"sessionId": "test-456",
"update": {
"sessionUpdate": "agent_message_chunk",
"content": {
"type": "text",
"text": "Hello"
}
}
}
})
);
}