use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::canvas::{CanvasDeclaration, CanvasHandler};
use crate::generated::api_types::OpenCanvasInstance;
use crate::handler::{
AutoModeSwitchHandler, ElicitationHandler, ExitPlanModeHandler, PermissionHandler,
UserInputHandler,
};
use crate::hooks::SessionHooks;
pub use crate::session_fs::{
DirEntry, DirEntryKind, FileInfo, FsError, SessionFsCapabilities, SessionFsConfig,
SessionFsConventions, SessionFsProvider, SessionFsSqliteProvider, SessionFsSqliteQueryResult,
SessionFsSqliteQueryType,
};
pub use crate::trace_context::{TraceContext, TraceContextProvider};
use crate::transforms::SystemMessageTransform;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[allow(dead_code)]
#[non_exhaustive]
pub(crate) enum ConnectionState {
Disconnected,
Connecting,
Connected,
Error,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum SessionLifecycleEventType {
#[serde(rename = "session.created")]
Created,
#[serde(rename = "session.deleted")]
Deleted,
#[serde(rename = "session.updated")]
Updated,
#[serde(rename = "session.foreground")]
Foreground,
#[serde(rename = "session.background")]
Background,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SessionLifecycleEventMetadata {
#[serde(rename = "startTime")]
pub start_time: String,
#[serde(rename = "modifiedTime")]
pub modified_time: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub summary: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SessionLifecycleEvent {
#[serde(rename = "type")]
pub event_type: SessionLifecycleEventType,
#[serde(rename = "sessionId")]
pub session_id: SessionId,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<SessionLifecycleEventMetadata>,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct SessionId(String);
impl SessionId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
pub fn into_inner(self) -> String {
self.0
}
}
impl std::ops::Deref for SessionId {
type Target = str;
fn deref(&self) -> &str {
&self.0
}
}
impl std::fmt::Display for SessionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
impl From<String> for SessionId {
fn from(s: String) -> Self {
Self(s)
}
}
impl From<&str> for SessionId {
fn from(s: &str) -> Self {
Self(s.to_owned())
}
}
impl AsRef<str> for SessionId {
fn as_ref(&self) -> &str {
&self.0
}
}
impl std::borrow::Borrow<str> for SessionId {
fn borrow(&self) -> &str {
&self.0
}
}
impl From<SessionId> for String {
fn from(id: SessionId) -> String {
id.0
}
}
impl PartialEq<str> for SessionId {
fn eq(&self, other: &str) -> bool {
self.0 == other
}
}
impl PartialEq<String> for SessionId {
fn eq(&self, other: &String) -> bool {
&self.0 == other
}
}
impl PartialEq<SessionId> for String {
fn eq(&self, other: &SessionId) -> bool {
self == &other.0
}
}
impl PartialEq<&str> for SessionId {
fn eq(&self, other: &&str) -> bool {
self.0 == *other
}
}
impl PartialEq<&SessionId> for SessionId {
fn eq(&self, other: &&SessionId) -> bool {
self.0 == other.0
}
}
impl PartialEq<SessionId> for &SessionId {
fn eq(&self, other: &SessionId) -> bool {
self.0 == other.0
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct RequestId(String);
impl RequestId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
pub fn into_inner(self) -> String {
self.0
}
}
impl std::ops::Deref for RequestId {
type Target = str;
fn deref(&self) -> &str {
&self.0
}
}
impl std::fmt::Display for RequestId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
impl From<String> for RequestId {
fn from(s: String) -> Self {
Self(s)
}
}
impl From<&str> for RequestId {
fn from(s: &str) -> Self {
Self(s.to_owned())
}
}
impl AsRef<str> for RequestId {
fn as_ref(&self) -> &str {
&self.0
}
}
impl std::borrow::Borrow<str> for RequestId {
fn borrow(&self) -> &str {
&self.0
}
}
impl From<RequestId> for String {
fn from(id: RequestId) -> String {
id.0
}
}
impl PartialEq<str> for RequestId {
fn eq(&self, other: &str) -> bool {
self.0 == other
}
}
impl PartialEq<String> for RequestId {
fn eq(&self, other: &String) -> bool {
&self.0 == other
}
}
impl PartialEq<RequestId> for String {
fn eq(&self, other: &RequestId) -> bool {
self == &other.0
}
}
impl PartialEq<&str> for RequestId {
fn eq(&self, other: &&str) -> bool {
self.0 == *other
}
}
#[derive(Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct Tool {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub namespaced_name: Option<String>,
#[serde(default)]
pub description: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub parameters: HashMap<String, Value>,
#[serde(default, skip_serializing_if = "is_false")]
pub overrides_built_in_tool: bool,
#[serde(default, skip_serializing_if = "is_false")]
pub skip_permission: bool,
#[serde(skip)]
pub(crate) handler: Option<Arc<dyn crate::tool::ToolHandler>>,
}
#[inline]
fn is_false(b: &bool) -> bool {
!*b
}
impl Tool {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
..Default::default()
}
}
pub fn with_namespaced_name(mut self, namespaced_name: impl Into<String>) -> Self {
self.namespaced_name = Some(namespaced_name.into());
self
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = description.into();
self
}
pub fn with_instructions(mut self, instructions: impl Into<String>) -> Self {
self.instructions = Some(instructions.into());
self
}
pub fn with_parameters(mut self, parameters: Value) -> Self {
self.parameters = crate::tool::tool_parameters(parameters);
self
}
pub fn with_overrides_built_in_tool(mut self, overrides: bool) -> Self {
self.overrides_built_in_tool = overrides;
self
}
pub fn with_skip_permission(mut self, skip: bool) -> Self {
self.skip_permission = skip;
self
}
pub fn with_handler(mut self, handler: Arc<dyn crate::tool::ToolHandler>) -> Self {
self.handler = Some(handler);
self
}
pub fn handler(&self) -> Option<&Arc<dyn crate::tool::ToolHandler>> {
self.handler.as_ref()
}
}
impl std::fmt::Debug for Tool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Tool")
.field("name", &self.name)
.field("namespaced_name", &self.namespaced_name)
.field("description", &self.description)
.field("instructions", &self.instructions)
.field("parameters", &self.parameters)
.field("overrides_built_in_tool", &self.overrides_built_in_tool)
.field("skip_permission", &self.skip_permission)
.field(
"handler",
&self.handler.as_ref().map(|_| "<set>").unwrap_or("None"),
)
.finish()
}
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct CommandContext {
pub session_id: SessionId,
pub command: String,
pub command_name: String,
pub args: String,
}
#[async_trait::async_trait]
pub trait CommandHandler: Send + Sync {
async fn on_command(&self, ctx: CommandContext) -> Result<(), crate::Error>;
}
#[non_exhaustive]
#[derive(Clone)]
pub struct CommandDefinition {
pub name: String,
pub description: Option<String>,
pub handler: Arc<dyn CommandHandler>,
}
impl CommandDefinition {
pub fn new(name: impl Into<String>, handler: Arc<dyn CommandHandler>) -> Self {
Self {
name: name.into(),
description: None,
handler,
}
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
}
impl std::fmt::Debug for CommandDefinition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CommandDefinition")
.field("name", &self.name)
.field("description", &self.description)
.field("handler", &"<set>")
.finish()
}
}
impl Serialize for CommandDefinition {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
use serde::ser::SerializeStruct;
let len = if self.description.is_some() { 2 } else { 1 };
let mut state = serializer.serialize_struct("CommandDefinition", len)?;
state.serialize_field("name", &self.name)?;
if let Some(description) = &self.description {
state.serialize_field("description", description)?;
}
state.end()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct CustomAgentConfig {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub display_name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<String>>,
pub prompt: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub mcp_servers: Option<HashMap<String, McpServerConfig>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub infer: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub skills: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
}
impl CustomAgentConfig {
pub fn new(name: impl Into<String>, prompt: impl Into<String>) -> Self {
Self {
name: name.into(),
prompt: prompt.into(),
..Self::default()
}
}
pub fn with_display_name(mut self, display_name: impl Into<String>) -> Self {
self.display_name = Some(display_name.into());
self
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn with_tools<I, S>(mut self, tools: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.tools = Some(tools.into_iter().map(Into::into).collect());
self
}
pub fn with_mcp_servers(mut self, mcp_servers: HashMap<String, McpServerConfig>) -> Self {
self.mcp_servers = Some(mcp_servers);
self
}
pub fn with_infer(mut self, infer: bool) -> Self {
self.infer = Some(infer);
self
}
pub fn with_skills<I, S>(mut self, skills: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.skills = Some(skills.into_iter().map(Into::into).collect());
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DefaultAgentConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub excluded_tools: Option<Vec<String>>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct InfiniteSessionConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub enabled: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub background_compaction_threshold: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub buffer_exhaustion_threshold: Option<f64>,
}
impl InfiniteSessionConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_enabled(mut self, enabled: bool) -> Self {
self.enabled = Some(enabled);
self
}
pub fn with_background_compaction_threshold(mut self, threshold: f64) -> Self {
self.background_compaction_threshold = Some(threshold);
self
}
pub fn with_buffer_exhaustion_threshold(mut self, threshold: f64) -> Self {
self.buffer_exhaustion_threshold = Some(threshold);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct CloudSessionRepository {
pub owner: String,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub branch: Option<String>,
}
impl CloudSessionRepository {
pub fn new(owner: impl Into<String>, name: impl Into<String>) -> Self {
Self {
owner: owner.into(),
name: name.into(),
branch: None,
}
}
pub fn with_branch(mut self, branch: impl Into<String>) -> Self {
self.branch = Some(branch.into());
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct CloudSessionOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub repository: Option<CloudSessionRepository>,
}
impl CloudSessionOptions {
pub fn with_repository(repository: CloudSessionRepository) -> Self {
Self {
repository: Some(repository),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct ExtensionInfo {
pub source: String,
pub name: String,
}
impl ExtensionInfo {
pub fn new(source: impl Into<String>, name: impl Into<String>) -> Self {
Self {
source: source.into(),
name: name.into(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
#[non_exhaustive]
pub enum McpServerConfig {
#[serde(alias = "local")]
Stdio(McpStdioServerConfig),
Http(McpHttpServerConfig),
Sse(McpHttpServerConfig),
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct McpStdioServerConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub timeout: Option<i64>,
pub command: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub args: Vec<String>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub env: HashMap<String, String>,
#[serde(default, skip_serializing_if = "Option::is_none", rename = "cwd")]
pub working_directory: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct McpHttpServerConfig {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub timeout: Option<i64>,
pub url: String,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub headers: HashMap<String, String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct ProviderConfig {
#[serde(default, skip_serializing_if = "Option::is_none", rename = "type")]
pub provider_type: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub wire_api: Option<String>,
pub base_url: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub api_key: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bearer_token: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub azure: Option<AzureProviderOptions>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub wire_model: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_prompt_tokens: Option<i64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<i64>,
}
impl ProviderConfig {
pub fn new(base_url: impl Into<String>) -> Self {
Self {
base_url: base_url.into(),
..Self::default()
}
}
pub fn with_provider_type(mut self, provider_type: impl Into<String>) -> Self {
self.provider_type = Some(provider_type.into());
self
}
pub fn with_wire_api(mut self, wire_api: impl Into<String>) -> Self {
self.wire_api = Some(wire_api.into());
self
}
pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}
pub fn with_bearer_token(mut self, bearer_token: impl Into<String>) -> Self {
self.bearer_token = Some(bearer_token.into());
self
}
pub fn with_azure(mut self, azure: AzureProviderOptions) -> Self {
self.azure = Some(azure);
self
}
pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
self.headers = Some(headers);
self
}
pub fn with_model_id(mut self, model_id: impl Into<String>) -> Self {
self.model_id = Some(model_id.into());
self
}
pub fn with_wire_model(mut self, wire_model: impl Into<String>) -> Self {
self.wire_model = Some(wire_model.into());
self
}
pub fn with_max_prompt_tokens(mut self, max: i64) -> Self {
self.max_prompt_tokens = Some(max);
self
}
pub fn with_max_output_tokens(mut self, max: i64) -> Self {
self.max_output_tokens = Some(max);
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AzureProviderOptions {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub api_version: Option<String>,
}
#[derive(Clone)]
#[non_exhaustive]
pub struct SessionConfig {
pub session_id: Option<SessionId>,
pub model: Option<String>,
pub client_name: Option<String>,
pub reasoning_effort: Option<String>,
pub streaming: Option<bool>,
pub system_message: Option<SystemMessageConfig>,
pub tools: Option<Vec<Tool>>,
pub canvases: Option<Vec<CanvasDeclaration>>,
pub canvas_handler: Option<Arc<dyn CanvasHandler>>,
pub request_canvas_renderer: Option<bool>,
pub request_extensions: Option<bool>,
pub extension_info: Option<ExtensionInfo>,
pub available_tools: Option<Vec<String>>,
pub excluded_tools: Option<Vec<String>>,
pub mcp_servers: Option<HashMap<String, McpServerConfig>>,
pub enable_config_discovery: Option<bool>,
pub skill_directories: Option<Vec<PathBuf>>,
pub instruction_directories: Option<Vec<PathBuf>>,
pub disabled_skills: Option<Vec<String>>,
pub hooks: Option<bool>,
pub custom_agents: Option<Vec<CustomAgentConfig>>,
pub default_agent: Option<DefaultAgentConfig>,
pub agent: Option<String>,
pub infinite_sessions: Option<InfiniteSessionConfig>,
pub provider: Option<ProviderConfig>,
pub enable_session_telemetry: Option<bool>,
pub model_capabilities: Option<crate::generated::api_types::ModelCapabilitiesOverride>,
pub config_dir: Option<PathBuf>,
pub working_directory: Option<PathBuf>,
pub github_token: Option<String>,
pub remote_session: Option<crate::generated::api_types::RemoteSessionMode>,
pub cloud: Option<CloudSessionOptions>,
pub include_sub_agent_streaming_events: Option<bool>,
pub commands: Option<Vec<CommandDefinition>>,
pub session_fs_provider: Option<Arc<dyn SessionFsProvider>>,
pub permission_handler: Option<Arc<dyn PermissionHandler>>,
pub elicitation_handler: Option<Arc<dyn ElicitationHandler>>,
pub user_input_handler: Option<Arc<dyn UserInputHandler>>,
pub exit_plan_mode_handler: Option<Arc<dyn ExitPlanModeHandler>>,
pub auto_mode_switch_handler: Option<Arc<dyn AutoModeSwitchHandler>>,
pub hooks_handler: Option<Arc<dyn SessionHooks>>,
pub(crate) permission_policy: Option<crate::permission::Policy>,
pub system_message_transform: Option<Arc<dyn SystemMessageTransform>>,
}
impl std::fmt::Debug for SessionConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SessionConfig")
.field("session_id", &self.session_id)
.field("model", &self.model)
.field("client_name", &self.client_name)
.field("reasoning_effort", &self.reasoning_effort)
.field("streaming", &self.streaming)
.field("system_message", &self.system_message)
.field("tools", &self.tools)
.field("canvases", &self.canvases)
.field(
"canvas_handler",
&self.canvas_handler.as_ref().map(|_| "<set>"),
)
.field("request_canvas_renderer", &self.request_canvas_renderer)
.field("request_extensions", &self.request_extensions)
.field("extension_info", &self.extension_info)
.field("available_tools", &self.available_tools)
.field("excluded_tools", &self.excluded_tools)
.field("mcp_servers", &self.mcp_servers)
.field("enable_config_discovery", &self.enable_config_discovery)
.field("skill_directories", &self.skill_directories)
.field("instruction_directories", &self.instruction_directories)
.field("disabled_skills", &self.disabled_skills)
.field("hooks", &self.hooks)
.field("custom_agents", &self.custom_agents)
.field("default_agent", &self.default_agent)
.field("agent", &self.agent)
.field("infinite_sessions", &self.infinite_sessions)
.field("provider", &self.provider)
.field("enable_session_telemetry", &self.enable_session_telemetry)
.field("model_capabilities", &self.model_capabilities)
.field("config_dir", &self.config_dir)
.field("working_directory", &self.working_directory)
.field(
"github_token",
&self.github_token.as_ref().map(|_| "<redacted>"),
)
.field("remote_session", &self.remote_session)
.field("cloud", &self.cloud)
.field(
"include_sub_agent_streaming_events",
&self.include_sub_agent_streaming_events,
)
.field("commands", &self.commands)
.field(
"session_fs_provider",
&self.session_fs_provider.as_ref().map(|_| "<set>"),
)
.field(
"permission_handler",
&self.permission_handler.as_ref().map(|_| "<set>"),
)
.field(
"elicitation_handler",
&self.elicitation_handler.as_ref().map(|_| "<set>"),
)
.field(
"user_input_handler",
&self.user_input_handler.as_ref().map(|_| "<set>"),
)
.field(
"exit_plan_mode_handler",
&self.exit_plan_mode_handler.as_ref().map(|_| "<set>"),
)
.field(
"auto_mode_switch_handler",
&self.auto_mode_switch_handler.as_ref().map(|_| "<set>"),
)
.field(
"hooks_handler",
&self.hooks_handler.as_ref().map(|_| "<set>"),
)
.field(
"system_message_transform",
&self.system_message_transform.as_ref().map(|_| "<set>"),
)
.finish()
}
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
session_id: None,
model: None,
client_name: None,
reasoning_effort: None,
streaming: None,
system_message: None,
tools: None,
canvases: None,
canvas_handler: None,
request_canvas_renderer: None,
request_extensions: None,
extension_info: None,
available_tools: None,
excluded_tools: None,
mcp_servers: None,
enable_config_discovery: None,
skill_directories: None,
instruction_directories: None,
disabled_skills: None,
hooks: None,
custom_agents: None,
default_agent: None,
agent: None,
infinite_sessions: None,
provider: None,
enable_session_telemetry: None,
model_capabilities: None,
config_dir: None,
working_directory: None,
github_token: None,
remote_session: None,
cloud: None,
include_sub_agent_streaming_events: None,
commands: None,
session_fs_provider: None,
permission_handler: None,
elicitation_handler: None,
user_input_handler: None,
exit_plan_mode_handler: None,
auto_mode_switch_handler: None,
hooks_handler: None,
permission_policy: None,
system_message_transform: None,
}
}
}
pub(crate) struct SessionConfigRuntime {
pub permission_handler: Option<Arc<dyn PermissionHandler>>,
pub permission_policy: Option<crate::permission::Policy>,
pub elicitation_handler: Option<Arc<dyn ElicitationHandler>>,
pub user_input_handler: Option<Arc<dyn UserInputHandler>>,
pub exit_plan_mode_handler: Option<Arc<dyn ExitPlanModeHandler>>,
pub auto_mode_switch_handler: Option<Arc<dyn AutoModeSwitchHandler>>,
pub hooks_handler: Option<Arc<dyn SessionHooks>>,
pub system_message_transform: Option<Arc<dyn SystemMessageTransform>>,
pub tool_handlers: HashMap<String, Arc<dyn crate::tool::ToolHandler>>,
pub canvas_handler: Option<Arc<dyn CanvasHandler>>,
pub session_fs_provider: Option<Arc<dyn SessionFsProvider>>,
pub commands: Option<Vec<CommandDefinition>>,
}
impl SessionConfig {
pub(crate) fn into_wire(
mut self,
session_id: SessionId,
) -> Result<(crate::wire::SessionCreateWire, SessionConfigRuntime), crate::Error> {
let permission_active =
self.permission_handler.is_some() || self.permission_policy.is_some();
let request_user_input = self.user_input_handler.is_some();
let request_exit_plan_mode = self.exit_plan_mode_handler.is_some();
let request_auto_mode_switch = self.auto_mode_switch_handler.is_some();
let request_elicitation = self.elicitation_handler.is_some();
let hooks_flag = self.hooks_handler.is_some();
let mut tool_handlers: HashMap<String, Arc<dyn crate::tool::ToolHandler>> = HashMap::new();
if let Some(tools) = self.tools.as_mut() {
for tool in tools.iter_mut() {
if let Some(handler) = tool.handler.take()
&& tool_handlers.insert(tool.name.clone(), handler).is_some()
{
return Err(crate::Error::InvalidConfig(format!(
"duplicate tool handler registered for name {:?}",
tool.name
)));
}
}
}
let wire_commands = self.commands.as_ref().map(|cmds| {
cmds.iter()
.map(|c| crate::wire::CommandWireDefinition {
name: c.name.clone(),
description: c.description.clone(),
})
.collect()
});
let wire_canvases = self.canvases.clone();
let canvas_handler = self.canvas_handler.clone();
let wire = crate::wire::SessionCreateWire {
session_id,
model: self.model,
client_name: self.client_name,
reasoning_effort: self.reasoning_effort,
streaming: self.streaming,
system_message: self.system_message,
tools: self.tools,
canvases: wire_canvases,
request_canvas_renderer: self.request_canvas_renderer,
request_extensions: self.request_extensions,
extension_info: self.extension_info,
available_tools: self.available_tools,
excluded_tools: self.excluded_tools,
mcp_servers: self.mcp_servers,
env_value_mode: "direct",
enable_config_discovery: self.enable_config_discovery,
request_user_input,
request_permission: permission_active,
request_exit_plan_mode,
request_auto_mode_switch,
request_elicitation,
hooks: hooks_flag,
skill_directories: self.skill_directories,
instruction_directories: self.instruction_directories,
disabled_skills: self.disabled_skills,
custom_agents: self.custom_agents,
default_agent: self.default_agent,
agent: self.agent,
infinite_sessions: self.infinite_sessions,
provider: self.provider,
enable_session_telemetry: self.enable_session_telemetry,
model_capabilities: self.model_capabilities,
config_dir: self.config_dir,
working_directory: self.working_directory,
github_token: self.github_token,
remote_session: self.remote_session,
cloud: self.cloud,
include_sub_agent_streaming_events: self.include_sub_agent_streaming_events,
commands: wire_commands,
};
let runtime = SessionConfigRuntime {
permission_handler: self.permission_handler,
permission_policy: self.permission_policy,
elicitation_handler: self.elicitation_handler,
user_input_handler: self.user_input_handler,
exit_plan_mode_handler: self.exit_plan_mode_handler,
auto_mode_switch_handler: self.auto_mode_switch_handler,
hooks_handler: self.hooks_handler,
system_message_transform: self.system_message_transform,
tool_handlers,
canvas_handler,
session_fs_provider: self.session_fs_provider,
commands: self.commands,
};
Ok((wire, runtime))
}
pub fn with_permission_handler(mut self, handler: Arc<dyn PermissionHandler>) -> Self {
self.permission_handler = Some(handler);
self
}
pub fn with_elicitation_handler(mut self, handler: Arc<dyn ElicitationHandler>) -> Self {
self.elicitation_handler = Some(handler);
self
}
pub fn with_user_input_handler(mut self, handler: Arc<dyn UserInputHandler>) -> Self {
self.user_input_handler = Some(handler);
self
}
pub fn with_exit_plan_mode_handler(mut self, handler: Arc<dyn ExitPlanModeHandler>) -> Self {
self.exit_plan_mode_handler = Some(handler);
self
}
pub fn with_auto_mode_switch_handler(
mut self,
handler: Arc<dyn AutoModeSwitchHandler>,
) -> Self {
self.auto_mode_switch_handler = Some(handler);
self
}
pub fn with_commands(mut self, commands: Vec<CommandDefinition>) -> Self {
self.commands = Some(commands);
self
}
pub fn with_session_fs_provider(mut self, provider: Arc<dyn SessionFsProvider>) -> Self {
self.session_fs_provider = Some(provider);
self
}
pub fn with_hooks(mut self, hooks: Arc<dyn SessionHooks>) -> Self {
self.hooks_handler = Some(hooks);
self
}
pub fn with_system_message_transform(
mut self,
transform: Arc<dyn SystemMessageTransform>,
) -> Self {
self.system_message_transform = Some(transform);
self
}
pub fn approve_all_permissions(mut self) -> Self {
self.permission_policy = Some(crate::permission::Policy::ApproveAll);
self
}
pub fn deny_all_permissions(mut self) -> Self {
self.permission_policy = Some(crate::permission::Policy::DenyAll);
self
}
pub fn approve_permissions_if<F>(mut self, predicate: F) -> Self
where
F: Fn(&crate::types::PermissionRequestData) -> bool + Send + Sync + 'static,
{
self.permission_policy = Some(crate::permission::Policy::Predicate(Arc::new(predicate)));
self
}
pub fn with_session_id(mut self, id: impl Into<SessionId>) -> Self {
self.session_id = Some(id.into());
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_client_name(mut self, name: impl Into<String>) -> Self {
self.client_name = Some(name.into());
self
}
pub fn with_reasoning_effort(mut self, effort: impl Into<String>) -> Self {
self.reasoning_effort = Some(effort.into());
self
}
pub fn with_streaming(mut self, streaming: bool) -> Self {
self.streaming = Some(streaming);
self
}
pub fn with_system_message(mut self, system_message: SystemMessageConfig) -> Self {
self.system_message = Some(system_message);
self
}
pub fn with_tools<I: IntoIterator<Item = Tool>>(mut self, tools: I) -> Self {
self.tools = Some(tools.into_iter().collect());
self
}
pub fn with_canvases<I: IntoIterator<Item = CanvasDeclaration>>(mut self, canvases: I) -> Self {
self.canvases = Some(canvases.into_iter().collect());
self
}
pub fn with_canvas_handler(mut self, handler: Arc<dyn CanvasHandler>) -> Self {
self.canvas_handler = Some(handler);
self
}
pub fn with_request_canvas_renderer(mut self, request: bool) -> Self {
self.request_canvas_renderer = Some(request);
self
}
pub fn with_request_extensions(mut self, request: bool) -> Self {
self.request_extensions = Some(request);
self
}
pub fn with_extension_info(mut self, extension_info: ExtensionInfo) -> Self {
self.extension_info = Some(extension_info);
self
}
pub fn with_available_tools<I, S>(mut self, tools: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.available_tools = Some(tools.into_iter().map(Into::into).collect());
self
}
pub fn with_excluded_tools<I, S>(mut self, tools: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.excluded_tools = Some(tools.into_iter().map(Into::into).collect());
self
}
pub fn with_mcp_servers(mut self, servers: HashMap<String, McpServerConfig>) -> Self {
self.mcp_servers = Some(servers);
self
}
pub fn with_enable_config_discovery(mut self, enable: bool) -> Self {
self.enable_config_discovery = Some(enable);
self
}
pub fn with_skill_directories<I, P>(mut self, paths: I) -> Self
where
I: IntoIterator<Item = P>,
P: Into<PathBuf>,
{
self.skill_directories = Some(paths.into_iter().map(Into::into).collect());
self
}
pub fn with_instruction_directories<I, P>(mut self, paths: I) -> Self
where
I: IntoIterator<Item = P>,
P: Into<PathBuf>,
{
self.instruction_directories = Some(paths.into_iter().map(Into::into).collect());
self
}
pub fn with_disabled_skills<I, S>(mut self, names: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.disabled_skills = Some(names.into_iter().map(Into::into).collect());
self
}
pub fn with_custom_agents<I: IntoIterator<Item = CustomAgentConfig>>(
mut self,
agents: I,
) -> Self {
self.custom_agents = Some(agents.into_iter().collect());
self
}
pub fn with_default_agent(mut self, agent: DefaultAgentConfig) -> Self {
self.default_agent = Some(agent);
self
}
pub fn with_agent(mut self, name: impl Into<String>) -> Self {
self.agent = Some(name.into());
self
}
pub fn with_infinite_sessions(mut self, config: InfiniteSessionConfig) -> Self {
self.infinite_sessions = Some(config);
self
}
pub fn with_provider(mut self, provider: ProviderConfig) -> Self {
self.provider = Some(provider);
self
}
pub fn with_enable_session_telemetry(mut self, enable: bool) -> Self {
self.enable_session_telemetry = Some(enable);
self
}
pub fn with_model_capabilities(
mut self,
capabilities: crate::generated::api_types::ModelCapabilitiesOverride,
) -> Self {
self.model_capabilities = Some(capabilities);
self
}
pub fn with_config_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.config_dir = Some(dir.into());
self
}
pub fn with_working_directory(mut self, dir: impl Into<PathBuf>) -> Self {
self.working_directory = Some(dir.into());
self
}
pub fn with_github_token(mut self, token: impl Into<String>) -> Self {
self.github_token = Some(token.into());
self
}
pub fn with_include_sub_agent_streaming_events(mut self, include: bool) -> Self {
self.include_sub_agent_streaming_events = Some(include);
self
}
pub fn with_remote_session(
mut self,
mode: crate::generated::api_types::RemoteSessionMode,
) -> Self {
self.remote_session = Some(mode);
self
}
pub fn with_cloud(mut self, cloud: CloudSessionOptions) -> Self {
self.cloud = Some(cloud);
self
}
}
#[derive(Clone)]
#[non_exhaustive]
pub struct ResumeSessionConfig {
pub session_id: SessionId,
pub client_name: Option<String>,
pub reasoning_effort: Option<String>,
pub streaming: Option<bool>,
pub system_message: Option<SystemMessageConfig>,
pub tools: Option<Vec<Tool>>,
pub canvases: Option<Vec<CanvasDeclaration>>,
pub canvas_handler: Option<Arc<dyn CanvasHandler>>,
pub open_canvases: Option<Vec<OpenCanvasInstance>>,
pub request_canvas_renderer: Option<bool>,
pub request_extensions: Option<bool>,
pub extension_info: Option<ExtensionInfo>,
pub available_tools: Option<Vec<String>>,
pub excluded_tools: Option<Vec<String>>,
pub mcp_servers: Option<HashMap<String, McpServerConfig>>,
pub enable_config_discovery: Option<bool>,
pub skill_directories: Option<Vec<PathBuf>>,
pub instruction_directories: Option<Vec<PathBuf>>,
pub disabled_skills: Option<Vec<String>>,
pub hooks: Option<bool>,
pub custom_agents: Option<Vec<CustomAgentConfig>>,
pub default_agent: Option<DefaultAgentConfig>,
pub agent: Option<String>,
pub infinite_sessions: Option<InfiniteSessionConfig>,
pub provider: Option<ProviderConfig>,
pub enable_session_telemetry: Option<bool>,
pub model_capabilities: Option<crate::generated::api_types::ModelCapabilitiesOverride>,
pub config_dir: Option<PathBuf>,
pub working_directory: Option<PathBuf>,
pub github_token: Option<String>,
pub remote_session: Option<crate::generated::api_types::RemoteSessionMode>,
pub include_sub_agent_streaming_events: Option<bool>,
pub commands: Option<Vec<CommandDefinition>>,
pub session_fs_provider: Option<Arc<dyn SessionFsProvider>>,
pub suppress_resume_event: Option<bool>,
pub continue_pending_work: Option<bool>,
pub permission_handler: Option<Arc<dyn PermissionHandler>>,
pub elicitation_handler: Option<Arc<dyn ElicitationHandler>>,
pub user_input_handler: Option<Arc<dyn UserInputHandler>>,
pub exit_plan_mode_handler: Option<Arc<dyn ExitPlanModeHandler>>,
pub auto_mode_switch_handler: Option<Arc<dyn AutoModeSwitchHandler>>,
pub hooks_handler: Option<Arc<dyn SessionHooks>>,
pub(crate) permission_policy: Option<crate::permission::Policy>,
pub system_message_transform: Option<Arc<dyn SystemMessageTransform>>,
}
impl std::fmt::Debug for ResumeSessionConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResumeSessionConfig")
.field("session_id", &self.session_id)
.field("client_name", &self.client_name)
.field("reasoning_effort", &self.reasoning_effort)
.field("streaming", &self.streaming)
.field("system_message", &self.system_message)
.field("tools", &self.tools)
.field("canvases", &self.canvases)
.field(
"canvas_handler",
&self.canvas_handler.as_ref().map(|_| "<set>"),
)
.field("open_canvases", &self.open_canvases)
.field("request_canvas_renderer", &self.request_canvas_renderer)
.field("request_extensions", &self.request_extensions)
.field("extension_info", &self.extension_info)
.field("available_tools", &self.available_tools)
.field("excluded_tools", &self.excluded_tools)
.field("mcp_servers", &self.mcp_servers)
.field("enable_config_discovery", &self.enable_config_discovery)
.field("skill_directories", &self.skill_directories)
.field("instruction_directories", &self.instruction_directories)
.field("disabled_skills", &self.disabled_skills)
.field("hooks", &self.hooks)
.field("custom_agents", &self.custom_agents)
.field("default_agent", &self.default_agent)
.field("agent", &self.agent)
.field("infinite_sessions", &self.infinite_sessions)
.field("provider", &self.provider)
.field("enable_session_telemetry", &self.enable_session_telemetry)
.field("model_capabilities", &self.model_capabilities)
.field("config_dir", &self.config_dir)
.field("working_directory", &self.working_directory)
.field(
"github_token",
&self.github_token.as_ref().map(|_| "<redacted>"),
)
.field("remote_session", &self.remote_session)
.field(
"include_sub_agent_streaming_events",
&self.include_sub_agent_streaming_events,
)
.field("commands", &self.commands)
.field(
"session_fs_provider",
&self.session_fs_provider.as_ref().map(|_| "<set>"),
)
.field(
"permission_handler",
&self.permission_handler.as_ref().map(|_| "<set>"),
)
.field(
"elicitation_handler",
&self.elicitation_handler.as_ref().map(|_| "<set>"),
)
.field(
"user_input_handler",
&self.user_input_handler.as_ref().map(|_| "<set>"),
)
.field(
"exit_plan_mode_handler",
&self.exit_plan_mode_handler.as_ref().map(|_| "<set>"),
)
.field(
"auto_mode_switch_handler",
&self.auto_mode_switch_handler.as_ref().map(|_| "<set>"),
)
.field(
"hooks_handler",
&self.hooks_handler.as_ref().map(|_| "<set>"),
)
.field(
"system_message_transform",
&self.system_message_transform.as_ref().map(|_| "<set>"),
)
.field("suppress_resume_event", &self.suppress_resume_event)
.field("continue_pending_work", &self.continue_pending_work)
.finish()
}
}
impl ResumeSessionConfig {
pub(crate) fn into_wire(
mut self,
) -> Result<(crate::wire::SessionResumeWire, SessionConfigRuntime), crate::Error> {
let permission_active =
self.permission_handler.is_some() || self.permission_policy.is_some();
let request_user_input = self.user_input_handler.is_some();
let request_exit_plan_mode = self.exit_plan_mode_handler.is_some();
let request_auto_mode_switch = self.auto_mode_switch_handler.is_some();
let request_elicitation = self.elicitation_handler.is_some();
let hooks_flag = self.hooks_handler.is_some();
let mut tool_handlers: HashMap<String, Arc<dyn crate::tool::ToolHandler>> = HashMap::new();
if let Some(tools) = self.tools.as_mut() {
for tool in tools.iter_mut() {
if let Some(handler) = tool.handler.take()
&& tool_handlers.insert(tool.name.clone(), handler).is_some()
{
return Err(crate::Error::InvalidConfig(format!(
"duplicate tool handler registered for name {:?}",
tool.name
)));
}
}
}
let wire_commands = self.commands.as_ref().map(|cmds| {
cmds.iter()
.map(|c| crate::wire::CommandWireDefinition {
name: c.name.clone(),
description: c.description.clone(),
})
.collect()
});
let wire_canvases = self.canvases.clone();
let canvas_handler = self.canvas_handler.clone();
let wire = crate::wire::SessionResumeWire {
session_id: self.session_id,
client_name: self.client_name,
reasoning_effort: self.reasoning_effort,
streaming: self.streaming,
system_message: self.system_message,
tools: self.tools,
canvases: wire_canvases,
open_canvases: self.open_canvases,
request_canvas_renderer: self.request_canvas_renderer,
request_extensions: self.request_extensions,
extension_info: self.extension_info,
available_tools: self.available_tools,
excluded_tools: self.excluded_tools,
mcp_servers: self.mcp_servers,
env_value_mode: "direct",
enable_config_discovery: self.enable_config_discovery,
request_user_input,
request_permission: permission_active,
request_exit_plan_mode,
request_auto_mode_switch,
request_elicitation,
hooks: hooks_flag,
skill_directories: self.skill_directories,
instruction_directories: self.instruction_directories,
disabled_skills: self.disabled_skills,
custom_agents: self.custom_agents,
default_agent: self.default_agent,
agent: self.agent,
infinite_sessions: self.infinite_sessions,
provider: self.provider,
enable_session_telemetry: self.enable_session_telemetry,
model_capabilities: self.model_capabilities,
config_dir: self.config_dir,
working_directory: self.working_directory,
github_token: self.github_token,
remote_session: self.remote_session,
include_sub_agent_streaming_events: self.include_sub_agent_streaming_events,
commands: wire_commands,
suppress_resume_event: self.suppress_resume_event,
continue_pending_work: self.continue_pending_work,
};
let runtime = SessionConfigRuntime {
permission_handler: self.permission_handler,
permission_policy: self.permission_policy,
elicitation_handler: self.elicitation_handler,
user_input_handler: self.user_input_handler,
exit_plan_mode_handler: self.exit_plan_mode_handler,
auto_mode_switch_handler: self.auto_mode_switch_handler,
hooks_handler: self.hooks_handler,
system_message_transform: self.system_message_transform,
tool_handlers,
canvas_handler,
session_fs_provider: self.session_fs_provider,
commands: self.commands,
};
Ok((wire, runtime))
}
pub fn new(session_id: SessionId) -> Self {
Self {
session_id,
client_name: None,
reasoning_effort: None,
streaming: None,
system_message: None,
tools: None,
canvases: None,
canvas_handler: None,
open_canvases: None,
request_canvas_renderer: None,
request_extensions: None,
extension_info: None,
available_tools: None,
excluded_tools: None,
mcp_servers: None,
enable_config_discovery: None,
skill_directories: None,
instruction_directories: None,
disabled_skills: None,
hooks: None,
custom_agents: None,
default_agent: None,
agent: None,
infinite_sessions: None,
provider: None,
enable_session_telemetry: None,
model_capabilities: None,
config_dir: None,
working_directory: None,
github_token: None,
remote_session: None,
include_sub_agent_streaming_events: None,
commands: None,
session_fs_provider: None,
suppress_resume_event: None,
continue_pending_work: None,
permission_handler: None,
elicitation_handler: None,
user_input_handler: None,
exit_plan_mode_handler: None,
auto_mode_switch_handler: None,
hooks_handler: None,
permission_policy: None,
system_message_transform: None,
}
}
pub fn with_permission_handler(mut self, handler: Arc<dyn PermissionHandler>) -> Self {
self.permission_handler = Some(handler);
self
}
pub fn with_elicitation_handler(mut self, handler: Arc<dyn ElicitationHandler>) -> Self {
self.elicitation_handler = Some(handler);
self
}
pub fn with_user_input_handler(mut self, handler: Arc<dyn UserInputHandler>) -> Self {
self.user_input_handler = Some(handler);
self
}
pub fn with_exit_plan_mode_handler(mut self, handler: Arc<dyn ExitPlanModeHandler>) -> Self {
self.exit_plan_mode_handler = Some(handler);
self
}
pub fn with_auto_mode_switch_handler(
mut self,
handler: Arc<dyn AutoModeSwitchHandler>,
) -> Self {
self.auto_mode_switch_handler = Some(handler);
self
}
pub fn with_hooks(mut self, hooks: Arc<dyn SessionHooks>) -> Self {
self.hooks_handler = Some(hooks);
self
}
pub fn with_system_message_transform(
mut self,
transform: Arc<dyn SystemMessageTransform>,
) -> Self {
self.system_message_transform = Some(transform);
self
}
pub fn with_commands(mut self, commands: Vec<CommandDefinition>) -> Self {
self.commands = Some(commands);
self
}
pub fn with_session_fs_provider(mut self, provider: Arc<dyn SessionFsProvider>) -> Self {
self.session_fs_provider = Some(provider);
self
}
pub fn approve_all_permissions(mut self) -> Self {
self.permission_policy = Some(crate::permission::Policy::ApproveAll);
self
}
pub fn deny_all_permissions(mut self) -> Self {
self.permission_policy = Some(crate::permission::Policy::DenyAll);
self
}
pub fn approve_permissions_if<F>(mut self, predicate: F) -> Self
where
F: Fn(&crate::types::PermissionRequestData) -> bool + Send + Sync + 'static,
{
self.permission_policy = Some(crate::permission::Policy::Predicate(Arc::new(predicate)));
self
}
pub fn with_client_name(mut self, name: impl Into<String>) -> Self {
self.client_name = Some(name.into());
self
}
pub fn with_reasoning_effort(mut self, effort: impl Into<String>) -> Self {
self.reasoning_effort = Some(effort.into());
self
}
pub fn with_streaming(mut self, streaming: bool) -> Self {
self.streaming = Some(streaming);
self
}
pub fn with_system_message(mut self, system_message: SystemMessageConfig) -> Self {
self.system_message = Some(system_message);
self
}
pub fn with_tools<I: IntoIterator<Item = Tool>>(mut self, tools: I) -> Self {
self.tools = Some(tools.into_iter().collect());
self
}
pub fn with_canvases<I: IntoIterator<Item = CanvasDeclaration>>(mut self, canvases: I) -> Self {
self.canvases = Some(canvases.into_iter().collect());
self
}
pub fn with_canvas_handler(mut self, handler: Arc<dyn CanvasHandler>) -> Self {
self.canvas_handler = Some(handler);
self
}
pub fn with_open_canvases<I: IntoIterator<Item = OpenCanvasInstance>>(
mut self,
open_canvases: I,
) -> Self {
self.open_canvases = Some(open_canvases.into_iter().collect());
self
}
pub fn with_request_canvas_renderer(mut self, request: bool) -> Self {
self.request_canvas_renderer = Some(request);
self
}
pub fn with_request_extensions(mut self, request: bool) -> Self {
self.request_extensions = Some(request);
self
}
pub fn with_extension_info(mut self, extension_info: ExtensionInfo) -> Self {
self.extension_info = Some(extension_info);
self
}
pub fn with_available_tools<I, S>(mut self, tools: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.available_tools = Some(tools.into_iter().map(Into::into).collect());
self
}
pub fn with_excluded_tools<I, S>(mut self, tools: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.excluded_tools = Some(tools.into_iter().map(Into::into).collect());
self
}
pub fn with_mcp_servers(mut self, servers: HashMap<String, McpServerConfig>) -> Self {
self.mcp_servers = Some(servers);
self
}
pub fn with_enable_config_discovery(mut self, enable: bool) -> Self {
self.enable_config_discovery = Some(enable);
self
}
pub fn with_skill_directories<I, P>(mut self, paths: I) -> Self
where
I: IntoIterator<Item = P>,
P: Into<PathBuf>,
{
self.skill_directories = Some(paths.into_iter().map(Into::into).collect());
self
}
pub fn with_instruction_directories<I, P>(mut self, paths: I) -> Self
where
I: IntoIterator<Item = P>,
P: Into<PathBuf>,
{
self.instruction_directories = Some(paths.into_iter().map(Into::into).collect());
self
}
pub fn with_disabled_skills<I, S>(mut self, names: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.disabled_skills = Some(names.into_iter().map(Into::into).collect());
self
}
pub fn with_custom_agents<I: IntoIterator<Item = CustomAgentConfig>>(
mut self,
agents: I,
) -> Self {
self.custom_agents = Some(agents.into_iter().collect());
self
}
pub fn with_default_agent(mut self, agent: DefaultAgentConfig) -> Self {
self.default_agent = Some(agent);
self
}
pub fn with_agent(mut self, name: impl Into<String>) -> Self {
self.agent = Some(name.into());
self
}
pub fn with_infinite_sessions(mut self, config: InfiniteSessionConfig) -> Self {
self.infinite_sessions = Some(config);
self
}
pub fn with_provider(mut self, provider: ProviderConfig) -> Self {
self.provider = Some(provider);
self
}
pub fn with_enable_session_telemetry(mut self, enable: bool) -> Self {
self.enable_session_telemetry = Some(enable);
self
}
pub fn with_model_capabilities(
mut self,
capabilities: crate::generated::api_types::ModelCapabilitiesOverride,
) -> Self {
self.model_capabilities = Some(capabilities);
self
}
pub fn with_config_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.config_dir = Some(dir.into());
self
}
pub fn with_working_directory(mut self, dir: impl Into<PathBuf>) -> Self {
self.working_directory = Some(dir.into());
self
}
pub fn with_github_token(mut self, token: impl Into<String>) -> Self {
self.github_token = Some(token.into());
self
}
pub fn with_include_sub_agent_streaming_events(mut self, include: bool) -> Self {
self.include_sub_agent_streaming_events = Some(include);
self
}
pub fn with_remote_session(
mut self,
mode: crate::generated::api_types::RemoteSessionMode,
) -> Self {
self.remote_session = Some(mode);
self
}
pub fn with_suppress_resume_event(mut self, suppress: bool) -> Self {
self.suppress_resume_event = Some(suppress);
self
}
pub fn with_continue_pending_work(mut self, continue_pending: bool) -> Self {
self.continue_pending_work = Some(continue_pending);
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct SystemMessageConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub mode: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sections: Option<HashMap<String, SectionOverride>>,
}
impl SystemMessageConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_mode(mut self, mode: impl Into<String>) -> Self {
self.mode = Some(mode.into());
self
}
pub fn with_content(mut self, content: impl Into<String>) -> Self {
self.content = Some(content.into());
self
}
pub fn with_sections(mut self, sections: HashMap<String, SectionOverride>) -> Self {
self.sections = Some(sections);
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SectionOverride {
#[serde(skip_serializing_if = "Option::is_none")]
pub action: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CreateSessionResult {
pub session_id: SessionId,
#[serde(skip_serializing_if = "Option::is_none")]
pub workspace_path: Option<PathBuf>,
#[serde(default, alias = "remote_url")]
pub remote_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub capabilities: Option<SessionCapabilities>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct ResumeSessionResult {
#[serde(default)]
pub session_id: Option<SessionId>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub workspace_path: Option<PathBuf>,
#[serde(default, alias = "remote_url")]
pub remote_url: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub capabilities: Option<SessionCapabilities>,
#[serde(
default,
alias = "openCanvasInstances",
skip_serializing_if = "Option::is_none"
)]
pub open_canvases: Option<Vec<OpenCanvasInstance>>,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum LogLevel {
#[default]
Info,
Warning,
Error,
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct LogOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub level: Option<LogLevel>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ephemeral: Option<bool>,
}
impl LogOptions {
pub fn with_level(mut self, level: LogLevel) -> Self {
self.level = Some(level);
self
}
pub fn with_ephemeral(mut self, ephemeral: bool) -> Self {
self.ephemeral = Some(ephemeral);
self
}
}
#[derive(Debug, Clone, Default)]
pub struct SetModelOptions {
pub reasoning_effort: Option<String>,
pub model_capabilities: Option<crate::generated::api_types::ModelCapabilitiesOverride>,
}
impl SetModelOptions {
pub fn with_reasoning_effort(mut self, effort: impl Into<String>) -> Self {
self.reasoning_effort = Some(effort.into());
self
}
pub fn with_model_capabilities(
mut self,
caps: crate::generated::api_types::ModelCapabilitiesOverride,
) -> Self {
self.model_capabilities = Some(caps);
self
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PingResponse {
#[serde(default)]
pub message: String,
#[serde(default)]
pub timestamp: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub protocol_version: Option<u32>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AttachmentLineRange {
pub start: u32,
pub end: u32,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AttachmentSelectionPosition {
pub line: u32,
pub character: u32,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AttachmentSelectionRange {
pub start: AttachmentSelectionPosition,
pub end: AttachmentSelectionPosition,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum GitHubReferenceType {
Issue,
Pr,
Discussion,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(
tag = "type",
rename_all = "camelCase",
rename_all_fields = "camelCase"
)]
#[non_exhaustive]
pub enum Attachment {
File {
path: PathBuf,
#[serde(skip_serializing_if = "Option::is_none")]
display_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
line_range: Option<AttachmentLineRange>,
},
Directory {
path: PathBuf,
#[serde(skip_serializing_if = "Option::is_none")]
display_name: Option<String>,
},
Selection {
file_path: PathBuf,
text: String,
#[serde(skip_serializing_if = "Option::is_none")]
display_name: Option<String>,
selection: AttachmentSelectionRange,
},
Blob {
data: String,
mime_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
display_name: Option<String>,
},
#[serde(rename = "github_reference")]
GitHubReference {
number: u64,
title: String,
reference_type: GitHubReferenceType,
state: String,
url: String,
},
}
impl Attachment {
pub fn display_name(&self) -> Option<&str> {
match self {
Self::File { display_name, .. }
| Self::Directory { display_name, .. }
| Self::Selection { display_name, .. }
| Self::Blob { display_name, .. } => display_name.as_deref(),
Self::GitHubReference { .. } => None,
}
}
pub fn label(&self) -> Option<String> {
if let Some(display_name) = self
.display_name()
.map(str::trim)
.filter(|name| !name.is_empty())
{
return Some(display_name.to_string());
}
match self {
Self::GitHubReference { number, title, .. } => Some(if title.trim().is_empty() {
format!("#{}", number)
} else {
title.trim().to_string()
}),
_ => self.derived_display_name(),
}
}
pub fn ensure_display_name(&mut self) {
if self
.display_name()
.map(str::trim)
.is_some_and(|name| !name.is_empty())
{
return;
}
let Some(derived_display_name) = self.derived_display_name() else {
return;
};
match self {
Self::File { display_name, .. }
| Self::Directory { display_name, .. }
| Self::Selection { display_name, .. }
| Self::Blob { display_name, .. } => *display_name = Some(derived_display_name),
Self::GitHubReference { .. } => {}
}
}
fn derived_display_name(&self) -> Option<String> {
match self {
Self::File { path, .. } | Self::Directory { path, .. } => {
Some(attachment_name_from_path(path))
}
Self::Selection { file_path, .. } => Some(attachment_name_from_path(file_path)),
Self::Blob { .. } => Some("attachment".to_string()),
Self::GitHubReference { .. } => None,
}
}
}
fn attachment_name_from_path(path: &Path) -> String {
path.file_name()
.map(|name| name.to_string_lossy().into_owned())
.filter(|name| !name.is_empty())
.unwrap_or_else(|| {
let full = path.to_string_lossy();
if full.is_empty() {
"attachment".to_string()
} else {
full.into_owned()
}
})
}
pub fn ensure_attachment_display_names(attachments: &mut [Attachment]) {
for attachment in attachments {
attachment.ensure_display_name();
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[non_exhaustive]
pub enum DeliveryMode {
Enqueue,
Immediate,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct MessageOptions {
pub prompt: String,
pub mode: Option<DeliveryMode>,
pub attachments: Option<Vec<Attachment>>,
pub wait_timeout: Option<Duration>,
pub request_headers: Option<HashMap<String, String>>,
pub traceparent: Option<String>,
pub tracestate: Option<String>,
}
impl MessageOptions {
pub fn new(prompt: impl Into<String>) -> Self {
Self {
prompt: prompt.into(),
mode: None,
attachments: None,
wait_timeout: None,
request_headers: None,
traceparent: None,
tracestate: None,
}
}
pub fn with_mode(mut self, mode: DeliveryMode) -> Self {
self.mode = Some(mode);
self
}
pub fn with_attachments(mut self, attachments: Vec<Attachment>) -> Self {
self.attachments = Some(attachments);
self
}
pub fn with_wait_timeout(mut self, timeout: Duration) -> Self {
self.wait_timeout = Some(timeout);
self
}
pub fn with_request_headers(mut self, headers: HashMap<String, String>) -> Self {
self.request_headers = Some(headers);
self
}
pub fn with_trace_context(mut self, ctx: TraceContext) -> Self {
self.traceparent = ctx.traceparent;
self.tracestate = ctx.tracestate;
self
}
pub fn with_traceparent(mut self, traceparent: impl Into<String>) -> Self {
self.traceparent = Some(traceparent.into());
self
}
pub fn with_tracestate(mut self, tracestate: impl Into<String>) -> Self {
self.tracestate = Some(tracestate.into());
self
}
}
impl From<&str> for MessageOptions {
fn from(prompt: &str) -> Self {
Self::new(prompt)
}
}
impl From<String> for MessageOptions {
fn from(prompt: String) -> Self {
Self::new(prompt)
}
}
impl From<&String> for MessageOptions {
fn from(prompt: &String) -> Self {
Self::new(prompt.clone())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct GetStatusResponse {
pub version: String,
pub protocol_version: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct GetAuthStatusResponse {
pub is_authenticated: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub auth_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub host: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub login: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub status_message: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionEventNotification {
pub session_id: SessionId,
pub event: SessionEvent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionEvent {
pub id: String,
pub timestamp: String,
pub parent_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ephemeral: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub agent_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub debug_cli_received_at_ms: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub debug_ws_forwarded_at_ms: Option<i64>,
#[serde(rename = "type")]
pub event_type: String,
pub data: Value,
}
impl SessionEvent {
pub fn parsed_type(&self) -> crate::generated::SessionEventType {
use serde::de::IntoDeserializer;
let deserializer: serde::de::value::StrDeserializer<'_, serde::de::value::Error> =
self.event_type.as_str().into_deserializer();
crate::generated::SessionEventType::deserialize(deserializer)
.unwrap_or(crate::generated::SessionEventType::Unknown)
}
pub fn typed_data<T: serde::de::DeserializeOwned>(&self) -> Option<T> {
serde_json::from_value(self.data.clone()).ok()
}
pub fn is_transient_error(&self) -> bool {
self.event_type == "session.error"
&& self.data.get("errorType").and_then(|v| v.as_str()) == Some("model_call")
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct ToolInvocation {
pub session_id: SessionId,
pub tool_call_id: String,
pub tool_name: String,
pub arguments: Value,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub traceparent: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tracestate: Option<String>,
}
impl ToolInvocation {
pub fn params<P: serde::de::DeserializeOwned>(&self) -> Result<P, crate::Error> {
serde_json::from_value(self.arguments.clone()).map_err(crate::Error::from)
}
pub fn trace_context(&self) -> TraceContext {
TraceContext {
traceparent: self.traceparent.clone(),
tracestate: self.tracestate.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolBinaryResult {
pub data: String,
pub mime_type: String,
pub r#type: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolResultExpanded {
pub text_result_for_llm: String,
pub result_type: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub binary_results_for_llm: Option<Vec<ToolBinaryResult>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_log: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_telemetry: Option<HashMap<String, Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
#[non_exhaustive]
pub enum ToolResult {
Text(String),
Expanded(ToolResultExpanded),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolResultResponse {
pub result: ToolResult,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionMetadata {
pub session_id: SessionId,
pub start_time: String,
pub modified_time: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub summary: Option<String>,
pub is_remote: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ListSessionsResponse {
pub sessions: Vec<SessionMetadata>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionListFilter {
#[serde(default, skip_serializing_if = "Option::is_none", rename = "cwd")]
pub working_directory: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub git_root: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub repository: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub branch: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GetSessionMetadataResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub session: Option<SessionMetadata>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GetLastSessionIdResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<SessionId>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GetForegroundSessionResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<SessionId>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GetMessagesResponse {
pub events: Vec<SessionEvent>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ElicitationResult {
pub action: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<Value>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub enum ElicitationMode {
Form,
Url,
#[serde(other)]
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ElicitationRequest {
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub requested_schema: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub mode: Option<ElicitationMode>,
#[serde(skip_serializing_if = "Option::is_none")]
pub elicitation_source: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub url: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub ui: Option<UiCapabilities>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct UiCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub elicitation: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub canvases: Option<bool>,
}
#[derive(Debug, Clone, Default)]
pub struct UiInputOptions<'a> {
pub title: Option<&'a str>,
pub description: Option<&'a str>,
pub min_length: Option<u64>,
pub max_length: Option<u64>,
pub format: Option<InputFormat>,
pub default: Option<&'a str>,
}
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub enum InputFormat {
Email,
Uri,
Date,
DateTime,
}
impl InputFormat {
pub fn as_str(&self) -> &'static str {
match self {
Self::Email => "email",
Self::Uri => "uri",
Self::Date => "date",
Self::DateTime => "date-time",
}
}
}
pub use crate::generated::api_types::{
Model, ModelBilling, ModelCapabilities, ModelCapabilitiesLimits, ModelCapabilitiesLimitsVision,
ModelCapabilitiesSupports, ModelList, ModelPolicy, PermissionDecision,
PermissionDecisionApproveOnce, PermissionDecisionReject, PermissionDecisionUserNotAvailable,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
#[non_exhaustive]
pub enum PermissionRequestKind {
Shell,
Write,
Read,
Url,
Mcp,
CustomTool,
Memory,
Hook,
#[serde(other)]
Unknown,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PermissionRequestData {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub kind: Option<PermissionRequestKind>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(flatten)]
pub extra: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ExitPlanModeData {
#[serde(default)]
pub summary: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub plan_content: Option<String>,
#[serde(default)]
pub actions: Vec<String>,
#[serde(default = "default_recommended_action")]
pub recommended_action: String,
}
fn default_recommended_action() -> String {
"autopilot".to_string()
}
impl Default for ExitPlanModeData {
fn default() -> Self {
Self {
summary: String::new(),
plan_content: None,
actions: Vec::new(),
recommended_action: default_recommended_action(),
}
}
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use serde_json::json;
use super::{
Attachment, AttachmentLineRange, AttachmentSelectionPosition, AttachmentSelectionRange,
ConnectionState, CustomAgentConfig, DeliveryMode, ExtensionInfo, GitHubReferenceType,
InfiniteSessionConfig, ProviderConfig, ResumeSessionConfig, SessionConfig, SessionEvent,
SessionId, SystemMessageConfig, Tool, ToolBinaryResult, ToolResult, ToolResultExpanded,
ToolResultResponse, ensure_attachment_display_names,
};
use crate::generated::session_events::TypedSessionEvent;
#[test]
fn tool_builder_composes() {
let tool = Tool::new("greet")
.with_description("Say hello")
.with_namespaced_name("hello/greet")
.with_instructions("Pass the user's name")
.with_parameters(json!({
"type": "object",
"properties": { "name": { "type": "string" } },
"required": ["name"]
}))
.with_overrides_built_in_tool(true)
.with_skip_permission(true);
assert_eq!(tool.name, "greet");
assert_eq!(tool.description, "Say hello");
assert_eq!(tool.namespaced_name.as_deref(), Some("hello/greet"));
assert_eq!(tool.instructions.as_deref(), Some("Pass the user's name"));
assert_eq!(tool.parameters.get("type").unwrap(), &json!("object"));
assert!(tool.overrides_built_in_tool);
assert!(tool.skip_permission);
}
#[test]
fn custom_agent_config_builder_with_model() {
let agent = CustomAgentConfig::new("my-agent", "You are helpful.")
.with_model("claude-haiku-4.5")
.with_display_name("My Agent");
assert_eq!(agent.name, "my-agent");
assert_eq!(agent.model.as_deref(), Some("claude-haiku-4.5"));
assert_eq!(agent.display_name.as_deref(), Some("My Agent"));
}
#[test]
fn custom_agent_config_serializes_model() {
let agent = CustomAgentConfig::new("model-agent", "prompt").with_model("claude-haiku-4.5");
let wire = serde_json::to_value(&agent).unwrap();
assert_eq!(wire["model"], "claude-haiku-4.5");
assert_eq!(wire["name"], "model-agent");
}
#[test]
fn custom_agent_config_omits_model_when_none() {
let agent = CustomAgentConfig::new("no-model-agent", "prompt");
let wire = serde_json::to_value(&agent).unwrap();
assert!(wire.get("model").is_none());
}
#[test]
#[should_panic(expected = "tool parameter schema must be a JSON object")]
fn tool_with_parameters_panics_on_non_object_value() {
let _ = Tool::new("noop").with_parameters(json!(null));
}
#[test]
fn tool_result_expanded_serializes_binary_results_for_llm() {
let response = ToolResultResponse {
result: ToolResult::Expanded(ToolResultExpanded {
text_result_for_llm: "rendered chart".to_string(),
result_type: "success".to_string(),
binary_results_for_llm: Some(vec![ToolBinaryResult {
data: "aW1n".to_string(),
mime_type: "image/png".to_string(),
r#type: "image".to_string(),
description: Some("chart preview".to_string()),
}]),
session_log: None,
error: None,
tool_telemetry: None,
}),
};
let wire = serde_json::to_value(&response).unwrap();
assert_eq!(
wire,
json!({
"result": {
"textResultForLlm": "rendered chart",
"resultType": "success",
"binaryResultsForLlm": [
{
"data": "aW1n",
"mimeType": "image/png",
"type": "image",
"description": "chart preview"
}
]
}
})
);
}
#[test]
fn tool_result_expanded_omits_binary_results_for_llm_when_none() {
let response = ToolResultResponse {
result: ToolResult::Expanded(ToolResultExpanded {
text_result_for_llm: "ok".to_string(),
result_type: "success".to_string(),
binary_results_for_llm: None,
session_log: None,
error: None,
tool_telemetry: None,
}),
};
let wire = serde_json::to_value(&response).unwrap();
assert_eq!(wire["result"]["textResultForLlm"], "ok");
assert!(wire["result"].get("binaryResultsForLlm").is_none());
}
#[test]
fn session_config_default_wire_flags_off_without_handlers() {
let cfg = SessionConfig::default();
let (wire, _runtime) = cfg
.into_wire(SessionId::from("default-flags"))
.expect("default config has no duplicate handlers");
assert!(!wire.request_user_input);
assert!(!wire.request_permission);
assert!(!wire.request_elicitation);
assert!(!wire.request_exit_plan_mode);
assert!(!wire.request_auto_mode_switch);
assert!(!wire.hooks);
}
#[test]
fn resume_session_config_new_wire_flags_off_without_handlers() {
let cfg = ResumeSessionConfig::new(SessionId::from("resume-flags"));
let (wire, _runtime) = cfg
.into_wire()
.expect("default resume config has no duplicate handlers");
assert!(!wire.request_user_input);
assert!(!wire.request_permission);
assert!(!wire.request_elicitation);
assert!(!wire.request_exit_plan_mode);
assert!(!wire.request_auto_mode_switch);
assert!(!wire.hooks);
}
#[test]
#[allow(clippy::field_reassign_with_default)]
fn session_config_into_wire_serializes_bucket_b_fields() {
use std::path::PathBuf;
use super::{CloudSessionOptions, CloudSessionRepository};
let mut cfg = SessionConfig::default();
cfg.config_dir = Some(PathBuf::from("/tmp/cfg"));
cfg.working_directory = Some(PathBuf::from("/tmp/work"));
cfg.github_token = Some("ghs_secret".to_string());
cfg.include_sub_agent_streaming_events = Some(false);
cfg.enable_session_telemetry = Some(false);
cfg.remote_session = Some(crate::generated::api_types::RemoteSessionMode::Export);
cfg.cloud = Some(CloudSessionOptions::with_repository(
CloudSessionRepository::new("github", "copilot-sdk").with_branch("main"),
));
let (wire, _runtime) = cfg
.into_wire(SessionId::from("custom-id"))
.expect("no duplicate handlers");
let wire_json = serde_json::to_value(&wire).unwrap();
assert_eq!(wire_json["sessionId"], "custom-id");
assert_eq!(wire_json["configDir"], "/tmp/cfg");
assert_eq!(wire_json["workingDirectory"], "/tmp/work");
assert_eq!(wire_json["gitHubToken"], "ghs_secret");
assert_eq!(wire_json["includeSubAgentStreamingEvents"], false);
assert_eq!(wire_json["enableSessionTelemetry"], false);
assert_eq!(wire_json["remoteSession"], "export");
assert_eq!(wire_json["cloud"]["repository"]["owner"], "github");
assert_eq!(wire_json["cloud"]["repository"]["name"], "copilot-sdk");
assert_eq!(wire_json["cloud"]["repository"]["branch"], "main");
let (empty_wire, _) = SessionConfig::default()
.into_wire(SessionId::from("empty"))
.expect("default has no duplicate handlers");
let empty_json = serde_json::to_value(&empty_wire).unwrap();
assert!(empty_json.get("gitHubToken").is_none());
assert!(empty_json.get("enableSessionTelemetry").is_none());
assert!(empty_json.get("remoteSession").is_none());
assert!(empty_json.get("cloud").is_none());
}
#[test]
fn resume_session_config_into_wire_serializes_bucket_b_fields() {
use std::path::PathBuf;
let mut cfg = ResumeSessionConfig::new(SessionId::from("sess-1"));
cfg.working_directory = Some(PathBuf::from("/tmp/work"));
cfg.config_dir = Some(PathBuf::from("/tmp/cfg"));
cfg.github_token = Some("ghs_secret".to_string());
cfg.include_sub_agent_streaming_events = Some(true);
cfg.enable_session_telemetry = Some(false);
cfg.remote_session = Some(crate::generated::api_types::RemoteSessionMode::On);
let (wire, _) = cfg.into_wire().expect("no duplicate handlers");
let wire_json = serde_json::to_value(&wire).unwrap();
assert_eq!(wire_json["sessionId"], "sess-1");
assert_eq!(wire_json["workingDirectory"], "/tmp/work");
assert_eq!(wire_json["configDir"], "/tmp/cfg");
assert_eq!(wire_json["gitHubToken"], "ghs_secret");
assert_eq!(wire_json["includeSubAgentStreamingEvents"], true);
assert_eq!(wire_json["enableSessionTelemetry"], false);
assert_eq!(wire_json["remoteSession"], "on");
let (empty_wire, _) = ResumeSessionConfig::new(SessionId::from("sess-2"))
.into_wire()
.expect("default resume has no duplicate handlers");
let empty_json = serde_json::to_value(&empty_wire).unwrap();
assert!(empty_json.get("remoteSession").is_none());
}
#[test]
fn session_config_builder_composes() {
use std::collections::HashMap;
let cfg = SessionConfig::default()
.with_session_id(SessionId::from("sess-1"))
.with_model("claude-sonnet-4")
.with_client_name("test-app")
.with_reasoning_effort("medium")
.with_streaming(true)
.with_tools([Tool::new("greet")])
.with_available_tools(["bash", "view"])
.with_excluded_tools(["dangerous"])
.with_mcp_servers(HashMap::new())
.with_enable_config_discovery(true)
.with_skill_directories([PathBuf::from("/tmp/skills")])
.with_disabled_skills(["broken-skill"])
.with_agent("researcher")
.with_config_dir(PathBuf::from("/tmp/config"))
.with_working_directory(PathBuf::from("/tmp/work"))
.with_github_token("ghp_test")
.with_enable_session_telemetry(false)
.with_include_sub_agent_streaming_events(false)
.with_extension_info(ExtensionInfo::new("github-app", "counter"));
assert_eq!(cfg.session_id.as_ref().map(|s| s.as_str()), Some("sess-1"));
assert_eq!(cfg.model.as_deref(), Some("claude-sonnet-4"));
assert_eq!(cfg.client_name.as_deref(), Some("test-app"));
assert_eq!(cfg.reasoning_effort.as_deref(), Some("medium"));
assert_eq!(cfg.streaming, Some(true));
assert_eq!(cfg.tools.as_ref().map(|t| t.len()), Some(1));
assert_eq!(
cfg.available_tools.as_deref(),
Some(&["bash".to_string(), "view".to_string()][..])
);
assert_eq!(
cfg.excluded_tools.as_deref(),
Some(&["dangerous".to_string()][..])
);
assert!(cfg.mcp_servers.is_some());
assert_eq!(cfg.enable_config_discovery, Some(true));
assert_eq!(
cfg.skill_directories.as_deref(),
Some(&[PathBuf::from("/tmp/skills")][..])
);
assert_eq!(
cfg.disabled_skills.as_deref(),
Some(&["broken-skill".to_string()][..])
);
assert_eq!(cfg.agent.as_deref(), Some("researcher"));
assert_eq!(cfg.config_dir, Some(PathBuf::from("/tmp/config")));
assert_eq!(cfg.working_directory, Some(PathBuf::from("/tmp/work")));
assert_eq!(cfg.github_token.as_deref(), Some("ghp_test"));
assert_eq!(cfg.enable_session_telemetry, Some(false));
assert_eq!(cfg.include_sub_agent_streaming_events, Some(false));
assert_eq!(
cfg.extension_info,
Some(ExtensionInfo::new("github-app", "counter"))
);
}
#[test]
fn resume_session_config_builder_composes() {
use std::collections::HashMap;
let cfg = ResumeSessionConfig::new(SessionId::from("sess-2"))
.with_client_name("test-app")
.with_streaming(true)
.with_tools([Tool::new("greet")])
.with_available_tools(["bash", "view"])
.with_excluded_tools(["dangerous"])
.with_mcp_servers(HashMap::new())
.with_enable_config_discovery(true)
.with_skill_directories([PathBuf::from("/tmp/skills")])
.with_disabled_skills(["broken-skill"])
.with_agent("researcher")
.with_config_dir(PathBuf::from("/tmp/config"))
.with_working_directory(PathBuf::from("/tmp/work"))
.with_github_token("ghp_test")
.with_enable_session_telemetry(false)
.with_include_sub_agent_streaming_events(true)
.with_suppress_resume_event(true)
.with_continue_pending_work(true)
.with_extension_info(ExtensionInfo::new("github-app", "counter"));
assert_eq!(cfg.session_id.as_str(), "sess-2");
assert_eq!(cfg.client_name.as_deref(), Some("test-app"));
assert_eq!(cfg.streaming, Some(true));
assert_eq!(cfg.tools.as_ref().map(|t| t.len()), Some(1));
assert_eq!(
cfg.available_tools.as_deref(),
Some(&["bash".to_string(), "view".to_string()][..])
);
assert_eq!(
cfg.excluded_tools.as_deref(),
Some(&["dangerous".to_string()][..])
);
assert!(cfg.mcp_servers.is_some());
assert_eq!(cfg.enable_config_discovery, Some(true));
assert_eq!(
cfg.skill_directories.as_deref(),
Some(&[PathBuf::from("/tmp/skills")][..])
);
assert_eq!(
cfg.disabled_skills.as_deref(),
Some(&["broken-skill".to_string()][..])
);
assert_eq!(cfg.agent.as_deref(), Some("researcher"));
assert_eq!(cfg.config_dir, Some(PathBuf::from("/tmp/config")));
assert_eq!(cfg.working_directory, Some(PathBuf::from("/tmp/work")));
assert_eq!(cfg.github_token.as_deref(), Some("ghp_test"));
assert_eq!(cfg.enable_session_telemetry, Some(false));
assert_eq!(cfg.include_sub_agent_streaming_events, Some(true));
assert_eq!(cfg.suppress_resume_event, Some(true));
assert_eq!(cfg.continue_pending_work, Some(true));
assert_eq!(
cfg.extension_info,
Some(ExtensionInfo::new("github-app", "counter"))
);
}
#[test]
fn resume_session_config_serializes_continue_pending_work_to_camel_case() {
let cfg =
ResumeSessionConfig::new(SessionId::from("sess-1")).with_continue_pending_work(true);
let (wire, _) = cfg.into_wire().expect("no duplicate handlers");
let json = serde_json::to_value(&wire).unwrap();
assert_eq!(json["continuePendingWork"], true);
let (wire, _) = ResumeSessionConfig::new(SessionId::from("sess-2"))
.into_wire()
.expect("no duplicate handlers");
let json = serde_json::to_value(&wire).unwrap();
assert!(json.get("continuePendingWork").is_none());
}
#[test]
fn resume_session_config_serializes_suppress_resume_event_to_disable_resume_on_wire() {
let cfg =
ResumeSessionConfig::new(SessionId::from("sess-1")).with_suppress_resume_event(true);
let (wire, _) = cfg.into_wire().expect("no duplicate handlers");
let json = serde_json::to_value(&wire).unwrap();
assert_eq!(json["disableResume"], true);
assert!(json.get("suppressResumeEvent").is_none());
}
#[test]
fn session_config_serializes_instruction_directories_to_camel_case() {
let cfg =
SessionConfig::default().with_instruction_directories([PathBuf::from("/tmp/instr")]);
let (wire, _) = cfg
.into_wire(SessionId::from("instr-on"))
.expect("no duplicate handlers");
let json = serde_json::to_value(&wire).unwrap();
assert_eq!(
json["instructionDirectories"],
serde_json::json!(["/tmp/instr"])
);
let (wire, _) = SessionConfig::default()
.into_wire(SessionId::from("instr-off"))
.expect("no duplicate handlers");
let json = serde_json::to_value(&wire).unwrap();
assert!(json.get("instructionDirectories").is_none());
}
#[test]
fn resume_session_config_serializes_instruction_directories_to_camel_case() {
let cfg = ResumeSessionConfig::new(SessionId::from("sess-1"))
.with_instruction_directories([PathBuf::from("/tmp/instr")]);
let (wire, _) = cfg.into_wire().expect("no duplicate handlers");
let json = serde_json::to_value(&wire).unwrap();
assert_eq!(
json["instructionDirectories"],
serde_json::json!(["/tmp/instr"])
);
let (wire, _) = ResumeSessionConfig::new(SessionId::from("sess-2"))
.into_wire()
.expect("no duplicate handlers");
let json = serde_json::to_value(&wire).unwrap();
assert!(json.get("instructionDirectories").is_none());
}
#[test]
fn custom_agent_config_builder_composes() {
use std::collections::HashMap;
let cfg = CustomAgentConfig::new("researcher", "You are a research assistant.")
.with_display_name("Research Assistant")
.with_description("Investigates technical questions.")
.with_tools(["bash", "view"])
.with_mcp_servers(HashMap::new())
.with_infer(true)
.with_skills(["rust-coding-skill"]);
assert_eq!(cfg.name, "researcher");
assert_eq!(cfg.prompt, "You are a research assistant.");
assert_eq!(cfg.display_name.as_deref(), Some("Research Assistant"));
assert_eq!(
cfg.description.as_deref(),
Some("Investigates technical questions.")
);
assert_eq!(
cfg.tools.as_deref(),
Some(&["bash".to_string(), "view".to_string()][..])
);
assert!(cfg.mcp_servers.is_some());
assert_eq!(cfg.infer, Some(true));
assert_eq!(
cfg.skills.as_deref(),
Some(&["rust-coding-skill".to_string()][..])
);
}
#[test]
fn infinite_session_config_builder_composes() {
let cfg = InfiniteSessionConfig::new()
.with_enabled(true)
.with_background_compaction_threshold(0.75)
.with_buffer_exhaustion_threshold(0.92);
assert_eq!(cfg.enabled, Some(true));
assert_eq!(cfg.background_compaction_threshold, Some(0.75));
assert_eq!(cfg.buffer_exhaustion_threshold, Some(0.92));
}
#[test]
fn provider_config_builder_composes() {
use std::collections::HashMap;
let mut headers = HashMap::new();
headers.insert("X-Custom".to_string(), "value".to_string());
let cfg = ProviderConfig::new("https://api.example.com")
.with_provider_type("openai")
.with_wire_api("completions")
.with_api_key("sk-test")
.with_bearer_token("bearer-test")
.with_headers(headers)
.with_model_id("gpt-4")
.with_wire_model("azure-gpt-4-deployment")
.with_max_prompt_tokens(8192)
.with_max_output_tokens(2048);
assert_eq!(cfg.base_url, "https://api.example.com");
assert_eq!(cfg.provider_type.as_deref(), Some("openai"));
assert_eq!(cfg.wire_api.as_deref(), Some("completions"));
assert_eq!(cfg.api_key.as_deref(), Some("sk-test"));
assert_eq!(cfg.bearer_token.as_deref(), Some("bearer-test"));
assert_eq!(
cfg.headers
.as_ref()
.and_then(|h| h.get("X-Custom"))
.map(String::as_str),
Some("value"),
);
assert_eq!(cfg.model_id.as_deref(), Some("gpt-4"));
assert_eq!(cfg.wire_model.as_deref(), Some("azure-gpt-4-deployment"));
assert_eq!(cfg.max_prompt_tokens, Some(8192));
assert_eq!(cfg.max_output_tokens, Some(2048));
let wire = serde_json::to_value(&cfg).unwrap();
assert_eq!(wire["modelId"], "gpt-4");
assert_eq!(wire["wireModel"], "azure-gpt-4-deployment");
assert_eq!(wire["maxPromptTokens"], 8192);
assert_eq!(wire["maxOutputTokens"], 2048);
let unset = ProviderConfig::new("https://api.example.com");
let wire_unset = serde_json::to_value(&unset).unwrap();
assert!(wire_unset.get("modelId").is_none());
assert!(wire_unset.get("wireModel").is_none());
assert!(wire_unset.get("maxPromptTokens").is_none());
assert!(wire_unset.get("maxOutputTokens").is_none());
}
#[test]
fn system_message_config_builder_composes() {
use std::collections::HashMap;
let cfg = SystemMessageConfig::new()
.with_mode("replace")
.with_content("Custom system message.")
.with_sections(HashMap::new());
assert_eq!(cfg.mode.as_deref(), Some("replace"));
assert_eq!(cfg.content.as_deref(), Some("Custom system message."));
assert!(cfg.sections.is_some());
}
#[test]
fn delivery_mode_serializes_to_kebab_case_strings() {
assert_eq!(
serde_json::to_string(&DeliveryMode::Enqueue).unwrap(),
"\"enqueue\""
);
assert_eq!(
serde_json::to_string(&DeliveryMode::Immediate).unwrap(),
"\"immediate\""
);
let parsed: DeliveryMode = serde_json::from_str("\"immediate\"").unwrap();
assert_eq!(parsed, DeliveryMode::Immediate);
}
#[test]
fn connection_state_distinguishes_variants() {
assert_ne!(ConnectionState::Connected, ConnectionState::Disconnected);
}
#[test]
fn session_event_round_trips_agent_id_on_envelope() {
let wire = json!({
"id": "evt-1",
"timestamp": "2026-04-30T12:00:00Z",
"parentId": null,
"agentId": "sub-agent-42",
"type": "assistant.message",
"data": { "message": "hi" }
});
let event: SessionEvent = serde_json::from_value(wire.clone()).unwrap();
assert_eq!(event.agent_id.as_deref(), Some("sub-agent-42"));
let roundtripped = serde_json::to_value(&event).unwrap();
assert_eq!(roundtripped["agentId"], "sub-agent-42");
let main_agent_event: SessionEvent = serde_json::from_value(json!({
"id": "evt-2",
"timestamp": "2026-04-30T12:00:01Z",
"parentId": null,
"type": "session.idle",
"data": {}
}))
.unwrap();
assert!(main_agent_event.agent_id.is_none());
let roundtripped = serde_json::to_value(&main_agent_event).unwrap();
assert!(roundtripped.get("agentId").is_none());
}
#[test]
fn typed_session_event_round_trips_agent_id_on_envelope() {
let wire = json!({
"id": "evt-1",
"timestamp": "2026-04-30T12:00:00Z",
"parentId": null,
"agentId": "sub-agent-42",
"type": "session.idle",
"data": {}
});
let event: TypedSessionEvent = serde_json::from_value(wire).unwrap();
assert_eq!(event.agent_id.as_deref(), Some("sub-agent-42"));
let roundtripped = serde_json::to_value(&event).unwrap();
assert_eq!(roundtripped["agentId"], "sub-agent-42");
}
#[test]
fn connection_state_variants_compile() {
let _ = ConnectionState::Disconnected;
let _ = ConnectionState::Connecting;
let _ = ConnectionState::Connected;
let _ = ConnectionState::Error;
}
#[test]
fn deserializes_runtime_attachment_variants() {
let attachments: Vec<Attachment> = serde_json::from_value(json!([
{
"type": "file",
"path": "/tmp/file.rs",
"displayName": "file.rs",
"lineRange": { "start": 7, "end": 12 }
},
{
"type": "directory",
"path": "/tmp/project",
"displayName": "project"
},
{
"type": "selection",
"filePath": "/tmp/lib.rs",
"displayName": "lib.rs",
"text": "fn main() {}",
"selection": {
"start": { "line": 1, "character": 2 },
"end": { "line": 3, "character": 4 }
}
},
{
"type": "blob",
"data": "Zm9v",
"mimeType": "image/png",
"displayName": "image.png"
},
{
"type": "github_reference",
"number": 42,
"title": "Fix rendering",
"referenceType": "issue",
"state": "open",
"url": "https://github.com/example/repo/issues/42"
}
]))
.expect("attachments should deserialize");
assert_eq!(attachments.len(), 5);
assert!(matches!(
&attachments[0],
Attachment::File {
path,
display_name,
line_range: Some(AttachmentLineRange { start: 7, end: 12 }),
} if path == &PathBuf::from("/tmp/file.rs") && display_name.as_deref() == Some("file.rs")
));
assert!(matches!(
&attachments[1],
Attachment::Directory { path, display_name }
if path == &PathBuf::from("/tmp/project") && display_name.as_deref() == Some("project")
));
assert!(matches!(
&attachments[2],
Attachment::Selection {
file_path,
display_name,
selection:
AttachmentSelectionRange {
start: AttachmentSelectionPosition { line: 1, character: 2 },
end: AttachmentSelectionPosition { line: 3, character: 4 },
},
..
} if file_path == &PathBuf::from("/tmp/lib.rs") && display_name.as_deref() == Some("lib.rs")
));
assert!(matches!(
&attachments[3],
Attachment::Blob {
data,
mime_type,
display_name,
} if data == "Zm9v" && mime_type == "image/png" && display_name.as_deref() == Some("image.png")
));
assert!(matches!(
&attachments[4],
Attachment::GitHubReference {
number: 42,
title,
reference_type: GitHubReferenceType::Issue,
state,
url,
} if title == "Fix rendering"
&& state == "open"
&& url == "https://github.com/example/repo/issues/42"
));
}
#[test]
fn ensures_display_names_for_variants_that_support_them() {
let mut attachments = vec![
Attachment::File {
path: PathBuf::from("/tmp/file.rs"),
display_name: None,
line_range: None,
},
Attachment::Selection {
file_path: PathBuf::from("/tmp/src/lib.rs"),
display_name: None,
text: "fn main() {}".to_string(),
selection: AttachmentSelectionRange {
start: AttachmentSelectionPosition {
line: 0,
character: 0,
},
end: AttachmentSelectionPosition {
line: 0,
character: 10,
},
},
},
Attachment::Blob {
data: "Zm9v".to_string(),
mime_type: "image/png".to_string(),
display_name: None,
},
Attachment::GitHubReference {
number: 7,
title: "Track regressions".to_string(),
reference_type: GitHubReferenceType::Issue,
state: "open".to_string(),
url: "https://example.com/issues/7".to_string(),
},
];
ensure_attachment_display_names(&mut attachments);
assert_eq!(attachments[0].display_name(), Some("file.rs"));
assert_eq!(attachments[1].display_name(), Some("lib.rs"));
assert_eq!(attachments[2].display_name(), Some("attachment"));
assert_eq!(attachments[3].display_name(), None);
assert_eq!(
attachments[3].label(),
Some("Track regressions".to_string())
);
}
}
#[cfg(test)]
mod permission_builder_tests {
use std::sync::Arc;
use crate::handler::{ApproveAllHandler, PermissionHandler, PermissionResult};
use crate::permission;
use crate::types::{
PermissionDecision, PermissionRequestData, RequestId, ResumeSessionConfig, SessionConfig,
SessionId,
};
fn data() -> PermissionRequestData {
PermissionRequestData {
extra: serde_json::json!({"tool": "shell"}),
..Default::default()
}
}
fn resolve_create(mut cfg: SessionConfig) -> Option<Arc<dyn PermissionHandler>> {
permission::resolve_handler(cfg.permission_handler.take(), cfg.permission_policy.take())
}
fn resolve_resume(mut cfg: ResumeSessionConfig) -> Option<Arc<dyn PermissionHandler>> {
permission::resolve_handler(cfg.permission_handler.take(), cfg.permission_policy.take())
}
async fn dispatch(handler: &Arc<dyn PermissionHandler>) -> PermissionResult {
handler
.handle(SessionId::from("s1"), RequestId::new("1"), data())
.await
}
#[tokio::test]
async fn approve_all_with_handler_present_approves() {
let cfg = SessionConfig::default()
.with_permission_handler(Arc::new(ApproveAllHandler))
.approve_all_permissions();
let h = resolve_create(cfg).expect("policy + handler yields handler");
assert!(matches!(
dispatch(&h).await,
PermissionResult::Decision(PermissionDecision::ApproveOnce(_))
));
}
#[tokio::test]
async fn approve_all_standalone_produces_handler() {
let cfg = SessionConfig::default().approve_all_permissions();
let h = resolve_create(cfg).expect("policy alone yields handler");
assert!(matches!(
dispatch(&h).await,
PermissionResult::Decision(PermissionDecision::ApproveOnce(_))
));
}
#[tokio::test]
async fn approve_all_is_order_independent() {
let a = SessionConfig::default()
.with_permission_handler(Arc::new(ApproveAllHandler))
.approve_all_permissions();
let b = SessionConfig::default()
.approve_all_permissions()
.with_permission_handler(Arc::new(ApproveAllHandler));
let ha = resolve_create(a).unwrap();
let hb = resolve_create(b).unwrap();
assert!(matches!(
dispatch(&ha).await,
PermissionResult::Decision(PermissionDecision::ApproveOnce(_))
));
assert!(matches!(
dispatch(&hb).await,
PermissionResult::Decision(PermissionDecision::ApproveOnce(_))
));
}
#[tokio::test]
async fn deny_all_is_order_independent() {
let a = SessionConfig::default()
.with_permission_handler(Arc::new(ApproveAllHandler))
.deny_all_permissions();
let b = SessionConfig::default()
.deny_all_permissions()
.with_permission_handler(Arc::new(ApproveAllHandler));
let ha = resolve_create(a).unwrap();
let hb = resolve_create(b).unwrap();
assert!(matches!(
dispatch(&ha).await,
PermissionResult::Decision(PermissionDecision::Reject(_))
));
assert!(matches!(
dispatch(&hb).await,
PermissionResult::Decision(PermissionDecision::Reject(_))
));
}
#[tokio::test]
async fn approve_permissions_if_consults_predicate() {
let cfg = SessionConfig::default().approve_permissions_if(|d| {
d.extra.get("tool").and_then(|v| v.as_str()) != Some("shell")
});
let h = resolve_create(cfg).unwrap();
assert!(matches!(
dispatch(&h).await,
PermissionResult::Decision(PermissionDecision::Reject(_))
));
}
#[tokio::test]
async fn approve_permissions_if_is_order_independent() {
let predicate = |d: &PermissionRequestData| {
d.extra.get("tool").and_then(|v| v.as_str()) != Some("shell")
};
let a = SessionConfig::default()
.with_permission_handler(Arc::new(ApproveAllHandler))
.approve_permissions_if(predicate);
let b = SessionConfig::default()
.approve_permissions_if(predicate)
.with_permission_handler(Arc::new(ApproveAllHandler));
let ha = resolve_create(a).unwrap();
let hb = resolve_create(b).unwrap();
assert!(matches!(
dispatch(&ha).await,
PermissionResult::Decision(PermissionDecision::Reject(_))
));
assert!(matches!(
dispatch(&hb).await,
PermissionResult::Decision(PermissionDecision::Reject(_))
));
}
#[tokio::test]
async fn resume_session_config_approve_all_works() {
let cfg = ResumeSessionConfig::new(SessionId::from("s1"))
.with_permission_handler(Arc::new(ApproveAllHandler))
.approve_all_permissions();
let h = resolve_resume(cfg).unwrap();
assert!(matches!(
dispatch(&h).await,
PermissionResult::Decision(PermissionDecision::ApproveOnce(_))
));
}
#[tokio::test]
async fn resume_session_config_approve_all_is_order_independent() {
let a = ResumeSessionConfig::new(SessionId::from("s1"))
.with_permission_handler(Arc::new(ApproveAllHandler))
.approve_all_permissions();
let b = ResumeSessionConfig::new(SessionId::from("s1"))
.approve_all_permissions()
.with_permission_handler(Arc::new(ApproveAllHandler));
let ha = resolve_resume(a).unwrap();
let hb = resolve_resume(b).unwrap();
assert!(matches!(
dispatch(&ha).await,
PermissionResult::Decision(PermissionDecision::ApproveOnce(_))
));
assert!(matches!(
dispatch(&hb).await,
PermissionResult::Decision(PermissionDecision::ApproveOnce(_))
));
}
}