use crate::error::ClaudeError;
use crate::error::Result;
use crate::types::InputFormat;
use crate::types::Model;
use crate::types::OutputFormat;
use crate::types::PermissionMode;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum MCPServer {
#[serde(rename = "stdio")]
Stdio {
command: String,
args: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
env: Option<HashMap<String, String>>,
},
#[serde(rename = "http")]
Http {
url: String,
#[serde(skip_serializing_if = "Option::is_none")]
headers: Option<HashMap<String, String>>,
},
}
impl MCPServer {
pub fn stdio(command: impl Into<String>, args: Vec<String>) -> Self {
Self::Stdio {
command: command.into(),
args,
env: None,
}
}
pub fn stdio_with_env(
command: impl Into<String>,
args: Vec<String>,
env: HashMap<String, String>,
) -> Self {
Self::Stdio {
command: command.into(),
args,
env: Some(env),
}
}
pub fn http(url: impl Into<String>) -> Self {
Self::Http {
url: url.into(),
headers: None,
}
}
pub fn http_with_headers(url: impl Into<String>, headers: HashMap<String, String>) -> Self {
Self::Http {
url: url.into(),
headers: Some(headers),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPConfig {
#[serde(rename = "mcpServers")]
pub mcp_servers: HashMap<String, MCPServer>,
}
#[derive(Debug, Clone, Default)]
pub struct SessionConfig {
pub query: String,
pub resume_session_id: Option<String>,
pub explicit_session_id: Option<String>,
pub continue_last_session: bool,
pub fork_session: bool,
pub model: Option<Model>,
pub fallback_model: Option<Model>,
pub output_format: OutputFormat,
pub input_format: Option<InputFormat>,
pub mcp_config: Option<MCPConfig>,
pub strict_mcp_config: bool,
pub permission_mode: Option<PermissionMode>,
pub dangerously_skip_permissions: bool,
pub allow_dangerously_skip_permissions: bool,
pub system_prompt: Option<String>,
pub append_system_prompt: Option<String>,
pub tools: Option<Vec<String>>,
pub allowed_tools: Option<Vec<String>>,
pub disallowed_tools: Option<Vec<String>>,
pub json_schema: Option<String>,
pub include_partial_messages: bool,
pub replay_user_messages: bool,
pub settings: Option<String>,
pub setting_sources: Option<Vec<String>>,
pub additional_dirs: Vec<PathBuf>,
pub plugin_dirs: Vec<PathBuf>,
pub ide: bool,
pub agents: Option<String>,
pub debug: bool,
pub debug_filter: Option<String>,
pub working_dir: Option<PathBuf>,
pub env: Option<HashMap<String, String>>,
pub verbose: bool,
}
impl SessionConfig {
pub fn builder(query: impl Into<String>) -> SessionConfigBuilder {
SessionConfigBuilder::new(query)
}
pub fn validate(&self) -> Result<()> {
if self.query.is_empty() {
return Err(ClaudeError::InvalidConfiguration {
message: "Query cannot be empty".to_string(),
});
}
if self.continue_last_session && self.resume_session_id.is_some() {
return Err(ClaudeError::InvalidConfiguration {
message: "Cannot set both continue_last_session and resume_session_id".to_string(),
});
}
if self.resume_session_id.is_some() && self.explicit_session_id.is_some() {
return Err(ClaudeError::InvalidConfiguration {
message: "Cannot set both resume_session_id and explicit_session_id".to_string(),
});
}
if self.dangerously_skip_permissions ^ self.allow_dangerously_skip_permissions {
return Err(ClaudeError::InvalidConfiguration {
message: "Dangerous permissions require both flags enabled together (use enable_dangerous_permissions())".to_string(),
});
}
Ok(())
}
}
pub struct SessionConfigBuilder {
config: SessionConfig,
}
impl SessionConfigBuilder {
#[must_use]
pub fn new(query: impl Into<String>) -> Self {
Self {
config: SessionConfig {
query: query.into(),
..Default::default()
},
}
}
#[must_use]
pub fn resume_session_id(mut self, id: impl Into<String>) -> Self {
self.config.resume_session_id = Some(id.into());
self
}
#[must_use]
pub fn explicit_session_id(mut self, id: impl Into<String>) -> Self {
self.config.explicit_session_id = Some(id.into());
self
}
#[must_use]
pub fn continue_last_session(mut self, yes: bool) -> Self {
self.config.continue_last_session = yes;
self
}
#[must_use]
pub fn fork_session(mut self, yes: bool) -> Self {
self.config.fork_session = yes;
self
}
#[must_use]
pub fn model(mut self, model: Model) -> Self {
self.config.model = Some(model);
self
}
#[must_use]
pub fn fallback_model(mut self, model: Model) -> Self {
self.config.fallback_model = Some(model);
self
}
#[must_use]
pub fn output_format(mut self, format: OutputFormat) -> Self {
self.config.output_format = format;
self
}
#[must_use]
pub fn input_format(mut self, format: InputFormat) -> Self {
self.config.input_format = Some(format);
self
}
#[must_use]
pub fn mcp_config(mut self, config: MCPConfig) -> Self {
self.config.mcp_config = Some(config);
self
}
#[must_use]
pub fn strict_mcp_config(mut self, yes: bool) -> Self {
self.config.strict_mcp_config = yes;
self
}
#[must_use]
pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
self.config.permission_mode = Some(mode);
self
}
#[must_use]
pub fn enable_dangerous_permissions(mut self) -> Self {
self.config.allow_dangerously_skip_permissions = true;
self.config.dangerously_skip_permissions = true;
self
}
#[must_use]
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.config.system_prompt = Some(prompt.into());
self
}
#[must_use]
pub fn append_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.config.append_system_prompt = Some(prompt.into());
self
}
#[must_use]
pub fn tools(mut self, tools: Vec<String>) -> Self {
self.config.tools = Some(tools);
self
}
#[must_use]
pub fn allowed_tools(mut self, tools: Vec<String>) -> Self {
self.config.allowed_tools = Some(tools);
self
}
#[must_use]
pub fn disallowed_tools(mut self, tools: Vec<String>) -> Self {
self.config.disallowed_tools = Some(tools);
self
}
#[must_use]
pub fn allow_tool(mut self, tool: impl Into<String>) -> Self {
self.config
.allowed_tools
.get_or_insert_with(Vec::new)
.push(tool.into());
self
}
#[must_use]
pub fn disallow_tool(mut self, tool: impl Into<String>) -> Self {
self.config
.disallowed_tools
.get_or_insert_with(Vec::new)
.push(tool.into());
self
}
#[must_use]
pub fn json_schema(mut self, schema: impl Into<String>) -> Self {
self.config.json_schema = Some(schema.into());
self
}
#[must_use]
pub fn include_partial_messages(mut self, yes: bool) -> Self {
self.config.include_partial_messages = yes;
self
}
#[must_use]
pub fn replay_user_messages(mut self, yes: bool) -> Self {
self.config.replay_user_messages = yes;
self
}
#[must_use]
pub fn settings(mut self, s: impl Into<String>) -> Self {
self.config.settings = Some(s.into());
self
}
#[must_use]
pub fn setting_sources(mut self, sources: Vec<String>) -> Self {
self.config.setting_sources = Some(sources);
self
}
#[must_use]
pub fn add_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.config.additional_dirs.push(dir.into());
self
}
#[must_use]
pub fn plugin_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.config.plugin_dirs.push(dir.into());
self
}
#[must_use]
pub fn ide(mut self, yes: bool) -> Self {
self.config.ide = yes;
self
}
#[must_use]
pub fn agents(mut self, json: impl Into<String>) -> Self {
self.config.agents = Some(json.into());
self
}
#[must_use]
pub fn debug(mut self, yes: bool) -> Self {
self.config.debug = yes;
self
}
#[must_use]
pub fn debug_filter(mut self, filter: impl Into<String>) -> Self {
self.config.debug_filter = Some(filter.into());
self
}
#[must_use]
pub fn working_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.config.working_dir = Some(dir.into());
self
}
#[must_use]
pub fn env(mut self, env: HashMap<String, String>) -> Self {
self.config.env = Some(env);
self
}
#[must_use]
pub fn env_var(mut self, key: impl Into<String>, val: impl Into<String>) -> Self {
self.config
.env
.get_or_insert_with(HashMap::new)
.insert(key.into(), val.into());
self
}
#[must_use]
pub fn verbose(mut self, verbose: bool) -> Self {
self.config.verbose = verbose;
self
}
pub fn build(self) -> Result<SessionConfig> {
self.config.validate()?;
Ok(self.config)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_config_validation_empty_query() {
let config = SessionConfig::builder("").build();
assert!(config.is_err());
assert!(
config
.unwrap_err()
.to_string()
.contains("Query cannot be empty")
);
}
#[test]
fn test_session_config_validation_valid() {
let config = SessionConfig::builder("test query").build();
assert!(config.is_ok());
}
#[test]
fn test_session_config_validation_session_conflicts() {
let config = SessionConfig {
query: "test".to_string(),
continue_last_session: true,
resume_session_id: Some("id".to_string()),
..Default::default()
};
let result = config.validate();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("continue_last_session and resume_session_id")
);
let config = SessionConfig {
query: "test".to_string(),
resume_session_id: Some("id1".to_string()),
explicit_session_id: Some("id2".to_string()),
..Default::default()
};
let result = config.validate();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("resume_session_id and explicit_session_id")
);
}
#[test]
fn test_session_config_validation_dangerous_permissions() {
let config = SessionConfig {
query: "test".to_string(),
dangerously_skip_permissions: true,
allow_dangerously_skip_permissions: false,
..Default::default()
};
let result = config.validate();
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("enable_dangerous_permissions")
);
let config = SessionConfig {
query: "test".to_string(),
dangerously_skip_permissions: true,
allow_dangerously_skip_permissions: true,
..Default::default()
};
let result = config.validate();
assert!(result.is_ok());
}
#[test]
fn test_enable_dangerous_permissions() {
let config = SessionConfig::builder("test")
.enable_dangerous_permissions()
.build()
.unwrap();
assert!(config.dangerously_skip_permissions);
assert!(config.allow_dangerously_skip_permissions);
}
#[test]
fn test_session_config_builder() {
let config = SessionConfig::builder("my query")
.resume_session_id("test-id")
.model(Model::Sonnet)
.output_format(OutputFormat::Json)
.verbose(true)
.build()
.unwrap();
assert_eq!(config.query, "my query");
assert_eq!(config.resume_session_id.as_deref(), Some("test-id"));
assert_eq!(config.model, Some(Model::Sonnet));
assert_eq!(config.output_format, OutputFormat::Json);
assert!(config.verbose);
}
#[test]
fn test_session_config_builder_new_fields() {
let config = SessionConfig::builder("query")
.fallback_model(Model::Haiku)
.input_format(InputFormat::StreamJson)
.permission_mode(PermissionMode::AcceptEdits)
.strict_mcp_config(true)
.json_schema(r#"{"type":"object"}"#)
.include_partial_messages(true)
.replay_user_messages(true)
.tools(vec!["Read".to_string(), "Write".to_string()])
.settings(r#"{"key":"value"}"#)
.setting_sources(vec!["source1".to_string()])
.add_dir("/tmp/dir1")
.add_dir("/tmp/dir2")
.plugin_dir("/tmp/plugins")
.ide(true)
.agents(r#"{"agent":"config"}"#)
.debug(true)
.debug_filter("filter*")
.env_var("KEY", "VALUE")
.build()
.unwrap();
assert_eq!(config.fallback_model, Some(Model::Haiku));
assert_eq!(config.input_format, Some(InputFormat::StreamJson));
assert_eq!(config.permission_mode, Some(PermissionMode::AcceptEdits));
assert!(config.strict_mcp_config);
assert_eq!(config.json_schema.as_deref(), Some(r#"{"type":"object"}"#));
assert!(config.include_partial_messages);
assert!(config.replay_user_messages);
assert_eq!(
config.tools,
Some(vec!["Read".to_string(), "Write".to_string()])
);
assert_eq!(config.settings.as_deref(), Some(r#"{"key":"value"}"#));
assert_eq!(config.setting_sources, Some(vec!["source1".to_string()]));
assert_eq!(config.additional_dirs.len(), 2);
assert_eq!(config.plugin_dirs.len(), 1);
assert!(config.ide);
assert_eq!(config.agents.as_deref(), Some(r#"{"agent":"config"}"#));
assert!(config.debug);
assert_eq!(config.debug_filter.as_deref(), Some("filter*"));
assert_eq!(config.env.as_ref().unwrap().get("KEY").unwrap(), "VALUE");
}
#[test]
fn test_default_output_format() {
let config = SessionConfig::builder("test").build().unwrap();
assert_eq!(config.output_format, OutputFormat::StreamingJson);
}
#[test]
fn test_mcp_config_serialization_stdio() {
let mut servers = HashMap::new();
servers.insert(
"test".to_string(),
MCPServer::stdio("cmd", vec!["arg1".to_string(), "arg2".to_string()]),
);
let mcp_config = MCPConfig {
mcp_servers: servers,
};
let json = serde_json::to_string(&mcp_config).unwrap();
assert!(json.contains(r#""type":"stdio""#));
let deserialized: MCPConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.mcp_servers.len(), 1);
assert!(deserialized.mcp_servers.contains_key("test"));
match &deserialized.mcp_servers["test"] {
MCPServer::Stdio { command, args, env } => {
assert_eq!(command, "cmd");
assert_eq!(args, &vec!["arg1".to_string(), "arg2".to_string()]);
assert!(env.is_none());
}
MCPServer::Http { .. } => panic!("Expected Stdio server"),
}
}
#[test]
fn test_mcp_config_serialization_http() {
let mut servers = HashMap::new();
let mut headers = HashMap::new();
headers.insert("Authorization".to_string(), "Bearer token".to_string());
servers.insert(
"http-server".to_string(),
MCPServer::http_with_headers("https://example.com/mcp", headers),
);
let mcp_config = MCPConfig {
mcp_servers: servers,
};
let json = serde_json::to_string(&mcp_config).unwrap();
assert!(json.contains(r#""type":"http""#));
let deserialized: MCPConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.mcp_servers.len(), 1);
assert!(deserialized.mcp_servers.contains_key("http-server"));
match &deserialized.mcp_servers["http-server"] {
MCPServer::Http { url, headers } => {
assert_eq!(url, "https://example.com/mcp");
assert!(headers.is_some());
assert_eq!(headers.as_ref().unwrap()["Authorization"], "Bearer token");
}
MCPServer::Stdio { .. } => panic!("Expected Http server"),
}
}
#[test]
fn test_mcp_config_mixed_servers() {
let mut servers = HashMap::new();
servers.insert(
"stdio-server".to_string(),
MCPServer::stdio("node", vec!["server.js".to_string()]),
);
servers.insert(
"http-server".to_string(),
MCPServer::http("https://api.example.com/mcp"),
);
let mcp_config = MCPConfig {
mcp_servers: servers,
};
let json = serde_json::to_string(&mcp_config).unwrap();
let deserialized: MCPConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.mcp_servers.len(), 2);
assert!(matches!(
&deserialized.mcp_servers["stdio-server"],
MCPServer::Stdio { .. }
));
assert!(matches!(
&deserialized.mcp_servers["http-server"],
MCPServer::Http { .. }
));
}
}