1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3#![deny(rustdoc::broken_intra_doc_links)]
4#![cfg_attr(test, allow(clippy::unwrap_used))]
5
6pub mod embeddedcli;
8pub mod handler;
10pub mod hooks;
12mod jsonrpc;
13pub mod permission;
15pub mod resolve;
17mod router;
18pub mod session;
20pub mod session_fs;
22mod session_fs_dispatch;
23pub mod subscription;
25pub mod tool;
27pub mod trace_context;
29pub mod transforms;
31pub mod types;
33
34pub mod generated;
36
37use std::ffi::OsString;
38use std::path::{Path, PathBuf};
39use std::process::Stdio;
40use std::sync::{Arc, OnceLock};
41use std::time::Instant;
42
43use async_trait::async_trait;
44pub(crate) use jsonrpc::{
47 JsonRpcClient, JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, error_codes,
48};
49
50#[cfg(feature = "test-support")]
52pub mod test_support {
53 pub use crate::jsonrpc::{
54 JsonRpcClient, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse,
55 error_codes,
56 };
57}
58use serde::{Deserialize, Serialize};
59use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, BufReader};
60use tokio::net::TcpStream;
61use tokio::process::{Child, Command};
62use tokio::sync::{broadcast, mpsc, oneshot};
63use tracing::{Instrument, debug, error, info, warn};
64pub use types::*;
65
66mod sdk_protocol_version;
67pub use sdk_protocol_version::{SDK_PROTOCOL_VERSION, get_sdk_protocol_version};
68pub use subscription::{EventSubscription, Lagged, LifecycleSubscription, RecvError};
69
70const MIN_PROTOCOL_VERSION: u32 = 2;
72
73#[derive(Debug, thiserror::Error)]
75#[non_exhaustive]
76pub enum Error {
77 #[error("protocol error: {0}")]
79 Protocol(ProtocolError),
80
81 #[error("RPC error {code}: {message}")]
83 Rpc {
84 code: i32,
86 message: String,
88 },
89
90 #[error("session error: {0}")]
92 Session(SessionError),
93
94 #[error(transparent)]
96 Io(#[from] std::io::Error),
97
98 #[error(transparent)]
100 Json(#[from] serde_json::Error),
101
102 #[error("binary not found: {name} ({hint})")]
104 BinaryNotFound {
105 name: &'static str,
107 hint: &'static str,
109 },
110
111 #[error("invalid client configuration: {0}")]
116 InvalidConfig(String),
117}
118
119impl Error {
120 pub fn is_transport_failure(&self) -> bool {
124 matches!(
125 self,
126 Error::Protocol(ProtocolError::RequestCancelled) | Error::Io(_)
127 )
128 }
129}
130
131#[derive(Debug)]
143pub struct StopErrors(Vec<Error>);
144
145impl StopErrors {
146 pub fn errors(&self) -> &[Error] {
149 &self.0
150 }
151
152 pub fn into_errors(self) -> Vec<Error> {
154 self.0
155 }
156}
157
158impl std::fmt::Display for StopErrors {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 match self.0.as_slice() {
161 [] => write!(f, "stop completed with no errors"),
162 [only] => write!(f, "stop failed: {only}"),
163 [first, rest @ ..] => write!(
164 f,
165 "stop failed with {n} errors; first: {first}",
166 n = 1 + rest.len(),
167 ),
168 }
169 }
170}
171
172impl std::error::Error for StopErrors {
173 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
174 self.0
175 .first()
176 .map(|e| e as &(dyn std::error::Error + 'static))
177 }
178}
179
180#[derive(Debug, thiserror::Error)]
182#[non_exhaustive]
183pub enum ProtocolError {
184 #[error("missing Content-Length header")]
186 MissingContentLength,
187
188 #[error("invalid Content-Length value: \"{0}\"")]
190 InvalidContentLength(String),
191
192 #[error("request cancelled")]
194 RequestCancelled,
195
196 #[error("timed out waiting for CLI to report listening port")]
198 CliStartupTimeout,
199
200 #[error("CLI exited before reporting listening port")]
202 CliStartupFailed,
203
204 #[error("version mismatch: server={server}, supported={min}–{max}")]
206 VersionMismatch {
207 server: u32,
209 min: u32,
211 max: u32,
213 },
214
215 #[error("version changed: was {previous}, now {current}")]
217 VersionChanged {
218 previous: u32,
220 current: u32,
222 },
223}
224
225#[derive(Debug, thiserror::Error)]
227#[non_exhaustive]
228pub enum SessionError {
229 #[error("session not found: {0}")]
231 NotFound(SessionId),
232
233 #[error("{0}")]
235 AgentError(String),
236
237 #[error("timed out after {0:?}")]
239 Timeout(std::time::Duration),
240
241 #[error("cannot send while send_and_wait is in flight")]
243 SendWhileWaiting,
244
245 #[error("event loop closed before session reached idle")]
247 EventLoopClosed,
248
249 #[error(
252 "elicitation not supported by host — check session.capabilities().ui.elicitation first"
253 )]
254 ElicitationNotSupported,
255
256 #[error(
261 "session was created on a client with session_fs configured but no SessionFsProvider was supplied"
262 )]
263 SessionFsProviderRequired,
264
265 #[error("invalid SessionFsConfig: {0}")]
269 InvalidSessionFsConfig(String),
270
271 #[error("CLI returned session ID {returned} after SDK registered {requested}")]
273 SessionIdMismatch {
274 requested: SessionId,
276 returned: SessionId,
278 },
279}
280
281#[derive(Debug, Default)]
283#[non_exhaustive]
284pub enum Transport {
285 #[default]
287 Stdio,
288 Tcp {
290 port: u16,
292 },
293 External {
295 host: String,
297 port: u16,
299 },
300}
301
302#[derive(Debug, Clone, Default)]
304pub enum CliProgram {
305 #[default]
308 Resolve,
309 Path(PathBuf),
311}
312
313impl From<PathBuf> for CliProgram {
314 fn from(path: PathBuf) -> Self {
315 Self::Path(path)
316 }
317}
318
319#[non_exhaustive]
328pub struct ClientOptions {
329 pub program: CliProgram,
331 pub prefix_args: Vec<OsString>,
333 pub cwd: PathBuf,
335 pub env: Vec<(OsString, OsString)>,
337 pub env_remove: Vec<OsString>,
339 pub extra_args: Vec<String>,
341 pub transport: Transport,
343 pub github_token: Option<String>,
348 pub use_logged_in_user: Option<bool>,
352 pub log_level: Option<LogLevel>,
355 pub session_idle_timeout_seconds: Option<u64>,
361 pub on_list_models: Option<Arc<dyn ListModelsHandler>>,
369 pub session_fs: Option<SessionFsConfig>,
377 pub on_get_trace_context: Option<Arc<dyn TraceContextProvider>>,
387 pub telemetry: Option<TelemetryConfig>,
391 pub copilot_home: Option<PathBuf>,
396 pub tcp_connection_token: Option<String>,
406 pub remote: bool,
412}
413
414impl std::fmt::Debug for ClientOptions {
415 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416 f.debug_struct("ClientOptions")
417 .field("program", &self.program)
418 .field("prefix_args", &self.prefix_args)
419 .field("cwd", &self.cwd)
420 .field("env", &self.env)
421 .field("env_remove", &self.env_remove)
422 .field("extra_args", &self.extra_args)
423 .field("transport", &self.transport)
424 .field(
425 "github_token",
426 &self.github_token.as_ref().map(|_| "<redacted>"),
427 )
428 .field("use_logged_in_user", &self.use_logged_in_user)
429 .field("log_level", &self.log_level)
430 .field(
431 "session_idle_timeout_seconds",
432 &self.session_idle_timeout_seconds,
433 )
434 .field(
435 "on_list_models",
436 &self.on_list_models.as_ref().map(|_| "<set>"),
437 )
438 .field("session_fs", &self.session_fs)
439 .field(
440 "on_get_trace_context",
441 &self.on_get_trace_context.as_ref().map(|_| "<set>"),
442 )
443 .field("telemetry", &self.telemetry)
444 .field("copilot_home", &self.copilot_home)
445 .field(
446 "tcp_connection_token",
447 &self.tcp_connection_token.as_ref().map(|_| "<redacted>"),
448 )
449 .field("remote", &self.remote)
450 .finish()
451 }
452}
453
454#[async_trait]
463pub trait ListModelsHandler: Send + Sync + 'static {
464 async fn list_models(&self) -> Result<Vec<Model>, Error>;
466}
467
468#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
470#[serde(rename_all = "lowercase")]
471pub enum LogLevel {
472 None,
474 Error,
476 Warning,
478 Info,
480 Debug,
482 All,
484}
485
486impl LogLevel {
487 pub fn as_str(self) -> &'static str {
489 match self {
490 Self::None => "none",
491 Self::Error => "error",
492 Self::Warning => "warning",
493 Self::Info => "info",
494 Self::Debug => "debug",
495 Self::All => "all",
496 }
497 }
498}
499
500impl std::fmt::Display for LogLevel {
501 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
502 f.write_str(self.as_str())
503 }
504}
505
506#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
511#[serde(rename_all = "kebab-case")]
512#[non_exhaustive]
513pub enum OtelExporterType {
514 OtlpHttp,
517 File,
520}
521
522impl OtelExporterType {
523 pub fn as_str(self) -> &'static str {
525 match self {
526 Self::OtlpHttp => "otlp-http",
527 Self::File => "file",
528 }
529 }
530}
531
532#[derive(Debug, Clone, Default)]
565#[non_exhaustive]
566pub struct TelemetryConfig {
567 pub otlp_endpoint: Option<String>,
569 pub file_path: Option<PathBuf>,
571 pub exporter_type: Option<OtelExporterType>,
574 pub source_name: Option<String>,
578 pub capture_content: Option<bool>,
582}
583
584impl TelemetryConfig {
585 pub fn new() -> Self {
588 Self::default()
589 }
590
591 pub fn with_otlp_endpoint(mut self, endpoint: impl Into<String>) -> Self {
593 self.otlp_endpoint = Some(endpoint.into());
594 self
595 }
596
597 pub fn with_file_path(mut self, path: impl Into<PathBuf>) -> Self {
599 self.file_path = Some(path.into());
600 self
601 }
602
603 pub fn with_exporter_type(mut self, exporter_type: OtelExporterType) -> Self {
605 self.exporter_type = Some(exporter_type);
606 self
607 }
608
609 pub fn with_source_name(mut self, source_name: impl Into<String>) -> Self {
613 self.source_name = Some(source_name.into());
614 self
615 }
616
617 pub fn with_capture_content(mut self, capture: bool) -> Self {
621 self.capture_content = Some(capture);
622 self
623 }
624
625 pub fn is_empty(&self) -> bool {
628 self.otlp_endpoint.is_none()
629 && self.file_path.is_none()
630 && self.exporter_type.is_none()
631 && self.source_name.is_none()
632 && self.capture_content.is_none()
633 }
634}
635
636impl Default for ClientOptions {
637 fn default() -> Self {
638 Self {
639 program: CliProgram::Resolve,
640 prefix_args: Vec::new(),
641 cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
642 env: Vec::new(),
643 env_remove: Vec::new(),
644 extra_args: Vec::new(),
645 transport: Transport::default(),
646 github_token: None,
647 use_logged_in_user: None,
648 log_level: None,
649 session_idle_timeout_seconds: None,
650 on_list_models: None,
651 session_fs: None,
652 on_get_trace_context: None,
653 telemetry: None,
654 copilot_home: None,
655 tcp_connection_token: None,
656 remote: false,
657 }
658 }
659}
660
661impl ClientOptions {
662 pub fn new() -> Self {
678 Self::default()
679 }
680
681 pub fn with_program(mut self, program: impl Into<CliProgram>) -> Self {
683 self.program = program.into();
684 self
685 }
686
687 pub fn with_prefix_args<I, S>(mut self, args: I) -> Self
689 where
690 I: IntoIterator<Item = S>,
691 S: Into<OsString>,
692 {
693 self.prefix_args = args.into_iter().map(Into::into).collect();
694 self
695 }
696
697 pub fn with_cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
699 self.cwd = cwd.into();
700 self
701 }
702
703 pub fn with_env<I, K, V>(mut self, env: I) -> Self
705 where
706 I: IntoIterator<Item = (K, V)>,
707 K: Into<OsString>,
708 V: Into<OsString>,
709 {
710 self.env = env.into_iter().map(|(k, v)| (k.into(), v.into())).collect();
711 self
712 }
713
714 pub fn with_env_remove<I, S>(mut self, names: I) -> Self
716 where
717 I: IntoIterator<Item = S>,
718 S: Into<OsString>,
719 {
720 self.env_remove = names.into_iter().map(Into::into).collect();
721 self
722 }
723
724 pub fn with_extra_args<I, S>(mut self, args: I) -> Self
726 where
727 I: IntoIterator<Item = S>,
728 S: Into<String>,
729 {
730 self.extra_args = args.into_iter().map(Into::into).collect();
731 self
732 }
733
734 pub fn with_transport(mut self, transport: Transport) -> Self {
736 self.transport = transport;
737 self
738 }
739
740 pub fn with_github_token(mut self, token: impl Into<String>) -> Self {
743 self.github_token = Some(token.into());
744 self
745 }
746
747 pub fn with_use_logged_in_user(mut self, use_logged_in: bool) -> Self {
750 self.use_logged_in_user = Some(use_logged_in);
751 self
752 }
753
754 pub fn with_log_level(mut self, level: LogLevel) -> Self {
756 self.log_level = Some(level);
757 self
758 }
759
760 pub fn with_session_idle_timeout_seconds(mut self, seconds: u64) -> Self {
763 self.session_idle_timeout_seconds = Some(seconds);
764 self
765 }
766
767 pub fn with_list_models_handler<H>(mut self, handler: H) -> Self
770 where
771 H: ListModelsHandler + 'static,
772 {
773 self.on_list_models = Some(Arc::new(handler));
774 self
775 }
776
777 pub fn with_session_fs(mut self, config: SessionFsConfig) -> Self {
779 self.session_fs = Some(config);
780 self
781 }
782
783 pub fn with_trace_context_provider<P>(mut self, provider: P) -> Self
787 where
788 P: TraceContextProvider + 'static,
789 {
790 self.on_get_trace_context = Some(Arc::new(provider));
791 self
792 }
793
794 pub fn with_telemetry(mut self, config: TelemetryConfig) -> Self {
796 self.telemetry = Some(config);
797 self
798 }
799
800 pub fn with_copilot_home(mut self, home: impl Into<PathBuf>) -> Self {
803 self.copilot_home = Some(home.into());
804 self
805 }
806
807 pub fn with_tcp_connection_token(mut self, token: impl Into<String>) -> Self {
811 self.tcp_connection_token = Some(token.into());
812 self
813 }
814
815 pub fn with_remote(mut self, enabled: bool) -> Self {
818 self.remote = enabled;
819 self
820 }
821}
822
823fn validate_session_fs_config(cfg: &SessionFsConfig) -> Result<(), Error> {
825 if cfg.initial_cwd.trim().is_empty() {
826 return Err(Error::Session(SessionError::InvalidSessionFsConfig(
827 "initial_cwd must not be empty".to_string(),
828 )));
829 }
830 if cfg.session_state_path.trim().is_empty() {
831 return Err(Error::Session(SessionError::InvalidSessionFsConfig(
832 "session_state_path must not be empty".to_string(),
833 )));
834 }
835 Ok(())
836}
837
838fn generate_connection_token() -> String {
845 let mut bytes = [0u8; 16];
846 getrandom::getrandom(&mut bytes)
847 .expect("OS CSPRNG (getrandom) is unavailable; cannot generate connection token");
848 let mut hex = String::with_capacity(32);
849 for byte in bytes {
850 use std::fmt::Write;
851 let _ = write!(hex, "{byte:02x}");
852 }
853 hex
854}
855
856#[derive(Clone)]
861pub struct Client {
862 inner: Arc<ClientInner>,
863}
864
865impl std::fmt::Debug for Client {
866 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
867 f.debug_struct("Client")
868 .field("cwd", &self.inner.cwd)
869 .field("pid", &self.pid())
870 .finish()
871 }
872}
873
874struct ClientInner {
875 child: parking_lot::Mutex<Option<Child>>,
876 rpc: JsonRpcClient,
877 cwd: PathBuf,
878 request_rx: parking_lot::Mutex<Option<mpsc::UnboundedReceiver<JsonRpcRequest>>>,
879 notification_tx: broadcast::Sender<JsonRpcNotification>,
880 router: router::SessionRouter,
881 negotiated_protocol_version: OnceLock<u32>,
882 state: parking_lot::Mutex<ConnectionState>,
883 lifecycle_tx: broadcast::Sender<SessionLifecycleEvent>,
884 on_list_models: Option<Arc<dyn ListModelsHandler>>,
885 models_cache: parking_lot::Mutex<Arc<tokio::sync::OnceCell<Vec<Model>>>>,
886 session_fs_configured: bool,
887 on_get_trace_context: Option<Arc<dyn TraceContextProvider>>,
888 effective_connection_token: Option<String>,
893}
894
895impl Client {
896 pub async fn start(options: ClientOptions) -> Result<Self, Error> {
909 let start_time = Instant::now();
910 if let Some(cfg) = &options.session_fs {
911 validate_session_fs_config(cfg)?;
912 }
913 if matches!(options.transport, Transport::External { .. }) {
916 if options.github_token.is_some() {
917 return Err(Error::InvalidConfig(
918 "github_token cannot be used with Transport::External \
919 (external server manages its own auth)"
920 .to_string(),
921 ));
922 }
923 if options.use_logged_in_user == Some(true) {
924 return Err(Error::InvalidConfig(
925 "use_logged_in_user cannot be used with Transport::External \
926 (external server manages its own auth)"
927 .to_string(),
928 ));
929 }
930 }
931 if let Some(token) = &options.tcp_connection_token {
935 if token.is_empty() {
936 return Err(Error::InvalidConfig(
937 "tcp_connection_token must be a non-empty string".to_string(),
938 ));
939 }
940 if matches!(options.transport, Transport::Stdio) {
941 return Err(Error::InvalidConfig(
942 "tcp_connection_token cannot be used with Transport::Stdio".to_string(),
943 ));
944 }
945 }
946 let effective_connection_token: Option<String> = match &options.transport {
947 Transport::Stdio => None,
948 Transport::Tcp { .. } => Some(
949 options
950 .tcp_connection_token
951 .clone()
952 .unwrap_or_else(generate_connection_token),
953 ),
954 Transport::External { .. } => options.tcp_connection_token.clone(),
955 };
956 let mut options = options;
957 if matches!(options.transport, Transport::Tcp { .. })
958 && options.tcp_connection_token.is_none()
959 {
960 options.tcp_connection_token = effective_connection_token.clone();
963 }
964 let session_fs_config = options.session_fs.clone();
965 let program = match &options.program {
966 CliProgram::Path(path) => {
967 info!(path = %path.display(), "using explicit copilot CLI path");
968 path.clone()
969 }
970 CliProgram::Resolve => {
971 let resolved = resolve::copilot_binary()?;
972 info!(path = %resolved.display(), "resolved copilot CLI");
973 #[cfg(windows)]
974 {
975 if let Some(ext) = resolved.extension().and_then(|e| e.to_str()).filter(|ext| {
976 ext.eq_ignore_ascii_case("cmd") || ext.eq_ignore_ascii_case("bat")
977 }) {
978 warn!(
979 path = %resolved.display(),
980 ext = %ext,
981 "resolved copilot CLI is a .cmd/.bat wrapper; \
982 this may cause console window flashes on Windows"
983 );
984 }
985 }
986 resolved
987 }
988 };
989
990 let client = match options.transport {
991 Transport::External { ref host, port } => {
992 info!(host = %host, port = %port, "connecting to external CLI server");
993 let connect_start = Instant::now();
994 let stream = TcpStream::connect((host.as_str(), port)).await?;
995 debug!(
996 elapsed_ms = connect_start.elapsed().as_millis(),
997 host = %host,
998 port,
999 "Client::start TCP connect complete"
1000 );
1001 let (reader, writer) = tokio::io::split(stream);
1002 Self::from_transport(
1003 reader,
1004 writer,
1005 None,
1006 options.cwd,
1007 options.on_list_models,
1008 session_fs_config.is_some(),
1009 options.on_get_trace_context,
1010 effective_connection_token.clone(),
1011 )?
1012 }
1013 Transport::Tcp { port } => {
1014 let (mut child, actual_port) = Self::spawn_tcp(&program, &options, port).await?;
1015 let connect_start = Instant::now();
1016 let stream = TcpStream::connect(("127.0.0.1", actual_port)).await?;
1017 debug!(
1018 elapsed_ms = connect_start.elapsed().as_millis(),
1019 port = actual_port,
1020 "Client::start TCP connect complete"
1021 );
1022 let (reader, writer) = tokio::io::split(stream);
1023 Self::drain_stderr(&mut child);
1024 Self::from_transport(
1025 reader,
1026 writer,
1027 Some(child),
1028 options.cwd,
1029 options.on_list_models,
1030 session_fs_config.is_some(),
1031 options.on_get_trace_context,
1032 effective_connection_token.clone(),
1033 )?
1034 }
1035 Transport::Stdio => {
1036 let mut child = Self::spawn_stdio(&program, &options)?;
1037 let stdin = child.stdin.take().expect("stdin is piped");
1038 let stdout = child.stdout.take().expect("stdout is piped");
1039 Self::drain_stderr(&mut child);
1040 Self::from_transport(
1041 stdout,
1042 stdin,
1043 Some(child),
1044 options.cwd,
1045 options.on_list_models,
1046 session_fs_config.is_some(),
1047 options.on_get_trace_context,
1048 effective_connection_token.clone(),
1049 )?
1050 }
1051 };
1052
1053 debug!(
1054 elapsed_ms = start_time.elapsed().as_millis(),
1055 "Client::start transport setup complete"
1056 );
1057 client.verify_protocol_version().await?;
1058 debug!(
1059 elapsed_ms = start_time.elapsed().as_millis(),
1060 "Client::start protocol verification complete"
1061 );
1062 if let Some(cfg) = session_fs_config {
1063 let session_fs_start = Instant::now();
1064 let request = crate::generated::api_types::SessionFsSetProviderRequest {
1065 conventions: cfg.conventions.into_wire(),
1066 initial_cwd: cfg.initial_cwd,
1067 session_state_path: cfg.session_state_path,
1068 };
1069 client.rpc().session_fs().set_provider(request).await?;
1070 debug!(
1071 elapsed_ms = session_fs_start.elapsed().as_millis(),
1072 "Client::start session filesystem setup complete"
1073 );
1074 }
1075 debug!(
1076 elapsed_ms = start_time.elapsed().as_millis(),
1077 "Client::start complete"
1078 );
1079 Ok(client)
1080 }
1081
1082 pub fn from_streams(
1086 reader: impl AsyncRead + Unpin + Send + 'static,
1087 writer: impl AsyncWrite + Unpin + Send + 'static,
1088 cwd: PathBuf,
1089 ) -> Result<Self, Error> {
1090 Self::from_transport(reader, writer, None, cwd, None, false, None, None)
1091 }
1092
1093 #[cfg(any(test, feature = "test-support"))]
1101 pub fn from_streams_with_trace_provider(
1102 reader: impl AsyncRead + Unpin + Send + 'static,
1103 writer: impl AsyncWrite + Unpin + Send + 'static,
1104 cwd: PathBuf,
1105 provider: Arc<dyn TraceContextProvider>,
1106 ) -> Result<Self, Error> {
1107 Self::from_transport(reader, writer, None, cwd, None, false, Some(provider), None)
1108 }
1109
1110 #[cfg(any(test, feature = "test-support"))]
1114 pub fn from_streams_with_connection_token(
1115 reader: impl AsyncRead + Unpin + Send + 'static,
1116 writer: impl AsyncWrite + Unpin + Send + 'static,
1117 cwd: PathBuf,
1118 token: Option<String>,
1119 ) -> Result<Self, Error> {
1120 Self::from_transport(reader, writer, None, cwd, None, false, None, token)
1121 }
1122
1123 #[cfg(any(test, feature = "test-support"))]
1129 pub fn generate_connection_token_for_test() -> String {
1130 generate_connection_token()
1131 }
1132
1133 #[allow(clippy::too_many_arguments)]
1134 fn from_transport(
1135 reader: impl AsyncRead + Unpin + Send + 'static,
1136 writer: impl AsyncWrite + Unpin + Send + 'static,
1137 child: Option<Child>,
1138 cwd: PathBuf,
1139 on_list_models: Option<Arc<dyn ListModelsHandler>>,
1140 session_fs_configured: bool,
1141 on_get_trace_context: Option<Arc<dyn TraceContextProvider>>,
1142 effective_connection_token: Option<String>,
1143 ) -> Result<Self, Error> {
1144 let setup_start = Instant::now();
1145 let (request_tx, request_rx) = mpsc::unbounded_channel::<JsonRpcRequest>();
1146 let (notification_broadcast_tx, _) = broadcast::channel::<JsonRpcNotification>(1024);
1147 let rpc = JsonRpcClient::new(
1148 writer,
1149 reader,
1150 notification_broadcast_tx.clone(),
1151 request_tx,
1152 );
1153
1154 let pid = child.as_ref().and_then(|c| c.id());
1155 info!(pid = ?pid, "copilot CLI client ready");
1156
1157 let client = Self {
1158 inner: Arc::new(ClientInner {
1159 child: parking_lot::Mutex::new(child),
1160 rpc,
1161 cwd,
1162 request_rx: parking_lot::Mutex::new(Some(request_rx)),
1163 notification_tx: notification_broadcast_tx,
1164 router: router::SessionRouter::new(),
1165 negotiated_protocol_version: OnceLock::new(),
1166 state: parking_lot::Mutex::new(ConnectionState::Connected),
1167 lifecycle_tx: broadcast::channel(256).0,
1168 on_list_models,
1169 models_cache: parking_lot::Mutex::new(Arc::new(tokio::sync::OnceCell::new())),
1170 session_fs_configured,
1171 on_get_trace_context,
1172 effective_connection_token,
1173 }),
1174 };
1175 client.spawn_lifecycle_dispatcher();
1176 debug!(
1177 elapsed_ms = setup_start.elapsed().as_millis(),
1178 pid = ?pid,
1179 "Client::from_transport setup complete"
1180 );
1181 Ok(client)
1182 }
1183
1184 fn spawn_lifecycle_dispatcher(&self) {
1188 let inner = Arc::clone(&self.inner);
1189 let mut notif_rx = inner.notification_tx.subscribe();
1190 tokio::spawn(async move {
1191 loop {
1192 match notif_rx.recv().await {
1193 Ok(notification) => {
1194 if notification.method != "session.lifecycle" {
1195 continue;
1196 }
1197 let Some(params) = notification.params.as_ref() else {
1198 continue;
1199 };
1200 let event: SessionLifecycleEvent =
1201 match serde_json::from_value(params.clone()) {
1202 Ok(e) => e,
1203 Err(e) => {
1204 warn!(
1205 error = %e,
1206 "failed to deserialize session.lifecycle notification"
1207 );
1208 continue;
1209 }
1210 };
1211 let _ = inner.lifecycle_tx.send(event);
1214 }
1215 Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
1216 warn!(missed = n, "lifecycle dispatcher lagged");
1217 }
1218 Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
1219 }
1220 }
1221 });
1222 }
1223
1224 fn build_command(program: &Path, options: &ClientOptions) -> Command {
1225 let mut command = Command::new(program);
1226 for arg in &options.prefix_args {
1227 command.arg(arg);
1228 }
1229 if let Some(token) = &options.github_token {
1232 command.env("COPILOT_SDK_AUTH_TOKEN", token);
1233 }
1234 if let Some(telemetry) = &options.telemetry {
1237 command.env("COPILOT_OTEL_ENABLED", "true");
1238 if let Some(endpoint) = &telemetry.otlp_endpoint {
1239 command.env("OTEL_EXPORTER_OTLP_ENDPOINT", endpoint);
1240 }
1241 if let Some(path) = &telemetry.file_path {
1242 command.env("COPILOT_OTEL_FILE_EXPORTER_PATH", path);
1243 }
1244 if let Some(exporter) = telemetry.exporter_type {
1245 command.env("COPILOT_OTEL_EXPORTER_TYPE", exporter.as_str());
1246 }
1247 if let Some(source) = &telemetry.source_name {
1248 command.env("COPILOT_OTEL_SOURCE_NAME", source);
1249 }
1250 if let Some(capture) = telemetry.capture_content {
1251 command.env(
1252 "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT",
1253 if capture { "true" } else { "false" },
1254 );
1255 }
1256 }
1257 if let Some(home) = &options.copilot_home {
1258 command.env("COPILOT_HOME", home);
1259 }
1260 if let Some(token) = &options.tcp_connection_token {
1261 command.env("COPILOT_CONNECTION_TOKEN", token);
1262 }
1263 for (key, value) in &options.env {
1264 command.env(key, value);
1265 }
1266 for key in &options.env_remove {
1267 command.env_remove(key);
1268 }
1269 command
1270 .current_dir(&options.cwd)
1271 .stdout(Stdio::piped())
1272 .stderr(Stdio::piped());
1273
1274 #[cfg(windows)]
1275 {
1276 use std::os::windows::process::CommandExt;
1277 const CREATE_NO_WINDOW: u32 = 0x08000000;
1278 command.as_std_mut().creation_flags(CREATE_NO_WINDOW);
1279 }
1280
1281 command
1282 }
1283
1284 fn auth_args(options: &ClientOptions) -> Vec<&'static str> {
1292 let mut args: Vec<&'static str> = Vec::new();
1293 if options.github_token.is_some() {
1294 args.push("--auth-token-env");
1295 args.push("COPILOT_SDK_AUTH_TOKEN");
1296 }
1297 let use_logged_in = options
1298 .use_logged_in_user
1299 .unwrap_or(options.github_token.is_none());
1300 if !use_logged_in {
1301 args.push("--no-auto-login");
1302 }
1303 args
1304 }
1305
1306 fn session_idle_timeout_args(options: &ClientOptions) -> Vec<String> {
1310 match options.session_idle_timeout_seconds {
1311 Some(secs) if secs > 0 => {
1312 vec!["--session-idle-timeout".to_string(), secs.to_string()]
1313 }
1314 _ => Vec::new(),
1315 }
1316 }
1317
1318 fn remote_args(options: &ClientOptions) -> Vec<String> {
1319 if options.remote {
1320 vec!["--remote".to_string()]
1321 } else {
1322 Vec::new()
1323 }
1324 }
1325
1326 fn spawn_stdio(program: &Path, options: &ClientOptions) -> Result<Child, Error> {
1327 info!(cwd = ?options.cwd, program = %program.display(), "spawning copilot CLI (stdio)");
1328 let mut command = Self::build_command(program, options);
1329 let log_level = options.log_level.unwrap_or(LogLevel::Info);
1330 command
1331 .args([
1332 "--server",
1333 "--stdio",
1334 "--no-auto-update",
1335 "--log-level",
1336 log_level.as_str(),
1337 ])
1338 .args(Self::auth_args(options))
1339 .args(Self::session_idle_timeout_args(options))
1340 .args(Self::remote_args(options))
1341 .args(&options.extra_args)
1342 .stdin(Stdio::piped());
1343 let spawn_start = Instant::now();
1344 let child = command.spawn()?;
1345 debug!(
1346 elapsed_ms = spawn_start.elapsed().as_millis(),
1347 "Client::spawn_stdio subprocess spawned"
1348 );
1349 Ok(child)
1350 }
1351
1352 async fn spawn_tcp(
1353 program: &Path,
1354 options: &ClientOptions,
1355 port: u16,
1356 ) -> Result<(Child, u16), Error> {
1357 info!(cwd = ?options.cwd, program = %program.display(), port = %port, "spawning copilot CLI (tcp)");
1358 let mut command = Self::build_command(program, options);
1359 let log_level = options.log_level.unwrap_or(LogLevel::Info);
1360 command
1361 .args([
1362 "--server",
1363 "--port",
1364 &port.to_string(),
1365 "--no-auto-update",
1366 "--log-level",
1367 log_level.as_str(),
1368 ])
1369 .args(Self::auth_args(options))
1370 .args(Self::session_idle_timeout_args(options))
1371 .args(Self::remote_args(options))
1372 .args(&options.extra_args)
1373 .stdin(Stdio::null());
1374 let spawn_start = Instant::now();
1375 let mut child = command.spawn()?;
1376 debug!(
1377 elapsed_ms = spawn_start.elapsed().as_millis(),
1378 "Client::spawn_tcp subprocess spawned"
1379 );
1380 let stdout = child.stdout.take().expect("stdout is piped");
1381
1382 let (port_tx, port_rx) = oneshot::channel::<u16>();
1383 let span = tracing::error_span!("copilot_cli_port_scan");
1384 tokio::spawn(
1385 async move {
1386 let port_re = regex::Regex::new(r"listening on port (\d+)").expect("valid regex");
1388 let mut lines = BufReader::new(stdout).lines();
1389 let mut port_tx = Some(port_tx);
1390 while let Ok(Some(line)) = lines.next_line().await {
1391 debug!(line = %line, "CLI stdout");
1392 if let Some(tx) = port_tx.take() {
1393 if let Some(caps) = port_re.captures(&line)
1394 && let Some(p) =
1395 caps.get(1).and_then(|m| m.as_str().parse::<u16>().ok())
1396 {
1397 let _ = tx.send(p);
1398 continue;
1399 }
1400 port_tx = Some(tx);
1402 }
1403 }
1404 }
1405 .instrument(span),
1406 );
1407
1408 let port_wait_start = Instant::now();
1409 let actual_port = tokio::time::timeout(std::time::Duration::from_secs(10), port_rx)
1410 .await
1411 .map_err(|_| Error::Protocol(ProtocolError::CliStartupTimeout))?
1412 .map_err(|_| Error::Protocol(ProtocolError::CliStartupFailed))?;
1413
1414 debug!(
1415 elapsed_ms = port_wait_start.elapsed().as_millis(),
1416 port = actual_port,
1417 "Client::spawn_tcp TCP port wait complete"
1418 );
1419 info!(port = %actual_port, "CLI server listening");
1420 Ok((child, actual_port))
1421 }
1422
1423 fn drain_stderr(child: &mut Child) {
1424 if let Some(stderr) = child.stderr.take() {
1425 let span = tracing::error_span!("copilot_cli");
1426 tokio::spawn(
1427 async move {
1428 let mut reader = BufReader::new(stderr).lines();
1429 while let Ok(Some(line)) = reader.next_line().await {
1430 warn!(line = %line, "CLI stderr");
1431 }
1432 }
1433 .instrument(span),
1434 );
1435 }
1436 }
1437
1438 pub fn cwd(&self) -> &PathBuf {
1440 &self.inner.cwd
1441 }
1442
1443 pub fn rpc(&self) -> crate::generated::rpc::ClientRpc<'_> {
1454 crate::generated::rpc::ClientRpc { client: self }
1455 }
1456
1457 pub(crate) async fn send_request(
1459 &self,
1460 method: &str,
1461 params: Option<serde_json::Value>,
1462 ) -> Result<JsonRpcResponse, Error> {
1463 self.inner.rpc.send_request(method, params).await
1464 }
1465
1466 pub async fn call(
1486 &self,
1487 method: &str,
1488 params: Option<serde_json::Value>,
1489 ) -> Result<serde_json::Value, Error> {
1490 let session_id: Option<SessionId> = params
1491 .as_ref()
1492 .and_then(|p| p.get("sessionId"))
1493 .and_then(|v| v.as_str())
1494 .map(SessionId::from);
1495 let response = self.send_request(method, params).await?;
1496 if let Some(err) = response.error {
1497 if err.message.contains("Session not found") {
1498 return Err(Error::Session(SessionError::NotFound(
1499 session_id.unwrap_or_else(|| "unknown".into()),
1500 )));
1501 }
1502 return Err(Error::Rpc {
1503 code: err.code,
1504 message: err.message,
1505 });
1506 }
1507 Ok(response.result.unwrap_or(serde_json::Value::Null))
1508 }
1509
1510 pub(crate) async fn send_response(&self, response: &JsonRpcResponse) -> Result<(), Error> {
1512 self.inner.rpc.write(response).await
1513 }
1514
1515 #[expect(dead_code, reason = "reserved for future pub(crate) use")]
1519 pub(crate) fn take_request_rx(&self) -> Option<mpsc::UnboundedReceiver<JsonRpcRequest>> {
1520 self.inner.request_rx.lock().take()
1521 }
1522
1523 pub(crate) fn register_session(
1531 &self,
1532 session_id: &SessionId,
1533 ) -> crate::router::SessionChannels {
1534 self.inner
1535 .router
1536 .ensure_started(&self.inner.notification_tx, &self.inner.request_rx);
1537 self.inner.router.register(session_id)
1538 }
1539
1540 pub(crate) fn unregister_session(&self, session_id: &SessionId) {
1542 self.inner.router.unregister(session_id);
1543 }
1544
1545 pub fn protocol_version(&self) -> Option<u32> {
1552 self.inner.negotiated_protocol_version.get().copied()
1553 }
1554
1555 pub async fn verify_protocol_version(&self) -> Result<(), Error> {
1579 let handshake_start = Instant::now();
1580 let mut used_fallback_ping = false;
1581 let server_version = match self.connect_handshake().await {
1585 Ok(v) => v,
1586 Err(Error::Rpc { code, .. }) if code == error_codes::METHOD_NOT_FOUND => {
1587 used_fallback_ping = true;
1588 self.ping(None).await?.protocol_version
1589 }
1590 Err(e) => return Err(e),
1591 };
1592
1593 match server_version {
1594 None => {
1595 warn!("CLI server did not report protocolVersion; skipping version check");
1596 }
1597 Some(v) if !(MIN_PROTOCOL_VERSION..=SDK_PROTOCOL_VERSION).contains(&v) => {
1598 return Err(Error::Protocol(ProtocolError::VersionMismatch {
1599 server: v,
1600 min: MIN_PROTOCOL_VERSION,
1601 max: SDK_PROTOCOL_VERSION,
1602 }));
1603 }
1604 Some(v) => {
1605 if let Some(&existing) = self.inner.negotiated_protocol_version.get() {
1606 if existing != v {
1607 return Err(Error::Protocol(ProtocolError::VersionChanged {
1608 previous: existing,
1609 current: v,
1610 }));
1611 }
1612 } else {
1613 let _ = self.inner.negotiated_protocol_version.set(v);
1614 }
1615 }
1616 }
1617
1618 debug!(
1619 elapsed_ms = handshake_start.elapsed().as_millis(),
1620 protocol_version = ?server_version,
1621 used_fallback_ping,
1622 "Client::verify_protocol_version protocol handshake complete"
1623 );
1624 Ok(())
1625 }
1626
1627 async fn connect_handshake(&self) -> Result<Option<u32>, Error> {
1634 let result = self
1635 .rpc()
1636 .connect(crate::generated::api_types::ConnectRequest {
1637 token: self.inner.effective_connection_token.clone(),
1638 })
1639 .await?;
1640 Ok(u32::try_from(result.protocol_version).ok())
1641 }
1642
1643 pub async fn ping(&self, message: Option<&str>) -> Result<crate::types::PingResponse, Error> {
1651 let params = match message {
1652 Some(m) => serde_json::json!({ "message": m }),
1653 None => serde_json::json!({}),
1654 };
1655 let value = self
1656 .call(generated::api_types::rpc_methods::PING, Some(params))
1657 .await?;
1658 Ok(serde_json::from_value(value)?)
1659 }
1660
1661 pub async fn list_sessions(
1664 &self,
1665 filter: Option<SessionListFilter>,
1666 ) -> Result<Vec<SessionMetadata>, Error> {
1667 let params = match filter {
1668 Some(f) => serde_json::json!({ "filter": f }),
1669 None => serde_json::json!({}),
1670 };
1671 let result = self.call("session.list", Some(params)).await?;
1672 let response: ListSessionsResponse = serde_json::from_value(result)?;
1673 Ok(response.sessions)
1674 }
1675
1676 pub async fn get_session_metadata(
1694 &self,
1695 session_id: &SessionId,
1696 ) -> Result<Option<SessionMetadata>, Error> {
1697 let result = self
1698 .call(
1699 "session.getMetadata",
1700 Some(serde_json::json!({ "sessionId": session_id })),
1701 )
1702 .await?;
1703 let response: GetSessionMetadataResponse = serde_json::from_value(result)?;
1704 Ok(response.session)
1705 }
1706
1707 pub async fn delete_session(&self, session_id: &SessionId) -> Result<(), Error> {
1709 self.call(
1710 "session.delete",
1711 Some(serde_json::json!({ "sessionId": session_id })),
1712 )
1713 .await?;
1714 Ok(())
1715 }
1716
1717 pub async fn get_last_session_id(&self) -> Result<Option<SessionId>, Error> {
1733 let result = self
1734 .call("session.getLastId", Some(serde_json::json!({})))
1735 .await?;
1736 let response: GetLastSessionIdResponse = serde_json::from_value(result)?;
1737 Ok(response.session_id)
1738 }
1739
1740 pub async fn get_foreground_session_id(&self) -> Result<Option<SessionId>, Error> {
1745 let result = self
1746 .call("session.getForeground", Some(serde_json::json!({})))
1747 .await?;
1748 let response: GetForegroundSessionResponse = serde_json::from_value(result)?;
1749 Ok(response.session_id)
1750 }
1751
1752 pub async fn set_foreground_session_id(&self, session_id: &SessionId) -> Result<(), Error> {
1757 self.call(
1758 "session.setForeground",
1759 Some(serde_json::json!({ "sessionId": session_id })),
1760 )
1761 .await?;
1762 Ok(())
1763 }
1764
1765 pub async fn get_status(&self) -> Result<GetStatusResponse, Error> {
1767 let result = self.call("status.get", Some(serde_json::json!({}))).await?;
1768 Ok(serde_json::from_value(result)?)
1769 }
1770
1771 pub async fn get_auth_status(&self) -> Result<GetAuthStatusResponse, Error> {
1773 let result = self
1774 .call("auth.getStatus", Some(serde_json::json!({})))
1775 .await?;
1776 Ok(serde_json::from_value(result)?)
1777 }
1778
1779 pub async fn list_models(&self) -> Result<Vec<Model>, Error> {
1784 let cache = self.inner.models_cache.lock().clone();
1785 let models = cache
1786 .get_or_try_init(|| async {
1787 if let Some(handler) = &self.inner.on_list_models {
1788 handler.list_models().await
1789 } else {
1790 Ok(self.rpc().models().list().await?.models)
1791 }
1792 })
1793 .await?;
1794 Ok(models.clone())
1795 }
1796
1797 pub(crate) async fn resolve_trace_context(&self) -> TraceContext {
1800 if let Some(provider) = &self.inner.on_get_trace_context {
1801 provider.get_trace_context().await
1802 } else {
1803 TraceContext::default()
1804 }
1805 }
1806
1807 pub fn pid(&self) -> Option<u32> {
1809 self.inner.child.lock().as_ref().and_then(|c| c.id())
1810 }
1811
1812 pub async fn stop(&self) -> Result<(), StopErrors> {
1838 let pid = self.pid();
1839 info!(pid = ?pid, "stopping CLI process");
1840 let mut errors: Vec<Error> = Vec::new();
1841
1842 for session_id in self.inner.router.session_ids() {
1845 match self
1846 .call(
1847 "session.destroy",
1848 Some(serde_json::json!({ "sessionId": session_id })),
1849 )
1850 .await
1851 {
1852 Ok(_) => {}
1853 Err(e) => {
1854 warn!(
1855 session_id = %session_id,
1856 error = %e,
1857 "session.destroy failed during Client::stop",
1858 );
1859 errors.push(e);
1860 }
1861 }
1862 self.inner.router.unregister(&session_id);
1863 }
1864
1865 let child = self.inner.child.lock().take();
1866 *self.inner.state.lock() = ConnectionState::Disconnected;
1867 *self.inner.models_cache.lock() = Arc::new(tokio::sync::OnceCell::new());
1868 if let Some(mut child) = child
1869 && let Err(e) = child.kill().await
1870 {
1871 errors.push(Error::Io(e));
1872 }
1873
1874 info!(pid = ?pid, errors = errors.len(), "CLI process stopped");
1875 if errors.is_empty() {
1876 Ok(())
1877 } else {
1878 Err(StopErrors(errors))
1879 }
1880 }
1881
1882 pub fn force_stop(&self) {
1912 let pid = self.pid();
1913 info!(pid = ?pid, "force-stopping CLI process");
1914 if let Some(mut child) = self.inner.child.lock().take()
1915 && let Err(e) = child.start_kill()
1916 {
1917 error!(pid = ?pid, error = %e, "failed to send kill signal");
1918 }
1919 self.inner.rpc.force_close();
1920 self.inner.router.clear();
1923 *self.inner.state.lock() = ConnectionState::Disconnected;
1924 *self.inner.models_cache.lock() = Arc::new(tokio::sync::OnceCell::new());
1925 }
1926
1927 pub fn subscribe_lifecycle(&self) -> LifecycleSubscription {
1961 LifecycleSubscription::new(self.inner.lifecycle_tx.subscribe())
1962 }
1963
1964 pub fn state(&self) -> ConnectionState {
1971 *self.inner.state.lock()
1972 }
1973}
1974
1975impl Drop for ClientInner {
1976 fn drop(&mut self) {
1977 if let Some(ref mut child) = *self.child.lock() {
1978 let pid = child.id();
1979 if let Err(e) = child.start_kill() {
1980 error!(pid = ?pid, error = %e, "failed to kill CLI process on drop");
1981 } else {
1982 info!(pid = ?pid, "kill signal sent for CLI process on drop");
1983 }
1984 }
1985 }
1986}
1987
1988#[cfg(test)]
1989mod tests {
1990 use super::*;
1991
1992 #[test]
1993 fn is_transport_failure_matches_request_cancelled() {
1994 let err = Error::Protocol(ProtocolError::RequestCancelled);
1995 assert!(err.is_transport_failure());
1996 }
1997
1998 #[test]
1999 fn is_transport_failure_matches_io_error() {
2000 let err = Error::Io(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "gone"));
2001 assert!(err.is_transport_failure());
2002 }
2003
2004 #[test]
2005 fn is_transport_failure_rejects_rpc_error() {
2006 let err = Error::Rpc {
2007 code: -1,
2008 message: "bad".into(),
2009 };
2010 assert!(!err.is_transport_failure());
2011 }
2012
2013 #[test]
2014 fn is_transport_failure_rejects_session_error() {
2015 let err = Error::Session(SessionError::NotFound("s1".into()));
2016 assert!(!err.is_transport_failure());
2017 }
2018
2019 #[test]
2020 fn client_options_builder_composes() {
2021 let opts = ClientOptions::new()
2022 .with_program(CliProgram::Path(PathBuf::from("/usr/local/bin/copilot")))
2023 .with_prefix_args(["node"])
2024 .with_cwd(PathBuf::from("/tmp"))
2025 .with_env([("KEY", "value")])
2026 .with_env_remove(["UNWANTED"])
2027 .with_extra_args(["--quiet"])
2028 .with_github_token("ghp_test")
2029 .with_use_logged_in_user(false)
2030 .with_log_level(LogLevel::Debug)
2031 .with_session_idle_timeout_seconds(120)
2032 .with_remote(true);
2033 assert!(matches!(opts.program, CliProgram::Path(_)));
2034 assert_eq!(opts.prefix_args, vec![std::ffi::OsString::from("node")]);
2035 assert_eq!(opts.cwd, PathBuf::from("/tmp"));
2036 assert_eq!(
2037 opts.env,
2038 vec![(
2039 std::ffi::OsString::from("KEY"),
2040 std::ffi::OsString::from("value")
2041 )]
2042 );
2043 assert_eq!(opts.env_remove, vec![std::ffi::OsString::from("UNWANTED")]);
2044 assert_eq!(opts.extra_args, vec!["--quiet".to_string()]);
2045 assert_eq!(opts.github_token.as_deref(), Some("ghp_test"));
2046 assert_eq!(opts.use_logged_in_user, Some(false));
2047 assert!(matches!(opts.log_level, Some(LogLevel::Debug)));
2048 assert_eq!(opts.session_idle_timeout_seconds, Some(120));
2049 assert!(opts.remote);
2050 }
2051
2052 #[test]
2053 fn is_transport_failure_rejects_other_protocol_errors() {
2054 let err = Error::Protocol(ProtocolError::CliStartupTimeout);
2055 assert!(!err.is_transport_failure());
2056 }
2057
2058 #[test]
2059 fn build_command_lets_env_remove_strip_injected_token() {
2060 let opts = ClientOptions {
2061 github_token: Some("secret".to_string()),
2062 env_remove: vec![std::ffi::OsString::from("COPILOT_SDK_AUTH_TOKEN")],
2063 ..Default::default()
2064 };
2065 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2066 let action = cmd
2068 .as_std()
2069 .get_envs()
2070 .find(|(k, _)| *k == std::ffi::OsStr::new("COPILOT_SDK_AUTH_TOKEN"))
2071 .map(|(_, v)| v);
2072 assert_eq!(
2073 action,
2074 Some(None),
2075 "env_remove should win over github_token"
2076 );
2077 }
2078
2079 #[test]
2080 fn build_command_lets_env_override_injected_token() {
2081 let opts = ClientOptions {
2082 github_token: Some("from-options".to_string()),
2083 env: vec![(
2084 std::ffi::OsString::from("COPILOT_SDK_AUTH_TOKEN"),
2085 std::ffi::OsString::from("from-env"),
2086 )],
2087 ..Default::default()
2088 };
2089 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2090 let value = cmd
2091 .as_std()
2092 .get_envs()
2093 .find(|(k, _)| *k == std::ffi::OsStr::new("COPILOT_SDK_AUTH_TOKEN"))
2094 .and_then(|(_, v)| v);
2095 assert_eq!(value, Some(std::ffi::OsStr::new("from-env")));
2096 }
2097
2098 #[test]
2099 fn build_command_injects_github_token_by_default() {
2100 let opts = ClientOptions {
2101 github_token: Some("just-the-token".to_string()),
2102 ..Default::default()
2103 };
2104 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2105 let value = cmd
2106 .as_std()
2107 .get_envs()
2108 .find(|(k, _)| *k == std::ffi::OsStr::new("COPILOT_SDK_AUTH_TOKEN"))
2109 .and_then(|(_, v)| v);
2110 assert_eq!(value, Some(std::ffi::OsStr::new("just-the-token")));
2111 }
2112
2113 fn env_value<'a>(cmd: &'a tokio::process::Command, key: &str) -> Option<&'a std::ffi::OsStr> {
2114 cmd.as_std()
2115 .get_envs()
2116 .find(|(k, _)| *k == std::ffi::OsStr::new(key))
2117 .and_then(|(_, v)| v)
2118 }
2119
2120 #[test]
2121 fn telemetry_config_builder_composes() {
2122 let cfg = TelemetryConfig::new()
2123 .with_otlp_endpoint("http://collector:4318")
2124 .with_file_path(PathBuf::from("/var/log/copilot.jsonl"))
2125 .with_exporter_type(OtelExporterType::OtlpHttp)
2126 .with_source_name("my-app")
2127 .with_capture_content(true);
2128
2129 assert_eq!(cfg.otlp_endpoint.as_deref(), Some("http://collector:4318"));
2130 assert_eq!(
2131 cfg.file_path.as_deref(),
2132 Some(Path::new("/var/log/copilot.jsonl")),
2133 );
2134 assert_eq!(cfg.exporter_type, Some(OtelExporterType::OtlpHttp));
2135 assert_eq!(cfg.source_name.as_deref(), Some("my-app"));
2136 assert_eq!(cfg.capture_content, Some(true));
2137 assert!(!cfg.is_empty());
2138 assert!(TelemetryConfig::new().is_empty());
2139 }
2140
2141 #[test]
2142 fn build_command_sets_otel_env_when_telemetry_enabled() {
2143 let opts = ClientOptions {
2144 telemetry: Some(TelemetryConfig {
2145 otlp_endpoint: Some("http://collector:4318".to_string()),
2146 file_path: Some(PathBuf::from("/var/log/copilot.jsonl")),
2147 exporter_type: Some(OtelExporterType::OtlpHttp),
2148 source_name: Some("my-app".to_string()),
2149 capture_content: Some(true),
2150 }),
2151 ..Default::default()
2152 };
2153 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2154 assert_eq!(
2155 env_value(&cmd, "COPILOT_OTEL_ENABLED"),
2156 Some(std::ffi::OsStr::new("true")),
2157 );
2158 assert_eq!(
2159 env_value(&cmd, "OTEL_EXPORTER_OTLP_ENDPOINT"),
2160 Some(std::ffi::OsStr::new("http://collector:4318")),
2161 );
2162 assert_eq!(
2163 env_value(&cmd, "COPILOT_OTEL_FILE_EXPORTER_PATH"),
2164 Some(std::ffi::OsStr::new("/var/log/copilot.jsonl")),
2165 );
2166 assert_eq!(
2167 env_value(&cmd, "COPILOT_OTEL_EXPORTER_TYPE"),
2168 Some(std::ffi::OsStr::new("otlp-http")),
2169 );
2170 assert_eq!(
2171 env_value(&cmd, "COPILOT_OTEL_SOURCE_NAME"),
2172 Some(std::ffi::OsStr::new("my-app")),
2173 );
2174 assert_eq!(
2175 env_value(&cmd, "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"),
2176 Some(std::ffi::OsStr::new("true")),
2177 );
2178 }
2179
2180 #[test]
2181 fn build_command_omits_otel_env_when_telemetry_none() {
2182 let opts = ClientOptions::default();
2183 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2184 for key in [
2185 "COPILOT_OTEL_ENABLED",
2186 "OTEL_EXPORTER_OTLP_ENDPOINT",
2187 "COPILOT_OTEL_FILE_EXPORTER_PATH",
2188 "COPILOT_OTEL_EXPORTER_TYPE",
2189 "COPILOT_OTEL_SOURCE_NAME",
2190 "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT",
2191 ] {
2192 assert!(
2193 env_value(&cmd, key).is_none(),
2194 "expected {key} to be unset when telemetry is None",
2195 );
2196 }
2197 }
2198
2199 #[test]
2200 fn build_command_omits_unset_telemetry_fields() {
2201 let opts = ClientOptions {
2202 telemetry: Some(TelemetryConfig {
2203 otlp_endpoint: Some("http://collector:4318".to_string()),
2204 ..Default::default()
2205 }),
2206 ..Default::default()
2207 };
2208 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2209 assert_eq!(
2211 env_value(&cmd, "COPILOT_OTEL_ENABLED"),
2212 Some(std::ffi::OsStr::new("true")),
2213 );
2214 assert_eq!(
2215 env_value(&cmd, "OTEL_EXPORTER_OTLP_ENDPOINT"),
2216 Some(std::ffi::OsStr::new("http://collector:4318")),
2217 );
2218 for key in [
2220 "COPILOT_OTEL_FILE_EXPORTER_PATH",
2221 "COPILOT_OTEL_EXPORTER_TYPE",
2222 "COPILOT_OTEL_SOURCE_NAME",
2223 "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT",
2224 ] {
2225 assert!(env_value(&cmd, key).is_none(), "{key} should be unset");
2226 }
2227 }
2228
2229 #[test]
2230 fn build_command_lets_user_env_override_telemetry() {
2231 let opts = ClientOptions {
2232 telemetry: Some(TelemetryConfig {
2233 otlp_endpoint: Some("http://from-config:4318".to_string()),
2234 ..Default::default()
2235 }),
2236 env: vec![(
2237 std::ffi::OsString::from("OTEL_EXPORTER_OTLP_ENDPOINT"),
2238 std::ffi::OsString::from("http://from-user-env:4318"),
2239 )],
2240 ..Default::default()
2241 };
2242 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2243 assert_eq!(
2244 env_value(&cmd, "OTEL_EXPORTER_OTLP_ENDPOINT"),
2245 Some(std::ffi::OsStr::new("http://from-user-env:4318")),
2246 "user-supplied options.env should override telemetry config",
2247 );
2248 }
2249
2250 #[test]
2251 fn build_command_sets_copilot_home_env_when_configured() {
2252 let opts = ClientOptions::new().with_copilot_home(PathBuf::from("/custom/copilot"));
2253 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2254 assert_eq!(
2255 env_value(&cmd, "COPILOT_HOME"),
2256 Some(std::ffi::OsStr::new("/custom/copilot")),
2257 );
2258
2259 let opts = ClientOptions::default();
2260 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2261 assert!(env_value(&cmd, "COPILOT_HOME").is_none());
2262 }
2263
2264 #[test]
2265 fn build_command_sets_connection_token_env_when_configured() {
2266 let opts = ClientOptions::new().with_tcp_connection_token("secret-token");
2267 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2268 assert_eq!(
2269 env_value(&cmd, "COPILOT_CONNECTION_TOKEN"),
2270 Some(std::ffi::OsStr::new("secret-token")),
2271 );
2272
2273 let opts = ClientOptions::default();
2274 let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
2275 assert!(env_value(&cmd, "COPILOT_CONNECTION_TOKEN").is_none());
2276 }
2277
2278 #[tokio::test]
2279 async fn start_rejects_token_with_stdio_transport() {
2280 let opts = ClientOptions::new()
2281 .with_tcp_connection_token("token-123")
2282 .with_program(CliProgram::Path(PathBuf::from("/bin/echo")));
2283 let err = Client::start(opts).await.unwrap_err();
2284 assert!(matches!(err, Error::InvalidConfig(_)), "got {err:?}");
2285 let Error::InvalidConfig(msg) = err else {
2286 unreachable!()
2287 };
2288 assert!(
2289 msg.contains("Stdio"),
2290 "error should explain the stdio incompatibility: {msg}"
2291 );
2292 }
2293
2294 #[tokio::test]
2295 async fn start_rejects_empty_connection_token() {
2296 let opts = ClientOptions::new()
2297 .with_tcp_connection_token("")
2298 .with_transport(Transport::Tcp { port: 0 })
2299 .with_program(CliProgram::Path(PathBuf::from("/bin/echo")));
2300 let err = Client::start(opts).await.unwrap_err();
2301 assert!(matches!(err, Error::InvalidConfig(_)), "got {err:?}");
2302 }
2303
2304 #[test]
2305 fn telemetry_config_capture_content_serializes_as_lowercase_bool() {
2306 let opts_true = ClientOptions {
2307 telemetry: Some(TelemetryConfig {
2308 capture_content: Some(true),
2309 ..Default::default()
2310 }),
2311 ..Default::default()
2312 };
2313 let opts_false = ClientOptions {
2314 telemetry: Some(TelemetryConfig {
2315 capture_content: Some(false),
2316 ..Default::default()
2317 }),
2318 ..Default::default()
2319 };
2320 let cmd_true = Client::build_command(Path::new("/bin/echo"), &opts_true);
2321 let cmd_false = Client::build_command(Path::new("/bin/echo"), &opts_false);
2322 assert_eq!(
2323 env_value(
2324 &cmd_true,
2325 "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"
2326 ),
2327 Some(std::ffi::OsStr::new("true")),
2328 );
2329 assert_eq!(
2330 env_value(
2331 &cmd_false,
2332 "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"
2333 ),
2334 Some(std::ffi::OsStr::new("false")),
2335 );
2336 }
2337
2338 #[test]
2339 fn session_idle_timeout_args_are_omitted_by_default() {
2340 let opts = ClientOptions::default();
2341 assert!(Client::session_idle_timeout_args(&opts).is_empty());
2342 }
2343
2344 #[test]
2345 fn session_idle_timeout_args_omitted_for_zero() {
2346 let opts = ClientOptions {
2347 session_idle_timeout_seconds: Some(0),
2348 ..Default::default()
2349 };
2350 assert!(Client::session_idle_timeout_args(&opts).is_empty());
2351 }
2352
2353 #[test]
2354 fn session_idle_timeout_args_emit_flag_for_positive_value() {
2355 let opts = ClientOptions {
2356 session_idle_timeout_seconds: Some(300),
2357 ..Default::default()
2358 };
2359 assert_eq!(
2360 Client::session_idle_timeout_args(&opts),
2361 vec!["--session-idle-timeout".to_string(), "300".to_string()]
2362 );
2363 }
2364
2365 #[test]
2366 fn remote_args_omitted_by_default() {
2367 let opts = ClientOptions::default();
2368 assert!(Client::remote_args(&opts).is_empty());
2369 }
2370
2371 #[test]
2372 fn remote_args_emit_flag_when_enabled() {
2373 let opts = ClientOptions {
2374 remote: true,
2375 ..Default::default()
2376 };
2377 assert_eq!(Client::remote_args(&opts), vec!["--remote".to_string()]);
2378 }
2379
2380 #[test]
2381 fn log_level_str_round_trips() {
2382 for level in [
2383 LogLevel::None,
2384 LogLevel::Error,
2385 LogLevel::Warning,
2386 LogLevel::Info,
2387 LogLevel::Debug,
2388 LogLevel::All,
2389 ] {
2390 let s = level.as_str();
2391 let json = serde_json::to_string(&level).unwrap();
2392 assert_eq!(json, format!("\"{s}\""));
2393 let parsed: LogLevel = serde_json::from_str(&json).unwrap();
2394 assert_eq!(parsed, level);
2395 }
2396 }
2397
2398 #[test]
2399 fn client_options_debug_redacts_handler() {
2400 struct StubHandler;
2401 #[async_trait]
2402 impl ListModelsHandler for StubHandler {
2403 async fn list_models(&self) -> Result<Vec<Model>, Error> {
2404 Ok(vec![])
2405 }
2406 }
2407 let opts = ClientOptions {
2408 on_list_models: Some(Arc::new(StubHandler)),
2409 github_token: Some("secret-token".into()),
2410 ..Default::default()
2411 };
2412 let debug = format!("{opts:?}");
2413 assert!(debug.contains("on_list_models: Some(\"<set>\")"));
2414 assert!(debug.contains("github_token: Some(\"<redacted>\")"));
2415 assert!(!debug.contains("secret-token"));
2416 }
2417
2418 #[tokio::test]
2419 async fn list_models_uses_on_list_models_handler_when_set() {
2420 use std::sync::atomic::{AtomicUsize, Ordering};
2421
2422 struct CountingHandler {
2423 calls: Arc<AtomicUsize>,
2424 models: Vec<Model>,
2425 }
2426 #[async_trait]
2427 impl ListModelsHandler for CountingHandler {
2428 async fn list_models(&self) -> Result<Vec<Model>, Error> {
2429 self.calls.fetch_add(1, Ordering::SeqCst);
2430 Ok(self.models.clone())
2431 }
2432 }
2433
2434 let calls = Arc::new(AtomicUsize::new(0));
2435 let model = Model {
2436 billing: None,
2437 capabilities: ModelCapabilities {
2438 limits: None,
2439 supports: None,
2440 },
2441 default_reasoning_effort: None,
2442 id: "byok-gpt-4".into(),
2443 model_picker_category: None,
2444 model_picker_price_category: None,
2445 name: "BYOK GPT-4".into(),
2446 policy: None,
2447 supported_reasoning_efforts: Vec::new(),
2448 };
2449 let handler: Arc<dyn ListModelsHandler> = Arc::new(CountingHandler {
2450 calls: Arc::clone(&calls),
2451 models: vec![model.clone()],
2452 });
2453
2454 let client = client_with_list_models_handler(handler);
2455
2456 let result = client.list_models().await.unwrap();
2457 assert_eq!(result.len(), 1);
2458 assert_eq!(result[0].id, "byok-gpt-4");
2459 assert_eq!(calls.load(Ordering::SeqCst), 1);
2460 }
2461
2462 #[tokio::test]
2463 async fn list_models_serializes_concurrent_cache_misses() {
2464 use std::sync::atomic::{AtomicUsize, Ordering};
2465
2466 struct SlowCountingHandler {
2467 calls: Arc<AtomicUsize>,
2468 models: Vec<Model>,
2469 }
2470 #[async_trait]
2471 impl ListModelsHandler for SlowCountingHandler {
2472 async fn list_models(&self) -> Result<Vec<Model>, Error> {
2473 self.calls.fetch_add(1, Ordering::SeqCst);
2474 tokio::time::sleep(std::time::Duration::from_millis(25)).await;
2475 Ok(self.models.clone())
2476 }
2477 }
2478
2479 let calls = Arc::new(AtomicUsize::new(0));
2480 let model = Model {
2481 billing: None,
2482 capabilities: ModelCapabilities {
2483 limits: None,
2484 supports: None,
2485 },
2486 default_reasoning_effort: None,
2487 id: "single-flight-model".into(),
2488 model_picker_category: None,
2489 model_picker_price_category: None,
2490 name: "Single Flight Model".into(),
2491 policy: None,
2492 supported_reasoning_efforts: Vec::new(),
2493 };
2494 let handler: Arc<dyn ListModelsHandler> = Arc::new(SlowCountingHandler {
2495 calls: Arc::clone(&calls),
2496 models: vec![model],
2497 });
2498 let client = client_with_list_models_handler(handler);
2499
2500 let (first, second) = tokio::join!(client.list_models(), client.list_models());
2501 assert_eq!(first.unwrap()[0].id, "single-flight-model");
2502 assert_eq!(second.unwrap()[0].id, "single-flight-model");
2503 assert_eq!(calls.load(Ordering::SeqCst), 1);
2504 }
2505
2506 #[tokio::test]
2507 async fn cancelled_create_session_unregisters_pending_session() {
2508 let (client_write, _server_read) = tokio::io::duplex(8192);
2509 let (_server_write, client_read) = tokio::io::duplex(8192);
2510 let client = Client::from_streams(client_read, client_write, std::env::temp_dir()).unwrap();
2511 let handle = tokio::spawn({
2512 let client = client.clone();
2513 async move { client.create_session(SessionConfig::default()).await }
2514 });
2515
2516 wait_for_pending_session_registration(&client).await;
2517 handle.abort();
2518 let _ = handle.await;
2519
2520 assert!(client.inner.router.session_ids().is_empty());
2521 client.force_stop();
2522 }
2523
2524 #[tokio::test]
2525 async fn cancelled_resume_session_unregisters_pending_session() {
2526 let (client_write, _server_read) = tokio::io::duplex(8192);
2527 let (_server_write, client_read) = tokio::io::duplex(8192);
2528 let client = Client::from_streams(client_read, client_write, std::env::temp_dir()).unwrap();
2529 let session_id = SessionId::new("resume-cancel-test");
2530 let handle = tokio::spawn({
2531 let client = client.clone();
2532 async move {
2533 client
2534 .resume_session(ResumeSessionConfig::new(session_id))
2535 .await
2536 }
2537 });
2538
2539 wait_for_pending_session_registration(&client).await;
2540 handle.abort();
2541 let _ = handle.await;
2542
2543 assert!(client.inner.router.session_ids().is_empty());
2544 client.force_stop();
2545 }
2546
2547 fn client_with_list_models_handler(handler: Arc<dyn ListModelsHandler>) -> Client {
2548 Client {
2549 inner: Arc::new(ClientInner {
2550 child: parking_lot::Mutex::new(None),
2551 rpc: {
2552 let (req_tx, _req_rx) = mpsc::unbounded_channel();
2553 let (notif_tx, _notif_rx) = broadcast::channel(16);
2554 let (read_pipe, _write_pipe) = tokio::io::duplex(64);
2555 let (_unused_read, write_pipe) = tokio::io::duplex(64);
2556 JsonRpcClient::new(write_pipe, read_pipe, notif_tx, req_tx)
2557 },
2558 cwd: PathBuf::from("."),
2559 request_rx: parking_lot::Mutex::new(None),
2560 notification_tx: broadcast::channel(16).0,
2561 router: router::SessionRouter::new(),
2562 negotiated_protocol_version: OnceLock::new(),
2563 state: parking_lot::Mutex::new(ConnectionState::Connected),
2564 lifecycle_tx: broadcast::channel(16).0,
2565 on_list_models: Some(handler),
2566 models_cache: parking_lot::Mutex::new(Arc::new(tokio::sync::OnceCell::new())),
2567 session_fs_configured: false,
2568 on_get_trace_context: None,
2569 effective_connection_token: None,
2570 }),
2571 }
2572 }
2573
2574 async fn wait_for_pending_session_registration(client: &Client) {
2575 let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(1);
2576 while client.inner.router.session_ids().is_empty() {
2577 assert!(
2578 tokio::time::Instant::now() < deadline,
2579 "session was not registered"
2580 );
2581 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
2582 }
2583 }
2584}