Skip to main content

claude_code_sdk_rust/
client.rs

1//! Main client for interactive sessions with Claude CLI.
2
3use futures::{Stream, StreamExt};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::{mpsc, RwLock};
7
8use crate::client_stream::stream_events_from_message;
9use crate::client_types::{MessageResponse, StreamEvent};
10use crate::error::{CLIConnectionError, Result};
11use crate::internal::control::{
12    initialize_request, initialize_timeout_duration, respond_to_control_request,
13    send_control_request_with_callbacks, send_control_request_with_callbacks_and_timeout,
14    ControlCallbacks,
15};
16use crate::internal::parser::parse_message_line;
17use crate::internal::session_resume::{
18    apply_materialized_options, materialize_resume_session, MaterializedResume,
19};
20use crate::internal::session_store_validation::validate_session_store_options;
21use crate::internal::transcript_mirror::TranscriptMirrorBatcher;
22use crate::internal::transport::{SubprocessCLITransport, Transport, TransportOptions};
23use crate::types::{
24    ClaudeAgentOptions, ContentBlock, ContextUsageResponse, MCPStatusResponse, Message,
25    PermissionMode, UserMessageInput,
26};
27
28#[derive(Debug)]
29#[allow(dead_code)]
30struct ClientState {
31    messages: Vec<Message>,
32    current_stream_buffer: String,
33    is_streaming: bool,
34    server_info: Option<HashMap<String, serde_json::Value>>,
35}
36
37pub struct ClaudeAgentClient {
38    transport: Box<dyn Transport>,
39    state: Arc<RwLock<ClientState>>,
40    session_id: String,
41    connected: bool,
42    initialized: bool,
43    initialization_result: Option<serde_json::Map<String, serde_json::Value>>,
44    control_callbacks: ControlCallbacks,
45    transcript_mirror: Option<TranscriptMirrorBatcher>,
46    source_options: Option<ClaudeAgentOptions>,
47    materialized_resume: Option<MaterializedResume>,
48}
49
50impl std::fmt::Debug for ClaudeAgentClient {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        f.debug_struct("ClaudeAgentClient")
53            .field("session_id", &self.session_id)
54            .finish_non_exhaustive()
55    }
56}
57
58impl ClaudeAgentClient {
59    /// Fire-and-forget streaming for a single prompt.
60    ///
61    /// Spawns a background task that creates a fresh client from `options`,
62    /// connects, and streams `content`, forwarding each [`StreamEvent`] to the
63    /// returned receiver. Unlike [`Self::stream_message`], this returns
64    /// immediately and the spawned task owns the client for the lifetime of the
65    /// stream, so the caller only needs to hold the receiver. Connection or
66    /// streaming failures are surfaced as a final [`StreamEvent::Error`].
67    ///
68    /// Accepts any [`UserMessageInput`]: a plain `String`/`&str` for text-only
69    /// prompts, or a `Vec<InputContentBlock>` to deliver text plus images.
70    ///
71    /// Must be called from within a Tokio runtime.
72    pub fn spawn_stream_message(
73        options: ClaudeAgentOptions,
74        content: impl Into<UserMessageInput>,
75    ) -> mpsc::UnboundedReceiver<StreamEvent> {
76        let content = content.into();
77        let (tx, rx) = mpsc::unbounded_channel();
78        tokio::spawn(async move {
79            if let Err(err) = Self::run_spawned_stream(options, content, tx.clone()).await {
80                let _ = tx.send(StreamEvent::Error(err.to_string()));
81            }
82        });
83        rx
84    }
85
86    async fn run_spawned_stream(
87        options: ClaudeAgentOptions,
88        content: UserMessageInput,
89        tx: mpsc::UnboundedSender<StreamEvent>,
90    ) -> Result<()> {
91        let mut client = Self::new(options)?;
92        client.connect().await?;
93        client.require_connected()?;
94        let payload = client.build_user_payload(&content, None)?;
95        let json_payload = serde_json::to_vec(&payload)?;
96        client.transport.write(&json_payload).await?;
97        client.transport.write(b"\n").await?;
98        {
99            let mut state = client.state.write().await;
100            state.is_streaming = true;
101        }
102        while let Some(data) = client.transport.read().await? {
103            let line = String::from_utf8_lossy(&data);
104            let value = serde_json::from_slice::<serde_json::Value>(&data)?;
105            if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
106                respond_to_control_request(
107                    client.transport.as_mut(),
108                    &value,
109                    &client.control_callbacks,
110                )
111                .await?;
112                continue;
113            }
114            if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
115                if let Some(batcher) = &mut client.transcript_mirror {
116                    for message in batcher.enqueue_value(&value).await? {
117                        let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
118                    }
119                }
120                continue;
121            }
122            let message = match parse_message_line(&line) {
123                Ok(Some(message)) => message,
124                Ok(None) => continue,
125                Err(err) => {
126                    // A single unrecognized message shape must not kill the
127                    // whole turn. Log it (payload included) and keep going.
128                    tracing::warn!("skipping unparseable CLI message: {err}");
129                    continue;
130                }
131            };
132            for event in stream_events_from_message(&message, &client.session_id) {
133                let _ = tx.send(event);
134            }
135            let done = matches!(message, Message::ResultMsg { .. });
136            if done {
137                if let Some(batcher) = &mut client.transcript_mirror {
138                    for message in batcher.flush().await? {
139                        let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
140                    }
141                }
142            }
143            {
144                let mut state = client.state.write().await;
145                state.messages.push(message);
146                if done {
147                    state.is_streaming = false;
148                }
149            }
150            if done {
151                break;
152            }
153        }
154        Ok(())
155    }
156
157    pub fn new(options: ClaudeAgentOptions) -> Result<Self> {
158        validate_session_store_options(&options)?;
159        let transport_options = TransportOptions::from(&options);
160        let transport = SubprocessCLITransport::new(transport_options);
161        let mut client = Self::with_transport(options.clone(), Box::new(transport))?;
162        client.source_options = Some(options);
163        Ok(client)
164    }
165
166    pub fn with_transport(
167        options: ClaudeAgentOptions,
168        transport: Box<dyn Transport>,
169    ) -> Result<Self> {
170        let session_id = options
171            .session_id
172            .clone()
173            .or_else(|| options.resume.clone())
174            .unwrap_or_else(|| "default".to_string());
175        let state = Arc::new(RwLock::new(ClientState {
176            messages: Vec::new(),
177            current_stream_buffer: String::new(),
178            is_streaming: false,
179            server_info: None,
180        }));
181        Ok(Self {
182            transport,
183            state,
184            session_id,
185            connected: false,
186            initialized: false,
187            initialization_result: None,
188            control_callbacks: ControlCallbacks::from_options(&options),
189            transcript_mirror: TranscriptMirrorBatcher::from_options(&options),
190            source_options: None,
191            materialized_resume: None,
192        })
193    }
194
195    pub async fn connect(&mut self) -> Result<()> {
196        if !self.connected {
197            self.materialize_resume_before_connect().await?;
198            self.transport.connect().await?;
199            self.connected = true;
200        }
201        self.ensure_initialized().await?;
202        Ok(())
203    }
204
205    pub async fn connect_with_prompt(
206        &mut self,
207        content: impl Into<UserMessageInput>,
208    ) -> Result<()> {
209        self.connect().await?;
210        let content = content.into();
211        let payload = self.build_user_payload(&content, None)?;
212        let mut json_payload = serde_json::to_vec(&payload)?;
213        json_payload.push(b'\n');
214        self.transport.write(&json_payload).await
215    }
216
217    pub async fn connect_with_stream<S>(&mut self, stream: S) -> Result<()>
218    where
219        S: Stream<Item = serde_json::Value> + Unpin,
220    {
221        self.connect().await?;
222        self.write_message_stream(stream, "default").await
223    }
224
225    async fn materialize_resume_before_connect(&mut self) -> Result<()> {
226        let Some(options) = self.source_options.clone() else {
227            return Ok(());
228        };
229        let Some(materialized) = materialize_resume_session(&options).await? else {
230            return Ok(());
231        };
232        let options = apply_materialized_options(&options, &materialized);
233        self.session_id = options
234            .session_id
235            .clone()
236            .or_else(|| options.resume.clone())
237            .unwrap_or_else(|| "default".to_string());
238        self.transport = Box::new(SubprocessCLITransport::new(TransportOptions::from(
239            &options,
240        )));
241        self.transcript_mirror = TranscriptMirrorBatcher::from_options(&options);
242        self.source_options = Some(options);
243        self.materialized_resume = Some(materialized);
244        Ok(())
245    }
246
247    fn require_connected(&self) -> Result<()> {
248        if self.connected && self.initialized {
249            Ok(())
250        } else {
251            Err(CLIConnectionError::new("Not connected. Call connect() first.").into())
252        }
253    }
254
255    async fn ensure_initialized(&mut self) -> Result<()> {
256        if self.initialized {
257            return Ok(());
258        }
259
260        let response = send_control_request_with_callbacks_and_timeout(
261            self.transport.as_mut(),
262            initialize_request(&self.control_callbacks),
263            &self.control_callbacks,
264            initialize_timeout_duration(),
265        )
266        .await?;
267        self.initialization_result = Some(response);
268        self.initialized = true;
269        Ok(())
270    }
271
272    pub async fn send_message(
273        &mut self,
274        content: impl Into<UserMessageInput>,
275    ) -> Result<MessageResponse> {
276        self.query(content).await?;
277        let messages = self.receive_response().await?;
278        let mut content_parts: Vec<String> = Vec::new();
279        let mut blocks: Vec<ContentBlock> = Vec::new();
280        let mut usage: Option<HashMap<String, serde_json::Value>> = None;
281        let mut stop_reason: Option<String> = None;
282        let mut model = String::new();
283
284        for message in messages {
285            match message {
286                Message::AssistantMsg {
287                    content: assistant_content,
288                    ..
289                } => {
290                    // Track the model from the first assistant message
291                    if model.is_empty() {
292                        model.clone_from(&assistant_content.model);
293                    }
294                    for block in &assistant_content.content {
295                        match block {
296                            ContentBlock::Text { text } => content_parts.push(text.clone()),
297                            ContentBlock::Thinking { thinking, .. } => {
298                                content_parts.push(thinking.clone())
299                            }
300                            _ => {}
301                        }
302                        blocks.push(block.clone());
303                    }
304                }
305                Message::ResultMsg {
306                    stop_reason: reason,
307                    usage: u,
308                    ..
309                } => {
310                    stop_reason = reason;
311                    if let Some(u) = u {
312                        usage = Some(u.into_iter().collect());
313                    }
314                }
315                _ => {}
316            }
317        }
318
319        Ok(MessageResponse {
320            content: content_parts.join(""),
321            blocks,
322            model,
323            stop_reason,
324            session_id: self.session_id.clone(),
325            usage,
326        })
327    }
328
329    pub async fn query(&mut self, content: impl Into<UserMessageInput>) -> Result<()> {
330        self.require_connected()?;
331        let content = content.into();
332        let payload = self.build_user_payload(&content, None)?;
333        let mut json_payload = serde_json::to_vec(&payload)?;
334        json_payload.push(b'\n');
335        self.transport.write(&json_payload).await
336    }
337
338    pub async fn query_with_session_id(
339        &mut self,
340        content: impl Into<UserMessageInput>,
341        session_id: impl Into<String>,
342    ) -> Result<()> {
343        self.require_connected()?;
344        let content = content.into();
345        let session_id = session_id.into();
346        let payload = self.build_user_payload(&content, Some(&session_id))?;
347        let mut json_payload = serde_json::to_vec(&payload)?;
348        json_payload.push(b'\n');
349        self.transport.write(&json_payload).await
350    }
351
352    pub async fn query_stream<S>(&mut self, stream: S) -> Result<()>
353    where
354        S: Stream<Item = serde_json::Value> + Unpin,
355    {
356        self.query_stream_with_session_id(stream, "default").await
357    }
358
359    pub async fn query_stream_with_session_id<S>(
360        &mut self,
361        stream: S,
362        session_id: impl Into<String>,
363    ) -> Result<()>
364    where
365        S: Stream<Item = serde_json::Value> + Unpin,
366    {
367        self.require_connected()?;
368        self.write_message_stream(stream, &session_id.into()).await
369    }
370
371    pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
372        self.receive_messages_until(true).await
373    }
374
375    pub async fn receive_messages(&mut self) -> Result<Vec<Message>> {
376        self.receive_messages_until(false).await
377    }
378
379    async fn receive_messages_until(&mut self, stop_at_result: bool) -> Result<Vec<Message>> {
380        self.require_connected()?;
381        let mut messages = Vec::new();
382        while let Some(data) = self.transport.read().await? {
383            let line = String::from_utf8_lossy(&data);
384            let value = serde_json::from_slice::<serde_json::Value>(&data)?;
385            if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
386                respond_to_control_request(
387                    self.transport.as_mut(),
388                    &value,
389                    &self.control_callbacks,
390                )
391                .await?;
392                continue;
393            }
394            if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
395                if let Some(batcher) = &mut self.transcript_mirror {
396                    messages.extend(batcher.enqueue_value(&value).await?);
397                }
398                continue;
399            }
400            let message = match parse_message_line(&line) {
401                Ok(Some(message)) => message,
402                Ok(None) => continue,
403                Err(err) => {
404                    // A single unrecognized message shape must not kill the
405                    // whole turn. Log it (payload included) and keep going.
406                    tracing::warn!("skipping unparseable CLI message: {err}");
407                    continue;
408                }
409            };
410            let done = matches!(message, Message::ResultMsg { .. });
411            if done {
412                if let Some(batcher) = &mut self.transcript_mirror {
413                    messages.extend(batcher.flush().await?);
414                }
415            }
416            {
417                let mut state = self.state.write().await;
418                state.messages.push(message.clone());
419            }
420            messages.push(message);
421            if stop_at_result && done {
422                break;
423            }
424        }
425        Ok(messages)
426    }
427
428    pub async fn stream_message(
429        &mut self,
430        content: impl Into<UserMessageInput>,
431    ) -> Result<mpsc::UnboundedReceiver<StreamEvent>> {
432        self.require_connected()?;
433        let content = content.into();
434        let payload = self.build_user_payload(&content, None)?;
435        let json_payload = serde_json::to_vec(&payload)?;
436        self.transport.write(&json_payload).await?;
437        self.transport
438            .write(
439                b"
440",
441            )
442            .await?;
443        let (tx, rx) = mpsc::unbounded_channel();
444        {
445            let mut state = self.state.write().await;
446            state.is_streaming = true;
447        }
448        while let Some(data) = self.transport.read().await? {
449            let line = String::from_utf8_lossy(&data);
450            let value = serde_json::from_slice::<serde_json::Value>(&data)?;
451            if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
452                respond_to_control_request(
453                    self.transport.as_mut(),
454                    &value,
455                    &self.control_callbacks,
456                )
457                .await?;
458                continue;
459            }
460            if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
461                if let Some(batcher) = &mut self.transcript_mirror {
462                    for message in batcher.enqueue_value(&value).await? {
463                        let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
464                    }
465                }
466                continue;
467            }
468            let message = match parse_message_line(&line) {
469                Ok(Some(message)) => message,
470                Ok(None) => continue,
471                Err(err) => {
472                    // A single unrecognized message shape must not kill the
473                    // whole turn. Log it (payload included) and keep going.
474                    tracing::warn!("skipping unparseable CLI message: {err}");
475                    continue;
476                }
477            };
478            for event in stream_events_from_message(&message, &self.session_id) {
479                let _ = tx.send(event);
480            }
481            let done = matches!(message, Message::ResultMsg { .. });
482            if done {
483                if let Some(batcher) = &mut self.transcript_mirror {
484                    for message in batcher.flush().await? {
485                        let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
486                    }
487                }
488            }
489            {
490                let mut state = self.state.write().await;
491                state.messages.push(message);
492                if done {
493                    state.is_streaming = false;
494                }
495            }
496            if done {
497                break;
498            }
499        }
500        Ok(rx)
501    }
502
503    async fn write_message_stream<S>(&mut self, mut stream: S, session_id: &str) -> Result<()>
504    where
505        S: Stream<Item = serde_json::Value> + Unpin,
506    {
507        while let Some(mut message) = stream.next().await {
508            if let Some(object) = message.as_object_mut() {
509                object
510                    .entry("session_id")
511                    .or_insert_with(|| serde_json::Value::String(session_id.to_string()));
512            }
513            let mut json_payload = serde_json::to_vec(&message)?;
514            json_payload.push(b'\n');
515            self.transport.write(&json_payload).await?;
516        }
517        Ok(())
518    }
519
520    pub async fn get_conversation_history(&self) -> Result<Vec<Message>> {
521        let state = self.state.read().await;
522        Ok(state.messages.clone())
523    }
524
525    pub async fn abort(&mut self) -> Result<()> {
526        if let Some(batcher) = &mut self.transcript_mirror {
527            let _ = batcher.flush().await?;
528        }
529        self.transport.close().await?;
530        if let Some(materialized) = &self.materialized_resume {
531            materialized.cleanup().await;
532        }
533        self.materialized_resume = None;
534        self.connected = false;
535        self.initialized = false;
536        Ok(())
537    }
538
539    pub async fn disconnect(&mut self) -> Result<()> {
540        self.abort().await
541    }
542
543    pub async fn close(mut self) -> Result<()> {
544        if let Some(batcher) = &mut self.transcript_mirror {
545            let _ = batcher.flush().await?;
546        }
547        self.transport.close().await?;
548        if let Some(materialized) = &self.materialized_resume {
549            materialized.cleanup().await;
550        }
551        Ok(())
552    }
553
554    pub async fn interrupt(&mut self) -> Result<()> {
555        self.require_connected()?;
556        send_control_request_with_callbacks(
557            self.transport.as_mut(),
558            serde_json::json!({"subtype": "interrupt"}),
559            &self.control_callbacks,
560        )
561        .await?;
562        Ok(())
563    }
564
565    pub async fn set_permission_mode(&mut self, mode: PermissionMode) -> Result<()> {
566        self.require_connected()?;
567        send_control_request_with_callbacks(
568            self.transport.as_mut(),
569            serde_json::json!({
570                "subtype": "set_permission_mode",
571                "mode": mode,
572            }),
573            &self.control_callbacks,
574        )
575        .await?;
576        Ok(())
577    }
578
579    pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
580        self.require_connected()?;
581        let model = model.map(serde_json::Value::String);
582        send_control_request_with_callbacks(
583            self.transport.as_mut(),
584            serde_json::json!({
585                "subtype": "set_model",
586                "model": model.unwrap_or(serde_json::Value::Null),
587            }),
588            &self.control_callbacks,
589        )
590        .await?;
591        Ok(())
592    }
593
594    pub async fn rewind_files(&mut self, user_message_id: impl Into<String>) -> Result<()> {
595        self.require_connected()?;
596        send_control_request_with_callbacks(
597            self.transport.as_mut(),
598            serde_json::json!({
599                "subtype": "rewind_files",
600                "user_message_id": user_message_id.into(),
601            }),
602            &self.control_callbacks,
603        )
604        .await?;
605        Ok(())
606    }
607
608    pub async fn reconnect_mcp_server(&mut self, server_name: impl Into<String>) -> Result<()> {
609        self.require_connected()?;
610        send_control_request_with_callbacks(
611            self.transport.as_mut(),
612            serde_json::json!({
613                "subtype": "mcp_reconnect",
614                "serverName": server_name.into(),
615            }),
616            &self.control_callbacks,
617        )
618        .await?;
619        Ok(())
620    }
621
622    pub async fn toggle_mcp_server(
623        &mut self,
624        server_name: impl Into<String>,
625        enabled: bool,
626    ) -> Result<()> {
627        self.require_connected()?;
628        send_control_request_with_callbacks(
629            self.transport.as_mut(),
630            serde_json::json!({
631                "subtype": "mcp_toggle",
632                "serverName": server_name.into(),
633                "enabled": enabled,
634            }),
635            &self.control_callbacks,
636        )
637        .await?;
638        Ok(())
639    }
640
641    pub async fn stop_task(&mut self, task_id: impl Into<String>) -> Result<()> {
642        self.require_connected()?;
643        send_control_request_with_callbacks(
644            self.transport.as_mut(),
645            serde_json::json!({
646                "subtype": "stop_task",
647                "task_id": task_id.into(),
648            }),
649            &self.control_callbacks,
650        )
651        .await?;
652        Ok(())
653    }
654
655    pub async fn get_mcp_status(&mut self) -> Result<MCPStatusResponse> {
656        self.require_connected()?;
657        let response = send_control_request_with_callbacks(
658            self.transport.as_mut(),
659            serde_json::json!({"subtype": "mcp_status"}),
660            &self.control_callbacks,
661        )
662        .await?;
663        let value = serde_json::Value::Object(response);
664        Ok(serde_json::from_value(value)?)
665    }
666
667    pub async fn get_context_usage(&mut self) -> Result<ContextUsageResponse> {
668        self.require_connected()?;
669        let response = send_control_request_with_callbacks(
670            self.transport.as_mut(),
671            serde_json::json!({"subtype": "get_context_usage"}),
672            &self.control_callbacks,
673        )
674        .await?;
675        Ok(serde_json::from_value(serde_json::Value::Object(response))?)
676    }
677
678    pub fn get_server_info(&self) -> Option<&serde_json::Map<String, serde_json::Value>> {
679        self.initialization_result.as_ref()
680    }
681
682    fn build_user_payload(
683        &self,
684        content: &UserMessageInput,
685        session_id: Option<&str>,
686    ) -> Result<serde_json::Map<String, serde_json::Value>> {
687        let mut payload = serde_json::Map::new();
688        payload.insert(
689            "type".to_string(),
690            serde_json::Value::String("user".to_string()),
691        );
692        payload.insert(
693            "session_id".to_string(),
694            serde_json::Value::String(
695                session_id
696                    .map(String::from)
697                    .unwrap_or_else(|| self.session_id.clone()),
698            ),
699        );
700        // `content` is a JSON string for text-only prompts, or an array of
701        // content blocks (text + images) for multimodal prompts.
702        let message = serde_json::json!({"role": "user", "content": content.to_content_value()});
703        payload.insert("message".to_string(), message);
704        Ok(payload)
705    }
706}