1use std::sync::{Arc, Mutex};
2
3use crate::client::{ClaudeClient, CommandRunner, DefaultRunner};
4use crate::config::{ClaudeConfig, ClaudeConfigBuilder};
5use crate::error::ClaudeError;
6use crate::types::ClaudeResponse;
7
8#[cfg(feature = "stream")]
9use crate::stream::StreamEvent;
10#[cfg(feature = "stream")]
11use std::pin::Pin;
12#[cfg(feature = "stream")]
13use tokio_stream::Stream;
14
15#[derive(Debug)]
38pub struct Conversation<R: CommandRunner = DefaultRunner> {
39 config: ClaudeConfig,
40 runner: R,
41 session_id: Arc<Mutex<Option<String>>>,
42}
43
44impl<R: CommandRunner> Conversation<R> {
45 #[must_use]
47 pub fn session_id(&self) -> Option<String> {
48 self.session_id.lock().unwrap().clone()
49 }
50}
51
52impl<R: CommandRunner + Clone> Conversation<R> {
53 pub(crate) fn with_runner(config: ClaudeConfig, runner: R) -> Self {
55 Self {
56 config,
57 runner,
58 session_id: Arc::new(Mutex::new(None)),
59 }
60 }
61
62 pub(crate) fn with_runner_resume(
65 config: ClaudeConfig,
66 runner: R,
67 session_id: impl Into<String>,
68 ) -> Self {
69 Self {
70 config,
71 runner,
72 session_id: Arc::new(Mutex::new(Some(session_id.into()))),
73 }
74 }
75
76 pub async fn ask(&mut self, prompt: &str) -> Result<ClaudeResponse, ClaudeError> {
80 self.ask_with(prompt, |b| b).await
81 }
82
83 pub async fn ask_with<F>(
88 &mut self,
89 prompt: &str,
90 config_fn: F,
91 ) -> Result<ClaudeResponse, ClaudeError>
92 where
93 F: FnOnce(ClaudeConfigBuilder) -> ClaudeConfigBuilder,
94 {
95 let builder = config_fn(self.config.to_builder());
96 let mut config = builder.build();
97
98 if let Some(ref id) = *self.session_id.lock().unwrap() {
99 config.resume = Some(id.clone());
100 }
101
102 let client = ClaudeClient::with_runner(config, self.runner.clone());
103 let response = client.ask(prompt).await?;
104
105 *self.session_id.lock().unwrap() = Some(response.session_id.clone());
106
107 Ok(response)
108 }
109}
110
111#[cfg(feature = "stream")]
112fn wrap_stream(
115 inner: Pin<Box<dyn Stream<Item = Result<StreamEvent, ClaudeError>> + Send>>,
116 session_id: Arc<Mutex<Option<String>>>,
117) -> Pin<Box<dyn Stream<Item = Result<StreamEvent, ClaudeError>> + Send>> {
118 Box::pin(async_stream::stream! {
119 tokio::pin!(inner);
120 while let Some(item) = tokio_stream::StreamExt::next(&mut inner).await {
121 if let Ok(ref event) = item {
122 match event {
123 StreamEvent::SystemInit { session_id: sid, .. } => {
124 *session_id.lock().unwrap() = Some(sid.clone());
125 }
126 StreamEvent::Result(response) => {
127 *session_id.lock().unwrap() = Some(response.session_id.clone());
128 }
129 _ => {}
130 }
131 }
132 yield item;
133 }
134 })
135}
136
137#[cfg(feature = "stream")]
138#[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
139impl Conversation {
140 pub async fn ask_stream(
153 &mut self,
154 prompt: &str,
155 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, ClaudeError>> + Send>>, ClaudeError>
156 {
157 self.ask_stream_with(prompt, |b| b).await
158 }
159
160 pub async fn ask_stream_with<F>(
172 &mut self,
173 prompt: &str,
174 config_fn: F,
175 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, ClaudeError>> + Send>>, ClaudeError>
176 where
177 F: FnOnce(ClaudeConfigBuilder) -> ClaudeConfigBuilder,
178 {
179 let builder = config_fn(self.config.to_builder());
180 let mut config = builder.build();
181
182 if let Some(ref id) = *self.session_id.lock().unwrap() {
183 config.resume = Some(id.clone());
184 }
185
186 let client = ClaudeClient::new(config);
187 let inner = client.ask_stream(prompt).await?;
188
189 Ok(wrap_stream(inner, Arc::clone(&self.session_id)))
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use std::collections::VecDeque;
197 use std::io;
198 use std::os::unix::process::ExitStatusExt;
199 use std::process::{ExitStatus, Output};
200
201 #[derive(Clone)]
205 struct RecordingRunner {
206 responses: Arc<Mutex<VecDeque<io::Result<Output>>>>,
207 captured_args: Arc<Mutex<Vec<Vec<String>>>>,
208 }
209
210 impl RecordingRunner {
211 fn new(responses: Vec<io::Result<Output>>) -> Self {
212 Self {
213 responses: Arc::new(Mutex::new(VecDeque::from(responses))),
214 captured_args: Arc::new(Mutex::new(Vec::new())),
215 }
216 }
217
218 fn captured_args(&self) -> Vec<Vec<String>> {
219 self.captured_args.lock().unwrap().clone()
220 }
221 }
222
223 impl CommandRunner for RecordingRunner {
224 async fn run(&self, args: &[String]) -> io::Result<Output> {
225 self.captured_args.lock().unwrap().push(args.to_vec());
226 self.responses
227 .lock()
228 .unwrap()
229 .pop_front()
230 .expect("RecordingRunner: no more responses")
231 }
232 }
233
234 fn make_success_output(session_id: &str) -> io::Result<Output> {
235 let json = format!(
236 r#"{{"type":"result","subtype":"success","is_error":false,"duration_ms":100,"duration_api_ms":90,"num_turns":1,"result":"Hello!","stop_reason":"end_turn","session_id":"{session_id}","total_cost_usd":0.001,"usage":{{"input_tokens":10,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":5,"server_tool_use":{{"web_search_requests":0,"web_fetch_requests":0}}}}}}"#
237 );
238 Ok(Output {
239 status: ExitStatus::from_raw(0),
240 stdout: json.into_bytes(),
241 stderr: Vec::new(),
242 })
243 }
244
245 #[tokio::test]
246 async fn session_id_initially_none() {
247 let runner = RecordingRunner::new(vec![]);
248 let conv = Conversation::with_runner(ClaudeConfig::default(), runner);
249 assert!(conv.session_id().is_none());
250 }
251
252 #[tokio::test]
253 async fn ask_captures_session_id() {
254 let runner = RecordingRunner::new(vec![make_success_output("sid-001")]);
255 let mut conv = Conversation::with_runner(ClaudeConfig::default(), runner);
256
257 let resp = conv.ask("hello").await.unwrap();
258 assert_eq!(resp.session_id, "sid-001");
259 assert_eq!(conv.session_id(), Some("sid-001".to_string()));
260 }
261
262 #[tokio::test]
263 async fn second_turn_sends_resume() {
264 let runner = RecordingRunner::new(vec![
265 make_success_output("sid-001"),
266 make_success_output("sid-001"),
267 ]);
268 let mut conv = Conversation::with_runner(ClaudeConfig::default(), runner.clone());
269
270 conv.ask("turn 1").await.unwrap();
271 conv.ask("turn 2").await.unwrap();
272
273 let args = runner.captured_args();
274 assert!(!args[0].contains(&"--resume".to_string()));
276 let idx = args[1].iter().position(|a| a == "--resume").unwrap();
278 assert_eq!(args[1][idx + 1], "sid-001");
279 }
280
281 #[tokio::test]
282 async fn ask_with_overrides_config() {
283 let runner = RecordingRunner::new(vec![make_success_output("sid-001")]);
284 let mut conv = Conversation::with_runner(ClaudeConfig::default(), runner.clone());
285
286 conv.ask_with("hello", |b| b.max_turns(5)).await.unwrap();
287
288 let args = &runner.captured_args()[0];
289 let idx = args.iter().position(|a| a == "--max-turns").unwrap();
290 assert_eq!(args[idx + 1], "5");
291 }
292
293 #[tokio::test]
294 async fn ask_with_does_not_affect_base_config() {
295 let runner = RecordingRunner::new(vec![
296 make_success_output("sid-001"),
297 make_success_output("sid-001"),
298 ]);
299 let config = ClaudeConfig::builder().max_turns(1).build();
300 let mut conv = Conversation::with_runner(config, runner.clone());
301
302 conv.ask_with("turn 1", |b| b.max_turns(5)).await.unwrap();
303 conv.ask("turn 2").await.unwrap();
304
305 let args = runner.captured_args();
306 let idx1 = args[0].iter().position(|a| a == "--max-turns").unwrap();
307 assert_eq!(args[0][idx1 + 1], "5");
308 let idx2 = args[1].iter().position(|a| a == "--max-turns").unwrap();
309 assert_eq!(args[1][idx2 + 1], "1");
310 }
311
312 #[tokio::test]
313 async fn error_preserves_session_id() {
314 let error_output: io::Result<Output> = Ok(Output {
315 status: ExitStatus::from_raw(256), stdout: Vec::new(),
317 stderr: b"error".to_vec(),
318 });
319 let runner = RecordingRunner::new(vec![make_success_output("sid-001"), error_output]);
320 let mut conv = Conversation::with_runner(ClaudeConfig::default(), runner);
321
322 conv.ask("turn 1").await.unwrap();
323 assert_eq!(conv.session_id(), Some("sid-001".to_string()));
324
325 let _ = conv.ask("turn 2").await;
326 assert_eq!(conv.session_id(), Some("sid-001".to_string()));
327 }
328
329 #[tokio::test]
330 async fn conversation_resume_sends_resume_on_first_turn() {
331 let runner = RecordingRunner::new(vec![make_success_output("sid-001")]);
332 let mut conv = Conversation::with_runner_resume(
333 ClaudeConfig::default(),
334 runner.clone(),
335 "existing-sid",
336 );
337
338 conv.ask("hello").await.unwrap();
339
340 let args = &runner.captured_args()[0];
341 let idx = args.iter().position(|a| a == "--resume").unwrap();
342 assert_eq!(args[idx + 1], "existing-sid");
343 }
344
345 #[cfg(feature = "stream")]
346 use crate::stream::StreamEvent;
347 #[cfg(feature = "stream")]
348 use crate::types::Usage;
349
350 #[cfg(feature = "stream")]
351 #[tokio::test]
352 async fn wrap_stream_captures_session_id_from_system_init() {
353 let session_id: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
354 let events: Vec<Result<StreamEvent, ClaudeError>> = vec![
355 Ok(StreamEvent::SystemInit {
356 session_id: "sid-stream-001".into(),
357 model: "haiku".into(),
358 }),
359 Ok(StreamEvent::AssistantText("Hello!".into())),
360 ];
361 let inner: Pin<Box<dyn Stream<Item = Result<StreamEvent, ClaudeError>> + Send>> =
362 Box::pin(tokio_stream::iter(events));
363
364 let wrapped = wrap_stream(inner, Arc::clone(&session_id));
365 tokio::pin!(wrapped);
366
367 let mut count = 0;
368 while (tokio_stream::StreamExt::next(&mut wrapped).await).is_some() {
369 count += 1;
370 }
371
372 assert_eq!(
373 *session_id.lock().unwrap(),
374 Some("sid-stream-001".to_string())
375 );
376 assert_eq!(count, 2);
377 }
378
379 #[cfg(feature = "stream")]
380 #[tokio::test]
381 async fn wrap_stream_updates_session_id_from_result() {
382 let session_id: Arc<Mutex<Option<String>>> =
383 Arc::new(Mutex::new(Some("old-sid".to_string())));
384 let response = ClaudeResponse {
385 result: "Hello!".into(),
386 is_error: false,
387 duration_ms: 100,
388 num_turns: 1,
389 session_id: "new-sid".into(),
390 total_cost_usd: 0.001,
391 stop_reason: "end_turn".into(),
392 usage: Usage {
393 input_tokens: 10,
394 output_tokens: 5,
395 cache_read_input_tokens: 0,
396 cache_creation_input_tokens: 0,
397 },
398 };
399 let events: Vec<Result<StreamEvent, ClaudeError>> = vec![
400 Ok(StreamEvent::SystemInit {
401 session_id: "old-sid".into(),
402 model: "haiku".into(),
403 }),
404 Ok(StreamEvent::Result(response)),
405 ];
406 let inner: Pin<Box<dyn Stream<Item = Result<StreamEvent, ClaudeError>> + Send>> =
407 Box::pin(tokio_stream::iter(events));
408
409 let wrapped = wrap_stream(inner, Arc::clone(&session_id));
410 tokio::pin!(wrapped);
411 while (tokio_stream::StreamExt::next(&mut wrapped).await).is_some() {}
412
413 assert_eq!(*session_id.lock().unwrap(), Some("new-sid".to_string()));
414 }
415
416 #[tokio::test]
417 async fn client_conversation_creates_working_conversation() {
418 let runner = RecordingRunner::new(vec![make_success_output("sid-001")]);
419 let config = ClaudeConfig::builder().model("haiku").build();
420 let client = ClaudeClient::with_runner(config, runner);
421
422 let mut conv = client.conversation();
423 let resp = conv.ask("hello").await.unwrap();
424 assert_eq!(resp.session_id, "sid-001");
425 }
426
427 #[tokio::test]
428 async fn client_conversation_resume_sends_resume() {
429 let runner = RecordingRunner::new(vec![make_success_output("sid-001")]);
430 let client = ClaudeClient::with_runner(ClaudeConfig::default(), runner.clone());
431
432 let mut conv = client.conversation_resume("existing-sid");
433 conv.ask("hello").await.unwrap();
434
435 let args = &runner.captured_args()[0];
436 let idx = args.iter().position(|a| a == "--resume").unwrap();
437 assert_eq!(args[idx + 1], "existing-sid");
438 }
439}