Skip to main content

claude_wrapper/
session.rs

1//! Multi-turn session management.
2//!
3//! A [`Session`] threads Claude's `session_id` across turns automatically,
4//! so callers never need to scrape it out of a result event or pass
5//! `--resume` by hand.
6//!
7//! # Ownership
8//!
9//! [`Session`] holds an `Arc<Claude>`. Wrapping the client in an `Arc`
10//! means a session can outlive the original client binding, be moved
11//! between tasks, and sit inside long-lived actor state -- which is the
12//! usage shape that callers like centralino-rs need. One `Arc::clone`
13//! per session is a negligible cost in exchange.
14//!
15//! # Two entry points
16//!
17//! - [`Session::send`] takes a plain prompt. Use this for straightforward
18//!   multi-turn chat.
19//! - [`Session::execute`] takes a fully-configured [`QueryCommand`]. Use
20//!   this when you want per-turn options like `model`, `max_turns`,
21//!   `permission_mode`, etc. The session automatically overrides any
22//!   session-related flags on the command (`--resume`, `--continue`,
23//!   `--session-id`, `--fork-session`) so they can't conflict.
24//!
25//! Streaming follows the same split: [`Session::stream`] and
26//! [`Session::stream_execute`].
27//!
28//! # Example
29//!
30//! ```no_run
31//! use std::sync::Arc;
32//! use claude_wrapper::{Claude, QueryCommand};
33//! use claude_wrapper::session::Session;
34//!
35//! # async fn example() -> claude_wrapper::Result<()> {
36//! let claude = Arc::new(Claude::builder().build()?);
37//!
38//! let mut session = Session::new(Arc::clone(&claude));
39//!
40//! // Simple path
41//! let first = session.send("explain quicksort").await?;
42//!
43//! // Full control: custom model, effort, permission mode, etc.
44//! let second = session
45//!     .execute(QueryCommand::new("now mergesort").model("opus"))
46//!     .await?;
47//!
48//! println!("total cost: ${:.4}", session.total_cost_usd());
49//! println!("turns: {}", session.total_turns());
50//! # Ok(())
51//! # }
52//! ```
53//!
54//! # Resuming an existing session
55//!
56//! ```no_run
57//! # use std::sync::Arc;
58//! # use claude_wrapper::{Claude};
59//! # use claude_wrapper::session::Session;
60//! # async fn example() -> claude_wrapper::Result<()> {
61//! # let claude = Arc::new(Claude::builder().build()?);
62//! // Reattach to a session you stored earlier
63//! let mut session = Session::resume(claude, "sess-abc123");
64//! let result = session.send("pick up where we left off").await?;
65//! # Ok(())
66//! # }
67//! ```
68
69use std::sync::Arc;
70
71use crate::Claude;
72use crate::command::query::QueryCommand;
73use crate::error::Result;
74use crate::types::QueryResult;
75
76#[cfg(feature = "json")]
77use crate::streaming::{StreamEvent, stream_query};
78
79/// A multi-turn conversation handle.
80///
81/// Owns an `Arc<Claude>` so it can be moved between tasks and live
82/// inside long-running actors. Tracks `session_id`, cumulative cost,
83/// turn count, and per-turn result history.
84#[derive(Debug, Clone)]
85pub struct Session {
86    claude: Arc<Claude>,
87    session_id: Option<String>,
88    history: Vec<QueryResult>,
89    cumulative_cost_usd: f64,
90    cumulative_turns: u32,
91}
92
93impl Session {
94    /// Start a fresh session. The first turn will discover a session id
95    /// from its result; subsequent turns reuse it via `--resume`.
96    pub fn new(claude: Arc<Claude>) -> Self {
97        Self {
98            claude,
99            session_id: None,
100            history: Vec::new(),
101            cumulative_cost_usd: 0.0,
102            cumulative_turns: 0,
103        }
104    }
105
106    /// Reattach to an existing session by id. The next turn immediately
107    /// passes `--resume <id>`. Cost and turn counters start at zero
108    /// since no history is available.
109    pub fn resume(claude: Arc<Claude>, session_id: impl Into<String>) -> Self {
110        Self {
111            claude,
112            session_id: Some(session_id.into()),
113            history: Vec::new(),
114            cumulative_cost_usd: 0.0,
115            cumulative_turns: 0,
116        }
117    }
118
119    /// Send a plain-prompt turn. Equivalent to
120    /// `execute(QueryCommand::new(prompt))`.
121    #[cfg(feature = "json")]
122    pub async fn send(&mut self, prompt: impl Into<String>) -> Result<QueryResult> {
123        self.execute(QueryCommand::new(prompt)).await
124    }
125
126    /// Send a turn with a fully-configured [`QueryCommand`].
127    ///
128    /// Any session-related flags on `cmd` (`--resume`, `--continue`,
129    /// `--session-id`, `--fork-session`) are overridden with this
130    /// session's current id, so they can't conflict.
131    #[cfg(feature = "json")]
132    pub async fn execute(&mut self, cmd: QueryCommand) -> Result<QueryResult> {
133        let cmd = match &self.session_id {
134            Some(id) => cmd.replace_session(id),
135            None => cmd,
136        };
137
138        let result = cmd.execute_json(&self.claude).await?;
139        self.record(&result);
140        Ok(result)
141    }
142
143    /// Stream a plain-prompt turn, dispatching each NDJSON event to
144    /// `handler`. The session's id is captured from the first event
145    /// that carries one, so subsequent turns can resume, and the id
146    /// persists even if the stream errors partway through.
147    #[cfg(feature = "json")]
148    pub async fn stream<F>(&mut self, prompt: impl Into<String>, handler: F) -> Result<()>
149    where
150        F: FnMut(StreamEvent),
151    {
152        self.stream_execute(QueryCommand::new(prompt), handler)
153            .await
154    }
155
156    /// Stream a turn with a fully-configured [`QueryCommand`], with the
157    /// same session-id capture semantics as [`Session::stream`].
158    ///
159    /// The command's output format is forced to `stream-json` and any
160    /// session-related flags are overridden as in [`Session::execute`].
161    #[cfg(feature = "json")]
162    pub async fn stream_execute<F>(&mut self, cmd: QueryCommand, mut handler: F) -> Result<()>
163    where
164        F: FnMut(StreamEvent),
165    {
166        use crate::types::OutputFormat;
167
168        let cmd = match &self.session_id {
169            Some(id) => cmd.replace_session(id),
170            None => cmd,
171        }
172        .output_format(OutputFormat::StreamJson);
173
174        // Capture session_id and result state from events inside a
175        // wrapper closure. The captures happen before the caller's
176        // handler runs, and self is updated after the stream completes
177        // (even on error) so id persists across partial failures.
178        let mut captured_session_id: Option<String> = None;
179        let mut captured_result: Option<QueryResult> = None;
180
181        let outcome = {
182            let wrap = |event: StreamEvent| {
183                if captured_session_id.is_none()
184                    && let Some(sid) = event.session_id()
185                {
186                    captured_session_id = Some(sid.to_string());
187                }
188                if event.is_result()
189                    && captured_result.is_none()
190                    && let Ok(qr) = serde_json::from_value::<QueryResult>(event.data.clone())
191                {
192                    captured_result = Some(qr);
193                }
194                handler(event);
195            };
196            stream_query(&self.claude, &cmd, wrap).await
197        };
198
199        if let Some(sid) = captured_session_id {
200            self.session_id = Some(sid);
201        }
202        if let Some(qr) = captured_result {
203            self.record(&qr);
204        }
205
206        outcome.map(|_| ())
207    }
208
209    /// Current session id, if one has been established.
210    pub fn id(&self) -> Option<&str> {
211        self.session_id.as_deref()
212    }
213
214    /// Cumulative cost in USD across all turns in this session.
215    pub fn total_cost_usd(&self) -> f64 {
216        self.cumulative_cost_usd
217    }
218
219    /// Cumulative turn count across all turns in this session.
220    pub fn total_turns(&self) -> u32 {
221        self.cumulative_turns
222    }
223
224    /// Full per-turn result history.
225    pub fn history(&self) -> &[QueryResult] {
226        &self.history
227    }
228
229    /// Result of the most recent turn, if any.
230    pub fn last_result(&self) -> Option<&QueryResult> {
231        self.history.last()
232    }
233
234    fn record(&mut self, result: &QueryResult) {
235        self.session_id = Some(result.session_id.clone());
236        self.cumulative_cost_usd += result.cost_usd.unwrap_or(0.0);
237        self.cumulative_turns += result.num_turns.unwrap_or(0);
238        self.history.push(result.clone());
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    fn test_claude() -> Arc<Claude> {
247        Arc::new(
248            Claude::builder()
249                .binary("/usr/local/bin/claude")
250                .build()
251                .unwrap(),
252        )
253    }
254
255    #[test]
256    fn new_session_has_no_id() {
257        let session = Session::new(test_claude());
258        assert!(session.id().is_none());
259        assert_eq!(session.total_cost_usd(), 0.0);
260        assert_eq!(session.total_turns(), 0);
261        assert!(session.history().is_empty());
262        assert!(session.last_result().is_none());
263    }
264
265    #[test]
266    fn resume_session_has_preset_id() {
267        let session = Session::resume(test_claude(), "sess-abc");
268        assert_eq!(session.id(), Some("sess-abc"));
269        assert_eq!(session.total_cost_usd(), 0.0);
270        assert_eq!(session.total_turns(), 0);
271    }
272
273    #[test]
274    fn record_updates_state() {
275        let mut session = Session::new(test_claude());
276        let result = QueryResult {
277            result: "ok".into(),
278            session_id: "sess-1".into(),
279            cost_usd: Some(0.05),
280            duration_ms: None,
281            num_turns: Some(3),
282            is_error: false,
283            extra: Default::default(),
284        };
285        session.record(&result);
286        assert_eq!(session.id(), Some("sess-1"));
287        assert!((session.total_cost_usd() - 0.05).abs() < f64::EPSILON);
288        assert_eq!(session.total_turns(), 3);
289        assert_eq!(session.history().len(), 1);
290        assert_eq!(
291            session.last_result().map(|r| r.session_id.as_str()),
292            Some("sess-1")
293        );
294    }
295
296    #[test]
297    fn record_accumulates_across_turns() {
298        let mut session = Session::new(test_claude());
299        let r1 = QueryResult {
300            result: "a".into(),
301            session_id: "sess-1".into(),
302            cost_usd: Some(0.01),
303            duration_ms: None,
304            num_turns: Some(2),
305            is_error: false,
306            extra: Default::default(),
307        };
308        let r2 = QueryResult {
309            result: "b".into(),
310            session_id: "sess-1".into(),
311            cost_usd: Some(0.02),
312            duration_ms: None,
313            num_turns: Some(1),
314            is_error: false,
315            extra: Default::default(),
316        };
317        session.record(&r1);
318        session.record(&r2);
319        assert_eq!(session.total_turns(), 3);
320        assert!((session.total_cost_usd() - 0.03).abs() < f64::EPSILON);
321        assert_eq!(session.history().len(), 2);
322    }
323
324    #[test]
325    fn replace_session_clears_conflicting_flags() {
326        use crate::command::ClaudeCommand;
327
328        // Verify that replace_session strips --continue/--session-id/
329        // --fork-session and sets --resume to the given id.
330        let cmd = QueryCommand::new("hi")
331            .continue_session()
332            .session_id("old")
333            .fork_session()
334            .replace_session("new-id");
335
336        let args = cmd.args();
337        assert!(args.contains(&"--resume".to_string()));
338        assert!(args.contains(&"new-id".to_string()));
339        assert!(!args.contains(&"--continue".to_string()));
340        assert!(!args.contains(&"--session-id".to_string()));
341        assert!(!args.contains(&"--fork-session".to_string()));
342    }
343}