1use futures::{Stream, StreamExt};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::{mpsc, RwLock};
7
8use crate::client_stream::stream_events_from_message;
9use crate::client_types::{MessageResponse, StreamEvent};
10use crate::error::{CLIConnectionError, Result};
11use crate::internal::control::{
12 initialize_request, initialize_timeout_duration, respond_to_control_request,
13 send_control_request_with_callbacks, send_control_request_with_callbacks_and_timeout,
14 ControlCallbacks,
15};
16use crate::internal::parser::parse_message_line;
17use crate::internal::session_resume::{
18 apply_materialized_options, materialize_resume_session, MaterializedResume,
19};
20use crate::internal::session_store_validation::validate_session_store_options;
21use crate::internal::transcript_mirror::TranscriptMirrorBatcher;
22use crate::internal::transport::{SubprocessCLITransport, Transport, TransportOptions};
23use crate::types::{
24 ClaudeAgentOptions, ContentBlock, ContextUsageResponse, MCPStatusResponse, Message,
25 PermissionMode,
26};
27
28#[derive(Debug)]
29#[allow(dead_code)]
30struct ClientState {
31 messages: Vec<Message>,
32 current_stream_buffer: String,
33 is_streaming: bool,
34 server_info: Option<HashMap<String, serde_json::Value>>,
35}
36
37pub struct ClaudeAgentClient {
38 transport: Box<dyn Transport>,
39 state: Arc<RwLock<ClientState>>,
40 session_id: String,
41 connected: bool,
42 initialized: bool,
43 initialization_result: Option<serde_json::Map<String, serde_json::Value>>,
44 control_callbacks: ControlCallbacks,
45 transcript_mirror: Option<TranscriptMirrorBatcher>,
46 source_options: Option<ClaudeAgentOptions>,
47 materialized_resume: Option<MaterializedResume>,
48}
49
50impl std::fmt::Debug for ClaudeAgentClient {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 f.debug_struct("ClaudeAgentClient")
53 .field("session_id", &self.session_id)
54 .finish_non_exhaustive()
55 }
56}
57
58impl ClaudeAgentClient {
59 pub fn new(options: ClaudeAgentOptions) -> Result<Self> {
60 validate_session_store_options(&options)?;
61 let transport_options = TransportOptions::from(&options);
62 let transport = SubprocessCLITransport::new(transport_options);
63 let mut client = Self::with_transport(options.clone(), Box::new(transport))?;
64 client.source_options = Some(options);
65 Ok(client)
66 }
67
68 pub fn with_transport(
69 options: ClaudeAgentOptions,
70 transport: Box<dyn Transport>,
71 ) -> Result<Self> {
72 let session_id = options
73 .session_id
74 .clone()
75 .or_else(|| options.resume.clone())
76 .unwrap_or_else(|| "default".to_string());
77 let state = Arc::new(RwLock::new(ClientState {
78 messages: Vec::new(),
79 current_stream_buffer: String::new(),
80 is_streaming: false,
81 server_info: None,
82 }));
83 Ok(Self {
84 transport,
85 state,
86 session_id,
87 connected: false,
88 initialized: false,
89 initialization_result: None,
90 control_callbacks: ControlCallbacks::from_options(&options),
91 transcript_mirror: TranscriptMirrorBatcher::from_options(&options),
92 source_options: None,
93 materialized_resume: None,
94 })
95 }
96
97 pub async fn connect(&mut self) -> Result<()> {
98 if !self.connected {
99 self.materialize_resume_before_connect().await?;
100 self.transport.connect().await?;
101 self.connected = true;
102 }
103 self.ensure_initialized().await?;
104 Ok(())
105 }
106
107 pub async fn connect_with_prompt(&mut self, content: impl Into<String>) -> Result<()> {
108 self.connect().await?;
109 let content = content.into();
110 let payload = self.build_user_payload(&content, None)?;
111 let mut json_payload = serde_json::to_vec(&payload)?;
112 json_payload.push(b'\n');
113 self.transport.write(&json_payload).await
114 }
115
116 pub async fn connect_with_stream<S>(&mut self, stream: S) -> Result<()>
117 where
118 S: Stream<Item = serde_json::Value> + Unpin,
119 {
120 self.connect().await?;
121 self.write_message_stream(stream, "default").await
122 }
123
124 async fn materialize_resume_before_connect(&mut self) -> Result<()> {
125 let Some(options) = self.source_options.clone() else {
126 return Ok(());
127 };
128 let Some(materialized) = materialize_resume_session(&options).await? else {
129 return Ok(());
130 };
131 let options = apply_materialized_options(&options, &materialized);
132 self.session_id = options
133 .session_id
134 .clone()
135 .or_else(|| options.resume.clone())
136 .unwrap_or_else(|| "default".to_string());
137 self.transport = Box::new(SubprocessCLITransport::new(TransportOptions::from(
138 &options,
139 )));
140 self.transcript_mirror = TranscriptMirrorBatcher::from_options(&options);
141 self.source_options = Some(options);
142 self.materialized_resume = Some(materialized);
143 Ok(())
144 }
145
146 fn require_connected(&self) -> Result<()> {
147 if self.connected && self.initialized {
148 Ok(())
149 } else {
150 Err(CLIConnectionError::new("Not connected. Call connect() first.").into())
151 }
152 }
153
154 async fn ensure_initialized(&mut self) -> Result<()> {
155 if self.initialized {
156 return Ok(());
157 }
158
159 let response = send_control_request_with_callbacks_and_timeout(
160 self.transport.as_mut(),
161 initialize_request(&self.control_callbacks),
162 &self.control_callbacks,
163 initialize_timeout_duration(),
164 )
165 .await?;
166 self.initialization_result = Some(response);
167 self.initialized = true;
168 Ok(())
169 }
170
171 pub async fn send_message(&mut self, content: impl Into<String>) -> Result<MessageResponse> {
172 self.query(content).await?;
173 let messages = self.receive_response().await?;
174 let mut content_parts: Vec<String> = Vec::new();
175 let mut blocks: Vec<ContentBlock> = Vec::new();
176 let mut usage: Option<HashMap<String, serde_json::Value>> = None;
177 let mut stop_reason: Option<String> = None;
178 let mut model = String::new();
179
180 for message in messages {
181 match message {
182 Message::AssistantMsg {
183 content: assistant_content,
184 ..
185 } => {
186 if model.is_empty() {
188 model.clone_from(&assistant_content.model);
189 }
190 for block in &assistant_content.content {
191 match block {
192 ContentBlock::Text { text } => content_parts.push(text.clone()),
193 ContentBlock::Thinking { thinking, .. } => {
194 content_parts.push(thinking.clone())
195 }
196 _ => {}
197 }
198 blocks.push(block.clone());
199 }
200 }
201 Message::ResultMsg {
202 stop_reason: reason,
203 usage: u,
204 ..
205 } => {
206 stop_reason = reason;
207 if let Some(u) = u {
208 usage = Some(u.into_iter().collect());
209 }
210 }
211 _ => {}
212 }
213 }
214
215 Ok(MessageResponse {
216 content: content_parts.join(""),
217 blocks,
218 model,
219 stop_reason,
220 session_id: self.session_id.clone(),
221 usage,
222 })
223 }
224
225 pub async fn query(&mut self, content: impl Into<String>) -> Result<()> {
226 self.require_connected()?;
227 let content_str = content.into();
228 let payload = self.build_user_payload(&content_str, None)?;
229 let mut json_payload = serde_json::to_vec(&payload)?;
230 json_payload.push(b'\n');
231 self.transport.write(&json_payload).await
232 }
233
234 pub async fn query_with_session_id(
235 &mut self,
236 content: impl Into<String>,
237 session_id: impl Into<String>,
238 ) -> Result<()> {
239 self.require_connected()?;
240 let content_str = content.into();
241 let session_id = session_id.into();
242 let payload = self.build_user_payload(&content_str, Some(&session_id))?;
243 let mut json_payload = serde_json::to_vec(&payload)?;
244 json_payload.push(b'\n');
245 self.transport.write(&json_payload).await
246 }
247
248 pub async fn query_stream<S>(&mut self, stream: S) -> Result<()>
249 where
250 S: Stream<Item = serde_json::Value> + Unpin,
251 {
252 self.query_stream_with_session_id(stream, "default").await
253 }
254
255 pub async fn query_stream_with_session_id<S>(
256 &mut self,
257 stream: S,
258 session_id: impl Into<String>,
259 ) -> Result<()>
260 where
261 S: Stream<Item = serde_json::Value> + Unpin,
262 {
263 self.require_connected()?;
264 self.write_message_stream(stream, &session_id.into()).await
265 }
266
267 pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
268 self.receive_messages_until(true).await
269 }
270
271 pub async fn receive_messages(&mut self) -> Result<Vec<Message>> {
272 self.receive_messages_until(false).await
273 }
274
275 async fn receive_messages_until(&mut self, stop_at_result: bool) -> Result<Vec<Message>> {
276 self.require_connected()?;
277 let mut messages = Vec::new();
278 while let Some(data) = self.transport.read().await? {
279 let line = String::from_utf8_lossy(&data);
280 let value = serde_json::from_slice::<serde_json::Value>(&data)?;
281 if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
282 respond_to_control_request(
283 self.transport.as_mut(),
284 &value,
285 &self.control_callbacks,
286 )
287 .await?;
288 continue;
289 }
290 if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
291 if let Some(batcher) = &mut self.transcript_mirror {
292 messages.extend(batcher.enqueue_value(&value).await?);
293 }
294 continue;
295 }
296 let Some(message) = parse_message_line(&line)? else {
297 continue;
298 };
299 let done = matches!(message, Message::ResultMsg { .. });
300 if done {
301 if let Some(batcher) = &mut self.transcript_mirror {
302 messages.extend(batcher.flush().await?);
303 }
304 }
305 {
306 let mut state = self.state.write().await;
307 state.messages.push(message.clone());
308 }
309 messages.push(message);
310 if stop_at_result && done {
311 break;
312 }
313 }
314 Ok(messages)
315 }
316
317 pub async fn stream_message(
318 &mut self,
319 content: impl Into<String>,
320 ) -> Result<mpsc::UnboundedReceiver<StreamEvent>> {
321 self.require_connected()?;
322 let content_str = content.into();
323 let payload = self.build_user_payload(&content_str, None)?;
324 let json_payload = serde_json::to_vec(&payload)?;
325 self.transport.write(&json_payload).await?;
326 self.transport
327 .write(
328 b"
329",
330 )
331 .await?;
332 let (tx, rx) = mpsc::unbounded_channel();
333 {
334 let mut state = self.state.write().await;
335 state.is_streaming = true;
336 }
337 while let Some(data) = self.transport.read().await? {
338 let line = String::from_utf8_lossy(&data);
339 let value = serde_json::from_slice::<serde_json::Value>(&data)?;
340 if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
341 respond_to_control_request(
342 self.transport.as_mut(),
343 &value,
344 &self.control_callbacks,
345 )
346 .await?;
347 continue;
348 }
349 if value.get("type").and_then(|v| v.as_str()) == Some("transcript_mirror") {
350 if let Some(batcher) = &mut self.transcript_mirror {
351 for message in batcher.enqueue_value(&value).await? {
352 let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
353 }
354 }
355 continue;
356 }
357 let Some(message) = parse_message_line(&line)? else {
358 continue;
359 };
360 for event in stream_events_from_message(&message, &self.session_id) {
361 let _ = tx.send(event);
362 }
363 let done = matches!(message, Message::ResultMsg { .. });
364 if done {
365 if let Some(batcher) = &mut self.transcript_mirror {
366 for message in batcher.flush().await? {
367 let _ = tx.send(StreamEvent::Error(format!("{message:?}")));
368 }
369 }
370 }
371 {
372 let mut state = self.state.write().await;
373 state.messages.push(message);
374 if done {
375 state.is_streaming = false;
376 }
377 }
378 if done {
379 break;
380 }
381 }
382 Ok(rx)
383 }
384
385 pub fn spawn_stream_message(
396 options: ClaudeAgentOptions,
397 prompt: impl Into<String>,
398 ) -> mpsc::UnboundedReceiver<StreamEvent> {
399 let prompt = prompt.into();
400 let (tx, rx) = mpsc::unbounded_channel();
401 tokio::spawn(async move {
402 let mut client = match ClaudeAgentClient::new(options) {
403 Ok(client) => client,
404 Err(err) => {
405 let _ = tx.send(StreamEvent::Error(err.to_string()));
406 return;
407 }
408 };
409 if let Err(err) = client.connect().await {
410 let _ = tx.send(StreamEvent::Error(err.to_string()));
411 return;
412 }
413 match client.stream_message(prompt).await {
414 Ok(mut events) => {
415 while let Some(event) = events.recv().await {
416 if tx.send(event).is_err() {
417 break;
418 }
419 }
420 }
421 Err(err) => {
422 let _ = tx.send(StreamEvent::Error(err.to_string()));
423 }
424 }
425 drop(client);
427 });
428 rx
429 }
430
431 async fn write_message_stream<S>(&mut self, mut stream: S, session_id: &str) -> Result<()>
432 where
433 S: Stream<Item = serde_json::Value> + Unpin,
434 {
435 while let Some(mut message) = stream.next().await {
436 if let Some(object) = message.as_object_mut() {
437 object
438 .entry("session_id")
439 .or_insert_with(|| serde_json::Value::String(session_id.to_string()));
440 }
441 let mut json_payload = serde_json::to_vec(&message)?;
442 json_payload.push(b'\n');
443 self.transport.write(&json_payload).await?;
444 }
445 Ok(())
446 }
447
448 pub async fn get_conversation_history(&self) -> Result<Vec<Message>> {
449 let state = self.state.read().await;
450 Ok(state.messages.clone())
451 }
452
453 pub async fn abort(&mut self) -> Result<()> {
454 if let Some(batcher) = &mut self.transcript_mirror {
455 let _ = batcher.flush().await?;
456 }
457 self.transport.close().await?;
458 if let Some(materialized) = &self.materialized_resume {
459 materialized.cleanup().await;
460 }
461 self.materialized_resume = None;
462 self.connected = false;
463 self.initialized = false;
464 Ok(())
465 }
466
467 pub async fn disconnect(&mut self) -> Result<()> {
468 self.abort().await
469 }
470
471 pub async fn close(mut self) -> Result<()> {
472 if let Some(batcher) = &mut self.transcript_mirror {
473 let _ = batcher.flush().await?;
474 }
475 self.transport.close().await?;
476 if let Some(materialized) = &self.materialized_resume {
477 materialized.cleanup().await;
478 }
479 Ok(())
480 }
481
482 pub async fn interrupt(&mut self) -> Result<()> {
483 self.require_connected()?;
484 send_control_request_with_callbacks(
485 self.transport.as_mut(),
486 serde_json::json!({"subtype": "interrupt"}),
487 &self.control_callbacks,
488 )
489 .await?;
490 Ok(())
491 }
492
493 pub async fn set_permission_mode(&mut self, mode: PermissionMode) -> Result<()> {
494 self.require_connected()?;
495 send_control_request_with_callbacks(
496 self.transport.as_mut(),
497 serde_json::json!({
498 "subtype": "set_permission_mode",
499 "mode": mode,
500 }),
501 &self.control_callbacks,
502 )
503 .await?;
504 Ok(())
505 }
506
507 pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
508 self.require_connected()?;
509 let model = model.map(serde_json::Value::String);
510 send_control_request_with_callbacks(
511 self.transport.as_mut(),
512 serde_json::json!({
513 "subtype": "set_model",
514 "model": model.unwrap_or(serde_json::Value::Null),
515 }),
516 &self.control_callbacks,
517 )
518 .await?;
519 Ok(())
520 }
521
522 pub async fn rewind_files(&mut self, user_message_id: impl Into<String>) -> Result<()> {
523 self.require_connected()?;
524 send_control_request_with_callbacks(
525 self.transport.as_mut(),
526 serde_json::json!({
527 "subtype": "rewind_files",
528 "user_message_id": user_message_id.into(),
529 }),
530 &self.control_callbacks,
531 )
532 .await?;
533 Ok(())
534 }
535
536 pub async fn reconnect_mcp_server(&mut self, server_name: impl Into<String>) -> Result<()> {
537 self.require_connected()?;
538 send_control_request_with_callbacks(
539 self.transport.as_mut(),
540 serde_json::json!({
541 "subtype": "mcp_reconnect",
542 "serverName": server_name.into(),
543 }),
544 &self.control_callbacks,
545 )
546 .await?;
547 Ok(())
548 }
549
550 pub async fn toggle_mcp_server(
551 &mut self,
552 server_name: impl Into<String>,
553 enabled: bool,
554 ) -> Result<()> {
555 self.require_connected()?;
556 send_control_request_with_callbacks(
557 self.transport.as_mut(),
558 serde_json::json!({
559 "subtype": "mcp_toggle",
560 "serverName": server_name.into(),
561 "enabled": enabled,
562 }),
563 &self.control_callbacks,
564 )
565 .await?;
566 Ok(())
567 }
568
569 pub async fn stop_task(&mut self, task_id: impl Into<String>) -> Result<()> {
570 self.require_connected()?;
571 send_control_request_with_callbacks(
572 self.transport.as_mut(),
573 serde_json::json!({
574 "subtype": "stop_task",
575 "task_id": task_id.into(),
576 }),
577 &self.control_callbacks,
578 )
579 .await?;
580 Ok(())
581 }
582
583 pub async fn get_mcp_status(&mut self) -> Result<MCPStatusResponse> {
584 self.require_connected()?;
585 let response = send_control_request_with_callbacks(
586 self.transport.as_mut(),
587 serde_json::json!({"subtype": "mcp_status"}),
588 &self.control_callbacks,
589 )
590 .await?;
591 let value = serde_json::Value::Object(response);
592 Ok(serde_json::from_value(value)?)
593 }
594
595 pub async fn get_context_usage(&mut self) -> Result<ContextUsageResponse> {
596 self.require_connected()?;
597 let response = send_control_request_with_callbacks(
598 self.transport.as_mut(),
599 serde_json::json!({"subtype": "get_context_usage"}),
600 &self.control_callbacks,
601 )
602 .await?;
603 Ok(serde_json::from_value(serde_json::Value::Object(response))?)
604 }
605
606 pub fn get_server_info(&self) -> Option<&serde_json::Map<String, serde_json::Value>> {
607 self.initialization_result.as_ref()
608 }
609
610 fn build_user_payload(
611 &self,
612 content: &str,
613 session_id: Option<&str>,
614 ) -> Result<serde_json::Map<String, serde_json::Value>> {
615 let mut payload = serde_json::Map::new();
616 payload.insert(
617 "type".to_string(),
618 serde_json::Value::String("user".to_string()),
619 );
620 payload.insert(
621 "session_id".to_string(),
622 serde_json::Value::String(
623 session_id
624 .map(String::from)
625 .unwrap_or_else(|| self.session_id.clone()),
626 ),
627 );
628 let message = serde_json::json!({"role": "user", "content": content});
629 payload.insert("message".to_string(), message);
630 Ok(payload)
631 }
632}