Skip to main content

agentox_core/client/
session.rs

1//! MCP session management — handles the protocol handshake and method calls.
2
3use crate::client::transport::Transport;
4use crate::error::{SessionError, TransportError};
5use crate::protocol::jsonrpc::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
6use crate::protocol::mcp_types::*;
7use std::sync::atomic::{AtomicI64, Ordering};
8
9/// An MCP session that manages the protocol lifecycle over a transport.
10pub struct McpSession {
11    transport: Box<dyn Transport>,
12    next_id: AtomicI64,
13    server_capabilities: Option<ServerCapabilities>,
14    server_info: Option<Implementation>,
15    protocol_version: Option<String>,
16}
17
18impl McpSession {
19    /// Create a new session wrapping a transport.
20    pub fn new(transport: Box<dyn Transport>) -> Self {
21        Self {
22            transport,
23            next_id: AtomicI64::new(1),
24            server_capabilities: None,
25            server_info: None,
26            protocol_version: None,
27        }
28    }
29
30    /// Perform the MCP initialize handshake.
31    pub async fn initialize(&mut self) -> Result<InitializeResult, SessionError> {
32        let params = InitializeParams {
33            protocol_version: "2025-11-25".to_string(),
34            capabilities: ClientCapabilities::default(),
35            client_info: Implementation {
36                name: "agentox".to_string(),
37                version: Some(env!("CARGO_PKG_VERSION").to_string()),
38            },
39        };
40
41        let response = self.call("initialize", &params).await?;
42        let result: InitializeResult = serde_json::from_value(response).map_err(|e| {
43            SessionError::UnexpectedFormat(format!("invalid initialize result: {e}"))
44        })?;
45
46        // Send the initialized notification
47        let notif = JsonRpcNotification::new("notifications/initialized", None);
48        self.transport
49            .send_notification(&notif)
50            .await
51            .map_err(SessionError::Transport)?;
52
53        // Store server info
54        self.server_capabilities = Some(result.capabilities.clone());
55        self.server_info = Some(result.server_info.clone());
56        self.protocol_version = Some(result.protocol_version.clone());
57
58        tracing::info!(
59            server = %result.server_info.name,
60            version = ?result.server_info.version,
61            protocol = %result.protocol_version,
62            "MCP session initialized"
63        );
64
65        Ok(result)
66    }
67
68    /// List all tools, following pagination cursors.
69    pub async fn list_tools(&mut self) -> Result<Vec<Tool>, SessionError> {
70        let mut all_tools = Vec::new();
71        let mut cursor: Option<String> = None;
72
73        loop {
74            let params = match &cursor {
75                Some(c) => serde_json::json!({ "cursor": c }),
76                None => serde_json::json!({}),
77            };
78
79            let response = self.call("tools/list", &params).await?;
80            let result: ListToolsResult = serde_json::from_value(response).map_err(|e| {
81                SessionError::UnexpectedFormat(format!("invalid tools/list result: {e}"))
82            })?;
83
84            all_tools.extend(result.tools);
85
86            match result.next_cursor {
87                Some(next) if !next.is_empty() => cursor = Some(next),
88                _ => break,
89            }
90        }
91
92        Ok(all_tools)
93    }
94
95    /// Call a specific tool.
96    pub async fn call_tool(
97        &mut self,
98        name: &str,
99        arguments: serde_json::Value,
100    ) -> Result<CallToolResult, SessionError> {
101        let params = CallToolParams {
102            name: name.to_string(),
103            arguments: Some(arguments),
104        };
105        let response = self.call("tools/call", &params).await?;
106        let result: CallToolResult = serde_json::from_value(response).map_err(|e| {
107            SessionError::UnexpectedFormat(format!("invalid tools/call result: {e}"))
108        })?;
109        Ok(result)
110    }
111
112    /// Send a raw string message (bypasses all type checking). Used for fuzzing.
113    pub async fn send_raw(&mut self, raw: &str) -> Result<Option<String>, TransportError> {
114        self.transport.send_raw(raw).await
115    }
116
117    /// Send a typed request and get the raw JSON-RPC response.
118    pub async fn send_request(
119        &mut self,
120        req: &JsonRpcRequest,
121    ) -> Result<JsonRpcResponse, TransportError> {
122        self.transport.send_request(req).await
123    }
124
125    /// Get the server capabilities (available after initialize).
126    pub fn server_capabilities(&self) -> Option<&ServerCapabilities> {
127        self.server_capabilities.as_ref()
128    }
129
130    /// Get the server info (available after initialize).
131    pub fn server_info(&self) -> Option<&Implementation> {
132        self.server_info.as_ref()
133    }
134
135    /// Get the negotiated protocol version.
136    pub fn protocol_version(&self) -> Option<&str> {
137        self.protocol_version.as_deref()
138    }
139
140    /// Shut down the session and underlying transport.
141    pub async fn shutdown(&mut self) -> Result<(), TransportError> {
142        self.transport.shutdown().await
143    }
144
145    /// Get the next request ID.
146    pub fn next_id(&self) -> i64 {
147        self.next_id.fetch_add(1, Ordering::SeqCst)
148    }
149
150    /// Send a typed method call and return the result value.
151    async fn call<P: serde::Serialize>(
152        &mut self,
153        method: &str,
154        params: &P,
155    ) -> Result<serde_json::Value, SessionError> {
156        let id = self.next_id();
157        let params_value = serde_json::to_value(params)
158            .map_err(|e| SessionError::UnexpectedFormat(e.to_string()))?;
159
160        let req = JsonRpcRequest::new(id, method, Some(params_value));
161        let response = self
162            .transport
163            .send_request(&req)
164            .await
165            .map_err(SessionError::Transport)?;
166
167        if let Some(error) = response.error {
168            return Err(SessionError::JsonRpc {
169                code: error.code,
170                message: error.message,
171            });
172        }
173
174        response.result.ok_or_else(|| {
175            SessionError::UnexpectedFormat("response has neither result nor error".to_string())
176        })
177    }
178}