use std::sync::Arc;
use crate::Claude;
use crate::command::query::QueryCommand;
use crate::error::Result;
use crate::types::QueryResult;
#[cfg(feature = "json")]
use crate::streaming::{StreamEvent, stream_query};
#[derive(Debug, Clone)]
pub struct Session {
claude: Arc<Claude>,
session_id: Option<String>,
history: Vec<QueryResult>,
cumulative_cost_usd: f64,
cumulative_turns: u32,
}
impl Session {
pub fn new(claude: Arc<Claude>) -> Self {
Self {
claude,
session_id: None,
history: Vec::new(),
cumulative_cost_usd: 0.0,
cumulative_turns: 0,
}
}
pub fn resume(claude: Arc<Claude>, session_id: impl Into<String>) -> Self {
Self {
claude,
session_id: Some(session_id.into()),
history: Vec::new(),
cumulative_cost_usd: 0.0,
cumulative_turns: 0,
}
}
#[cfg(feature = "json")]
pub async fn send(&mut self, prompt: impl Into<String>) -> Result<QueryResult> {
self.execute(QueryCommand::new(prompt)).await
}
#[cfg(feature = "json")]
pub async fn execute(&mut self, cmd: QueryCommand) -> Result<QueryResult> {
let cmd = match &self.session_id {
Some(id) => cmd.replace_session(id),
None => cmd,
};
let result = cmd.execute_json(&self.claude).await?;
self.record(&result);
Ok(result)
}
#[cfg(feature = "json")]
pub async fn stream<F>(&mut self, prompt: impl Into<String>, handler: F) -> Result<()>
where
F: FnMut(StreamEvent),
{
self.stream_execute(QueryCommand::new(prompt), handler)
.await
}
#[cfg(feature = "json")]
pub async fn stream_execute<F>(&mut self, cmd: QueryCommand, mut handler: F) -> Result<()>
where
F: FnMut(StreamEvent),
{
use crate::types::OutputFormat;
let cmd = match &self.session_id {
Some(id) => cmd.replace_session(id),
None => cmd,
}
.output_format(OutputFormat::StreamJson);
let mut captured_session_id: Option<String> = None;
let mut captured_result: Option<QueryResult> = None;
let outcome = {
let wrap = |event: StreamEvent| {
if captured_session_id.is_none()
&& let Some(sid) = event.session_id()
{
captured_session_id = Some(sid.to_string());
}
if event.is_result()
&& captured_result.is_none()
&& let Ok(qr) = serde_json::from_value::<QueryResult>(event.data.clone())
{
captured_result = Some(qr);
}
handler(event);
};
stream_query(&self.claude, &cmd, wrap).await
};
if let Some(sid) = captured_session_id {
self.session_id = Some(sid);
}
if let Some(qr) = captured_result {
self.record(&qr);
}
outcome.map(|_| ())
}
pub fn id(&self) -> Option<&str> {
self.session_id.as_deref()
}
pub fn total_cost_usd(&self) -> f64 {
self.cumulative_cost_usd
}
pub fn total_turns(&self) -> u32 {
self.cumulative_turns
}
pub fn history(&self) -> &[QueryResult] {
&self.history
}
pub fn last_result(&self) -> Option<&QueryResult> {
self.history.last()
}
fn record(&mut self, result: &QueryResult) {
self.session_id = Some(result.session_id.clone());
self.cumulative_cost_usd += result.cost_usd.unwrap_or(0.0);
self.cumulative_turns += result.num_turns.unwrap_or(0);
self.history.push(result.clone());
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_claude() -> Arc<Claude> {
Arc::new(
Claude::builder()
.binary("/usr/local/bin/claude")
.build()
.unwrap(),
)
}
#[test]
fn new_session_has_no_id() {
let session = Session::new(test_claude());
assert!(session.id().is_none());
assert_eq!(session.total_cost_usd(), 0.0);
assert_eq!(session.total_turns(), 0);
assert!(session.history().is_empty());
assert!(session.last_result().is_none());
}
#[test]
fn resume_session_has_preset_id() {
let session = Session::resume(test_claude(), "sess-abc");
assert_eq!(session.id(), Some("sess-abc"));
assert_eq!(session.total_cost_usd(), 0.0);
assert_eq!(session.total_turns(), 0);
}
#[test]
fn record_updates_state() {
let mut session = Session::new(test_claude());
let result = QueryResult {
result: "ok".into(),
session_id: "sess-1".into(),
cost_usd: Some(0.05),
duration_ms: None,
num_turns: Some(3),
is_error: false,
extra: Default::default(),
};
session.record(&result);
assert_eq!(session.id(), Some("sess-1"));
assert!((session.total_cost_usd() - 0.05).abs() < f64::EPSILON);
assert_eq!(session.total_turns(), 3);
assert_eq!(session.history().len(), 1);
assert_eq!(
session.last_result().map(|r| r.session_id.as_str()),
Some("sess-1")
);
}
#[test]
fn record_accumulates_across_turns() {
let mut session = Session::new(test_claude());
let r1 = QueryResult {
result: "a".into(),
session_id: "sess-1".into(),
cost_usd: Some(0.01),
duration_ms: None,
num_turns: Some(2),
is_error: false,
extra: Default::default(),
};
let r2 = QueryResult {
result: "b".into(),
session_id: "sess-1".into(),
cost_usd: Some(0.02),
duration_ms: None,
num_turns: Some(1),
is_error: false,
extra: Default::default(),
};
session.record(&r1);
session.record(&r2);
assert_eq!(session.total_turns(), 3);
assert!((session.total_cost_usd() - 0.03).abs() < f64::EPSILON);
assert_eq!(session.history().len(), 2);
}
#[test]
fn replace_session_clears_conflicting_flags() {
use crate::command::ClaudeCommand;
let cmd = QueryCommand::new("hi")
.continue_session()
.session_id("old")
.fork_session()
.replace_session("new-id");
let args = cmd.args();
assert!(args.contains(&"--resume".to_string()));
assert!(args.contains(&"new-id".to_string()));
assert!(!args.contains(&"--continue".to_string()));
assert!(!args.contains(&"--session-id".to_string()));
assert!(!args.contains(&"--fork-session".to_string()));
}
}