1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3#![deny(rustdoc::broken_intra_doc_links)]
4#![cfg_attr(test, allow(clippy::unwrap_used))]
5
6pub mod canvas;
8mod canvas_dispatch;
9#[cfg(feature = "bundled-cli")]
11pub(crate) mod embeddedcli;
12mod errors;
13pub use errors::*;
14pub mod copilot_request_handler;
18#[doc(hidden)]
21pub mod github_telemetry;
22pub mod handler;
24pub mod hooks;
26mod jsonrpc;
27pub mod permission;
29pub mod provider_token;
31mod provider_token_dispatch;
32pub(crate) mod resolve;
34mod router;
35pub mod session;
37pub mod session_fs;
39mod session_fs_dispatch;
40pub mod subscription;
42pub mod tool;
44pub mod trace_context;
46pub mod transforms;
48pub mod types;
50mod wire;
51
52pub mod session_events;
54
55pub mod rpc;
58
59pub(crate) mod generated;
64
65pub mod mode;
68
69use std::ffi::OsString;
70use std::path::{Path, PathBuf};
71use std::process::Stdio;
72use std::sync::{Arc, OnceLock};
73use std::time::{Duration, Instant};
74
75use async_trait::async_trait;
76pub(crate) use jsonrpc::{
79 JsonRpcClient, JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, error_codes,
80};
81pub use mode::{BUILTIN_TOOLS_ISOLATED, ClientMode, ToolSet};
82pub use provider_token::{BearerTokenError, BearerTokenProvider, ProviderTokenArgs};
83
84#[cfg(feature = "test-support")]
86pub mod test_support {
87 pub use crate::jsonrpc::{
88 JsonRpcClient, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse,
89 error_codes,
90 };
91}
92use serde::{Deserialize, Serialize};
93use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, BufReader};
94use tokio::net::TcpStream;
95use tokio::process::{Child, Command};
96use tokio::sync::{broadcast, mpsc, oneshot};
97use tracing::{Instrument, debug, error, info, warn};
98pub use types::*;
99
100mod sdk_protocol_version;
101pub use sdk_protocol_version::{SDK_PROTOCOL_VERSION, get_sdk_protocol_version};
102pub use subscription::{EventSubscription, LifecycleSubscription};
103
104const MIN_PROTOCOL_VERSION: u32 = 3;
106const RUNTIME_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
107
108#[derive(Debug, Default)]
110#[non_exhaustive]
111pub enum Transport {
112 #[default]
114 Stdio,
115 Tcp {
117 port: u16,
119 connection_token: Option<String>,
123 },
124 External {
126 host: String,
128 port: u16,
130 connection_token: Option<String>,
133 },
134}
135
136#[derive(Debug, Clone, Default)]
138pub enum CliProgram {
139 #[default]
142 Resolve,
143 Path(PathBuf),
145}
146
147impl From<PathBuf> for CliProgram {
148 fn from(path: PathBuf) -> Self {
149 Self::Path(path)
150 }
151}
152
153pub const HAS_BUNDLED_CLI: bool = cfg!(has_bundled_cli);
160
161pub fn install_bundled_cli() -> Option<PathBuf> {
185 #[cfg(feature = "bundled-cli")]
186 {
187 embeddedcli::path()
188 }
189 #[cfg(not(feature = "bundled-cli"))]
190 {
191 None
192 }
193}
194
195#[non_exhaustive]
205pub struct ClientOptions {
206 pub program: CliProgram,
208 pub prefix_args: Vec<OsString>,
210 pub working_directory: PathBuf,
212 pub env: Vec<(OsString, OsString)>,
214 pub env_remove: Vec<OsString>,
216 pub extra_args: Vec<String>,
218 pub transport: Transport,
220 pub github_token: Option<String>,
225 pub use_logged_in_user: Option<bool>,
229 pub log_level: Option<LogLevel>,
233 pub session_idle_timeout_seconds: Option<u64>,
239 pub on_list_models: Option<Arc<dyn ListModelsHandler>>,
247 pub session_fs: Option<SessionFsConfig>,
255 pub request_handler: Option<Arc<dyn crate::copilot_request_handler::CopilotRequestHandler>>,
264 #[doc(hidden)]
272 pub on_github_telemetry: Option<crate::github_telemetry::GitHubTelemetryCallback>,
273 pub on_get_trace_context: Option<Arc<dyn TraceContextProvider>>,
283 pub telemetry: Option<TelemetryConfig>,
287 pub base_directory: Option<PathBuf>,
292 pub enable_remote_sessions: bool,
298 pub bundled_cli_extract_dir: Option<PathBuf>,
317 pub mode: ClientMode,
321}
322
323impl std::fmt::Debug for ClientOptions {
324 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 f.debug_struct("ClientOptions")
326 .field("program", &self.program)
327 .field("prefix_args", &self.prefix_args)
328 .field("working_directory", &self.working_directory)
329 .field("env", &self.env)
330 .field("env_remove", &self.env_remove)
331 .field("extra_args", &self.extra_args)
332 .field("transport", &self.transport)
333 .field(
334 "github_token",
335 &self.github_token.as_ref().map(|_| "<redacted>"),
336 )
337 .field("use_logged_in_user", &self.use_logged_in_user)
338 .field("log_level", &self.log_level)
339 .field(
340 "session_idle_timeout_seconds",
341 &self.session_idle_timeout_seconds,
342 )
343 .field(
344 "on_list_models",
345 &self.on_list_models.as_ref().map(|_| "<set>"),
346 )
347 .field("session_fs", &self.session_fs)
348 .field(
349 "request_handler",
350 &self.request_handler.as_ref().map(|_| "<set>"),
351 )
352 .field(
353 "on_github_telemetry",
354 &self.on_github_telemetry.as_ref().map(|_| "<set>"),
355 )
356 .field(
357 "on_get_trace_context",
358 &self.on_get_trace_context.as_ref().map(|_| "<set>"),
359 )
360 .field("telemetry", &self.telemetry)
361 .field("base_directory", &self.base_directory)
362 .field("enable_remote_sessions", &self.enable_remote_sessions)
363 .field("bundled_cli_extract_dir", &self.bundled_cli_extract_dir)
364 .finish()
365 }
366}
367
368#[async_trait]
377pub trait ListModelsHandler: Send + Sync + 'static {
378 async fn list_models(&self) -> Result<Vec<Model>>;
380}
381
382#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
384#[serde(rename_all = "lowercase")]
385pub enum LogLevel {
386 None,
388 Error,
390 Warning,
392 Info,
394 Debug,
396 All,
398}
399
400impl LogLevel {
401 pub fn as_str(self) -> &'static str {
403 match self {
404 Self::None => "none",
405 Self::Error => "error",
406 Self::Warning => "warning",
407 Self::Info => "info",
408 Self::Debug => "debug",
409 Self::All => "all",
410 }
411 }
412}
413
414impl std::fmt::Display for LogLevel {
415 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416 f.write_str(self.as_str())
417 }
418}
419
420#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
425#[serde(rename_all = "kebab-case")]
426#[non_exhaustive]
427pub enum OtelExporterType {
428 OtlpHttp,
431 File,
434}
435
436impl OtelExporterType {
437 pub fn as_str(self) -> &'static str {
439 match self {
440 Self::OtlpHttp => "otlp-http",
441 Self::File => "file",
442 }
443 }
444}
445
446#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
452#[non_exhaustive]
453pub enum OtlpHttpProtocol {
454 #[serde(rename = "http/json")]
456 HttpJson,
457 #[serde(rename = "http/protobuf")]
459 HttpProtobuf,
460}
461
462impl OtlpHttpProtocol {
463 pub fn as_str(self) -> &'static str {
465 match self {
466 Self::HttpJson => "http/json",
467 Self::HttpProtobuf => "http/protobuf",
468 }
469 }
470}
471
472#[derive(Debug, Clone, Default)]
507#[non_exhaustive]
508pub struct TelemetryConfig {
509 pub otlp_endpoint: Option<String>,
511 pub otlp_protocol: Option<OtlpHttpProtocol>,
513 pub file_path: Option<PathBuf>,
515 pub exporter_type: Option<OtelExporterType>,
518 pub source_name: Option<String>,
522 pub capture_content: Option<bool>,
526}
527
528impl TelemetryConfig {
529 pub fn new() -> Self {
532 Self::default()
533 }
534
535 pub fn with_otlp_endpoint(mut self, endpoint: impl Into<String>) -> Self {
537 self.otlp_endpoint = Some(endpoint.into());
538 self
539 }
540
541 pub fn with_otlp_protocol(mut self, protocol: OtlpHttpProtocol) -> Self {
543 self.otlp_protocol = Some(protocol);
544 self
545 }
546
547 pub fn with_file_path(mut self, path: impl Into<PathBuf>) -> Self {
549 self.file_path = Some(path.into());
550 self
551 }
552
553 pub fn with_exporter_type(mut self, exporter_type: OtelExporterType) -> Self {
555 self.exporter_type = Some(exporter_type);
556 self
557 }
558
559 pub fn with_source_name(mut self, source_name: impl Into<String>) -> Self {
563 self.source_name = Some(source_name.into());
564 self
565 }
566
567 pub fn with_capture_content(mut self, capture: bool) -> Self {
571 self.capture_content = Some(capture);
572 self
573 }
574
575 pub fn is_empty(&self) -> bool {
578 self.otlp_endpoint.is_none()
579 && self.otlp_protocol.is_none()
580 && self.file_path.is_none()
581 && self.exporter_type.is_none()
582 && self.source_name.is_none()
583 && self.capture_content.is_none()
584 }
585}
586
587impl Default for ClientOptions {
588 fn default() -> Self {
589 Self {
590 program: CliProgram::Resolve,
591 prefix_args: Vec::new(),
592 working_directory: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
593 env: Vec::new(),
594 env_remove: Vec::new(),
595 extra_args: Vec::new(),
596 transport: Transport::default(),
597 github_token: None,
598 use_logged_in_user: None,
599 log_level: None,
600 session_idle_timeout_seconds: None,
601 on_list_models: None,
602 session_fs: None,
603 request_handler: None,
604 on_github_telemetry: None,
605 on_get_trace_context: None,
606 telemetry: None,
607 base_directory: None,
608 enable_remote_sessions: false,
609 bundled_cli_extract_dir: None,
610 mode: ClientMode::default(),
611 }
612 }
613}
614
615impl ClientOptions {
616 pub fn new() -> Self {
632 Self::default()
633 }
634
635 pub fn with_program(mut self, program: impl Into<CliProgram>) -> Self {
637 self.program = program.into();
638 self
639 }
640
641 pub fn with_prefix_args<I, S>(mut self, args: I) -> Self
643 where
644 I: IntoIterator<Item = S>,
645 S: Into<OsString>,
646 {
647 self.prefix_args = args.into_iter().map(Into::into).collect();
648 self
649 }
650
651 pub fn with_cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
653 self.working_directory = cwd.into();
654 self
655 }
656
657 pub fn with_env<I, K, V>(mut self, env: I) -> Self
659 where
660 I: IntoIterator<Item = (K, V)>,
661 K: Into<OsString>,
662 V: Into<OsString>,
663 {
664 self.env = env.into_iter().map(|(k, v)| (k.into(), v.into())).collect();
665 self
666 }
667
668 pub fn with_env_remove<I, S>(mut self, names: I) -> Self
670 where
671 I: IntoIterator<Item = S>,
672 S: Into<OsString>,
673 {
674 self.env_remove = names.into_iter().map(Into::into).collect();
675 self
676 }
677
678 pub fn with_extra_args<I, S>(mut self, args: I) -> Self
680 where
681 I: IntoIterator<Item = S>,
682 S: Into<String>,
683 {
684 self.extra_args = args.into_iter().map(Into::into).collect();
685 self
686 }
687
688 pub fn with_transport(mut self, transport: Transport) -> Self {
690 self.transport = transport;
691 self
692 }
693
694 pub fn with_github_token(mut self, token: impl Into<String>) -> Self {
697 self.github_token = Some(token.into());
698 self
699 }
700
701 pub fn with_use_logged_in_user(mut self, use_logged_in: bool) -> Self {
704 self.use_logged_in_user = Some(use_logged_in);
705 self
706 }
707
708 pub fn with_log_level(mut self, level: LogLevel) -> Self {
710 self.log_level = Some(level);
711 self
712 }
713
714 pub fn with_session_idle_timeout_seconds(mut self, seconds: u64) -> Self {
717 self.session_idle_timeout_seconds = Some(seconds);
718 self
719 }
720
721 pub fn with_list_models_handler<H>(mut self, handler: H) -> Self
724 where
725 H: ListModelsHandler + 'static,
726 {
727 self.on_list_models = Some(Arc::new(handler));
728 self
729 }
730
731 pub fn with_session_fs(mut self, config: SessionFsConfig) -> Self {
733 self.session_fs = Some(config);
734 self
735 }
736
737 pub fn with_request_handler<H>(mut self, handler: H) -> Self
742 where
743 H: crate::copilot_request_handler::CopilotRequestHandler,
744 {
745 self.request_handler = Some(Arc::new(handler));
746 self
747 }
748
749 #[doc(hidden)]
755 pub fn with_on_github_telemetry<F>(mut self, callback: F) -> Self
756 where
757 F: Fn(crate::github_telemetry::GitHubTelemetryNotification) + Send + Sync + 'static,
758 {
759 self.on_github_telemetry = Some(Arc::new(callback));
760 self
761 }
762
763 pub fn with_trace_context_provider<P>(mut self, provider: P) -> Self
767 where
768 P: TraceContextProvider + 'static,
769 {
770 self.on_get_trace_context = Some(Arc::new(provider));
771 self
772 }
773
774 pub fn with_telemetry(mut self, config: TelemetryConfig) -> Self {
776 self.telemetry = Some(config);
777 self
778 }
779
780 pub fn with_base_directory(mut self, dir: impl Into<PathBuf>) -> Self {
783 self.base_directory = Some(dir.into());
784 self
785 }
786
787 pub fn with_enable_remote_sessions(mut self, enabled: bool) -> Self {
790 self.enable_remote_sessions = enabled;
791 self
792 }
793
794 pub fn with_bundled_cli_extract_dir(mut self, dir: impl Into<PathBuf>) -> Self {
804 self.bundled_cli_extract_dir = Some(dir.into());
805 self
806 }
807
808 pub fn with_mode(mut self, mode: ClientMode) -> Self {
813 self.mode = mode;
814 self
815 }
816}
817
818fn validate_session_fs_config(cfg: &SessionFsConfig) -> Result<()> {
820 if cfg.initial_cwd.trim().is_empty() {
821 return Err(Error::with_message(
822 ErrorKind::Session(SessionErrorKind::InvalidSessionFsConfig),
823 "invalid SessionFsConfig: initial_cwd must not be empty",
824 ));
825 }
826 if cfg.session_state_path.trim().is_empty() {
827 return Err(Error::with_message(
828 ErrorKind::Session(SessionErrorKind::InvalidSessionFsConfig),
829 "invalid SessionFsConfig: session_state_path must not be empty",
830 ));
831 }
832 Ok(())
833}
834
835fn generate_connection_token() -> String {
842 let mut bytes = [0u8; 16];
843 getrandom::getrandom(&mut bytes)
844 .expect("OS CSPRNG (getrandom) is unavailable; cannot generate connection token");
845 let mut hex = String::with_capacity(32);
846 for byte in bytes {
847 use std::fmt::Write;
848 let _ = write!(hex, "{byte:02x}");
849 }
850 hex
851}
852
853#[derive(Clone)]
858pub struct Client {
859 inner: Arc<ClientInner>,
860}
861
862impl std::fmt::Debug for Client {
863 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
864 f.debug_struct("Client")
865 .field("working_directory", &self.inner.cwd)
866 .field("pid", &self.pid())
867 .finish()
868 }
869}
870
871struct ClientInner {
872 child: parking_lot::Mutex<Option<Child>>,
873 rpc: JsonRpcClient,
874 cwd: PathBuf,
875 request_rx: parking_lot::Mutex<Option<mpsc::UnboundedReceiver<JsonRpcRequest>>>,
876 notification_tx: broadcast::Sender<JsonRpcNotification>,
877 router: router::SessionRouter,
878 negotiated_protocol_version: OnceLock<u32>,
879 state: parking_lot::Mutex<ConnectionState>,
880 lifecycle_tx: broadcast::Sender<SessionLifecycleEvent>,
881 on_list_models: Option<Arc<dyn ListModelsHandler>>,
882 models_cache: parking_lot::Mutex<Arc<tokio::sync::OnceCell<Vec<Model>>>>,
883 session_fs_configured: bool,
884 session_fs_sqlite_declared: bool,
885 llm_inference: OnceLock<Arc<copilot_request_handler::CopilotRequestDispatcher>>,
888 on_github_telemetry: Option<crate::github_telemetry::GitHubTelemetryCallback>,
893 on_get_trace_context: Option<Arc<dyn TraceContextProvider>>,
894 effective_connection_token: Option<String>,
899 pub(crate) mode: ClientMode,
902}
903
904impl Client {
905 pub async fn start(options: ClientOptions) -> Result<Self> {
918 let start_time = Instant::now();
919 if options.mode == ClientMode::Empty
920 && options.base_directory.is_none()
921 && options.session_fs.is_none()
922 {
923 return Err(Error::with_message(
924 ErrorKind::InvalidConfig,
925 "ClientMode::Empty requires either `base_directory` or \
926 `session_fs` to be set (no implicit ~/.copilot fallback).",
927 ));
928 }
929 if let Some(cfg) = &options.session_fs {
930 validate_session_fs_config(cfg)?;
931 }
932 if matches!(options.transport, Transport::External { .. }) {
935 if options.github_token.is_some() {
936 return Err(Error::with_message(
937 ErrorKind::InvalidConfig,
938 "invalid client configuration: github_token cannot be used with \
939 Transport::External (external server manages its own auth)",
940 ));
941 }
942 if options.use_logged_in_user == Some(true) {
943 return Err(Error::with_message(
944 ErrorKind::InvalidConfig,
945 "invalid client configuration: use_logged_in_user cannot be used with \
946 Transport::External (external server manages its own auth)",
947 ));
948 }
949 }
950 match &options.transport {
954 Transport::Tcp {
955 connection_token: Some(t),
956 ..
957 }
958 | Transport::External {
959 connection_token: Some(t),
960 ..
961 } if t.is_empty() => {
962 return Err(Error::with_message(
963 ErrorKind::InvalidConfig,
964 "invalid client configuration: connection_token must be a non-empty string",
965 ));
966 }
967 _ => {}
968 }
969 let mut options = options;
974 let effective_connection_token: Option<String> = match &mut options.transport {
975 Transport::Stdio => None,
976 Transport::Tcp {
977 connection_token, ..
978 } => Some(
979 connection_token
980 .get_or_insert_with(generate_connection_token)
981 .clone(),
982 ),
983 Transport::External {
984 connection_token, ..
985 } => connection_token.clone(),
986 };
987 let session_fs_config = options.session_fs.clone();
988 let request_handler = options.request_handler.clone();
989 let session_fs_sqlite_declared = session_fs_config
990 .as_ref()
991 .and_then(|c| c.capabilities.as_ref())
992 .is_some_and(|caps| caps.sqlite);
993 let program = match &options.program {
994 CliProgram::Path(path) => {
995 info!(path = %path.display(), "using explicit copilot CLI path");
996 path.clone()
997 }
998 CliProgram::Resolve => {
999 let resolved = resolve::copilot_binary_with_extract_dir(
1000 options.bundled_cli_extract_dir.as_deref(),
1001 )?;
1002 info!(path = %resolved.display(), "resolved copilot CLI");
1003 #[cfg(windows)]
1004 {
1005 if let Some(ext) = resolved.extension().and_then(|e| e.to_str()).filter(|ext| {
1006 ext.eq_ignore_ascii_case("cmd") || ext.eq_ignore_ascii_case("bat")
1007 }) {
1008 warn!(
1009 path = %resolved.display(),
1010 ext = %ext,
1011 "resolved copilot CLI is a .cmd/.bat wrapper; \
1012 this may cause console window flashes on Windows"
1013 );
1014 }
1015 }
1016 resolved
1017 }
1018 };
1019
1020 let client = match options.transport {
1021 Transport::External {
1022 ref host,
1023 port,
1024 connection_token: _,
1025 } => {
1026 info!(host = %host, port = %port, "connecting to external CLI server");
1027 let connect_start = Instant::now();
1028 let stream = TcpStream::connect((host.as_str(), port)).await?;
1029 debug!(
1030 elapsed_ms = connect_start.elapsed().as_millis(),
1031 host = %host,
1032 port,
1033 "Client::start TCP connect complete"
1034 );
1035 let (reader, writer) = tokio::io::split(stream);
1036 Self::from_transport(
1037 reader,
1038 writer,
1039 None,
1040 options.working_directory,
1041 options.on_list_models,
1042 session_fs_config.is_some(),
1043 session_fs_sqlite_declared,
1044 options.on_get_trace_context,
1045 options.on_github_telemetry,
1046 effective_connection_token.clone(),
1047 options.mode,
1048 )?
1049 }
1050 Transport::Tcp {
1051 port,
1052 connection_token: _,
1053 } => {
1054 let (mut child, actual_port) = Self::spawn_tcp(&program, &options, port).await?;
1055 let connect_start = Instant::now();
1056 let stream = TcpStream::connect(("127.0.0.1", actual_port)).await?;
1057 debug!(
1058 elapsed_ms = connect_start.elapsed().as_millis(),
1059 port = actual_port,
1060 "Client::start TCP connect complete"
1061 );
1062 let (reader, writer) = tokio::io::split(stream);
1063 Self::drain_stderr(&mut child);
1064 Self::from_transport(
1065 reader,
1066 writer,
1067 Some(child),
1068 options.working_directory,
1069 options.on_list_models,
1070 session_fs_config.is_some(),
1071 session_fs_sqlite_declared,
1072 options.on_get_trace_context,
1073 options.on_github_telemetry,
1074 effective_connection_token.clone(),
1075 options.mode,
1076 )?
1077 }
1078 Transport::Stdio => {
1079 let mut child = Self::spawn_stdio(&program, &options)?;
1080 let stdin = child.stdin.take().expect("stdin is piped");
1081 let stdout = child.stdout.take().expect("stdout is piped");
1082 Self::drain_stderr(&mut child);
1083 Self::from_transport(
1084 stdout,
1085 stdin,
1086 Some(child),
1087 options.working_directory,
1088 options.on_list_models,
1089 session_fs_config.is_some(),
1090 session_fs_sqlite_declared,
1091 options.on_get_trace_context,
1092 options.on_github_telemetry,
1093 effective_connection_token.clone(),
1094 options.mode,
1095 )?
1096 }
1097 };
1098
1099 debug!(
1100 elapsed_ms = start_time.elapsed().as_millis(),
1101 "Client::start transport setup complete"
1102 );
1103 client.verify_protocol_version().await?;
1104 debug!(
1105 elapsed_ms = start_time.elapsed().as_millis(),
1106 "Client::start protocol verification complete"
1107 );
1108 if let Some(cfg) = session_fs_config {
1109 let session_fs_start = Instant::now();
1110 let capabilities = cfg.capabilities.as_ref().map(|c| {
1111 crate::generated::api_types::SessionFsSetProviderCapabilities {
1112 sqlite: Some(c.sqlite),
1113 }
1114 });
1115 let request = crate::generated::api_types::SessionFsSetProviderRequest {
1116 capabilities,
1117 conventions: cfg.conventions.into_wire(),
1118 initial_cwd: cfg.initial_cwd,
1119 session_state_path: cfg.session_state_path,
1120 };
1121 client.rpc().session_fs().set_provider(request).await?;
1122 debug!(
1123 elapsed_ms = session_fs_start.elapsed().as_millis(),
1124 "Client::start session filesystem setup complete"
1125 );
1126 }
1127 if let Some(handler) = request_handler {
1128 let llm_inference_start = Instant::now();
1129 let dispatcher = Arc::new(copilot_request_handler::CopilotRequestDispatcher::new(
1130 handler,
1131 ));
1132 dispatcher.set_client(Arc::downgrade(&client.inner));
1133 let _ = client.inner.llm_inference.set(dispatcher.clone());
1134 client.inner.router.ensure_started(
1137 &client.inner.notification_tx,
1138 &client.inner.request_rx,
1139 Some(dispatcher.clone()),
1140 client.inner.on_github_telemetry.clone(),
1141 );
1142 client.rpc().llm_inference().set_provider().await?;
1143 debug!(
1144 elapsed_ms = llm_inference_start.elapsed().as_millis(),
1145 "Client::start Copilot request handler registration complete"
1146 );
1147 }
1148 debug!(
1149 elapsed_ms = start_time.elapsed().as_millis(),
1150 "Client::start complete"
1151 );
1152 Ok(client)
1153 }
1154
1155 pub fn from_streams(
1159 reader: impl AsyncRead + Unpin + Send + 'static,
1160 writer: impl AsyncWrite + Unpin + Send + 'static,
1161 cwd: PathBuf,
1162 ) -> Result<Self> {
1163 Self::from_transport(
1164 reader,
1165 writer,
1166 None,
1167 cwd,
1168 None,
1169 false,
1170 false,
1171 None,
1172 None,
1173 None,
1174 ClientMode::default(),
1175 )
1176 }
1177
1178 #[cfg(any(test, feature = "test-support"))]
1186 pub fn from_streams_with_trace_provider(
1187 reader: impl AsyncRead + Unpin + Send + 'static,
1188 writer: impl AsyncWrite + Unpin + Send + 'static,
1189 cwd: PathBuf,
1190 provider: Arc<dyn TraceContextProvider>,
1191 ) -> Result<Self> {
1192 Self::from_transport(
1193 reader,
1194 writer,
1195 None,
1196 cwd,
1197 None,
1198 false,
1199 false,
1200 Some(provider),
1201 None,
1202 None,
1203 ClientMode::default(),
1204 )
1205 }
1206
1207 #[cfg(any(test, feature = "test-support"))]
1211 pub fn from_streams_with_connection_token(
1212 reader: impl AsyncRead + Unpin + Send + 'static,
1213 writer: impl AsyncWrite + Unpin + Send + 'static,
1214 cwd: PathBuf,
1215 token: Option<String>,
1216 ) -> Result<Self> {
1217 Self::from_transport(
1218 reader,
1219 writer,
1220 None,
1221 cwd,
1222 None,
1223 false,
1224 false,
1225 None,
1226 None,
1227 token,
1228 ClientMode::default(),
1229 )
1230 }
1231
1232 #[doc(hidden)]
1235 #[cfg(any(test, feature = "test-support"))]
1236 pub fn from_streams_with_github_telemetry(
1237 reader: impl AsyncRead + Unpin + Send + 'static,
1238 writer: impl AsyncWrite + Unpin + Send + 'static,
1239 cwd: PathBuf,
1240 on_github_telemetry: crate::github_telemetry::GitHubTelemetryCallback,
1241 ) -> Result<Self> {
1242 Self::from_transport(
1243 reader,
1244 writer,
1245 None,
1246 cwd,
1247 None,
1248 false,
1249 false,
1250 None,
1251 Some(on_github_telemetry),
1252 None,
1253 ClientMode::default(),
1254 )
1255 }
1256
1257 #[cfg(any(test, feature = "test-support"))]
1263 pub fn generate_connection_token_for_test() -> String {
1264 generate_connection_token()
1265 }
1266
1267 #[allow(clippy::too_many_arguments)]
1268 fn from_transport(
1269 reader: impl AsyncRead + Unpin + Send + 'static,
1270 writer: impl AsyncWrite + Unpin + Send + 'static,
1271 child: Option<Child>,
1272 cwd: PathBuf,
1273 on_list_models: Option<Arc<dyn ListModelsHandler>>,
1274 session_fs_configured: bool,
1275 session_fs_sqlite_declared: bool,
1276 on_get_trace_context: Option<Arc<dyn TraceContextProvider>>,
1277 on_github_telemetry: Option<crate::github_telemetry::GitHubTelemetryCallback>,
1278 effective_connection_token: Option<String>,
1279 mode: ClientMode,
1280 ) -> Result<Self> {
1281 let setup_start = Instant::now();
1282 let (request_tx, request_rx) = mpsc::unbounded_channel::<JsonRpcRequest>();
1283 let (notification_broadcast_tx, _) = broadcast::channel::<JsonRpcNotification>(1024);
1284 let rpc = JsonRpcClient::new(
1285 writer,
1286 reader,
1287 notification_broadcast_tx.clone(),
1288 request_tx,
1289 );
1290
1291 let pid = child.as_ref().and_then(|c| c.id());
1292 info!(pid = ?pid, "copilot CLI client ready");
1293
1294 let client = Self {
1295 inner: Arc::new(ClientInner {
1296 child: parking_lot::Mutex::new(child),
1297 rpc,
1298 cwd,
1299 request_rx: parking_lot::Mutex::new(Some(request_rx)),
1300 notification_tx: notification_broadcast_tx,
1301 router: router::SessionRouter::new(),
1302 negotiated_protocol_version: OnceLock::new(),
1303 state: parking_lot::Mutex::new(ConnectionState::Connected),
1304 lifecycle_tx: broadcast::channel(256).0,
1305 on_list_models,
1306 models_cache: parking_lot::Mutex::new(Arc::new(tokio::sync::OnceCell::new())),
1307 session_fs_configured,
1308 session_fs_sqlite_declared,
1309 llm_inference: OnceLock::new(),
1310 on_github_telemetry,
1311 on_get_trace_context,
1312 effective_connection_token,
1313 mode,
1314 }),
1315 };
1316 client.spawn_lifecycle_dispatcher();
1317 debug!(
1318 elapsed_ms = setup_start.elapsed().as_millis(),
1319 pid = ?pid,
1320 "Client::from_transport setup complete"
1321 );
1322 Ok(client)
1323 }
1324
1325 fn spawn_lifecycle_dispatcher(&self) {
1329 let inner = Arc::clone(&self.inner);
1330 let mut notif_rx = inner.notification_tx.subscribe();
1331 tokio::spawn(async move {
1332 loop {
1333 match notif_rx.recv().await {
1334 Ok(notification) => {
1335 if notification.method != "session.lifecycle" {
1336 continue;
1337 }
1338 let Some(params) = notification.params.as_ref() else {
1339 continue;
1340 };
1341 let event: SessionLifecycleEvent =
1342 match serde_json::from_value(params.clone()) {
1343 Ok(e) => e,
1344 Err(e) => {
1345 warn!(
1346 error = %e,
1347 "failed to deserialize session.lifecycle notification"
1348 );
1349 continue;
1350 }
1351 };
1352 let _ = inner.lifecycle_tx.send(event);
1355 }
1356 Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
1357 warn!(missed = n, "lifecycle dispatcher lagged");
1358 }
1359 Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
1360 }
1361 }
1362 });
1363 }
1364
1365 fn build_command(program: &Path, options: &ClientOptions) -> Command {
1366 let mut command = Command::new(program);
1367 for arg in &options.prefix_args {
1368 command.arg(arg);
1369 }
1370 if let Some(token) = &options.github_token {
1373 command.env("COPILOT_SDK_AUTH_TOKEN", token);
1374 }
1375 if let Some(telemetry) = &options.telemetry {
1378 command.env("COPILOT_OTEL_ENABLED", "true");
1379 if let Some(endpoint) = &telemetry.otlp_endpoint {
1380 command.env("OTEL_EXPORTER_OTLP_ENDPOINT", endpoint);
1381 }
1382 if let Some(protocol) = telemetry.otlp_protocol {
1383 command.env("OTEL_EXPORTER_OTLP_PROTOCOL", protocol.as_str());
1384 }
1385 if let Some(path) = &telemetry.file_path {
1386 command.env("COPILOT_OTEL_FILE_EXPORTER_PATH", path);
1387 }
1388 if let Some(exporter) = telemetry.exporter_type {
1389 command.env("COPILOT_OTEL_EXPORTER_TYPE", exporter.as_str());
1390 }
1391 if let Some(source) = &telemetry.source_name {
1392 command.env("COPILOT_OTEL_SOURCE_NAME", source);
1393 }
1394 if let Some(capture) = telemetry.capture_content {
1395 command.env(
1396 "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT",
1397 if capture { "true" } else { "false" },
1398 );
1399 }
1400 }
1401 if let Some(dir) = &options.base_directory {
1402 command.env("COPILOT_HOME", dir);
1403 }
1404 if options.mode == ClientMode::Empty {
1407 command.env("COPILOT_DISABLE_KEYTAR", "1");
1408 }
1409 if let Transport::Tcp {
1410 connection_token: Some(token),
1411 ..
1412 } = &options.transport
1413 {
1414 command.env("COPILOT_CONNECTION_TOKEN", token);
1415 }
1416 for (key, value) in &options.env {
1417 command.env(key, value);
1418 }
1419 for key in &options.env_remove {
1420 command.env_remove(key);
1421 }
1422 command
1423 .current_dir(&options.working_directory)
1424 .stdout(Stdio::piped())
1425 .stderr(Stdio::piped());
1426
1427 #[cfg(windows)]
1428 {
1429 use std::os::windows::process::CommandExt;
1430 const CREATE_NO_WINDOW: u32 = 0x08000000;
1431 command.as_std_mut().creation_flags(CREATE_NO_WINDOW);
1432 }
1433
1434 command
1435 }
1436
1437 fn auth_args(options: &ClientOptions) -> Vec<&'static str> {
1445 let mut args: Vec<&'static str> = Vec::new();
1446 if options.github_token.is_some() {
1447 args.push("--auth-token-env");
1448 args.push("COPILOT_SDK_AUTH_TOKEN");
1449 }
1450 let use_logged_in = options
1451 .use_logged_in_user
1452 .unwrap_or(options.github_token.is_none());
1453 if !use_logged_in {
1454 args.push("--no-auto-login");
1455 }
1456 args
1457 }
1458
1459 fn session_idle_timeout_args(options: &ClientOptions) -> Vec<String> {
1463 match options.session_idle_timeout_seconds {
1464 Some(secs) if secs > 0 => {
1465 vec!["--session-idle-timeout".to_string(), secs.to_string()]
1466 }
1467 _ => Vec::new(),
1468 }
1469 }
1470
1471 fn remote_args(options: &ClientOptions) -> Vec<String> {
1472 if options.enable_remote_sessions {
1473 vec!["--remote".to_string()]
1474 } else {
1475 Vec::new()
1476 }
1477 }
1478
1479 fn log_level_args(options: &ClientOptions) -> Vec<&'static str> {
1480 match options.log_level {
1481 Some(level) => vec!["--log-level", level.as_str()],
1482 None => Vec::new(),
1483 }
1484 }
1485
1486 fn spawn_stdio(program: &Path, options: &ClientOptions) -> Result<Child> {
1487 info!(cwd = ?options.working_directory, program = %program.display(), "spawning copilot CLI (stdio)");
1488 let mut command = Self::build_command(program, options);
1489 command
1490 .args(["--server", "--stdio", "--no-auto-update"])
1491 .args(Self::log_level_args(options))
1492 .args(Self::auth_args(options))
1493 .args(Self::session_idle_timeout_args(options))
1494 .args(Self::remote_args(options))
1495 .args(&options.extra_args)
1496 .stdin(Stdio::piped());
1497 let spawn_start = Instant::now();
1498 let child = command.spawn()?;
1499 debug!(
1500 elapsed_ms = spawn_start.elapsed().as_millis(),
1501 "Client::spawn_stdio subprocess spawned"
1502 );
1503 Ok(child)
1504 }
1505
1506 async fn spawn_tcp(program: &Path, options: &ClientOptions, port: u16) -> Result<(Child, u16)> {
1507 info!(cwd = ?options.working_directory, program = %program.display(), port = %port, "spawning copilot CLI (tcp)");
1508 let mut command = Self::build_command(program, options);
1509 command
1510 .args(["--server", "--port", &port.to_string(), "--no-auto-update"])
1511 .args(Self::log_level_args(options))
1512 .args(Self::auth_args(options))
1513 .args(Self::session_idle_timeout_args(options))
1514 .args(Self::remote_args(options))
1515 .args(&options.extra_args)
1516 .stdin(Stdio::null());
1517 let spawn_start = Instant::now();
1518 let mut child = command.spawn()?;
1519 debug!(
1520 elapsed_ms = spawn_start.elapsed().as_millis(),
1521 "Client::spawn_tcp subprocess spawned"
1522 );
1523 let stdout = child.stdout.take().expect("stdout is piped");
1524
1525 let (port_tx, port_rx) = oneshot::channel::<u16>();
1526 let span = tracing::error_span!("copilot_cli_port_scan");
1527 tokio::spawn(
1528 async move {
1529 let port_re = regex::Regex::new(r"listening on port (\d+)").expect("valid regex");
1531 let mut lines = BufReader::new(stdout).lines();
1532 let mut port_tx = Some(port_tx);
1533 while let Ok(Some(line)) = lines.next_line().await {
1534 debug!(line = %line, "CLI stdout");
1535 if let Some(tx) = port_tx.take() {
1536 if let Some(caps) = port_re.captures(&line)
1537 && let Some(p) =
1538 caps.get(1).and_then(|m| m.as_str().parse::<u16>().ok())
1539 {
1540 let _ = tx.send(p);
1541 continue;
1542 }
1543 port_tx = Some(tx);
1545 }
1546 }
1547 }
1548 .instrument(span),
1549 );
1550
1551 let port_wait_start = Instant::now();
1552 let actual_port = tokio::time::timeout(std::time::Duration::from_secs(10), port_rx)
1553 .await
1554 .map_err(|_| Error::from(ErrorKind::Protocol(ProtocolErrorKind::CliStartupTimeout)))?
1555 .map_err(|_| Error::from(ErrorKind::Protocol(ProtocolErrorKind::CliStartupFailed)))?;
1556
1557 debug!(
1558 elapsed_ms = port_wait_start.elapsed().as_millis(),
1559 port = actual_port,
1560 "Client::spawn_tcp TCP port wait complete"
1561 );
1562 info!(port = %actual_port, "CLI server listening");
1563 Ok((child, actual_port))
1564 }
1565
1566 fn drain_stderr(child: &mut Child) {
1567 if let Some(stderr) = child.stderr.take() {
1568 let span = tracing::error_span!("copilot_cli");
1569 tokio::spawn(
1570 async move {
1571 let mut reader = BufReader::new(stderr).lines();
1572 while let Ok(Some(line)) = reader.next_line().await {
1573 warn!(line = %line, "CLI stderr");
1574 }
1575 }
1576 .instrument(span),
1577 );
1578 }
1579 }
1580
1581 pub fn cwd(&self) -> &PathBuf {
1583 &self.inner.cwd
1584 }
1585
1586 pub fn mode(&self) -> ClientMode {
1588 self.inner.mode
1589 }
1590
1591 pub fn rpc(&self) -> crate::generated::rpc::ClientRpc<'_> {
1602 crate::generated::rpc::ClientRpc { client: self }
1603 }
1604
1605 #[allow(dead_code, reason = "convenience for future internal use")]
1607 pub(crate) async fn send_request(
1608 &self,
1609 method: &str,
1610 params: Option<serde_json::Value>,
1611 ) -> Result<JsonRpcResponse> {
1612 self.inner.rpc.send_request(method, params).await
1613 }
1614
1615 pub async fn call(
1635 &self,
1636 method: &str,
1637 params: Option<serde_json::Value>,
1638 ) -> Result<serde_json::Value> {
1639 self.call_with_inline_callback(method, params, None).await
1640 }
1641
1642 pub(crate) async fn call_with_inline_callback(
1657 &self,
1658 method: &str,
1659 params: Option<serde_json::Value>,
1660 inline_callback: Option<crate::jsonrpc::InlineResponseCallback>,
1661 ) -> Result<serde_json::Value> {
1662 let session_id: Option<SessionId> = params
1663 .as_ref()
1664 .and_then(|p| p.get("sessionId"))
1665 .and_then(|v| v.as_str())
1666 .map(SessionId::from);
1667 let response = self
1668 .inner
1669 .rpc
1670 .send_request_with_inline_callback(method, params, inline_callback)
1671 .await?;
1672 if let Some(err) = response.error {
1673 if err.message.contains("Session not found") {
1674 return Err(ErrorKind::Session(SessionErrorKind::NotFound(
1675 session_id.unwrap_or_else(|| "unknown".into()),
1676 ))
1677 .into());
1678 }
1679 return Err(Error::with_message(
1680 ErrorKind::Rpc { code: err.code },
1681 err.message,
1682 ));
1683 }
1684 Ok(response.result.unwrap_or(serde_json::Value::Null))
1685 }
1686
1687 pub(crate) async fn send_response(&self, response: &JsonRpcResponse) -> Result<()> {
1689 self.inner.rpc.write(response).await
1690 }
1691
1692 pub(crate) fn from_inner(inner: Arc<ClientInner>) -> Self {
1694 Self { inner }
1695 }
1696
1697 #[expect(dead_code, reason = "reserved for future pub(crate) use")]
1701 pub(crate) fn take_request_rx(&self) -> Option<mpsc::UnboundedReceiver<JsonRpcRequest>> {
1702 self.inner.request_rx.lock().take()
1703 }
1704
1705 pub(crate) fn register_session(
1713 &self,
1714 session_id: &SessionId,
1715 ) -> crate::router::SessionChannels {
1716 self.inner.router.ensure_started(
1717 &self.inner.notification_tx,
1718 &self.inner.request_rx,
1719 self.inner.llm_inference.get().cloned(),
1720 self.inner.on_github_telemetry.clone(),
1721 );
1722 self.inner.router.register(session_id)
1723 }
1724
1725 pub(crate) fn unregister_session(&self, session_id: &SessionId) {
1727 self.inner.router.unregister(session_id);
1728 }
1729
1730 pub fn protocol_version(&self) -> Option<u32> {
1737 self.inner.negotiated_protocol_version.get().copied()
1738 }
1739
1740 pub async fn verify_protocol_version(&self) -> Result<()> {
1764 let handshake_start = Instant::now();
1765 let mut used_fallback_ping = false;
1766 let server_version = match self.connect_handshake().await {
1770 Ok(v) => v,
1771 Err(ref e) if e.rpc_code() == Some(error_codes::METHOD_NOT_FOUND) => {
1772 used_fallback_ping = true;
1773 self.ping(None).await?.protocol_version
1774 }
1775 Err(e) => return Err(e),
1776 };
1777
1778 match server_version {
1779 None => {
1780 warn!("CLI server did not report protocolVersion; skipping version check");
1781 }
1782 Some(v) if !(MIN_PROTOCOL_VERSION..=SDK_PROTOCOL_VERSION).contains(&v) => {
1783 return Err(ErrorKind::Protocol(ProtocolErrorKind::VersionMismatch {
1784 server: v,
1785 min: MIN_PROTOCOL_VERSION,
1786 max: SDK_PROTOCOL_VERSION,
1787 })
1788 .into());
1789 }
1790 Some(v) => {
1791 if let Some(&existing) = self.inner.negotiated_protocol_version.get() {
1792 if existing != v {
1793 return Err(ErrorKind::Protocol(ProtocolErrorKind::VersionChanged {
1794 previous: existing,
1795 current: v,
1796 })
1797 .into());
1798 }
1799 } else {
1800 let _ = self.inner.negotiated_protocol_version.set(v);
1801 }
1802 }
1803 }
1804
1805 debug!(
1806 elapsed_ms = handshake_start.elapsed().as_millis(),
1807 protocol_version = ?server_version,
1808 used_fallback_ping,
1809 "Client::verify_protocol_version protocol handshake complete"
1810 );
1811 Ok(())
1812 }
1813
1814 async fn connect_handshake(&self) -> Result<Option<u32>> {
1821 let result = self
1822 .rpc()
1823 .connect(crate::generated::api_types::ConnectRequest {
1824 token: self.inner.effective_connection_token.clone(),
1825 })
1826 .await?;
1827 Ok(u32::try_from(result.protocol_version).ok())
1828 }
1829
1830 pub async fn ping(&self, message: Option<&str>) -> Result<crate::types::PingResponse> {
1838 let params = match message {
1839 Some(m) => serde_json::json!({ "message": m }),
1840 None => serde_json::json!({}),
1841 };
1842 let value = self
1843 .call(generated::api_types::rpc_methods::PING, Some(params))
1844 .await?;
1845 Ok(serde_json::from_value(value)?)
1846 }
1847
1848 pub async fn list_sessions(
1851 &self,
1852 filter: Option<SessionListFilter>,
1853 ) -> Result<Vec<SessionMetadata>> {
1854 let params = match filter {
1855 Some(f) => serde_json::json!({ "filter": f }),
1856 None => serde_json::json!({}),
1857 };
1858 let result = self.call("session.list", Some(params)).await?;
1859 let response: ListSessionsResponse = serde_json::from_value(result)?;
1860 Ok(response.sessions)
1861 }
1862
1863 pub async fn get_session_metadata(
1881 &self,
1882 session_id: &SessionId,
1883 ) -> Result<Option<SessionMetadata>> {
1884 let result = self
1885 .call(
1886 "session.getMetadata",
1887 Some(serde_json::json!({ "sessionId": session_id })),
1888 )
1889 .await?;
1890 let response: GetSessionMetadataResponse = serde_json::from_value(result)?;
1891 Ok(response.session)
1892 }
1893
1894 pub async fn delete_session(&self, session_id: &SessionId) -> Result<()> {
1896 self.call(
1897 "session.delete",
1898 Some(serde_json::json!({ "sessionId": session_id })),
1899 )
1900 .await?;
1901 Ok(())
1902 }
1903
1904 pub async fn get_last_session_id(&self) -> Result<Option<SessionId>> {
1920 let result = self
1921 .call("session.getLastId", Some(serde_json::json!({})))
1922 .await?;
1923 let response: GetLastSessionIdResponse = serde_json::from_value(result)?;
1924 Ok(response.session_id)
1925 }
1926
1927 pub async fn get_foreground_session_id(&self) -> Result<Option<SessionId>> {
1932 let result = self
1933 .call("session.getForeground", Some(serde_json::json!({})))
1934 .await?;
1935 let response: GetForegroundSessionResponse = serde_json::from_value(result)?;
1936 Ok(response.session_id)
1937 }
1938
1939 pub async fn set_foreground_session_id(&self, session_id: &SessionId) -> Result<()> {
1944 self.call(
1945 "session.setForeground",
1946 Some(serde_json::json!({ "sessionId": session_id })),
1947 )
1948 .await?;
1949 Ok(())
1950 }
1951
1952 pub async fn get_status(&self) -> Result<GetStatusResponse> {
1954 let result = self.call("status.get", Some(serde_json::json!({}))).await?;
1955 Ok(serde_json::from_value(result)?)
1956 }
1957
1958 pub async fn get_auth_status(&self) -> Result<GetAuthStatusResponse> {
1960 let result = self
1961 .call("auth.getStatus", Some(serde_json::json!({})))
1962 .await?;
1963 Ok(serde_json::from_value(result)?)
1964 }
1965
1966 pub async fn list_models(&self) -> Result<Vec<Model>> {
1971 let cache = self.inner.models_cache.lock().clone();
1972 let models = cache
1973 .get_or_try_init(|| async {
1974 if let Some(handler) = &self.inner.on_list_models {
1975 handler.list_models().await
1976 } else {
1977 Ok(self.rpc().models().list().await?.models)
1978 }
1979 })
1980 .await?;
1981 Ok(models.clone())
1982 }
1983
1984 pub(crate) async fn resolve_trace_context(&self) -> TraceContext {
1987 if let Some(provider) = &self.inner.on_get_trace_context {
1988 provider.get_trace_context().await
1989 } else {
1990 TraceContext::default()
1991 }
1992 }
1993
1994 pub fn pid(&self) -> Option<u32> {
1996 self.inner.child.lock().as_ref().and_then(|c| c.id())
1997 }
1998
1999 pub async fn stop(&self) -> std::result::Result<(), StopErrors> {
2026 let pid = self.pid();
2027 info!(pid = ?pid, "stopping CLI process");
2028 let mut errors: Vec<Error> = Vec::new();
2029
2030 for session_id in self.inner.router.session_ids() {
2033 match self
2034 .call(
2035 "session.destroy",
2036 Some(serde_json::json!({ "sessionId": session_id })),
2037 )
2038 .await
2039 {
2040 Ok(_) => {}
2041 Err(e) => {
2042 warn!(
2043 session_id = %session_id,
2044 error = %e,
2045 "session.destroy failed during Client::stop",
2046 );
2047 errors.push(e);
2048 }
2049 }
2050 self.inner.router.unregister(&session_id);
2051 }
2052
2053 let should_shutdown_runtime = self.inner.child.lock().is_some();
2054 if should_shutdown_runtime {
2055 let runtime_shutdown_start = Instant::now();
2056 match tokio::time::timeout(RUNTIME_SHUTDOWN_TIMEOUT, self.rpc().runtime().shutdown())
2057 .await
2058 {
2059 Ok(Ok(())) => {
2060 debug!(
2061 elapsed_ms = runtime_shutdown_start.elapsed().as_millis(),
2062 "Client::stop runtime shutdown complete"
2063 );
2064 }
2065 Ok(Err(e)) => {
2066 warn!(
2067 elapsed_ms = runtime_shutdown_start.elapsed().as_millis(),
2068 error = %e,
2069 "runtime.shutdown failed during Client::stop",
2070 );
2071 errors.push(e);
2072 }
2073 Err(_) => {
2074 let e = std::io::Error::new(
2075 std::io::ErrorKind::TimedOut,
2076 "runtime.shutdown timed out during Client::stop",
2077 );
2078 warn!(
2079 elapsed_ms = runtime_shutdown_start.elapsed().as_millis(),
2080 timeout = ?RUNTIME_SHUTDOWN_TIMEOUT,
2081 error = %e,
2082 "runtime.shutdown timed out during Client::stop",
2083 );
2084 errors.push(e.into());
2085 }
2086 }
2087 }
2088
2089 let child = self.inner.child.lock().take();
2090 *self.inner.state.lock() = ConnectionState::Disconnected;
2091 *self.inner.models_cache.lock() = Arc::new(tokio::sync::OnceCell::new());
2092 if let Some(mut child) = child {
2093 match child.try_wait() {
2094 Ok(Some(_status)) => {}
2095 Ok(None) => {
2096 if let Err(e) = child.kill().await {
2103 errors.push(e.into());
2104 }
2105 }
2106 Err(e) => errors.push(e.into()),
2107 }
2108 }
2109
2110 info!(pid = ?pid, errors = errors.len(), "CLI process stopped");
2111 if errors.is_empty() {
2112 Ok(())
2113 } else {
2114 Err(StopErrors(errors))
2115 }
2116 }
2117
2118 pub fn force_stop(&self) {
2148 let pid = self.pid();
2149 info!(pid = ?pid, "force-stopping CLI process");
2150 if let Some(mut child) = self.inner.child.lock().take()
2151 && let Err(e) = child.start_kill()
2152 {
2153 error!(pid = ?pid, error = %e, "failed to send kill signal");
2154 }
2155 self.inner.rpc.force_close();
2156 self.inner.router.clear();
2159 *self.inner.state.lock() = ConnectionState::Disconnected;
2160 *self.inner.models_cache.lock() = Arc::new(tokio::sync::OnceCell::new());
2161 }
2162
2163 pub fn subscribe_lifecycle(&self) -> LifecycleSubscription {
2198 LifecycleSubscription::new(self.inner.lifecycle_tx.subscribe())
2199 }
2200}
2201
2202impl Drop for ClientInner {
2203 fn drop(&mut self) {
2204 if let Some(ref mut child) = *self.child.lock() {
2205 let pid = child.id();
2206 if let Err(e) = child.start_kill() {
2207 error!(pid = ?pid, error = %e, "failed to kill CLI process on drop");
2208 } else {
2209 info!(pid = ?pid, "kill signal sent for CLI process on drop");
2210 }
2211 }
2212 }
2213}
2214
2215#[cfg(test)]
2216mod tests {
2217 use super::*;
2218
2219 #[test]
2220 fn is_transport_failure_matches_request_cancelled() {
2221 let err = Error::from(ErrorKind::Protocol(ProtocolErrorKind::RequestCancelled));
2222 assert!(err.is_transport_failure());
2223 }
2224
2225 #[test]
2226 fn is_transport_failure_matches_io_error() {
2227 let err = Error::from(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "gone"));
2228 assert!(err.is_transport_failure());
2229 }
2230
2231 #[test]
2232 fn is_transport_failure_rejects_rpc_error() {
2233 let err = Error::with_message(ErrorKind::Rpc { code: -1 }, "bad");
2234 assert!(!err.is_transport_failure());
2235 }
2236
2237 #[test]
2238 fn is_transport_failure_rejects_session_error() {
2239 let err = Error::from(ErrorKind::Session(SessionErrorKind::NotFound("s1".into())));
2240 assert!(!err.is_transport_failure());
2241 }
2242
2243 #[test]
2244 fn client_options_builder_composes() {
2245 let opts = ClientOptions::new()
2246 .with_program(CliProgram::Path(PathBuf::from("/usr/local/bin/copilot")))
2247 .with_prefix_args(["node"])
2248 .with_cwd(PathBuf::from("/tmp"))
2249 .with_env([("KEY", "value")])
2250 .with_env_remove(["UNWANTED"])
2251 .with_extra_args(["--quiet"])
2252 .with_github_token("ghp_test")
2253 .with_use_logged_in_user(false)
2254 .with_log_level(LogLevel::Debug)
2255 .with_session_idle_timeout_seconds(120)
2256 .with_enable_remote_sessions(true);
2257 assert!(matches!(opts.program, CliProgram::Path(_)));
2258 assert_eq!(opts.prefix_args, vec![std::ffi::OsString::from("node")]);
2259 assert_eq!(opts.working_directory, PathBuf::from("/tmp"));
2260 assert_eq!(
2261 opts.env,
2262 vec![(
2263 std::ffi::OsString::from("KEY"),
2264 std::ffi::OsString::from("value")
2265 )]
2266 );
2267 assert_eq!(opts.env_remove, vec![std::ffi::OsString::from("UNWANTED")]);
2268 assert_eq!(opts.extra_args, vec!["--quiet".to_string()]);
2269 assert_eq!(opts.github_token.as_deref(), Some("ghp_test"));
2270 assert_eq!(opts.use_logged_in_user, Some(false));
2271 assert!(matches!(opts.log_level, Some(LogLevel::Debug)));
2272 assert_eq!(opts.session_idle_timeout_seconds, Some(120));
2273 assert!(opts.enable_remote_sessions);
2274 }
2275
2276 #[test]
2277 fn is_transport_failure_rejects_other_protocol_errors() {
2278 let err = Error::from(ErrorKind::Protocol(ProtocolErrorKind::CliStartupTimeout));
2279 assert!(!err.is_transport_failure());
2280 }
2281
2282 #[test]
2283 fn build_command_lets_env_remove_strip_injected_token() {
2284 let opts = ClientOptions {
2285 github_token: Some("secret".to_string()),
2286 env_remove: vec![std::ffi::OsString::from("COPILOT_SDK_AUTH_TOKEN")],
2287 ..Default::default()
2288 };
2289 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2290 let action = cmd
2292 .as_std()
2293 .get_envs()
2294 .find(|(k, _)| *k == std::ffi::OsStr::new("COPILOT_SDK_AUTH_TOKEN"))
2295 .map(|(_, v)| v);
2296 assert_eq!(
2297 action,
2298 Some(None),
2299 "env_remove should win over github_token"
2300 );
2301 }
2302
2303 #[test]
2304 fn build_command_lets_env_override_injected_token() {
2305 let opts = ClientOptions {
2306 github_token: Some("from-options".to_string()),
2307 env: vec![(
2308 std::ffi::OsString::from("COPILOT_SDK_AUTH_TOKEN"),
2309 std::ffi::OsString::from("from-env"),
2310 )],
2311 ..Default::default()
2312 };
2313 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2314 let value = cmd
2315 .as_std()
2316 .get_envs()
2317 .find(|(k, _)| *k == std::ffi::OsStr::new("COPILOT_SDK_AUTH_TOKEN"))
2318 .and_then(|(_, v)| v);
2319 assert_eq!(value, Some(std::ffi::OsStr::new("from-env")));
2320 }
2321
2322 #[test]
2323 fn build_command_injects_github_token_by_default() {
2324 let opts = ClientOptions {
2325 github_token: Some("just-the-token".to_string()),
2326 ..Default::default()
2327 };
2328 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2329 let value = cmd
2330 .as_std()
2331 .get_envs()
2332 .find(|(k, _)| *k == std::ffi::OsStr::new("COPILOT_SDK_AUTH_TOKEN"))
2333 .and_then(|(_, v)| v);
2334 assert_eq!(value, Some(std::ffi::OsStr::new("just-the-token")));
2335 }
2336
2337 fn env_value<'a>(cmd: &'a tokio::process::Command, key: &str) -> Option<&'a std::ffi::OsStr> {
2338 cmd.as_std()
2339 .get_envs()
2340 .find(|(k, _)| *k == std::ffi::OsStr::new(key))
2341 .and_then(|(_, v)| v)
2342 }
2343
2344 #[test]
2345 fn telemetry_config_builder_composes() {
2346 let cfg = TelemetryConfig::new()
2347 .with_otlp_endpoint("http://collector:4318")
2348 .with_otlp_protocol(OtlpHttpProtocol::HttpProtobuf)
2349 .with_file_path(PathBuf::from("/var/log/copilot.jsonl"))
2350 .with_exporter_type(OtelExporterType::OtlpHttp)
2351 .with_source_name("my-app")
2352 .with_capture_content(true);
2353
2354 assert_eq!(cfg.otlp_endpoint.as_deref(), Some("http://collector:4318"));
2355 assert_eq!(cfg.otlp_protocol, Some(OtlpHttpProtocol::HttpProtobuf));
2356 assert_eq!(
2357 cfg.file_path.as_deref(),
2358 Some(Path::new("/var/log/copilot.jsonl")),
2359 );
2360 assert_eq!(cfg.exporter_type, Some(OtelExporterType::OtlpHttp));
2361 assert_eq!(cfg.source_name.as_deref(), Some("my-app"));
2362 assert_eq!(cfg.capture_content, Some(true));
2363 assert!(!cfg.is_empty());
2364 assert!(TelemetryConfig::new().is_empty());
2365 }
2366
2367 #[test]
2368 fn otlp_http_protocol_serde_matches_env_value() {
2369 for (protocol, wire) in [
2370 (OtlpHttpProtocol::HttpJson, "http/json"),
2371 (OtlpHttpProtocol::HttpProtobuf, "http/protobuf"),
2372 ] {
2373 assert_eq!(protocol.as_str(), wire);
2374
2375 let serialized = serde_json::to_string(&protocol).unwrap();
2376 assert_eq!(serialized, format!("\"{wire}\""));
2377
2378 let deserialized: OtlpHttpProtocol = serde_json::from_str(&serialized).unwrap();
2379 assert_eq!(deserialized, protocol);
2380 }
2381 }
2382
2383 #[test]
2384 fn build_command_sets_otel_env_when_telemetry_enabled() {
2385 let opts = ClientOptions {
2386 telemetry: Some(TelemetryConfig {
2387 otlp_endpoint: Some("http://collector:4318".to_string()),
2388 otlp_protocol: Some(OtlpHttpProtocol::HttpProtobuf),
2389 file_path: Some(PathBuf::from("/var/log/copilot.jsonl")),
2390 exporter_type: Some(OtelExporterType::OtlpHttp),
2391 source_name: Some("my-app".to_string()),
2392 capture_content: Some(true),
2393 }),
2394 ..Default::default()
2395 };
2396 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2397 assert_eq!(
2398 env_value(&cmd, "COPILOT_OTEL_ENABLED"),
2399 Some(std::ffi::OsStr::new("true")),
2400 );
2401 assert_eq!(
2402 env_value(&cmd, "OTEL_EXPORTER_OTLP_ENDPOINT"),
2403 Some(std::ffi::OsStr::new("http://collector:4318")),
2404 );
2405 assert_eq!(
2406 env_value(&cmd, "OTEL_EXPORTER_OTLP_PROTOCOL"),
2407 Some(std::ffi::OsStr::new("http/protobuf")),
2408 );
2409 assert_eq!(
2410 env_value(&cmd, "COPILOT_OTEL_FILE_EXPORTER_PATH"),
2411 Some(std::ffi::OsStr::new("/var/log/copilot.jsonl")),
2412 );
2413 assert_eq!(
2414 env_value(&cmd, "COPILOT_OTEL_EXPORTER_TYPE"),
2415 Some(std::ffi::OsStr::new("otlp-http")),
2416 );
2417 assert_eq!(
2418 env_value(&cmd, "COPILOT_OTEL_SOURCE_NAME"),
2419 Some(std::ffi::OsStr::new("my-app")),
2420 );
2421 assert_eq!(
2422 env_value(&cmd, "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"),
2423 Some(std::ffi::OsStr::new("true")),
2424 );
2425 }
2426
2427 #[test]
2428 fn build_command_omits_otel_env_when_telemetry_none() {
2429 let opts = ClientOptions::default();
2430 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2431 for key in [
2432 "COPILOT_OTEL_ENABLED",
2433 "OTEL_EXPORTER_OTLP_ENDPOINT",
2434 "OTEL_EXPORTER_OTLP_PROTOCOL",
2435 "COPILOT_OTEL_FILE_EXPORTER_PATH",
2436 "COPILOT_OTEL_EXPORTER_TYPE",
2437 "COPILOT_OTEL_SOURCE_NAME",
2438 "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT",
2439 ] {
2440 assert!(
2441 env_value(&cmd, key).is_none(),
2442 "expected {key} to be unset when telemetry is None",
2443 );
2444 }
2445 }
2446
2447 #[test]
2448 fn build_command_omits_unset_telemetry_fields() {
2449 let opts = ClientOptions {
2450 telemetry: Some(TelemetryConfig {
2451 otlp_endpoint: Some("http://collector:4318".to_string()),
2452 ..Default::default()
2453 }),
2454 ..Default::default()
2455 };
2456 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2457 assert_eq!(
2459 env_value(&cmd, "COPILOT_OTEL_ENABLED"),
2460 Some(std::ffi::OsStr::new("true")),
2461 );
2462 assert_eq!(
2463 env_value(&cmd, "OTEL_EXPORTER_OTLP_ENDPOINT"),
2464 Some(std::ffi::OsStr::new("http://collector:4318")),
2465 );
2466 for key in [
2468 "OTEL_EXPORTER_OTLP_PROTOCOL",
2469 "COPILOT_OTEL_FILE_EXPORTER_PATH",
2470 "COPILOT_OTEL_EXPORTER_TYPE",
2471 "COPILOT_OTEL_SOURCE_NAME",
2472 "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT",
2473 ] {
2474 assert!(env_value(&cmd, key).is_none(), "{key} should be unset");
2475 }
2476 }
2477
2478 #[test]
2479 fn build_command_lets_user_env_override_telemetry() {
2480 let opts = ClientOptions {
2481 telemetry: Some(TelemetryConfig {
2482 otlp_endpoint: Some("http://from-config:4318".to_string()),
2483 ..Default::default()
2484 }),
2485 env: vec![(
2486 std::ffi::OsString::from("OTEL_EXPORTER_OTLP_ENDPOINT"),
2487 std::ffi::OsString::from("http://from-user-env:4318"),
2488 )],
2489 ..Default::default()
2490 };
2491 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2492 assert_eq!(
2493 env_value(&cmd, "OTEL_EXPORTER_OTLP_ENDPOINT"),
2494 Some(std::ffi::OsStr::new("http://from-user-env:4318")),
2495 "user-supplied options.env should override telemetry config",
2496 );
2497 }
2498
2499 #[test]
2500 fn build_command_sets_copilot_home_env_when_configured() {
2501 let opts = ClientOptions::new().with_base_directory(PathBuf::from("/custom/copilot"));
2502 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2503 assert_eq!(
2504 env_value(&cmd, "COPILOT_HOME"),
2505 Some(std::ffi::OsStr::new("/custom/copilot")),
2506 );
2507
2508 let opts = ClientOptions::default();
2509 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2510 assert!(env_value(&cmd, "COPILOT_HOME").is_none());
2511 }
2512
2513 #[test]
2514 fn build_command_sets_connection_token_env_when_configured() {
2515 let opts = ClientOptions::new().with_transport(Transport::Tcp {
2516 port: 0,
2517 connection_token: Some("secret-token".to_string()),
2518 });
2519 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2520 assert_eq!(
2521 env_value(&cmd, "COPILOT_CONNECTION_TOKEN"),
2522 Some(std::ffi::OsStr::new("secret-token")),
2523 );
2524
2525 let opts = ClientOptions::default();
2526 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2527 assert!(env_value(&cmd, "COPILOT_CONNECTION_TOKEN").is_none());
2528 }
2529
2530 #[tokio::test]
2531 async fn start_rejects_empty_connection_token() {
2532 let opts = ClientOptions::new()
2533 .with_transport(Transport::Tcp {
2534 port: 0,
2535 connection_token: Some(String::new()),
2536 })
2537 .with_program(CliProgram::Path(PathBuf::from("/bin/echo")));
2538 let err = Client::start(opts).await.unwrap_err();
2539 assert!(
2540 matches!(err.kind(), ErrorKind::InvalidConfig),
2541 "got {err:?}"
2542 );
2543 }
2544
2545 #[tokio::test]
2546 async fn start_rejects_empty_external_connection_token() {
2547 let opts = ClientOptions::new()
2548 .with_transport(Transport::External {
2549 host: "127.0.0.1".to_string(),
2550 port: 1,
2551 connection_token: Some(String::new()),
2552 })
2553 .with_program(CliProgram::Path(PathBuf::from("/bin/echo")));
2554 let err = Client::start(opts).await.unwrap_err();
2555 assert!(
2556 matches!(err.kind(), ErrorKind::InvalidConfig),
2557 "got {err:?}"
2558 );
2559 }
2560
2561 #[test]
2562 fn telemetry_config_capture_content_serializes_as_lowercase_bool() {
2563 let opts_true = ClientOptions {
2564 telemetry: Some(TelemetryConfig {
2565 capture_content: Some(true),
2566 ..Default::default()
2567 }),
2568 ..Default::default()
2569 };
2570 let opts_false = ClientOptions {
2571 telemetry: Some(TelemetryConfig {
2572 capture_content: Some(false),
2573 ..Default::default()
2574 }),
2575 ..Default::default()
2576 };
2577 let cmd_true = Client::build_command(Path::new("/bin/echo"), &opts_true);
2578 let cmd_false = Client::build_command(Path::new("/bin/echo"), &opts_false);
2579 assert_eq!(
2580 env_value(
2581 &cmd_true,
2582 "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"
2583 ),
2584 Some(std::ffi::OsStr::new("true")),
2585 );
2586 assert_eq!(
2587 env_value(
2588 &cmd_false,
2589 "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"
2590 ),
2591 Some(std::ffi::OsStr::new("false")),
2592 );
2593 }
2594
2595 #[test]
2596 fn session_idle_timeout_args_are_omitted_by_default() {
2597 let opts = ClientOptions::default();
2598 assert!(Client::session_idle_timeout_args(&opts).is_empty());
2599 }
2600
2601 #[test]
2602 fn session_idle_timeout_args_omitted_for_zero() {
2603 let opts = ClientOptions {
2604 session_idle_timeout_seconds: Some(0),
2605 ..Default::default()
2606 };
2607 assert!(Client::session_idle_timeout_args(&opts).is_empty());
2608 }
2609
2610 #[test]
2611 fn session_idle_timeout_args_emit_flag_for_positive_value() {
2612 let opts = ClientOptions {
2613 session_idle_timeout_seconds: Some(300),
2614 ..Default::default()
2615 };
2616 assert_eq!(
2617 Client::session_idle_timeout_args(&opts),
2618 vec!["--session-idle-timeout".to_string(), "300".to_string()]
2619 );
2620 }
2621
2622 #[test]
2623 fn remote_args_omitted_by_default() {
2624 let opts = ClientOptions::default();
2625 assert!(Client::remote_args(&opts).is_empty());
2626 }
2627
2628 #[test]
2629 fn remote_args_emit_flag_when_enabled() {
2630 let opts = ClientOptions {
2631 enable_remote_sessions: true,
2632 ..Default::default()
2633 };
2634 assert_eq!(Client::remote_args(&opts), vec!["--remote".to_string()]);
2635 }
2636
2637 #[test]
2638 fn log_level_args_omitted_when_unset() {
2639 let opts = ClientOptions::default();
2640 assert!(opts.log_level.is_none());
2641 assert!(
2642 Client::log_level_args(&opts).is_empty(),
2643 "with no caller-supplied log_level the SDK must not pass --log-level"
2644 );
2645 }
2646
2647 #[test]
2648 fn log_level_args_emit_flag_when_set() {
2649 let opts = ClientOptions::default().with_log_level(LogLevel::Debug);
2650 assert_eq!(Client::log_level_args(&opts), vec!["--log-level", "debug"]);
2651 }
2652
2653 #[test]
2654 fn log_level_str_round_trips() {
2655 for level in [
2656 LogLevel::None,
2657 LogLevel::Error,
2658 LogLevel::Warning,
2659 LogLevel::Info,
2660 LogLevel::Debug,
2661 LogLevel::All,
2662 ] {
2663 let s = level.as_str();
2664 let json = serde_json::to_string(&level).unwrap();
2665 assert_eq!(json, format!("\"{s}\""));
2666 let parsed: LogLevel = serde_json::from_str(&json).unwrap();
2667 assert_eq!(parsed, level);
2668 }
2669 }
2670
2671 #[test]
2672 fn client_options_debug_redacts_handler() {
2673 struct StubHandler;
2674 #[async_trait]
2675 impl ListModelsHandler for StubHandler {
2676 async fn list_models(&self) -> Result<Vec<Model>> {
2677 Ok(vec![])
2678 }
2679 }
2680 let opts = ClientOptions {
2681 on_list_models: Some(Arc::new(StubHandler)),
2682 github_token: Some("secret-token".into()),
2683 ..Default::default()
2684 };
2685 let debug = format!("{opts:?}");
2686 assert!(debug.contains("on_list_models: Some(\"<set>\")"));
2687 assert!(debug.contains("github_token: Some(\"<redacted>\")"));
2688 assert!(!debug.contains("secret-token"));
2689 }
2690
2691 #[tokio::test]
2692 async fn list_models_uses_on_list_models_handler_when_set() {
2693 use std::sync::atomic::{AtomicUsize, Ordering};
2694
2695 struct CountingHandler {
2696 calls: Arc<AtomicUsize>,
2697 models: Vec<Model>,
2698 }
2699 #[async_trait]
2700 impl ListModelsHandler for CountingHandler {
2701 async fn list_models(&self) -> Result<Vec<Model>> {
2702 self.calls.fetch_add(1, Ordering::SeqCst);
2703 Ok(self.models.clone())
2704 }
2705 }
2706
2707 let calls = Arc::new(AtomicUsize::new(0));
2708 let model = Model {
2709 id: "byok-gpt-4".into(),
2710 name: "BYOK GPT-4".into(),
2711 ..Default::default()
2712 };
2713 let handler: Arc<dyn ListModelsHandler> = Arc::new(CountingHandler {
2714 calls: Arc::clone(&calls),
2715 models: vec![model.clone()],
2716 });
2717
2718 let client = client_with_list_models_handler(handler);
2719
2720 let result = client.list_models().await.unwrap();
2721 assert_eq!(result.len(), 1);
2722 assert_eq!(result[0].id, "byok-gpt-4");
2723 assert_eq!(calls.load(Ordering::SeqCst), 1);
2724 }
2725
2726 #[tokio::test]
2727 async fn list_models_serializes_concurrent_cache_misses() {
2728 use std::sync::atomic::{AtomicUsize, Ordering};
2729
2730 struct SlowCountingHandler {
2731 calls: Arc<AtomicUsize>,
2732 models: Vec<Model>,
2733 }
2734 #[async_trait]
2735 impl ListModelsHandler for SlowCountingHandler {
2736 async fn list_models(&self) -> Result<Vec<Model>> {
2737 self.calls.fetch_add(1, Ordering::SeqCst);
2738 tokio::time::sleep(std::time::Duration::from_millis(25)).await;
2739 Ok(self.models.clone())
2740 }
2741 }
2742
2743 let calls = Arc::new(AtomicUsize::new(0));
2744 let model = Model {
2745 id: "single-flight-model".into(),
2746 name: "Single Flight Model".into(),
2747 ..Default::default()
2748 };
2749 let handler: Arc<dyn ListModelsHandler> = Arc::new(SlowCountingHandler {
2750 calls: Arc::clone(&calls),
2751 models: vec![model],
2752 });
2753 let client = client_with_list_models_handler(handler);
2754
2755 let (first, second) = tokio::join!(client.list_models(), client.list_models());
2756 assert_eq!(first.unwrap()[0].id, "single-flight-model");
2757 assert_eq!(second.unwrap()[0].id, "single-flight-model");
2758 assert_eq!(calls.load(Ordering::SeqCst), 1);
2759 }
2760
2761 #[tokio::test]
2762 async fn cancelled_resume_session_unregisters_pending_session() {
2763 let (client_write, _server_read) = tokio::io::duplex(8192);
2764 let (_server_write, client_read) = tokio::io::duplex(8192);
2765 let client = Client::from_streams(client_read, client_write, std::env::temp_dir()).unwrap();
2766 let session_id = SessionId::new("resume-cancel-test");
2767 let handle = tokio::spawn({
2768 let client = client.clone();
2769 async move {
2770 client
2771 .resume_session(ResumeSessionConfig::new(session_id))
2772 .await
2773 }
2774 });
2775
2776 wait_for_pending_session_registration(&client).await;
2777 handle.abort();
2778 let _ = handle.await;
2779
2780 assert!(client.inner.router.session_ids().is_empty());
2781 client.force_stop();
2782 }
2783
2784 fn client_with_list_models_handler(handler: Arc<dyn ListModelsHandler>) -> Client {
2785 Client {
2786 inner: Arc::new(ClientInner {
2787 child: parking_lot::Mutex::new(None),
2788 rpc: {
2789 let (req_tx, _req_rx) = mpsc::unbounded_channel();
2790 let (notif_tx, _notif_rx) = broadcast::channel(16);
2791 let (read_pipe, _write_pipe) = tokio::io::duplex(64);
2792 let (_unused_read, write_pipe) = tokio::io::duplex(64);
2793 JsonRpcClient::new(write_pipe, read_pipe, notif_tx, req_tx)
2794 },
2795 cwd: PathBuf::from("."),
2796 request_rx: parking_lot::Mutex::new(None),
2797 notification_tx: broadcast::channel(16).0,
2798 router: router::SessionRouter::new(),
2799 negotiated_protocol_version: OnceLock::new(),
2800 state: parking_lot::Mutex::new(ConnectionState::Connected),
2801 lifecycle_tx: broadcast::channel(16).0,
2802 on_list_models: Some(handler),
2803 models_cache: parking_lot::Mutex::new(Arc::new(tokio::sync::OnceCell::new())),
2804 session_fs_configured: false,
2805 session_fs_sqlite_declared: false,
2806 llm_inference: OnceLock::new(),
2807 on_github_telemetry: None,
2808 on_get_trace_context: None,
2809 effective_connection_token: None,
2810 mode: ClientMode::default(),
2811 }),
2812 }
2813 }
2814
2815 async fn wait_for_pending_session_registration(client: &Client) {
2816 let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(1);
2817 while client.inner.router.session_ids().is_empty() {
2818 assert!(
2819 tokio::time::Instant::now() < deadline,
2820 "session was not registered"
2821 );
2822 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
2823 }
2824 }
2825}