Skip to main content

claude_code_sdk_rust/
client.rs

1//! Main client for interactive sessions with Claude CLI.
2
3use futures::{Stream, StreamExt};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::{mpsc, RwLock};
7
8use crate::client_stream::stream_events_from_message;
9use crate::client_types::{MessageResponse, StreamEvent};
10use crate::error::{CLIConnectionError, Result};
11use crate::internal::control::{
12    initialize_request, initialize_timeout_duration, respond_to_control_request,
13    send_control_request_with_callbacks, send_control_request_with_callbacks_and_timeout,
14    ControlCallbacks,
15};
16use crate::internal::parser::parse_message_line;
17use crate::internal::session_resume::{
18    apply_materialized_options, materialize_resume_session, MaterializedResume,
19};
20use crate::internal::session_store_validation::validate_session_store_options;
21use crate::internal::transcript_mirror::TranscriptMirrorBatcher;
22use crate::internal::transport::{SubprocessCLITransport, Transport, TransportOptions};
23use crate::types::{
24    ClaudeAgentOptions, ContentBlock, ContextUsageResponse, MCPStatusResponse, Message,
25    PermissionMode,
26};
27
28#[derive(Debug)]
29#[allow(dead_code)]
30struct ClientState {
31    messages: Vec<Message>,
32    current_stream_buffer: String,
33    is_streaming: bool,
34    server_info: Option<HashMap<String, serde_json::Value>>,
35}
36
37pub struct ClaudeAgentClient {
38    transport: Box<dyn Transport>,
39    state: Arc<RwLock<ClientState>>,
40    session_id: String,
41    connected: bool,
42    initialized: bool,
43    initialization_result: Option<serde_json::Map<String, serde_json::Value>>,
44    control_callbacks: ControlCallbacks,
45    transcript_mirror: Option<TranscriptMirrorBatcher>,
46    source_options: Option<ClaudeAgentOptions>,
47    materialized_resume: Option<MaterializedResume>,
48}
49
50impl std::fmt::Debug for ClaudeAgentClient {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        f.debug_struct("ClaudeAgentClient")
53            .field("session_id", &self.session_id)
54            .finish_non_exhaustive()
55    }
56}
57
58impl ClaudeAgentClient {
59    /// Fire-and-forget streaming for a single prompt.
60    ///
61    /// Spawns a background task that creates a fresh client from `options`,
62    /// connects, and streams `content`, forwarding each [`StreamEvent`] to the
63    /// returned receiver. Unlike [`Self::stream_message`], this returns
64    /// immediately and the spawned task owns the client for the lifetime of the
65    /// stream, so the caller only needs to hold the receiver. Connection or
66    /// streaming failures are surfaced as a final [`StreamEvent::Error`].
67    ///
68    /// Must be called from within a Tokio runtime.
69    pub fn spawn_stream_message(
70        options: ClaudeAgentOptions,
71        content: impl Into<String>,
72    ) -> mpsc::UnboundedReceiver<StreamEvent> {
73        let content = content.into();
74        let (tx, rx) = mpsc::unbounded_channel();
75        tokio::spawn(async move {
76            if let Err(err) = Self::run_spawned_stream(options, content, tx.clone()).await {
77                let _ = tx.send(StreamEvent::Error(err.to_string()));
78            }
79        });
80        rx
81    }
82
83    async fn run_spawned_stream(
84        options: ClaudeAgentOptions,
85        content: String,
86        tx: mpsc::UnboundedSender<StreamEvent>,
87    ) -> Result<()> {
88        let mut client = Self::new(options)?;
89        client.connect().await?;
90        client.require_connected()?;
91        let payload = client.build_user_payload(&content, None)?;
92        let json_payload = serde_json::to_vec(&payload)?;
93        client.transport.write(&json_payload).await?;
94        client.transport.write(b"\n").await?;
95        {
96            let mut state = client.state.write().await;
97            state.is_streaming = true;
98        }
99        while let Some(data) = client.transport.read().await? {
100            let line = String::from_utf8_lossy(&data);
101            let value = serde_json::from_slice::<serde_json::Value>(&data)?;
102            if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
103                respond_to_control_request(
104                    client.transport.as_mut(),
105                    &value,
106                    &client.control_callbacks,
107                )
108                .await?;
109                continue;
110            }
111            if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
112                if let Some(batcher) = &mut client.transcript_mirror {
113                    for message in batcher.enqueue_value(&value).await? {
114                        let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
115                    }
116                }
117                continue;
118            }
119            let message = match parse_message_line(&line) {
120                Ok(Some(message)) => message,
121                Ok(None) => continue,
122                Err(err) => {
123                    // A single unrecognized message shape must not kill the
124                    // whole turn. Log it (payload included) and keep going.
125                    tracing::warn!("skipping unparseable CLI message: {err}");
126                    continue;
127                }
128            };
129            for event in stream_events_from_message(&message, &client.session_id) {
130                let _ = tx.send(event);
131            }
132            let done = matches!(message, Message::ResultMsg { .. });
133            if done {
134                if let Some(batcher) = &mut client.transcript_mirror {
135                    for message in batcher.flush().await? {
136                        let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
137                    }
138                }
139            }
140            {
141                let mut state = client.state.write().await;
142                state.messages.push(message);
143                if done {
144                    state.is_streaming = false;
145                }
146            }
147            if done {
148                break;
149            }
150        }
151        Ok(())
152    }
153
154    pub fn new(options: ClaudeAgentOptions) -> Result<Self> {
155        validate_session_store_options(&options)?;
156        let transport_options = TransportOptions::from(&options);
157        let transport = SubprocessCLITransport::new(transport_options);
158        let mut client = Self::with_transport(options.clone(), Box::new(transport))?;
159        client.source_options = Some(options);
160        Ok(client)
161    }
162
163    pub fn with_transport(
164        options: ClaudeAgentOptions,
165        transport: Box<dyn Transport>,
166    ) -> Result<Self> {
167        let session_id = options
168            .session_id
169            .clone()
170            .or_else(|| options.resume.clone())
171            .unwrap_or_else(|| "default".to_string());
172        let state = Arc::new(RwLock::new(ClientState {
173            messages: Vec::new(),
174            current_stream_buffer: String::new(),
175            is_streaming: false,
176            server_info: None,
177        }));
178        Ok(Self {
179            transport,
180            state,
181            session_id,
182            connected: false,
183            initialized: false,
184            initialization_result: None,
185            control_callbacks: ControlCallbacks::from_options(&options),
186            transcript_mirror: TranscriptMirrorBatcher::from_options(&options),
187            source_options: None,
188            materialized_resume: None,
189        })
190    }
191
192    pub async fn connect(&mut self) -> Result<()> {
193        if !self.connected {
194            self.materialize_resume_before_connect().await?;
195            self.transport.connect().await?;
196            self.connected = true;
197        }
198        self.ensure_initialized().await?;
199        Ok(())
200    }
201
202    pub async fn connect_with_prompt(&mut self, content: impl Into<String>) -> Result<()> {
203        self.connect().await?;
204        let content = content.into();
205        let payload = self.build_user_payload(&content, None)?;
206        let mut json_payload = serde_json::to_vec(&payload)?;
207        json_payload.push(b'\n');
208        self.transport.write(&json_payload).await
209    }
210
211    pub async fn connect_with_stream<S>(&mut self, stream: S) -> Result<()>
212    where
213        S: Stream<Item = serde_json::Value> + Unpin,
214    {
215        self.connect().await?;
216        self.write_message_stream(stream, "default").await
217    }
218
219    async fn materialize_resume_before_connect(&mut self) -> Result<()> {
220        let Some(options) = self.source_options.clone() else {
221            return Ok(());
222        };
223        let Some(materialized) = materialize_resume_session(&options).await? else {
224            return Ok(());
225        };
226        let options = apply_materialized_options(&options, &materialized);
227        self.session_id = options
228            .session_id
229            .clone()
230            .or_else(|| options.resume.clone())
231            .unwrap_or_else(|| "default".to_string());
232        self.transport = Box::new(SubprocessCLITransport::new(TransportOptions::from(
233            &options,
234        )));
235        self.transcript_mirror = TranscriptMirrorBatcher::from_options(&options);
236        self.source_options = Some(options);
237        self.materialized_resume = Some(materialized);
238        Ok(())
239    }
240
241    fn require_connected(&self) -> Result<()> {
242        if self.connected && self.initialized {
243            Ok(())
244        } else {
245            Err(CLIConnectionError::new("Not connected. Call connect() first.").into())
246        }
247    }
248
249    async fn ensure_initialized(&mut self) -> Result<()> {
250        if self.initialized {
251            return Ok(());
252        }
253
254        let response = send_control_request_with_callbacks_and_timeout(
255            self.transport.as_mut(),
256            initialize_request(&self.control_callbacks),
257            &self.control_callbacks,
258            initialize_timeout_duration(),
259        )
260        .await?;
261        self.initialization_result = Some(response);
262        self.initialized = true;
263        Ok(())
264    }
265
266    pub async fn send_message(&mut self, content: impl Into<String>) -> Result<MessageResponse> {
267        self.query(content).await?;
268        let messages = self.receive_response().await?;
269        let mut content_parts: Vec<String> = Vec::new();
270        let mut blocks: Vec<ContentBlock> = Vec::new();
271        let mut usage: Option<HashMap<String, serde_json::Value>> = None;
272        let mut stop_reason: Option<String> = None;
273        let mut model = String::new();
274
275        for message in messages {
276            match message {
277                Message::AssistantMsg {
278                    content: assistant_content,
279                    ..
280                } => {
281                    // Track the model from the first assistant message
282                    if model.is_empty() {
283                        model.clone_from(&assistant_content.model);
284                    }
285                    for block in &assistant_content.content {
286                        match block {
287                            ContentBlock::Text { text } => content_parts.push(text.clone()),
288                            ContentBlock::Thinking { thinking, .. } => {
289                                content_parts.push(thinking.clone())
290                            }
291                            _ => {}
292                        }
293                        blocks.push(block.clone());
294                    }
295                }
296                Message::ResultMsg {
297                    stop_reason: reason,
298                    usage: u,
299                    ..
300                } => {
301                    stop_reason = reason;
302                    if let Some(u) = u {
303                        usage = Some(u.into_iter().collect());
304                    }
305                }
306                _ => {}
307            }
308        }
309
310        Ok(MessageResponse {
311            content: content_parts.join(""),
312            blocks,
313            model,
314            stop_reason,
315            session_id: self.session_id.clone(),
316            usage,
317        })
318    }
319
320    pub async fn query(&mut self, content: impl Into<String>) -> Result<()> {
321        self.require_connected()?;
322        let content_str = content.into();
323        let payload = self.build_user_payload(&content_str, None)?;
324        let mut json_payload = serde_json::to_vec(&payload)?;
325        json_payload.push(b'\n');
326        self.transport.write(&json_payload).await
327    }
328
329    pub async fn query_with_session_id(
330        &mut self,
331        content: impl Into<String>,
332        session_id: impl Into<String>,
333    ) -> Result<()> {
334        self.require_connected()?;
335        let content_str = content.into();
336        let session_id = session_id.into();
337        let payload = self.build_user_payload(&content_str, Some(&session_id))?;
338        let mut json_payload = serde_json::to_vec(&payload)?;
339        json_payload.push(b'\n');
340        self.transport.write(&json_payload).await
341    }
342
343    pub async fn query_stream<S>(&mut self, stream: S) -> Result<()>
344    where
345        S: Stream<Item = serde_json::Value> + Unpin,
346    {
347        self.query_stream_with_session_id(stream, "default").await
348    }
349
350    pub async fn query_stream_with_session_id<S>(
351        &mut self,
352        stream: S,
353        session_id: impl Into<String>,
354    ) -> Result<()>
355    where
356        S: Stream<Item = serde_json::Value> + Unpin,
357    {
358        self.require_connected()?;
359        self.write_message_stream(stream, &session_id.into()).await
360    }
361
362    pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
363        self.receive_messages_until(true).await
364    }
365
366    pub async fn receive_messages(&mut self) -> Result<Vec<Message>> {
367        self.receive_messages_until(false).await
368    }
369
370    async fn receive_messages_until(&mut self, stop_at_result: bool) -> Result<Vec<Message>> {
371        self.require_connected()?;
372        let mut messages = Vec::new();
373        while let Some(data) = self.transport.read().await? {
374            let line = String::from_utf8_lossy(&data);
375            let value = serde_json::from_slice::<serde_json::Value>(&data)?;
376            if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
377                respond_to_control_request(
378                    self.transport.as_mut(),
379                    &value,
380                    &self.control_callbacks,
381                )
382                .await?;
383                continue;
384            }
385            if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
386                if let Some(batcher) = &mut self.transcript_mirror {
387                    messages.extend(batcher.enqueue_value(&value).await?);
388                }
389                continue;
390            }
391            let message = match parse_message_line(&line) {
392                Ok(Some(message)) => message,
393                Ok(None) => continue,
394                Err(err) => {
395                    // A single unrecognized message shape must not kill the
396                    // whole turn. Log it (payload included) and keep going.
397                    tracing::warn!("skipping unparseable CLI message: {err}");
398                    continue;
399                }
400            };
401            let done = matches!(message, Message::ResultMsg { .. });
402            if done {
403                if let Some(batcher) = &mut self.transcript_mirror {
404                    messages.extend(batcher.flush().await?);
405                }
406            }
407            {
408                let mut state = self.state.write().await;
409                state.messages.push(message.clone());
410            }
411            messages.push(message);
412            if stop_at_result && done {
413                break;
414            }
415        }
416        Ok(messages)
417    }
418
419    pub async fn stream_message(
420        &mut self,
421        content: impl Into<String>,
422    ) -> Result<mpsc::UnboundedReceiver<StreamEvent>> {
423        self.require_connected()?;
424        let content_str = content.into();
425        let payload = self.build_user_payload(&content_str, None)?;
426        let json_payload = serde_json::to_vec(&payload)?;
427        self.transport.write(&json_payload).await?;
428        self.transport
429            .write(
430                b"
431",
432            )
433            .await?;
434        let (tx, rx) = mpsc::unbounded_channel();
435        {
436            let mut state = self.state.write().await;
437            state.is_streaming = true;
438        }
439        while let Some(data) = self.transport.read().await? {
440            let line = String::from_utf8_lossy(&data);
441            let value = serde_json::from_slice::<serde_json::Value>(&data)?;
442            if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
443                respond_to_control_request(
444                    self.transport.as_mut(),
445                    &value,
446                    &self.control_callbacks,
447                )
448                .await?;
449                continue;
450            }
451            if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
452                if let Some(batcher) = &mut self.transcript_mirror {
453                    for message in batcher.enqueue_value(&value).await? {
454                        let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
455                    }
456                }
457                continue;
458            }
459            let message = match parse_message_line(&line) {
460                Ok(Some(message)) => message,
461                Ok(None) => continue,
462                Err(err) => {
463                    // A single unrecognized message shape must not kill the
464                    // whole turn. Log it (payload included) and keep going.
465                    tracing::warn!("skipping unparseable CLI message: {err}");
466                    continue;
467                }
468            };
469            for event in stream_events_from_message(&message, &self.session_id) {
470                let _ = tx.send(event);
471            }
472            let done = matches!(message, Message::ResultMsg { .. });
473            if done {
474                if let Some(batcher) = &mut self.transcript_mirror {
475                    for message in batcher.flush().await? {
476                        let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
477                    }
478                }
479            }
480            {
481                let mut state = self.state.write().await;
482                state.messages.push(message);
483                if done {
484                    state.is_streaming = false;
485                }
486            }
487            if done {
488                break;
489            }
490        }
491        Ok(rx)
492    }
493
494    async fn write_message_stream<S>(&mut self, mut stream: S, session_id: &str) -> Result<()>
495    where
496        S: Stream<Item = serde_json::Value> + Unpin,
497    {
498        while let Some(mut message) = stream.next().await {
499            if let Some(object) = message.as_object_mut() {
500                object
501                    .entry("session_id")
502                    .or_insert_with(|| serde_json::Value::String(session_id.to_string()));
503            }
504            let mut json_payload = serde_json::to_vec(&message)?;
505            json_payload.push(b'\n');
506            self.transport.write(&json_payload).await?;
507        }
508        Ok(())
509    }
510
511    pub async fn get_conversation_history(&self) -> Result<Vec<Message>> {
512        let state = self.state.read().await;
513        Ok(state.messages.clone())
514    }
515
516    pub async fn abort(&mut self) -> Result<()> {
517        if let Some(batcher) = &mut self.transcript_mirror {
518            let _ = batcher.flush().await?;
519        }
520        self.transport.close().await?;
521        if let Some(materialized) = &self.materialized_resume {
522            materialized.cleanup().await;
523        }
524        self.materialized_resume = None;
525        self.connected = false;
526        self.initialized = false;
527        Ok(())
528    }
529
530    pub async fn disconnect(&mut self) -> Result<()> {
531        self.abort().await
532    }
533
534    pub async fn close(mut self) -> Result<()> {
535        if let Some(batcher) = &mut self.transcript_mirror {
536            let _ = batcher.flush().await?;
537        }
538        self.transport.close().await?;
539        if let Some(materialized) = &self.materialized_resume {
540            materialized.cleanup().await;
541        }
542        Ok(())
543    }
544
545    pub async fn interrupt(&mut self) -> Result<()> {
546        self.require_connected()?;
547        send_control_request_with_callbacks(
548            self.transport.as_mut(),
549            serde_json::json!({"subtype": "interrupt"}),
550            &self.control_callbacks,
551        )
552        .await?;
553        Ok(())
554    }
555
556    pub async fn set_permission_mode(&mut self, mode: PermissionMode) -> Result<()> {
557        self.require_connected()?;
558        send_control_request_with_callbacks(
559            self.transport.as_mut(),
560            serde_json::json!({
561                "subtype": "set_permission_mode",
562                "mode": mode,
563            }),
564            &self.control_callbacks,
565        )
566        .await?;
567        Ok(())
568    }
569
570    pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
571        self.require_connected()?;
572        let model = model.map(serde_json::Value::String);
573        send_control_request_with_callbacks(
574            self.transport.as_mut(),
575            serde_json::json!({
576                "subtype": "set_model",
577                "model": model.unwrap_or(serde_json::Value::Null),
578            }),
579            &self.control_callbacks,
580        )
581        .await?;
582        Ok(())
583    }
584
585    pub async fn rewind_files(&mut self, user_message_id: impl Into<String>) -> Result<()> {
586        self.require_connected()?;
587        send_control_request_with_callbacks(
588            self.transport.as_mut(),
589            serde_json::json!({
590                "subtype": "rewind_files",
591                "user_message_id": user_message_id.into(),
592            }),
593            &self.control_callbacks,
594        )
595        .await?;
596        Ok(())
597    }
598
599    pub async fn reconnect_mcp_server(&mut self, server_name: impl Into<String>) -> Result<()> {
600        self.require_connected()?;
601        send_control_request_with_callbacks(
602            self.transport.as_mut(),
603            serde_json::json!({
604                "subtype": "mcp_reconnect",
605                "serverName": server_name.into(),
606            }),
607            &self.control_callbacks,
608        )
609        .await?;
610        Ok(())
611    }
612
613    pub async fn toggle_mcp_server(
614        &mut self,
615        server_name: impl Into<String>,
616        enabled: bool,
617    ) -> Result<()> {
618        self.require_connected()?;
619        send_control_request_with_callbacks(
620            self.transport.as_mut(),
621            serde_json::json!({
622                "subtype": "mcp_toggle",
623                "serverName": server_name.into(),
624                "enabled": enabled,
625            }),
626            &self.control_callbacks,
627        )
628        .await?;
629        Ok(())
630    }
631
632    pub async fn stop_task(&mut self, task_id: impl Into<String>) -> Result<()> {
633        self.require_connected()?;
634        send_control_request_with_callbacks(
635            self.transport.as_mut(),
636            serde_json::json!({
637                "subtype": "stop_task",
638                "task_id": task_id.into(),
639            }),
640            &self.control_callbacks,
641        )
642        .await?;
643        Ok(())
644    }
645
646    pub async fn get_mcp_status(&mut self) -> Result<MCPStatusResponse> {
647        self.require_connected()?;
648        let response = send_control_request_with_callbacks(
649            self.transport.as_mut(),
650            serde_json::json!({"subtype": "mcp_status"}),
651            &self.control_callbacks,
652        )
653        .await?;
654        let value = serde_json::Value::Object(response);
655        Ok(serde_json::from_value(value)?)
656    }
657
658    pub async fn get_context_usage(&mut self) -> Result<ContextUsageResponse> {
659        self.require_connected()?;
660        let response = send_control_request_with_callbacks(
661            self.transport.as_mut(),
662            serde_json::json!({"subtype": "get_context_usage"}),
663            &self.control_callbacks,
664        )
665        .await?;
666        Ok(serde_json::from_value(serde_json::Value::Object(response))?)
667    }
668
669    pub fn get_server_info(&self) -> Option<&serde_json::Map<String, serde_json::Value>> {
670        self.initialization_result.as_ref()
671    }
672
673    fn build_user_payload(
674        &self,
675        content: &str,
676        session_id: Option<&str>,
677    ) -> Result<serde_json::Map<String, serde_json::Value>> {
678        let mut payload = serde_json::Map::new();
679        payload.insert(
680            "type".to_string(),
681            serde_json::Value::String("user".to_string()),
682        );
683        payload.insert(
684            "session_id".to_string(),
685            serde_json::Value::String(
686                session_id
687                    .map(String::from)
688                    .unwrap_or_else(|| self.session_id.clone()),
689            ),
690        );
691        let message = serde_json::json!({"role": "user", "content": content});
692        payload.insert("message".to_string(), message);
693        Ok(payload)
694    }
695}