1use std::collections::HashMap;
2use std::future::Future;
3use std::path::{Path, PathBuf};
4use std::pin::Pin;
5use std::process::Child;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
8use std::time::Duration;
9
10use serde::Serialize;
11use serde_json::{Value, json};
12use tokio::sync::{Mutex, RwLock, broadcast, mpsc, oneshot};
13
14use crate::api::{Codex, ResumeThread, Thread, ThreadOptions};
15use crate::error::{ClientError, IncomingClassified, RpcError, classify_incoming};
16use crate::events::{
17 ServerEvent, ServerNotification, ServerRequestEvent, parse_notification, parse_server_request,
18};
19use crate::protocol::requests;
20use crate::protocol::responses;
21use crate::protocol::server_requests;
22use crate::protocol::shared::{EmptyObject, RequestId};
23use crate::transport::TransportHandle;
24use crate::transport::stdio::spawn_stdio_transport;
25use crate::transport::ws::connect_ws_transport;
26use crate::transport::ws_daemon::{ensure_local_ws_app_server, start_ws_server};
27
28type PendingMap = HashMap<RequestId, oneshot::Sender<Result<Value, RpcError>>>;
29type RefreshFuture = Pin<
30 Box<
31 dyn Future<Output = Result<server_requests::ChatgptAuthTokensRefreshResponse, ClientError>>
32 + Send,
33 >,
34>;
35type RefreshHandler =
36 Arc<dyn Fn(server_requests::ChatgptAuthTokensRefreshParams) -> RefreshFuture + Send + Sync>;
37type ApplyPatchApprovalFuture = Pin<
38 Box<
39 dyn Future<Output = Result<server_requests::ApplyPatchApprovalResponse, ClientError>>
40 + Send,
41 >,
42>;
43type ApplyPatchApprovalHandler = Arc<
44 dyn Fn(server_requests::ApplyPatchApprovalParams) -> ApplyPatchApprovalFuture + Send + Sync,
45>;
46type ExecCommandApprovalFuture = Pin<
47 Box<
48 dyn Future<Output = Result<server_requests::ExecCommandApprovalResponse, ClientError>>
49 + Send,
50 >,
51>;
52type ExecCommandApprovalHandler = Arc<
53 dyn Fn(server_requests::ExecCommandApprovalParams) -> ExecCommandApprovalFuture + Send + Sync,
54>;
55type CommandExecutionRequestApprovalFuture = Pin<
56 Box<
57 dyn Future<
58 Output = Result<
59 server_requests::CommandExecutionRequestApprovalResponse,
60 ClientError,
61 >,
62 > + Send,
63 >,
64>;
65type CommandExecutionRequestApprovalHandler = Arc<
66 dyn Fn(
67 server_requests::CommandExecutionRequestApprovalParams,
68 ) -> CommandExecutionRequestApprovalFuture
69 + Send
70 + Sync,
71>;
72type FileChangeRequestApprovalFuture = Pin<
73 Box<
74 dyn Future<Output = Result<server_requests::FileChangeRequestApprovalResponse, ClientError>>
75 + Send,
76 >,
77>;
78type FileChangeRequestApprovalHandler = Arc<
79 dyn Fn(server_requests::FileChangeRequestApprovalParams) -> FileChangeRequestApprovalFuture
80 + Send
81 + Sync,
82>;
83type ToolRequestUserInputFuture = Pin<
84 Box<
85 dyn Future<Output = Result<server_requests::ToolRequestUserInputResponse, ClientError>>
86 + Send,
87 >,
88>;
89type ToolRequestUserInputHandler = Arc<
90 dyn Fn(server_requests::ToolRequestUserInputParams) -> ToolRequestUserInputFuture + Send + Sync,
91>;
92type DynamicToolCallFuture = Pin<
93 Box<dyn Future<Output = Result<server_requests::DynamicToolCallResponse, ClientError>> + Send>,
94>;
95type DynamicToolCallHandler =
96 Arc<dyn Fn(server_requests::DynamicToolCallParams) -> DynamicToolCallFuture + Send + Sync>;
97
98#[derive(Debug, Clone)]
99pub struct ClientOptions {
100 pub default_timeout: Duration,
101}
102
103impl Default for ClientOptions {
104 fn default() -> Self {
105 Self {
106 default_timeout: Duration::from_secs(30),
107 }
108 }
109}
110
111#[derive(Debug, Clone)]
112pub struct StdioConfig {
113 pub codex_binary: String,
114 pub args: Vec<String>,
115 pub env: HashMap<String, String>,
116 pub options: ClientOptions,
117}
118
119impl Default for StdioConfig {
120 fn default() -> Self {
121 Self {
122 codex_binary: "codex".to_string(),
123 args: vec!["app-server".to_string()],
124 env: HashMap::new(),
125 options: ClientOptions::default(),
126 }
127 }
128}
129
130#[derive(Debug, Clone)]
131pub struct WsConfig {
132 pub url: String,
133 pub env: HashMap<String, String>,
134 pub options: ClientOptions,
135}
136
137impl WsConfig {
138 pub fn new(
139 url: impl Into<String>,
140 env: HashMap<String, String>,
141 options: ClientOptions,
142 ) -> Self {
143 Self {
144 url: url.into(),
145 env,
146 options,
147 }
148 }
149
150 pub fn with_url(mut self, url: impl Into<String>) -> Self {
151 self.url = url.into();
152 self
153 }
154
155 pub fn with_env(mut self, env: HashMap<String, String>) -> Self {
156 self.env = env;
157 self
158 }
159}
160
161impl Default for WsConfig {
162 fn default() -> Self {
163 Self {
164 url: String::from("ws://127.0.0.1:4222"),
165 env: HashMap::new(),
166 options: ClientOptions::default(),
167 }
168 }
169}
170
171#[derive(Debug, Clone)]
172pub struct WsStartConfig {
173 pub listen_url: String,
174 pub connect_url: String,
175 pub env: HashMap<String, String>,
176 pub reuse_existing: bool,
177}
178
179impl WsStartConfig {
180 pub fn new(
181 listen_url: impl Into<String>,
182 connect_url: impl Into<String>,
183 env: HashMap<String, String>,
184 ) -> Self {
185 Self {
186 listen_url: listen_url.into(),
187 connect_url: connect_url.into(),
188 env,
189 reuse_existing: true,
190 }
191 }
192
193 pub fn with_listen_url(mut self, listen_url: impl Into<String>) -> Self {
194 self.listen_url = listen_url.into();
195 self
196 }
197
198 pub fn with_connect_url(mut self, connect_url: impl Into<String>) -> Self {
199 self.connect_url = connect_url.into();
200 self
201 }
202
203 pub fn with_env(mut self, env: HashMap<String, String>) -> Self {
204 self.env = env;
205 self
206 }
207
208 pub fn with_reuse_existing(mut self, reuse_existing: bool) -> Self {
209 self.reuse_existing = reuse_existing;
210 self
211 }
212}
213
214impl Default for WsStartConfig {
215 fn default() -> Self {
216 Self {
217 listen_url: String::from("ws://127.0.0.1:4222"),
218 connect_url: String::from("ws://127.0.0.1:4222"),
219 env: HashMap::new(),
220 reuse_existing: true,
221 }
222 }
223}
224
225#[derive(Debug, Clone, Copy, PartialEq, Eq)]
226pub enum WsStartMode {
227 Daemon,
228 Blocking,
229}
230
231#[derive(Debug)]
232pub struct WsServerHandle {
233 listen_url: String,
234 connect_url: String,
235 mode: WsStartMode,
236 reused_existing: bool,
237 log_path: Option<PathBuf>,
238 process_group_id: Option<u32>,
239 child: Option<Child>,
240}
241
242impl WsServerHandle {
243 pub fn listen_url(&self) -> &str {
244 &self.listen_url
245 }
246
247 pub fn connect_url(&self) -> &str {
248 &self.connect_url
249 }
250
251 pub fn mode(&self) -> WsStartMode {
252 self.mode
253 }
254
255 pub fn reused_existing(&self) -> bool {
256 self.reused_existing
257 }
258
259 pub fn started_new_process(&self) -> bool {
260 !self.reused_existing
261 }
262
263 pub fn owns_process(&self) -> bool {
264 self.child.is_some()
265 }
266
267 pub fn log_path(&self) -> Option<&Path> {
268 self.log_path.as_deref()
269 }
270
271 pub fn connect_config(&self, options: ClientOptions) -> WsConfig {
272 WsConfig::new(self.connect_url.clone(), HashMap::new(), options)
273 }
274
275 pub fn shutdown(&mut self) -> Result<(), ClientError> {
276 if let Some(process_group_id) = self.process_group_id.take() {
277 let _ = terminate_process_group(process_group_id);
278 }
279
280 let Some(mut child) = self.child.take() else {
281 return Ok(());
282 };
283
284 for _ in 0..20 {
285 if child.try_wait()?.is_some() {
286 return Ok(());
287 }
288 std::thread::sleep(Duration::from_millis(100));
289 }
290
291 let _ = child.kill();
292 let _ = child.wait()?;
293 Ok(())
294 }
295
296 pub(crate) fn from_reused_existing(
297 listen_url: String,
298 connect_url: String,
299 mode: WsStartMode,
300 log_path: Option<PathBuf>,
301 ) -> Self {
302 Self {
303 listen_url,
304 connect_url,
305 mode,
306 reused_existing: true,
307 log_path,
308 process_group_id: None,
309 child: None,
310 }
311 }
312
313 pub(crate) fn daemon_started(
314 listen_url: String,
315 connect_url: String,
316 log_path: PathBuf,
317 ) -> Self {
318 Self {
319 listen_url,
320 connect_url,
321 mode: WsStartMode::Daemon,
322 reused_existing: false,
323 log_path: Some(log_path),
324 process_group_id: None,
325 child: None,
326 }
327 }
328
329 pub(crate) fn blocking_started(listen_url: String, connect_url: String, child: Child) -> Self {
330 let process_group_id = Some(child.id());
331 Self {
332 listen_url,
333 connect_url,
334 mode: WsStartMode::Blocking,
335 reused_existing: false,
336 log_path: None,
337 process_group_id,
338 child: Some(child),
339 }
340 }
341}
342
343impl Drop for WsServerHandle {
344 fn drop(&mut self) {
345 let _ = self.shutdown();
346 }
347}
348
349#[cfg(unix)]
350fn terminate_process_group(process_group_id: u32) -> std::io::Result<()> {
351 let status = std::process::Command::new("kill")
352 .arg("-TERM")
353 .arg(format!("-{process_group_id}"))
354 .status()?;
355 if status.success() {
356 Ok(())
357 } else {
358 Err(std::io::Error::other(format!(
359 "failed to terminate process group {process_group_id} with status {status}"
360 )))
361 }
362}
363
364#[cfg(not(unix))]
365fn terminate_process_group(_process_group_id: u32) -> std::io::Result<()> {
366 Ok(())
367}
368
369struct Inner {
370 outbound: mpsc::Sender<Value>,
371 pending: Mutex<PendingMap>,
372 default_timeout: Duration,
373 initialized: AtomicBool,
374 ready: AtomicBool,
375 next_id: AtomicI64,
376 event_tx: broadcast::Sender<ServerEvent>,
377 event_rx: Mutex<broadcast::Receiver<ServerEvent>>,
378 refresh_handler: RwLock<Option<RefreshHandler>>,
379 apply_patch_approval_handler: RwLock<Option<ApplyPatchApprovalHandler>>,
380 exec_command_approval_handler: RwLock<Option<ExecCommandApprovalHandler>>,
381 command_execution_request_approval_handler:
382 RwLock<Option<CommandExecutionRequestApprovalHandler>>,
383 file_change_request_approval_handler: RwLock<Option<FileChangeRequestApprovalHandler>>,
384 tool_request_user_input_handler: RwLock<Option<ToolRequestUserInputHandler>>,
385 dynamic_tool_call_handler: RwLock<Option<DynamicToolCallHandler>>,
386}
387
388#[derive(Clone)]
389pub struct CodexClient {
390 inner: Arc<Inner>,
391}
392
393macro_rules! typed_method {
394 ($fn_name:ident, $method:literal, $params_ty:ty, $result_ty:ty) => {
395 pub async fn $fn_name(&self, params: $params_ty) -> Result<$result_ty, ClientError> {
396 self.request_typed_internal($method, params, None, true)
397 .await
398 }
399 };
400}
401
402macro_rules! typed_null_method {
403 ($fn_name:ident, $method:literal, $result_ty:ty) => {
404 pub async fn $fn_name(&self) -> Result<$result_ty, ClientError> {
405 self.request_typed_value_internal($method, Value::Null, None, true)
406 .await
407 }
408 };
409}
410
411impl CodexClient {
412 pub async fn spawn_stdio(config: StdioConfig) -> Result<Self, ClientError> {
413 let handle = spawn_stdio_transport(&config.codex_binary, &config.args, &config.env).await?;
414 Ok(Self::from_transport(handle, config.options.default_timeout))
415 }
416
417 pub async fn connect_ws(config: WsConfig) -> Result<Self, ClientError> {
418 let handle = connect_ws_transport(&config.url).await?;
419 Ok(Self::from_transport(handle, config.options.default_timeout))
420 }
421
422 pub async fn start_ws(config: WsStartConfig) -> Result<WsServerHandle, ClientError> {
423 Self::start_ws_daemon(config).await
424 }
425
426 pub async fn start_ws_daemon(config: WsStartConfig) -> Result<WsServerHandle, ClientError> {
427 start_ws_server(&config, WsStartMode::Daemon).await
428 }
429
430 pub async fn start_ws_blocking(config: WsStartConfig) -> Result<WsServerHandle, ClientError> {
431 start_ws_server(&config, WsStartMode::Blocking).await
432 }
433
434 pub async fn start_and_connect_ws(config: WsConfig) -> Result<Self, ClientError> {
435 ensure_local_ws_app_server(&config.url, &config.env).await?;
436
437 let handle = connect_ws_transport(&config.url).await?;
438 Ok(Self::from_transport(handle, config.options.default_timeout))
439 }
440
441 fn from_transport(handle: TransportHandle, default_timeout: Duration) -> Self {
442 let (event_tx, event_rx) = broadcast::channel(1024);
443 let inner = Arc::new(Inner {
444 outbound: handle.outbound,
445 pending: Mutex::new(HashMap::new()),
446 default_timeout,
447 initialized: AtomicBool::new(false),
448 ready: AtomicBool::new(false),
449 next_id: AtomicI64::new(1),
450 event_tx,
451 event_rx: Mutex::new(event_rx),
452 refresh_handler: RwLock::new(None),
453 apply_patch_approval_handler: RwLock::new(None),
454 exec_command_approval_handler: RwLock::new(None),
455 command_execution_request_approval_handler: RwLock::new(None),
456 file_change_request_approval_handler: RwLock::new(None),
457 tool_request_user_input_handler: RwLock::new(None),
458 dynamic_tool_call_handler: RwLock::new(None),
459 });
460
461 tokio::spawn(run_inbound_loop(handle.inbound, inner.clone()));
462 Self { inner }
463 }
464
465 pub fn as_api(&self) -> Codex {
466 Codex::from_client(self.clone())
467 }
468
469 pub fn start_thread(&self, options: ThreadOptions) -> Thread {
470 self.as_api().start_thread(options)
471 }
472
473 pub fn resume_thread(&self, target: impl Into<ResumeThread>, options: ThreadOptions) -> Thread {
474 self.as_api().resume_thread(target, options)
475 }
476
477 pub fn resume_thread_by_id(&self, id: impl Into<String>, options: ThreadOptions) -> Thread {
478 self.as_api().resume_thread_by_id(id, options)
479 }
480
481 pub fn resume_latest_thread(&self, options: ThreadOptions) -> Thread {
482 self.as_api().resume_latest_thread(options)
483 }
484
485 pub fn subscribe(&self) -> broadcast::Receiver<ServerEvent> {
486 self.inner.event_tx.subscribe()
487 }
488
489 pub async fn next_event(&self) -> Result<ServerEvent, ClientError> {
490 let mut rx = self.inner.event_rx.lock().await;
491 rx.recv().await.map_err(|err| {
492 ClientError::TransportSend(format!("event channel receive failed: {err}"))
493 })
494 }
495
496 pub async fn set_chatgpt_auth_tokens_refresh_handler<F, Fut>(&self, handler: F)
497 where
498 F: Fn(server_requests::ChatgptAuthTokensRefreshParams) -> Fut + Send + Sync + 'static,
499 Fut: Future<Output = Result<server_requests::ChatgptAuthTokensRefreshResponse, ClientError>>
500 + Send
501 + 'static,
502 {
503 let wrapped: RefreshHandler = Arc::new(move |params| Box::pin(handler(params)));
504 *self.inner.refresh_handler.write().await = Some(wrapped);
505 }
506
507 pub async fn clear_chatgpt_auth_tokens_refresh_handler(&self) {
508 *self.inner.refresh_handler.write().await = None;
509 }
510
511 pub async fn set_apply_patch_approval_handler<F, Fut>(&self, handler: F)
512 where
513 F: Fn(server_requests::ApplyPatchApprovalParams) -> Fut + Send + Sync + 'static,
514 Fut: Future<Output = Result<server_requests::ApplyPatchApprovalResponse, ClientError>>
515 + Send
516 + 'static,
517 {
518 let wrapped: ApplyPatchApprovalHandler = Arc::new(move |params| Box::pin(handler(params)));
519 *self.inner.apply_patch_approval_handler.write().await = Some(wrapped);
520 }
521
522 pub async fn clear_apply_patch_approval_handler(&self) {
523 *self.inner.apply_patch_approval_handler.write().await = None;
524 }
525
526 pub async fn set_exec_command_approval_handler<F, Fut>(&self, handler: F)
527 where
528 F: Fn(server_requests::ExecCommandApprovalParams) -> Fut + Send + Sync + 'static,
529 Fut: Future<Output = Result<server_requests::ExecCommandApprovalResponse, ClientError>>
530 + Send
531 + 'static,
532 {
533 let wrapped: ExecCommandApprovalHandler = Arc::new(move |params| Box::pin(handler(params)));
534 *self.inner.exec_command_approval_handler.write().await = Some(wrapped);
535 }
536
537 pub async fn clear_exec_command_approval_handler(&self) {
538 *self.inner.exec_command_approval_handler.write().await = None;
539 }
540
541 pub async fn set_command_execution_request_approval_handler<F, Fut>(&self, handler: F)
542 where
543 F: Fn(server_requests::CommandExecutionRequestApprovalParams) -> Fut
544 + Send
545 + Sync
546 + 'static,
547 Fut: Future<
548 Output = Result<
549 server_requests::CommandExecutionRequestApprovalResponse,
550 ClientError,
551 >,
552 > + Send
553 + 'static,
554 {
555 let wrapped: CommandExecutionRequestApprovalHandler =
556 Arc::new(move |params| Box::pin(handler(params)));
557 *self
558 .inner
559 .command_execution_request_approval_handler
560 .write()
561 .await = Some(wrapped);
562 }
563
564 pub async fn clear_command_execution_request_approval_handler(&self) {
565 *self
566 .inner
567 .command_execution_request_approval_handler
568 .write()
569 .await = None;
570 }
571
572 pub async fn set_file_change_request_approval_handler<F, Fut>(&self, handler: F)
573 where
574 F: Fn(server_requests::FileChangeRequestApprovalParams) -> Fut + Send + Sync + 'static,
575 Fut: Future<Output = Result<server_requests::FileChangeRequestApprovalResponse, ClientError>>
576 + Send
577 + 'static,
578 {
579 let wrapped: FileChangeRequestApprovalHandler =
580 Arc::new(move |params| Box::pin(handler(params)));
581 *self
582 .inner
583 .file_change_request_approval_handler
584 .write()
585 .await = Some(wrapped);
586 }
587
588 pub async fn clear_file_change_request_approval_handler(&self) {
589 *self
590 .inner
591 .file_change_request_approval_handler
592 .write()
593 .await = None;
594 }
595
596 pub async fn set_tool_request_user_input_handler<F, Fut>(&self, handler: F)
597 where
598 F: Fn(server_requests::ToolRequestUserInputParams) -> Fut + Send + Sync + 'static,
599 Fut: Future<Output = Result<server_requests::ToolRequestUserInputResponse, ClientError>>
600 + Send
601 + 'static,
602 {
603 let wrapped: ToolRequestUserInputHandler =
604 Arc::new(move |params| Box::pin(handler(params)));
605 *self.inner.tool_request_user_input_handler.write().await = Some(wrapped);
606 }
607
608 pub async fn clear_tool_request_user_input_handler(&self) {
609 *self.inner.tool_request_user_input_handler.write().await = None;
610 }
611
612 pub async fn set_dynamic_tool_call_handler<F, Fut>(&self, handler: F)
613 where
614 F: Fn(server_requests::DynamicToolCallParams) -> Fut + Send + Sync + 'static,
615 Fut: Future<Output = Result<server_requests::DynamicToolCallResponse, ClientError>>
616 + Send
617 + 'static,
618 {
619 let wrapped: DynamicToolCallHandler = Arc::new(move |params| Box::pin(handler(params)));
620 *self.inner.dynamic_tool_call_handler.write().await = Some(wrapped);
621 }
622
623 pub async fn clear_dynamic_tool_call_handler(&self) {
624 *self.inner.dynamic_tool_call_handler.write().await = None;
625 }
626
627 pub async fn initialize(
628 &self,
629 params: requests::InitializeParams,
630 ) -> Result<responses::InitializeResult, ClientError> {
631 if self.inner.initialized.load(Ordering::SeqCst) {
632 return Err(ClientError::AlreadyInitialized);
633 }
634
635 let result: responses::InitializeResult = self
636 .request_typed_internal("initialize", params, None, false)
637 .await?;
638
639 self.inner.initialized.store(true, Ordering::SeqCst);
640 Ok(result)
641 }
642
643 pub async fn initialized(&self) -> Result<(), ClientError> {
644 if !self.inner.initialized.load(Ordering::SeqCst) {
645 return Err(ClientError::NotInitialized {
646 method: "initialized".to_string(),
647 });
648 }
649 self.send_notification("initialized", EmptyObject::default(), false)
650 .await?;
651 self.inner.ready.store(true, Ordering::SeqCst);
652 Ok(())
653 }
654
655 pub async fn send_raw_request(
656 &self,
657 method: impl Into<String>,
658 params: Value,
659 timeout: Option<Duration>,
660 ) -> Result<Value, ClientError> {
661 let method = method.into();
662 let requires_ready = method != "initialize";
663 self.request_value_internal(&method, params, timeout, requires_ready)
664 .await
665 }
666
667 pub async fn send_raw_notification(
668 &self,
669 method: impl Into<String>,
670 params: Value,
671 ) -> Result<(), ClientError> {
672 let method = method.into();
673 let requires_ready = method != "initialized";
674 self.send_notification(&method, params, requires_ready)
675 .await
676 }
677
678 pub async fn respond_server_request<R: Serialize>(
679 &self,
680 id: RequestId,
681 result: R,
682 ) -> Result<(), ClientError> {
683 let result = serde_json::to_value(result)?;
684 self.send_message(json!({ "id": id, "result": result }))
685 .await
686 }
687
688 pub async fn respond_server_request_error(
689 &self,
690 id: RequestId,
691 error: RpcError,
692 ) -> Result<(), ClientError> {
693 self.send_message(json!({ "id": id, "error": error })).await
694 }
695
696 pub async fn respond_chatgpt_auth_tokens_refresh(
697 &self,
698 id: RequestId,
699 response: server_requests::ChatgptAuthTokensRefreshResponse,
700 ) -> Result<(), ClientError> {
701 self.respond_server_request(id, response).await
702 }
703
704 pub async fn respond_apply_patch_approval(
705 &self,
706 id: RequestId,
707 response: server_requests::ApplyPatchApprovalResponse,
708 ) -> Result<(), ClientError> {
709 self.respond_server_request(id, response).await
710 }
711
712 pub async fn respond_exec_command_approval(
713 &self,
714 id: RequestId,
715 response: server_requests::ExecCommandApprovalResponse,
716 ) -> Result<(), ClientError> {
717 self.respond_server_request(id, response).await
718 }
719
720 pub async fn respond_command_execution_request_approval(
721 &self,
722 id: RequestId,
723 response: server_requests::CommandExecutionRequestApprovalResponse,
724 ) -> Result<(), ClientError> {
725 self.respond_server_request(id, response).await
726 }
727
728 pub async fn respond_file_change_request_approval(
729 &self,
730 id: RequestId,
731 response: server_requests::FileChangeRequestApprovalResponse,
732 ) -> Result<(), ClientError> {
733 self.respond_server_request(id, response).await
734 }
735
736 pub async fn respond_tool_request_user_input(
737 &self,
738 id: RequestId,
739 response: server_requests::ToolRequestUserInputResponse,
740 ) -> Result<(), ClientError> {
741 self.respond_server_request(id, response).await
742 }
743
744 pub async fn respond_dynamic_tool_call(
745 &self,
746 id: RequestId,
747 response: server_requests::DynamicToolCallResponse,
748 ) -> Result<(), ClientError> {
749 self.respond_server_request(id, response).await
750 }
751
752 typed_method!(
753 thread_start,
754 "thread/start",
755 requests::ThreadStartParams,
756 responses::ThreadResult
757 );
758 typed_method!(
759 thread_resume,
760 "thread/resume",
761 requests::ThreadResumeParams,
762 responses::ThreadResult
763 );
764 typed_method!(
765 thread_fork,
766 "thread/fork",
767 requests::ThreadForkParams,
768 responses::ThreadResult
769 );
770 typed_method!(
771 thread_archive,
772 "thread/archive",
773 requests::ThreadArchiveParams,
774 responses::ThreadArchiveResult
775 );
776 typed_method!(
777 thread_name_set,
778 "thread/name/set",
779 requests::ThreadSetNameParams,
780 responses::ThreadSetNameResult
781 );
782 typed_method!(
783 thread_unarchive,
784 "thread/unarchive",
785 requests::ThreadUnarchiveParams,
786 responses::ThreadUnarchiveResult
787 );
788 typed_method!(
789 thread_compact_start,
790 "thread/compact/start",
791 requests::ThreadCompactStartParams,
792 responses::ThreadCompactStartResult
793 );
794 typed_method!(
795 thread_background_terminals_clean,
796 "thread/backgroundTerminals/clean",
797 requests::ThreadBackgroundTerminalsCleanParams,
798 responses::ThreadBackgroundTerminalsCleanResult
799 );
800 typed_method!(
801 thread_rollback,
802 "thread/rollback",
803 requests::ThreadRollbackParams,
804 responses::ThreadRollbackResult
805 );
806 typed_method!(
807 thread_list,
808 "thread/list",
809 requests::ThreadListParams,
810 responses::ThreadListResult
811 );
812 typed_method!(
813 thread_loaded_list,
814 "thread/loaded/list",
815 requests::ThreadLoadedListParams,
816 responses::ThreadLoadedListResult
817 );
818 typed_method!(
819 thread_read,
820 "thread/read",
821 requests::ThreadReadParams,
822 responses::ThreadReadResult
823 );
824 typed_method!(
825 skills_list,
826 "skills/list",
827 requests::SkillsListParams,
828 responses::SkillsListResult
829 );
830 typed_method!(
831 skills_remote_list,
832 "skills/remote/list",
833 requests::SkillsRemoteReadParams,
834 responses::SkillsRemoteReadResult
835 );
836 typed_method!(
837 skills_remote_export,
838 "skills/remote/export",
839 requests::SkillsRemoteWriteParams,
840 responses::SkillsRemoteWriteResult
841 );
842 typed_method!(
843 app_list,
844 "app/list",
845 requests::AppsListParams,
846 responses::AppsListResult
847 );
848 typed_method!(
849 skills_config_write,
850 "skills/config/write",
851 requests::SkillsConfigWriteParams,
852 responses::SkillsConfigWriteResult
853 );
854 typed_method!(
855 turn_start,
856 "turn/start",
857 requests::TurnStartParams,
858 responses::TurnResult
859 );
860 typed_method!(
861 turn_steer,
862 "turn/steer",
863 requests::TurnSteerParams,
864 responses::TurnSteerResult
865 );
866 typed_method!(
867 turn_interrupt,
868 "turn/interrupt",
869 requests::TurnInterruptParams,
870 EmptyObject
871 );
872 typed_method!(
873 review_start,
874 "review/start",
875 requests::ReviewStartParams,
876 responses::ReviewStartResult
877 );
878 typed_method!(
879 model_list,
880 "model/list",
881 requests::ModelListParams,
882 responses::ModelListResult
883 );
884 typed_method!(
885 experimental_feature_list,
886 "experimentalFeature/list",
887 requests::ExperimentalFeatureListParams,
888 responses::ExperimentalFeatureListResult
889 );
890 typed_method!(
891 collaboration_mode_list,
892 "collaborationMode/list",
893 requests::CollaborationModeListParams,
894 responses::CollaborationModeListResult
895 );
896 typed_method!(
897 mock_experimental_method,
898 "mock/experimentalMethod",
899 requests::MockExperimentalMethodParams,
900 responses::MockExperimentalMethodResult
901 );
902 typed_method!(
903 mcp_server_oauth_login,
904 "mcpServer/oauth/login",
905 requests::McpServerOauthLoginParams,
906 responses::McpServerOauthLoginResult
907 );
908 typed_method!(
909 mcp_server_status_list,
910 "mcpServerStatus/list",
911 requests::ListMcpServerStatusParams,
912 responses::McpServerStatusListResult
913 );
914 typed_method!(
915 windows_sandbox_setup_start,
916 "windowsSandbox/setupStart",
917 requests::WindowsSandboxSetupStartParams,
918 responses::WindowsSandboxSetupStartResult
919 );
920 typed_method!(
921 account_login_start,
922 "account/login/start",
923 requests::LoginAccountParams,
924 responses::LoginAccountResult
925 );
926 typed_method!(
927 account_login_cancel,
928 "account/login/cancel",
929 requests::CancelLoginAccountParams,
930 EmptyObject
931 );
932 typed_method!(
933 feedback_upload,
934 "feedback/upload",
935 requests::FeedbackUploadParams,
936 responses::FeedbackUploadResult
937 );
938 typed_method!(
939 command_exec,
940 "command/exec",
941 requests::CommandExecParams,
942 responses::CommandExecResult
943 );
944 typed_method!(
945 config_read,
946 "config/read",
947 requests::ConfigReadParams,
948 responses::ConfigReadResult
949 );
950 typed_method!(
951 config_value_write,
952 "config/value/write",
953 requests::ConfigValueWriteParams,
954 responses::ConfigValueWriteResult
955 );
956 typed_method!(
957 config_batch_write,
958 "config/batchWrite",
959 requests::ConfigBatchWriteParams,
960 responses::ConfigBatchWriteResult
961 );
962 typed_method!(
963 account_read,
964 "account/read",
965 requests::GetAccountParams,
966 responses::GetAccountResult
967 );
968 typed_method!(
969 fuzzy_file_search_session_start,
970 "fuzzyFileSearch/sessionStart",
971 requests::FuzzyFileSearchSessionStartParams,
972 responses::FuzzyFileSearchSessionStartResult
973 );
974 typed_method!(
975 fuzzy_file_search_session_update,
976 "fuzzyFileSearch/sessionUpdate",
977 requests::FuzzyFileSearchSessionUpdateParams,
978 responses::FuzzyFileSearchSessionUpdateResult
979 );
980 typed_method!(
981 fuzzy_file_search_session_stop,
982 "fuzzyFileSearch/sessionStop",
983 requests::FuzzyFileSearchSessionStopParams,
984 responses::FuzzyFileSearchSessionStopResult
985 );
986
987 pub async fn skills_remote_read(
989 &self,
990 params: requests::SkillsRemoteReadParams,
991 ) -> Result<responses::SkillsRemoteReadResult, ClientError> {
992 self.skills_remote_list(params).await
993 }
994
995 pub async fn skills_remote_write(
996 &self,
997 params: requests::SkillsRemoteWriteParams,
998 ) -> Result<responses::SkillsRemoteWriteResult, ClientError> {
999 self.skills_remote_export(params).await
1000 }
1001
1002 typed_null_method!(
1003 config_mcp_server_reload,
1004 "config/mcpServer/reload",
1005 EmptyObject
1006 );
1007 typed_null_method!(account_logout, "account/logout", EmptyObject);
1008 typed_null_method!(
1009 account_rate_limits_read,
1010 "account/rateLimits/read",
1011 responses::AccountRateLimitsReadResult
1012 );
1013 typed_null_method!(
1014 config_requirements_read,
1015 "configRequirements/read",
1016 responses::ConfigRequirementsReadResult
1017 );
1018
1019 async fn send_notification<P: Serialize>(
1020 &self,
1021 method: &str,
1022 params: P,
1023 requires_ready: bool,
1024 ) -> Result<(), ClientError> {
1025 if requires_ready && !self.inner.ready.load(Ordering::SeqCst) {
1026 return Err(ClientError::NotReady {
1027 method: method.to_string(),
1028 });
1029 }
1030
1031 let value = serde_json::to_value(params)?;
1032 self.send_message(json!({ "method": method, "params": value }))
1033 .await
1034 }
1035
1036 async fn request_typed_internal<P, R>(
1037 &self,
1038 method: &str,
1039 params: P,
1040 timeout: Option<Duration>,
1041 requires_ready: bool,
1042 ) -> Result<R, ClientError>
1043 where
1044 P: Serialize,
1045 R: serde::de::DeserializeOwned,
1046 {
1047 let value = serde_json::to_value(params)?;
1048 self.request_typed_value_internal(method, value, timeout, requires_ready)
1049 .await
1050 }
1051
1052 async fn request_typed_value_internal<R>(
1053 &self,
1054 method: &str,
1055 params: Value,
1056 timeout: Option<Duration>,
1057 requires_ready: bool,
1058 ) -> Result<R, ClientError>
1059 where
1060 R: serde::de::DeserializeOwned,
1061 {
1062 let raw = self
1063 .request_value_internal(method, params, timeout, requires_ready)
1064 .await?;
1065
1066 serde_json::from_value(raw).map_err(|source| ClientError::UnexpectedResult {
1067 method: method.to_string(),
1068 source,
1069 })
1070 }
1071
1072 async fn request_value_internal(
1073 &self,
1074 method: &str,
1075 params: Value,
1076 timeout: Option<Duration>,
1077 requires_ready: bool,
1078 ) -> Result<Value, ClientError> {
1079 if requires_ready && !self.inner.ready.load(Ordering::SeqCst) {
1080 return Err(ClientError::NotReady {
1081 method: method.to_string(),
1082 });
1083 }
1084
1085 if method == "initialize" && self.inner.initialized.load(Ordering::SeqCst) {
1086 return Err(ClientError::AlreadyInitialized);
1087 }
1088
1089 let id_num = self.inner.next_id.fetch_add(1, Ordering::SeqCst);
1090 let id = RequestId::Integer(id_num);
1091
1092 let request = json!({
1093 "method": method,
1094 "id": id,
1095 "params": params,
1096 });
1097
1098 let (tx, rx) = oneshot::channel();
1099 self.inner.pending.lock().await.insert(id.clone(), tx);
1100
1101 if let Err(err) = self.send_message(request).await {
1102 self.inner.pending.lock().await.remove(&id);
1103 return Err(err);
1104 }
1105
1106 let timeout = timeout.unwrap_or(self.inner.default_timeout);
1107 match tokio::time::timeout(timeout, rx).await {
1108 Ok(Ok(Ok(value))) => Ok(value),
1109 Ok(Ok(Err(error))) => Err(ClientError::Rpc { error }),
1110 Ok(Err(_)) => Err(ClientError::TransportClosed),
1111 Err(_) => {
1112 self.inner.pending.lock().await.remove(&id);
1113 Err(ClientError::Timeout {
1114 method: method.to_string(),
1115 timeout_ms: timeout.as_millis() as u64,
1116 })
1117 }
1118 }
1119 }
1120
1121 async fn send_message(&self, value: Value) -> Result<(), ClientError> {
1122 self.inner.outbound.send(value).await.map_err(|err| {
1123 ClientError::TransportSend(format!("failed to send outbound frame: {err}"))
1124 })
1125 }
1126}
1127
1128async fn run_inbound_loop(
1129 mut inbound: mpsc::Receiver<Result<Value, ClientError>>,
1130 inner: Arc<Inner>,
1131) {
1132 while let Some(frame) = inbound.recv().await {
1133 match frame {
1134 Ok(value) => {
1135 if let Err(err) = process_incoming_value(value, &inner).await {
1136 fail_all_pending(&inner, &format!("processing inbound frame failed: {err}"))
1137 .await;
1138 let _ = inner.event_tx.send(ServerEvent::TransportClosed);
1139 break;
1140 }
1141 }
1142 Err(err) => {
1143 fail_all_pending(&inner, &format!("transport error: {err}")).await;
1144 let _ = inner.event_tx.send(ServerEvent::TransportClosed);
1145 break;
1146 }
1147 }
1148 }
1149}
1150
1151async fn process_incoming_value(value: Value, inner: &Arc<Inner>) -> Result<(), ClientError> {
1152 match classify_incoming(value)? {
1153 IncomingClassified::Response { id, result } => {
1154 if let Some(sender) = inner.pending.lock().await.remove(&id) {
1155 let _ = sender.send(result);
1156 }
1157 }
1158 IncomingClassified::Notification {
1159 method,
1160 params,
1161 raw: _,
1162 } => {
1163 let parsed = parse_notification(method.clone(), params.clone())
1164 .unwrap_or(ServerNotification::Unknown { method, params });
1165 let _ = inner.event_tx.send(ServerEvent::Notification(parsed));
1166 }
1167 IncomingClassified::ServerRequest {
1168 id,
1169 method,
1170 params,
1171 raw: _,
1172 } => {
1173 let parsed = parse_server_request(id.clone(), method.clone(), params.clone())
1174 .unwrap_or(ServerRequestEvent::Unknown { id, method, params });
1175 if !try_auto_handle_server_request(inner, &parsed).await {
1176 let _ = inner.event_tx.send(ServerEvent::ServerRequest(parsed));
1177 }
1178 }
1179 }
1180 Ok(())
1181}
1182
1183async fn try_auto_handle_server_request(inner: &Arc<Inner>, request: &ServerRequestEvent) -> bool {
1184 match request {
1185 ServerRequestEvent::ChatgptAuthTokensRefresh { id, params } => {
1186 let handler = inner.refresh_handler.read().await.clone();
1187 let Some(handler) = handler else {
1188 return false;
1189 };
1190
1191 let response = handler(params.clone()).await;
1192 send_server_request_handler_result(inner, id, response, "chatgptAuthTokens refresh")
1193 .await
1194 }
1195 ServerRequestEvent::ApplyPatchApproval { id, params } => {
1196 let handler = inner.apply_patch_approval_handler.read().await.clone();
1197 let Some(handler) = handler else {
1198 return false;
1199 };
1200
1201 let response = handler(params.clone()).await;
1202 send_server_request_handler_result(inner, id, response, "applyPatchApproval").await
1203 }
1204 ServerRequestEvent::ExecCommandApproval { id, params } => {
1205 let handler = inner.exec_command_approval_handler.read().await.clone();
1206 let Some(handler) = handler else {
1207 return false;
1208 };
1209
1210 let response = handler(params.clone()).await;
1211 send_server_request_handler_result(inner, id, response, "execCommandApproval").await
1212 }
1213 ServerRequestEvent::CommandExecutionRequestApproval { id, params } => {
1214 let handler = inner
1215 .command_execution_request_approval_handler
1216 .read()
1217 .await
1218 .clone();
1219 let Some(handler) = handler else {
1220 return false;
1221 };
1222
1223 let response = handler(params.clone()).await;
1224 send_server_request_handler_result(
1225 inner,
1226 id,
1227 response,
1228 "item/commandExecution/requestApproval",
1229 )
1230 .await
1231 }
1232 ServerRequestEvent::FileChangeRequestApproval { id, params } => {
1233 let handler = inner
1234 .file_change_request_approval_handler
1235 .read()
1236 .await
1237 .clone();
1238 let Some(handler) = handler else {
1239 return false;
1240 };
1241
1242 let response = handler(params.clone()).await;
1243 send_server_request_handler_result(
1244 inner,
1245 id,
1246 response,
1247 "item/fileChange/requestApproval",
1248 )
1249 .await
1250 }
1251 ServerRequestEvent::ToolRequestUserInput { id, params } => {
1252 let handler = inner.tool_request_user_input_handler.read().await.clone();
1253 let Some(handler) = handler else {
1254 return false;
1255 };
1256
1257 let response = handler(params.clone()).await;
1258 send_server_request_handler_result(inner, id, response, "item/tool/requestUserInput")
1259 .await
1260 }
1261 ServerRequestEvent::DynamicToolCall { id, params } => {
1262 let handler = inner.dynamic_tool_call_handler.read().await.clone();
1263 let Some(handler) = handler else {
1264 return false;
1265 };
1266
1267 let response = handler(params.clone()).await;
1268 send_server_request_handler_result(inner, id, response, "item/tool/call").await
1269 }
1270 _ => false,
1271 }
1272}
1273
1274async fn send_server_request_handler_result<R: Serialize>(
1275 inner: &Arc<Inner>,
1276 id: &RequestId,
1277 response: Result<R, ClientError>,
1278 context: &str,
1279) -> bool {
1280 let payload = match response {
1281 Ok(result) => json!({ "id": id, "result": result }),
1282 Err(err) => json!({
1283 "id": id,
1284 "error": {
1285 "code": -32001,
1286 "message": format!("{context} handler failed: {err}")
1287 }
1288 }),
1289 };
1290
1291 if inner.outbound.send(payload).await.is_err() {
1292 let _ = inner.event_tx.send(ServerEvent::TransportClosed);
1293 }
1294
1295 true
1296}
1297
1298async fn fail_all_pending(inner: &Arc<Inner>, message: &str) {
1299 let mut pending = inner.pending.lock().await;
1300 let entries = std::mem::take(&mut *pending);
1301 drop(pending);
1302
1303 for (_, sender) in entries {
1304 let _ = sender.send(Err(RpcError {
1305 code: -32098,
1306 message: message.to_string(),
1307 data: None,
1308 }));
1309 }
1310}
1311
1312#[cfg(test)]
1313mod tests {
1314 use super::*;
1315 use tokio::time::{Duration, timeout};
1316
1317 fn test_client() -> (
1318 CodexClient,
1319 mpsc::Sender<Result<Value, ClientError>>,
1320 mpsc::Receiver<Value>,
1321 ) {
1322 let (transport_outbound_tx, transport_outbound_rx) = mpsc::channel::<Value>(32);
1323 let (transport_inbound_tx, transport_inbound_rx) =
1324 mpsc::channel::<Result<Value, ClientError>>(32);
1325 let client = CodexClient::from_transport(
1326 TransportHandle {
1327 outbound: transport_outbound_tx,
1328 inbound: transport_inbound_rx,
1329 },
1330 Duration::from_secs(5),
1331 );
1332 (client, transport_inbound_tx, transport_outbound_rx)
1333 }
1334
1335 #[tokio::test]
1336 async fn auto_handles_apply_patch_approval_when_handler_registered() {
1337 let (client, inbound_tx, mut outbound_rx) = test_client();
1338
1339 client
1340 .set_apply_patch_approval_handler(|_| async {
1341 let mut response = server_requests::ApplyPatchApprovalResponse::default();
1342 response
1343 .extra
1344 .insert("decision".to_string(), Value::String("approve".to_string()));
1345 Ok(response)
1346 })
1347 .await;
1348
1349 inbound_tx
1350 .send(Ok(json!({
1351 "id": 42,
1352 "method": "applyPatchApproval",
1353 "params": {}
1354 })))
1355 .await
1356 .expect("send inbound server request");
1357
1358 let outbound = timeout(Duration::from_secs(2), outbound_rx.recv())
1359 .await
1360 .expect("timed out waiting for outbound response")
1361 .expect("expected outbound response frame");
1362 assert_eq!(outbound.get("id"), Some(&json!(42)));
1363 assert_eq!(
1364 outbound.pointer("/result/decision"),
1365 Some(&Value::String("approve".to_string()))
1366 );
1367 }
1368
1369 #[tokio::test]
1370 async fn unhandled_server_request_is_published_as_event() {
1371 let (client, inbound_tx, mut outbound_rx) = test_client();
1372
1373 inbound_tx
1374 .send(Ok(json!({
1375 "id": 7,
1376 "method": "applyPatchApproval",
1377 "params": {}
1378 })))
1379 .await
1380 .expect("send inbound server request");
1381
1382 let event = timeout(Duration::from_secs(2), client.next_event())
1383 .await
1384 .expect("timed out waiting for event")
1385 .expect("event receive");
1386
1387 match event {
1388 ServerEvent::ServerRequest(ServerRequestEvent::ApplyPatchApproval { id, .. }) => {
1389 assert_eq!(id, RequestId::Integer(7));
1390 }
1391 other => panic!("unexpected event: {other:?}"),
1392 }
1393
1394 assert!(
1395 timeout(Duration::from_millis(200), outbound_rx.recv())
1396 .await
1397 .is_err(),
1398 "did not expect auto-response when handler is absent"
1399 );
1400 }
1401}