1use futures::{Stream, StreamExt};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::{mpsc, RwLock};
7
8use crate::client_stream::stream_events_from_message;
9use crate::client_types::{MessageResponse, StreamEvent};
10use crate::error::{CLIConnectionError, Result};
11use crate::internal::control::{
12 initialize_request, initialize_timeout_duration, respond_to_control_request,
13 send_control_request_with_callbacks, send_control_request_with_callbacks_and_timeout,
14 ControlCallbacks,
15};
16use crate::internal::parser::parse_message_line;
17use crate::internal::session_resume::{
18 apply_materialized_options, materialize_resume_session, MaterializedResume,
19};
20use crate::internal::session_store_validation::validate_session_store_options;
21use crate::internal::transcript_mirror::TranscriptMirrorBatcher;
22use crate::internal::transport::{SubprocessCLITransport, Transport, TransportOptions};
23use crate::types::{
24 ClaudeAgentOptions, ContentBlock, ContextUsageResponse, MCPStatusResponse, Message,
25 PermissionMode,
26};
27
28#[derive(Debug)]
29#[allow(dead_code)]
30struct ClientState {
31 messages: Vec<Message>,
32 current_stream_buffer: String,
33 is_streaming: bool,
34 server_info: Option<HashMap<String, serde_json::Value>>,
35}
36
37pub struct ClaudeAgentClient {
38 transport: Box<dyn Transport>,
39 state: Arc<RwLock<ClientState>>,
40 session_id: String,
41 connected: bool,
42 initialized: bool,
43 initialization_result: Option<serde_json::Map<String, serde_json::Value>>,
44 control_callbacks: ControlCallbacks,
45 transcript_mirror: Option<TranscriptMirrorBatcher>,
46 source_options: Option<ClaudeAgentOptions>,
47 materialized_resume: Option<MaterializedResume>,
48}
49
50impl std::fmt::Debug for ClaudeAgentClient {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 f.debug_struct("ClaudeAgentClient")
53 .field("session_id", &self.session_id)
54 .finish_non_exhaustive()
55 }
56}
57
58impl ClaudeAgentClient {
59 pub fn spawn_stream_message(
60 options: ClaudeAgentOptions,
61 content: impl Into<String>,
62 ) -> mpsc::UnboundedReceiver<StreamEvent> {
63 let content = content.into();
64 let (tx, rx) = mpsc::unbounded_channel();
65 tokio::spawn(async move {
66 if let Err(err) = Self::run_spawned_stream(options, content, tx.clone()).await {
67 let _ = tx.send(StreamEvent::Error(err.to_string()));
68 }
69 });
70 rx
71 }
72
73 async fn run_spawned_stream(
74 options: ClaudeAgentOptions,
75 content: String,
76 tx: mpsc::UnboundedSender<StreamEvent>,
77 ) -> Result<()> {
78 let mut client = Self::new(options)?;
79 client.connect().await?;
80 client.require_connected()?;
81 let payload = client.build_user_payload(&content, None)?;
82 let json_payload = serde_json::to_vec(&payload)?;
83 client.transport.write(&json_payload).await?;
84 client.transport.write(b"\n").await?;
85 {
86 let mut state = client.state.write().await;
87 state.is_streaming = true;
88 }
89 while let Some(data) = client.transport.read().await? {
90 let line = String::from_utf8_lossy(&data);
91 let value = serde_json::from_slice::<serde_json::Value>(&data)?;
92 if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
93 respond_to_control_request(
94 client.transport.as_mut(),
95 &value,
96 &client.control_callbacks,
97 )
98 .await?;
99 continue;
100 }
101 if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
102 if let Some(batcher) = &mut client.transcript_mirror {
103 for message in batcher.enqueue_value(&value).await? {
104 let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
105 }
106 }
107 continue;
108 }
109 let message = match parse_message_line(&line) {
110 Ok(Some(message)) => message,
111 Ok(None) => continue,
112 Err(err) => {
113 tracing::warn!("skipping unparseable CLI message: {err}");
116 continue;
117 }
118 };
119 for event in stream_events_from_message(&message, &client.session_id) {
120 let _ = tx.send(event);
121 }
122 let done = matches!(message, Message::ResultMsg { .. });
123 if done {
124 if let Some(batcher) = &mut client.transcript_mirror {
125 for message in batcher.flush().await? {
126 let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
127 }
128 }
129 }
130 {
131 let mut state = client.state.write().await;
132 state.messages.push(message);
133 if done {
134 state.is_streaming = false;
135 }
136 }
137 if done {
138 break;
139 }
140 }
141 Ok(())
142 }
143
144 pub fn new(options: ClaudeAgentOptions) -> Result<Self> {
145 validate_session_store_options(&options)?;
146 let transport_options = TransportOptions::from(&options);
147 let transport = SubprocessCLITransport::new(transport_options);
148 let mut client = Self::with_transport(options.clone(), Box::new(transport))?;
149 client.source_options = Some(options);
150 Ok(client)
151 }
152
153 pub fn with_transport(
154 options: ClaudeAgentOptions,
155 transport: Box<dyn Transport>,
156 ) -> Result<Self> {
157 let session_id = options
158 .session_id
159 .clone()
160 .or_else(|| options.resume.clone())
161 .unwrap_or_else(|| "default".to_string());
162 let state = Arc::new(RwLock::new(ClientState {
163 messages: Vec::new(),
164 current_stream_buffer: String::new(),
165 is_streaming: false,
166 server_info: None,
167 }));
168 Ok(Self {
169 transport,
170 state,
171 session_id,
172 connected: false,
173 initialized: false,
174 initialization_result: None,
175 control_callbacks: ControlCallbacks::from_options(&options),
176 transcript_mirror: TranscriptMirrorBatcher::from_options(&options),
177 source_options: None,
178 materialized_resume: None,
179 })
180 }
181
182 pub async fn connect(&mut self) -> Result<()> {
183 if !self.connected {
184 self.materialize_resume_before_connect().await?;
185 self.transport.connect().await?;
186 self.connected = true;
187 }
188 self.ensure_initialized().await?;
189 Ok(())
190 }
191
192 pub async fn connect_with_prompt(&mut self, content: impl Into<String>) -> Result<()> {
193 self.connect().await?;
194 let content = content.into();
195 let payload = self.build_user_payload(&content, None)?;
196 let mut json_payload = serde_json::to_vec(&payload)?;
197 json_payload.push(b'\n');
198 self.transport.write(&json_payload).await
199 }
200
201 pub async fn connect_with_stream<S>(&mut self, stream: S) -> Result<()>
202 where
203 S: Stream<Item = serde_json::Value> + Unpin,
204 {
205 self.connect().await?;
206 self.write_message_stream(stream, "default").await
207 }
208
209 async fn materialize_resume_before_connect(&mut self) -> Result<()> {
210 let Some(options) = self.source_options.clone() else {
211 return Ok(());
212 };
213 let Some(materialized) = materialize_resume_session(&options).await? else {
214 return Ok(());
215 };
216 let options = apply_materialized_options(&options, &materialized);
217 self.session_id = options
218 .session_id
219 .clone()
220 .or_else(|| options.resume.clone())
221 .unwrap_or_else(|| "default".to_string());
222 self.transport = Box::new(SubprocessCLITransport::new(TransportOptions::from(
223 &options,
224 )));
225 self.transcript_mirror = TranscriptMirrorBatcher::from_options(&options);
226 self.source_options = Some(options);
227 self.materialized_resume = Some(materialized);
228 Ok(())
229 }
230
231 fn require_connected(&self) -> Result<()> {
232 if self.connected && self.initialized {
233 Ok(())
234 } else {
235 Err(CLIConnectionError::new("Not connected. Call connect() first.").into())
236 }
237 }
238
239 async fn ensure_initialized(&mut self) -> Result<()> {
240 if self.initialized {
241 return Ok(());
242 }
243
244 let response = send_control_request_with_callbacks_and_timeout(
245 self.transport.as_mut(),
246 initialize_request(&self.control_callbacks),
247 &self.control_callbacks,
248 initialize_timeout_duration(),
249 )
250 .await?;
251 self.initialization_result = Some(response);
252 self.initialized = true;
253 Ok(())
254 }
255
256 pub async fn send_message(&mut self, content: impl Into<String>) -> Result<MessageResponse> {
257 self.query(content).await?;
258 let messages = self.receive_response().await?;
259 let mut content_parts: Vec<String> = Vec::new();
260 let mut blocks: Vec<ContentBlock> = Vec::new();
261 let mut usage: Option<HashMap<String, serde_json::Value>> = None;
262 let mut stop_reason: Option<String> = None;
263 let mut model = String::new();
264
265 for message in messages {
266 match message {
267 Message::AssistantMsg {
268 content: assistant_content,
269 ..
270 } => {
271 if model.is_empty() {
273 model.clone_from(&assistant_content.model);
274 }
275 for block in &assistant_content.content {
276 match block {
277 ContentBlock::Text { text } => content_parts.push(text.clone()),
278 ContentBlock::Thinking { thinking, .. } => {
279 content_parts.push(thinking.clone())
280 }
281 _ => {}
282 }
283 blocks.push(block.clone());
284 }
285 }
286 Message::ResultMsg {
287 stop_reason: reason,
288 usage: u,
289 ..
290 } => {
291 stop_reason = reason;
292 if let Some(u) = u {
293 usage = Some(u.into_iter().collect());
294 }
295 }
296 _ => {}
297 }
298 }
299
300 Ok(MessageResponse {
301 content: content_parts.join(""),
302 blocks,
303 model,
304 stop_reason,
305 session_id: self.session_id.clone(),
306 usage,
307 })
308 }
309
310 pub async fn query(&mut self, content: impl Into<String>) -> Result<()> {
311 self.require_connected()?;
312 let content_str = content.into();
313 let payload = self.build_user_payload(&content_str, None)?;
314 let mut json_payload = serde_json::to_vec(&payload)?;
315 json_payload.push(b'\n');
316 self.transport.write(&json_payload).await
317 }
318
319 pub async fn query_with_session_id(
320 &mut self,
321 content: impl Into<String>,
322 session_id: impl Into<String>,
323 ) -> Result<()> {
324 self.require_connected()?;
325 let content_str = content.into();
326 let session_id = session_id.into();
327 let payload = self.build_user_payload(&content_str, Some(&session_id))?;
328 let mut json_payload = serde_json::to_vec(&payload)?;
329 json_payload.push(b'\n');
330 self.transport.write(&json_payload).await
331 }
332
333 pub async fn query_stream<S>(&mut self, stream: S) -> Result<()>
334 where
335 S: Stream<Item = serde_json::Value> + Unpin,
336 {
337 self.query_stream_with_session_id(stream, "default").await
338 }
339
340 pub async fn query_stream_with_session_id<S>(
341 &mut self,
342 stream: S,
343 session_id: impl Into<String>,
344 ) -> Result<()>
345 where
346 S: Stream<Item = serde_json::Value> + Unpin,
347 {
348 self.require_connected()?;
349 self.write_message_stream(stream, &session_id.into()).await
350 }
351
352 pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
353 self.receive_messages_until(true).await
354 }
355
356 pub async fn receive_messages(&mut self) -> Result<Vec<Message>> {
357 self.receive_messages_until(false).await
358 }
359
360 async fn receive_messages_until(&mut self, stop_at_result: bool) -> Result<Vec<Message>> {
361 self.require_connected()?;
362 let mut messages = Vec::new();
363 while let Some(data) = self.transport.read().await? {
364 let line = String::from_utf8_lossy(&data);
365 let value = serde_json::from_slice::<serde_json::Value>(&data)?;
366 if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
367 respond_to_control_request(
368 self.transport.as_mut(),
369 &value,
370 &self.control_callbacks,
371 )
372 .await?;
373 continue;
374 }
375 if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
376 if let Some(batcher) = &mut self.transcript_mirror {
377 messages.extend(batcher.enqueue_value(&value).await?);
378 }
379 continue;
380 }
381 let message = match parse_message_line(&line) {
382 Ok(Some(message)) => message,
383 Ok(None) => continue,
384 Err(err) => {
385 tracing::warn!("skipping unparseable CLI message: {err}");
388 continue;
389 }
390 };
391 let done = matches!(message, Message::ResultMsg { .. });
392 if done {
393 if let Some(batcher) = &mut self.transcript_mirror {
394 messages.extend(batcher.flush().await?);
395 }
396 }
397 {
398 let mut state = self.state.write().await;
399 state.messages.push(message.clone());
400 }
401 messages.push(message);
402 if stop_at_result && done {
403 break;
404 }
405 }
406 Ok(messages)
407 }
408
409 pub async fn stream_message(
410 &mut self,
411 content: impl Into<String>,
412 ) -> Result<mpsc::UnboundedReceiver<StreamEvent>> {
413 self.require_connected()?;
414 let content_str = content.into();
415 let payload = self.build_user_payload(&content_str, None)?;
416 let json_payload = serde_json::to_vec(&payload)?;
417 self.transport.write(&json_payload).await?;
418 self.transport
419 .write(
420 b"
421",
422 )
423 .await?;
424 let (tx, rx) = mpsc::unbounded_channel();
425 {
426 let mut state = self.state.write().await;
427 state.is_streaming = true;
428 }
429 while let Some(data) = self.transport.read().await? {
430 let line = String::from_utf8_lossy(&data);
431 let value = serde_json::from_slice::<serde_json::Value>(&data)?;
432 if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
433 respond_to_control_request(
434 self.transport.as_mut(),
435 &value,
436 &self.control_callbacks,
437 )
438 .await?;
439 continue;
440 }
441 if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
442 if let Some(batcher) = &mut self.transcript_mirror {
443 for message in batcher.enqueue_value(&value).await? {
444 let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
445 }
446 }
447 continue;
448 }
449 let message = match parse_message_line(&line) {
450 Ok(Some(message)) => message,
451 Ok(None) => continue,
452 Err(err) => {
453 tracing::warn!("skipping unparseable CLI message: {err}");
456 continue;
457 }
458 };
459 for event in stream_events_from_message(&message, &self.session_id) {
460 let _ = tx.send(event);
461 }
462 let done = matches!(message, Message::ResultMsg { .. });
463 if done {
464 if let Some(batcher) = &mut self.transcript_mirror {
465 for message in batcher.flush().await? {
466 let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
467 }
468 }
469 }
470 {
471 let mut state = self.state.write().await;
472 state.messages.push(message);
473 if done {
474 state.is_streaming = false;
475 }
476 }
477 if done {
478 break;
479 }
480 }
481 Ok(rx)
482 }
483
484 async fn write_message_stream<S>(&mut self, mut stream: S, session_id: &str) -> Result<()>
485 where
486 S: Stream<Item = serde_json::Value> + Unpin,
487 {
488 while let Some(mut message) = stream.next().await {
489 if let Some(object) = message.as_object_mut() {
490 object
491 .entry("session_id")
492 .or_insert_with(|| serde_json::Value::String(session_id.to_string()));
493 }
494 let mut json_payload = serde_json::to_vec(&message)?;
495 json_payload.push(b'\n');
496 self.transport.write(&json_payload).await?;
497 }
498 Ok(())
499 }
500
501 pub async fn get_conversation_history(&self) -> Result<Vec<Message>> {
502 let state = self.state.read().await;
503 Ok(state.messages.clone())
504 }
505
506 pub async fn abort(&mut self) -> Result<()> {
507 if let Some(batcher) = &mut self.transcript_mirror {
508 let _ = batcher.flush().await?;
509 }
510 self.transport.close().await?;
511 if let Some(materialized) = &self.materialized_resume {
512 materialized.cleanup().await;
513 }
514 self.materialized_resume = None;
515 self.connected = false;
516 self.initialized = false;
517 Ok(())
518 }
519
520 pub async fn disconnect(&mut self) -> Result<()> {
521 self.abort().await
522 }
523
524 pub async fn close(mut self) -> Result<()> {
525 if let Some(batcher) = &mut self.transcript_mirror {
526 let _ = batcher.flush().await?;
527 }
528 self.transport.close().await?;
529 if let Some(materialized) = &self.materialized_resume {
530 materialized.cleanup().await;
531 }
532 Ok(())
533 }
534
535 pub async fn interrupt(&mut self) -> Result<()> {
536 self.require_connected()?;
537 send_control_request_with_callbacks(
538 self.transport.as_mut(),
539 serde_json::json!({"subtype": "interrupt"}),
540 &self.control_callbacks,
541 )
542 .await?;
543 Ok(())
544 }
545
546 pub async fn set_permission_mode(&mut self, mode: PermissionMode) -> Result<()> {
547 self.require_connected()?;
548 send_control_request_with_callbacks(
549 self.transport.as_mut(),
550 serde_json::json!({
551 "subtype": "set_permission_mode",
552 "mode": mode,
553 }),
554 &self.control_callbacks,
555 )
556 .await?;
557 Ok(())
558 }
559
560 pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
561 self.require_connected()?;
562 let model = model.map(serde_json::Value::String);
563 send_control_request_with_callbacks(
564 self.transport.as_mut(),
565 serde_json::json!({
566 "subtype": "set_model",
567 "model": model.unwrap_or(serde_json::Value::Null),
568 }),
569 &self.control_callbacks,
570 )
571 .await?;
572 Ok(())
573 }
574
575 pub async fn rewind_files(&mut self, user_message_id: impl Into<String>) -> Result<()> {
576 self.require_connected()?;
577 send_control_request_with_callbacks(
578 self.transport.as_mut(),
579 serde_json::json!({
580 "subtype": "rewind_files",
581 "user_message_id": user_message_id.into(),
582 }),
583 &self.control_callbacks,
584 )
585 .await?;
586 Ok(())
587 }
588
589 pub async fn reconnect_mcp_server(&mut self, server_name: impl Into<String>) -> Result<()> {
590 self.require_connected()?;
591 send_control_request_with_callbacks(
592 self.transport.as_mut(),
593 serde_json::json!({
594 "subtype": "mcp_reconnect",
595 "serverName": server_name.into(),
596 }),
597 &self.control_callbacks,
598 )
599 .await?;
600 Ok(())
601 }
602
603 pub async fn toggle_mcp_server(
604 &mut self,
605 server_name: impl Into<String>,
606 enabled: bool,
607 ) -> Result<()> {
608 self.require_connected()?;
609 send_control_request_with_callbacks(
610 self.transport.as_mut(),
611 serde_json::json!({
612 "subtype": "mcp_toggle",
613 "serverName": server_name.into(),
614 "enabled": enabled,
615 }),
616 &self.control_callbacks,
617 )
618 .await?;
619 Ok(())
620 }
621
622 pub async fn stop_task(&mut self, task_id: impl Into<String>) -> Result<()> {
623 self.require_connected()?;
624 send_control_request_with_callbacks(
625 self.transport.as_mut(),
626 serde_json::json!({
627 "subtype": "stop_task",
628 "task_id": task_id.into(),
629 }),
630 &self.control_callbacks,
631 )
632 .await?;
633 Ok(())
634 }
635
636 pub async fn get_mcp_status(&mut self) -> Result<MCPStatusResponse> {
637 self.require_connected()?;
638 let response = send_control_request_with_callbacks(
639 self.transport.as_mut(),
640 serde_json::json!({"subtype": "mcp_status"}),
641 &self.control_callbacks,
642 )
643 .await?;
644 let value = serde_json::Value::Object(response);
645 Ok(serde_json::from_value(value)?)
646 }
647
648 pub async fn get_context_usage(&mut self) -> Result<ContextUsageResponse> {
649 self.require_connected()?;
650 let response = send_control_request_with_callbacks(
651 self.transport.as_mut(),
652 serde_json::json!({"subtype": "get_context_usage"}),
653 &self.control_callbacks,
654 )
655 .await?;
656 Ok(serde_json::from_value(serde_json::Value::Object(response))?)
657 }
658
659 pub fn get_server_info(&self) -> Option<&serde_json::Map<String, serde_json::Value>> {
660 self.initialization_result.as_ref()
661 }
662
663 fn build_user_payload(
664 &self,
665 content: &str,
666 session_id: Option<&str>,
667 ) -> Result<serde_json::Map<String, serde_json::Value>> {
668 let mut payload = serde_json::Map::new();
669 payload.insert(
670 "type".to_string(),
671 serde_json::Value::String("user".to_string()),
672 );
673 payload.insert(
674 "session_id".to_string(),
675 serde_json::Value::String(
676 session_id
677 .map(String::from)
678 .unwrap_or_else(|| self.session_id.clone()),
679 ),
680 );
681 let message = serde_json::json!({"role": "user", "content": content});
682 payload.insert("message".to_string(), message);
683 Ok(payload)
684 }
685}