Skip to main content

claude_wrapper/
conversation.rs

1//! Host-side bookkeeping wrapper around [`DuplexSession`].
2//!
3//! [`Conversation`] keeps a rolling history of [`TurnResult`]s,
4//! cumulative cost, and an optional [`BudgetTracker`] hard stop on top
5//! of an underlying [`DuplexSession`]. The duplex session itself
6//! remains the transport; this wrapper only adds accounting that
7//! `DuplexSession::send` does not provide on its own.
8//!
9//! For the equivalent shape on top of transient `--resume` subprocess
10//! turns, see [`Session`](crate::session::Session). [`Conversation`]
11//! is the duplex-flavoured peer.
12//!
13//! # When to use
14//!
15//! Reach for [`Conversation`] when you already want a
16//! [`DuplexSession`] (long-running host, mid-turn interrupts, broadcast
17//! subscribers) AND want to answer questions like:
18//!
19//! - How much have I spent on this conversation so far?
20//! - What's the full history of turns and their costs?
21//! - Stop accepting new turns once I hit $5.
22//!
23//! If you do not need bookkeeping, use [`DuplexSession`] directly. If
24//! you want accounting on short-lived per-turn subprocess calls, use
25//! [`Session`](crate::session::Session) instead.
26//!
27//! # Example
28//!
29//! ```no_run
30//! use claude_wrapper::Claude;
31//! use claude_wrapper::conversation::Conversation;
32//! use claude_wrapper::duplex::{DuplexOptions, DuplexSession};
33//!
34//! # async fn example() -> claude_wrapper::Result<()> {
35//! let claude = Claude::builder().build()?;
36//! let session = DuplexSession::spawn(
37//!     &claude,
38//!     DuplexOptions::default().model("haiku"),
39//! ).await?;
40//!
41//! let mut conv = Conversation::new(session);
42//! let _first = conv.send("hello").await?;
43//! let _second = conv.send("and again").await?;
44//!
45//! println!("turns: {}", conv.total_turns());
46//! println!("cost:  ${:.4}", conv.total_cost_usd());
47//!
48//! conv.close().await?;
49//! # Ok(())
50//! # }
51//! ```
52//!
53//! # Budget tracking
54//!
55//! Attach a [`BudgetTracker`] (the same type [`Session`](crate::session::Session)
56//! uses) to enforce a cumulative USD ceiling. The pre-turn check runs
57//! before delegating to [`DuplexSession::send`]; once the ceiling is
58//! hit, [`Conversation::send`] returns
59//! [`Error::BudgetExceeded`](crate::error::Error::BudgetExceeded)
60//! without touching the underlying session.
61//!
62//! ```no_run
63//! use claude_wrapper::{BudgetTracker, Claude};
64//! use claude_wrapper::conversation::Conversation;
65//! use claude_wrapper::duplex::{DuplexOptions, DuplexSession};
66//!
67//! # async fn example() -> claude_wrapper::Result<()> {
68//! let budget = BudgetTracker::builder().max_usd(5.00).build();
69//! let claude = Claude::builder().build()?;
70//! let session = DuplexSession::spawn(&claude, DuplexOptions::default()).await?;
71//!
72//! let mut conv = Conversation::new(session).with_budget(budget.clone());
73//! let _ = conv.send("hello").await?;
74//! println!("spent: ${:.4}", budget.total_usd());
75//! # Ok(())
76//! # }
77//! ```
78//!
79//! # Beyond bookkeeping
80//!
81//! [`Conversation::send`] is the only entry point that updates
82//! history. For [`DuplexSession::subscribe`],
83//! [`DuplexSession::interrupt`], and
84//! [`DuplexSession::respond_to_permission`], use [`Conversation::session`]
85//! to reach the inner handle. Those calls bypass the wrapper's
86//! accounting on purpose: an interrupt still produces a `TurnResult`
87//! that the in-flight [`Conversation::send`] records cleanly when the
88//! truncated turn lands.
89
90use crate::budget::BudgetTracker;
91use crate::duplex::{DuplexSession, TurnResult};
92use crate::error::Result;
93
94/// Host-side bookkeeping over a [`DuplexSession`].
95///
96/// See the [module docs](crate::conversation) for the full design.
97#[derive(Debug)]
98pub struct Conversation {
99    inner: DuplexSession,
100    state: ConversationState,
101}
102
103/// Pure accounting state extracted so unit tests can exercise the
104/// bookkeeping logic without spawning a duplex child.
105#[derive(Debug, Default)]
106struct ConversationState {
107    history: Vec<TurnResult>,
108    cumulative_cost_usd: f64,
109    cumulative_turns: u32,
110    budget: Option<BudgetTracker>,
111}
112
113impl ConversationState {
114    fn record(&mut self, turn: TurnResult) {
115        let cost = turn.total_cost_usd().unwrap_or(0.0);
116        self.cumulative_cost_usd += cost;
117        self.cumulative_turns = self.cumulative_turns.saturating_add(1);
118        if let Some(b) = &self.budget {
119            b.record(cost);
120        }
121        self.history.push(turn);
122    }
123}
124
125impl Conversation {
126    /// Wrap a [`DuplexSession`] in a fresh [`Conversation`].
127    ///
128    /// The conversation starts with an empty history and zeroed
129    /// counters; the underlying session is not touched until the
130    /// first [`Conversation::send`].
131    #[must_use]
132    pub fn new(session: DuplexSession) -> Self {
133        Self {
134            inner: session,
135            state: ConversationState::default(),
136        }
137    }
138
139    /// Attach a [`BudgetTracker`] for cumulative-cost ceilings.
140    ///
141    /// Every turn's cost (from [`TurnResult::total_cost_usd`]) is
142    /// recorded on the tracker, and [`Conversation::send`] returns
143    /// [`Error::BudgetExceeded`](crate::error::Error::BudgetExceeded)
144    /// before dispatching a turn if the ceiling has been hit. Clone a
145    /// tracker across several conversations to enforce a shared
146    /// ceiling.
147    #[must_use]
148    pub fn with_budget(mut self, budget: BudgetTracker) -> Self {
149        self.state.budget = Some(budget);
150        self
151    }
152
153    /// The attached [`BudgetTracker`], if any.
154    #[must_use]
155    pub fn budget(&self) -> Option<&BudgetTracker> {
156        self.state.budget.as_ref()
157    }
158
159    /// Send one user message and record the resulting [`TurnResult`].
160    ///
161    /// Pre-turn: if a [`BudgetTracker`] is attached and its ceiling is
162    /// hit, returns
163    /// [`Error::BudgetExceeded`](crate::error::Error::BudgetExceeded)
164    /// without touching the underlying session.
165    ///
166    /// On success the returned reference points at the just-recorded
167    /// last entry of [`Conversation::history`]. Errors from the
168    /// underlying [`DuplexSession::send`]
169    /// (e.g. [`Error::DuplexTurnInFlight`](crate::error::Error::DuplexTurnInFlight),
170    /// [`Error::DuplexClosed`](crate::error::Error::DuplexClosed))
171    /// propagate unchanged and do not update the history or cost
172    /// counters.
173    pub async fn send(&mut self, prompt: impl Into<String>) -> Result<&TurnResult> {
174        if let Some(b) = &self.state.budget {
175            b.check()?;
176        }
177
178        let turn = self.inner.send(prompt).await?;
179        self.state.record(turn);
180        Ok(self
181            .state
182            .history
183            .last()
184            .expect("just-pushed entry must be present"))
185    }
186
187    /// Per-turn result history, in arrival order.
188    #[must_use]
189    pub fn history(&self) -> &[TurnResult] {
190        &self.state.history
191    }
192
193    /// Result of the most recent turn, if any.
194    #[must_use]
195    pub fn last(&self) -> Option<&TurnResult> {
196        self.state.history.last()
197    }
198
199    /// Cumulative cost in USD across every recorded turn.
200    #[must_use]
201    pub fn total_cost_usd(&self) -> f64 {
202        self.state.cumulative_cost_usd
203    }
204
205    /// Number of turns recorded through [`Conversation::send`].
206    #[must_use]
207    pub fn total_turns(&self) -> u32 {
208        self.state.cumulative_turns
209    }
210
211    /// Session id from the most recent turn's `result` payload, if
212    /// any. Returns `None` until the first turn lands; later turns on
213    /// a single duplex child reuse the same id.
214    #[must_use]
215    pub fn session_id(&self) -> Option<&str> {
216        self.state.history.last().and_then(TurnResult::session_id)
217    }
218
219    /// Borrow the underlying [`DuplexSession`].
220    ///
221    /// Use this for [`DuplexSession::subscribe`],
222    /// [`DuplexSession::interrupt`], and
223    /// [`DuplexSession::respond_to_permission`]. Those calls bypass
224    /// [`Conversation`]'s bookkeeping on purpose -- an interrupt still
225    /// produces a [`TurnResult`] that the in-flight
226    /// [`Conversation::send`] records cleanly when the truncated turn
227    /// lands.
228    #[must_use]
229    pub fn session(&self) -> &DuplexSession {
230        &self.inner
231    }
232
233    /// Close the underlying [`DuplexSession`] and wait for its task to
234    /// exit. Consumes the [`Conversation`].
235    pub async fn close(self) -> Result<()> {
236        self.inner.close().await
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use serde_json::json;
244
245    fn turn(session_id: &str, cost: f64) -> TurnResult {
246        TurnResult {
247            result: json!({
248                "type": "result",
249                "result": "ok",
250                "session_id": session_id,
251                "total_cost_usd": cost,
252            }),
253            events: vec![],
254        }
255    }
256
257    #[test]
258    fn record_pushes_turn_and_updates_counters() {
259        let mut state = ConversationState::default();
260        state.record(turn("sess-1", 0.05));
261
262        assert_eq!(state.history.len(), 1);
263        assert_eq!(state.cumulative_turns, 1);
264        assert!((state.cumulative_cost_usd - 0.05).abs() < 1e-9);
265        assert_eq!(state.history[0].session_id(), Some("sess-1"));
266    }
267
268    #[test]
269    fn record_accumulates_across_turns() {
270        let mut state = ConversationState::default();
271        state.record(turn("sess-1", 0.01));
272        state.record(turn("sess-1", 0.02));
273        state.record(turn("sess-1", 0.03));
274
275        assert_eq!(state.history.len(), 3);
276        assert_eq!(state.cumulative_turns, 3);
277        assert!((state.cumulative_cost_usd - 0.06).abs() < 1e-9);
278    }
279
280    #[test]
281    fn record_treats_missing_cost_as_zero() {
282        let mut state = ConversationState::default();
283        let bare = TurnResult {
284            result: json!({ "type": "result", "session_id": "sess-1" }),
285            events: vec![],
286        };
287        state.record(bare);
288
289        assert_eq!(state.cumulative_turns, 1);
290        assert_eq!(state.cumulative_cost_usd, 0.0);
291    }
292
293    #[test]
294    fn record_forwards_cost_to_budget() {
295        let budget = BudgetTracker::builder().build();
296        let mut state = ConversationState {
297            budget: Some(budget.clone()),
298            ..Default::default()
299        };
300
301        state.record(turn("sess-1", 0.07));
302
303        assert!((budget.total_usd() - 0.07).abs() < 1e-9);
304        assert!((state.cumulative_cost_usd - 0.07).abs() < 1e-9);
305    }
306
307    #[test]
308    fn budget_check_blocks_before_send() {
309        use crate::error::Error;
310
311        // Conversation::send delegates pre-turn budget enforcement to
312        // BudgetTracker::check. Exercise that directly with a
313        // pre-loaded tracker so we don't need a live duplex session.
314        let budget = BudgetTracker::builder().max_usd(0.10).build();
315        budget.record(0.15);
316
317        match budget.check() {
318            Err(Error::BudgetExceeded { total_usd, max_usd }) => {
319                assert!((total_usd - 0.15).abs() < 1e-9);
320                assert!((max_usd - 0.10).abs() < 1e-9);
321            }
322            other => panic!("expected BudgetExceeded, got {other:?}"),
323        }
324    }
325
326    #[test]
327    fn last_returns_most_recent() {
328        let mut state = ConversationState::default();
329        assert!(state.history.last().is_none());
330
331        state.record(turn("sess-1", 0.01));
332        state.record(turn("sess-1", 0.02));
333        let last = state.history.last().expect("last entry present");
334        assert!((last.total_cost_usd().unwrap() - 0.02).abs() < 1e-9);
335    }
336
337    #[test]
338    fn session_id_pulled_from_last_turn() {
339        let mut state = ConversationState::default();
340        // Empty history: no session id yet.
341        assert!(
342            state
343                .history
344                .last()
345                .and_then(TurnResult::session_id)
346                .is_none()
347        );
348
349        state.record(turn("sess-A", 0.01));
350        state.record(turn("sess-B", 0.02));
351        assert_eq!(
352            state.history.last().and_then(TurnResult::session_id),
353            Some("sess-B")
354        );
355    }
356}