use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::core::types::AgentEvent;
#[derive(Debug, Serialize)]
pub struct InitializeParams {
#[serde(rename = "protocolVersion")]
pub protocol_version: u32,
#[serde(rename = "clientCapabilities")]
pub client_capabilities: ClientCapabilities,
#[serde(rename = "clientInfo")]
pub client_info: ClientInfo,
}
#[derive(Debug, Serialize)]
pub struct ClientCapabilities {
pub fs: FsCapabilities,
pub terminal: bool,
}
#[derive(Debug, Serialize)]
pub struct FsCapabilities {
#[serde(rename = "readTextFile")]
pub read_text_file: bool,
#[serde(rename = "writeTextFile")]
pub write_text_file: bool,
}
#[derive(Debug, Serialize)]
pub struct ClientInfo {
pub name: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
pub title: Option<&'static str>,
pub version: &'static str,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SessionNewParams {
pub cwd: String,
#[serde(rename = "mcpServers", default)]
pub mcp_servers: Vec<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ContentBlock {
#[serde(rename = "text")]
Text { text: String },
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SessionPromptParams {
#[serde(rename = "sessionId")]
pub session_id: String,
pub prompt: Vec<ContentBlock>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SessionCancelParams {
#[serde(rename = "sessionId")]
pub session_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionUpdateParams {
#[serde(rename = "sessionId")]
pub session_id: String,
pub update: SessionUpdate,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "sessionUpdate")]
pub enum SessionUpdate {
#[serde(rename = "agent_message_chunk")]
AgentMessageChunk {
#[serde(default)]
content: Value,
},
#[serde(rename = "agent_thought_chunk")]
AgentThoughtChunk {
#[serde(default)]
content: Value,
},
#[serde(rename = "tool_call")]
ToolCall {
#[serde(rename = "toolCallId", default)]
tool_call_id: String,
#[serde(default)]
title: String,
#[serde(default)]
kind: String,
#[serde(default)]
status: String,
#[serde(rename = "rawInput", default)]
raw_input: Value,
#[serde(default)]
locations: Vec<Value>,
},
#[serde(rename = "tool_call_update")]
ToolCallUpdate {
#[serde(rename = "toolCallId", default)]
tool_call_id: String,
#[serde(default)]
status: String,
#[serde(default)]
content: Vec<Value>,
},
#[serde(rename = "stop")]
Stop {
#[serde(rename = "stopReason", default)]
stop_reason: String,
},
#[serde(other)]
Unknown,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FsReadParams {
pub path: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TerminalCreateParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub cwd: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub env: Option<std::collections::HashMap<String, String>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PermissionRequestParams {
#[serde(rename = "toolName")]
pub tool_name: String,
pub description: String,
#[serde(rename = "sessionId")]
pub session_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AgentCapabilities {
#[serde(rename = "protocolVersion", default)]
pub protocol_version: u32,
#[serde(rename = "agentCapabilities", default)]
pub agent_capabilities: AgentCapabilityFlags,
#[serde(rename = "agentInfo", default)]
pub agent_info: AgentInfo,
#[serde(rename = "authMethods", default)]
pub auth_methods: Vec<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AgentCapabilityFlags {
#[serde(rename = "loadSession", default)]
pub load_session: bool,
#[serde(flatten)]
pub extra: std::collections::HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AgentInfo {
#[serde(default)]
pub name: String,
#[serde(default)]
pub title: String,
#[serde(default)]
pub version: String,
}
fn extract_text_from_content(content: &Value) -> String {
if let Some(t) = content.get("text").and_then(|v| v.as_str()) {
return t.to_owned();
}
if let Some(arr) = content.as_array() {
return arr
.iter()
.filter_map(|b| b.get("text").and_then(|v| v.as_str()))
.collect::<Vec<_>>()
.join("");
}
if let Some(s) = content.as_str() {
return s.to_owned();
}
String::new()
}
pub(crate) fn update_to_event(params: &SessionUpdateParams) -> Vec<AgentEvent> {
match ¶ms.update {
SessionUpdate::AgentMessageChunk { content } => {
let text = extract_text_from_content(content);
if text.is_empty() {
vec![]
} else {
vec![AgentEvent::Text { text, is_delta: true }]
}
}
SessionUpdate::AgentThoughtChunk { content } => {
let text = content
.get("thought")
.and_then(|v| v.as_str())
.or_else(|| content.as_str())
.unwrap_or("")
.to_owned();
vec![AgentEvent::Thinking { text }]
}
SessionUpdate::ToolCall { tool_call_id, title, raw_input, .. } => {
vec![AgentEvent::ToolStart {
id: tool_call_id.clone(),
name: title.clone(),
input: raw_input.clone(),
}]
}
SessionUpdate::ToolCallUpdate { tool_call_id, status, content } => {
let output = content
.iter()
.filter_map(|v| v.as_str())
.collect::<Vec<_>>()
.join("");
let is_error = status == "error";
vec![AgentEvent::ToolResult {
id: tool_call_id.clone(),
output,
is_error,
duration_ms: None,
}]
}
SessionUpdate::Stop { stop_reason } => {
vec![
AgentEvent::TurnComplete { input_tokens: 0, output_tokens: 0 },
AgentEvent::SessionEnd {
result: stop_reason.clone(),
cost_usd: None,
is_error: false,
},
]
}
SessionUpdate::Unknown => vec![],
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn make_update(update: SessionUpdate) -> SessionUpdateParams {
SessionUpdateParams { session_id: "s1".to_string(), update }
}
#[test]
fn update_to_event_text_delta_array() {
let p = make_update(SessionUpdate::AgentMessageChunk {
content: json!([{"type": "text", "text": "hello"}]),
});
let events = update_to_event(&p);
assert_eq!(events.len(), 1);
assert!(matches!(&events[0], AgentEvent::Text { text, is_delta: true } if text == "hello"));
}
#[test]
fn update_to_event_text_delta_single_object() {
let p = make_update(SessionUpdate::AgentMessageChunk {
content: json!({"type": "text", "text": "hello"}),
});
let events = update_to_event(&p);
assert_eq!(events.len(), 1);
assert!(matches!(&events[0], AgentEvent::Text { text, is_delta: true } if text == "hello"));
}
#[test]
fn update_to_event_text_delta_plain_string() {
let p = make_update(SessionUpdate::AgentMessageChunk {
content: json!("hello"),
});
let events = update_to_event(&p);
assert_eq!(events.len(), 1);
assert!(matches!(&events[0], AgentEvent::Text { text, is_delta: true } if text == "hello"));
}
#[test]
fn update_to_event_text_delta_empty_returns_no_events() {
let p = make_update(SessionUpdate::AgentMessageChunk {
content: json!(null),
});
let events = update_to_event(&p);
assert!(events.is_empty());
}
#[test]
fn update_to_event_thinking_plain_string() {
let p = make_update(SessionUpdate::AgentThoughtChunk {
content: json!("thinking..."),
});
let events = update_to_event(&p);
assert_eq!(events.len(), 1);
assert!(matches!(&events[0], AgentEvent::Thinking { text } if text == "thinking..."));
}
#[test]
fn update_to_event_thinking_thought_field() {
let p = make_update(SessionUpdate::AgentThoughtChunk {
content: json!({"thought": "deep thought"}),
});
let events = update_to_event(&p);
assert_eq!(events.len(), 1);
assert!(matches!(&events[0], AgentEvent::Thinking { text } if text == "deep thought"));
}
#[test]
fn update_to_event_tool_start() {
let p = make_update(SessionUpdate::ToolCall {
tool_call_id: "t1".to_string(),
title: "bash".to_string(),
kind: "bash".to_string(),
status: "pending".to_string(),
raw_input: json!({"cmd": "ls"}),
locations: vec![],
});
let events = update_to_event(&p);
assert_eq!(events.len(), 1);
assert!(
matches!(&events[0], AgentEvent::ToolStart { id, name, .. } if id == "t1" && name == "bash")
);
}
#[test]
fn update_to_event_tool_result() {
let p = make_update(SessionUpdate::ToolCallUpdate {
tool_call_id: "t1".to_string(),
status: "done".to_string(),
content: vec![json!("ok")],
});
let events = update_to_event(&p);
assert_eq!(events.len(), 1);
assert!(
matches!(&events[0], AgentEvent::ToolResult { id, output, is_error, .. }
if id == "t1" && output == "ok" && !is_error)
);
}
#[test]
fn update_to_event_stop_emits_two_events() {
let p = make_update(SessionUpdate::Stop { stop_reason: "end_turn".to_string() });
let events = update_to_event(&p);
assert_eq!(events.len(), 2);
assert!(matches!(&events[0], AgentEvent::TurnComplete { .. }));
assert!(matches!(&events[1], AgentEvent::SessionEnd { is_error: false, .. }));
}
#[test]
fn update_to_event_unknown_returns_empty() {
let p = make_update(SessionUpdate::Unknown);
let events = update_to_event(&p);
assert!(events.is_empty());
}
#[test]
fn session_update_params_round_trip() {
let original = SessionUpdateParams {
session_id: "abc".to_string(),
update: SessionUpdate::AgentMessageChunk {
content: json!([{"type": "text", "text": "hi"}]),
},
};
let serialized = serde_json::to_string(&original).unwrap();
let deserialized: SessionUpdateParams = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.session_id, "abc");
assert!(matches!(
deserialized.update,
SessionUpdate::AgentMessageChunk { .. }
));
}
#[test]
fn session_update_params_gemini_round_trip() {
let raw = r#"{"sessionId":"s1","update":{"content":{"text":"hello","type":"text"},"sessionUpdate":"agent_message_chunk"}}"#;
let params: SessionUpdateParams = serde_json::from_str(raw).unwrap();
let events = update_to_event(¶ms);
assert_eq!(events.len(), 1);
assert!(matches!(&events[0], AgentEvent::Text { text, is_delta: true } if text == "hello"));
}
#[test]
fn initialize_params_serialize_camel_case() {
let p = InitializeParams {
protocol_version: 1,
client_capabilities: ClientCapabilities {
fs: FsCapabilities { read_text_file: true, write_text_file: true },
terminal: true,
},
client_info: ClientInfo { name: "gate4agent", title: Some("Gate4Agent"), version: "0.2.0" },
};
let s = serde_json::to_string(&p).unwrap();
assert!(s.contains("protocolVersion"), "must use protocolVersion");
assert!(s.contains("clientInfo"), "must use clientInfo");
assert!(s.contains("clientCapabilities"), "must include clientCapabilities");
assert!(s.contains(r#""protocolVersion":1"#), "protocolVersion must be integer 1");
}
#[test]
fn session_new_params_serialize() {
let p = SessionNewParams { cwd: "/home/user".to_string(), mcp_servers: vec![] };
let s = serde_json::to_string(&p).unwrap();
assert!(s.contains("\"cwd\""), "must use cwd");
assert!(s.contains("\"mcpServers\""), "must use mcpServers");
}
#[test]
fn session_prompt_params_wraps_content_blocks() {
let p = SessionPromptParams {
session_id: "s1".to_string(),
prompt: vec![ContentBlock::Text { text: "hello".to_string() }],
};
let s = serde_json::to_string(&p).unwrap();
assert!(s.contains("\"prompt\""), "must have prompt field");
assert!(s.contains("\"type\":\"text\""), "content block must have type=text");
assert!(s.contains("\"text\":\"hello\""), "must have text content");
}
}