1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3#![deny(rustdoc::broken_intra_doc_links)]
4#![cfg_attr(test, allow(clippy::unwrap_used))]
5
6pub mod embeddedcli;
8pub mod handler;
10pub mod hooks;
12mod jsonrpc;
13pub mod permission;
15pub mod resolve;
17mod router;
18pub mod session;
20pub mod session_fs;
22mod session_fs_dispatch;
23pub mod subscription;
25pub mod tool;
27pub mod trace_context;
29pub mod transforms;
31pub mod types;
33
34pub mod generated;
36
37use std::ffi::OsString;
38use std::path::{Path, PathBuf};
39use std::process::Stdio;
40use std::sync::{Arc, OnceLock};
41use std::time::Instant;
42
43use async_trait::async_trait;
44pub(crate) use jsonrpc::{
47 JsonRpcClient, JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, error_codes,
48};
49
50#[cfg(feature = "test-support")]
52pub mod test_support {
53 pub use crate::jsonrpc::{
54 JsonRpcClient, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse,
55 error_codes,
56 };
57}
58use serde::{Deserialize, Serialize};
59use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, BufReader};
60use tokio::net::TcpStream;
61use tokio::process::{Child, Command};
62use tokio::sync::{broadcast, mpsc, oneshot};
63use tracing::{Instrument, debug, error, info, warn};
64pub use types::*;
65
66mod sdk_protocol_version;
67pub use sdk_protocol_version::{SDK_PROTOCOL_VERSION, get_sdk_protocol_version};
68pub use subscription::{EventSubscription, Lagged, LifecycleSubscription, RecvError};
69
70const MIN_PROTOCOL_VERSION: u32 = 2;
72
73#[derive(Debug, thiserror::Error)]
75#[non_exhaustive]
76pub enum Error {
77 #[error("protocol error: {0}")]
79 Protocol(ProtocolError),
80
81 #[error("RPC error {code}: {message}")]
83 Rpc {
84 code: i32,
86 message: String,
88 },
89
90 #[error("session error: {0}")]
92 Session(SessionError),
93
94 #[error(transparent)]
96 Io(#[from] std::io::Error),
97
98 #[error(transparent)]
100 Json(#[from] serde_json::Error),
101
102 #[error("binary not found: {name} ({hint})")]
104 BinaryNotFound {
105 name: &'static str,
107 hint: &'static str,
109 },
110
111 #[error("invalid client configuration: {0}")]
116 InvalidConfig(String),
117}
118
119impl Error {
120 pub fn is_transport_failure(&self) -> bool {
124 matches!(
125 self,
126 Error::Protocol(ProtocolError::RequestCancelled) | Error::Io(_)
127 )
128 }
129}
130
131#[derive(Debug)]
143pub struct StopErrors(Vec<Error>);
144
145impl StopErrors {
146 pub fn errors(&self) -> &[Error] {
149 &self.0
150 }
151
152 pub fn into_errors(self) -> Vec<Error> {
154 self.0
155 }
156}
157
158impl std::fmt::Display for StopErrors {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 match self.0.as_slice() {
161 [] => write!(f, "stop completed with no errors"),
162 [only] => write!(f, "stop failed: {only}"),
163 [first, rest @ ..] => write!(
164 f,
165 "stop failed with {n} errors; first: {first}",
166 n = 1 + rest.len(),
167 ),
168 }
169 }
170}
171
172impl std::error::Error for StopErrors {
173 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
174 self.0
175 .first()
176 .map(|e| e as &(dyn std::error::Error + 'static))
177 }
178}
179
180#[derive(Debug, thiserror::Error)]
182#[non_exhaustive]
183pub enum ProtocolError {
184 #[error("missing Content-Length header")]
186 MissingContentLength,
187
188 #[error("invalid Content-Length value: \"{0}\"")]
190 InvalidContentLength(String),
191
192 #[error("request cancelled")]
194 RequestCancelled,
195
196 #[error("timed out waiting for CLI to report listening port")]
198 CliStartupTimeout,
199
200 #[error("CLI exited before reporting listening port")]
202 CliStartupFailed,
203
204 #[error("version mismatch: server={server}, supported={min}–{max}")]
206 VersionMismatch {
207 server: u32,
209 min: u32,
211 max: u32,
213 },
214
215 #[error("version changed: was {previous}, now {current}")]
217 VersionChanged {
218 previous: u32,
220 current: u32,
222 },
223}
224
225#[derive(Debug, thiserror::Error)]
227#[non_exhaustive]
228pub enum SessionError {
229 #[error("session not found: {0}")]
231 NotFound(SessionId),
232
233 #[error("{0}")]
235 AgentError(String),
236
237 #[error("timed out after {0:?}")]
239 Timeout(std::time::Duration),
240
241 #[error("cannot send while send_and_wait is in flight")]
243 SendWhileWaiting,
244
245 #[error("event loop closed before session reached idle")]
247 EventLoopClosed,
248
249 #[error(
252 "elicitation not supported by host — check session.capabilities().ui.elicitation first"
253 )]
254 ElicitationNotSupported,
255
256 #[error(
261 "session was created on a client with session_fs configured but no SessionFsProvider was supplied"
262 )]
263 SessionFsProviderRequired,
264
265 #[error("invalid SessionFsConfig: {0}")]
269 InvalidSessionFsConfig(String),
270
271 #[error("CLI returned session ID {returned} after SDK registered {requested}")]
273 SessionIdMismatch {
274 requested: SessionId,
276 returned: SessionId,
278 },
279}
280
281#[derive(Debug, Default)]
283#[non_exhaustive]
284pub enum Transport {
285 #[default]
287 Stdio,
288 Tcp {
290 port: u16,
292 },
293 External {
295 host: String,
297 port: u16,
299 },
300}
301
302#[derive(Debug, Clone, Default)]
304pub enum CliProgram {
305 #[default]
308 Resolve,
309 Path(PathBuf),
311}
312
313impl From<PathBuf> for CliProgram {
314 fn from(path: PathBuf) -> Self {
315 Self::Path(path)
316 }
317}
318
319#[non_exhaustive]
328pub struct ClientOptions {
329 pub program: CliProgram,
331 pub prefix_args: Vec<OsString>,
333 pub cwd: PathBuf,
335 pub env: Vec<(OsString, OsString)>,
337 pub env_remove: Vec<OsString>,
339 pub extra_args: Vec<String>,
341 pub transport: Transport,
343 pub github_token: Option<String>,
348 pub use_logged_in_user: Option<bool>,
352 pub log_level: Option<LogLevel>,
355 pub session_idle_timeout_seconds: Option<u64>,
361 pub on_list_models: Option<Arc<dyn ListModelsHandler>>,
369 pub session_fs: Option<SessionFsConfig>,
377 pub on_get_trace_context: Option<Arc<dyn TraceContextProvider>>,
387 pub telemetry: Option<TelemetryConfig>,
391 pub copilot_home: Option<PathBuf>,
396 pub tcp_connection_token: Option<String>,
406 pub remote: bool,
412}
413
414impl std::fmt::Debug for ClientOptions {
415 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416 f.debug_struct("ClientOptions")
417 .field("program", &self.program)
418 .field("prefix_args", &self.prefix_args)
419 .field("cwd", &self.cwd)
420 .field("env", &self.env)
421 .field("env_remove", &self.env_remove)
422 .field("extra_args", &self.extra_args)
423 .field("transport", &self.transport)
424 .field(
425 "github_token",
426 &self.github_token.as_ref().map(|_| "<redacted>"),
427 )
428 .field("use_logged_in_user", &self.use_logged_in_user)
429 .field("log_level", &self.log_level)
430 .field(
431 "session_idle_timeout_seconds",
432 &self.session_idle_timeout_seconds,
433 )
434 .field(
435 "on_list_models",
436 &self.on_list_models.as_ref().map(|_| "<set>"),
437 )
438 .field("session_fs", &self.session_fs)
439 .field(
440 "on_get_trace_context",
441 &self.on_get_trace_context.as_ref().map(|_| "<set>"),
442 )
443 .field("telemetry", &self.telemetry)
444 .field("copilot_home", &self.copilot_home)
445 .field(
446 "tcp_connection_token",
447 &self.tcp_connection_token.as_ref().map(|_| "<redacted>"),
448 )
449 .field("remote", &self.remote)
450 .finish()
451 }
452}
453
454#[async_trait]
463pub trait ListModelsHandler: Send + Sync + 'static {
464 async fn list_models(&self) -> Result<Vec<Model>, Error>;
466}
467
468#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
470#[serde(rename_all = "lowercase")]
471pub enum LogLevel {
472 None,
474 Error,
476 Warning,
478 Info,
480 Debug,
482 All,
484}
485
486impl LogLevel {
487 pub fn as_str(self) -> &'static str {
489 match self {
490 Self::None => "none",
491 Self::Error => "error",
492 Self::Warning => "warning",
493 Self::Info => "info",
494 Self::Debug => "debug",
495 Self::All => "all",
496 }
497 }
498}
499
500impl std::fmt::Display for LogLevel {
501 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
502 f.write_str(self.as_str())
503 }
504}
505
506#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
511#[serde(rename_all = "kebab-case")]
512#[non_exhaustive]
513pub enum OtelExporterType {
514 OtlpHttp,
517 File,
520}
521
522impl OtelExporterType {
523 pub fn as_str(self) -> &'static str {
525 match self {
526 Self::OtlpHttp => "otlp-http",
527 Self::File => "file",
528 }
529 }
530}
531
532#[derive(Debug, Clone, Default)]
565#[non_exhaustive]
566pub struct TelemetryConfig {
567 pub otlp_endpoint: Option<String>,
569 pub file_path: Option<PathBuf>,
571 pub exporter_type: Option<OtelExporterType>,
574 pub source_name: Option<String>,
578 pub capture_content: Option<bool>,
582}
583
584impl TelemetryConfig {
585 pub fn new() -> Self {
588 Self::default()
589 }
590
591 pub fn with_otlp_endpoint(mut self, endpoint: impl Into<String>) -> Self {
593 self.otlp_endpoint = Some(endpoint.into());
594 self
595 }
596
597 pub fn with_file_path(mut self, path: impl Into<PathBuf>) -> Self {
599 self.file_path = Some(path.into());
600 self
601 }
602
603 pub fn with_exporter_type(mut self, exporter_type: OtelExporterType) -> Self {
605 self.exporter_type = Some(exporter_type);
606 self
607 }
608
609 pub fn with_source_name(mut self, source_name: impl Into<String>) -> Self {
613 self.source_name = Some(source_name.into());
614 self
615 }
616
617 pub fn with_capture_content(mut self, capture: bool) -> Self {
621 self.capture_content = Some(capture);
622 self
623 }
624
625 pub fn is_empty(&self) -> bool {
628 self.otlp_endpoint.is_none()
629 && self.file_path.is_none()
630 && self.exporter_type.is_none()
631 && self.source_name.is_none()
632 && self.capture_content.is_none()
633 }
634}
635
636impl Default for ClientOptions {
637 fn default() -> Self {
638 Self {
639 program: CliProgram::Resolve,
640 prefix_args: Vec::new(),
641 cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
642 env: Vec::new(),
643 env_remove: Vec::new(),
644 extra_args: Vec::new(),
645 transport: Transport::default(),
646 github_token: None,
647 use_logged_in_user: None,
648 log_level: None,
649 session_idle_timeout_seconds: None,
650 on_list_models: None,
651 session_fs: None,
652 on_get_trace_context: None,
653 telemetry: None,
654 copilot_home: None,
655 tcp_connection_token: None,
656 remote: false,
657 }
658 }
659}
660
661impl ClientOptions {
662 pub fn new() -> Self {
678 Self::default()
679 }
680
681 pub fn with_program(mut self, program: impl Into<CliProgram>) -> Self {
683 self.program = program.into();
684 self
685 }
686
687 pub fn with_prefix_args<I, S>(mut self, args: I) -> Self
689 where
690 I: IntoIterator<Item = S>,
691 S: Into<OsString>,
692 {
693 self.prefix_args = args.into_iter().map(Into::into).collect();
694 self
695 }
696
697 pub fn with_cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
699 self.cwd = cwd.into();
700 self
701 }
702
703 pub fn with_env<I, K, V>(mut self, env: I) -> Self
705 where
706 I: IntoIterator<Item = (K, V)>,
707 K: Into<OsString>,
708 V: Into<OsString>,
709 {
710 self.env = env.into_iter().map(|(k, v)| (k.into(), v.into())).collect();
711 self
712 }
713
714 pub fn with_env_remove<I, S>(mut self, names: I) -> Self
716 where
717 I: IntoIterator<Item = S>,
718 S: Into<OsString>,
719 {
720 self.env_remove = names.into_iter().map(Into::into).collect();
721 self
722 }
723
724 pub fn with_extra_args<I, S>(mut self, args: I) -> Self
726 where
727 I: IntoIterator<Item = S>,
728 S: Into<String>,
729 {
730 self.extra_args = args.into_iter().map(Into::into).collect();
731 self
732 }
733
734 pub fn with_transport(mut self, transport: Transport) -> Self {
736 self.transport = transport;
737 self
738 }
739
740 pub fn with_github_token(mut self, token: impl Into<String>) -> Self {
743 self.github_token = Some(token.into());
744 self
745 }
746
747 pub fn with_use_logged_in_user(mut self, use_logged_in: bool) -> Self {
750 self.use_logged_in_user = Some(use_logged_in);
751 self
752 }
753
754 pub fn with_log_level(mut self, level: LogLevel) -> Self {
756 self.log_level = Some(level);
757 self
758 }
759
760 pub fn with_session_idle_timeout_seconds(mut self, seconds: u64) -> Self {
763 self.session_idle_timeout_seconds = Some(seconds);
764 self
765 }
766
767 pub fn with_list_models_handler<H>(mut self, handler: H) -> Self
770 where
771 H: ListModelsHandler + 'static,
772 {
773 self.on_list_models = Some(Arc::new(handler));
774 self
775 }
776
777 pub fn with_session_fs(mut self, config: SessionFsConfig) -> Self {
779 self.session_fs = Some(config);
780 self
781 }
782
783 pub fn with_trace_context_provider<P>(mut self, provider: P) -> Self
787 where
788 P: TraceContextProvider + 'static,
789 {
790 self.on_get_trace_context = Some(Arc::new(provider));
791 self
792 }
793
794 pub fn with_telemetry(mut self, config: TelemetryConfig) -> Self {
796 self.telemetry = Some(config);
797 self
798 }
799
800 pub fn with_copilot_home(mut self, home: impl Into<PathBuf>) -> Self {
803 self.copilot_home = Some(home.into());
804 self
805 }
806
807 pub fn with_tcp_connection_token(mut self, token: impl Into<String>) -> Self {
811 self.tcp_connection_token = Some(token.into());
812 self
813 }
814
815 pub fn with_remote(mut self, enabled: bool) -> Self {
818 self.remote = enabled;
819 self
820 }
821}
822
823fn validate_session_fs_config(cfg: &SessionFsConfig) -> Result<(), Error> {
825 if cfg.initial_cwd.trim().is_empty() {
826 return Err(Error::Session(SessionError::InvalidSessionFsConfig(
827 "initial_cwd must not be empty".to_string(),
828 )));
829 }
830 if cfg.session_state_path.trim().is_empty() {
831 return Err(Error::Session(SessionError::InvalidSessionFsConfig(
832 "session_state_path must not be empty".to_string(),
833 )));
834 }
835 Ok(())
836}
837
838fn generate_connection_token() -> String {
845 let mut bytes = [0u8; 16];
846 getrandom::getrandom(&mut bytes)
847 .expect("OS CSPRNG (getrandom) is unavailable; cannot generate connection token");
848 let mut hex = String::with_capacity(32);
849 for byte in bytes {
850 use std::fmt::Write;
851 let _ = write!(hex, "{byte:02x}");
852 }
853 hex
854}
855
856#[derive(Clone)]
861pub struct Client {
862 inner: Arc<ClientInner>,
863}
864
865impl std::fmt::Debug for Client {
866 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
867 f.debug_struct("Client")
868 .field("cwd", &self.inner.cwd)
869 .field("pid", &self.pid())
870 .finish()
871 }
872}
873
874struct ClientInner {
875 child: parking_lot::Mutex<Option<Child>>,
876 rpc: JsonRpcClient,
877 cwd: PathBuf,
878 request_rx: parking_lot::Mutex<Option<mpsc::UnboundedReceiver<JsonRpcRequest>>>,
879 notification_tx: broadcast::Sender<JsonRpcNotification>,
880 router: router::SessionRouter,
881 negotiated_protocol_version: OnceLock<u32>,
882 state: parking_lot::Mutex<ConnectionState>,
883 lifecycle_tx: broadcast::Sender<SessionLifecycleEvent>,
884 on_list_models: Option<Arc<dyn ListModelsHandler>>,
885 models_cache: parking_lot::Mutex<Arc<tokio::sync::OnceCell<Vec<Model>>>>,
886 session_fs_configured: bool,
887 session_fs_sqlite_declared: bool,
888 on_get_trace_context: Option<Arc<dyn TraceContextProvider>>,
889 effective_connection_token: Option<String>,
894}
895
896impl Client {
897 pub async fn start(options: ClientOptions) -> Result<Self, Error> {
910 let start_time = Instant::now();
911 if let Some(cfg) = &options.session_fs {
912 validate_session_fs_config(cfg)?;
913 }
914 if matches!(options.transport, Transport::External { .. }) {
917 if options.github_token.is_some() {
918 return Err(Error::InvalidConfig(
919 "github_token cannot be used with Transport::External \
920 (external server manages its own auth)"
921 .to_string(),
922 ));
923 }
924 if options.use_logged_in_user == Some(true) {
925 return Err(Error::InvalidConfig(
926 "use_logged_in_user cannot be used with Transport::External \
927 (external server manages its own auth)"
928 .to_string(),
929 ));
930 }
931 }
932 if let Some(token) = &options.tcp_connection_token {
936 if token.is_empty() {
937 return Err(Error::InvalidConfig(
938 "tcp_connection_token must be a non-empty string".to_string(),
939 ));
940 }
941 if matches!(options.transport, Transport::Stdio) {
942 return Err(Error::InvalidConfig(
943 "tcp_connection_token cannot be used with Transport::Stdio".to_string(),
944 ));
945 }
946 }
947 let effective_connection_token: Option<String> = match &options.transport {
948 Transport::Stdio => None,
949 Transport::Tcp { .. } => Some(
950 options
951 .tcp_connection_token
952 .clone()
953 .unwrap_or_else(generate_connection_token),
954 ),
955 Transport::External { .. } => options.tcp_connection_token.clone(),
956 };
957 let mut options = options;
958 if matches!(options.transport, Transport::Tcp { .. })
959 && options.tcp_connection_token.is_none()
960 {
961 options.tcp_connection_token = effective_connection_token.clone();
964 }
965 let session_fs_config = options.session_fs.clone();
966 let session_fs_sqlite_declared = session_fs_config
967 .as_ref()
968 .and_then(|c| c.capabilities.as_ref())
969 .is_some_and(|caps| caps.sqlite);
970 let program = match &options.program {
971 CliProgram::Path(path) => {
972 info!(path = %path.display(), "using explicit copilot CLI path");
973 path.clone()
974 }
975 CliProgram::Resolve => {
976 let resolved = resolve::copilot_binary()?;
977 info!(path = %resolved.display(), "resolved copilot CLI");
978 #[cfg(windows)]
979 {
980 if let Some(ext) = resolved.extension().and_then(|e| e.to_str()).filter(|ext| {
981 ext.eq_ignore_ascii_case("cmd") || ext.eq_ignore_ascii_case("bat")
982 }) {
983 warn!(
984 path = %resolved.display(),
985 ext = %ext,
986 "resolved copilot CLI is a .cmd/.bat wrapper; \
987 this may cause console window flashes on Windows"
988 );
989 }
990 }
991 resolved
992 }
993 };
994
995 let client = match options.transport {
996 Transport::External { ref host, port } => {
997 info!(host = %host, port = %port, "connecting to external CLI server");
998 let connect_start = Instant::now();
999 let stream = TcpStream::connect((host.as_str(), port)).await?;
1000 debug!(
1001 elapsed_ms = connect_start.elapsed().as_millis(),
1002 host = %host,
1003 port,
1004 "Client::start TCP connect complete"
1005 );
1006 let (reader, writer) = tokio::io::split(stream);
1007 Self::from_transport(
1008 reader,
1009 writer,
1010 None,
1011 options.cwd,
1012 options.on_list_models,
1013 session_fs_config.is_some(),
1014 session_fs_sqlite_declared,
1015 options.on_get_trace_context,
1016 effective_connection_token.clone(),
1017 )?
1018 }
1019 Transport::Tcp { port } => {
1020 let (mut child, actual_port) = Self::spawn_tcp(&program, &options, port).await?;
1021 let connect_start = Instant::now();
1022 let stream = TcpStream::connect(("127.0.0.1", actual_port)).await?;
1023 debug!(
1024 elapsed_ms = connect_start.elapsed().as_millis(),
1025 port = actual_port,
1026 "Client::start TCP connect complete"
1027 );
1028 let (reader, writer) = tokio::io::split(stream);
1029 Self::drain_stderr(&mut child);
1030 Self::from_transport(
1031 reader,
1032 writer,
1033 Some(child),
1034 options.cwd,
1035 options.on_list_models,
1036 session_fs_config.is_some(),
1037 session_fs_sqlite_declared,
1038 options.on_get_trace_context,
1039 effective_connection_token.clone(),
1040 )?
1041 }
1042 Transport::Stdio => {
1043 let mut child = Self::spawn_stdio(&program, &options)?;
1044 let stdin = child.stdin.take().expect("stdin is piped");
1045 let stdout = child.stdout.take().expect("stdout is piped");
1046 Self::drain_stderr(&mut child);
1047 Self::from_transport(
1048 stdout,
1049 stdin,
1050 Some(child),
1051 options.cwd,
1052 options.on_list_models,
1053 session_fs_config.is_some(),
1054 session_fs_sqlite_declared,
1055 options.on_get_trace_context,
1056 effective_connection_token.clone(),
1057 )?
1058 }
1059 };
1060
1061 debug!(
1062 elapsed_ms = start_time.elapsed().as_millis(),
1063 "Client::start transport setup complete"
1064 );
1065 client.verify_protocol_version().await?;
1066 debug!(
1067 elapsed_ms = start_time.elapsed().as_millis(),
1068 "Client::start protocol verification complete"
1069 );
1070 if let Some(cfg) = session_fs_config {
1071 let session_fs_start = Instant::now();
1072 let capabilities = cfg.capabilities.as_ref().map(|c| {
1073 crate::generated::api_types::SessionFsSetProviderCapabilities {
1074 sqlite: Some(c.sqlite),
1075 }
1076 });
1077 let request = crate::generated::api_types::SessionFsSetProviderRequest {
1078 capabilities,
1079 conventions: cfg.conventions.into_wire(),
1080 initial_cwd: cfg.initial_cwd,
1081 session_state_path: cfg.session_state_path,
1082 };
1083 client.rpc().session_fs().set_provider(request).await?;
1084 debug!(
1085 elapsed_ms = session_fs_start.elapsed().as_millis(),
1086 "Client::start session filesystem setup complete"
1087 );
1088 }
1089 debug!(
1090 elapsed_ms = start_time.elapsed().as_millis(),
1091 "Client::start complete"
1092 );
1093 Ok(client)
1094 }
1095
1096 pub fn from_streams(
1100 reader: impl AsyncRead + Unpin + Send + 'static,
1101 writer: impl AsyncWrite + Unpin + Send + 'static,
1102 cwd: PathBuf,
1103 ) -> Result<Self, Error> {
1104 Self::from_transport(reader, writer, None, cwd, None, false, false, None, None)
1105 }
1106
1107 #[cfg(any(test, feature = "test-support"))]
1115 pub fn from_streams_with_trace_provider(
1116 reader: impl AsyncRead + Unpin + Send + 'static,
1117 writer: impl AsyncWrite + Unpin + Send + 'static,
1118 cwd: PathBuf,
1119 provider: Arc<dyn TraceContextProvider>,
1120 ) -> Result<Self, Error> {
1121 Self::from_transport(
1122 reader,
1123 writer,
1124 None,
1125 cwd,
1126 None,
1127 false,
1128 false,
1129 Some(provider),
1130 None,
1131 )
1132 }
1133
1134 #[cfg(any(test, feature = "test-support"))]
1138 pub fn from_streams_with_connection_token(
1139 reader: impl AsyncRead + Unpin + Send + 'static,
1140 writer: impl AsyncWrite + Unpin + Send + 'static,
1141 cwd: PathBuf,
1142 token: Option<String>,
1143 ) -> Result<Self, Error> {
1144 Self::from_transport(reader, writer, None, cwd, None, false, false, None, token)
1145 }
1146
1147 #[cfg(any(test, feature = "test-support"))]
1153 pub fn generate_connection_token_for_test() -> String {
1154 generate_connection_token()
1155 }
1156
1157 #[allow(clippy::too_many_arguments)]
1158 fn from_transport(
1159 reader: impl AsyncRead + Unpin + Send + 'static,
1160 writer: impl AsyncWrite + Unpin + Send + 'static,
1161 child: Option<Child>,
1162 cwd: PathBuf,
1163 on_list_models: Option<Arc<dyn ListModelsHandler>>,
1164 session_fs_configured: bool,
1165 session_fs_sqlite_declared: bool,
1166 on_get_trace_context: Option<Arc<dyn TraceContextProvider>>,
1167 effective_connection_token: Option<String>,
1168 ) -> Result<Self, Error> {
1169 let setup_start = Instant::now();
1170 let (request_tx, request_rx) = mpsc::unbounded_channel::<JsonRpcRequest>();
1171 let (notification_broadcast_tx, _) = broadcast::channel::<JsonRpcNotification>(1024);
1172 let rpc = JsonRpcClient::new(
1173 writer,
1174 reader,
1175 notification_broadcast_tx.clone(),
1176 request_tx,
1177 );
1178
1179 let pid = child.as_ref().and_then(|c| c.id());
1180 info!(pid = ?pid, "copilot CLI client ready");
1181
1182 let client = Self {
1183 inner: Arc::new(ClientInner {
1184 child: parking_lot::Mutex::new(child),
1185 rpc,
1186 cwd,
1187 request_rx: parking_lot::Mutex::new(Some(request_rx)),
1188 notification_tx: notification_broadcast_tx,
1189 router: router::SessionRouter::new(),
1190 negotiated_protocol_version: OnceLock::new(),
1191 state: parking_lot::Mutex::new(ConnectionState::Connected),
1192 lifecycle_tx: broadcast::channel(256).0,
1193 on_list_models,
1194 models_cache: parking_lot::Mutex::new(Arc::new(tokio::sync::OnceCell::new())),
1195 session_fs_configured,
1196 session_fs_sqlite_declared,
1197 on_get_trace_context,
1198 effective_connection_token,
1199 }),
1200 };
1201 client.spawn_lifecycle_dispatcher();
1202 debug!(
1203 elapsed_ms = setup_start.elapsed().as_millis(),
1204 pid = ?pid,
1205 "Client::from_transport setup complete"
1206 );
1207 Ok(client)
1208 }
1209
1210 fn spawn_lifecycle_dispatcher(&self) {
1214 let inner = Arc::clone(&self.inner);
1215 let mut notif_rx = inner.notification_tx.subscribe();
1216 tokio::spawn(async move {
1217 loop {
1218 match notif_rx.recv().await {
1219 Ok(notification) => {
1220 if notification.method != "session.lifecycle" {
1221 continue;
1222 }
1223 let Some(params) = notification.params.as_ref() else {
1224 continue;
1225 };
1226 let event: SessionLifecycleEvent =
1227 match serde_json::from_value(params.clone()) {
1228 Ok(e) => e,
1229 Err(e) => {
1230 warn!(
1231 error = %e,
1232 "failed to deserialize session.lifecycle notification"
1233 );
1234 continue;
1235 }
1236 };
1237 let _ = inner.lifecycle_tx.send(event);
1240 }
1241 Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
1242 warn!(missed = n, "lifecycle dispatcher lagged");
1243 }
1244 Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
1245 }
1246 }
1247 });
1248 }
1249
1250 fn build_command(program: &Path, options: &ClientOptions) -> Command {
1251 let mut command = Command::new(program);
1252 for arg in &options.prefix_args {
1253 command.arg(arg);
1254 }
1255 if let Some(token) = &options.github_token {
1258 command.env("COPILOT_SDK_AUTH_TOKEN", token);
1259 }
1260 if let Some(telemetry) = &options.telemetry {
1263 command.env("COPILOT_OTEL_ENABLED", "true");
1264 if let Some(endpoint) = &telemetry.otlp_endpoint {
1265 command.env("OTEL_EXPORTER_OTLP_ENDPOINT", endpoint);
1266 }
1267 if let Some(path) = &telemetry.file_path {
1268 command.env("COPILOT_OTEL_FILE_EXPORTER_PATH", path);
1269 }
1270 if let Some(exporter) = telemetry.exporter_type {
1271 command.env("COPILOT_OTEL_EXPORTER_TYPE", exporter.as_str());
1272 }
1273 if let Some(source) = &telemetry.source_name {
1274 command.env("COPILOT_OTEL_SOURCE_NAME", source);
1275 }
1276 if let Some(capture) = telemetry.capture_content {
1277 command.env(
1278 "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT",
1279 if capture { "true" } else { "false" },
1280 );
1281 }
1282 }
1283 if let Some(home) = &options.copilot_home {
1284 command.env("COPILOT_HOME", home);
1285 }
1286 if let Some(token) = &options.tcp_connection_token {
1287 command.env("COPILOT_CONNECTION_TOKEN", token);
1288 }
1289 for (key, value) in &options.env {
1290 command.env(key, value);
1291 }
1292 for key in &options.env_remove {
1293 command.env_remove(key);
1294 }
1295 command
1296 .current_dir(&options.cwd)
1297 .stdout(Stdio::piped())
1298 .stderr(Stdio::piped());
1299
1300 #[cfg(windows)]
1301 {
1302 use std::os::windows::process::CommandExt;
1303 const CREATE_NO_WINDOW: u32 = 0x08000000;
1304 command.as_std_mut().creation_flags(CREATE_NO_WINDOW);
1305 }
1306
1307 command
1308 }
1309
1310 fn auth_args(options: &ClientOptions) -> Vec<&'static str> {
1318 let mut args: Vec<&'static str> = Vec::new();
1319 if options.github_token.is_some() {
1320 args.push("--auth-token-env");
1321 args.push("COPILOT_SDK_AUTH_TOKEN");
1322 }
1323 let use_logged_in = options
1324 .use_logged_in_user
1325 .unwrap_or(options.github_token.is_none());
1326 if !use_logged_in {
1327 args.push("--no-auto-login");
1328 }
1329 args
1330 }
1331
1332 fn session_idle_timeout_args(options: &ClientOptions) -> Vec<String> {
1336 match options.session_idle_timeout_seconds {
1337 Some(secs) if secs > 0 => {
1338 vec!["--session-idle-timeout".to_string(), secs.to_string()]
1339 }
1340 _ => Vec::new(),
1341 }
1342 }
1343
1344 fn remote_args(options: &ClientOptions) -> Vec<String> {
1345 if options.remote {
1346 vec!["--remote".to_string()]
1347 } else {
1348 Vec::new()
1349 }
1350 }
1351
1352 fn spawn_stdio(program: &Path, options: &ClientOptions) -> Result<Child, Error> {
1353 info!(cwd = ?options.cwd, program = %program.display(), "spawning copilot CLI (stdio)");
1354 let mut command = Self::build_command(program, options);
1355 let log_level = options.log_level.unwrap_or(LogLevel::Info);
1356 command
1357 .args([
1358 "--server",
1359 "--stdio",
1360 "--no-auto-update",
1361 "--log-level",
1362 log_level.as_str(),
1363 ])
1364 .args(Self::auth_args(options))
1365 .args(Self::session_idle_timeout_args(options))
1366 .args(Self::remote_args(options))
1367 .args(&options.extra_args)
1368 .stdin(Stdio::piped());
1369 let spawn_start = Instant::now();
1370 let child = command.spawn()?;
1371 debug!(
1372 elapsed_ms = spawn_start.elapsed().as_millis(),
1373 "Client::spawn_stdio subprocess spawned"
1374 );
1375 Ok(child)
1376 }
1377
1378 async fn spawn_tcp(
1379 program: &Path,
1380 options: &ClientOptions,
1381 port: u16,
1382 ) -> Result<(Child, u16), Error> {
1383 info!(cwd = ?options.cwd, program = %program.display(), port = %port, "spawning copilot CLI (tcp)");
1384 let mut command = Self::build_command(program, options);
1385 let log_level = options.log_level.unwrap_or(LogLevel::Info);
1386 command
1387 .args([
1388 "--server",
1389 "--port",
1390 &port.to_string(),
1391 "--no-auto-update",
1392 "--log-level",
1393 log_level.as_str(),
1394 ])
1395 .args(Self::auth_args(options))
1396 .args(Self::session_idle_timeout_args(options))
1397 .args(Self::remote_args(options))
1398 .args(&options.extra_args)
1399 .stdin(Stdio::null());
1400 let spawn_start = Instant::now();
1401 let mut child = command.spawn()?;
1402 debug!(
1403 elapsed_ms = spawn_start.elapsed().as_millis(),
1404 "Client::spawn_tcp subprocess spawned"
1405 );
1406 let stdout = child.stdout.take().expect("stdout is piped");
1407
1408 let (port_tx, port_rx) = oneshot::channel::<u16>();
1409 let span = tracing::error_span!("copilot_cli_port_scan");
1410 tokio::spawn(
1411 async move {
1412 let port_re = regex::Regex::new(r"listening on port (\d+)").expect("valid regex");
1414 let mut lines = BufReader::new(stdout).lines();
1415 let mut port_tx = Some(port_tx);
1416 while let Ok(Some(line)) = lines.next_line().await {
1417 debug!(line = %line, "CLI stdout");
1418 if let Some(tx) = port_tx.take() {
1419 if let Some(caps) = port_re.captures(&line)
1420 && let Some(p) =
1421 caps.get(1).and_then(|m| m.as_str().parse::<u16>().ok())
1422 {
1423 let _ = tx.send(p);
1424 continue;
1425 }
1426 port_tx = Some(tx);
1428 }
1429 }
1430 }
1431 .instrument(span),
1432 );
1433
1434 let port_wait_start = Instant::now();
1435 let actual_port = tokio::time::timeout(std::time::Duration::from_secs(10), port_rx)
1436 .await
1437 .map_err(|_| Error::Protocol(ProtocolError::CliStartupTimeout))?
1438 .map_err(|_| Error::Protocol(ProtocolError::CliStartupFailed))?;
1439
1440 debug!(
1441 elapsed_ms = port_wait_start.elapsed().as_millis(),
1442 port = actual_port,
1443 "Client::spawn_tcp TCP port wait complete"
1444 );
1445 info!(port = %actual_port, "CLI server listening");
1446 Ok((child, actual_port))
1447 }
1448
1449 fn drain_stderr(child: &mut Child) {
1450 if let Some(stderr) = child.stderr.take() {
1451 let span = tracing::error_span!("copilot_cli");
1452 tokio::spawn(
1453 async move {
1454 let mut reader = BufReader::new(stderr).lines();
1455 while let Ok(Some(line)) = reader.next_line().await {
1456 warn!(line = %line, "CLI stderr");
1457 }
1458 }
1459 .instrument(span),
1460 );
1461 }
1462 }
1463
1464 pub fn cwd(&self) -> &PathBuf {
1466 &self.inner.cwd
1467 }
1468
1469 pub fn rpc(&self) -> crate::generated::rpc::ClientRpc<'_> {
1480 crate::generated::rpc::ClientRpc { client: self }
1481 }
1482
1483 pub(crate) async fn send_request(
1485 &self,
1486 method: &str,
1487 params: Option<serde_json::Value>,
1488 ) -> Result<JsonRpcResponse, Error> {
1489 self.inner.rpc.send_request(method, params).await
1490 }
1491
1492 pub async fn call(
1512 &self,
1513 method: &str,
1514 params: Option<serde_json::Value>,
1515 ) -> Result<serde_json::Value, Error> {
1516 let session_id: Option<SessionId> = params
1517 .as_ref()
1518 .and_then(|p| p.get("sessionId"))
1519 .and_then(|v| v.as_str())
1520 .map(SessionId::from);
1521 let response = self.send_request(method, params).await?;
1522 if let Some(err) = response.error {
1523 if err.message.contains("Session not found") {
1524 return Err(Error::Session(SessionError::NotFound(
1525 session_id.unwrap_or_else(|| "unknown".into()),
1526 )));
1527 }
1528 return Err(Error::Rpc {
1529 code: err.code,
1530 message: err.message,
1531 });
1532 }
1533 Ok(response.result.unwrap_or(serde_json::Value::Null))
1534 }
1535
1536 pub(crate) async fn send_response(&self, response: &JsonRpcResponse) -> Result<(), Error> {
1538 self.inner.rpc.write(response).await
1539 }
1540
1541 #[expect(dead_code, reason = "reserved for future pub(crate) use")]
1545 pub(crate) fn take_request_rx(&self) -> Option<mpsc::UnboundedReceiver<JsonRpcRequest>> {
1546 self.inner.request_rx.lock().take()
1547 }
1548
1549 pub(crate) fn register_session(
1557 &self,
1558 session_id: &SessionId,
1559 ) -> crate::router::SessionChannels {
1560 self.inner
1561 .router
1562 .ensure_started(&self.inner.notification_tx, &self.inner.request_rx);
1563 self.inner.router.register(session_id)
1564 }
1565
1566 pub(crate) fn unregister_session(&self, session_id: &SessionId) {
1568 self.inner.router.unregister(session_id);
1569 }
1570
1571 pub fn protocol_version(&self) -> Option<u32> {
1578 self.inner.negotiated_protocol_version.get().copied()
1579 }
1580
1581 pub async fn verify_protocol_version(&self) -> Result<(), Error> {
1605 let handshake_start = Instant::now();
1606 let mut used_fallback_ping = false;
1607 let server_version = match self.connect_handshake().await {
1611 Ok(v) => v,
1612 Err(Error::Rpc { code, .. }) if code == error_codes::METHOD_NOT_FOUND => {
1613 used_fallback_ping = true;
1614 self.ping(None).await?.protocol_version
1615 }
1616 Err(e) => return Err(e),
1617 };
1618
1619 match server_version {
1620 None => {
1621 warn!("CLI server did not report protocolVersion; skipping version check");
1622 }
1623 Some(v) if !(MIN_PROTOCOL_VERSION..=SDK_PROTOCOL_VERSION).contains(&v) => {
1624 return Err(Error::Protocol(ProtocolError::VersionMismatch {
1625 server: v,
1626 min: MIN_PROTOCOL_VERSION,
1627 max: SDK_PROTOCOL_VERSION,
1628 }));
1629 }
1630 Some(v) => {
1631 if let Some(&existing) = self.inner.negotiated_protocol_version.get() {
1632 if existing != v {
1633 return Err(Error::Protocol(ProtocolError::VersionChanged {
1634 previous: existing,
1635 current: v,
1636 }));
1637 }
1638 } else {
1639 let _ = self.inner.negotiated_protocol_version.set(v);
1640 }
1641 }
1642 }
1643
1644 debug!(
1645 elapsed_ms = handshake_start.elapsed().as_millis(),
1646 protocol_version = ?server_version,
1647 used_fallback_ping,
1648 "Client::verify_protocol_version protocol handshake complete"
1649 );
1650 Ok(())
1651 }
1652
1653 async fn connect_handshake(&self) -> Result<Option<u32>, Error> {
1660 let result = self
1661 .rpc()
1662 .connect(crate::generated::api_types::ConnectRequest {
1663 token: self.inner.effective_connection_token.clone(),
1664 })
1665 .await?;
1666 Ok(u32::try_from(result.protocol_version).ok())
1667 }
1668
1669 pub async fn ping(&self, message: Option<&str>) -> Result<crate::types::PingResponse, Error> {
1677 let params = match message {
1678 Some(m) => serde_json::json!({ "message": m }),
1679 None => serde_json::json!({}),
1680 };
1681 let value = self
1682 .call(generated::api_types::rpc_methods::PING, Some(params))
1683 .await?;
1684 Ok(serde_json::from_value(value)?)
1685 }
1686
1687 pub async fn list_sessions(
1690 &self,
1691 filter: Option<SessionListFilter>,
1692 ) -> Result<Vec<SessionMetadata>, Error> {
1693 let params = match filter {
1694 Some(f) => serde_json::json!({ "filter": f }),
1695 None => serde_json::json!({}),
1696 };
1697 let result = self.call("session.list", Some(params)).await?;
1698 let response: ListSessionsResponse = serde_json::from_value(result)?;
1699 Ok(response.sessions)
1700 }
1701
1702 pub async fn get_session_metadata(
1720 &self,
1721 session_id: &SessionId,
1722 ) -> Result<Option<SessionMetadata>, Error> {
1723 let result = self
1724 .call(
1725 "session.getMetadata",
1726 Some(serde_json::json!({ "sessionId": session_id })),
1727 )
1728 .await?;
1729 let response: GetSessionMetadataResponse = serde_json::from_value(result)?;
1730 Ok(response.session)
1731 }
1732
1733 pub async fn delete_session(&self, session_id: &SessionId) -> Result<(), Error> {
1735 self.call(
1736 "session.delete",
1737 Some(serde_json::json!({ "sessionId": session_id })),
1738 )
1739 .await?;
1740 Ok(())
1741 }
1742
1743 pub async fn get_last_session_id(&self) -> Result<Option<SessionId>, Error> {
1759 let result = self
1760 .call("session.getLastId", Some(serde_json::json!({})))
1761 .await?;
1762 let response: GetLastSessionIdResponse = serde_json::from_value(result)?;
1763 Ok(response.session_id)
1764 }
1765
1766 pub async fn get_foreground_session_id(&self) -> Result<Option<SessionId>, Error> {
1771 let result = self
1772 .call("session.getForeground", Some(serde_json::json!({})))
1773 .await?;
1774 let response: GetForegroundSessionResponse = serde_json::from_value(result)?;
1775 Ok(response.session_id)
1776 }
1777
1778 pub async fn set_foreground_session_id(&self, session_id: &SessionId) -> Result<(), Error> {
1783 self.call(
1784 "session.setForeground",
1785 Some(serde_json::json!({ "sessionId": session_id })),
1786 )
1787 .await?;
1788 Ok(())
1789 }
1790
1791 pub async fn get_status(&self) -> Result<GetStatusResponse, Error> {
1793 let result = self.call("status.get", Some(serde_json::json!({}))).await?;
1794 Ok(serde_json::from_value(result)?)
1795 }
1796
1797 pub async fn get_auth_status(&self) -> Result<GetAuthStatusResponse, Error> {
1799 let result = self
1800 .call("auth.getStatus", Some(serde_json::json!({})))
1801 .await?;
1802 Ok(serde_json::from_value(result)?)
1803 }
1804
1805 pub async fn list_models(&self) -> Result<Vec<Model>, Error> {
1810 let cache = self.inner.models_cache.lock().clone();
1811 let models = cache
1812 .get_or_try_init(|| async {
1813 if let Some(handler) = &self.inner.on_list_models {
1814 handler.list_models().await
1815 } else {
1816 Ok(self.rpc().models().list().await?.models)
1817 }
1818 })
1819 .await?;
1820 Ok(models.clone())
1821 }
1822
1823 pub(crate) async fn resolve_trace_context(&self) -> TraceContext {
1826 if let Some(provider) = &self.inner.on_get_trace_context {
1827 provider.get_trace_context().await
1828 } else {
1829 TraceContext::default()
1830 }
1831 }
1832
1833 pub fn pid(&self) -> Option<u32> {
1835 self.inner.child.lock().as_ref().and_then(|c| c.id())
1836 }
1837
1838 pub async fn stop(&self) -> Result<(), StopErrors> {
1864 let pid = self.pid();
1865 info!(pid = ?pid, "stopping CLI process");
1866 let mut errors: Vec<Error> = Vec::new();
1867
1868 for session_id in self.inner.router.session_ids() {
1871 match self
1872 .call(
1873 "session.destroy",
1874 Some(serde_json::json!({ "sessionId": session_id })),
1875 )
1876 .await
1877 {
1878 Ok(_) => {}
1879 Err(e) => {
1880 warn!(
1881 session_id = %session_id,
1882 error = %e,
1883 "session.destroy failed during Client::stop",
1884 );
1885 errors.push(e);
1886 }
1887 }
1888 self.inner.router.unregister(&session_id);
1889 }
1890
1891 let child = self.inner.child.lock().take();
1892 *self.inner.state.lock() = ConnectionState::Disconnected;
1893 *self.inner.models_cache.lock() = Arc::new(tokio::sync::OnceCell::new());
1894 if let Some(mut child) = child
1895 && let Err(e) = child.kill().await
1896 {
1897 errors.push(Error::Io(e));
1898 }
1899
1900 info!(pid = ?pid, errors = errors.len(), "CLI process stopped");
1901 if errors.is_empty() {
1902 Ok(())
1903 } else {
1904 Err(StopErrors(errors))
1905 }
1906 }
1907
1908 pub fn force_stop(&self) {
1938 let pid = self.pid();
1939 info!(pid = ?pid, "force-stopping CLI process");
1940 if let Some(mut child) = self.inner.child.lock().take()
1941 && let Err(e) = child.start_kill()
1942 {
1943 error!(pid = ?pid, error = %e, "failed to send kill signal");
1944 }
1945 self.inner.rpc.force_close();
1946 self.inner.router.clear();
1949 *self.inner.state.lock() = ConnectionState::Disconnected;
1950 *self.inner.models_cache.lock() = Arc::new(tokio::sync::OnceCell::new());
1951 }
1952
1953 pub fn subscribe_lifecycle(&self) -> LifecycleSubscription {
1987 LifecycleSubscription::new(self.inner.lifecycle_tx.subscribe())
1988 }
1989
1990 pub fn state(&self) -> ConnectionState {
1997 *self.inner.state.lock()
1998 }
1999}
2000
2001impl Drop for ClientInner {
2002 fn drop(&mut self) {
2003 if let Some(ref mut child) = *self.child.lock() {
2004 let pid = child.id();
2005 if let Err(e) = child.start_kill() {
2006 error!(pid = ?pid, error = %e, "failed to kill CLI process on drop");
2007 } else {
2008 info!(pid = ?pid, "kill signal sent for CLI process on drop");
2009 }
2010 }
2011 }
2012}
2013
2014#[cfg(test)]
2015mod tests {
2016 use super::*;
2017
2018 #[test]
2019 fn is_transport_failure_matches_request_cancelled() {
2020 let err = Error::Protocol(ProtocolError::RequestCancelled);
2021 assert!(err.is_transport_failure());
2022 }
2023
2024 #[test]
2025 fn is_transport_failure_matches_io_error() {
2026 let err = Error::Io(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "gone"));
2027 assert!(err.is_transport_failure());
2028 }
2029
2030 #[test]
2031 fn is_transport_failure_rejects_rpc_error() {
2032 let err = Error::Rpc {
2033 code: -1,
2034 message: "bad".into(),
2035 };
2036 assert!(!err.is_transport_failure());
2037 }
2038
2039 #[test]
2040 fn is_transport_failure_rejects_session_error() {
2041 let err = Error::Session(SessionError::NotFound("s1".into()));
2042 assert!(!err.is_transport_failure());
2043 }
2044
2045 #[test]
2046 fn client_options_builder_composes() {
2047 let opts = ClientOptions::new()
2048 .with_program(CliProgram::Path(PathBuf::from("/usr/local/bin/copilot")))
2049 .with_prefix_args(["node"])
2050 .with_cwd(PathBuf::from("/tmp"))
2051 .with_env([("KEY", "value")])
2052 .with_env_remove(["UNWANTED"])
2053 .with_extra_args(["--quiet"])
2054 .with_github_token("ghp_test")
2055 .with_use_logged_in_user(false)
2056 .with_log_level(LogLevel::Debug)
2057 .with_session_idle_timeout_seconds(120)
2058 .with_remote(true);
2059 assert!(matches!(opts.program, CliProgram::Path(_)));
2060 assert_eq!(opts.prefix_args, vec![std::ffi::OsString::from("node")]);
2061 assert_eq!(opts.cwd, PathBuf::from("/tmp"));
2062 assert_eq!(
2063 opts.env,
2064 vec![(
2065 std::ffi::OsString::from("KEY"),
2066 std::ffi::OsString::from("value")
2067 )]
2068 );
2069 assert_eq!(opts.env_remove, vec![std::ffi::OsString::from("UNWANTED")]);
2070 assert_eq!(opts.extra_args, vec!["--quiet".to_string()]);
2071 assert_eq!(opts.github_token.as_deref(), Some("ghp_test"));
2072 assert_eq!(opts.use_logged_in_user, Some(false));
2073 assert!(matches!(opts.log_level, Some(LogLevel::Debug)));
2074 assert_eq!(opts.session_idle_timeout_seconds, Some(120));
2075 assert!(opts.remote);
2076 }
2077
2078 #[test]
2079 fn is_transport_failure_rejects_other_protocol_errors() {
2080 let err = Error::Protocol(ProtocolError::CliStartupTimeout);
2081 assert!(!err.is_transport_failure());
2082 }
2083
2084 #[test]
2085 fn build_command_lets_env_remove_strip_injected_token() {
2086 let opts = ClientOptions {
2087 github_token: Some("secret".to_string()),
2088 env_remove: vec![std::ffi::OsString::from("COPILOT_SDK_AUTH_TOKEN")],
2089 ..Default::default()
2090 };
2091 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2092 let action = cmd
2094 .as_std()
2095 .get_envs()
2096 .find(|(k, _)| *k == std::ffi::OsStr::new("COPILOT_SDK_AUTH_TOKEN"))
2097 .map(|(_, v)| v);
2098 assert_eq!(
2099 action,
2100 Some(None),
2101 "env_remove should win over github_token"
2102 );
2103 }
2104
2105 #[test]
2106 fn build_command_lets_env_override_injected_token() {
2107 let opts = ClientOptions {
2108 github_token: Some("from-options".to_string()),
2109 env: vec![(
2110 std::ffi::OsString::from("COPILOT_SDK_AUTH_TOKEN"),
2111 std::ffi::OsString::from("from-env"),
2112 )],
2113 ..Default::default()
2114 };
2115 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2116 let value = cmd
2117 .as_std()
2118 .get_envs()
2119 .find(|(k, _)| *k == std::ffi::OsStr::new("COPILOT_SDK_AUTH_TOKEN"))
2120 .and_then(|(_, v)| v);
2121 assert_eq!(value, Some(std::ffi::OsStr::new("from-env")));
2122 }
2123
2124 #[test]
2125 fn build_command_injects_github_token_by_default() {
2126 let opts = ClientOptions {
2127 github_token: Some("just-the-token".to_string()),
2128 ..Default::default()
2129 };
2130 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2131 let value = cmd
2132 .as_std()
2133 .get_envs()
2134 .find(|(k, _)| *k == std::ffi::OsStr::new("COPILOT_SDK_AUTH_TOKEN"))
2135 .and_then(|(_, v)| v);
2136 assert_eq!(value, Some(std::ffi::OsStr::new("just-the-token")));
2137 }
2138
2139 fn env_value<'a>(cmd: &'a tokio::process::Command, key: &str) -> Option<&'a std::ffi::OsStr> {
2140 cmd.as_std()
2141 .get_envs()
2142 .find(|(k, _)| *k == std::ffi::OsStr::new(key))
2143 .and_then(|(_, v)| v)
2144 }
2145
2146 #[test]
2147 fn telemetry_config_builder_composes() {
2148 let cfg = TelemetryConfig::new()
2149 .with_otlp_endpoint("http://collector:4318")
2150 .with_file_path(PathBuf::from("/var/log/copilot.jsonl"))
2151 .with_exporter_type(OtelExporterType::OtlpHttp)
2152 .with_source_name("my-app")
2153 .with_capture_content(true);
2154
2155 assert_eq!(cfg.otlp_endpoint.as_deref(), Some("http://collector:4318"));
2156 assert_eq!(
2157 cfg.file_path.as_deref(),
2158 Some(Path::new("/var/log/copilot.jsonl")),
2159 );
2160 assert_eq!(cfg.exporter_type, Some(OtelExporterType::OtlpHttp));
2161 assert_eq!(cfg.source_name.as_deref(), Some("my-app"));
2162 assert_eq!(cfg.capture_content, Some(true));
2163 assert!(!cfg.is_empty());
2164 assert!(TelemetryConfig::new().is_empty());
2165 }
2166
2167 #[test]
2168 fn build_command_sets_otel_env_when_telemetry_enabled() {
2169 let opts = ClientOptions {
2170 telemetry: Some(TelemetryConfig {
2171 otlp_endpoint: Some("http://collector:4318".to_string()),
2172 file_path: Some(PathBuf::from("/var/log/copilot.jsonl")),
2173 exporter_type: Some(OtelExporterType::OtlpHttp),
2174 source_name: Some("my-app".to_string()),
2175 capture_content: Some(true),
2176 }),
2177 ..Default::default()
2178 };
2179 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2180 assert_eq!(
2181 env_value(&cmd, "COPILOT_OTEL_ENABLED"),
2182 Some(std::ffi::OsStr::new("true")),
2183 );
2184 assert_eq!(
2185 env_value(&cmd, "OTEL_EXPORTER_OTLP_ENDPOINT"),
2186 Some(std::ffi::OsStr::new("http://collector:4318")),
2187 );
2188 assert_eq!(
2189 env_value(&cmd, "COPILOT_OTEL_FILE_EXPORTER_PATH"),
2190 Some(std::ffi::OsStr::new("/var/log/copilot.jsonl")),
2191 );
2192 assert_eq!(
2193 env_value(&cmd, "COPILOT_OTEL_EXPORTER_TYPE"),
2194 Some(std::ffi::OsStr::new("otlp-http")),
2195 );
2196 assert_eq!(
2197 env_value(&cmd, "COPILOT_OTEL_SOURCE_NAME"),
2198 Some(std::ffi::OsStr::new("my-app")),
2199 );
2200 assert_eq!(
2201 env_value(&cmd, "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"),
2202 Some(std::ffi::OsStr::new("true")),
2203 );
2204 }
2205
2206 #[test]
2207 fn build_command_omits_otel_env_when_telemetry_none() {
2208 let opts = ClientOptions::default();
2209 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2210 for key in [
2211 "COPILOT_OTEL_ENABLED",
2212 "OTEL_EXPORTER_OTLP_ENDPOINT",
2213 "COPILOT_OTEL_FILE_EXPORTER_PATH",
2214 "COPILOT_OTEL_EXPORTER_TYPE",
2215 "COPILOT_OTEL_SOURCE_NAME",
2216 "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT",
2217 ] {
2218 assert!(
2219 env_value(&cmd, key).is_none(),
2220 "expected {key} to be unset when telemetry is None",
2221 );
2222 }
2223 }
2224
2225 #[test]
2226 fn build_command_omits_unset_telemetry_fields() {
2227 let opts = ClientOptions {
2228 telemetry: Some(TelemetryConfig {
2229 otlp_endpoint: Some("http://collector:4318".to_string()),
2230 ..Default::default()
2231 }),
2232 ..Default::default()
2233 };
2234 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2235 assert_eq!(
2237 env_value(&cmd, "COPILOT_OTEL_ENABLED"),
2238 Some(std::ffi::OsStr::new("true")),
2239 );
2240 assert_eq!(
2241 env_value(&cmd, "OTEL_EXPORTER_OTLP_ENDPOINT"),
2242 Some(std::ffi::OsStr::new("http://collector:4318")),
2243 );
2244 for key in [
2246 "COPILOT_OTEL_FILE_EXPORTER_PATH",
2247 "COPILOT_OTEL_EXPORTER_TYPE",
2248 "COPILOT_OTEL_SOURCE_NAME",
2249 "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT",
2250 ] {
2251 assert!(env_value(&cmd, key).is_none(), "{key} should be unset");
2252 }
2253 }
2254
2255 #[test]
2256 fn build_command_lets_user_env_override_telemetry() {
2257 let opts = ClientOptions {
2258 telemetry: Some(TelemetryConfig {
2259 otlp_endpoint: Some("http://from-config:4318".to_string()),
2260 ..Default::default()
2261 }),
2262 env: vec![(
2263 std::ffi::OsString::from("OTEL_EXPORTER_OTLP_ENDPOINT"),
2264 std::ffi::OsString::from("http://from-user-env:4318"),
2265 )],
2266 ..Default::default()
2267 };
2268 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2269 assert_eq!(
2270 env_value(&cmd, "OTEL_EXPORTER_OTLP_ENDPOINT"),
2271 Some(std::ffi::OsStr::new("http://from-user-env:4318")),
2272 "user-supplied options.env should override telemetry config",
2273 );
2274 }
2275
2276 #[test]
2277 fn build_command_sets_copilot_home_env_when_configured() {
2278 let opts = ClientOptions::new().with_copilot_home(PathBuf::from("/custom/copilot"));
2279 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2280 assert_eq!(
2281 env_value(&cmd, "COPILOT_HOME"),
2282 Some(std::ffi::OsStr::new("/custom/copilot")),
2283 );
2284
2285 let opts = ClientOptions::default();
2286 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2287 assert!(env_value(&cmd, "COPILOT_HOME").is_none());
2288 }
2289
2290 #[test]
2291 fn build_command_sets_connection_token_env_when_configured() {
2292 let opts = ClientOptions::new().with_tcp_connection_token("secret-token");
2293 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2294 assert_eq!(
2295 env_value(&cmd, "COPILOT_CONNECTION_TOKEN"),
2296 Some(std::ffi::OsStr::new("secret-token")),
2297 );
2298
2299 let opts = ClientOptions::default();
2300 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2301 assert!(env_value(&cmd, "COPILOT_CONNECTION_TOKEN").is_none());
2302 }
2303
2304 #[tokio::test]
2305 async fn start_rejects_token_with_stdio_transport() {
2306 let opts = ClientOptions::new()
2307 .with_tcp_connection_token("token-123")
2308 .with_program(CliProgram::Path(PathBuf::from("/bin/echo")));
2309 let err = Client::start(opts).await.unwrap_err();
2310 assert!(matches!(err, Error::InvalidConfig(_)), "got {err:?}");
2311 let Error::InvalidConfig(msg) = err else {
2312 unreachable!()
2313 };
2314 assert!(
2315 msg.contains("Stdio"),
2316 "error should explain the stdio incompatibility: {msg}"
2317 );
2318 }
2319
2320 #[tokio::test]
2321 async fn start_rejects_empty_connection_token() {
2322 let opts = ClientOptions::new()
2323 .with_tcp_connection_token("")
2324 .with_transport(Transport::Tcp { port: 0 })
2325 .with_program(CliProgram::Path(PathBuf::from("/bin/echo")));
2326 let err = Client::start(opts).await.unwrap_err();
2327 assert!(matches!(err, Error::InvalidConfig(_)), "got {err:?}");
2328 }
2329
2330 #[test]
2331 fn telemetry_config_capture_content_serializes_as_lowercase_bool() {
2332 let opts_true = ClientOptions {
2333 telemetry: Some(TelemetryConfig {
2334 capture_content: Some(true),
2335 ..Default::default()
2336 }),
2337 ..Default::default()
2338 };
2339 let opts_false = ClientOptions {
2340 telemetry: Some(TelemetryConfig {
2341 capture_content: Some(false),
2342 ..Default::default()
2343 }),
2344 ..Default::default()
2345 };
2346 let cmd_true = Client::build_command(Path::new("/bin/echo"), &opts_true);
2347 let cmd_false = Client::build_command(Path::new("/bin/echo"), &opts_false);
2348 assert_eq!(
2349 env_value(
2350 &cmd_true,
2351 "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"
2352 ),
2353 Some(std::ffi::OsStr::new("true")),
2354 );
2355 assert_eq!(
2356 env_value(
2357 &cmd_false,
2358 "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"
2359 ),
2360 Some(std::ffi::OsStr::new("false")),
2361 );
2362 }
2363
2364 #[test]
2365 fn session_idle_timeout_args_are_omitted_by_default() {
2366 let opts = ClientOptions::default();
2367 assert!(Client::session_idle_timeout_args(&opts).is_empty());
2368 }
2369
2370 #[test]
2371 fn session_idle_timeout_args_omitted_for_zero() {
2372 let opts = ClientOptions {
2373 session_idle_timeout_seconds: Some(0),
2374 ..Default::default()
2375 };
2376 assert!(Client::session_idle_timeout_args(&opts).is_empty());
2377 }
2378
2379 #[test]
2380 fn session_idle_timeout_args_emit_flag_for_positive_value() {
2381 let opts = ClientOptions {
2382 session_idle_timeout_seconds: Some(300),
2383 ..Default::default()
2384 };
2385 assert_eq!(
2386 Client::session_idle_timeout_args(&opts),
2387 vec!["--session-idle-timeout".to_string(), "300".to_string()]
2388 );
2389 }
2390
2391 #[test]
2392 fn remote_args_omitted_by_default() {
2393 let opts = ClientOptions::default();
2394 assert!(Client::remote_args(&opts).is_empty());
2395 }
2396
2397 #[test]
2398 fn remote_args_emit_flag_when_enabled() {
2399 let opts = ClientOptions {
2400 remote: true,
2401 ..Default::default()
2402 };
2403 assert_eq!(Client::remote_args(&opts), vec!["--remote".to_string()]);
2404 }
2405
2406 #[test]
2407 fn log_level_str_round_trips() {
2408 for level in [
2409 LogLevel::None,
2410 LogLevel::Error,
2411 LogLevel::Warning,
2412 LogLevel::Info,
2413 LogLevel::Debug,
2414 LogLevel::All,
2415 ] {
2416 let s = level.as_str();
2417 let json = serde_json::to_string(&level).unwrap();
2418 assert_eq!(json, format!("\"{s}\""));
2419 let parsed: LogLevel = serde_json::from_str(&json).unwrap();
2420 assert_eq!(parsed, level);
2421 }
2422 }
2423
2424 #[test]
2425 fn client_options_debug_redacts_handler() {
2426 struct StubHandler;
2427 #[async_trait]
2428 impl ListModelsHandler for StubHandler {
2429 async fn list_models(&self) -> Result<Vec<Model>, Error> {
2430 Ok(vec![])
2431 }
2432 }
2433 let opts = ClientOptions {
2434 on_list_models: Some(Arc::new(StubHandler)),
2435 github_token: Some("secret-token".into()),
2436 ..Default::default()
2437 };
2438 let debug = format!("{opts:?}");
2439 assert!(debug.contains("on_list_models: Some(\"<set>\")"));
2440 assert!(debug.contains("github_token: Some(\"<redacted>\")"));
2441 assert!(!debug.contains("secret-token"));
2442 }
2443
2444 #[tokio::test]
2445 async fn list_models_uses_on_list_models_handler_when_set() {
2446 use std::sync::atomic::{AtomicUsize, Ordering};
2447
2448 struct CountingHandler {
2449 calls: Arc<AtomicUsize>,
2450 models: Vec<Model>,
2451 }
2452 #[async_trait]
2453 impl ListModelsHandler for CountingHandler {
2454 async fn list_models(&self) -> Result<Vec<Model>, Error> {
2455 self.calls.fetch_add(1, Ordering::SeqCst);
2456 Ok(self.models.clone())
2457 }
2458 }
2459
2460 let calls = Arc::new(AtomicUsize::new(0));
2461 let model = Model {
2462 id: "byok-gpt-4".into(),
2463 name: "BYOK GPT-4".into(),
2464 ..Default::default()
2465 };
2466 let handler: Arc<dyn ListModelsHandler> = Arc::new(CountingHandler {
2467 calls: Arc::clone(&calls),
2468 models: vec![model.clone()],
2469 });
2470
2471 let client = client_with_list_models_handler(handler);
2472
2473 let result = client.list_models().await.unwrap();
2474 assert_eq!(result.len(), 1);
2475 assert_eq!(result[0].id, "byok-gpt-4");
2476 assert_eq!(calls.load(Ordering::SeqCst), 1);
2477 }
2478
2479 #[tokio::test]
2480 async fn list_models_serializes_concurrent_cache_misses() {
2481 use std::sync::atomic::{AtomicUsize, Ordering};
2482
2483 struct SlowCountingHandler {
2484 calls: Arc<AtomicUsize>,
2485 models: Vec<Model>,
2486 }
2487 #[async_trait]
2488 impl ListModelsHandler for SlowCountingHandler {
2489 async fn list_models(&self) -> Result<Vec<Model>, Error> {
2490 self.calls.fetch_add(1, Ordering::SeqCst);
2491 tokio::time::sleep(std::time::Duration::from_millis(25)).await;
2492 Ok(self.models.clone())
2493 }
2494 }
2495
2496 let calls = Arc::new(AtomicUsize::new(0));
2497 let model = Model {
2498 id: "single-flight-model".into(),
2499 name: "Single Flight Model".into(),
2500 ..Default::default()
2501 };
2502 let handler: Arc<dyn ListModelsHandler> = Arc::new(SlowCountingHandler {
2503 calls: Arc::clone(&calls),
2504 models: vec![model],
2505 });
2506 let client = client_with_list_models_handler(handler);
2507
2508 let (first, second) = tokio::join!(client.list_models(), client.list_models());
2509 assert_eq!(first.unwrap()[0].id, "single-flight-model");
2510 assert_eq!(second.unwrap()[0].id, "single-flight-model");
2511 assert_eq!(calls.load(Ordering::SeqCst), 1);
2512 }
2513
2514 #[tokio::test]
2515 async fn cancelled_create_session_unregisters_pending_session() {
2516 let (client_write, _server_read) = tokio::io::duplex(8192);
2517 let (_server_write, client_read) = tokio::io::duplex(8192);
2518 let client = Client::from_streams(client_read, client_write, std::env::temp_dir()).unwrap();
2519 let handle = tokio::spawn({
2520 let client = client.clone();
2521 async move { client.create_session(SessionConfig::default()).await }
2522 });
2523
2524 wait_for_pending_session_registration(&client).await;
2525 handle.abort();
2526 let _ = handle.await;
2527
2528 assert!(client.inner.router.session_ids().is_empty());
2529 client.force_stop();
2530 }
2531
2532 #[tokio::test]
2533 async fn cancelled_resume_session_unregisters_pending_session() {
2534 let (client_write, _server_read) = tokio::io::duplex(8192);
2535 let (_server_write, client_read) = tokio::io::duplex(8192);
2536 let client = Client::from_streams(client_read, client_write, std::env::temp_dir()).unwrap();
2537 let session_id = SessionId::new("resume-cancel-test");
2538 let handle = tokio::spawn({
2539 let client = client.clone();
2540 async move {
2541 client
2542 .resume_session(ResumeSessionConfig::new(session_id))
2543 .await
2544 }
2545 });
2546
2547 wait_for_pending_session_registration(&client).await;
2548 handle.abort();
2549 let _ = handle.await;
2550
2551 assert!(client.inner.router.session_ids().is_empty());
2552 client.force_stop();
2553 }
2554
2555 fn client_with_list_models_handler(handler: Arc<dyn ListModelsHandler>) -> Client {
2556 Client {
2557 inner: Arc::new(ClientInner {
2558 child: parking_lot::Mutex::new(None),
2559 rpc: {
2560 let (req_tx, _req_rx) = mpsc::unbounded_channel();
2561 let (notif_tx, _notif_rx) = broadcast::channel(16);
2562 let (read_pipe, _write_pipe) = tokio::io::duplex(64);
2563 let (_unused_read, write_pipe) = tokio::io::duplex(64);
2564 JsonRpcClient::new(write_pipe, read_pipe, notif_tx, req_tx)
2565 },
2566 cwd: PathBuf::from("."),
2567 request_rx: parking_lot::Mutex::new(None),
2568 notification_tx: broadcast::channel(16).0,
2569 router: router::SessionRouter::new(),
2570 negotiated_protocol_version: OnceLock::new(),
2571 state: parking_lot::Mutex::new(ConnectionState::Connected),
2572 lifecycle_tx: broadcast::channel(16).0,
2573 on_list_models: Some(handler),
2574 models_cache: parking_lot::Mutex::new(Arc::new(tokio::sync::OnceCell::new())),
2575 session_fs_configured: false,
2576 session_fs_sqlite_declared: false,
2577 on_get_trace_context: None,
2578 effective_connection_token: None,
2579 }),
2580 }
2581 }
2582
2583 async fn wait_for_pending_session_registration(client: &Client) {
2584 let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(1);
2585 while client.inner.router.session_ids().is_empty() {
2586 assert!(
2587 tokio::time::Instant::now() < deadline,
2588 "session was not registered"
2589 );
2590 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
2591 }
2592 }
2593}