Skip to main content

claude_agent_sdk/
client.rs

1//! Main client for interactive sessions with Claude CLI.
2
3use futures::{Stream, StreamExt};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::{mpsc, RwLock};
7
8use crate::client_stream::stream_events_from_message;
9use crate::client_types::{MessageResponse, StreamEvent};
10use crate::error::{CLIConnectionError, Result};
11use crate::internal::control::{
12    initialize_request, initialize_timeout_duration, respond_to_control_request,
13    send_control_request_with_callbacks, send_control_request_with_callbacks_and_timeout,
14    ControlCallbacks,
15};
16use crate::internal::parser::parse_message_line;
17use crate::internal::session_resume::{
18    apply_materialized_options, materialize_resume_session, MaterializedResume,
19};
20use crate::internal::session_store_validation::validate_session_store_options;
21use crate::internal::transcript_mirror::TranscriptMirrorBatcher;
22use crate::internal::transport::{SubprocessCLITransport, Transport, TransportOptions};
23use crate::types::{
24    ClaudeAgentOptions, ContentBlock, ContextUsageResponse, MCPStatusResponse, Message,
25    PermissionMode,
26};
27
28#[derive(Debug)]
29#[allow(dead_code)]
30struct ClientState {
31    messages: Vec<Message>,
32    current_stream_buffer: String,
33    is_streaming: bool,
34    server_info: Option<HashMap<String, serde_json::Value>>,
35}
36
37pub struct ClaudeAgentClient {
38    transport: Box<dyn Transport>,
39    state: Arc<RwLock<ClientState>>,
40    session_id: String,
41    connected: bool,
42    initialized: bool,
43    initialization_result: Option<serde_json::Map<String, serde_json::Value>>,
44    control_callbacks: ControlCallbacks,
45    transcript_mirror: Option<TranscriptMirrorBatcher>,
46    source_options: Option<ClaudeAgentOptions>,
47    materialized_resume: Option<MaterializedResume>,
48}
49
50impl std::fmt::Debug for ClaudeAgentClient {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        f.debug_struct("ClaudeAgentClient")
53            .field("session_id", &self.session_id)
54            .finish_non_exhaustive()
55    }
56}
57
58impl ClaudeAgentClient {
59    pub fn new(options: ClaudeAgentOptions) -> Result<Self> {
60        validate_session_store_options(&options)?;
61        let transport_options = TransportOptions::from(&options);
62        let transport = SubprocessCLITransport::new(transport_options);
63        let mut client = Self::with_transport(options.clone(), Box::new(transport))?;
64        client.source_options = Some(options);
65        Ok(client)
66    }
67
68    pub fn with_transport(
69        options: ClaudeAgentOptions,
70        transport: Box<dyn Transport>,
71    ) -> Result<Self> {
72        let session_id = options
73            .session_id
74            .clone()
75            .or_else(|| options.resume.clone())
76            .unwrap_or_else(|| "default".to_string());
77        let state = Arc::new(RwLock::new(ClientState {
78            messages: Vec::new(),
79            current_stream_buffer: String::new(),
80            is_streaming: false,
81            server_info: None,
82        }));
83        Ok(Self {
84            transport,
85            state,
86            session_id,
87            connected: false,
88            initialized: false,
89            initialization_result: None,
90            control_callbacks: ControlCallbacks::from_options(&options),
91            transcript_mirror: TranscriptMirrorBatcher::from_options(&options),
92            source_options: None,
93            materialized_resume: None,
94        })
95    }
96
97    pub async fn connect(&mut self) -> Result<()> {
98        if !self.connected {
99            self.materialize_resume_before_connect().await?;
100            self.transport.connect().await?;
101            self.connected = true;
102        }
103        self.ensure_initialized().await?;
104        Ok(())
105    }
106
107    pub async fn connect_with_prompt(&mut self, content: impl Into<String>) -> Result<()> {
108        self.connect().await?;
109        let content = content.into();
110        let payload = self.build_user_payload(&content, None)?;
111        let mut json_payload = serde_json::to_vec(&payload)?;
112        json_payload.push(b'\n');
113        self.transport.write(&json_payload).await
114    }
115
116    pub async fn connect_with_stream<S>(&mut self, stream: S) -> Result<()>
117    where
118        S: Stream<Item = serde_json::Value> + Unpin,
119    {
120        self.connect().await?;
121        self.write_message_stream(stream, "default").await
122    }
123
124    async fn materialize_resume_before_connect(&mut self) -> Result<()> {
125        let Some(options) = self.source_options.clone() else {
126            return Ok(());
127        };
128        let Some(materialized) = materialize_resume_session(&options).await? else {
129            return Ok(());
130        };
131        let options = apply_materialized_options(&options, &materialized);
132        self.session_id = options
133            .session_id
134            .clone()
135            .or_else(|| options.resume.clone())
136            .unwrap_or_else(|| "default".to_string());
137        self.transport = Box::new(SubprocessCLITransport::new(TransportOptions::from(
138            &options,
139        )));
140        self.transcript_mirror = TranscriptMirrorBatcher::from_options(&options);
141        self.source_options = Some(options);
142        self.materialized_resume = Some(materialized);
143        Ok(())
144    }
145
146    fn require_connected(&self) -> Result<()> {
147        if self.connected && self.initialized {
148            Ok(())
149        } else {
150            Err(CLIConnectionError::new("Not connected. Call connect() first.").into())
151        }
152    }
153
154    async fn ensure_initialized(&mut self) -> Result<()> {
155        if self.initialized {
156            return Ok(());
157        }
158
159        let response = send_control_request_with_callbacks_and_timeout(
160            self.transport.as_mut(),
161            initialize_request(&self.control_callbacks),
162            &self.control_callbacks,
163            initialize_timeout_duration(),
164        )
165        .await?;
166        self.initialization_result = Some(response);
167        self.initialized = true;
168        Ok(())
169    }
170
171    pub async fn send_message(&mut self, content: impl Into<String>) -> Result<MessageResponse> {
172        self.query(content).await?;
173        let messages = self.receive_response().await?;
174        let mut content_parts: Vec<String> = Vec::new();
175        let mut blocks: Vec<ContentBlock> = Vec::new();
176        let mut usage: Option<HashMap<String, serde_json::Value>> = None;
177        let mut stop_reason: Option<String> = None;
178        let mut model = String::new();
179
180        for message in messages {
181            match message {
182                Message::AssistantMsg {
183                    content: assistant_content,
184                    ..
185                } => {
186                    // Track the model from the first assistant message
187                    if model.is_empty() {
188                        model.clone_from(&assistant_content.model);
189                    }
190                    for block in &assistant_content.content {
191                        match block {
192                            ContentBlock::Text { text } => content_parts.push(text.clone()),
193                            ContentBlock::Thinking { thinking, .. } => {
194                                content_parts.push(thinking.clone())
195                            }
196                            _ => {}
197                        }
198                        blocks.push(block.clone());
199                    }
200                }
201                Message::ResultMsg {
202                    stop_reason: reason,
203                    usage: u,
204                    ..
205                } => {
206                    stop_reason = reason;
207                    if let Some(u) = u {
208                        usage = Some(u.into_iter().collect());
209                    }
210                }
211                _ => {}
212            }
213        }
214
215        Ok(MessageResponse {
216            content: content_parts.join(""),
217            blocks,
218            model,
219            stop_reason,
220            session_id: self.session_id.clone(),
221            usage,
222        })
223    }
224
225    pub async fn query(&mut self, content: impl Into<String>) -> Result<()> {
226        self.require_connected()?;
227        let content_str = content.into();
228        let payload = self.build_user_payload(&content_str, None)?;
229        let mut json_payload = serde_json::to_vec(&payload)?;
230        json_payload.push(b'\n');
231        self.transport.write(&json_payload).await
232    }
233
234    pub async fn query_with_session_id(
235        &mut self,
236        content: impl Into<String>,
237        session_id: impl Into<String>,
238    ) -> Result<()> {
239        self.require_connected()?;
240        let content_str = content.into();
241        let session_id = session_id.into();
242        let payload = self.build_user_payload(&content_str, Some(&session_id))?;
243        let mut json_payload = serde_json::to_vec(&payload)?;
244        json_payload.push(b'\n');
245        self.transport.write(&json_payload).await
246    }
247
248    pub async fn query_stream<S>(&mut self, stream: S) -> Result<()>
249    where
250        S: Stream<Item = serde_json::Value> + Unpin,
251    {
252        self.query_stream_with_session_id(stream, "default").await
253    }
254
255    pub async fn query_stream_with_session_id<S>(
256        &mut self,
257        stream: S,
258        session_id: impl Into<String>,
259    ) -> Result<()>
260    where
261        S: Stream<Item = serde_json::Value> + Unpin,
262    {
263        self.require_connected()?;
264        self.write_message_stream(stream, &session_id.into()).await
265    }
266
267    pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
268        self.receive_messages_until(true).await
269    }
270
271    pub async fn receive_messages(&mut self) -> Result<Vec<Message>> {
272        self.receive_messages_until(false).await
273    }
274
275    async fn receive_messages_until(&mut self, stop_at_result: bool) -> Result<Vec<Message>> {
276        self.require_connected()?;
277        let mut messages = Vec::new();
278        while let Some(data) = self.transport.read().await? {
279            let line = String::from_utf8_lossy(&data);
280            let value = serde_json::from_slice::<serde_json::Value>(&data)?;
281            if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
282                respond_to_control_request(
283                    self.transport.as_mut(),
284                    &value,
285                    &self.control_callbacks,
286                )
287                .await?;
288                continue;
289            }
290            if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
291                if let Some(batcher) = &mut self.transcript_mirror {
292                    messages.extend(batcher.enqueue_value(&value).await?);
293                }
294                continue;
295            }
296            let Some(message) = parse_message_line(&line)? else {
297                continue;
298            };
299            let done = matches!(message, Message::ResultMsg { .. });
300            if done {
301                if let Some(batcher) = &mut self.transcript_mirror {
302                    messages.extend(batcher.flush().await?);
303                }
304            }
305            {
306                let mut state = self.state.write().await;
307                state.messages.push(message.clone());
308            }
309            messages.push(message);
310            if stop_at_result && done {
311                break;
312            }
313        }
314        Ok(messages)
315    }
316
317    pub async fn stream_message(
318        &mut self,
319        content: impl Into<String>,
320    ) -> Result<mpsc::UnboundedReceiver<StreamEvent>> {
321        self.require_connected()?;
322        let content_str = content.into();
323        let payload = self.build_user_payload(&content_str, None)?;
324        let json_payload = serde_json::to_vec(&payload)?;
325        self.transport.write(&json_payload).await?;
326        self.transport
327            .write(
328                b"
329",
330            )
331            .await?;
332        let (tx, rx) = mpsc::unbounded_channel();
333        {
334            let mut state = self.state.write().await;
335            state.is_streaming = true;
336        }
337        while let Some(data) = self.transport.read().await? {
338            let line = String::from_utf8_lossy(&data);
339            let value = serde_json::from_slice::<serde_json::Value>(&data)?;
340            if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
341                respond_to_control_request(
342                    self.transport.as_mut(),
343                    &value,
344                    &self.control_callbacks,
345                )
346                .await?;
347                continue;
348            }
349            if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
350                if let Some(batcher) = &mut self.transcript_mirror {
351                    for message in batcher.enqueue_value(&value).await? {
352                        let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
353                    }
354                }
355                continue;
356            }
357            let Some(message) = parse_message_line(&line)? else {
358                continue;
359            };
360            for event in stream_events_from_message(&message, &self.session_id) {
361                let _ = tx.send(event);
362            }
363            let done = matches!(message, Message::ResultMsg { .. });
364            if done {
365                if let Some(batcher) = &mut self.transcript_mirror {
366                    for message in batcher.flush().await? {
367                        let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
368                    }
369                }
370            }
371            {
372                let mut state = self.state.write().await;
373                state.messages.push(message);
374                if done {
375                    state.is_streaming = false;
376                }
377            }
378            if done {
379                break;
380            }
381        }
382        Ok(rx)
383    }
384
385    /// Fire-and-forget streaming for a single prompt.
386    ///
387    /// Spawns a background task that creates a fresh client from `options`,
388    /// connects, and streams `prompt`, forwarding each [`StreamEvent`] to the
389    /// returned receiver. Unlike [`Self::stream_message`], this returns
390    /// immediately and the spawned task owns the client for the lifetime of the
391    /// stream, so the caller only needs to hold the receiver. Connection or
392    /// streaming failures are surfaced as a final [`StreamEvent::Error`].
393    ///
394    /// Must be called from within a Tokio runtime.
395    pub fn spawn_stream_message(
396        options: ClaudeAgentOptions,
397        prompt: impl Into<String>,
398    ) -> mpsc::UnboundedReceiver<StreamEvent> {
399        let prompt = prompt.into();
400        let (tx, rx) = mpsc::unbounded_channel();
401        tokio::spawn(async move {
402            let mut client = match ClaudeAgentClient::new(options) {
403                Ok(client) => client,
404                Err(err) => {
405                    let _ = tx.send(StreamEvent::Error(err.to_string()));
406                    return;
407                }
408            };
409            if let Err(err) = client.connect().await {
410                let _ = tx.send(StreamEvent::Error(err.to_string()));
411                return;
412            }
413            match client.stream_message(prompt).await {
414                Ok(mut events) => {
415                    while let Some(event) = events.recv().await {
416                        if tx.send(event).is_err() {
417                            break;
418                        }
419                    }
420                }
421                Err(err) => {
422                    let _ = tx.send(StreamEvent::Error(err.to_string()));
423                }
424            }
425            // Keep the client (and its transport) alive until streaming ends.
426            drop(client);
427        });
428        rx
429    }
430
431    async fn write_message_stream<S>(&mut self, mut stream: S, session_id: &str) -> Result<()>
432    where
433        S: Stream<Item = serde_json::Value> + Unpin,
434    {
435        while let Some(mut message) = stream.next().await {
436            if let Some(object) = message.as_object_mut() {
437                object
438                    .entry("session_id")
439                    .or_insert_with(|| serde_json::Value::String(session_id.to_string()));
440            }
441            let mut json_payload = serde_json::to_vec(&message)?;
442            json_payload.push(b'\n');
443            self.transport.write(&json_payload).await?;
444        }
445        Ok(())
446    }
447
448    pub async fn get_conversation_history(&self) -> Result<Vec<Message>> {
449        let state = self.state.read().await;
450        Ok(state.messages.clone())
451    }
452
453    pub async fn abort(&mut self) -> Result<()> {
454        if let Some(batcher) = &mut self.transcript_mirror {
455            let _ = batcher.flush().await?;
456        }
457        self.transport.close().await?;
458        if let Some(materialized) = &self.materialized_resume {
459            materialized.cleanup().await;
460        }
461        self.materialized_resume = None;
462        self.connected = false;
463        self.initialized = false;
464        Ok(())
465    }
466
467    pub async fn disconnect(&mut self) -> Result<()> {
468        self.abort().await
469    }
470
471    pub async fn close(mut self) -> Result<()> {
472        if let Some(batcher) = &mut self.transcript_mirror {
473            let _ = batcher.flush().await?;
474        }
475        self.transport.close().await?;
476        if let Some(materialized) = &self.materialized_resume {
477            materialized.cleanup().await;
478        }
479        Ok(())
480    }
481
482    pub async fn interrupt(&mut self) -> Result<()> {
483        self.require_connected()?;
484        send_control_request_with_callbacks(
485            self.transport.as_mut(),
486            serde_json::json!({"subtype": "interrupt"}),
487            &self.control_callbacks,
488        )
489        .await?;
490        Ok(())
491    }
492
493    pub async fn set_permission_mode(&mut self, mode: PermissionMode) -> Result<()> {
494        self.require_connected()?;
495        send_control_request_with_callbacks(
496            self.transport.as_mut(),
497            serde_json::json!({
498                "subtype": "set_permission_mode",
499                "mode": mode,
500            }),
501            &self.control_callbacks,
502        )
503        .await?;
504        Ok(())
505    }
506
507    pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
508        self.require_connected()?;
509        let model = model.map(serde_json::Value::String);
510        send_control_request_with_callbacks(
511            self.transport.as_mut(),
512            serde_json::json!({
513                "subtype": "set_model",
514                "model": model.unwrap_or(serde_json::Value::Null),
515            }),
516            &self.control_callbacks,
517        )
518        .await?;
519        Ok(())
520    }
521
522    pub async fn rewind_files(&mut self, user_message_id: impl Into<String>) -> Result<()> {
523        self.require_connected()?;
524        send_control_request_with_callbacks(
525            self.transport.as_mut(),
526            serde_json::json!({
527                "subtype": "rewind_files",
528                "user_message_id": user_message_id.into(),
529            }),
530            &self.control_callbacks,
531        )
532        .await?;
533        Ok(())
534    }
535
536    pub async fn reconnect_mcp_server(&mut self, server_name: impl Into<String>) -> Result<()> {
537        self.require_connected()?;
538        send_control_request_with_callbacks(
539            self.transport.as_mut(),
540            serde_json::json!({
541                "subtype": "mcp_reconnect",
542                "serverName": server_name.into(),
543            }),
544            &self.control_callbacks,
545        )
546        .await?;
547        Ok(())
548    }
549
550    pub async fn toggle_mcp_server(
551        &mut self,
552        server_name: impl Into<String>,
553        enabled: bool,
554    ) -> Result<()> {
555        self.require_connected()?;
556        send_control_request_with_callbacks(
557            self.transport.as_mut(),
558            serde_json::json!({
559                "subtype": "mcp_toggle",
560                "serverName": server_name.into(),
561                "enabled": enabled,
562            }),
563            &self.control_callbacks,
564        )
565        .await?;
566        Ok(())
567    }
568
569    pub async fn stop_task(&mut self, task_id: impl Into<String>) -> Result<()> {
570        self.require_connected()?;
571        send_control_request_with_callbacks(
572            self.transport.as_mut(),
573            serde_json::json!({
574                "subtype": "stop_task",
575                "task_id": task_id.into(),
576            }),
577            &self.control_callbacks,
578        )
579        .await?;
580        Ok(())
581    }
582
583    pub async fn get_mcp_status(&mut self) -> Result<MCPStatusResponse> {
584        self.require_connected()?;
585        let response = send_control_request_with_callbacks(
586            self.transport.as_mut(),
587            serde_json::json!({"subtype": "mcp_status"}),
588            &self.control_callbacks,
589        )
590        .await?;
591        let value = serde_json::Value::Object(response);
592        Ok(serde_json::from_value(value)?)
593    }
594
595    pub async fn get_context_usage(&mut self) -> Result<ContextUsageResponse> {
596        self.require_connected()?;
597        let response = send_control_request_with_callbacks(
598            self.transport.as_mut(),
599            serde_json::json!({"subtype": "get_context_usage"}),
600            &self.control_callbacks,
601        )
602        .await?;
603        Ok(serde_json::from_value(serde_json::Value::Object(response))?)
604    }
605
606    pub fn get_server_info(&self) -> Option<&serde_json::Map<String, serde_json::Value>> {
607        self.initialization_result.as_ref()
608    }
609
610    fn build_user_payload(
611        &self,
612        content: &str,
613        session_id: Option<&str>,
614    ) -> Result<serde_json::Map<String, serde_json::Value>> {
615        let mut payload = serde_json::Map::new();
616        payload.insert(
617            "type".to_string(),
618            serde_json::Value::String("user".to_string()),
619        );
620        payload.insert(
621            "session_id".to_string(),
622            serde_json::Value::String(
623                session_id
624                    .map(String::from)
625                    .unwrap_or_else(|| self.session_id.clone()),
626            ),
627        );
628        let message = serde_json::json!({"role": "user", "content": content});
629        payload.insert("message".to_string(), message);
630        Ok(payload)
631    }
632}