use std::sync::Arc;
use derive_more::{Display, From};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
#[derive(
Debug,
PartialEq,
Clone,
Hash,
Eq,
Deserialize,
Serialize,
PartialOrd,
Ord,
Display,
JsonSchema,
From,
)]
#[serde(untagged)]
#[allow(
clippy::exhaustive_enums,
reason = "This comes from the JSON-RPC specification itself"
)]
#[from(String, i64)]
pub enum RequestId {
#[display("null")]
Null,
Number(i64),
Str(String),
}
#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
#[allow(
clippy::exhaustive_structs,
reason = "This comes from the JSON-RPC specification itself"
)]
#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
#[skip_serializing_none]
pub struct Request<Params> {
pub id: RequestId,
pub method: Arc<str>,
pub params: Option<Params>,
}
#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
#[allow(
clippy::exhaustive_enums,
reason = "This comes from the JSON-RPC specification itself"
)]
#[serde(untagged)]
#[schemars(rename = "{Result}", extend("x-docs-ignore" = true))]
pub enum Response<Result, Error> {
Result { id: RequestId, result: Result },
Error { id: RequestId, error: Error },
}
impl<R, E> Response<R, E> {
#[must_use]
pub fn new(id: impl Into<RequestId>, result: std::result::Result<R, E>) -> Self {
match result {
Ok(result) => Self::Result {
id: id.into(),
result,
},
Err(error) => Self::Error {
id: id.into(),
error,
},
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, JsonSchema)]
#[allow(
clippy::exhaustive_structs,
reason = "This comes from the JSON-RPC specification itself"
)]
#[schemars(rename = "{Params}", extend("x-docs-ignore" = true))]
#[skip_serializing_none]
pub struct Notification<Params> {
pub method: Arc<str>,
pub params: Option<Params>,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[schemars(inline)]
enum JsonRpcVersion {
#[serde(rename = "2.0")]
V2,
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[schemars(inline)]
pub struct JsonRpcMessage<M> {
jsonrpc: JsonRpcVersion,
#[serde(flatten)]
message: M,
}
impl<M> JsonRpcMessage<M> {
#[must_use]
pub fn wrap(message: M) -> Self {
Self {
jsonrpc: JsonRpcVersion::V2,
message,
}
}
#[must_use]
pub fn into_inner(self) -> M {
self.message
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)]
#[display("JSON-RPC batch must contain at least one message")]
#[non_exhaustive]
pub struct EmptyJsonRpcBatch;
impl std::error::Error for EmptyJsonRpcBatch {}
#[derive(Debug, Serialize, JsonSchema)]
#[schemars(inline)]
#[serde(transparent)]
#[allow(
clippy::exhaustive_structs,
reason = "This comes from the JSON-RPC specification itself"
)]
pub struct JsonRpcBatch<M>(#[schemars(length(min = 1))] Vec<JsonRpcMessage<M>>);
impl<M> JsonRpcBatch<M> {
pub fn new(messages: Vec<JsonRpcMessage<M>>) -> Result<Self, EmptyJsonRpcBatch> {
if messages.is_empty() {
Err(EmptyJsonRpcBatch)
} else {
Ok(Self(messages))
}
}
#[must_use]
pub fn as_slice(&self) -> &[JsonRpcMessage<M>] {
&self.0
}
#[must_use]
pub fn into_vec(self) -> Vec<JsonRpcMessage<M>> {
self.0
}
}
impl<M> TryFrom<Vec<JsonRpcMessage<M>>> for JsonRpcBatch<M> {
type Error = EmptyJsonRpcBatch;
fn try_from(messages: Vec<JsonRpcMessage<M>>) -> Result<Self, Self::Error> {
Self::new(messages)
}
}
impl<'de, M> Deserialize<'de> for JsonRpcBatch<M>
where
M: Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let messages = Vec::<JsonRpcMessage<M>>::deserialize(deserializer)?;
Self::new(messages).map_err(serde::de::Error::custom)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
AgentNotification, CancelNotification, ClientNotification, ContentBlock, ContentChunk,
SessionId, SessionNotification, SessionUpdate, TextContent,
};
use serde_json::{Number, Value, json};
#[test]
fn id_deserialization() {
let id = serde_json::from_value::<RequestId>(Value::Null).unwrap();
assert_eq!(id, RequestId::Null);
let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_u128(1).unwrap()))
.unwrap();
assert_eq!(id, RequestId::Number(1));
let id = serde_json::from_value::<RequestId>(Value::Number(Number::from_i128(-1).unwrap()))
.unwrap();
assert_eq!(id, RequestId::Number(-1));
let id = serde_json::from_value::<RequestId>(Value::String("id".to_owned())).unwrap();
assert_eq!(id, RequestId::Str("id".to_owned()));
}
#[test]
fn id_serialization() {
let id = serde_json::to_value(RequestId::Null).unwrap();
assert_eq!(id, Value::Null);
let id = serde_json::to_value(RequestId::Number(1)).unwrap();
assert_eq!(id, Value::Number(Number::from_u128(1).unwrap()));
let id = serde_json::to_value(RequestId::Number(-1)).unwrap();
assert_eq!(id, Value::Number(Number::from_i128(-1).unwrap()));
let id = serde_json::to_value(RequestId::Str("id".to_owned())).unwrap();
assert_eq!(id, Value::String("id".to_owned()));
}
#[test]
fn id_display() {
let id = RequestId::Null;
assert_eq!(id.to_string(), "null");
let id = RequestId::Number(1);
assert_eq!(id.to_string(), "1");
let id = RequestId::Number(-1);
assert_eq!(id.to_string(), "-1");
let id = RequestId::Str("id".to_owned());
assert_eq!(id.to_string(), "id");
}
#[test]
fn batch_deserialization_requires_at_least_one_message() {
let err = serde_json::from_value::<JsonRpcBatch<Notification<ClientNotification>>>(
Value::Array(Vec::new()),
)
.unwrap_err();
assert!(err.to_string().contains("at least one message"));
}
#[test]
fn batch_serialization_round_trips_non_empty_messages() {
let notification = JsonRpcMessage::wrap(Notification {
method: "cancel".into(),
params: Some(ClientNotification::CancelNotification(CancelNotification {
session_id: SessionId("test-123".into()),
meta: None,
})),
});
let batch = JsonRpcBatch::new(vec![notification]).unwrap();
let serialized = serde_json::to_value(&batch).unwrap();
assert_eq!(
serialized,
json!([{
"jsonrpc": "2.0",
"method": "cancel",
"params": {
"sessionId": "test-123"
},
}])
);
let deserialized =
serde_json::from_value::<JsonRpcBatch<Notification<ClientNotification>>>(serialized)
.unwrap();
assert_eq!(deserialized.as_slice().len(), 1);
}
#[test]
fn notification_wire_format() {
let outgoing_msg = JsonRpcMessage::wrap(Notification {
method: "cancel".into(),
params: Some(ClientNotification::CancelNotification(CancelNotification {
session_id: SessionId("test-123".into()),
meta: None,
})),
});
let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
assert_eq!(
serialized,
json!({
"jsonrpc": "2.0",
"method": "cancel",
"params": {
"sessionId": "test-123"
},
})
);
let outgoing_msg = JsonRpcMessage::wrap(Notification {
method: "sessionUpdate".into(),
params: Some(AgentNotification::SessionNotification(
SessionNotification {
session_id: SessionId("test-456".into()),
update: SessionUpdate::AgentMessageChunk(ContentChunk {
content: ContentBlock::Text(TextContent {
annotations: None,
text: "Hello".to_string(),
meta: None,
}),
message_id: None,
meta: None,
}),
meta: None,
},
)),
});
let serialized: Value = serde_json::to_value(&outgoing_msg).unwrap();
assert_eq!(
serialized,
json!({
"jsonrpc": "2.0",
"method": "sessionUpdate",
"params": {
"sessionId": "test-456",
"update": {
"sessionUpdate": "agent_message_chunk",
"content": {
"type": "text",
"text": "Hello"
}
}
}
})
);
}
}