use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
use bon::Builder;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Builder, Serialize, Deserialize)]
pub struct SessionConfig<C> {
#[builder(into)]
pub model: Option<String>,
pub cwd: Option<PathBuf>,
#[builder(default)]
pub additional_dirs: Vec<PathBuf>,
#[builder(default)]
pub env: HashMap<String, String>,
#[builder(default)]
pub env_remove: HashSet<String>,
pub max_turns: Option<u32>,
#[builder(into)]
pub system_prompt: Option<String>,
pub backend: C,
}
impl<C> SessionConfig<C> {
pub fn with_backend<D>(self, backend: D) -> SessionConfig<D> {
self.map_backend(|_| backend)
}
pub fn map_backend<D>(self, f: impl FnOnce(C) -> D) -> SessionConfig<D> {
let SessionConfig {
backend,
model,
cwd,
additional_dirs,
env,
env_remove,
max_turns,
system_prompt,
} = self;
SessionConfig {
backend: f(backend),
model,
cwd,
additional_dirs,
env,
env_remove,
max_turns,
system_prompt,
}
}
pub fn try_map_backend<D, E>(
self,
f: impl FnOnce(C) -> std::result::Result<D, E>,
) -> std::result::Result<SessionConfig<D>, E> {
let SessionConfig {
backend,
model,
cwd,
additional_dirs,
env,
env_remove,
max_turns,
system_prompt,
} = self;
Ok(SessionConfig {
backend: f(backend)?,
model,
cwd,
additional_dirs,
env,
env_remove,
max_turns,
system_prompt,
})
}
}
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize)]
pub struct TurnOptions {
pub output_schema: Option<serde_json::Value>,
#[serde(default)]
pub timeout: Option<std::time::Duration>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum Input {
Text(String),
Structured(Vec<InputPart>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum InputPart {
Text(String),
Image(PathBuf),
}
impl From<String> for Input {
fn from(s: String) -> Self {
Input::Text(s)
}
}
impl From<&str> for Input {
fn from(s: &str) -> Self {
Input::Text(s.to_owned())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn session_config_builder() {
let config = SessionConfig::builder()
.model("gpt-4")
.cwd(PathBuf::from("/tmp"))
.backend("test-backend")
.build();
assert_eq!(config.model.as_deref(), Some("gpt-4"));
assert_eq!(config.cwd, Some(PathBuf::from("/tmp")));
assert_eq!(config.backend, "test-backend");
assert!(config.additional_dirs.is_empty());
assert!(config.env.is_empty());
}
#[test]
fn with_backend_preserves_shared_fields() {
let config = SessionConfig::builder()
.model("gpt-4")
.cwd(PathBuf::from("/tmp"))
.max_turns(5)
.backend("original")
.build();
let swapped = config.with_backend(42u32);
assert_eq!(swapped.model.as_deref(), Some("gpt-4"));
assert_eq!(swapped.cwd, Some(PathBuf::from("/tmp")));
assert_eq!(swapped.max_turns, Some(5));
assert_eq!(swapped.backend, 42);
}
#[test]
fn map_backend_transforms() {
let config = SessionConfig::builder()
.model("gpt-4")
.backend("hello")
.build();
let mapped = config.map_backend(|s: &str| s.len());
assert_eq!(mapped.model.as_deref(), Some("gpt-4"));
assert_eq!(mapped.backend, 5);
}
#[test]
fn try_map_backend_ok() {
let mut env = HashMap::new();
env.insert("KEY".into(), "VAL".into());
let config = SessionConfig::builder()
.model("gpt-4")
.cwd(PathBuf::from("/tmp"))
.additional_dirs(vec![PathBuf::from("/extra")])
.env(env)
.max_turns(10)
.system_prompt("be helpful")
.backend(42u32)
.build();
let result: std::result::Result<SessionConfig<String>, &str> =
config.try_map_backend(|n| Ok(n.to_string()));
let mapped = result.unwrap();
assert_eq!(mapped.backend, "42");
assert_eq!(mapped.model.as_deref(), Some("gpt-4"));
assert_eq!(mapped.cwd, Some(PathBuf::from("/tmp")));
assert_eq!(mapped.additional_dirs, vec![PathBuf::from("/extra")]);
assert_eq!(mapped.env.get("KEY").unwrap(), "VAL");
assert_eq!(mapped.max_turns, Some(10));
assert_eq!(mapped.system_prompt.as_deref(), Some("be helpful"));
}
#[test]
fn try_map_backend_err() {
let config = SessionConfig::builder().backend(42u32).build();
let result: std::result::Result<SessionConfig<String>, &str> =
config.try_map_backend(|_| Err("mismatch"));
assert_eq!(result.unwrap_err(), "mismatch");
}
#[test]
fn env_remove_preserved_through_map_backend() {
let config = SessionConfig::builder()
.env_remove(HashSet::from(["SECRET".into()]))
.backend("original")
.build();
let mapped = config.map_backend(|s: &str| s.len());
assert!(mapped.env_remove.contains("SECRET"));
assert_eq!(mapped.backend, 8);
}
#[test]
fn env_remove_preserved_through_try_map_backend() {
let config = SessionConfig::builder()
.env_remove(HashSet::from(["KEY".into()]))
.backend(42u32)
.build();
let result: std::result::Result<SessionConfig<String>, &str> =
config.try_map_backend(|n| Ok(n.to_string()));
let mapped = result.unwrap();
assert!(mapped.env_remove.contains("KEY"));
}
#[test]
fn input_from_str() {
let input: Input = "hello".into();
match input {
Input::Text(s) => assert_eq!(s, "hello"),
_ => panic!("expected Text variant"),
}
}
#[test]
fn input_from_string() {
let input: Input = String::from("hello").into();
match input {
Input::Text(s) => assert_eq!(s, "hello"),
_ => panic!("expected Text variant"),
}
}
}