1use futures::{Stream, StreamExt};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::{mpsc, RwLock};
7
8use crate::client_stream::stream_events_from_message;
9use crate::client_types::{MessageResponse, StreamEvent};
10use crate::error::{CLIConnectionError, Result};
11use crate::internal::control::{
12 initialize_request, initialize_timeout_duration, respond_to_control_request,
13 send_control_request_with_callbacks, send_control_request_with_callbacks_and_timeout,
14 ControlCallbacks,
15};
16use crate::internal::parser::parse_message_line;
17use crate::internal::session_resume::{
18 apply_materialized_options, materialize_resume_session, MaterializedResume,
19};
20use crate::internal::session_store_validation::validate_session_store_options;
21use crate::internal::transcript_mirror::TranscriptMirrorBatcher;
22use crate::internal::transport::{SubprocessCLITransport, Transport, TransportOptions};
23use crate::types::{
24 ClaudeAgentOptions, ContentBlock, ContextUsageResponse, MCPStatusResponse, Message,
25 PermissionMode, UserMessageInput,
26};
27
28#[derive(Debug)]
29#[allow(dead_code)]
30struct ClientState {
31 messages: Vec<Message>,
32 current_stream_buffer: String,
33 is_streaming: bool,
34 server_info: Option<HashMap<String, serde_json::Value>>,
35}
36
37pub struct ClaudeAgentClient {
38 transport: Box<dyn Transport>,
39 state: Arc<RwLock<ClientState>>,
40 session_id: String,
41 connected: bool,
42 initialized: bool,
43 initialization_result: Option<serde_json::Map<String, serde_json::Value>>,
44 control_callbacks: ControlCallbacks,
45 transcript_mirror: Option<TranscriptMirrorBatcher>,
46 source_options: Option<ClaudeAgentOptions>,
47 materialized_resume: Option<MaterializedResume>,
48}
49
50impl std::fmt::Debug for ClaudeAgentClient {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 f.debug_struct("ClaudeAgentClient")
53 .field("session_id", &self.session_id)
54 .finish_non_exhaustive()
55 }
56}
57
58impl ClaudeAgentClient {
59 pub fn spawn_stream_message(
73 options: ClaudeAgentOptions,
74 content: impl Into<UserMessageInput>,
75 ) -> mpsc::UnboundedReceiver<StreamEvent> {
76 let content = content.into();
77 let (tx, rx) = mpsc::unbounded_channel();
78 tokio::spawn(async move {
79 if let Err(err) = Self::run_spawned_stream(options, content, tx.clone()).await {
80 let _ = tx.send(StreamEvent::Error(err.to_string()));
81 }
82 });
83 rx
84 }
85
86 async fn run_spawned_stream(
87 options: ClaudeAgentOptions,
88 content: UserMessageInput,
89 tx: mpsc::UnboundedSender<StreamEvent>,
90 ) -> Result<()> {
91 let mut client = Self::new(options)?;
92 client.connect().await?;
93 client.require_connected()?;
94 let payload = client.build_user_payload(&content, None)?;
95 let json_payload = serde_json::to_vec(&payload)?;
96 client.transport.write(&json_payload).await?;
97 client.transport.write(b"\n").await?;
98 {
99 let mut state = client.state.write().await;
100 state.is_streaming = true;
101 }
102 while let Some(data) = client.transport.read().await? {
103 let line = String::from_utf8_lossy(&data);
104 let value = serde_json::from_slice::<serde_json::Value>(&data)?;
105 if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
106 respond_to_control_request(
107 client.transport.as_mut(),
108 &value,
109 &client.control_callbacks,
110 )
111 .await?;
112 continue;
113 }
114 if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
115 if let Some(batcher) = &mut client.transcript_mirror {
116 for message in batcher.enqueue_value(&value).await? {
117 let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
118 }
119 }
120 continue;
121 }
122 let message = match parse_message_line(&line) {
123 Ok(Some(message)) => message,
124 Ok(None) => continue,
125 Err(err) => {
126 tracing::warn!("skipping unparseable CLI message: {err}");
129 continue;
130 }
131 };
132 for event in stream_events_from_message(&message, &client.session_id) {
133 let _ = tx.send(event);
134 }
135 let done = matches!(message, Message::ResultMsg { .. });
136 if done {
137 if let Some(batcher) = &mut client.transcript_mirror {
138 for message in batcher.flush().await? {
139 let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
140 }
141 }
142 }
143 {
144 let mut state = client.state.write().await;
145 state.messages.push(message);
146 if done {
147 state.is_streaming = false;
148 }
149 }
150 if done {
151 break;
152 }
153 }
154 Ok(())
155 }
156
157 pub fn new(options: ClaudeAgentOptions) -> Result<Self> {
158 validate_session_store_options(&options)?;
159 let transport_options = TransportOptions::from(&options);
160 let transport = SubprocessCLITransport::new(transport_options);
161 let mut client = Self::with_transport(options.clone(), Box::new(transport))?;
162 client.source_options = Some(options);
163 Ok(client)
164 }
165
166 pub fn with_transport(
167 options: ClaudeAgentOptions,
168 transport: Box<dyn Transport>,
169 ) -> Result<Self> {
170 let session_id = options
171 .session_id
172 .clone()
173 .or_else(|| options.resume.clone())
174 .unwrap_or_else(|| "default".to_string());
175 let state = Arc::new(RwLock::new(ClientState {
176 messages: Vec::new(),
177 current_stream_buffer: String::new(),
178 is_streaming: false,
179 server_info: None,
180 }));
181 Ok(Self {
182 transport,
183 state,
184 session_id,
185 connected: false,
186 initialized: false,
187 initialization_result: None,
188 control_callbacks: ControlCallbacks::from_options(&options),
189 transcript_mirror: TranscriptMirrorBatcher::from_options(&options),
190 source_options: None,
191 materialized_resume: None,
192 })
193 }
194
195 pub async fn connect(&mut self) -> Result<()> {
196 if !self.connected {
197 self.materialize_resume_before_connect().await?;
198 self.transport.connect().await?;
199 self.connected = true;
200 }
201 self.ensure_initialized().await?;
202 Ok(())
203 }
204
205 pub async fn connect_with_prompt(
206 &mut self,
207 content: impl Into<UserMessageInput>,
208 ) -> Result<()> {
209 self.connect().await?;
210 let content = content.into();
211 let payload = self.build_user_payload(&content, None)?;
212 let mut json_payload = serde_json::to_vec(&payload)?;
213 json_payload.push(b'\n');
214 self.transport.write(&json_payload).await
215 }
216
217 pub async fn connect_with_stream<S>(&mut self, stream: S) -> Result<()>
218 where
219 S: Stream<Item = serde_json::Value> + Unpin,
220 {
221 self.connect().await?;
222 self.write_message_stream(stream, "default").await
223 }
224
225 async fn materialize_resume_before_connect(&mut self) -> Result<()> {
226 let Some(options) = self.source_options.clone() else {
227 return Ok(());
228 };
229 let Some(materialized) = materialize_resume_session(&options).await? else {
230 return Ok(());
231 };
232 let options = apply_materialized_options(&options, &materialized);
233 self.session_id = options
234 .session_id
235 .clone()
236 .or_else(|| options.resume.clone())
237 .unwrap_or_else(|| "default".to_string());
238 self.transport = Box::new(SubprocessCLITransport::new(TransportOptions::from(
239 &options,
240 )));
241 self.transcript_mirror = TranscriptMirrorBatcher::from_options(&options);
242 self.source_options = Some(options);
243 self.materialized_resume = Some(materialized);
244 Ok(())
245 }
246
247 fn require_connected(&self) -> Result<()> {
248 if self.connected && self.initialized {
249 Ok(())
250 } else {
251 Err(CLIConnectionError::new("Not connected. Call connect() first.").into())
252 }
253 }
254
255 async fn ensure_initialized(&mut self) -> Result<()> {
256 if self.initialized {
257 return Ok(());
258 }
259
260 let response = send_control_request_with_callbacks_and_timeout(
261 self.transport.as_mut(),
262 initialize_request(&self.control_callbacks),
263 &self.control_callbacks,
264 initialize_timeout_duration(),
265 )
266 .await?;
267 self.initialization_result = Some(response);
268 self.initialized = true;
269 Ok(())
270 }
271
272 pub async fn send_message(
273 &mut self,
274 content: impl Into<UserMessageInput>,
275 ) -> Result<MessageResponse> {
276 self.query(content).await?;
277 let messages = self.receive_response().await?;
278 let mut content_parts: Vec<String> = Vec::new();
279 let mut blocks: Vec<ContentBlock> = Vec::new();
280 let mut usage: Option<HashMap<String, serde_json::Value>> = None;
281 let mut stop_reason: Option<String> = None;
282 let mut model = String::new();
283
284 for message in messages {
285 match message {
286 Message::AssistantMsg {
287 content: assistant_content,
288 ..
289 } => {
290 if model.is_empty() {
292 model.clone_from(&assistant_content.model);
293 }
294 for block in &assistant_content.content {
295 match block {
296 ContentBlock::Text { text } => content_parts.push(text.clone()),
297 ContentBlock::Thinking { thinking, .. } => {
298 content_parts.push(thinking.clone())
299 }
300 _ => {}
301 }
302 blocks.push(block.clone());
303 }
304 }
305 Message::ResultMsg {
306 stop_reason: reason,
307 usage: u,
308 ..
309 } => {
310 stop_reason = reason;
311 if let Some(u) = u {
312 usage = Some(u.into_iter().collect());
313 }
314 }
315 _ => {}
316 }
317 }
318
319 Ok(MessageResponse {
320 content: content_parts.join(""),
321 blocks,
322 model,
323 stop_reason,
324 session_id: self.session_id.clone(),
325 usage,
326 })
327 }
328
329 pub async fn query(&mut self, content: impl Into<UserMessageInput>) -> Result<()> {
330 self.require_connected()?;
331 let content = content.into();
332 let payload = self.build_user_payload(&content, None)?;
333 let mut json_payload = serde_json::to_vec(&payload)?;
334 json_payload.push(b'\n');
335 self.transport.write(&json_payload).await
336 }
337
338 pub async fn query_with_session_id(
339 &mut self,
340 content: impl Into<UserMessageInput>,
341 session_id: impl Into<String>,
342 ) -> Result<()> {
343 self.require_connected()?;
344 let content = content.into();
345 let session_id = session_id.into();
346 let payload = self.build_user_payload(&content, Some(&session_id))?;
347 let mut json_payload = serde_json::to_vec(&payload)?;
348 json_payload.push(b'\n');
349 self.transport.write(&json_payload).await
350 }
351
352 pub async fn query_stream<S>(&mut self, stream: S) -> Result<()>
353 where
354 S: Stream<Item = serde_json::Value> + Unpin,
355 {
356 self.query_stream_with_session_id(stream, "default").await
357 }
358
359 pub async fn query_stream_with_session_id<S>(
360 &mut self,
361 stream: S,
362 session_id: impl Into<String>,
363 ) -> Result<()>
364 where
365 S: Stream<Item = serde_json::Value> + Unpin,
366 {
367 self.require_connected()?;
368 self.write_message_stream(stream, &session_id.into()).await
369 }
370
371 pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
372 self.receive_messages_until(true).await
373 }
374
375 pub async fn receive_messages(&mut self) -> Result<Vec<Message>> {
376 self.receive_messages_until(false).await
377 }
378
379 async fn receive_messages_until(&mut self, stop_at_result: bool) -> Result<Vec<Message>> {
380 self.require_connected()?;
381 let mut messages = Vec::new();
382 while let Some(data) = self.transport.read().await? {
383 let line = String::from_utf8_lossy(&data);
384 let value = serde_json::from_slice::<serde_json::Value>(&data)?;
385 if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
386 respond_to_control_request(
387 self.transport.as_mut(),
388 &value,
389 &self.control_callbacks,
390 )
391 .await?;
392 continue;
393 }
394 if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
395 if let Some(batcher) = &mut self.transcript_mirror {
396 messages.extend(batcher.enqueue_value(&value).await?);
397 }
398 continue;
399 }
400 let message = match parse_message_line(&line) {
401 Ok(Some(message)) => message,
402 Ok(None) => continue,
403 Err(err) => {
404 tracing::warn!("skipping unparseable CLI message: {err}");
407 continue;
408 }
409 };
410 let done = matches!(message, Message::ResultMsg { .. });
411 if done {
412 if let Some(batcher) = &mut self.transcript_mirror {
413 messages.extend(batcher.flush().await?);
414 }
415 }
416 {
417 let mut state = self.state.write().await;
418 state.messages.push(message.clone());
419 }
420 messages.push(message);
421 if stop_at_result && done {
422 break;
423 }
424 }
425 Ok(messages)
426 }
427
428 pub async fn stream_message(
429 &mut self,
430 content: impl Into<UserMessageInput>,
431 ) -> Result<mpsc::UnboundedReceiver<StreamEvent>> {
432 self.require_connected()?;
433 let content = content.into();
434 let payload = self.build_user_payload(&content, None)?;
435 let json_payload = serde_json::to_vec(&payload)?;
436 self.transport.write(&json_payload).await?;
437 self.transport
438 .write(
439 b"
440",
441 )
442 .await?;
443 let (tx, rx) = mpsc::unbounded_channel();
444 {
445 let mut state = self.state.write().await;
446 state.is_streaming = true;
447 }
448 while let Some(data) = self.transport.read().await? {
449 let line = String::from_utf8_lossy(&data);
450 let value = serde_json::from_slice::<serde_json::Value>(&data)?;
451 if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
452 respond_to_control_request(
453 self.transport.as_mut(),
454 &value,
455 &self.control_callbacks,
456 )
457 .await?;
458 continue;
459 }
460 if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
461 if let Some(batcher) = &mut self.transcript_mirror {
462 for message in batcher.enqueue_value(&value).await? {
463 let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
464 }
465 }
466 continue;
467 }
468 let message = match parse_message_line(&line) {
469 Ok(Some(message)) => message,
470 Ok(None) => continue,
471 Err(err) => {
472 tracing::warn!("skipping unparseable CLI message: {err}");
475 continue;
476 }
477 };
478 for event in stream_events_from_message(&message, &self.session_id) {
479 let _ = tx.send(event);
480 }
481 let done = matches!(message, Message::ResultMsg { .. });
482 if done {
483 if let Some(batcher) = &mut self.transcript_mirror {
484 for message in batcher.flush().await? {
485 let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
486 }
487 }
488 }
489 {
490 let mut state = self.state.write().await;
491 state.messages.push(message);
492 if done {
493 state.is_streaming = false;
494 }
495 }
496 if done {
497 break;
498 }
499 }
500 Ok(rx)
501 }
502
503 async fn write_message_stream<S>(&mut self, mut stream: S, session_id: &str) -> Result<()>
504 where
505 S: Stream<Item = serde_json::Value> + Unpin,
506 {
507 while let Some(mut message) = stream.next().await {
508 if let Some(object) = message.as_object_mut() {
509 object
510 .entry("session_id")
511 .or_insert_with(|| serde_json::Value::String(session_id.to_string()));
512 }
513 let mut json_payload = serde_json::to_vec(&message)?;
514 json_payload.push(b'\n');
515 self.transport.write(&json_payload).await?;
516 }
517 Ok(())
518 }
519
520 pub async fn get_conversation_history(&self) -> Result<Vec<Message>> {
521 let state = self.state.read().await;
522 Ok(state.messages.clone())
523 }
524
525 pub async fn abort(&mut self) -> Result<()> {
526 if let Some(batcher) = &mut self.transcript_mirror {
527 let _ = batcher.flush().await?;
528 }
529 self.transport.close().await?;
530 if let Some(materialized) = &self.materialized_resume {
531 materialized.cleanup().await;
532 }
533 self.materialized_resume = None;
534 self.connected = false;
535 self.initialized = false;
536 Ok(())
537 }
538
539 pub async fn disconnect(&mut self) -> Result<()> {
540 self.abort().await
541 }
542
543 pub async fn close(mut self) -> Result<()> {
544 if let Some(batcher) = &mut self.transcript_mirror {
545 let _ = batcher.flush().await?;
546 }
547 self.transport.close().await?;
548 if let Some(materialized) = &self.materialized_resume {
549 materialized.cleanup().await;
550 }
551 Ok(())
552 }
553
554 pub async fn interrupt(&mut self) -> Result<()> {
555 self.require_connected()?;
556 send_control_request_with_callbacks(
557 self.transport.as_mut(),
558 serde_json::json!({"subtype": "interrupt"}),
559 &self.control_callbacks,
560 )
561 .await?;
562 Ok(())
563 }
564
565 pub async fn set_permission_mode(&mut self, mode: PermissionMode) -> Result<()> {
566 self.require_connected()?;
567 send_control_request_with_callbacks(
568 self.transport.as_mut(),
569 serde_json::json!({
570 "subtype": "set_permission_mode",
571 "mode": mode,
572 }),
573 &self.control_callbacks,
574 )
575 .await?;
576 Ok(())
577 }
578
579 pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
580 self.require_connected()?;
581 let model = model.map(serde_json::Value::String);
582 send_control_request_with_callbacks(
583 self.transport.as_mut(),
584 serde_json::json!({
585 "subtype": "set_model",
586 "model": model.unwrap_or(serde_json::Value::Null),
587 }),
588 &self.control_callbacks,
589 )
590 .await?;
591 Ok(())
592 }
593
594 pub async fn rewind_files(&mut self, user_message_id: impl Into<String>) -> Result<()> {
595 self.require_connected()?;
596 send_control_request_with_callbacks(
597 self.transport.as_mut(),
598 serde_json::json!({
599 "subtype": "rewind_files",
600 "user_message_id": user_message_id.into(),
601 }),
602 &self.control_callbacks,
603 )
604 .await?;
605 Ok(())
606 }
607
608 pub async fn reconnect_mcp_server(&mut self, server_name: impl Into<String>) -> Result<()> {
609 self.require_connected()?;
610 send_control_request_with_callbacks(
611 self.transport.as_mut(),
612 serde_json::json!({
613 "subtype": "mcp_reconnect",
614 "serverName": server_name.into(),
615 }),
616 &self.control_callbacks,
617 )
618 .await?;
619 Ok(())
620 }
621
622 pub async fn toggle_mcp_server(
623 &mut self,
624 server_name: impl Into<String>,
625 enabled: bool,
626 ) -> Result<()> {
627 self.require_connected()?;
628 send_control_request_with_callbacks(
629 self.transport.as_mut(),
630 serde_json::json!({
631 "subtype": "mcp_toggle",
632 "serverName": server_name.into(),
633 "enabled": enabled,
634 }),
635 &self.control_callbacks,
636 )
637 .await?;
638 Ok(())
639 }
640
641 pub async fn stop_task(&mut self, task_id: impl Into<String>) -> Result<()> {
642 self.require_connected()?;
643 send_control_request_with_callbacks(
644 self.transport.as_mut(),
645 serde_json::json!({
646 "subtype": "stop_task",
647 "task_id": task_id.into(),
648 }),
649 &self.control_callbacks,
650 )
651 .await?;
652 Ok(())
653 }
654
655 pub async fn get_mcp_status(&mut self) -> Result<MCPStatusResponse> {
656 self.require_connected()?;
657 let response = send_control_request_with_callbacks(
658 self.transport.as_mut(),
659 serde_json::json!({"subtype": "mcp_status"}),
660 &self.control_callbacks,
661 )
662 .await?;
663 let value = serde_json::Value::Object(response);
664 Ok(serde_json::from_value(value)?)
665 }
666
667 pub async fn get_context_usage(&mut self) -> Result<ContextUsageResponse> {
668 self.require_connected()?;
669 let response = send_control_request_with_callbacks(
670 self.transport.as_mut(),
671 serde_json::json!({"subtype": "get_context_usage"}),
672 &self.control_callbacks,
673 )
674 .await?;
675 Ok(serde_json::from_value(serde_json::Value::Object(response))?)
676 }
677
678 pub fn get_server_info(&self) -> Option<&serde_json::Map<String, serde_json::Value>> {
679 self.initialization_result.as_ref()
680 }
681
682 fn build_user_payload(
683 &self,
684 content: &UserMessageInput,
685 session_id: Option<&str>,
686 ) -> Result<serde_json::Map<String, serde_json::Value>> {
687 let mut payload = serde_json::Map::new();
688 payload.insert(
689 "type".to_string(),
690 serde_json::Value::String("user".to_string()),
691 );
692 payload.insert(
693 "session_id".to_string(),
694 serde_json::Value::String(
695 session_id
696 .map(String::from)
697 .unwrap_or_else(|| self.session_id.clone()),
698 ),
699 );
700 let message = serde_json::json!({"role": "user", "content": content.to_content_value()});
703 payload.insert("message".to_string(), message);
704 Ok(payload)
705 }
706}