Skip to main content

claude_wrapper/
session.rs

1//! Multi-turn session management.
2//!
3//! A [`Session`] threads Claude's `session_id` across turns automatically,
4//! so callers never need to scrape it out of a result event or pass
5//! `--resume` by hand.
6//!
7//! # Ownership
8//!
9//! [`Session`] holds an `Arc<Claude>`. Wrapping the client in an `Arc`
10//! means a session can outlive the original client binding, be moved
11//! between tasks, and sit inside long-lived actor state -- which is the
12//! usage shape that callers like centralino-rs need. One `Arc::clone`
13//! per session is a negligible cost in exchange.
14//!
15//! # Two entry points
16//!
17//! - [`Session::send`] takes a plain prompt. Use this for straightforward
18//!   multi-turn chat.
19//! - [`Session::execute`] takes a fully-configured [`QueryCommand`]. Use
20//!   this when you want per-turn options like `model`, `max_turns`,
21//!   `permission_mode`, etc. The session automatically overrides any
22//!   session-related flags on the command (`--resume`, `--continue`,
23//!   `--session-id`, `--fork-session`) so they can't conflict.
24//!
25//! Streaming follows the same split: [`Session::stream`] and
26//! [`Session::stream_execute`].
27//!
28//! # Example
29//!
30//! ```no_run
31//! use std::sync::Arc;
32//! use claude_wrapper::{Claude, QueryCommand};
33//! use claude_wrapper::session::Session;
34//!
35//! # async fn example() -> claude_wrapper::Result<()> {
36//! let claude = Arc::new(Claude::builder().build()?);
37//!
38//! let mut session = Session::new(Arc::clone(&claude));
39//!
40//! // Simple path
41//! let first = session.send("explain quicksort").await?;
42//!
43//! // Full control: custom model, effort, permission mode, etc.
44//! let second = session
45//!     .execute(QueryCommand::new("now mergesort").model("opus"))
46//!     .await?;
47//!
48//! println!("total cost: ${:.4}", session.total_cost_usd());
49//! println!("turns: {}", session.total_turns());
50//! # Ok(())
51//! # }
52//! ```
53//!
54//! # Resuming an existing session
55//!
56//! ```no_run
57//! # use std::sync::Arc;
58//! # use claude_wrapper::{Claude};
59//! # use claude_wrapper::session::Session;
60//! # async fn example() -> claude_wrapper::Result<()> {
61//! # let claude = Arc::new(Claude::builder().build()?);
62//! // Reattach to a session you stored earlier
63//! let mut session = Session::resume(claude, "sess-abc123");
64//! let result = session.send("pick up where we left off").await?;
65//! # Ok(())
66//! # }
67//! ```
68
69use std::sync::Arc;
70
71use crate::Claude;
72use crate::budget::BudgetTracker;
73use crate::command::query::QueryCommand;
74use crate::error::Result;
75use crate::types::QueryResult;
76
77#[cfg(feature = "json")]
78use crate::streaming::{StreamEvent, stream_query};
79
80/// A multi-turn conversation handle.
81///
82/// Owns an `Arc<Claude>` so it can be moved between tasks and live
83/// inside long-running actors. Tracks `session_id`, cumulative cost,
84/// turn count, and per-turn result history.
85#[derive(Debug, Clone)]
86pub struct Session {
87    claude: Arc<Claude>,
88    session_id: Option<String>,
89    history: Vec<QueryResult>,
90    cumulative_cost_usd: f64,
91    cumulative_turns: u32,
92    budget: Option<BudgetTracker>,
93}
94
95impl Session {
96    /// Start a fresh session. The first turn will discover a session id
97    /// from its result; subsequent turns reuse it via `--resume`.
98    pub fn new(claude: Arc<Claude>) -> Self {
99        Self {
100            claude,
101            session_id: None,
102            history: Vec::new(),
103            cumulative_cost_usd: 0.0,
104            cumulative_turns: 0,
105            budget: None,
106        }
107    }
108
109    /// Reattach to an existing session by id. The next turn immediately
110    /// passes `--resume <id>`. Cost and turn counters start at zero
111    /// since no history is available.
112    pub fn resume(claude: Arc<Claude>, session_id: impl Into<String>) -> Self {
113        Self {
114            claude,
115            session_id: Some(session_id.into()),
116            history: Vec::new(),
117            cumulative_cost_usd: 0.0,
118            cumulative_turns: 0,
119            budget: None,
120        }
121    }
122
123    /// Attach a [`BudgetTracker`] to this session. Every turn's cost
124    /// (from [`QueryResult::cost_usd`]) is recorded on the tracker, and
125    /// [`Session::execute`]/[`Session::stream_execute`] return
126    /// [`crate::error::Error::BudgetExceeded`]
127    /// before dispatching a turn if the tracker's ceiling has been hit.
128    ///
129    /// Clone a tracker across several sessions to enforce a shared
130    /// ceiling; each `Session` then sees the same running total.
131    pub fn with_budget(mut self, budget: BudgetTracker) -> Self {
132        self.budget = Some(budget);
133        self
134    }
135
136    /// The attached [`BudgetTracker`], if any.
137    pub fn budget(&self) -> Option<&BudgetTracker> {
138        self.budget.as_ref()
139    }
140
141    /// Send a plain-prompt turn. Equivalent to
142    /// `execute(QueryCommand::new(prompt))`.
143    #[cfg(feature = "json")]
144    pub async fn send(&mut self, prompt: impl Into<String>) -> Result<QueryResult> {
145        self.execute(QueryCommand::new(prompt)).await
146    }
147
148    /// Send a turn with a fully-configured [`QueryCommand`].
149    ///
150    /// Any session-related flags on `cmd` (`--resume`, `--continue`,
151    /// `--session-id`, `--fork-session`) are overridden with this
152    /// session's current id, so they can't conflict.
153    #[cfg(feature = "json")]
154    pub async fn execute(&mut self, cmd: QueryCommand) -> Result<QueryResult> {
155        if let Some(b) = &self.budget {
156            b.check()?;
157        }
158
159        let cmd = match &self.session_id {
160            Some(id) => cmd.replace_session(id),
161            None => cmd,
162        };
163
164        let result = cmd.execute_json(&self.claude).await?;
165        self.record(&result);
166        Ok(result)
167    }
168
169    /// Stream a plain-prompt turn, dispatching each NDJSON event to
170    /// `handler`. The session's id is captured from the first event
171    /// that carries one, so subsequent turns can resume, and the id
172    /// persists even if the stream errors partway through.
173    #[cfg(feature = "json")]
174    pub async fn stream<F>(&mut self, prompt: impl Into<String>, handler: F) -> Result<()>
175    where
176        F: FnMut(StreamEvent),
177    {
178        self.stream_execute(QueryCommand::new(prompt), handler)
179            .await
180    }
181
182    /// Stream a turn with a fully-configured [`QueryCommand`], with the
183    /// same session-id capture semantics as [`Session::stream`].
184    ///
185    /// The command's output format is forced to `stream-json` and any
186    /// session-related flags are overridden as in [`Session::execute`].
187    #[cfg(feature = "json")]
188    pub async fn stream_execute<F>(&mut self, cmd: QueryCommand, mut handler: F) -> Result<()>
189    where
190        F: FnMut(StreamEvent),
191    {
192        use crate::types::OutputFormat;
193
194        if let Some(b) = &self.budget {
195            b.check()?;
196        }
197
198        let cmd = match &self.session_id {
199            Some(id) => cmd.replace_session(id),
200            None => cmd,
201        }
202        .output_format(OutputFormat::StreamJson);
203
204        // Capture session_id and result state from events inside a
205        // wrapper closure. The captures happen before the caller's
206        // handler runs, and self is updated after the stream completes
207        // (even on error) so id persists across partial failures.
208        let mut captured_session_id: Option<String> = None;
209        let mut captured_result: Option<QueryResult> = None;
210
211        let outcome = {
212            let wrap = |event: StreamEvent| {
213                if captured_session_id.is_none()
214                    && let Some(sid) = event.session_id()
215                {
216                    captured_session_id = Some(sid.to_string());
217                }
218                if event.is_result()
219                    && captured_result.is_none()
220                    && let Ok(qr) = serde_json::from_value::<QueryResult>(event.data.clone())
221                {
222                    captured_result = Some(qr);
223                }
224                handler(event);
225            };
226            stream_query(&self.claude, &cmd, wrap).await
227        };
228
229        if let Some(sid) = captured_session_id {
230            self.session_id = Some(sid);
231        }
232        if let Some(qr) = captured_result {
233            self.record(&qr);
234        }
235
236        outcome.map(|_| ())
237    }
238
239    /// Current session id, if one has been established.
240    pub fn id(&self) -> Option<&str> {
241        self.session_id.as_deref()
242    }
243
244    /// Cumulative cost in USD across all turns in this session.
245    pub fn total_cost_usd(&self) -> f64 {
246        self.cumulative_cost_usd
247    }
248
249    /// Cumulative turn count across all turns in this session.
250    pub fn total_turns(&self) -> u32 {
251        self.cumulative_turns
252    }
253
254    /// Full per-turn result history.
255    pub fn history(&self) -> &[QueryResult] {
256        &self.history
257    }
258
259    /// Result of the most recent turn, if any.
260    pub fn last_result(&self) -> Option<&QueryResult> {
261        self.history.last()
262    }
263
264    fn record(&mut self, result: &QueryResult) {
265        self.session_id = Some(result.session_id.clone());
266        let cost = result.cost_usd.unwrap_or(0.0);
267        self.cumulative_cost_usd += cost;
268        self.cumulative_turns += result.num_turns.unwrap_or(0);
269        if let Some(b) = &self.budget {
270            b.record(cost);
271        }
272        self.history.push(result.clone());
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    fn test_claude() -> Arc<Claude> {
281        Arc::new(
282            Claude::builder()
283                .binary("/usr/local/bin/claude")
284                .build()
285                .unwrap(),
286        )
287    }
288
289    #[test]
290    fn new_session_has_no_id() {
291        let session = Session::new(test_claude());
292        assert!(session.id().is_none());
293        assert_eq!(session.total_cost_usd(), 0.0);
294        assert_eq!(session.total_turns(), 0);
295        assert!(session.history().is_empty());
296        assert!(session.last_result().is_none());
297    }
298
299    #[test]
300    fn resume_session_has_preset_id() {
301        let session = Session::resume(test_claude(), "sess-abc");
302        assert_eq!(session.id(), Some("sess-abc"));
303        assert_eq!(session.total_cost_usd(), 0.0);
304        assert_eq!(session.total_turns(), 0);
305    }
306
307    #[test]
308    fn record_updates_state() {
309        let mut session = Session::new(test_claude());
310        let result = QueryResult {
311            result: "ok".into(),
312            session_id: "sess-1".into(),
313            cost_usd: Some(0.05),
314            duration_ms: None,
315            num_turns: Some(3),
316            is_error: false,
317            extra: Default::default(),
318        };
319        session.record(&result);
320        assert_eq!(session.id(), Some("sess-1"));
321        assert!((session.total_cost_usd() - 0.05).abs() < f64::EPSILON);
322        assert_eq!(session.total_turns(), 3);
323        assert_eq!(session.history().len(), 1);
324        assert_eq!(
325            session.last_result().map(|r| r.session_id.as_str()),
326            Some("sess-1")
327        );
328    }
329
330    #[test]
331    fn record_accumulates_across_turns() {
332        let mut session = Session::new(test_claude());
333        let r1 = QueryResult {
334            result: "a".into(),
335            session_id: "sess-1".into(),
336            cost_usd: Some(0.01),
337            duration_ms: None,
338            num_turns: Some(2),
339            is_error: false,
340            extra: Default::default(),
341        };
342        let r2 = QueryResult {
343            result: "b".into(),
344            session_id: "sess-1".into(),
345            cost_usd: Some(0.02),
346            duration_ms: None,
347            num_turns: Some(1),
348            is_error: false,
349            extra: Default::default(),
350        };
351        session.record(&r1);
352        session.record(&r2);
353        assert_eq!(session.total_turns(), 3);
354        assert!((session.total_cost_usd() - 0.03).abs() < f64::EPSILON);
355        assert_eq!(session.history().len(), 2);
356    }
357
358    #[test]
359    fn record_forwards_cost_to_budget() {
360        use crate::budget::BudgetTracker;
361
362        let budget = BudgetTracker::builder().build();
363        let mut session = Session::new(test_claude()).with_budget(budget.clone());
364
365        let r = QueryResult {
366            result: "ok".into(),
367            session_id: "sess-1".into(),
368            cost_usd: Some(0.07),
369            duration_ms: None,
370            num_turns: Some(1),
371            is_error: false,
372            extra: Default::default(),
373        };
374        session.record(&r);
375
376        assert!((budget.total_usd() - 0.07).abs() < 1e-9);
377        assert!((session.total_cost_usd() - 0.07).abs() < 1e-9);
378    }
379
380    #[test]
381    fn budget_pre_check_would_block_next_turn() {
382        use crate::budget::BudgetTracker;
383        use crate::error::Error;
384
385        // The execute() pre-check defers to BudgetTracker::check().
386        // Exercise that directly with a pre-loaded tracker, so we don't
387        // need a live Claude CLI.
388        let budget = BudgetTracker::builder().max_usd(0.10).build();
389        budget.record(0.15);
390
391        let session = Session::new(test_claude()).with_budget(budget);
392        match session.budget().unwrap().check() {
393            Err(Error::BudgetExceeded { total_usd, max_usd }) => {
394                assert!((total_usd - 0.15).abs() < 1e-9);
395                assert!((max_usd - 0.10).abs() < 1e-9);
396            }
397            other => panic!("expected BudgetExceeded, got {other:?}"),
398        }
399    }
400
401    #[test]
402    fn replace_session_clears_conflicting_flags() {
403        use crate::command::ClaudeCommand;
404
405        // Verify that replace_session strips --continue/--session-id/
406        // --fork-session and sets --resume to the given id.
407        let cmd = QueryCommand::new("hi")
408            .continue_session()
409            .session_id("old")
410            .fork_session()
411            .replace_session("new-id");
412
413        let args = cmd.args();
414        assert!(args.contains(&"--resume".to_string()));
415        assert!(args.contains(&"new-id".to_string()));
416        assert!(!args.contains(&"--continue".to_string()));
417        assert!(!args.contains(&"--session-id".to_string()));
418        assert!(!args.contains(&"--fork-session".to_string()));
419    }
420}