Skip to main content

claude_code/
conversation.rs

1use std::sync::{Arc, Mutex};
2
3use crate::client::{ClaudeClient, CommandRunner, DefaultRunner};
4use crate::config::{ClaudeConfig, ClaudeConfigBuilder};
5use crate::error::ClaudeError;
6use crate::types::ClaudeResponse;
7
8#[cfg(feature = "stream")]
9use crate::stream::StreamEvent;
10#[cfg(feature = "stream")]
11use std::pin::Pin;
12#[cfg(feature = "stream")]
13use tokio_stream::Stream;
14
15/// Stateful multi-turn conversation wrapper around [`ClaudeClient`].
16///
17/// Manages `session_id` automatically across turns using `--resume`.
18/// The base config is cloned per turn; each turn builds a temporary
19/// config with `--resume <session_id>` injected.
20///
21/// # Design decisions
22///
23/// **Ownership model:** Owns cloned copies of [`ClaudeConfig`] and the runner
24/// instead of borrowing `&ClaudeClient`. `ClaudeClient` is stateless (config +
25/// runner only, no connection pool), so cloning is cheap and avoids lifetime
26/// parameters that complicate async usage (spawn, struct storage).
27///
28/// **session_id storage:** Uses `Arc<Mutex<Option<String>>>` so that the
29/// streaming path can update the session ID while the caller consumes the
30/// returned `Stream` (which outlives the `&mut self` borrow).
31///
32/// # Note
33///
34/// Callers must set [`ClaudeConfigBuilder::no_session_persistence`]`(false)` in
35/// the config for multi-turn to work. The library does not override this; option
36/// validation is the CLI's responsibility.
37#[derive(Debug)]
38pub struct Conversation<R: CommandRunner = DefaultRunner> {
39    config: ClaudeConfig,
40    runner: R,
41    session_id: Arc<Mutex<Option<String>>>,
42}
43
44impl<R: CommandRunner> Conversation<R> {
45    /// Returns the current session ID, or `None` if no turn has completed.
46    #[must_use]
47    pub fn session_id(&self) -> Option<String> {
48        self.session_id.lock().unwrap().clone()
49    }
50}
51
52impl<R: CommandRunner + Clone> Conversation<R> {
53    /// Creates a new conversation (internal; use [`ClaudeClient::conversation`]).
54    pub(crate) fn with_runner(config: ClaudeConfig, runner: R) -> Self {
55        Self {
56            config,
57            runner,
58            session_id: Arc::new(Mutex::new(None)),
59        }
60    }
61
62    /// Creates a conversation resuming an existing session (internal;
63    /// use [`ClaudeClient::conversation_resume`]).
64    pub(crate) fn with_runner_resume(
65        config: ClaudeConfig,
66        runner: R,
67        session_id: impl Into<String>,
68    ) -> Self {
69        Self {
70            config,
71            runner,
72            session_id: Arc::new(Mutex::new(Some(session_id.into()))),
73        }
74    }
75
76    /// Sends a prompt and returns the response.
77    ///
78    /// Shorthand for `ask_with(prompt, |b| b)`.
79    pub async fn ask(&mut self, prompt: &str) -> Result<ClaudeResponse, ClaudeError> {
80        self.ask_with(prompt, |b| b).await
81    }
82
83    /// Sends a prompt with per-turn config overrides and returns the response.
84    ///
85    /// The closure receives a [`ClaudeConfigBuilder`] pre-filled with the base
86    /// config. Overrides apply to this turn only; the base config is unchanged.
87    pub async fn ask_with<F>(
88        &mut self,
89        prompt: &str,
90        config_fn: F,
91    ) -> Result<ClaudeResponse, ClaudeError>
92    where
93        F: FnOnce(ClaudeConfigBuilder) -> ClaudeConfigBuilder,
94    {
95        let builder = config_fn(self.config.to_builder());
96        let mut config = builder.build();
97
98        if let Some(ref id) = *self.session_id.lock().unwrap() {
99            config.resume = Some(id.clone());
100        }
101
102        let client = ClaudeClient::with_runner(config, self.runner.clone());
103        let response = client.ask(prompt).await?;
104
105        *self.session_id.lock().unwrap() = Some(response.session_id.clone());
106
107        Ok(response)
108    }
109}
110
111#[cfg(feature = "stream")]
112/// Wraps a stream to transparently capture `session_id` from
113/// [`StreamEvent::SystemInit`] and [`StreamEvent::Result`].
114fn wrap_stream(
115    inner: Pin<Box<dyn Stream<Item = Result<StreamEvent, ClaudeError>> + Send>>,
116    session_id: Arc<Mutex<Option<String>>>,
117) -> Pin<Box<dyn Stream<Item = Result<StreamEvent, ClaudeError>> + Send>> {
118    Box::pin(async_stream::stream! {
119        tokio::pin!(inner);
120        while let Some(item) = tokio_stream::StreamExt::next(&mut inner).await {
121            if let Ok(ref event) = item {
122                match event {
123                    StreamEvent::SystemInit { session_id: sid, .. } => {
124                        *session_id.lock().unwrap() = Some(sid.clone());
125                    }
126                    StreamEvent::Result(response) => {
127                        *session_id.lock().unwrap() = Some(response.session_id.clone());
128                    }
129                    _ => {}
130                }
131            }
132            yield item;
133        }
134    })
135}
136
137#[cfg(feature = "stream")]
138#[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
139impl Conversation {
140    /// Sends a prompt and returns a stream of events.
141    ///
142    /// Shorthand for `ask_stream_with(prompt, |b| b)`.
143    ///
144    /// Only available for `Conversation<DefaultRunner>` (i.e., conversations
145    /// created via [`ClaudeClient::new`]). The [`CommandRunner`] trait's
146    /// [`run`](CommandRunner::run) method returns a completed [`std::process::Output`],
147    /// which cannot support streaming; therefore streaming always spawns a
148    /// real CLI subprocess.
149    ///
150    /// **Note:** Timeout from the base config is **not** applied to streams.
151    /// Use [`tokio_stream::StreamExt::timeout()`] if needed.
152    pub async fn ask_stream(
153        &mut self,
154        prompt: &str,
155    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, ClaudeError>> + Send>>, ClaudeError>
156    {
157        self.ask_stream_with(prompt, |b| b).await
158    }
159
160    /// Sends a prompt with per-turn config overrides and returns a stream.
161    ///
162    /// The closure receives a [`ClaudeConfigBuilder`] pre-filled with the base
163    /// config. Overrides apply to this turn only; the base config is unchanged.
164    ///
165    /// All events are passed through transparently. Internally, `session_id`
166    /// is captured from [`StreamEvent::SystemInit`] and updated from
167    /// [`StreamEvent::Result`].
168    ///
169    /// Only available for `Conversation<DefaultRunner>`. See [`ask_stream`](Self::ask_stream)
170    /// for details on the streaming constraint.
171    pub async fn ask_stream_with<F>(
172        &mut self,
173        prompt: &str,
174        config_fn: F,
175    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, ClaudeError>> + Send>>, ClaudeError>
176    where
177        F: FnOnce(ClaudeConfigBuilder) -> ClaudeConfigBuilder,
178    {
179        let builder = config_fn(self.config.to_builder());
180        let mut config = builder.build();
181
182        if let Some(ref id) = *self.session_id.lock().unwrap() {
183            config.resume = Some(id.clone());
184        }
185
186        let client = ClaudeClient::new(config);
187        let inner = client.ask_stream(prompt).await?;
188
189        Ok(wrap_stream(inner, Arc::clone(&self.session_id)))
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use std::collections::VecDeque;
197    use std::io;
198    use std::os::unix::process::ExitStatusExt;
199    use std::process::{ExitStatus, Output};
200
201    /// A [`CommandRunner`] that records arguments and returns pre-configured
202    /// responses. Clone-compatible (unlike mockall mocks), which is required
203    /// for `Conversation` since it clones the runner for each turn.
204    #[derive(Clone)]
205    struct RecordingRunner {
206        responses: Arc<Mutex<VecDeque<io::Result<Output>>>>,
207        captured_args: Arc<Mutex<Vec<Vec<String>>>>,
208    }
209
210    impl RecordingRunner {
211        fn new(responses: Vec<io::Result<Output>>) -> Self {
212            Self {
213                responses: Arc::new(Mutex::new(VecDeque::from(responses))),
214                captured_args: Arc::new(Mutex::new(Vec::new())),
215            }
216        }
217
218        fn captured_args(&self) -> Vec<Vec<String>> {
219            self.captured_args.lock().unwrap().clone()
220        }
221    }
222
223    impl CommandRunner for RecordingRunner {
224        async fn run(&self, args: &[String]) -> io::Result<Output> {
225            self.captured_args.lock().unwrap().push(args.to_vec());
226            self.responses
227                .lock()
228                .unwrap()
229                .pop_front()
230                .expect("RecordingRunner: no more responses")
231        }
232    }
233
234    fn make_success_output(session_id: &str) -> io::Result<Output> {
235        let json = format!(
236            r#"{{"type":"result","subtype":"success","is_error":false,"duration_ms":100,"duration_api_ms":90,"num_turns":1,"result":"Hello!","stop_reason":"end_turn","session_id":"{session_id}","total_cost_usd":0.001,"usage":{{"input_tokens":10,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":5,"server_tool_use":{{"web_search_requests":0,"web_fetch_requests":0}}}}}}"#
237        );
238        Ok(Output {
239            status: ExitStatus::from_raw(0),
240            stdout: json.into_bytes(),
241            stderr: Vec::new(),
242        })
243    }
244
245    #[tokio::test]
246    async fn session_id_initially_none() {
247        let runner = RecordingRunner::new(vec![]);
248        let conv = Conversation::with_runner(ClaudeConfig::default(), runner);
249        assert!(conv.session_id().is_none());
250    }
251
252    #[tokio::test]
253    async fn ask_captures_session_id() {
254        let runner = RecordingRunner::new(vec![make_success_output("sid-001")]);
255        let mut conv = Conversation::with_runner(ClaudeConfig::default(), runner);
256
257        let resp = conv.ask("hello").await.unwrap();
258        assert_eq!(resp.session_id, "sid-001");
259        assert_eq!(conv.session_id(), Some("sid-001".to_string()));
260    }
261
262    #[tokio::test]
263    async fn second_turn_sends_resume() {
264        let runner = RecordingRunner::new(vec![
265            make_success_output("sid-001"),
266            make_success_output("sid-001"),
267        ]);
268        let mut conv = Conversation::with_runner(ClaudeConfig::default(), runner.clone());
269
270        conv.ask("turn 1").await.unwrap();
271        conv.ask("turn 2").await.unwrap();
272
273        let args = runner.captured_args();
274        // Turn 1: no --resume
275        assert!(!args[0].contains(&"--resume".to_string()));
276        // Turn 2: --resume sid-001
277        let idx = args[1].iter().position(|a| a == "--resume").unwrap();
278        assert_eq!(args[1][idx + 1], "sid-001");
279    }
280
281    #[tokio::test]
282    async fn ask_with_overrides_config() {
283        let runner = RecordingRunner::new(vec![make_success_output("sid-001")]);
284        let mut conv = Conversation::with_runner(ClaudeConfig::default(), runner.clone());
285
286        conv.ask_with("hello", |b| b.max_turns(5)).await.unwrap();
287
288        let args = &runner.captured_args()[0];
289        let idx = args.iter().position(|a| a == "--max-turns").unwrap();
290        assert_eq!(args[idx + 1], "5");
291    }
292
293    #[tokio::test]
294    async fn ask_with_does_not_affect_base_config() {
295        let runner = RecordingRunner::new(vec![
296            make_success_output("sid-001"),
297            make_success_output("sid-001"),
298        ]);
299        let config = ClaudeConfig::builder().max_turns(1).build();
300        let mut conv = Conversation::with_runner(config, runner.clone());
301
302        conv.ask_with("turn 1", |b| b.max_turns(5)).await.unwrap();
303        conv.ask("turn 2").await.unwrap();
304
305        let args = runner.captured_args();
306        let idx1 = args[0].iter().position(|a| a == "--max-turns").unwrap();
307        assert_eq!(args[0][idx1 + 1], "5");
308        let idx2 = args[1].iter().position(|a| a == "--max-turns").unwrap();
309        assert_eq!(args[1][idx2 + 1], "1");
310    }
311
312    #[tokio::test]
313    async fn error_preserves_session_id() {
314        let error_output: io::Result<Output> = Ok(Output {
315            status: ExitStatus::from_raw(256), // exit code 1
316            stdout: Vec::new(),
317            stderr: b"error".to_vec(),
318        });
319        let runner = RecordingRunner::new(vec![make_success_output("sid-001"), error_output]);
320        let mut conv = Conversation::with_runner(ClaudeConfig::default(), runner);
321
322        conv.ask("turn 1").await.unwrap();
323        assert_eq!(conv.session_id(), Some("sid-001".to_string()));
324
325        let _ = conv.ask("turn 2").await;
326        assert_eq!(conv.session_id(), Some("sid-001".to_string()));
327    }
328
329    #[tokio::test]
330    async fn conversation_resume_sends_resume_on_first_turn() {
331        let runner = RecordingRunner::new(vec![make_success_output("sid-001")]);
332        let mut conv = Conversation::with_runner_resume(
333            ClaudeConfig::default(),
334            runner.clone(),
335            "existing-sid",
336        );
337
338        conv.ask("hello").await.unwrap();
339
340        let args = &runner.captured_args()[0];
341        let idx = args.iter().position(|a| a == "--resume").unwrap();
342        assert_eq!(args[idx + 1], "existing-sid");
343    }
344
345    #[cfg(feature = "stream")]
346    use crate::stream::StreamEvent;
347    #[cfg(feature = "stream")]
348    use crate::types::Usage;
349
350    #[cfg(feature = "stream")]
351    #[tokio::test]
352    async fn wrap_stream_captures_session_id_from_system_init() {
353        let session_id: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
354        let events: Vec<Result<StreamEvent, ClaudeError>> = vec![
355            Ok(StreamEvent::SystemInit {
356                session_id: "sid-stream-001".into(),
357                model: "haiku".into(),
358            }),
359            Ok(StreamEvent::AssistantText("Hello!".into())),
360        ];
361        let inner: Pin<Box<dyn Stream<Item = Result<StreamEvent, ClaudeError>> + Send>> =
362            Box::pin(tokio_stream::iter(events));
363
364        let wrapped = wrap_stream(inner, Arc::clone(&session_id));
365        tokio::pin!(wrapped);
366
367        let mut count = 0;
368        while (tokio_stream::StreamExt::next(&mut wrapped).await).is_some() {
369            count += 1;
370        }
371
372        assert_eq!(
373            *session_id.lock().unwrap(),
374            Some("sid-stream-001".to_string())
375        );
376        assert_eq!(count, 2);
377    }
378
379    #[cfg(feature = "stream")]
380    #[tokio::test]
381    async fn wrap_stream_updates_session_id_from_result() {
382        let session_id: Arc<Mutex<Option<String>>> =
383            Arc::new(Mutex::new(Some("old-sid".to_string())));
384        let response = ClaudeResponse {
385            result: "Hello!".into(),
386            is_error: false,
387            duration_ms: 100,
388            num_turns: 1,
389            session_id: "new-sid".into(),
390            total_cost_usd: 0.001,
391            stop_reason: "end_turn".into(),
392            usage: Usage {
393                input_tokens: 10,
394                output_tokens: 5,
395                cache_read_input_tokens: 0,
396                cache_creation_input_tokens: 0,
397            },
398        };
399        let events: Vec<Result<StreamEvent, ClaudeError>> = vec![
400            Ok(StreamEvent::SystemInit {
401                session_id: "old-sid".into(),
402                model: "haiku".into(),
403            }),
404            Ok(StreamEvent::Result(response)),
405        ];
406        let inner: Pin<Box<dyn Stream<Item = Result<StreamEvent, ClaudeError>> + Send>> =
407            Box::pin(tokio_stream::iter(events));
408
409        let wrapped = wrap_stream(inner, Arc::clone(&session_id));
410        tokio::pin!(wrapped);
411        while (tokio_stream::StreamExt::next(&mut wrapped).await).is_some() {}
412
413        assert_eq!(*session_id.lock().unwrap(), Some("new-sid".to_string()));
414    }
415
416    #[tokio::test]
417    async fn client_conversation_creates_working_conversation() {
418        let runner = RecordingRunner::new(vec![make_success_output("sid-001")]);
419        let config = ClaudeConfig::builder().model("haiku").build();
420        let client = ClaudeClient::with_runner(config, runner);
421
422        let mut conv = client.conversation();
423        let resp = conv.ask("hello").await.unwrap();
424        assert_eq!(resp.session_id, "sid-001");
425    }
426
427    #[tokio::test]
428    async fn client_conversation_resume_sends_resume() {
429        let runner = RecordingRunner::new(vec![make_success_output("sid-001")]);
430        let client = ClaudeClient::with_runner(ClaudeConfig::default(), runner.clone());
431
432        let mut conv = client.conversation_resume("existing-sid");
433        conv.ask("hello").await.unwrap();
434
435        let args = &runner.captured_args()[0];
436        let idx = args.iter().position(|a| a == "--resume").unwrap();
437        assert_eq!(args[idx + 1], "existing-sid");
438    }
439}