#![doc = include_str!("../README.md")]
#![warn(missing_docs)]
#![deny(rustdoc::broken_intra_doc_links)]
#![cfg_attr(test, allow(clippy::unwrap_used))]
pub mod embeddedcli;
pub mod handler;
pub mod hooks;
mod jsonrpc;
pub mod permission;
pub mod resolve;
mod router;
pub mod session;
pub mod session_fs;
mod session_fs_dispatch;
pub mod subscription;
pub mod tool;
pub mod trace_context;
pub mod transforms;
pub mod types;
pub mod generated;
use std::ffi::OsString;
use std::path::{Path, PathBuf};
use std::process::Stdio;
use std::sync::{Arc, OnceLock};
use async_trait::async_trait;
pub(crate) use jsonrpc::{
JsonRpcClient, JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, error_codes,
};
#[cfg(feature = "test-support")]
pub mod test_support {
pub use crate::jsonrpc::{
JsonRpcClient, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse,
error_codes,
};
}
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, BufReader};
use tokio::net::TcpStream;
use tokio::process::{Child, Command};
use tokio::sync::{broadcast, mpsc, oneshot};
use tracing::{Instrument, debug, error, info, warn};
pub use types::*;
mod sdk_protocol_version;
pub use sdk_protocol_version::{SDK_PROTOCOL_VERSION, get_sdk_protocol_version};
pub use subscription::{EventSubscription, Lagged, LifecycleSubscription, RecvError};
const MIN_PROTOCOL_VERSION: u32 = 2;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum Error {
#[error("protocol error: {0}")]
Protocol(ProtocolError),
#[error("RPC error {code}: {message}")]
Rpc {
code: i32,
message: String,
},
#[error("session error: {0}")]
Session(SessionError),
#[error(transparent)]
Io(#[from] std::io::Error),
#[error(transparent)]
Json(#[from] serde_json::Error),
#[error("binary not found: {name} ({hint})")]
BinaryNotFound {
name: &'static str,
hint: &'static str,
},
#[error("invalid client configuration: {0}")]
InvalidConfig(String),
}
impl Error {
pub fn is_transport_failure(&self) -> bool {
matches!(
self,
Error::Protocol(ProtocolError::RequestCancelled) | Error::Io(_)
)
}
}
#[derive(Debug)]
pub struct StopErrors(Vec<Error>);
impl StopErrors {
pub fn errors(&self) -> &[Error] {
&self.0
}
pub fn into_errors(self) -> Vec<Error> {
self.0
}
}
impl std::fmt::Display for StopErrors {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.0.as_slice() {
[] => write!(f, "stop completed with no errors"),
[only] => write!(f, "stop failed: {only}"),
[first, rest @ ..] => write!(
f,
"stop failed with {n} errors; first: {first}",
n = 1 + rest.len(),
),
}
}
}
impl std::error::Error for StopErrors {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.0
.first()
.map(|e| e as &(dyn std::error::Error + 'static))
}
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum ProtocolError {
#[error("missing Content-Length header")]
MissingContentLength,
#[error("invalid Content-Length value: \"{0}\"")]
InvalidContentLength(String),
#[error("request cancelled")]
RequestCancelled,
#[error("timed out waiting for CLI to report listening port")]
CliStartupTimeout,
#[error("CLI exited before reporting listening port")]
CliStartupFailed,
#[error("version mismatch: server={server}, supported={min}–{max}")]
VersionMismatch {
server: u32,
min: u32,
max: u32,
},
#[error("version changed: was {previous}, now {current}")]
VersionChanged {
previous: u32,
current: u32,
},
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum SessionError {
#[error("session not found: {0}")]
NotFound(SessionId),
#[error("{0}")]
AgentError(String),
#[error("timed out after {0:?}")]
Timeout(std::time::Duration),
#[error("cannot send while send_and_wait is in flight")]
SendWhileWaiting,
#[error("event loop closed before session reached idle")]
EventLoopClosed,
#[error(
"elicitation not supported by host — check session.capabilities().ui.elicitation first"
)]
ElicitationNotSupported,
#[error(
"session was created on a client with session_fs configured but no SessionFsProvider was supplied"
)]
SessionFsProviderRequired,
#[error("invalid SessionFsConfig: {0}")]
InvalidSessionFsConfig(String),
}
#[derive(Debug, Default)]
#[non_exhaustive]
pub enum Transport {
#[default]
Stdio,
Tcp {
port: u16,
},
External {
host: String,
port: u16,
},
}
#[derive(Debug, Clone, Default)]
pub enum CliProgram {
#[default]
Resolve,
Path(PathBuf),
}
impl From<PathBuf> for CliProgram {
fn from(path: PathBuf) -> Self {
Self::Path(path)
}
}
#[non_exhaustive]
pub struct ClientOptions {
pub program: CliProgram,
pub prefix_args: Vec<OsString>,
pub cwd: PathBuf,
pub env: Vec<(OsString, OsString)>,
pub env_remove: Vec<OsString>,
pub extra_args: Vec<String>,
pub transport: Transport,
pub github_token: Option<String>,
pub use_logged_in_user: Option<bool>,
pub log_level: Option<LogLevel>,
pub session_idle_timeout_seconds: Option<u64>,
pub on_list_models: Option<Arc<dyn ListModelsHandler>>,
pub session_fs: Option<SessionFsConfig>,
pub on_get_trace_context: Option<Arc<dyn TraceContextProvider>>,
pub telemetry: Option<TelemetryConfig>,
pub copilot_home: Option<PathBuf>,
pub tcp_connection_token: Option<String>,
}
impl std::fmt::Debug for ClientOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ClientOptions")
.field("program", &self.program)
.field("prefix_args", &self.prefix_args)
.field("cwd", &self.cwd)
.field("env", &self.env)
.field("env_remove", &self.env_remove)
.field("extra_args", &self.extra_args)
.field("transport", &self.transport)
.field(
"github_token",
&self.github_token.as_ref().map(|_| "<redacted>"),
)
.field("use_logged_in_user", &self.use_logged_in_user)
.field("log_level", &self.log_level)
.field(
"session_idle_timeout_seconds",
&self.session_idle_timeout_seconds,
)
.field(
"on_list_models",
&self.on_list_models.as_ref().map(|_| "<set>"),
)
.field("session_fs", &self.session_fs)
.field(
"on_get_trace_context",
&self.on_get_trace_context.as_ref().map(|_| "<set>"),
)
.field("telemetry", &self.telemetry)
.field("copilot_home", &self.copilot_home)
.field(
"tcp_connection_token",
&self.tcp_connection_token.as_ref().map(|_| "<redacted>"),
)
.finish()
}
}
#[async_trait]
pub trait ListModelsHandler: Send + Sync + 'static {
async fn list_models(&self) -> Result<Vec<Model>, Error>;
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum LogLevel {
None,
Error,
Warning,
Info,
Debug,
All,
}
impl LogLevel {
pub fn as_str(self) -> &'static str {
match self {
Self::None => "none",
Self::Error => "error",
Self::Warning => "warning",
Self::Info => "info",
Self::Debug => "debug",
Self::All => "all",
}
}
}
impl std::fmt::Display for LogLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
#[non_exhaustive]
pub enum OtelExporterType {
OtlpHttp,
File,
}
impl OtelExporterType {
pub fn as_str(self) -> &'static str {
match self {
Self::OtlpHttp => "otlp-http",
Self::File => "file",
}
}
}
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub struct TelemetryConfig {
pub otlp_endpoint: Option<String>,
pub file_path: Option<PathBuf>,
pub exporter_type: Option<OtelExporterType>,
pub source_name: Option<String>,
pub capture_content: Option<bool>,
}
impl TelemetryConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_otlp_endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.otlp_endpoint = Some(endpoint.into());
self
}
pub fn with_file_path(mut self, path: impl Into<PathBuf>) -> Self {
self.file_path = Some(path.into());
self
}
pub fn with_exporter_type(mut self, exporter_type: OtelExporterType) -> Self {
self.exporter_type = Some(exporter_type);
self
}
pub fn with_source_name(mut self, source_name: impl Into<String>) -> Self {
self.source_name = Some(source_name.into());
self
}
pub fn with_capture_content(mut self, capture: bool) -> Self {
self.capture_content = Some(capture);
self
}
pub fn is_empty(&self) -> bool {
self.otlp_endpoint.is_none()
&& self.file_path.is_none()
&& self.exporter_type.is_none()
&& self.source_name.is_none()
&& self.capture_content.is_none()
}
}
impl Default for ClientOptions {
fn default() -> Self {
Self {
program: CliProgram::Resolve,
prefix_args: Vec::new(),
cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
env: Vec::new(),
env_remove: Vec::new(),
extra_args: Vec::new(),
transport: Transport::default(),
github_token: None,
use_logged_in_user: None,
log_level: None,
session_idle_timeout_seconds: None,
on_list_models: None,
session_fs: None,
on_get_trace_context: None,
telemetry: None,
copilot_home: None,
tcp_connection_token: None,
}
}
}
impl ClientOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_program(mut self, program: impl Into<CliProgram>) -> Self {
self.program = program.into();
self
}
pub fn with_prefix_args<I, S>(mut self, args: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<OsString>,
{
self.prefix_args = args.into_iter().map(Into::into).collect();
self
}
pub fn with_cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
self.cwd = cwd.into();
self
}
pub fn with_env<I, K, V>(mut self, env: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: Into<OsString>,
V: Into<OsString>,
{
self.env = env.into_iter().map(|(k, v)| (k.into(), v.into())).collect();
self
}
pub fn with_env_remove<I, S>(mut self, names: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<OsString>,
{
self.env_remove = names.into_iter().map(Into::into).collect();
self
}
pub fn with_extra_args<I, S>(mut self, args: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.extra_args = args.into_iter().map(Into::into).collect();
self
}
pub fn with_transport(mut self, transport: Transport) -> Self {
self.transport = transport;
self
}
pub fn with_github_token(mut self, token: impl Into<String>) -> Self {
self.github_token = Some(token.into());
self
}
pub fn with_use_logged_in_user(mut self, use_logged_in: bool) -> Self {
self.use_logged_in_user = Some(use_logged_in);
self
}
pub fn with_log_level(mut self, level: LogLevel) -> Self {
self.log_level = Some(level);
self
}
pub fn with_session_idle_timeout_seconds(mut self, seconds: u64) -> Self {
self.session_idle_timeout_seconds = Some(seconds);
self
}
pub fn with_list_models_handler<H>(mut self, handler: H) -> Self
where
H: ListModelsHandler + 'static,
{
self.on_list_models = Some(Arc::new(handler));
self
}
pub fn with_session_fs(mut self, config: SessionFsConfig) -> Self {
self.session_fs = Some(config);
self
}
pub fn with_trace_context_provider<P>(mut self, provider: P) -> Self
where
P: TraceContextProvider + 'static,
{
self.on_get_trace_context = Some(Arc::new(provider));
self
}
pub fn with_telemetry(mut self, config: TelemetryConfig) -> Self {
self.telemetry = Some(config);
self
}
pub fn with_copilot_home(mut self, home: impl Into<PathBuf>) -> Self {
self.copilot_home = Some(home.into());
self
}
pub fn with_tcp_connection_token(mut self, token: impl Into<String>) -> Self {
self.tcp_connection_token = Some(token.into());
self
}
}
fn validate_session_fs_config(cfg: &SessionFsConfig) -> Result<(), Error> {
if cfg.initial_cwd.trim().is_empty() {
return Err(Error::Session(SessionError::InvalidSessionFsConfig(
"initial_cwd must not be empty".to_string(),
)));
}
if cfg.session_state_path.trim().is_empty() {
return Err(Error::Session(SessionError::InvalidSessionFsConfig(
"session_state_path must not be empty".to_string(),
)));
}
Ok(())
}
fn generate_connection_token() -> String {
let mut bytes = [0u8; 16];
getrandom::getrandom(&mut bytes)
.expect("OS CSPRNG (getrandom) is unavailable; cannot generate connection token");
let mut hex = String::with_capacity(32);
for byte in bytes {
use std::fmt::Write;
let _ = write!(hex, "{byte:02x}");
}
hex
}
#[derive(Clone)]
pub struct Client {
inner: Arc<ClientInner>,
}
impl std::fmt::Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field("cwd", &self.inner.cwd)
.field("pid", &self.pid())
.finish()
}
}
struct ClientInner {
child: parking_lot::Mutex<Option<Child>>,
rpc: JsonRpcClient,
cwd: PathBuf,
request_rx: parking_lot::Mutex<Option<mpsc::UnboundedReceiver<JsonRpcRequest>>>,
notification_tx: broadcast::Sender<JsonRpcNotification>,
router: router::SessionRouter,
negotiated_protocol_version: OnceLock<u32>,
state: parking_lot::Mutex<ConnectionState>,
lifecycle_tx: broadcast::Sender<SessionLifecycleEvent>,
on_list_models: Option<Arc<dyn ListModelsHandler>>,
session_fs_configured: bool,
on_get_trace_context: Option<Arc<dyn TraceContextProvider>>,
effective_connection_token: Option<String>,
}
impl Client {
pub async fn start(options: ClientOptions) -> Result<Self, Error> {
if let Some(cfg) = &options.session_fs {
validate_session_fs_config(cfg)?;
}
if let Some(token) = &options.tcp_connection_token {
if token.is_empty() {
return Err(Error::InvalidConfig(
"tcp_connection_token must be a non-empty string".to_string(),
));
}
if matches!(options.transport, Transport::Stdio) {
return Err(Error::InvalidConfig(
"tcp_connection_token cannot be used with Transport::Stdio".to_string(),
));
}
}
let effective_connection_token: Option<String> = match &options.transport {
Transport::Stdio => None,
Transport::Tcp { .. } => Some(
options
.tcp_connection_token
.clone()
.unwrap_or_else(generate_connection_token),
),
Transport::External { .. } => options.tcp_connection_token.clone(),
};
let mut options = options;
if matches!(options.transport, Transport::Tcp { .. })
&& options.tcp_connection_token.is_none()
{
options.tcp_connection_token = effective_connection_token.clone();
}
let session_fs_config = options.session_fs.clone();
let program = match &options.program {
CliProgram::Path(path) => {
info!(path = %path.display(), "using explicit copilot CLI path");
path.clone()
}
CliProgram::Resolve => {
let resolved = resolve::copilot_binary()?;
info!(path = %resolved.display(), "resolved copilot CLI");
#[cfg(windows)]
{
if let Some(ext) = resolved.extension().and_then(|e| e.to_str()) {
if ext.eq_ignore_ascii_case("cmd") || ext.eq_ignore_ascii_case("bat") {
warn!(
path = %resolved.display(),
ext = %ext,
"resolved copilot CLI is a .cmd/.bat wrapper; \
this may cause console window flashes on Windows"
);
}
}
}
resolved
}
};
let client = match options.transport {
Transport::External { ref host, port } => {
info!(host = %host, port = %port, "connecting to external CLI server");
let stream = TcpStream::connect((host.as_str(), port)).await?;
let (reader, writer) = tokio::io::split(stream);
Self::from_transport(
reader,
writer,
None,
options.cwd,
options.on_list_models,
session_fs_config.is_some(),
options.on_get_trace_context,
effective_connection_token.clone(),
)?
}
Transport::Tcp { port } => {
let (mut child, actual_port) = Self::spawn_tcp(&program, &options, port).await?;
let stream = TcpStream::connect(("127.0.0.1", actual_port)).await?;
let (reader, writer) = tokio::io::split(stream);
Self::drain_stderr(&mut child);
Self::from_transport(
reader,
writer,
Some(child),
options.cwd,
options.on_list_models,
session_fs_config.is_some(),
options.on_get_trace_context,
effective_connection_token.clone(),
)?
}
Transport::Stdio => {
let mut child = Self::spawn_stdio(&program, &options)?;
let stdin = child.stdin.take().expect("stdin is piped");
let stdout = child.stdout.take().expect("stdout is piped");
Self::drain_stderr(&mut child);
Self::from_transport(
stdout,
stdin,
Some(child),
options.cwd,
options.on_list_models,
session_fs_config.is_some(),
options.on_get_trace_context,
effective_connection_token.clone(),
)?
}
};
client.verify_protocol_version().await?;
if let Some(cfg) = session_fs_config {
let request = crate::generated::api_types::SessionFsSetProviderRequest {
conventions: cfg.conventions.into_wire(),
initial_cwd: cfg.initial_cwd,
session_state_path: cfg.session_state_path,
};
client.rpc().session_fs().set_provider(request).await?;
}
Ok(client)
}
pub fn from_streams(
reader: impl AsyncRead + Unpin + Send + 'static,
writer: impl AsyncWrite + Unpin + Send + 'static,
cwd: PathBuf,
) -> Result<Self, Error> {
Self::from_transport(reader, writer, None, cwd, None, false, None, None)
}
#[cfg(any(test, feature = "test-support"))]
pub fn from_streams_with_trace_provider(
reader: impl AsyncRead + Unpin + Send + 'static,
writer: impl AsyncWrite + Unpin + Send + 'static,
cwd: PathBuf,
provider: Arc<dyn TraceContextProvider>,
) -> Result<Self, Error> {
Self::from_transport(reader, writer, None, cwd, None, false, Some(provider), None)
}
#[cfg(any(test, feature = "test-support"))]
pub fn from_streams_with_connection_token(
reader: impl AsyncRead + Unpin + Send + 'static,
writer: impl AsyncWrite + Unpin + Send + 'static,
cwd: PathBuf,
token: Option<String>,
) -> Result<Self, Error> {
Self::from_transport(reader, writer, None, cwd, None, false, None, token)
}
#[cfg(any(test, feature = "test-support"))]
pub fn generate_connection_token_for_test() -> String {
generate_connection_token()
}
#[allow(clippy::too_many_arguments)]
fn from_transport(
reader: impl AsyncRead + Unpin + Send + 'static,
writer: impl AsyncWrite + Unpin + Send + 'static,
child: Option<Child>,
cwd: PathBuf,
on_list_models: Option<Arc<dyn ListModelsHandler>>,
session_fs_configured: bool,
on_get_trace_context: Option<Arc<dyn TraceContextProvider>>,
effective_connection_token: Option<String>,
) -> Result<Self, Error> {
let (request_tx, request_rx) = mpsc::unbounded_channel::<JsonRpcRequest>();
let (notification_broadcast_tx, _) = broadcast::channel::<JsonRpcNotification>(1024);
let rpc = JsonRpcClient::new(
writer,
reader,
notification_broadcast_tx.clone(),
request_tx,
);
let pid = child.as_ref().and_then(|c| c.id());
info!(pid = ?pid, "copilot CLI client ready");
let client = Self {
inner: Arc::new(ClientInner {
child: parking_lot::Mutex::new(child),
rpc,
cwd,
request_rx: parking_lot::Mutex::new(Some(request_rx)),
notification_tx: notification_broadcast_tx,
router: router::SessionRouter::new(),
negotiated_protocol_version: OnceLock::new(),
state: parking_lot::Mutex::new(ConnectionState::Connected),
lifecycle_tx: broadcast::channel(256).0,
on_list_models,
session_fs_configured,
on_get_trace_context,
effective_connection_token,
}),
};
client.spawn_lifecycle_dispatcher();
Ok(client)
}
fn spawn_lifecycle_dispatcher(&self) {
let inner = Arc::clone(&self.inner);
let mut notif_rx = inner.notification_tx.subscribe();
tokio::spawn(async move {
loop {
match notif_rx.recv().await {
Ok(notification) => {
if notification.method != "session.lifecycle" {
continue;
}
let Some(params) = notification.params.as_ref() else {
continue;
};
let event: SessionLifecycleEvent =
match serde_json::from_value(params.clone()) {
Ok(e) => e,
Err(e) => {
warn!(
error = %e,
"failed to deserialize session.lifecycle notification"
);
continue;
}
};
let _ = inner.lifecycle_tx.send(event);
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
warn!(missed = n, "lifecycle dispatcher lagged");
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
});
}
fn build_command(program: &Path, options: &ClientOptions) -> Command {
let mut command = Command::new(program);
for arg in &options.prefix_args {
command.arg(arg);
}
if let Some(token) = &options.github_token {
command.env("COPILOT_SDK_AUTH_TOKEN", token);
}
if let Some(telemetry) = &options.telemetry {
command.env("COPILOT_OTEL_ENABLED", "true");
if let Some(endpoint) = &telemetry.otlp_endpoint {
command.env("OTEL_EXPORTER_OTLP_ENDPOINT", endpoint);
}
if let Some(path) = &telemetry.file_path {
command.env("COPILOT_OTEL_FILE_EXPORTER_PATH", path);
}
if let Some(exporter) = telemetry.exporter_type {
command.env("COPILOT_OTEL_EXPORTER_TYPE", exporter.as_str());
}
if let Some(source) = &telemetry.source_name {
command.env("COPILOT_OTEL_SOURCE_NAME", source);
}
if let Some(capture) = telemetry.capture_content {
command.env(
"OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT",
if capture { "true" } else { "false" },
);
}
}
if let Some(home) = &options.copilot_home {
command.env("COPILOT_HOME", home);
}
if let Some(token) = &options.tcp_connection_token {
command.env("COPILOT_CONNECTION_TOKEN", token);
}
for (key, value) in &options.env {
command.env(key, value);
}
for key in &options.env_remove {
command.env_remove(key);
}
command
.current_dir(&options.cwd)
.stdout(Stdio::piped())
.stderr(Stdio::piped());
#[cfg(windows)]
{
use std::os::windows::process::CommandExt;
const CREATE_NO_WINDOW: u32 = 0x08000000;
command.as_std_mut().creation_flags(CREATE_NO_WINDOW);
}
command
}
fn auth_args(options: &ClientOptions) -> Vec<&'static str> {
let mut args: Vec<&'static str> = Vec::new();
if options.github_token.is_some() {
args.push("--auth-token-env");
args.push("COPILOT_SDK_AUTH_TOKEN");
}
let use_logged_in = options
.use_logged_in_user
.unwrap_or(options.github_token.is_none());
if !use_logged_in {
args.push("--no-auto-login");
}
args
}
fn session_idle_timeout_args(options: &ClientOptions) -> Vec<String> {
match options.session_idle_timeout_seconds {
Some(secs) if secs > 0 => {
vec!["--session-idle-timeout".to_string(), secs.to_string()]
}
_ => Vec::new(),
}
}
fn spawn_stdio(program: &Path, options: &ClientOptions) -> Result<Child, Error> {
info!(cwd = ?options.cwd, program = %program.display(), "spawning copilot CLI (stdio)");
let mut command = Self::build_command(program, options);
let log_level = options.log_level.unwrap_or(LogLevel::Info);
command
.args([
"--server",
"--stdio",
"--no-auto-update",
"--log-level",
log_level.as_str(),
])
.args(Self::auth_args(options))
.args(Self::session_idle_timeout_args(options))
.args(&options.extra_args)
.stdin(Stdio::piped());
Ok(command.spawn()?)
}
async fn spawn_tcp(
program: &Path,
options: &ClientOptions,
port: u16,
) -> Result<(Child, u16), Error> {
info!(cwd = ?options.cwd, program = %program.display(), port = %port, "spawning copilot CLI (tcp)");
let mut command = Self::build_command(program, options);
let log_level = options.log_level.unwrap_or(LogLevel::Info);
command
.args([
"--server",
"--port",
&port.to_string(),
"--no-auto-update",
"--log-level",
log_level.as_str(),
])
.args(Self::auth_args(options))
.args(Self::session_idle_timeout_args(options))
.args(&options.extra_args)
.stdin(Stdio::null());
let mut child = command.spawn()?;
let stdout = child.stdout.take().expect("stdout is piped");
let (port_tx, port_rx) = oneshot::channel::<u16>();
let span = tracing::error_span!("copilot_cli_port_scan");
tokio::spawn(
async move {
let port_re = regex::Regex::new(r"listening on port (\d+)").expect("valid regex");
let mut lines = BufReader::new(stdout).lines();
let mut port_tx = Some(port_tx);
while let Ok(Some(line)) = lines.next_line().await {
debug!(line = %line, "CLI stdout");
if let Some(tx) = port_tx.take() {
if let Some(caps) = port_re.captures(&line)
&& let Some(p) =
caps.get(1).and_then(|m| m.as_str().parse::<u16>().ok())
{
let _ = tx.send(p);
continue;
}
port_tx = Some(tx);
}
}
}
.instrument(span),
);
let actual_port = tokio::time::timeout(std::time::Duration::from_secs(10), port_rx)
.await
.map_err(|_| Error::Protocol(ProtocolError::CliStartupTimeout))?
.map_err(|_| Error::Protocol(ProtocolError::CliStartupFailed))?;
info!(port = %actual_port, "CLI server listening");
Ok((child, actual_port))
}
fn drain_stderr(child: &mut Child) {
if let Some(stderr) = child.stderr.take() {
let span = tracing::error_span!("copilot_cli");
tokio::spawn(
async move {
let mut reader = BufReader::new(stderr).lines();
while let Ok(Some(line)) = reader.next_line().await {
warn!(line = %line, "CLI stderr");
}
}
.instrument(span),
);
}
}
pub fn cwd(&self) -> &PathBuf {
&self.inner.cwd
}
pub fn rpc(&self) -> crate::generated::rpc::ClientRpc<'_> {
crate::generated::rpc::ClientRpc { client: self }
}
pub(crate) async fn send_request(
&self,
method: &str,
params: Option<serde_json::Value>,
) -> Result<JsonRpcResponse, Error> {
self.inner.rpc.send_request(method, params).await
}
pub async fn call(
&self,
method: &str,
params: Option<serde_json::Value>,
) -> Result<serde_json::Value, Error> {
let session_id: Option<SessionId> = params
.as_ref()
.and_then(|p| p.get("sessionId"))
.and_then(|v| v.as_str())
.map(SessionId::from);
let response = self.send_request(method, params).await?;
if let Some(err) = response.error {
if err.message.contains("Session not found") {
return Err(Error::Session(SessionError::NotFound(
session_id.unwrap_or_else(|| "unknown".into()),
)));
}
return Err(Error::Rpc {
code: err.code,
message: err.message,
});
}
Ok(response.result.unwrap_or(serde_json::Value::Null))
}
pub(crate) async fn send_response(&self, response: &JsonRpcResponse) -> Result<(), Error> {
self.inner.rpc.write(response).await
}
#[expect(dead_code, reason = "reserved for future pub(crate) use")]
pub(crate) fn take_request_rx(&self) -> Option<mpsc::UnboundedReceiver<JsonRpcRequest>> {
self.inner.request_rx.lock().take()
}
pub(crate) fn register_session(
&self,
session_id: &SessionId,
) -> crate::router::SessionChannels {
self.inner
.router
.ensure_started(&self.inner.notification_tx, &self.inner.request_rx);
self.inner.router.register(session_id)
}
pub(crate) fn unregister_session(&self, session_id: &SessionId) {
self.inner.router.unregister(session_id);
}
pub fn protocol_version(&self) -> Option<u32> {
self.inner.negotiated_protocol_version.get().copied()
}
pub async fn verify_protocol_version(&self) -> Result<(), Error> {
let server_version = match self.connect_handshake().await {
Ok(v) => v,
Err(Error::Rpc { code, .. }) if code == error_codes::METHOD_NOT_FOUND => {
self.ping(None).await?.protocol_version
}
Err(e) => return Err(e),
};
match server_version {
None => {
warn!("CLI server did not report protocolVersion; skipping version check");
}
Some(v) if !(MIN_PROTOCOL_VERSION..=SDK_PROTOCOL_VERSION).contains(&v) => {
return Err(Error::Protocol(ProtocolError::VersionMismatch {
server: v,
min: MIN_PROTOCOL_VERSION,
max: SDK_PROTOCOL_VERSION,
}));
}
Some(v) => {
if let Some(&existing) = self.inner.negotiated_protocol_version.get() {
if existing != v {
return Err(Error::Protocol(ProtocolError::VersionChanged {
previous: existing,
current: v,
}));
}
} else {
let _ = self.inner.negotiated_protocol_version.set(v);
}
}
}
Ok(())
}
async fn connect_handshake(&self) -> Result<Option<u32>, Error> {
let result = self
.rpc()
.connect(crate::generated::api_types::ConnectRequest {
token: self.inner.effective_connection_token.clone(),
})
.await?;
Ok(u32::try_from(result.protocol_version).ok())
}
pub async fn ping(&self, message: Option<&str>) -> Result<crate::types::PingResponse, Error> {
let params = match message {
Some(m) => serde_json::json!({ "message": m }),
None => serde_json::json!({}),
};
let value = self
.call(generated::api_types::rpc_methods::PING, Some(params))
.await?;
Ok(serde_json::from_value(value)?)
}
pub async fn list_sessions(
&self,
filter: Option<SessionListFilter>,
) -> Result<Vec<SessionMetadata>, Error> {
let params = match filter {
Some(f) => serde_json::json!({ "filter": f }),
None => serde_json::json!({}),
};
let result = self.call("session.list", Some(params)).await?;
let response: ListSessionsResponse = serde_json::from_value(result)?;
Ok(response.sessions)
}
pub async fn get_session_metadata(
&self,
session_id: &SessionId,
) -> Result<Option<SessionMetadata>, Error> {
let result = self
.call(
"session.getMetadata",
Some(serde_json::json!({ "sessionId": session_id })),
)
.await?;
let response: GetSessionMetadataResponse = serde_json::from_value(result)?;
Ok(response.session)
}
pub async fn delete_session(&self, session_id: &SessionId) -> Result<(), Error> {
self.call(
"session.delete",
Some(serde_json::json!({ "sessionId": session_id })),
)
.await?;
Ok(())
}
pub async fn get_last_session_id(&self) -> Result<Option<SessionId>, Error> {
let result = self
.call("session.getLastId", Some(serde_json::json!({})))
.await?;
let response: GetLastSessionIdResponse = serde_json::from_value(result)?;
Ok(response.session_id)
}
pub async fn get_foreground_session_id(&self) -> Result<Option<SessionId>, Error> {
let result = self
.call("session.getForeground", Some(serde_json::json!({})))
.await?;
let response: GetForegroundSessionResponse = serde_json::from_value(result)?;
Ok(response.session_id)
}
pub async fn set_foreground_session_id(&self, session_id: &SessionId) -> Result<(), Error> {
self.call(
"session.setForeground",
Some(serde_json::json!({ "sessionId": session_id })),
)
.await?;
Ok(())
}
pub async fn get_status(&self) -> Result<GetStatusResponse, Error> {
let result = self.call("status.get", Some(serde_json::json!({}))).await?;
Ok(serde_json::from_value(result)?)
}
pub async fn get_auth_status(&self) -> Result<GetAuthStatusResponse, Error> {
let result = self
.call("auth.getStatus", Some(serde_json::json!({})))
.await?;
Ok(serde_json::from_value(result)?)
}
pub async fn list_models(&self) -> Result<Vec<Model>, Error> {
if let Some(handler) = &self.inner.on_list_models {
return handler.list_models().await;
}
Ok(self.rpc().models().list().await?.models)
}
pub(crate) async fn resolve_trace_context(&self) -> TraceContext {
if let Some(provider) = &self.inner.on_get_trace_context {
provider.get_trace_context().await
} else {
TraceContext::default()
}
}
pub fn pid(&self) -> Option<u32> {
self.inner.child.lock().as_ref().and_then(|c| c.id())
}
pub async fn stop(&self) -> Result<(), StopErrors> {
let pid = self.pid();
info!(pid = ?pid, "stopping CLI process");
let mut errors: Vec<Error> = Vec::new();
for session_id in self.inner.router.session_ids() {
match self
.call(
"session.destroy",
Some(serde_json::json!({ "sessionId": session_id })),
)
.await
{
Ok(_) => {}
Err(e) => {
warn!(
session_id = %session_id,
error = %e,
"session.destroy failed during Client::stop",
);
errors.push(e);
}
}
self.inner.router.unregister(&session_id);
}
let child = self.inner.child.lock().take();
*self.inner.state.lock() = ConnectionState::Disconnected;
if let Some(mut child) = child
&& let Err(e) = child.kill().await
{
errors.push(Error::Io(e));
}
info!(pid = ?pid, errors = errors.len(), "CLI process stopped");
if errors.is_empty() {
Ok(())
} else {
Err(StopErrors(errors))
}
}
pub fn force_stop(&self) {
let pid = self.pid();
info!(pid = ?pid, "force-stopping CLI process");
if let Some(mut child) = self.inner.child.lock().take()
&& let Err(e) = child.start_kill()
{
error!(pid = ?pid, error = %e, "failed to send kill signal");
}
self.inner.router.clear();
*self.inner.state.lock() = ConnectionState::Disconnected;
}
pub fn subscribe_lifecycle(&self) -> LifecycleSubscription {
LifecycleSubscription::new(self.inner.lifecycle_tx.subscribe())
}
pub fn state(&self) -> ConnectionState {
*self.inner.state.lock()
}
}
impl Drop for ClientInner {
fn drop(&mut self) {
if let Some(ref mut child) = *self.child.lock() {
let pid = child.id();
if let Err(e) = child.start_kill() {
error!(pid = ?pid, error = %e, "failed to kill CLI process on drop");
} else {
info!(pid = ?pid, "kill signal sent for CLI process on drop");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn is_transport_failure_matches_request_cancelled() {
let err = Error::Protocol(ProtocolError::RequestCancelled);
assert!(err.is_transport_failure());
}
#[test]
fn is_transport_failure_matches_io_error() {
let err = Error::Io(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "gone"));
assert!(err.is_transport_failure());
}
#[test]
fn is_transport_failure_rejects_rpc_error() {
let err = Error::Rpc {
code: -1,
message: "bad".into(),
};
assert!(!err.is_transport_failure());
}
#[test]
fn is_transport_failure_rejects_session_error() {
let err = Error::Session(SessionError::NotFound("s1".into()));
assert!(!err.is_transport_failure());
}
#[test]
fn client_options_builder_composes() {
let opts = ClientOptions::new()
.with_program(CliProgram::Path(PathBuf::from("/usr/local/bin/copilot")))
.with_prefix_args(["node"])
.with_cwd(PathBuf::from("/tmp"))
.with_env([("KEY", "value")])
.with_env_remove(["UNWANTED"])
.with_extra_args(["--quiet"])
.with_github_token("ghp_test")
.with_use_logged_in_user(false)
.with_log_level(LogLevel::Debug)
.with_session_idle_timeout_seconds(120);
assert!(matches!(opts.program, CliProgram::Path(_)));
assert_eq!(opts.prefix_args, vec![std::ffi::OsString::from("node")]);
assert_eq!(opts.cwd, PathBuf::from("/tmp"));
assert_eq!(
opts.env,
vec![(
std::ffi::OsString::from("KEY"),
std::ffi::OsString::from("value")
)]
);
assert_eq!(opts.env_remove, vec![std::ffi::OsString::from("UNWANTED")]);
assert_eq!(opts.extra_args, vec!["--quiet".to_string()]);
assert_eq!(opts.github_token.as_deref(), Some("ghp_test"));
assert_eq!(opts.use_logged_in_user, Some(false));
assert!(matches!(opts.log_level, Some(LogLevel::Debug)));
assert_eq!(opts.session_idle_timeout_seconds, Some(120));
}
#[test]
fn is_transport_failure_rejects_other_protocol_errors() {
let err = Error::Protocol(ProtocolError::CliStartupTimeout);
assert!(!err.is_transport_failure());
}
#[test]
fn build_command_lets_env_remove_strip_injected_token() {
let opts = ClientOptions {
github_token: Some("secret".to_string()),
env_remove: vec![std::ffi::OsString::from("COPILOT_SDK_AUTH_TOKEN")],
..Default::default()
};
let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
let action = cmd
.as_std()
.get_envs()
.find(|(k, _)| *k == std::ffi::OsStr::new("COPILOT_SDK_AUTH_TOKEN"))
.map(|(_, v)| v);
assert_eq!(
action,
Some(None),
"env_remove should win over github_token"
);
}
#[test]
fn build_command_lets_env_override_injected_token() {
let opts = ClientOptions {
github_token: Some("from-options".to_string()),
env: vec![(
std::ffi::OsString::from("COPILOT_SDK_AUTH_TOKEN"),
std::ffi::OsString::from("from-env"),
)],
..Default::default()
};
let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
let value = cmd
.as_std()
.get_envs()
.find(|(k, _)| *k == std::ffi::OsStr::new("COPILOT_SDK_AUTH_TOKEN"))
.and_then(|(_, v)| v);
assert_eq!(value, Some(std::ffi::OsStr::new("from-env")));
}
#[test]
fn build_command_injects_github_token_by_default() {
let opts = ClientOptions {
github_token: Some("just-the-token".to_string()),
..Default::default()
};
let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
let value = cmd
.as_std()
.get_envs()
.find(|(k, _)| *k == std::ffi::OsStr::new("COPILOT_SDK_AUTH_TOKEN"))
.and_then(|(_, v)| v);
assert_eq!(value, Some(std::ffi::OsStr::new("just-the-token")));
}
fn env_value<'a>(cmd: &'a tokio::process::Command, key: &str) -> Option<&'a std::ffi::OsStr> {
cmd.as_std()
.get_envs()
.find(|(k, _)| *k == std::ffi::OsStr::new(key))
.and_then(|(_, v)| v)
}
#[test]
fn telemetry_config_builder_composes() {
let cfg = TelemetryConfig::new()
.with_otlp_endpoint("http://collector:4318")
.with_file_path(PathBuf::from("/var/log/copilot.jsonl"))
.with_exporter_type(OtelExporterType::OtlpHttp)
.with_source_name("my-app")
.with_capture_content(true);
assert_eq!(cfg.otlp_endpoint.as_deref(), Some("http://collector:4318"));
assert_eq!(
cfg.file_path.as_deref(),
Some(Path::new("/var/log/copilot.jsonl")),
);
assert_eq!(cfg.exporter_type, Some(OtelExporterType::OtlpHttp));
assert_eq!(cfg.source_name.as_deref(), Some("my-app"));
assert_eq!(cfg.capture_content, Some(true));
assert!(!cfg.is_empty());
assert!(TelemetryConfig::new().is_empty());
}
#[test]
fn build_command_sets_otel_env_when_telemetry_enabled() {
let opts = ClientOptions {
telemetry: Some(TelemetryConfig {
otlp_endpoint: Some("http://collector:4318".to_string()),
file_path: Some(PathBuf::from("/var/log/copilot.jsonl")),
exporter_type: Some(OtelExporterType::OtlpHttp),
source_name: Some("my-app".to_string()),
capture_content: Some(true),
}),
..Default::default()
};
let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
assert_eq!(
env_value(&cmd, "COPILOT_OTEL_ENABLED"),
Some(std::ffi::OsStr::new("true")),
);
assert_eq!(
env_value(&cmd, "OTEL_EXPORTER_OTLP_ENDPOINT"),
Some(std::ffi::OsStr::new("http://collector:4318")),
);
assert_eq!(
env_value(&cmd, "COPILOT_OTEL_FILE_EXPORTER_PATH"),
Some(std::ffi::OsStr::new("/var/log/copilot.jsonl")),
);
assert_eq!(
env_value(&cmd, "COPILOT_OTEL_EXPORTER_TYPE"),
Some(std::ffi::OsStr::new("otlp-http")),
);
assert_eq!(
env_value(&cmd, "COPILOT_OTEL_SOURCE_NAME"),
Some(std::ffi::OsStr::new("my-app")),
);
assert_eq!(
env_value(&cmd, "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"),
Some(std::ffi::OsStr::new("true")),
);
}
#[test]
fn build_command_omits_otel_env_when_telemetry_none() {
let opts = ClientOptions::default();
let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
for key in [
"COPILOT_OTEL_ENABLED",
"OTEL_EXPORTER_OTLP_ENDPOINT",
"COPILOT_OTEL_FILE_EXPORTER_PATH",
"COPILOT_OTEL_EXPORTER_TYPE",
"COPILOT_OTEL_SOURCE_NAME",
"OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT",
] {
assert!(
env_value(&cmd, key).is_none(),
"expected {key} to be unset when telemetry is None",
);
}
}
#[test]
fn build_command_omits_unset_telemetry_fields() {
let opts = ClientOptions {
telemetry: Some(TelemetryConfig {
otlp_endpoint: Some("http://collector:4318".to_string()),
..Default::default()
}),
..Default::default()
};
let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
assert_eq!(
env_value(&cmd, "COPILOT_OTEL_ENABLED"),
Some(std::ffi::OsStr::new("true")),
);
assert_eq!(
env_value(&cmd, "OTEL_EXPORTER_OTLP_ENDPOINT"),
Some(std::ffi::OsStr::new("http://collector:4318")),
);
for key in [
"COPILOT_OTEL_FILE_EXPORTER_PATH",
"COPILOT_OTEL_EXPORTER_TYPE",
"COPILOT_OTEL_SOURCE_NAME",
"OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT",
] {
assert!(env_value(&cmd, key).is_none(), "{key} should be unset");
}
}
#[test]
fn build_command_lets_user_env_override_telemetry() {
let opts = ClientOptions {
telemetry: Some(TelemetryConfig {
otlp_endpoint: Some("http://from-config:4318".to_string()),
..Default::default()
}),
env: vec![(
std::ffi::OsString::from("OTEL_EXPORTER_OTLP_ENDPOINT"),
std::ffi::OsString::from("http://from-user-env:4318"),
)],
..Default::default()
};
let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
assert_eq!(
env_value(&cmd, "OTEL_EXPORTER_OTLP_ENDPOINT"),
Some(std::ffi::OsStr::new("http://from-user-env:4318")),
"user-supplied options.env should override telemetry config",
);
}
#[test]
fn build_command_sets_copilot_home_env_when_configured() {
let opts = ClientOptions::new().with_copilot_home(PathBuf::from("/custom/copilot"));
let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
assert_eq!(
env_value(&cmd, "COPILOT_HOME"),
Some(std::ffi::OsStr::new("/custom/copilot")),
);
let opts = ClientOptions::default();
let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
assert!(env_value(&cmd, "COPILOT_HOME").is_none());
}
#[test]
fn build_command_sets_connection_token_env_when_configured() {
let opts = ClientOptions::new().with_tcp_connection_token("secret-token");
let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
assert_eq!(
env_value(&cmd, "COPILOT_CONNECTION_TOKEN"),
Some(std::ffi::OsStr::new("secret-token")),
);
let opts = ClientOptions::default();
let cmd = Client::build_command(Path::new("/bin/echo"), &opts);
assert!(env_value(&cmd, "COPILOT_CONNECTION_TOKEN").is_none());
}
#[tokio::test]
async fn start_rejects_token_with_stdio_transport() {
let opts = ClientOptions::new()
.with_tcp_connection_token("token-123")
.with_program(CliProgram::Path(PathBuf::from("/bin/echo")));
let err = Client::start(opts).await.unwrap_err();
assert!(matches!(err, Error::InvalidConfig(_)), "got {err:?}");
let Error::InvalidConfig(msg) = err else {
unreachable!()
};
assert!(
msg.contains("Stdio"),
"error should explain the stdio incompatibility: {msg}"
);
}
#[tokio::test]
async fn start_rejects_empty_connection_token() {
let opts = ClientOptions::new()
.with_tcp_connection_token("")
.with_transport(Transport::Tcp { port: 0 })
.with_program(CliProgram::Path(PathBuf::from("/bin/echo")));
let err = Client::start(opts).await.unwrap_err();
assert!(matches!(err, Error::InvalidConfig(_)), "got {err:?}");
}
#[test]
fn telemetry_config_capture_content_serializes_as_lowercase_bool() {
let opts_true = ClientOptions {
telemetry: Some(TelemetryConfig {
capture_content: Some(true),
..Default::default()
}),
..Default::default()
};
let opts_false = ClientOptions {
telemetry: Some(TelemetryConfig {
capture_content: Some(false),
..Default::default()
}),
..Default::default()
};
let cmd_true = Client::build_command(Path::new("/bin/echo"), &opts_true);
let cmd_false = Client::build_command(Path::new("/bin/echo"), &opts_false);
assert_eq!(
env_value(
&cmd_true,
"OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"
),
Some(std::ffi::OsStr::new("true")),
);
assert_eq!(
env_value(
&cmd_false,
"OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"
),
Some(std::ffi::OsStr::new("false")),
);
}
#[test]
fn session_idle_timeout_args_are_omitted_by_default() {
let opts = ClientOptions::default();
assert!(Client::session_idle_timeout_args(&opts).is_empty());
}
#[test]
fn session_idle_timeout_args_omitted_for_zero() {
let opts = ClientOptions {
session_idle_timeout_seconds: Some(0),
..Default::default()
};
assert!(Client::session_idle_timeout_args(&opts).is_empty());
}
#[test]
fn session_idle_timeout_args_emit_flag_for_positive_value() {
let opts = ClientOptions {
session_idle_timeout_seconds: Some(300),
..Default::default()
};
assert_eq!(
Client::session_idle_timeout_args(&opts),
vec!["--session-idle-timeout".to_string(), "300".to_string()]
);
}
#[test]
fn log_level_str_round_trips() {
for level in [
LogLevel::None,
LogLevel::Error,
LogLevel::Warning,
LogLevel::Info,
LogLevel::Debug,
LogLevel::All,
] {
let s = level.as_str();
let json = serde_json::to_string(&level).unwrap();
assert_eq!(json, format!("\"{s}\""));
let parsed: LogLevel = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, level);
}
}
#[test]
fn client_options_debug_redacts_handler() {
struct StubHandler;
#[async_trait]
impl ListModelsHandler for StubHandler {
async fn list_models(&self) -> Result<Vec<Model>, Error> {
Ok(vec![])
}
}
let opts = ClientOptions {
on_list_models: Some(Arc::new(StubHandler)),
github_token: Some("secret-token".into()),
..Default::default()
};
let debug = format!("{opts:?}");
assert!(debug.contains("on_list_models: Some(\"<set>\")"));
assert!(debug.contains("github_token: Some(\"<redacted>\")"));
assert!(!debug.contains("secret-token"));
}
#[tokio::test]
async fn list_models_uses_on_list_models_handler_when_set() {
use std::sync::atomic::{AtomicUsize, Ordering};
struct CountingHandler {
calls: Arc<AtomicUsize>,
models: Vec<Model>,
}
#[async_trait]
impl ListModelsHandler for CountingHandler {
async fn list_models(&self) -> Result<Vec<Model>, Error> {
self.calls.fetch_add(1, Ordering::SeqCst);
Ok(self.models.clone())
}
}
let calls = Arc::new(AtomicUsize::new(0));
let model = Model {
billing: None,
capabilities: ModelCapabilities {
limits: None,
supports: None,
},
default_reasoning_effort: None,
id: "byok-gpt-4".into(),
name: "BYOK GPT-4".into(),
policy: None,
supported_reasoning_efforts: Vec::new(),
};
let handler = Arc::new(CountingHandler {
calls: Arc::clone(&calls),
models: vec![model.clone()],
});
let inner = ClientInner {
child: parking_lot::Mutex::new(None),
rpc: {
let (req_tx, _req_rx) = mpsc::unbounded_channel();
let (notif_tx, _notif_rx) = broadcast::channel(16);
let (read_pipe, _write_pipe) = tokio::io::duplex(64);
let (_unused_read, write_pipe) = tokio::io::duplex(64);
JsonRpcClient::new(write_pipe, read_pipe, notif_tx, req_tx)
},
cwd: PathBuf::from("."),
request_rx: parking_lot::Mutex::new(None),
notification_tx: broadcast::channel(16).0,
router: router::SessionRouter::new(),
negotiated_protocol_version: OnceLock::new(),
state: parking_lot::Mutex::new(ConnectionState::Connected),
lifecycle_tx: broadcast::channel(16).0,
on_list_models: Some(handler),
session_fs_configured: false,
on_get_trace_context: None,
effective_connection_token: None,
};
let client = Client {
inner: Arc::new(inner),
};
let result = client.list_models().await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].id, "byok-gpt-4");
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
}