use crate::Claude;
use crate::command::query::QueryCommand;
use crate::error::Result;
use crate::types::{Effort, InputFormat, OutputFormat, PermissionMode, QueryResult};
#[derive(Debug)]
pub struct Session<'a> {
claude: &'a Claude,
session_id: String,
cumulative_cost_usd: f64,
cumulative_turns: u32,
}
impl<'a> Session<'a> {
pub fn from_result(claude: &'a Claude, result: &QueryResult) -> Self {
Self {
claude,
session_id: result.session_id.clone(),
cumulative_cost_usd: result.cost_usd.unwrap_or(0.0),
cumulative_turns: result.num_turns.unwrap_or(0),
}
}
pub fn from_id(claude: &'a Claude, session_id: impl Into<String>) -> Self {
Self {
claude,
session_id: session_id.into(),
cumulative_cost_usd: 0.0,
cumulative_turns: 0,
}
}
pub async fn continue_recent(
claude: &'a Claude,
prompt: impl Into<String>,
) -> Result<(Self, QueryResult)> {
let result = QueryCommand::new(prompt)
.continue_session()
.execute_json(claude)
.await?;
let session = Self {
claude,
session_id: result.session_id.clone(),
cumulative_cost_usd: result.cost_usd.unwrap_or(0.0),
cumulative_turns: result.num_turns.unwrap_or(0),
};
Ok((session, result))
}
pub fn query(&mut self, prompt: impl Into<String>) -> SessionQuery<'_, 'a> {
SessionQuery::new(self, prompt)
}
pub async fn fork(&self, prompt: impl Into<String>) -> Result<(Session<'a>, QueryResult)> {
let result = QueryCommand::new(prompt)
.resume(&self.session_id)
.fork_session()
.execute_json(self.claude)
.await?;
let forked = Session {
claude: self.claude,
session_id: result.session_id.clone(),
cumulative_cost_usd: self.cumulative_cost_usd + result.cost_usd.unwrap_or(0.0),
cumulative_turns: self.cumulative_turns + result.num_turns.unwrap_or(0),
};
Ok((forked, result))
}
pub fn id(&self) -> &str {
&self.session_id
}
pub fn total_cost_usd(&self) -> f64 {
self.cumulative_cost_usd
}
pub fn total_turns(&self) -> u32 {
self.cumulative_turns
}
}
#[derive(Debug)]
pub struct SessionQuery<'s, 'a> {
session: &'s mut Session<'a>,
command: QueryCommand,
}
impl<'s, 'a> SessionQuery<'s, 'a> {
fn new(session: &'s mut Session<'a>, prompt: impl Into<String>) -> Self {
let command = QueryCommand::new(prompt).resume(&session.session_id);
Self { session, command }
}
#[must_use]
pub fn model(mut self, model: impl Into<String>) -> Self {
self.command = self.command.model(model);
self
}
#[must_use]
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.command = self.command.system_prompt(prompt);
self
}
#[must_use]
pub fn append_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.command = self.command.append_system_prompt(prompt);
self
}
#[must_use]
pub fn output_format(mut self, format: OutputFormat) -> Self {
self.command = self.command.output_format(format);
self
}
#[must_use]
pub fn max_budget_usd(mut self, budget: f64) -> Self {
self.command = self.command.max_budget_usd(budget);
self
}
#[must_use]
pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
self.command = self.command.permission_mode(mode);
self
}
#[must_use]
pub fn allowed_tools(mut self, tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.command = self.command.allowed_tools(tools);
self
}
#[must_use]
pub fn allowed_tool(mut self, tool: impl Into<String>) -> Self {
self.command = self.command.allowed_tool(tool);
self
}
#[must_use]
pub fn disallowed_tools(mut self, tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.command = self.command.disallowed_tools(tools);
self
}
#[must_use]
pub fn mcp_config(mut self, path: impl Into<String>) -> Self {
self.command = self.command.mcp_config(path);
self
}
#[must_use]
pub fn add_dir(mut self, dir: impl Into<String>) -> Self {
self.command = self.command.add_dir(dir);
self
}
#[must_use]
pub fn effort(mut self, effort: Effort) -> Self {
self.command = self.command.effort(effort);
self
}
#[must_use]
pub fn max_turns(mut self, turns: u32) -> Self {
self.command = self.command.max_turns(turns);
self
}
#[must_use]
pub fn json_schema(mut self, schema: impl Into<String>) -> Self {
self.command = self.command.json_schema(schema);
self
}
#[must_use]
pub fn fallback_model(mut self, model: impl Into<String>) -> Self {
self.command = self.command.fallback_model(model);
self
}
#[must_use]
pub fn no_session_persistence(mut self) -> Self {
self.command = self.command.no_session_persistence();
self
}
#[must_use]
pub fn dangerously_skip_permissions(mut self) -> Self {
self.command = self.command.dangerously_skip_permissions();
self
}
#[must_use]
pub fn agent(mut self, agent: impl Into<String>) -> Self {
self.command = self.command.agent(agent);
self
}
#[must_use]
pub fn agents_json(mut self, json: impl Into<String>) -> Self {
self.command = self.command.agents_json(json);
self
}
#[must_use]
pub fn tools(mut self, tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.command = self.command.tools(tools);
self
}
#[must_use]
pub fn file(mut self, spec: impl Into<String>) -> Self {
self.command = self.command.file(spec);
self
}
#[must_use]
pub fn include_partial_messages(mut self) -> Self {
self.command = self.command.include_partial_messages();
self
}
#[must_use]
pub fn input_format(mut self, format: InputFormat) -> Self {
self.command = self.command.input_format(format);
self
}
#[must_use]
pub fn strict_mcp_config(mut self) -> Self {
self.command = self.command.strict_mcp_config();
self
}
#[must_use]
pub fn settings(mut self, settings: impl Into<String>) -> Self {
self.command = self.command.settings(settings);
self
}
#[must_use]
pub fn retry(mut self, policy: crate::retry::RetryPolicy) -> Self {
self.command = self.command.retry(policy);
self
}
pub async fn execute(self) -> Result<QueryResult> {
let result = self.command.execute_json(self.session.claude).await?;
self.session.cumulative_cost_usd += result.cost_usd.unwrap_or(0.0);
self.session.cumulative_turns += result.num_turns.unwrap_or(0);
self.session.session_id.clone_from(&result.session_id);
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ClaudeCommand;
fn test_claude() -> Claude {
Claude::builder()
.binary("/usr/local/bin/claude")
.build()
.unwrap()
}
fn test_result(session_id: &str, cost: f64, turns: u32) -> QueryResult {
QueryResult {
result: "test".into(),
session_id: session_id.into(),
cost_usd: Some(cost),
duration_ms: None,
num_turns: Some(turns),
is_error: false,
extra: Default::default(),
}
}
#[test]
fn session_from_result_captures_state() {
let claude = test_claude();
let result = test_result("sess-abc", 0.05, 3);
let session = Session::from_result(&claude, &result);
assert_eq!(session.id(), "sess-abc");
assert!((session.total_cost_usd() - 0.05).abs() < f64::EPSILON);
assert_eq!(session.total_turns(), 3);
}
#[test]
fn session_from_id_starts_clean() {
let claude = test_claude();
let session = Session::from_id(&claude, "sess-xyz");
assert_eq!(session.id(), "sess-xyz");
assert!((session.total_cost_usd()).abs() < f64::EPSILON);
assert_eq!(session.total_turns(), 0);
}
#[test]
fn session_from_result_handles_none_cost_and_turns() {
let claude = test_claude();
let result = QueryResult {
result: "ok".into(),
session_id: "s1".into(),
cost_usd: None,
duration_ms: None,
num_turns: None,
is_error: false,
extra: Default::default(),
};
let session = Session::from_result(&claude, &result);
assert_eq!(session.total_cost_usd(), 0.0);
assert_eq!(session.total_turns(), 0);
}
#[test]
fn session_query_sets_resume_flag() {
let claude = test_claude();
let mut session = Session::from_id(&claude, "sess-123");
let sq = session.query("follow up");
let args = sq.command.args();
assert!(args.contains(&"--resume".to_string()));
assert!(args.contains(&"sess-123".to_string()));
}
#[test]
fn session_query_model_delegation() {
let claude = test_claude();
let mut session = Session::from_id(&claude, "sess-123");
let sq = session.query("follow up").model("sonnet");
let args = sq.command.args();
assert!(args.contains(&"--model".to_string()));
assert!(args.contains(&"sonnet".to_string()));
}
#[test]
fn session_query_effort_delegation() {
let claude = test_claude();
let mut session = Session::from_id(&claude, "sess-123");
let sq = session.query("follow up").effort(Effort::High);
let args = sq.command.args();
assert!(args.contains(&"--effort".to_string()));
assert!(args.contains(&"high".to_string()));
}
#[test]
fn session_query_max_turns_delegation() {
let claude = test_claude();
let mut session = Session::from_id(&claude, "sess-123");
let sq = session.query("follow up").max_turns(10);
let args = sq.command.args();
assert!(args.contains(&"--max-turns".to_string()));
assert!(args.contains(&"10".to_string()));
}
#[test]
fn session_query_prompt_is_last_arg() {
let claude = test_claude();
let mut session = Session::from_id(&claude, "sess-123");
let sq = session.query("my prompt");
let args = sq.command.args();
assert_eq!(args.last().unwrap(), "my prompt");
}
#[test]
fn session_query_does_not_have_continue_or_fork() {
let claude = test_claude();
let mut session = Session::from_id(&claude, "sess-123");
let sq = session.query("test");
let args = sq.command.args();
assert!(!args.contains(&"--continue".to_string()));
assert!(!args.contains(&"--fork-session".to_string()));
assert!(!args.contains(&"--session-id".to_string()));
}
}