Skip to main content

claude_wrapper/
session.rs

1//! Multi-turn session management for short-lived processes.
2//!
3//! A [`Session`] threads Claude's `session_id` across turns automatically,
4//! so callers never need to scrape it out of a result event or pass
5//! `--resume` by hand.
6//!
7//! # When to use
8//!
9//! [`Session`] is the right fit when each turn can stand on its own
10//! and the host process is short-lived: CLIs, build scripts, batch
11//! jobs, lambdas. Each turn spawns a fresh `claude` subprocess and
12//! resumes the conversation via `--resume <session_id>`.
13//!
14//! For long-running hosts (IDE backends, daemons, agent servers,
15//! chat UIs) where holding a `claude` subprocess open across many
16//! turns is cheap, prefer [`DuplexSession`](crate::duplex::DuplexSession).
17//! It supports mid-turn interrupts, mid-turn permission decisions, and
18//! a broadcast event stream that [`Session`] cannot offer because of
19//! the transient subprocess model.
20//!
21//! # Ownership
22//!
23//! [`Session`] holds an `Arc<Claude>`. Wrapping the client in an `Arc`
24//! means a session can outlive the original client binding, be moved
25//! between tasks, and sit inside long-lived actor state -- which is the
26//! usage shape that callers like centralino-rs need. One `Arc::clone`
27//! per session is a negligible cost in exchange.
28//!
29//! # Two entry points
30//!
31//! - [`Session::send`] takes a plain prompt. Use this for straightforward
32//!   multi-turn chat.
33//! - [`Session::execute`] takes a fully-configured [`QueryCommand`]. Use
34//!   this when you want per-turn options like `model`, `max_turns`,
35//!   `permission_mode`, etc. The session automatically overrides any
36//!   session-related flags on the command (`--resume`, `--continue`,
37//!   `--session-id`, `--fork-session`) so they can't conflict.
38//!
39//! Streaming follows the same split: [`Session::stream`] and
40//! [`Session::stream_execute`].
41//!
42//! # Example
43//!
44//! ```no_run
45//! use std::sync::Arc;
46//! use claude_wrapper::{Claude, QueryCommand};
47//! use claude_wrapper::session::Session;
48//!
49//! # async fn example() -> claude_wrapper::Result<()> {
50//! let claude = Arc::new(Claude::builder().build()?);
51//!
52//! let mut session = Session::new(Arc::clone(&claude));
53//!
54//! // Simple path
55//! let first = session.send("explain quicksort").await?;
56//!
57//! // Full control: custom model, effort, permission mode, etc.
58//! let second = session
59//!     .execute(QueryCommand::new("now mergesort").model("opus"))
60//!     .await?;
61//!
62//! println!("total cost: ${:.4}", session.total_cost_usd());
63//! println!("turns: {}", session.total_turns());
64//! # Ok(())
65//! # }
66//! ```
67//!
68//! # Resuming an existing session
69//!
70//! ```no_run
71//! # use std::sync::Arc;
72//! # use claude_wrapper::{Claude};
73//! # use claude_wrapper::session::Session;
74//! # async fn example() -> claude_wrapper::Result<()> {
75//! # let claude = Arc::new(Claude::builder().build()?);
76//! // Reattach to a session you stored earlier
77//! let mut session = Session::resume(claude, "sess-abc123");
78//! let result = session.send("pick up where we left off").await?;
79//! # Ok(())
80//! # }
81//! ```
82
83use std::sync::Arc;
84
85use crate::Claude;
86use crate::budget::BudgetTracker;
87use crate::command::query::QueryCommand;
88use crate::error::Result;
89use crate::types::QueryResult;
90
91#[cfg(feature = "json")]
92use crate::streaming::{StreamEvent, stream_query};
93
94/// A multi-turn conversation handle.
95///
96/// Owns an `Arc<Claude>` so it can be moved between tasks and live
97/// inside long-running actors. Tracks `session_id`, cumulative cost,
98/// turn count, and per-turn result history.
99#[derive(Debug, Clone)]
100pub struct Session {
101    claude: Arc<Claude>,
102    session_id: Option<String>,
103    history: Vec<QueryResult>,
104    cumulative_cost_usd: f64,
105    cumulative_turns: u32,
106    budget: Option<BudgetTracker>,
107}
108
109impl Session {
110    /// Start a fresh session. The first turn will discover a session id
111    /// from its result; subsequent turns reuse it via `--resume`.
112    pub fn new(claude: Arc<Claude>) -> Self {
113        Self {
114            claude,
115            session_id: None,
116            history: Vec::new(),
117            cumulative_cost_usd: 0.0,
118            cumulative_turns: 0,
119            budget: None,
120        }
121    }
122
123    /// Reattach to an existing session by id. The next turn immediately
124    /// passes `--resume <id>`. Cost and turn counters start at zero
125    /// since no history is available.
126    pub fn resume(claude: Arc<Claude>, session_id: impl Into<String>) -> Self {
127        Self {
128            claude,
129            session_id: Some(session_id.into()),
130            history: Vec::new(),
131            cumulative_cost_usd: 0.0,
132            cumulative_turns: 0,
133            budget: None,
134        }
135    }
136
137    /// Attach a [`BudgetTracker`] to this session. Every turn's cost
138    /// (from [`QueryResult::cost_usd`]) is recorded on the tracker, and
139    /// [`Session::execute`]/[`Session::stream_execute`] return
140    /// [`crate::error::Error::BudgetExceeded`]
141    /// before dispatching a turn if the tracker's ceiling has been hit.
142    ///
143    /// Clone a tracker across several sessions to enforce a shared
144    /// ceiling; each `Session` then sees the same running total.
145    pub fn with_budget(mut self, budget: BudgetTracker) -> Self {
146        self.budget = Some(budget);
147        self
148    }
149
150    /// The attached [`BudgetTracker`], if any.
151    pub fn budget(&self) -> Option<&BudgetTracker> {
152        self.budget.as_ref()
153    }
154
155    /// Send a plain-prompt turn. Equivalent to
156    /// `execute(QueryCommand::new(prompt))`.
157    #[cfg(feature = "json")]
158    pub async fn send(&mut self, prompt: impl Into<String>) -> Result<QueryResult> {
159        self.execute(QueryCommand::new(prompt)).await
160    }
161
162    /// Send a turn with a fully-configured [`QueryCommand`].
163    ///
164    /// Any session-related flags on `cmd` (`--resume`, `--continue`,
165    /// `--session-id`, `--fork-session`) are overridden with this
166    /// session's current id, so they can't conflict.
167    #[cfg(feature = "json")]
168    pub async fn execute(&mut self, cmd: QueryCommand) -> Result<QueryResult> {
169        if let Some(b) = &self.budget {
170            b.check()?;
171        }
172
173        let cmd = match &self.session_id {
174            Some(id) => cmd.replace_session(id),
175            None => cmd,
176        };
177
178        let result = cmd.execute_json(&self.claude).await?;
179        self.record(&result);
180        Ok(result)
181    }
182
183    /// Stream a plain-prompt turn, dispatching each NDJSON event to
184    /// `handler`. The session's id is captured from the first event
185    /// that carries one, so subsequent turns can resume, and the id
186    /// persists even if the stream errors partway through.
187    #[cfg(feature = "json")]
188    pub async fn stream<F>(&mut self, prompt: impl Into<String>, handler: F) -> Result<()>
189    where
190        F: FnMut(StreamEvent),
191    {
192        self.stream_execute(QueryCommand::new(prompt), handler)
193            .await
194    }
195
196    /// Stream a turn with a fully-configured [`QueryCommand`], with the
197    /// same session-id capture semantics as [`Session::stream`].
198    ///
199    /// The command's output format is forced to `stream-json` and any
200    /// session-related flags are overridden as in [`Session::execute`].
201    #[cfg(feature = "json")]
202    pub async fn stream_execute<F>(&mut self, cmd: QueryCommand, mut handler: F) -> Result<()>
203    where
204        F: FnMut(StreamEvent),
205    {
206        use crate::types::OutputFormat;
207
208        if let Some(b) = &self.budget {
209            b.check()?;
210        }
211
212        let cmd = match &self.session_id {
213            Some(id) => cmd.replace_session(id),
214            None => cmd,
215        }
216        .output_format(OutputFormat::StreamJson);
217
218        // Capture session_id and result state from events inside a
219        // wrapper closure. The captures happen before the caller's
220        // handler runs, and self is updated after the stream completes
221        // (even on error) so id persists across partial failures.
222        let mut captured_session_id: Option<String> = None;
223        let mut captured_result: Option<QueryResult> = None;
224
225        let outcome = {
226            let wrap = |event: StreamEvent| {
227                if captured_session_id.is_none()
228                    && let Some(sid) = event.session_id()
229                {
230                    captured_session_id = Some(sid.to_string());
231                }
232                if event.is_result()
233                    && captured_result.is_none()
234                    && let Ok(qr) = serde_json::from_value::<QueryResult>(event.data.clone())
235                {
236                    captured_result = Some(qr);
237                }
238                handler(event);
239            };
240            stream_query(&self.claude, &cmd, wrap).await
241        };
242
243        if let Some(sid) = captured_session_id {
244            self.session_id = Some(sid);
245        }
246        if let Some(qr) = captured_result {
247            self.record(&qr);
248        }
249
250        outcome.map(|_| ())
251    }
252
253    /// Current session id, if one has been established.
254    pub fn id(&self) -> Option<&str> {
255        self.session_id.as_deref()
256    }
257
258    /// Cumulative cost in USD across all turns in this session.
259    pub fn total_cost_usd(&self) -> f64 {
260        self.cumulative_cost_usd
261    }
262
263    /// Cumulative turn count across all turns in this session.
264    pub fn total_turns(&self) -> u32 {
265        self.cumulative_turns
266    }
267
268    /// Full per-turn result history.
269    pub fn history(&self) -> &[QueryResult] {
270        &self.history
271    }
272
273    /// Result of the most recent turn, if any.
274    pub fn last_result(&self) -> Option<&QueryResult> {
275        self.history.last()
276    }
277
278    fn record(&mut self, result: &QueryResult) {
279        self.session_id = Some(result.session_id.clone());
280        let cost = result.cost_usd.unwrap_or(0.0);
281        self.cumulative_cost_usd += cost;
282        self.cumulative_turns += result.num_turns.unwrap_or(0);
283        if let Some(b) = &self.budget {
284            b.record(cost);
285        }
286        self.history.push(result.clone());
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    fn test_claude() -> Arc<Claude> {
295        Arc::new(
296            Claude::builder()
297                .binary("/usr/local/bin/claude")
298                .build()
299                .unwrap(),
300        )
301    }
302
303    #[test]
304    fn new_session_has_no_id() {
305        let session = Session::new(test_claude());
306        assert!(session.id().is_none());
307        assert_eq!(session.total_cost_usd(), 0.0);
308        assert_eq!(session.total_turns(), 0);
309        assert!(session.history().is_empty());
310        assert!(session.last_result().is_none());
311    }
312
313    #[test]
314    fn resume_session_has_preset_id() {
315        let session = Session::resume(test_claude(), "sess-abc");
316        assert_eq!(session.id(), Some("sess-abc"));
317        assert_eq!(session.total_cost_usd(), 0.0);
318        assert_eq!(session.total_turns(), 0);
319    }
320
321    #[test]
322    fn record_updates_state() {
323        let mut session = Session::new(test_claude());
324        let result = QueryResult {
325            result: "ok".into(),
326            session_id: "sess-1".into(),
327            cost_usd: Some(0.05),
328            duration_ms: None,
329            num_turns: Some(3),
330            is_error: false,
331            extra: Default::default(),
332        };
333        session.record(&result);
334        assert_eq!(session.id(), Some("sess-1"));
335        assert!((session.total_cost_usd() - 0.05).abs() < f64::EPSILON);
336        assert_eq!(session.total_turns(), 3);
337        assert_eq!(session.history().len(), 1);
338        assert_eq!(
339            session.last_result().map(|r| r.session_id.as_str()),
340            Some("sess-1")
341        );
342    }
343
344    #[test]
345    fn record_accumulates_across_turns() {
346        let mut session = Session::new(test_claude());
347        let r1 = QueryResult {
348            result: "a".into(),
349            session_id: "sess-1".into(),
350            cost_usd: Some(0.01),
351            duration_ms: None,
352            num_turns: Some(2),
353            is_error: false,
354            extra: Default::default(),
355        };
356        let r2 = QueryResult {
357            result: "b".into(),
358            session_id: "sess-1".into(),
359            cost_usd: Some(0.02),
360            duration_ms: None,
361            num_turns: Some(1),
362            is_error: false,
363            extra: Default::default(),
364        };
365        session.record(&r1);
366        session.record(&r2);
367        assert_eq!(session.total_turns(), 3);
368        assert!((session.total_cost_usd() - 0.03).abs() < f64::EPSILON);
369        assert_eq!(session.history().len(), 2);
370    }
371
372    #[test]
373    fn record_forwards_cost_to_budget() {
374        use crate::budget::BudgetTracker;
375
376        let budget = BudgetTracker::builder().build();
377        let mut session = Session::new(test_claude()).with_budget(budget.clone());
378
379        let r = QueryResult {
380            result: "ok".into(),
381            session_id: "sess-1".into(),
382            cost_usd: Some(0.07),
383            duration_ms: None,
384            num_turns: Some(1),
385            is_error: false,
386            extra: Default::default(),
387        };
388        session.record(&r);
389
390        assert!((budget.total_usd() - 0.07).abs() < 1e-9);
391        assert!((session.total_cost_usd() - 0.07).abs() < 1e-9);
392    }
393
394    #[test]
395    fn budget_pre_check_would_block_next_turn() {
396        use crate::budget::BudgetTracker;
397        use crate::error::Error;
398
399        // The execute() pre-check defers to BudgetTracker::check().
400        // Exercise that directly with a pre-loaded tracker, so we don't
401        // need a live Claude CLI.
402        let budget = BudgetTracker::builder().max_usd(0.10).build();
403        budget.record(0.15);
404
405        let session = Session::new(test_claude()).with_budget(budget);
406        match session.budget().unwrap().check() {
407            Err(Error::BudgetExceeded { total_usd, max_usd }) => {
408                assert!((total_usd - 0.15).abs() < 1e-9);
409                assert!((max_usd - 0.10).abs() < 1e-9);
410            }
411            other => panic!("expected BudgetExceeded, got {other:?}"),
412        }
413    }
414
415    #[test]
416    fn replace_session_clears_conflicting_flags() {
417        use crate::command::ClaudeCommand;
418
419        // Verify that replace_session strips --continue/--session-id/
420        // --fork-session and sets --resume to the given id.
421        let cmd = QueryCommand::new("hi")
422            .continue_session()
423            .session_id("old")
424            .fork_session()
425            .replace_session("new-id");
426
427        let args = cmd.args();
428        assert!(args.contains(&"--resume".to_string()));
429        assert!(args.contains(&"new-id".to_string()));
430        assert!(!args.contains(&"--continue".to_string()));
431        assert!(!args.contains(&"--session-id".to_string()));
432        assert!(!args.contains(&"--fork-session".to_string()));
433    }
434}