claude_wrapper/
session.rs1use std::sync::Arc;
70
71use crate::Claude;
72use crate::budget::BudgetTracker;
73use crate::command::query::QueryCommand;
74use crate::error::Result;
75use crate::types::QueryResult;
76
77#[cfg(feature = "json")]
78use crate::streaming::{StreamEvent, stream_query};
79
80#[derive(Debug, Clone)]
86pub struct Session {
87 claude: Arc<Claude>,
88 session_id: Option<String>,
89 history: Vec<QueryResult>,
90 cumulative_cost_usd: f64,
91 cumulative_turns: u32,
92 budget: Option<BudgetTracker>,
93}
94
95impl Session {
96 pub fn new(claude: Arc<Claude>) -> Self {
99 Self {
100 claude,
101 session_id: None,
102 history: Vec::new(),
103 cumulative_cost_usd: 0.0,
104 cumulative_turns: 0,
105 budget: None,
106 }
107 }
108
109 pub fn resume(claude: Arc<Claude>, session_id: impl Into<String>) -> Self {
113 Self {
114 claude,
115 session_id: Some(session_id.into()),
116 history: Vec::new(),
117 cumulative_cost_usd: 0.0,
118 cumulative_turns: 0,
119 budget: None,
120 }
121 }
122
123 pub fn with_budget(mut self, budget: BudgetTracker) -> Self {
132 self.budget = Some(budget);
133 self
134 }
135
136 pub fn budget(&self) -> Option<&BudgetTracker> {
138 self.budget.as_ref()
139 }
140
141 #[cfg(feature = "json")]
144 pub async fn send(&mut self, prompt: impl Into<String>) -> Result<QueryResult> {
145 self.execute(QueryCommand::new(prompt)).await
146 }
147
148 #[cfg(feature = "json")]
154 pub async fn execute(&mut self, cmd: QueryCommand) -> Result<QueryResult> {
155 if let Some(b) = &self.budget {
156 b.check()?;
157 }
158
159 let cmd = match &self.session_id {
160 Some(id) => cmd.replace_session(id),
161 None => cmd,
162 };
163
164 let result = cmd.execute_json(&self.claude).await?;
165 self.record(&result);
166 Ok(result)
167 }
168
169 #[cfg(feature = "json")]
174 pub async fn stream<F>(&mut self, prompt: impl Into<String>, handler: F) -> Result<()>
175 where
176 F: FnMut(StreamEvent),
177 {
178 self.stream_execute(QueryCommand::new(prompt), handler)
179 .await
180 }
181
182 #[cfg(feature = "json")]
188 pub async fn stream_execute<F>(&mut self, cmd: QueryCommand, mut handler: F) -> Result<()>
189 where
190 F: FnMut(StreamEvent),
191 {
192 use crate::types::OutputFormat;
193
194 if let Some(b) = &self.budget {
195 b.check()?;
196 }
197
198 let cmd = match &self.session_id {
199 Some(id) => cmd.replace_session(id),
200 None => cmd,
201 }
202 .output_format(OutputFormat::StreamJson);
203
204 let mut captured_session_id: Option<String> = None;
209 let mut captured_result: Option<QueryResult> = None;
210
211 let outcome = {
212 let wrap = |event: StreamEvent| {
213 if captured_session_id.is_none()
214 && let Some(sid) = event.session_id()
215 {
216 captured_session_id = Some(sid.to_string());
217 }
218 if event.is_result()
219 && captured_result.is_none()
220 && let Ok(qr) = serde_json::from_value::<QueryResult>(event.data.clone())
221 {
222 captured_result = Some(qr);
223 }
224 handler(event);
225 };
226 stream_query(&self.claude, &cmd, wrap).await
227 };
228
229 if let Some(sid) = captured_session_id {
230 self.session_id = Some(sid);
231 }
232 if let Some(qr) = captured_result {
233 self.record(&qr);
234 }
235
236 outcome.map(|_| ())
237 }
238
239 pub fn id(&self) -> Option<&str> {
241 self.session_id.as_deref()
242 }
243
244 pub fn total_cost_usd(&self) -> f64 {
246 self.cumulative_cost_usd
247 }
248
249 pub fn total_turns(&self) -> u32 {
251 self.cumulative_turns
252 }
253
254 pub fn history(&self) -> &[QueryResult] {
256 &self.history
257 }
258
259 pub fn last_result(&self) -> Option<&QueryResult> {
261 self.history.last()
262 }
263
264 fn record(&mut self, result: &QueryResult) {
265 self.session_id = Some(result.session_id.clone());
266 let cost = result.cost_usd.unwrap_or(0.0);
267 self.cumulative_cost_usd += cost;
268 self.cumulative_turns += result.num_turns.unwrap_or(0);
269 if let Some(b) = &self.budget {
270 b.record(cost);
271 }
272 self.history.push(result.clone());
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 fn test_claude() -> Arc<Claude> {
281 Arc::new(
282 Claude::builder()
283 .binary("/usr/local/bin/claude")
284 .build()
285 .unwrap(),
286 )
287 }
288
289 #[test]
290 fn new_session_has_no_id() {
291 let session = Session::new(test_claude());
292 assert!(session.id().is_none());
293 assert_eq!(session.total_cost_usd(), 0.0);
294 assert_eq!(session.total_turns(), 0);
295 assert!(session.history().is_empty());
296 assert!(session.last_result().is_none());
297 }
298
299 #[test]
300 fn resume_session_has_preset_id() {
301 let session = Session::resume(test_claude(), "sess-abc");
302 assert_eq!(session.id(), Some("sess-abc"));
303 assert_eq!(session.total_cost_usd(), 0.0);
304 assert_eq!(session.total_turns(), 0);
305 }
306
307 #[test]
308 fn record_updates_state() {
309 let mut session = Session::new(test_claude());
310 let result = QueryResult {
311 result: "ok".into(),
312 session_id: "sess-1".into(),
313 cost_usd: Some(0.05),
314 duration_ms: None,
315 num_turns: Some(3),
316 is_error: false,
317 extra: Default::default(),
318 };
319 session.record(&result);
320 assert_eq!(session.id(), Some("sess-1"));
321 assert!((session.total_cost_usd() - 0.05).abs() < f64::EPSILON);
322 assert_eq!(session.total_turns(), 3);
323 assert_eq!(session.history().len(), 1);
324 assert_eq!(
325 session.last_result().map(|r| r.session_id.as_str()),
326 Some("sess-1")
327 );
328 }
329
330 #[test]
331 fn record_accumulates_across_turns() {
332 let mut session = Session::new(test_claude());
333 let r1 = QueryResult {
334 result: "a".into(),
335 session_id: "sess-1".into(),
336 cost_usd: Some(0.01),
337 duration_ms: None,
338 num_turns: Some(2),
339 is_error: false,
340 extra: Default::default(),
341 };
342 let r2 = QueryResult {
343 result: "b".into(),
344 session_id: "sess-1".into(),
345 cost_usd: Some(0.02),
346 duration_ms: None,
347 num_turns: Some(1),
348 is_error: false,
349 extra: Default::default(),
350 };
351 session.record(&r1);
352 session.record(&r2);
353 assert_eq!(session.total_turns(), 3);
354 assert!((session.total_cost_usd() - 0.03).abs() < f64::EPSILON);
355 assert_eq!(session.history().len(), 2);
356 }
357
358 #[test]
359 fn record_forwards_cost_to_budget() {
360 use crate::budget::BudgetTracker;
361
362 let budget = BudgetTracker::builder().build();
363 let mut session = Session::new(test_claude()).with_budget(budget.clone());
364
365 let r = QueryResult {
366 result: "ok".into(),
367 session_id: "sess-1".into(),
368 cost_usd: Some(0.07),
369 duration_ms: None,
370 num_turns: Some(1),
371 is_error: false,
372 extra: Default::default(),
373 };
374 session.record(&r);
375
376 assert!((budget.total_usd() - 0.07).abs() < 1e-9);
377 assert!((session.total_cost_usd() - 0.07).abs() < 1e-9);
378 }
379
380 #[test]
381 fn budget_pre_check_would_block_next_turn() {
382 use crate::budget::BudgetTracker;
383 use crate::error::Error;
384
385 let budget = BudgetTracker::builder().max_usd(0.10).build();
389 budget.record(0.15);
390
391 let session = Session::new(test_claude()).with_budget(budget);
392 match session.budget().unwrap().check() {
393 Err(Error::BudgetExceeded { total_usd, max_usd }) => {
394 assert!((total_usd - 0.15).abs() < 1e-9);
395 assert!((max_usd - 0.10).abs() < 1e-9);
396 }
397 other => panic!("expected BudgetExceeded, got {other:?}"),
398 }
399 }
400
401 #[test]
402 fn replace_session_clears_conflicting_flags() {
403 use crate::command::ClaudeCommand;
404
405 let cmd = QueryCommand::new("hi")
408 .continue_session()
409 .session_id("old")
410 .fork_session()
411 .replace_session("new-id");
412
413 let args = cmd.args();
414 assert!(args.contains(&"--resume".to_string()));
415 assert!(args.contains(&"new-id".to_string()));
416 assert!(!args.contains(&"--continue".to_string()));
417 assert!(!args.contains(&"--session-id".to_string()));
418 assert!(!args.contains(&"--fork-session".to_string()));
419 }
420}