1use futures::{Stream, StreamExt};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::{mpsc, RwLock};
7
8use crate::client_stream::stream_events_from_message;
9use crate::client_types::{MessageResponse, StreamEvent};
10use crate::error::{CLIConnectionError, Result};
11use crate::internal::control::{
12 initialize_request, initialize_timeout_duration, respond_to_control_request,
13 send_control_request_with_callbacks, send_control_request_with_callbacks_and_timeout,
14 ControlCallbacks,
15};
16use crate::internal::parser::parse_message_line;
17use crate::internal::session_resume::{
18 apply_materialized_options, materialize_resume_session, MaterializedResume,
19};
20use crate::internal::session_store_validation::validate_session_store_options;
21use crate::internal::transcript_mirror::TranscriptMirrorBatcher;
22use crate::internal::transport::{SubprocessCLITransport, Transport, TransportOptions};
23use crate::types::{
24 ClaudeAgentOptions, ContentBlock, ContextUsageResponse, MCPStatusResponse, Message,
25 PermissionMode,
26};
27
28#[derive(Debug)]
29#[allow(dead_code)]
30struct ClientState {
31 messages: Vec<Message>,
32 current_stream_buffer: String,
33 is_streaming: bool,
34 server_info: Option<HashMap<String, serde_json::Value>>,
35}
36
37pub struct ClaudeAgentClient {
38 transport: Box<dyn Transport>,
39 state: Arc<RwLock<ClientState>>,
40 session_id: String,
41 connected: bool,
42 initialized: bool,
43 initialization_result: Option<serde_json::Map<String, serde_json::Value>>,
44 control_callbacks: ControlCallbacks,
45 transcript_mirror: Option<TranscriptMirrorBatcher>,
46 source_options: Option<ClaudeAgentOptions>,
47 materialized_resume: Option<MaterializedResume>,
48}
49
50impl std::fmt::Debug for ClaudeAgentClient {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 f.debug_struct("ClaudeAgentClient")
53 .field("session_id", &self.session_id)
54 .finish_non_exhaustive()
55 }
56}
57
58impl ClaudeAgentClient {
59 pub fn spawn_stream_message(
70 options: ClaudeAgentOptions,
71 content: impl Into<String>,
72 ) -> mpsc::UnboundedReceiver<StreamEvent> {
73 let content = content.into();
74 let (tx, rx) = mpsc::unbounded_channel();
75 tokio::spawn(async move {
76 if let Err(err) = Self::run_spawned_stream(options, content, tx.clone()).await {
77 let _ = tx.send(StreamEvent::Error(err.to_string()));
78 }
79 });
80 rx
81 }
82
83 async fn run_spawned_stream(
84 options: ClaudeAgentOptions,
85 content: String,
86 tx: mpsc::UnboundedSender<StreamEvent>,
87 ) -> Result<()> {
88 let mut client = Self::new(options)?;
89 client.connect().await?;
90 client.require_connected()?;
91 let payload = client.build_user_payload(&content, None)?;
92 let json_payload = serde_json::to_vec(&payload)?;
93 client.transport.write(&json_payload).await?;
94 client.transport.write(b"\n").await?;
95 {
96 let mut state = client.state.write().await;
97 state.is_streaming = true;
98 }
99 while let Some(data) = client.transport.read().await? {
100 let line = String::from_utf8_lossy(&data);
101 let value = serde_json::from_slice::<serde_json::Value>(&data)?;
102 if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
103 respond_to_control_request(
104 client.transport.as_mut(),
105 &value,
106 &client.control_callbacks,
107 )
108 .await?;
109 continue;
110 }
111 if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
112 if let Some(batcher) = &mut client.transcript_mirror {
113 for message in batcher.enqueue_value(&value).await? {
114 let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
115 }
116 }
117 continue;
118 }
119 let message = match parse_message_line(&line) {
120 Ok(Some(message)) => message,
121 Ok(None) => continue,
122 Err(err) => {
123 tracing::warn!("skipping unparseable CLI message: {err}");
126 continue;
127 }
128 };
129 for event in stream_events_from_message(&message, &client.session_id) {
130 let _ = tx.send(event);
131 }
132 let done = matches!(message, Message::ResultMsg { .. });
133 if done {
134 if let Some(batcher) = &mut client.transcript_mirror {
135 for message in batcher.flush().await? {
136 let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
137 }
138 }
139 }
140 {
141 let mut state = client.state.write().await;
142 state.messages.push(message);
143 if done {
144 state.is_streaming = false;
145 }
146 }
147 if done {
148 break;
149 }
150 }
151 Ok(())
152 }
153
154 pub fn new(options: ClaudeAgentOptions) -> Result<Self> {
155 validate_session_store_options(&options)?;
156 let transport_options = TransportOptions::from(&options);
157 let transport = SubprocessCLITransport::new(transport_options);
158 let mut client = Self::with_transport(options.clone(), Box::new(transport))?;
159 client.source_options = Some(options);
160 Ok(client)
161 }
162
163 pub fn with_transport(
164 options: ClaudeAgentOptions,
165 transport: Box<dyn Transport>,
166 ) -> Result<Self> {
167 let session_id = options
168 .session_id
169 .clone()
170 .or_else(|| options.resume.clone())
171 .unwrap_or_else(|| "default".to_string());
172 let state = Arc::new(RwLock::new(ClientState {
173 messages: Vec::new(),
174 current_stream_buffer: String::new(),
175 is_streaming: false,
176 server_info: None,
177 }));
178 Ok(Self {
179 transport,
180 state,
181 session_id,
182 connected: false,
183 initialized: false,
184 initialization_result: None,
185 control_callbacks: ControlCallbacks::from_options(&options),
186 transcript_mirror: TranscriptMirrorBatcher::from_options(&options),
187 source_options: None,
188 materialized_resume: None,
189 })
190 }
191
192 pub async fn connect(&mut self) -> Result<()> {
193 if !self.connected {
194 self.materialize_resume_before_connect().await?;
195 self.transport.connect().await?;
196 self.connected = true;
197 }
198 self.ensure_initialized().await?;
199 Ok(())
200 }
201
202 pub async fn connect_with_prompt(&mut self, content: impl Into<String>) -> Result<()> {
203 self.connect().await?;
204 let content = content.into();
205 let payload = self.build_user_payload(&content, None)?;
206 let mut json_payload = serde_json::to_vec(&payload)?;
207 json_payload.push(b'\n');
208 self.transport.write(&json_payload).await
209 }
210
211 pub async fn connect_with_stream<S>(&mut self, stream: S) -> Result<()>
212 where
213 S: Stream<Item = serde_json::Value> + Unpin,
214 {
215 self.connect().await?;
216 self.write_message_stream(stream, "default").await
217 }
218
219 async fn materialize_resume_before_connect(&mut self) -> Result<()> {
220 let Some(options) = self.source_options.clone() else {
221 return Ok(());
222 };
223 let Some(materialized) = materialize_resume_session(&options).await? else {
224 return Ok(());
225 };
226 let options = apply_materialized_options(&options, &materialized);
227 self.session_id = options
228 .session_id
229 .clone()
230 .or_else(|| options.resume.clone())
231 .unwrap_or_else(|| "default".to_string());
232 self.transport = Box::new(SubprocessCLITransport::new(TransportOptions::from(
233 &options,
234 )));
235 self.transcript_mirror = TranscriptMirrorBatcher::from_options(&options);
236 self.source_options = Some(options);
237 self.materialized_resume = Some(materialized);
238 Ok(())
239 }
240
241 fn require_connected(&self) -> Result<()> {
242 if self.connected && self.initialized {
243 Ok(())
244 } else {
245 Err(CLIConnectionError::new("Not connected. Call connect() first.").into())
246 }
247 }
248
249 async fn ensure_initialized(&mut self) -> Result<()> {
250 if self.initialized {
251 return Ok(());
252 }
253
254 let response = send_control_request_with_callbacks_and_timeout(
255 self.transport.as_mut(),
256 initialize_request(&self.control_callbacks),
257 &self.control_callbacks,
258 initialize_timeout_duration(),
259 )
260 .await?;
261 self.initialization_result = Some(response);
262 self.initialized = true;
263 Ok(())
264 }
265
266 pub async fn send_message(&mut self, content: impl Into<String>) -> Result<MessageResponse> {
267 self.query(content).await?;
268 let messages = self.receive_response().await?;
269 let mut content_parts: Vec<String> = Vec::new();
270 let mut blocks: Vec<ContentBlock> = Vec::new();
271 let mut usage: Option<HashMap<String, serde_json::Value>> = None;
272 let mut stop_reason: Option<String> = None;
273 let mut model = String::new();
274
275 for message in messages {
276 match message {
277 Message::AssistantMsg {
278 content: assistant_content,
279 ..
280 } => {
281 if model.is_empty() {
283 model.clone_from(&assistant_content.model);
284 }
285 for block in &assistant_content.content {
286 match block {
287 ContentBlock::Text { text } => content_parts.push(text.clone()),
288 ContentBlock::Thinking { thinking, .. } => {
289 content_parts.push(thinking.clone())
290 }
291 _ => {}
292 }
293 blocks.push(block.clone());
294 }
295 }
296 Message::ResultMsg {
297 stop_reason: reason,
298 usage: u,
299 ..
300 } => {
301 stop_reason = reason;
302 if let Some(u) = u {
303 usage = Some(u.into_iter().collect());
304 }
305 }
306 _ => {}
307 }
308 }
309
310 Ok(MessageResponse {
311 content: content_parts.join(""),
312 blocks,
313 model,
314 stop_reason,
315 session_id: self.session_id.clone(),
316 usage,
317 })
318 }
319
320 pub async fn query(&mut self, content: impl Into<String>) -> Result<()> {
321 self.require_connected()?;
322 let content_str = content.into();
323 let payload = self.build_user_payload(&content_str, None)?;
324 let mut json_payload = serde_json::to_vec(&payload)?;
325 json_payload.push(b'\n');
326 self.transport.write(&json_payload).await
327 }
328
329 pub async fn query_with_session_id(
330 &mut self,
331 content: impl Into<String>,
332 session_id: impl Into<String>,
333 ) -> Result<()> {
334 self.require_connected()?;
335 let content_str = content.into();
336 let session_id = session_id.into();
337 let payload = self.build_user_payload(&content_str, Some(&session_id))?;
338 let mut json_payload = serde_json::to_vec(&payload)?;
339 json_payload.push(b'\n');
340 self.transport.write(&json_payload).await
341 }
342
343 pub async fn query_stream<S>(&mut self, stream: S) -> Result<()>
344 where
345 S: Stream<Item = serde_json::Value> + Unpin,
346 {
347 self.query_stream_with_session_id(stream, "default").await
348 }
349
350 pub async fn query_stream_with_session_id<S>(
351 &mut self,
352 stream: S,
353 session_id: impl Into<String>,
354 ) -> Result<()>
355 where
356 S: Stream<Item = serde_json::Value> + Unpin,
357 {
358 self.require_connected()?;
359 self.write_message_stream(stream, &session_id.into()).await
360 }
361
362 pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
363 self.receive_messages_until(true).await
364 }
365
366 pub async fn receive_messages(&mut self) -> Result<Vec<Message>> {
367 self.receive_messages_until(false).await
368 }
369
370 async fn receive_messages_until(&mut self, stop_at_result: bool) -> Result<Vec<Message>> {
371 self.require_connected()?;
372 let mut messages = Vec::new();
373 while let Some(data) = self.transport.read().await? {
374 let line = String::from_utf8_lossy(&data);
375 let value = serde_json::from_slice::<serde_json::Value>(&data)?;
376 if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
377 respond_to_control_request(
378 self.transport.as_mut(),
379 &value,
380 &self.control_callbacks,
381 )
382 .await?;
383 continue;
384 }
385 if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
386 if let Some(batcher) = &mut self.transcript_mirror {
387 messages.extend(batcher.enqueue_value(&value).await?);
388 }
389 continue;
390 }
391 let message = match parse_message_line(&line) {
392 Ok(Some(message)) => message,
393 Ok(None) => continue,
394 Err(err) => {
395 tracing::warn!("skipping unparseable CLI message: {err}");
398 continue;
399 }
400 };
401 let done = matches!(message, Message::ResultMsg { .. });
402 if done {
403 if let Some(batcher) = &mut self.transcript_mirror {
404 messages.extend(batcher.flush().await?);
405 }
406 }
407 {
408 let mut state = self.state.write().await;
409 state.messages.push(message.clone());
410 }
411 messages.push(message);
412 if stop_at_result && done {
413 break;
414 }
415 }
416 Ok(messages)
417 }
418
419 pub async fn stream_message(
420 &mut self,
421 content: impl Into<String>,
422 ) -> Result<mpsc::UnboundedReceiver<StreamEvent>> {
423 self.require_connected()?;
424 let content_str = content.into();
425 let payload = self.build_user_payload(&content_str, None)?;
426 let json_payload = serde_json::to_vec(&payload)?;
427 self.transport.write(&json_payload).await?;
428 self.transport
429 .write(
430 b"
431",
432 )
433 .await?;
434 let (tx, rx) = mpsc::unbounded_channel();
435 {
436 let mut state = self.state.write().await;
437 state.is_streaming = true;
438 }
439 while let Some(data) = self.transport.read().await? {
440 let line = String::from_utf8_lossy(&data);
441 let value = serde_json::from_slice::<serde_json::Value>(&data)?;
442 if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
443 respond_to_control_request(
444 self.transport.as_mut(),
445 &value,
446 &self.control_callbacks,
447 )
448 .await?;
449 continue;
450 }
451 if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
452 if let Some(batcher) = &mut self.transcript_mirror {
453 for message in batcher.enqueue_value(&value).await? {
454 let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
455 }
456 }
457 continue;
458 }
459 let message = match parse_message_line(&line) {
460 Ok(Some(message)) => message,
461 Ok(None) => continue,
462 Err(err) => {
463 tracing::warn!("skipping unparseable CLI message: {err}");
466 continue;
467 }
468 };
469 for event in stream_events_from_message(&message, &self.session_id) {
470 let _ = tx.send(event);
471 }
472 let done = matches!(message, Message::ResultMsg { .. });
473 if done {
474 if let Some(batcher) = &mut self.transcript_mirror {
475 for message in batcher.flush().await? {
476 let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
477 }
478 }
479 }
480 {
481 let mut state = self.state.write().await;
482 state.messages.push(message);
483 if done {
484 state.is_streaming = false;
485 }
486 }
487 if done {
488 break;
489 }
490 }
491 Ok(rx)
492 }
493
494 async fn write_message_stream<S>(&mut self, mut stream: S, session_id: &str) -> Result<()>
495 where
496 S: Stream<Item = serde_json::Value> + Unpin,
497 {
498 while let Some(mut message) = stream.next().await {
499 if let Some(object) = message.as_object_mut() {
500 object
501 .entry("session_id")
502 .or_insert_with(|| serde_json::Value::String(session_id.to_string()));
503 }
504 let mut json_payload = serde_json::to_vec(&message)?;
505 json_payload.push(b'\n');
506 self.transport.write(&json_payload).await?;
507 }
508 Ok(())
509 }
510
511 pub async fn get_conversation_history(&self) -> Result<Vec<Message>> {
512 let state = self.state.read().await;
513 Ok(state.messages.clone())
514 }
515
516 pub async fn abort(&mut self) -> Result<()> {
517 if let Some(batcher) = &mut self.transcript_mirror {
518 let _ = batcher.flush().await?;
519 }
520 self.transport.close().await?;
521 if let Some(materialized) = &self.materialized_resume {
522 materialized.cleanup().await;
523 }
524 self.materialized_resume = None;
525 self.connected = false;
526 self.initialized = false;
527 Ok(())
528 }
529
530 pub async fn disconnect(&mut self) -> Result<()> {
531 self.abort().await
532 }
533
534 pub async fn close(mut self) -> Result<()> {
535 if let Some(batcher) = &mut self.transcript_mirror {
536 let _ = batcher.flush().await?;
537 }
538 self.transport.close().await?;
539 if let Some(materialized) = &self.materialized_resume {
540 materialized.cleanup().await;
541 }
542 Ok(())
543 }
544
545 pub async fn interrupt(&mut self) -> Result<()> {
546 self.require_connected()?;
547 send_control_request_with_callbacks(
548 self.transport.as_mut(),
549 serde_json::json!({"subtype": "interrupt"}),
550 &self.control_callbacks,
551 )
552 .await?;
553 Ok(())
554 }
555
556 pub async fn set_permission_mode(&mut self, mode: PermissionMode) -> Result<()> {
557 self.require_connected()?;
558 send_control_request_with_callbacks(
559 self.transport.as_mut(),
560 serde_json::json!({
561 "subtype": "set_permission_mode",
562 "mode": mode,
563 }),
564 &self.control_callbacks,
565 )
566 .await?;
567 Ok(())
568 }
569
570 pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
571 self.require_connected()?;
572 let model = model.map(serde_json::Value::String);
573 send_control_request_with_callbacks(
574 self.transport.as_mut(),
575 serde_json::json!({
576 "subtype": "set_model",
577 "model": model.unwrap_or(serde_json::Value::Null),
578 }),
579 &self.control_callbacks,
580 )
581 .await?;
582 Ok(())
583 }
584
585 pub async fn rewind_files(&mut self, user_message_id: impl Into<String>) -> Result<()> {
586 self.require_connected()?;
587 send_control_request_with_callbacks(
588 self.transport.as_mut(),
589 serde_json::json!({
590 "subtype": "rewind_files",
591 "user_message_id": user_message_id.into(),
592 }),
593 &self.control_callbacks,
594 )
595 .await?;
596 Ok(())
597 }
598
599 pub async fn reconnect_mcp_server(&mut self, server_name: impl Into<String>) -> Result<()> {
600 self.require_connected()?;
601 send_control_request_with_callbacks(
602 self.transport.as_mut(),
603 serde_json::json!({
604 "subtype": "mcp_reconnect",
605 "serverName": server_name.into(),
606 }),
607 &self.control_callbacks,
608 )
609 .await?;
610 Ok(())
611 }
612
613 pub async fn toggle_mcp_server(
614 &mut self,
615 server_name: impl Into<String>,
616 enabled: bool,
617 ) -> Result<()> {
618 self.require_connected()?;
619 send_control_request_with_callbacks(
620 self.transport.as_mut(),
621 serde_json::json!({
622 "subtype": "mcp_toggle",
623 "serverName": server_name.into(),
624 "enabled": enabled,
625 }),
626 &self.control_callbacks,
627 )
628 .await?;
629 Ok(())
630 }
631
632 pub async fn stop_task(&mut self, task_id: impl Into<String>) -> Result<()> {
633 self.require_connected()?;
634 send_control_request_with_callbacks(
635 self.transport.as_mut(),
636 serde_json::json!({
637 "subtype": "stop_task",
638 "task_id": task_id.into(),
639 }),
640 &self.control_callbacks,
641 )
642 .await?;
643 Ok(())
644 }
645
646 pub async fn get_mcp_status(&mut self) -> Result<MCPStatusResponse> {
647 self.require_connected()?;
648 let response = send_control_request_with_callbacks(
649 self.transport.as_mut(),
650 serde_json::json!({"subtype": "mcp_status"}),
651 &self.control_callbacks,
652 )
653 .await?;
654 let value = serde_json::Value::Object(response);
655 Ok(serde_json::from_value(value)?)
656 }
657
658 pub async fn get_context_usage(&mut self) -> Result<ContextUsageResponse> {
659 self.require_connected()?;
660 let response = send_control_request_with_callbacks(
661 self.transport.as_mut(),
662 serde_json::json!({"subtype": "get_context_usage"}),
663 &self.control_callbacks,
664 )
665 .await?;
666 Ok(serde_json::from_value(serde_json::Value::Object(response))?)
667 }
668
669 pub fn get_server_info(&self) -> Option<&serde_json::Map<String, serde_json::Value>> {
670 self.initialization_result.as_ref()
671 }
672
673 fn build_user_payload(
674 &self,
675 content: &str,
676 session_id: Option<&str>,
677 ) -> Result<serde_json::Map<String, serde_json::Value>> {
678 let mut payload = serde_json::Map::new();
679 payload.insert(
680 "type".to_string(),
681 serde_json::Value::String("user".to_string()),
682 );
683 payload.insert(
684 "session_id".to_string(),
685 serde_json::Value::String(
686 session_id
687 .map(String::from)
688 .unwrap_or_else(|| self.session_id.clone()),
689 ),
690 );
691 let message = serde_json::json!({"role": "user", "content": content});
692 payload.insert("message".to_string(), message);
693 Ok(payload)
694 }
695}