Skip to main content

codex/mcp/
client.rs

1use std::{collections::HashSet, io, path::PathBuf, sync::Arc, time::Duration};
2
3use serde_json::{json, Value};
4use thiserror::Error;
5
6use super::{
7    AppCallHandle, ApprovalDecision, ClientInfo, CodexCallHandle, CodexCallParams, CodexCallResult,
8    CodexReplyParams, InitializeParams, RequestId, StdioServerConfig, METHOD_CODEX,
9    METHOD_CODEX_APPROVAL, METHOD_THREAD_FORK, METHOD_THREAD_LIST, METHOD_THREAD_RESUME,
10    METHOD_THREAD_START, METHOD_TURN_INTERRUPT, METHOD_TURN_START,
11};
12
13use super::jsonrpc::{map_response, JsonRpcTransport};
14
15/// Errors surfaced while managing MCP/app-server transports.
16#[derive(Debug, Error)]
17pub enum McpError {
18    #[error("failed to spawn `{command}`: {source}")]
19    Spawn {
20        command: String,
21        #[source]
22        source: io::Error,
23    },
24    #[error("server did not respond to initialize: {0}")]
25    Handshake(String),
26    #[error("transport task failed: {0}")]
27    Transport(String),
28    #[error("server returned JSON-RPC error {code}: {message}")]
29    Rpc {
30        code: i64,
31        message: String,
32        data: Option<Value>,
33    },
34    #[error("server reported an error: {0}")]
35    Server(String),
36    #[error("request was cancelled")]
37    Cancelled,
38    #[error("timed out after {0:?}")]
39    Timeout(Duration),
40    #[error("serialization failed: {0}")]
41    Serialization(#[from] serde_json::Error),
42    #[error("transport channel closed unexpectedly")]
43    ChannelClosed,
44}
45
46/// Client wrapper around the stdio MCP server.
47pub struct CodexMcpServer {
48    transport: Arc<JsonRpcTransport>,
49}
50
51impl CodexMcpServer {
52    /// Launch `codex mcp-server`, issue `initialize`, and return a connected handle.
53    pub async fn start(config: StdioServerConfig, client: ClientInfo) -> Result<Self, McpError> {
54        Self::with_capabilities(config, client, Value::Object(Default::default())).await
55    }
56
57    /// Launch with explicit capabilities to send during `initialize`.
58    pub async fn with_capabilities(
59        config: StdioServerConfig,
60        client: ClientInfo,
61        capabilities: Value,
62    ) -> Result<Self, McpError> {
63        let capabilities = match capabilities {
64            Value::Null => Value::Object(Default::default()),
65            other => other,
66        };
67        let transport = JsonRpcTransport::spawn_mcp(config).await?;
68        let params = InitializeParams {
69            client,
70            protocol_version: "2024-11-05".to_string(),
71            capabilities,
72        };
73
74        transport
75            .initialize(params, transport.startup_timeout())
76            .await
77            .map_err(|err| McpError::Handshake(err.to_string()))?;
78
79        Ok(Self {
80            transport: Arc::new(transport),
81        })
82    }
83
84    /// Send a new Codex prompt via `codex/codex`.
85    pub async fn codex(&self, params: CodexCallParams) -> Result<CodexCallHandle, McpError> {
86        self.invoke_tool_call("codex", serde_json::to_value(params)?)
87            .await
88    }
89
90    /// Continue an existing conversation via `codex/codex-reply`.
91    pub async fn codex_reply(&self, params: CodexReplyParams) -> Result<CodexCallHandle, McpError> {
92        self.invoke_tool_call("codex-reply", serde_json::to_value(params)?)
93            .await
94    }
95
96    /// Send an approval decision back to the MCP server.
97    pub async fn send_approval(&self, decision: ApprovalDecision) -> Result<(), McpError> {
98        let (_, rx) = self
99            .transport
100            .request(METHOD_CODEX_APPROVAL, serde_json::to_value(decision)?)
101            .await?;
102
103        match rx.await {
104            Ok(Ok(_)) => Ok(()),
105            Ok(Err(err)) => Err(err),
106            Err(_) => Err(McpError::ChannelClosed),
107        }
108    }
109
110    /// Request cancellation for a pending call.
111    pub fn cancel(&self, request_id: RequestId) -> Result<(), McpError> {
112        self.transport.cancel(request_id)
113    }
114
115    /// Gracefully shut down the MCP server.
116    pub async fn shutdown(&self) -> Result<(), McpError> {
117        self.transport.shutdown().await
118    }
119
120    async fn invoke_tool_call(
121        &self,
122        tool_name: &str,
123        arguments: Value,
124    ) -> Result<CodexCallHandle, McpError> {
125        let events = self.transport.register_codex_listener().await;
126        let request = json!({
127            "name": tool_name,
128            "arguments": arguments,
129        });
130        let (request_id, raw_response) = self.transport.request(METHOD_CODEX, request).await?;
131        let response = map_response::<CodexCallResult>(raw_response);
132
133        Ok(CodexCallHandle {
134            request_id,
135            events,
136            response,
137        })
138    }
139}
140
141/// Client wrapper around the stdio app-server.
142pub struct CodexAppServer {
143    transport: Arc<JsonRpcTransport>,
144}
145
146impl CodexAppServer {
147    /// Launch `codex app-server`, issue `initialize`, and return a connected handle.
148    pub async fn start(config: StdioServerConfig, client: ClientInfo) -> Result<Self, McpError> {
149        Self::with_capabilities(config, client, Value::Object(Default::default())).await
150    }
151
152    /// Launch with `capabilities.experimentalApi=true` in the `initialize` handshake.
153    pub async fn start_experimental(
154        config: StdioServerConfig,
155        client: ClientInfo,
156    ) -> Result<Self, McpError> {
157        Self::with_capabilities(config, client, json!({ "experimentalApi": true })).await
158    }
159
160    /// Launch with explicit capabilities to send during `initialize`.
161    pub async fn with_capabilities(
162        config: StdioServerConfig,
163        client: ClientInfo,
164        capabilities: Value,
165    ) -> Result<Self, McpError> {
166        let capabilities = match capabilities {
167            Value::Null => Value::Object(Default::default()),
168            other => other,
169        };
170        let transport = JsonRpcTransport::spawn_app(config).await?;
171        let params = InitializeParams {
172            client,
173            protocol_version: "2024-11-05".to_string(),
174            capabilities,
175        };
176
177        transport
178            .initialize(params, transport.startup_timeout())
179            .await
180            .map_err(|err| McpError::Handshake(err.to_string()))?;
181
182        Ok(Self {
183            transport: Arc::new(transport),
184        })
185    }
186
187    /// Start a new thread (or use a provided ID) via `thread/start`.
188    pub async fn thread_start(
189        &self,
190        params: super::ThreadStartParams,
191    ) -> Result<AppCallHandle, McpError> {
192        self.invoke_app_call(METHOD_THREAD_START, serde_json::to_value(params)?)
193            .await
194    }
195
196    /// Resume an existing thread via `thread/resume`.
197    pub async fn thread_resume(
198        &self,
199        params: super::ThreadResumeParams,
200    ) -> Result<AppCallHandle, McpError> {
201        self.invoke_app_call(METHOD_THREAD_RESUME, serde_json::to_value(params)?)
202            .await
203    }
204
205    /// List threads via `thread/list`.
206    pub async fn thread_list(
207        &self,
208        params: super::ThreadListParams,
209    ) -> Result<super::ThreadListResponse, McpError> {
210        let (_, rx) = self
211            .transport
212            .request(METHOD_THREAD_LIST, serde_json::to_value(params)?)
213            .await?;
214        let mapped = map_response::<super::ThreadListResponse>(rx);
215        match mapped.await {
216            Ok(result) => result,
217            Err(_) => Err(McpError::ChannelClosed),
218        }
219    }
220
221    /// Fork an existing thread via `thread/fork`.
222    pub async fn thread_fork(
223        &self,
224        params: super::ThreadForkParams,
225    ) -> Result<super::ThreadForkResponse, McpError> {
226        let (_, rx) = self
227            .transport
228            .request(METHOD_THREAD_FORK, serde_json::to_value(params)?)
229            .await?;
230        let mapped = map_response::<super::ThreadForkResponse>(rx);
231        match mapped.await {
232            Ok(result) => result,
233            Err(_) => Err(McpError::ChannelClosed),
234        }
235    }
236
237    /// Start a new turn on a thread via `turn/start`.
238    pub async fn turn_start(
239        &self,
240        params: super::TurnStartParams,
241    ) -> Result<AppCallHandle, McpError> {
242        self.invoke_app_call(METHOD_TURN_START, serde_json::to_value(params)?)
243            .await
244    }
245
246    /// Start a new turn on a thread via `turn/start` (pinned fork flow subset).
247    pub async fn turn_start_v2(
248        &self,
249        params: super::TurnStartParamsV2,
250    ) -> Result<AppCallHandle, McpError> {
251        self.invoke_app_call(METHOD_TURN_START, serde_json::to_value(params)?)
252            .await
253    }
254
255    /// Select the deterministic "last" thread id using `thread/list` paging and tuple ordering.
256    pub async fn select_last_thread_id(&self, cwd: PathBuf) -> Result<Option<String>, McpError> {
257        let mut cursor: Option<String> = None;
258        let mut seen_cursors: HashSet<String> = HashSet::new();
259        let mut best: Option<(i64, i64, String)> = None;
260
261        loop {
262            let page = self
263                .thread_list(super::ThreadListParams {
264                    cwd: Some(cwd.clone()),
265                    cursor: cursor.clone(),
266                    limit: Some(100),
267                    sort_key: Some(super::ThreadListSortKey::UpdatedAt),
268                    archived: None,
269                    model_providers: None,
270                    source_kinds: None,
271                })
272                .await?;
273
274            for thread in page.data {
275                let candidate = (thread.updated_at, thread.created_at, thread.id);
276                let should_replace = match best.as_ref() {
277                    None => true,
278                    Some(current) => {
279                        (candidate.0, candidate.1, &candidate.2)
280                            > (current.0, current.1, &current.2)
281                    }
282                };
283
284                if should_replace {
285                    best = Some(candidate);
286                }
287            }
288
289            let Some(next_cursor) = page.next_cursor else {
290                break;
291            };
292
293            if !seen_cursors.insert(next_cursor.clone()) {
294                return Err(McpError::Transport(format!(
295                    "thread/list pagination cursor repeated: {next_cursor}"
296                )));
297            }
298            cursor = Some(next_cursor);
299        }
300
301        Ok(best.map(|(_, _, id)| id))
302    }
303
304    /// Interrupt an active turn via `turn/interrupt`.
305    pub async fn turn_interrupt(
306        &self,
307        params: super::TurnInterruptParams,
308    ) -> Result<AppCallHandle, McpError> {
309        self.invoke_app_call(METHOD_TURN_INTERRUPT, serde_json::to_value(params)?)
310            .await
311    }
312
313    /// Request cancellation for a pending call.
314    pub fn cancel(&self, request_id: RequestId) -> Result<(), McpError> {
315        self.transport.cancel(request_id)
316    }
317
318    /// Gracefully shut down the app-server.
319    pub async fn shutdown(&self) -> Result<(), McpError> {
320        self.transport.shutdown().await
321    }
322
323    async fn invoke_app_call(
324        &self,
325        method: &str,
326        params: Value,
327    ) -> Result<AppCallHandle, McpError> {
328        let events = self.transport.register_app_listener().await;
329        let (request_id, raw_response) = self.transport.request(method, params).await?;
330        let response = map_response::<Value>(raw_response);
331
332        Ok(AppCallHandle {
333            request_id,
334            events,
335            response,
336        })
337    }
338}