claude_wrapper/session.rs
1//! Multi-turn session management.
2//!
3//! A [`Session`] threads Claude's `session_id` across turns automatically,
4//! so callers never need to scrape it out of a result event or pass
5//! `--resume` by hand.
6//!
7//! # Ownership
8//!
9//! [`Session`] holds an `Arc<Claude>`. Wrapping the client in an `Arc`
10//! means a session can outlive the original client binding, be moved
11//! between tasks, and sit inside long-lived actor state -- which is the
12//! usage shape that callers like centralino-rs need. One `Arc::clone`
13//! per session is a negligible cost in exchange.
14//!
15//! # Two entry points
16//!
17//! - [`Session::send`] takes a plain prompt. Use this for straightforward
18//! multi-turn chat.
19//! - [`Session::execute`] takes a fully-configured [`QueryCommand`]. Use
20//! this when you want per-turn options like `model`, `max_turns`,
21//! `permission_mode`, etc. The session automatically overrides any
22//! session-related flags on the command (`--resume`, `--continue`,
23//! `--session-id`, `--fork-session`) so they can't conflict.
24//!
25//! Streaming follows the same split: [`Session::stream`] and
26//! [`Session::stream_execute`].
27//!
28//! # Example
29//!
30//! ```no_run
31//! use std::sync::Arc;
32//! use claude_wrapper::{Claude, QueryCommand};
33//! use claude_wrapper::session::Session;
34//!
35//! # async fn example() -> claude_wrapper::Result<()> {
36//! let claude = Arc::new(Claude::builder().build()?);
37//!
38//! let mut session = Session::new(Arc::clone(&claude));
39//!
40//! // Simple path
41//! let first = session.send("explain quicksort").await?;
42//!
43//! // Full control: custom model, effort, permission mode, etc.
44//! let second = session
45//! .execute(QueryCommand::new("now mergesort").model("opus"))
46//! .await?;
47//!
48//! println!("total cost: ${:.4}", session.total_cost_usd());
49//! println!("turns: {}", session.total_turns());
50//! # Ok(())
51//! # }
52//! ```
53//!
54//! # Resuming an existing session
55//!
56//! ```no_run
57//! # use std::sync::Arc;
58//! # use claude_wrapper::{Claude};
59//! # use claude_wrapper::session::Session;
60//! # async fn example() -> claude_wrapper::Result<()> {
61//! # let claude = Arc::new(Claude::builder().build()?);
62//! // Reattach to a session you stored earlier
63//! let mut session = Session::resume(claude, "sess-abc123");
64//! let result = session.send("pick up where we left off").await?;
65//! # Ok(())
66//! # }
67//! ```
68
69use std::sync::Arc;
70
71use crate::Claude;
72use crate::command::query::QueryCommand;
73use crate::error::Result;
74use crate::types::QueryResult;
75
76#[cfg(feature = "json")]
77use crate::streaming::{StreamEvent, stream_query};
78
79/// A multi-turn conversation handle.
80///
81/// Owns an `Arc<Claude>` so it can be moved between tasks and live
82/// inside long-running actors. Tracks `session_id`, cumulative cost,
83/// turn count, and per-turn result history.
84#[derive(Debug, Clone)]
85pub struct Session {
86 claude: Arc<Claude>,
87 session_id: Option<String>,
88 history: Vec<QueryResult>,
89 cumulative_cost_usd: f64,
90 cumulative_turns: u32,
91}
92
93impl Session {
94 /// Start a fresh session. The first turn will discover a session id
95 /// from its result; subsequent turns reuse it via `--resume`.
96 pub fn new(claude: Arc<Claude>) -> Self {
97 Self {
98 claude,
99 session_id: None,
100 history: Vec::new(),
101 cumulative_cost_usd: 0.0,
102 cumulative_turns: 0,
103 }
104 }
105
106 /// Reattach to an existing session by id. The next turn immediately
107 /// passes `--resume <id>`. Cost and turn counters start at zero
108 /// since no history is available.
109 pub fn resume(claude: Arc<Claude>, session_id: impl Into<String>) -> Self {
110 Self {
111 claude,
112 session_id: Some(session_id.into()),
113 history: Vec::new(),
114 cumulative_cost_usd: 0.0,
115 cumulative_turns: 0,
116 }
117 }
118
119 /// Send a plain-prompt turn. Equivalent to
120 /// `execute(QueryCommand::new(prompt))`.
121 #[cfg(feature = "json")]
122 pub async fn send(&mut self, prompt: impl Into<String>) -> Result<QueryResult> {
123 self.execute(QueryCommand::new(prompt)).await
124 }
125
126 /// Send a turn with a fully-configured [`QueryCommand`].
127 ///
128 /// Any session-related flags on `cmd` (`--resume`, `--continue`,
129 /// `--session-id`, `--fork-session`) are overridden with this
130 /// session's current id, so they can't conflict.
131 #[cfg(feature = "json")]
132 pub async fn execute(&mut self, cmd: QueryCommand) -> Result<QueryResult> {
133 let cmd = match &self.session_id {
134 Some(id) => cmd.replace_session(id),
135 None => cmd,
136 };
137
138 let result = cmd.execute_json(&self.claude).await?;
139 self.record(&result);
140 Ok(result)
141 }
142
143 /// Stream a plain-prompt turn, dispatching each NDJSON event to
144 /// `handler`. The session's id is captured from the first event
145 /// that carries one, so subsequent turns can resume, and the id
146 /// persists even if the stream errors partway through.
147 #[cfg(feature = "json")]
148 pub async fn stream<F>(&mut self, prompt: impl Into<String>, handler: F) -> Result<()>
149 where
150 F: FnMut(StreamEvent),
151 {
152 self.stream_execute(QueryCommand::new(prompt), handler)
153 .await
154 }
155
156 /// Stream a turn with a fully-configured [`QueryCommand`], with the
157 /// same session-id capture semantics as [`Session::stream`].
158 ///
159 /// The command's output format is forced to `stream-json` and any
160 /// session-related flags are overridden as in [`Session::execute`].
161 #[cfg(feature = "json")]
162 pub async fn stream_execute<F>(&mut self, cmd: QueryCommand, mut handler: F) -> Result<()>
163 where
164 F: FnMut(StreamEvent),
165 {
166 use crate::types::OutputFormat;
167
168 let cmd = match &self.session_id {
169 Some(id) => cmd.replace_session(id),
170 None => cmd,
171 }
172 .output_format(OutputFormat::StreamJson);
173
174 // Capture session_id and result state from events inside a
175 // wrapper closure. The captures happen before the caller's
176 // handler runs, and self is updated after the stream completes
177 // (even on error) so id persists across partial failures.
178 let mut captured_session_id: Option<String> = None;
179 let mut captured_result: Option<QueryResult> = None;
180
181 let outcome = {
182 let wrap = |event: StreamEvent| {
183 if captured_session_id.is_none()
184 && let Some(sid) = event.session_id()
185 {
186 captured_session_id = Some(sid.to_string());
187 }
188 if event.is_result()
189 && captured_result.is_none()
190 && let Ok(qr) = serde_json::from_value::<QueryResult>(event.data.clone())
191 {
192 captured_result = Some(qr);
193 }
194 handler(event);
195 };
196 stream_query(&self.claude, &cmd, wrap).await
197 };
198
199 if let Some(sid) = captured_session_id {
200 self.session_id = Some(sid);
201 }
202 if let Some(qr) = captured_result {
203 self.record(&qr);
204 }
205
206 outcome.map(|_| ())
207 }
208
209 /// Current session id, if one has been established.
210 pub fn id(&self) -> Option<&str> {
211 self.session_id.as_deref()
212 }
213
214 /// Cumulative cost in USD across all turns in this session.
215 pub fn total_cost_usd(&self) -> f64 {
216 self.cumulative_cost_usd
217 }
218
219 /// Cumulative turn count across all turns in this session.
220 pub fn total_turns(&self) -> u32 {
221 self.cumulative_turns
222 }
223
224 /// Full per-turn result history.
225 pub fn history(&self) -> &[QueryResult] {
226 &self.history
227 }
228
229 /// Result of the most recent turn, if any.
230 pub fn last_result(&self) -> Option<&QueryResult> {
231 self.history.last()
232 }
233
234 fn record(&mut self, result: &QueryResult) {
235 self.session_id = Some(result.session_id.clone());
236 self.cumulative_cost_usd += result.cost_usd.unwrap_or(0.0);
237 self.cumulative_turns += result.num_turns.unwrap_or(0);
238 self.history.push(result.clone());
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 fn test_claude() -> Arc<Claude> {
247 Arc::new(
248 Claude::builder()
249 .binary("/usr/local/bin/claude")
250 .build()
251 .unwrap(),
252 )
253 }
254
255 #[test]
256 fn new_session_has_no_id() {
257 let session = Session::new(test_claude());
258 assert!(session.id().is_none());
259 assert_eq!(session.total_cost_usd(), 0.0);
260 assert_eq!(session.total_turns(), 0);
261 assert!(session.history().is_empty());
262 assert!(session.last_result().is_none());
263 }
264
265 #[test]
266 fn resume_session_has_preset_id() {
267 let session = Session::resume(test_claude(), "sess-abc");
268 assert_eq!(session.id(), Some("sess-abc"));
269 assert_eq!(session.total_cost_usd(), 0.0);
270 assert_eq!(session.total_turns(), 0);
271 }
272
273 #[test]
274 fn record_updates_state() {
275 let mut session = Session::new(test_claude());
276 let result = QueryResult {
277 result: "ok".into(),
278 session_id: "sess-1".into(),
279 cost_usd: Some(0.05),
280 duration_ms: None,
281 num_turns: Some(3),
282 is_error: false,
283 extra: Default::default(),
284 };
285 session.record(&result);
286 assert_eq!(session.id(), Some("sess-1"));
287 assert!((session.total_cost_usd() - 0.05).abs() < f64::EPSILON);
288 assert_eq!(session.total_turns(), 3);
289 assert_eq!(session.history().len(), 1);
290 assert_eq!(
291 session.last_result().map(|r| r.session_id.as_str()),
292 Some("sess-1")
293 );
294 }
295
296 #[test]
297 fn record_accumulates_across_turns() {
298 let mut session = Session::new(test_claude());
299 let r1 = QueryResult {
300 result: "a".into(),
301 session_id: "sess-1".into(),
302 cost_usd: Some(0.01),
303 duration_ms: None,
304 num_turns: Some(2),
305 is_error: false,
306 extra: Default::default(),
307 };
308 let r2 = QueryResult {
309 result: "b".into(),
310 session_id: "sess-1".into(),
311 cost_usd: Some(0.02),
312 duration_ms: None,
313 num_turns: Some(1),
314 is_error: false,
315 extra: Default::default(),
316 };
317 session.record(&r1);
318 session.record(&r2);
319 assert_eq!(session.total_turns(), 3);
320 assert!((session.total_cost_usd() - 0.03).abs() < f64::EPSILON);
321 assert_eq!(session.history().len(), 2);
322 }
323
324 #[test]
325 fn replace_session_clears_conflicting_flags() {
326 use crate::command::ClaudeCommand;
327
328 // Verify that replace_session strips --continue/--session-id/
329 // --fork-session and sets --resume to the given id.
330 let cmd = QueryCommand::new("hi")
331 .continue_session()
332 .session_id("old")
333 .fork_session()
334 .replace_session("new-id");
335
336 let args = cmd.args();
337 assert!(args.contains(&"--resume".to_string()));
338 assert!(args.contains(&"new-id".to_string()));
339 assert!(!args.contains(&"--continue".to_string()));
340 assert!(!args.contains(&"--session-id".to_string()));
341 assert!(!args.contains(&"--fork-session".to_string()));
342 }
343}