Skip to main content

claude_wrapper/
session.rs

1//! Type-safe session management for multi-turn conversations.
2//!
3//! The [`Session`] struct consolidates session control into a single abstraction
4//! that prevents conflicting session flags at the type level. Instead of
5//! independently calling `.continue_session()`, `.resume()`, `.session_id()`,
6//! or `.fork_session()` on a [`QueryCommand`] (which can
7//! be combined incorrectly), a `Session` encodes the session mode in its
8//! construction and provides `.query()` with automatic resume behavior.
9//!
10//! # Example
11//!
12//! ```no_run
13//! use claude_wrapper::{Claude, QueryCommand};
14//! use claude_wrapper::session::Session;
15//!
16//! # async fn example() -> claude_wrapper::Result<()> {
17//! let claude = Claude::builder().build()?;
18//!
19//! // Start a session with an initial query
20//! let first = QueryCommand::new("explain quicksort")
21//!     .execute_json(&claude)
22//!     .await?;
23//!
24//! // Wrap it in a Session for automatic resume
25//! let mut session = Session::from_result(&claude, &first);
26//!
27//! // Follow-up queries auto-resume the session
28//! let second = session.query("now explain mergesort")
29//!     .model("sonnet")
30//!     .execute()
31//!     .await?;
32//!
33//! println!("total cost: ${:.4}", session.total_cost_usd());
34//! println!("total turns: {}", session.total_turns());
35//! # Ok(())
36//! # }
37//! ```
38
39use crate::Claude;
40use crate::command::query::QueryCommand;
41use crate::error::Result;
42use crate::types::{Effort, InputFormat, OutputFormat, PermissionMode, QueryResult};
43
44/// A type-safe session handle for multi-turn conversations.
45///
46/// `Session` wraps a [`Claude`] client reference and a session ID, providing
47/// `.query()` that automatically resumes the session. It tracks cumulative
48/// cost and turn count across all queries in the session.
49///
50/// Conflicting session flags are impossible because the session mode is
51/// encoded in construction rather than as independent builder methods.
52#[derive(Debug)]
53pub struct Session<'a> {
54    claude: &'a Claude,
55    session_id: String,
56    cumulative_cost_usd: f64,
57    cumulative_turns: u32,
58}
59
60impl<'a> Session<'a> {
61    /// Create a session from a completed query result.
62    ///
63    /// This is the most common way to start a session: run an initial
64    /// [`QueryCommand::execute_json()`] and then wrap the result.
65    ///
66    /// # Example
67    ///
68    /// ```no_run
69    /// use claude_wrapper::{Claude, QueryCommand};
70    /// use claude_wrapper::session::Session;
71    ///
72    /// # async fn example() -> claude_wrapper::Result<()> {
73    /// let claude = Claude::builder().build()?;
74    /// let result = QueryCommand::new("hello")
75    ///     .execute_json(&claude).await?;
76    ///
77    /// let mut session = Session::from_result(&claude, &result);
78    /// # Ok(())
79    /// # }
80    /// ```
81    pub fn from_result(claude: &'a Claude, result: &QueryResult) -> Self {
82        Self {
83            claude,
84            session_id: result.session_id.clone(),
85            cumulative_cost_usd: result.cost_usd.unwrap_or(0.0),
86            cumulative_turns: result.num_turns.unwrap_or(0),
87        }
88    }
89
90    /// Attach to an existing session by ID.
91    ///
92    /// Cost and turn counters start at zero since we have no history.
93    pub fn from_id(claude: &'a Claude, session_id: impl Into<String>) -> Self {
94        Self {
95            claude,
96            session_id: session_id.into(),
97            cumulative_cost_usd: 0.0,
98            cumulative_turns: 0,
99        }
100    }
101
102    /// Continue the most recent session.
103    ///
104    /// Runs the first query with `--continue` to discover the session ID,
105    /// then returns a `Session` that uses `--resume` for subsequent queries.
106    pub async fn continue_recent(
107        claude: &'a Claude,
108        prompt: impl Into<String>,
109    ) -> Result<(Self, QueryResult)> {
110        let result = QueryCommand::new(prompt)
111            .continue_session()
112            .execute_json(claude)
113            .await?;
114
115        let session = Self {
116            claude,
117            session_id: result.session_id.clone(),
118            cumulative_cost_usd: result.cost_usd.unwrap_or(0.0),
119            cumulative_turns: result.num_turns.unwrap_or(0),
120        };
121        Ok((session, result))
122    }
123
124    /// Send a follow-up query in this session.
125    ///
126    /// Returns a [`SessionQuery`] builder with `--resume` pre-set.
127    /// Configure additional options (model, effort, etc.) on the builder,
128    /// then call `.execute()`.
129    ///
130    /// # Example
131    ///
132    /// ```no_run
133    /// # use claude_wrapper::{Claude, QueryCommand};
134    /// # use claude_wrapper::session::Session;
135    /// # async fn example() -> claude_wrapper::Result<()> {
136    /// # let claude = Claude::builder().build()?;
137    /// # let result = QueryCommand::new("hello").execute_json(&claude).await?;
138    /// let mut session = Session::from_result(&claude, &result);
139    ///
140    /// let follow_up = session.query("what about the edge cases?")
141    ///     .model("opus")
142    ///     .max_turns(5)
143    ///     .execute()
144    ///     .await?;
145    /// # Ok(())
146    /// # }
147    /// ```
148    pub fn query(&mut self, prompt: impl Into<String>) -> SessionQuery<'_, 'a> {
149        SessionQuery::new(self, prompt)
150    }
151
152    /// Fork this session into a new one.
153    ///
154    /// Sends a query with `--resume` and `--fork-session`, creating a new
155    /// session branched from this one. Returns the new `Session` and the
156    /// query result. The original session is not modified.
157    pub async fn fork(&self, prompt: impl Into<String>) -> Result<(Session<'a>, QueryResult)> {
158        let result = QueryCommand::new(prompt)
159            .resume(&self.session_id)
160            .fork_session()
161            .execute_json(self.claude)
162            .await?;
163
164        let forked = Session {
165            claude: self.claude,
166            session_id: result.session_id.clone(),
167            cumulative_cost_usd: self.cumulative_cost_usd + result.cost_usd.unwrap_or(0.0),
168            cumulative_turns: self.cumulative_turns + result.num_turns.unwrap_or(0),
169        };
170        Ok((forked, result))
171    }
172
173    /// Get the current session ID.
174    pub fn id(&self) -> &str {
175        &self.session_id
176    }
177
178    /// Get cumulative cost in USD across all queries in this session.
179    pub fn total_cost_usd(&self) -> f64 {
180        self.cumulative_cost_usd
181    }
182
183    /// Get cumulative turn count across all queries in this session.
184    pub fn total_turns(&self) -> u32 {
185        self.cumulative_turns
186    }
187}
188
189/// Builder for a follow-up query within a session.
190///
191/// This wraps a [`QueryCommand`] with `--resume` pre-set. Session-related
192/// methods (`.continue_session()`, `.session_id()`, `.fork_session()`,
193/// `.resume()`) are intentionally not exposed, preventing conflicting flags
194/// at the type level.
195///
196/// All other `QueryCommand` options are available via delegation.
197#[derive(Debug)]
198pub struct SessionQuery<'s, 'a> {
199    session: &'s mut Session<'a>,
200    command: QueryCommand,
201}
202
203impl<'s, 'a> SessionQuery<'s, 'a> {
204    fn new(session: &'s mut Session<'a>, prompt: impl Into<String>) -> Self {
205        let command = QueryCommand::new(prompt).resume(&session.session_id);
206        Self { session, command }
207    }
208
209    /// Set the model to use.
210    #[must_use]
211    pub fn model(mut self, model: impl Into<String>) -> Self {
212        self.command = self.command.model(model);
213        self
214    }
215
216    /// Set a custom system prompt.
217    #[must_use]
218    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
219        self.command = self.command.system_prompt(prompt);
220        self
221    }
222
223    /// Append to the default system prompt.
224    #[must_use]
225    pub fn append_system_prompt(mut self, prompt: impl Into<String>) -> Self {
226        self.command = self.command.append_system_prompt(prompt);
227        self
228    }
229
230    /// Set the output format.
231    #[must_use]
232    pub fn output_format(mut self, format: OutputFormat) -> Self {
233        self.command = self.command.output_format(format);
234        self
235    }
236
237    /// Set the maximum budget in USD.
238    #[must_use]
239    pub fn max_budget_usd(mut self, budget: f64) -> Self {
240        self.command = self.command.max_budget_usd(budget);
241        self
242    }
243
244    /// Set the permission mode.
245    #[must_use]
246    pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
247        self.command = self.command.permission_mode(mode);
248        self
249    }
250
251    /// Add allowed tools.
252    #[must_use]
253    pub fn allowed_tools(mut self, tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
254        self.command = self.command.allowed_tools(tools);
255        self
256    }
257
258    /// Add a single allowed tool.
259    #[must_use]
260    pub fn allowed_tool(mut self, tool: impl Into<String>) -> Self {
261        self.command = self.command.allowed_tool(tool);
262        self
263    }
264
265    /// Add disallowed tools.
266    #[must_use]
267    pub fn disallowed_tools(mut self, tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
268        self.command = self.command.disallowed_tools(tools);
269        self
270    }
271
272    /// Add an MCP config file path.
273    #[must_use]
274    pub fn mcp_config(mut self, path: impl Into<String>) -> Self {
275        self.command = self.command.mcp_config(path);
276        self
277    }
278
279    /// Add an additional directory for tool access.
280    #[must_use]
281    pub fn add_dir(mut self, dir: impl Into<String>) -> Self {
282        self.command = self.command.add_dir(dir);
283        self
284    }
285
286    /// Set the effort level.
287    #[must_use]
288    pub fn effort(mut self, effort: Effort) -> Self {
289        self.command = self.command.effort(effort);
290        self
291    }
292
293    /// Set the maximum number of turns.
294    #[must_use]
295    pub fn max_turns(mut self, turns: u32) -> Self {
296        self.command = self.command.max_turns(turns);
297        self
298    }
299
300    /// Set a JSON schema for structured output validation.
301    #[must_use]
302    pub fn json_schema(mut self, schema: impl Into<String>) -> Self {
303        self.command = self.command.json_schema(schema);
304        self
305    }
306
307    /// Set a fallback model.
308    #[must_use]
309    pub fn fallback_model(mut self, model: impl Into<String>) -> Self {
310        self.command = self.command.fallback_model(model);
311        self
312    }
313
314    /// Disable session persistence.
315    #[must_use]
316    pub fn no_session_persistence(mut self) -> Self {
317        self.command = self.command.no_session_persistence();
318        self
319    }
320
321    /// Bypass all permission checks.
322    #[must_use]
323    pub fn dangerously_skip_permissions(mut self) -> Self {
324        self.command = self.command.dangerously_skip_permissions();
325        self
326    }
327
328    /// Set the agent for the session.
329    #[must_use]
330    pub fn agent(mut self, agent: impl Into<String>) -> Self {
331        self.command = self.command.agent(agent);
332        self
333    }
334
335    /// Set custom agents as a JSON object.
336    #[must_use]
337    pub fn agents_json(mut self, json: impl Into<String>) -> Self {
338        self.command = self.command.agents_json(json);
339        self
340    }
341
342    /// Set the list of available built-in tools.
343    #[must_use]
344    pub fn tools(mut self, tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
345        self.command = self.command.tools(tools);
346        self
347    }
348
349    /// Add a file resource to download at startup.
350    #[must_use]
351    pub fn file(mut self, spec: impl Into<String>) -> Self {
352        self.command = self.command.file(spec);
353        self
354    }
355
356    /// Include partial message chunks as they arrive.
357    #[must_use]
358    pub fn include_partial_messages(mut self) -> Self {
359        self.command = self.command.include_partial_messages();
360        self
361    }
362
363    /// Set the input format.
364    #[must_use]
365    pub fn input_format(mut self, format: InputFormat) -> Self {
366        self.command = self.command.input_format(format);
367        self
368    }
369
370    /// Only use MCP servers from `--mcp-config`.
371    #[must_use]
372    pub fn strict_mcp_config(mut self) -> Self {
373        self.command = self.command.strict_mcp_config();
374        self
375    }
376
377    /// Path to a settings JSON file or a JSON string.
378    #[must_use]
379    pub fn settings(mut self, settings: impl Into<String>) -> Self {
380        self.command = self.command.settings(settings);
381        self
382    }
383
384    /// Set a per-command retry policy.
385    #[must_use]
386    pub fn retry(mut self, policy: crate::retry::RetryPolicy) -> Self {
387        self.command = self.command.retry(policy);
388        self
389    }
390
391    /// Execute the query, updating the session's cumulative cost and turns.
392    pub async fn execute(self) -> Result<QueryResult> {
393        let result = self.command.execute_json(self.session.claude).await?;
394        self.session.cumulative_cost_usd += result.cost_usd.unwrap_or(0.0);
395        self.session.cumulative_turns += result.num_turns.unwrap_or(0);
396        self.session.session_id.clone_from(&result.session_id);
397        Ok(result)
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use crate::ClaudeCommand;
405
406    fn test_claude() -> Claude {
407        Claude::builder()
408            .binary("/usr/local/bin/claude")
409            .build()
410            .unwrap()
411    }
412
413    fn test_result(session_id: &str, cost: f64, turns: u32) -> QueryResult {
414        QueryResult {
415            result: "test".into(),
416            session_id: session_id.into(),
417            cost_usd: Some(cost),
418            duration_ms: None,
419            num_turns: Some(turns),
420            is_error: false,
421            extra: Default::default(),
422        }
423    }
424
425    #[test]
426    fn session_from_result_captures_state() {
427        let claude = test_claude();
428        let result = test_result("sess-abc", 0.05, 3);
429        let session = Session::from_result(&claude, &result);
430
431        assert_eq!(session.id(), "sess-abc");
432        assert!((session.total_cost_usd() - 0.05).abs() < f64::EPSILON);
433        assert_eq!(session.total_turns(), 3);
434    }
435
436    #[test]
437    fn session_from_id_starts_clean() {
438        let claude = test_claude();
439        let session = Session::from_id(&claude, "sess-xyz");
440
441        assert_eq!(session.id(), "sess-xyz");
442        assert!((session.total_cost_usd()).abs() < f64::EPSILON);
443        assert_eq!(session.total_turns(), 0);
444    }
445
446    #[test]
447    fn session_from_result_handles_none_cost_and_turns() {
448        let claude = test_claude();
449        let result = QueryResult {
450            result: "ok".into(),
451            session_id: "s1".into(),
452            cost_usd: None,
453            duration_ms: None,
454            num_turns: None,
455            is_error: false,
456            extra: Default::default(),
457        };
458        let session = Session::from_result(&claude, &result);
459
460        assert_eq!(session.total_cost_usd(), 0.0);
461        assert_eq!(session.total_turns(), 0);
462    }
463
464    #[test]
465    fn session_query_sets_resume_flag() {
466        let claude = test_claude();
467        let mut session = Session::from_id(&claude, "sess-123");
468        let sq = session.query("follow up");
469
470        let args = sq.command.args();
471        assert!(args.contains(&"--resume".to_string()));
472        assert!(args.contains(&"sess-123".to_string()));
473    }
474
475    #[test]
476    fn session_query_model_delegation() {
477        let claude = test_claude();
478        let mut session = Session::from_id(&claude, "sess-123");
479        let sq = session.query("follow up").model("sonnet");
480
481        let args = sq.command.args();
482        assert!(args.contains(&"--model".to_string()));
483        assert!(args.contains(&"sonnet".to_string()));
484    }
485
486    #[test]
487    fn session_query_effort_delegation() {
488        let claude = test_claude();
489        let mut session = Session::from_id(&claude, "sess-123");
490        let sq = session.query("follow up").effort(Effort::High);
491
492        let args = sq.command.args();
493        assert!(args.contains(&"--effort".to_string()));
494        assert!(args.contains(&"high".to_string()));
495    }
496
497    #[test]
498    fn session_query_max_turns_delegation() {
499        let claude = test_claude();
500        let mut session = Session::from_id(&claude, "sess-123");
501        let sq = session.query("follow up").max_turns(10);
502
503        let args = sq.command.args();
504        assert!(args.contains(&"--max-turns".to_string()));
505        assert!(args.contains(&"10".to_string()));
506    }
507
508    #[test]
509    fn session_query_prompt_is_last_arg() {
510        let claude = test_claude();
511        let mut session = Session::from_id(&claude, "sess-123");
512        let sq = session.query("my prompt");
513
514        let args = sq.command.args();
515        assert_eq!(args.last().unwrap(), "my prompt");
516    }
517
518    #[test]
519    fn session_query_does_not_have_continue_or_fork() {
520        // This is a compile-time check: SessionQuery does not expose
521        // .continue_session(), .session_id(), .fork_session(), or .resume().
522        // If any of those methods existed on SessionQuery, they would appear
523        // in the API. We verify structurally by checking that the inner
524        // command only has --resume set (no --continue, --fork-session, --session-id).
525        let claude = test_claude();
526        let mut session = Session::from_id(&claude, "sess-123");
527        let sq = session.query("test");
528
529        let args = sq.command.args();
530        assert!(!args.contains(&"--continue".to_string()));
531        assert!(!args.contains(&"--fork-session".to_string()));
532        assert!(!args.contains(&"--session-id".to_string()));
533    }
534}