Skip to main content

claude_code_sdk_rust/
client.rs

1//! Main client for interactive sessions with Claude CLI.
2
3use futures::{Stream, StreamExt};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::{mpsc, RwLock};
7
8use crate::client_stream::stream_events_from_message;
9use crate::client_types::{MessageResponse, StreamEvent};
10use crate::error::{CLIConnectionError, Result};
11use crate::internal::control::{
12    initialize_request, initialize_timeout_duration, respond_to_control_request,
13    send_control_request_with_callbacks, send_control_request_with_callbacks_and_timeout,
14    ControlCallbacks,
15};
16use crate::internal::parser::parse_message_line;
17use crate::internal::session_resume::{
18    apply_materialized_options, materialize_resume_session, MaterializedResume,
19};
20use crate::internal::session_store_validation::validate_session_store_options;
21use crate::internal::transcript_mirror::TranscriptMirrorBatcher;
22use crate::internal::transport::{SubprocessCLITransport, Transport, TransportOptions};
23use crate::types::{
24    ClaudeAgentOptions, ContentBlock, ContextUsageResponse, MCPStatusResponse, Message,
25    PermissionMode,
26};
27
28#[derive(Debug)]
29#[allow(dead_code)]
30struct ClientState {
31    messages: Vec<Message>,
32    current_stream_buffer: String,
33    is_streaming: bool,
34    server_info: Option<HashMap<String, serde_json::Value>>,
35}
36
37pub struct ClaudeAgentClient {
38    transport: Box<dyn Transport>,
39    state: Arc<RwLock<ClientState>>,
40    session_id: String,
41    connected: bool,
42    initialized: bool,
43    initialization_result: Option<serde_json::Map<String, serde_json::Value>>,
44    control_callbacks: ControlCallbacks,
45    transcript_mirror: Option<TranscriptMirrorBatcher>,
46    source_options: Option<ClaudeAgentOptions>,
47    materialized_resume: Option<MaterializedResume>,
48}
49
50impl std::fmt::Debug for ClaudeAgentClient {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        f.debug_struct("ClaudeAgentClient")
53            .field("session_id", &self.session_id)
54            .finish_non_exhaustive()
55    }
56}
57
58impl ClaudeAgentClient {
59    pub fn spawn_stream_message(
60        options: ClaudeAgentOptions,
61        content: impl Into<String>,
62    ) -> mpsc::UnboundedReceiver<StreamEvent> {
63        let content = content.into();
64        let (tx, rx) = mpsc::unbounded_channel();
65        tokio::spawn(async move {
66            if let Err(err) = Self::run_spawned_stream(options, content, tx.clone()).await {
67                let _ = tx.send(StreamEvent::Error(err.to_string()));
68            }
69        });
70        rx
71    }
72
73    async fn run_spawned_stream(
74        options: ClaudeAgentOptions,
75        content: String,
76        tx: mpsc::UnboundedSender<StreamEvent>,
77    ) -> Result<()> {
78        let mut client = Self::new(options)?;
79        client.connect().await?;
80        client.require_connected()?;
81        let payload = client.build_user_payload(&content, None)?;
82        let json_payload = serde_json::to_vec(&payload)?;
83        client.transport.write(&json_payload).await?;
84        client.transport.write(b"\n").await?;
85        {
86            let mut state = client.state.write().await;
87            state.is_streaming = true;
88        }
89        while let Some(data) = client.transport.read().await? {
90            let line = String::from_utf8_lossy(&data);
91            let value = serde_json::from_slice::<serde_json::Value>(&data)?;
92            if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
93                respond_to_control_request(
94                    client.transport.as_mut(),
95                    &value,
96                    &client.control_callbacks,
97                )
98                .await?;
99                continue;
100            }
101            if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
102                if let Some(batcher) = &mut client.transcript_mirror {
103                    for message in batcher.enqueue_value(&value).await? {
104                        let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
105                    }
106                }
107                continue;
108            }
109            let message = match parse_message_line(&line) {
110                Ok(Some(message)) => message,
111                Ok(None) => continue,
112                Err(err) => {
113                    // A single unrecognized message shape must not kill the
114                    // whole turn. Log it (payload included) and keep going.
115                    tracing::warn!("skipping unparseable CLI message: {err}");
116                    continue;
117                }
118            };
119            for event in stream_events_from_message(&message, &client.session_id) {
120                let _ = tx.send(event);
121            }
122            let done = matches!(message, Message::ResultMsg { .. });
123            if done {
124                if let Some(batcher) = &mut client.transcript_mirror {
125                    for message in batcher.flush().await? {
126                        let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
127                    }
128                }
129            }
130            {
131                let mut state = client.state.write().await;
132                state.messages.push(message);
133                if done {
134                    state.is_streaming = false;
135                }
136            }
137            if done {
138                break;
139            }
140        }
141        Ok(())
142    }
143
144    pub fn new(options: ClaudeAgentOptions) -> Result<Self> {
145        validate_session_store_options(&options)?;
146        let transport_options = TransportOptions::from(&options);
147        let transport = SubprocessCLITransport::new(transport_options);
148        let mut client = Self::with_transport(options.clone(), Box::new(transport))?;
149        client.source_options = Some(options);
150        Ok(client)
151    }
152
153    pub fn with_transport(
154        options: ClaudeAgentOptions,
155        transport: Box<dyn Transport>,
156    ) -> Result<Self> {
157        let session_id = options
158            .session_id
159            .clone()
160            .or_else(|| options.resume.clone())
161            .unwrap_or_else(|| "default".to_string());
162        let state = Arc::new(RwLock::new(ClientState {
163            messages: Vec::new(),
164            current_stream_buffer: String::new(),
165            is_streaming: false,
166            server_info: None,
167        }));
168        Ok(Self {
169            transport,
170            state,
171            session_id,
172            connected: false,
173            initialized: false,
174            initialization_result: None,
175            control_callbacks: ControlCallbacks::from_options(&options),
176            transcript_mirror: TranscriptMirrorBatcher::from_options(&options),
177            source_options: None,
178            materialized_resume: None,
179        })
180    }
181
182    pub async fn connect(&mut self) -> Result<()> {
183        if !self.connected {
184            self.materialize_resume_before_connect().await?;
185            self.transport.connect().await?;
186            self.connected = true;
187        }
188        self.ensure_initialized().await?;
189        Ok(())
190    }
191
192    pub async fn connect_with_prompt(&mut self, content: impl Into<String>) -> Result<()> {
193        self.connect().await?;
194        let content = content.into();
195        let payload = self.build_user_payload(&content, None)?;
196        let mut json_payload = serde_json::to_vec(&payload)?;
197        json_payload.push(b'\n');
198        self.transport.write(&json_payload).await
199    }
200
201    pub async fn connect_with_stream<S>(&mut self, stream: S) -> Result<()>
202    where
203        S: Stream<Item = serde_json::Value> + Unpin,
204    {
205        self.connect().await?;
206        self.write_message_stream(stream, "default").await
207    }
208
209    async fn materialize_resume_before_connect(&mut self) -> Result<()> {
210        let Some(options) = self.source_options.clone() else {
211            return Ok(());
212        };
213        let Some(materialized) = materialize_resume_session(&options).await? else {
214            return Ok(());
215        };
216        let options = apply_materialized_options(&options, &materialized);
217        self.session_id = options
218            .session_id
219            .clone()
220            .or_else(|| options.resume.clone())
221            .unwrap_or_else(|| "default".to_string());
222        self.transport = Box::new(SubprocessCLITransport::new(TransportOptions::from(
223            &options,
224        )));
225        self.transcript_mirror = TranscriptMirrorBatcher::from_options(&options);
226        self.source_options = Some(options);
227        self.materialized_resume = Some(materialized);
228        Ok(())
229    }
230
231    fn require_connected(&self) -> Result<()> {
232        if self.connected && self.initialized {
233            Ok(())
234        } else {
235            Err(CLIConnectionError::new("Not connected. Call connect() first.").into())
236        }
237    }
238
239    async fn ensure_initialized(&mut self) -> Result<()> {
240        if self.initialized {
241            return Ok(());
242        }
243
244        let response = send_control_request_with_callbacks_and_timeout(
245            self.transport.as_mut(),
246            initialize_request(&self.control_callbacks),
247            &self.control_callbacks,
248            initialize_timeout_duration(),
249        )
250        .await?;
251        self.initialization_result = Some(response);
252        self.initialized = true;
253        Ok(())
254    }
255
256    pub async fn send_message(&mut self, content: impl Into<String>) -> Result<MessageResponse> {
257        self.query(content).await?;
258        let messages = self.receive_response().await?;
259        let mut content_parts: Vec<String> = Vec::new();
260        let mut blocks: Vec<ContentBlock> = Vec::new();
261        let mut usage: Option<HashMap<String, serde_json::Value>> = None;
262        let mut stop_reason: Option<String> = None;
263        let mut model = String::new();
264
265        for message in messages {
266            match message {
267                Message::AssistantMsg {
268                    content: assistant_content,
269                    ..
270                } => {
271                    // Track the model from the first assistant message
272                    if model.is_empty() {
273                        model.clone_from(&assistant_content.model);
274                    }
275                    for block in &assistant_content.content {
276                        match block {
277                            ContentBlock::Text { text } => content_parts.push(text.clone()),
278                            ContentBlock::Thinking { thinking, .. } => {
279                                content_parts.push(thinking.clone())
280                            }
281                            _ => {}
282                        }
283                        blocks.push(block.clone());
284                    }
285                }
286                Message::ResultMsg {
287                    stop_reason: reason,
288                    usage: u,
289                    ..
290                } => {
291                    stop_reason = reason;
292                    if let Some(u) = u {
293                        usage = Some(u.into_iter().collect());
294                    }
295                }
296                _ => {}
297            }
298        }
299
300        Ok(MessageResponse {
301            content: content_parts.join(""),
302            blocks,
303            model,
304            stop_reason,
305            session_id: self.session_id.clone(),
306            usage,
307        })
308    }
309
310    pub async fn query(&mut self, content: impl Into<String>) -> Result<()> {
311        self.require_connected()?;
312        let content_str = content.into();
313        let payload = self.build_user_payload(&content_str, None)?;
314        let mut json_payload = serde_json::to_vec(&payload)?;
315        json_payload.push(b'\n');
316        self.transport.write(&json_payload).await
317    }
318
319    pub async fn query_with_session_id(
320        &mut self,
321        content: impl Into<String>,
322        session_id: impl Into<String>,
323    ) -> Result<()> {
324        self.require_connected()?;
325        let content_str = content.into();
326        let session_id = session_id.into();
327        let payload = self.build_user_payload(&content_str, Some(&session_id))?;
328        let mut json_payload = serde_json::to_vec(&payload)?;
329        json_payload.push(b'\n');
330        self.transport.write(&json_payload).await
331    }
332
333    pub async fn query_stream<S>(&mut self, stream: S) -> Result<()>
334    where
335        S: Stream<Item = serde_json::Value> + Unpin,
336    {
337        self.query_stream_with_session_id(stream, "default").await
338    }
339
340    pub async fn query_stream_with_session_id<S>(
341        &mut self,
342        stream: S,
343        session_id: impl Into<String>,
344    ) -> Result<()>
345    where
346        S: Stream<Item = serde_json::Value> + Unpin,
347    {
348        self.require_connected()?;
349        self.write_message_stream(stream, &session_id.into()).await
350    }
351
352    pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
353        self.receive_messages_until(true).await
354    }
355
356    pub async fn receive_messages(&mut self) -> Result<Vec<Message>> {
357        self.receive_messages_until(false).await
358    }
359
360    async fn receive_messages_until(&mut self, stop_at_result: bool) -> Result<Vec<Message>> {
361        self.require_connected()?;
362        let mut messages = Vec::new();
363        while let Some(data) = self.transport.read().await? {
364            let line = String::from_utf8_lossy(&data);
365            let value = serde_json::from_slice::<serde_json::Value>(&data)?;
366            if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
367                respond_to_control_request(
368                    self.transport.as_mut(),
369                    &value,
370                    &self.control_callbacks,
371                )
372                .await?;
373                continue;
374            }
375            if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
376                if let Some(batcher) = &mut self.transcript_mirror {
377                    messages.extend(batcher.enqueue_value(&value).await?);
378                }
379                continue;
380            }
381            let message = match parse_message_line(&line) {
382                Ok(Some(message)) => message,
383                Ok(None) => continue,
384                Err(err) => {
385                    // A single unrecognized message shape must not kill the
386                    // whole turn. Log it (payload included) and keep going.
387                    tracing::warn!("skipping unparseable CLI message: {err}");
388                    continue;
389                }
390            };
391            let done = matches!(message, Message::ResultMsg { .. });
392            if done {
393                if let Some(batcher) = &mut self.transcript_mirror {
394                    messages.extend(batcher.flush().await?);
395                }
396            }
397            {
398                let mut state = self.state.write().await;
399                state.messages.push(message.clone());
400            }
401            messages.push(message);
402            if stop_at_result && done {
403                break;
404            }
405        }
406        Ok(messages)
407    }
408
409    pub async fn stream_message(
410        &mut self,
411        content: impl Into<String>,
412    ) -> Result<mpsc::UnboundedReceiver<StreamEvent>> {
413        self.require_connected()?;
414        let content_str = content.into();
415        let payload = self.build_user_payload(&content_str, None)?;
416        let json_payload = serde_json::to_vec(&payload)?;
417        self.transport.write(&json_payload).await?;
418        self.transport
419            .write(
420                b"
421",
422            )
423            .await?;
424        let (tx, rx) = mpsc::unbounded_channel();
425        {
426            let mut state = self.state.write().await;
427            state.is_streaming = true;
428        }
429        while let Some(data) = self.transport.read().await? {
430            let line = String::from_utf8_lossy(&data);
431            let value = serde_json::from_slice::<serde_json::Value>(&data)?;
432            if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
433                respond_to_control_request(
434                    self.transport.as_mut(),
435                    &value,
436                    &self.control_callbacks,
437                )
438                .await?;
439                continue;
440            }
441            if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
442                if let Some(batcher) = &mut self.transcript_mirror {
443                    for message in batcher.enqueue_value(&value).await? {
444                        let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
445                    }
446                }
447                continue;
448            }
449            let message = match parse_message_line(&line) {
450                Ok(Some(message)) => message,
451                Ok(None) => continue,
452                Err(err) => {
453                    // A single unrecognized message shape must not kill the
454                    // whole turn. Log it (payload included) and keep going.
455                    tracing::warn!("skipping unparseable CLI message: {err}");
456                    continue;
457                }
458            };
459            for event in stream_events_from_message(&message, &self.session_id) {
460                let _ = tx.send(event);
461            }
462            let done = matches!(message, Message::ResultMsg { .. });
463            if done {
464                if let Some(batcher) = &mut self.transcript_mirror {
465                    for message in batcher.flush().await? {
466                        let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
467                    }
468                }
469            }
470            {
471                let mut state = self.state.write().await;
472                state.messages.push(message);
473                if done {
474                    state.is_streaming = false;
475                }
476            }
477            if done {
478                break;
479            }
480        }
481        Ok(rx)
482    }
483
484    async fn write_message_stream<S>(&mut self, mut stream: S, session_id: &str) -> Result<()>
485    where
486        S: Stream<Item = serde_json::Value> + Unpin,
487    {
488        while let Some(mut message) = stream.next().await {
489            if let Some(object) = message.as_object_mut() {
490                object
491                    .entry("session_id")
492                    .or_insert_with(|| serde_json::Value::String(session_id.to_string()));
493            }
494            let mut json_payload = serde_json::to_vec(&message)?;
495            json_payload.push(b'\n');
496            self.transport.write(&json_payload).await?;
497        }
498        Ok(())
499    }
500
501    pub async fn get_conversation_history(&self) -> Result<Vec<Message>> {
502        let state = self.state.read().await;
503        Ok(state.messages.clone())
504    }
505
506    pub async fn abort(&mut self) -> Result<()> {
507        if let Some(batcher) = &mut self.transcript_mirror {
508            let _ = batcher.flush().await?;
509        }
510        self.transport.close().await?;
511        if let Some(materialized) = &self.materialized_resume {
512            materialized.cleanup().await;
513        }
514        self.materialized_resume = None;
515        self.connected = false;
516        self.initialized = false;
517        Ok(())
518    }
519
520    pub async fn disconnect(&mut self) -> Result<()> {
521        self.abort().await
522    }
523
524    pub async fn close(mut self) -> Result<()> {
525        if let Some(batcher) = &mut self.transcript_mirror {
526            let _ = batcher.flush().await?;
527        }
528        self.transport.close().await?;
529        if let Some(materialized) = &self.materialized_resume {
530            materialized.cleanup().await;
531        }
532        Ok(())
533    }
534
535    pub async fn interrupt(&mut self) -> Result<()> {
536        self.require_connected()?;
537        send_control_request_with_callbacks(
538            self.transport.as_mut(),
539            serde_json::json!({"subtype": "interrupt"}),
540            &self.control_callbacks,
541        )
542        .await?;
543        Ok(())
544    }
545
546    pub async fn set_permission_mode(&mut self, mode: PermissionMode) -> Result<()> {
547        self.require_connected()?;
548        send_control_request_with_callbacks(
549            self.transport.as_mut(),
550            serde_json::json!({
551                "subtype": "set_permission_mode",
552                "mode": mode,
553            }),
554            &self.control_callbacks,
555        )
556        .await?;
557        Ok(())
558    }
559
560    pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
561        self.require_connected()?;
562        let model = model.map(serde_json::Value::String);
563        send_control_request_with_callbacks(
564            self.transport.as_mut(),
565            serde_json::json!({
566                "subtype": "set_model",
567                "model": model.unwrap_or(serde_json::Value::Null),
568            }),
569            &self.control_callbacks,
570        )
571        .await?;
572        Ok(())
573    }
574
575    pub async fn rewind_files(&mut self, user_message_id: impl Into<String>) -> Result<()> {
576        self.require_connected()?;
577        send_control_request_with_callbacks(
578            self.transport.as_mut(),
579            serde_json::json!({
580                "subtype": "rewind_files",
581                "user_message_id": user_message_id.into(),
582            }),
583            &self.control_callbacks,
584        )
585        .await?;
586        Ok(())
587    }
588
589    pub async fn reconnect_mcp_server(&mut self, server_name: impl Into<String>) -> Result<()> {
590        self.require_connected()?;
591        send_control_request_with_callbacks(
592            self.transport.as_mut(),
593            serde_json::json!({
594                "subtype": "mcp_reconnect",
595                "serverName": server_name.into(),
596            }),
597            &self.control_callbacks,
598        )
599        .await?;
600        Ok(())
601    }
602
603    pub async fn toggle_mcp_server(
604        &mut self,
605        server_name: impl Into<String>,
606        enabled: bool,
607    ) -> Result<()> {
608        self.require_connected()?;
609        send_control_request_with_callbacks(
610            self.transport.as_mut(),
611            serde_json::json!({
612                "subtype": "mcp_toggle",
613                "serverName": server_name.into(),
614                "enabled": enabled,
615            }),
616            &self.control_callbacks,
617        )
618        .await?;
619        Ok(())
620    }
621
622    pub async fn stop_task(&mut self, task_id: impl Into<String>) -> Result<()> {
623        self.require_connected()?;
624        send_control_request_with_callbacks(
625            self.transport.as_mut(),
626            serde_json::json!({
627                "subtype": "stop_task",
628                "task_id": task_id.into(),
629            }),
630            &self.control_callbacks,
631        )
632        .await?;
633        Ok(())
634    }
635
636    pub async fn get_mcp_status(&mut self) -> Result<MCPStatusResponse> {
637        self.require_connected()?;
638        let response = send_control_request_with_callbacks(
639            self.transport.as_mut(),
640            serde_json::json!({"subtype": "mcp_status"}),
641            &self.control_callbacks,
642        )
643        .await?;
644        let value = serde_json::Value::Object(response);
645        Ok(serde_json::from_value(value)?)
646    }
647
648    pub async fn get_context_usage(&mut self) -> Result<ContextUsageResponse> {
649        self.require_connected()?;
650        let response = send_control_request_with_callbacks(
651            self.transport.as_mut(),
652            serde_json::json!({"subtype": "get_context_usage"}),
653            &self.control_callbacks,
654        )
655        .await?;
656        Ok(serde_json::from_value(serde_json::Value::Object(response))?)
657    }
658
659    pub fn get_server_info(&self) -> Option<&serde_json::Map<String, serde_json::Value>> {
660        self.initialization_result.as_ref()
661    }
662
663    fn build_user_payload(
664        &self,
665        content: &str,
666        session_id: Option<&str>,
667    ) -> Result<serde_json::Map<String, serde_json::Value>> {
668        let mut payload = serde_json::Map::new();
669        payload.insert(
670            "type".to_string(),
671            serde_json::Value::String("user".to_string()),
672        );
673        payload.insert(
674            "session_id".to_string(),
675            serde_json::Value::String(
676                session_id
677                    .map(String::from)
678                    .unwrap_or_else(|| self.session_id.clone()),
679            ),
680        );
681        let message = serde_json::json!({"role": "user", "content": content});
682        payload.insert("message".to_string(), message);
683        Ok(payload)
684    }
685}