1use crate::{
7 errors::Result,
8 types::{ControlRequest, ControlResponse, Message},
9};
10use async_trait::async_trait;
11use futures::stream::Stream;
12use serde_json::Value as JsonValue;
13use std::pin::Pin;
14use tokio::sync::mpsc::Receiver;
15
16pub mod subprocess;
17pub mod mock;
18
19pub use subprocess::SubprocessTransport;
20
21#[derive(Debug, Clone, serde::Serialize)]
23pub struct InputMessage {
24 #[serde(rename = "type")]
26 pub r#type: String,
27 pub message: serde_json::Value,
29 pub parent_tool_use_id: Option<String>,
31 pub session_id: String,
33}
34
35impl InputMessage {
36 pub fn user(content: String, session_id: String) -> Self {
38 Self {
39 r#type: "user".to_string(),
40 message: serde_json::json!({
41 "role": "user",
42 "content": content
43 }),
44 parent_tool_use_id: None,
45 session_id,
46 }
47 }
48
49 pub fn tool_result(
51 tool_use_id: String,
52 content: String,
53 session_id: String,
54 is_error: bool,
55 ) -> Self {
56 Self {
57 r#type: "user".to_string(),
58 message: serde_json::json!({
59 "role": "user",
60 "content": [{
61 "type": "tool_result",
62 "tool_use_id": tool_use_id,
63 "content": content,
64 "is_error": is_error
65 }]
66 }),
67 parent_tool_use_id: Some(tool_use_id),
68 session_id,
69 }
70 }
71}
72
73#[async_trait]
75pub trait Transport: Send + Sync {
76 fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
78
79 async fn connect(&mut self) -> Result<()>;
81
82 async fn send_message(&mut self, message: InputMessage) -> Result<()>;
84
85 fn receive_messages(&mut self) -> Pin<Box<dyn Stream<Item = Result<Message>> + Send + 'static>>;
87
88 async fn send_control_request(&mut self, request: ControlRequest) -> Result<()>;
90
91 async fn receive_control_response(&mut self) -> Result<Option<ControlResponse>>;
93
94 async fn send_sdk_control_request(&mut self, request: JsonValue) -> Result<()>;
96
97 async fn send_sdk_control_response(&mut self, response: JsonValue) -> Result<()>;
99
100 fn take_sdk_control_receiver(&mut self) -> Option<Receiver<JsonValue>> {
104 None
105 }
106
107 #[allow(dead_code)]
109 fn is_connected(&self) -> bool;
110
111 async fn disconnect(&mut self) -> Result<()>;
113
114 async fn end_input(&mut self) -> Result<()> {
116 Ok(())
117 }
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122pub enum TransportState {
123 Disconnected,
125 Connecting,
127 Connected,
129 Disconnecting,
131 #[allow(dead_code)]
133 Error,
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[test]
141 fn test_input_message_user() {
142 let msg = InputMessage::user("Hello".to_string(), "session-123".to_string());
143 assert_eq!(msg.r#type, "user");
144 assert_eq!(msg.session_id, "session-123");
145 assert!(msg.parent_tool_use_id.is_none());
146
147 let json = serde_json::to_string(&msg).unwrap();
148 assert!(json.contains(r#""type":"user""#));
149 assert!(json.contains(r#""content":"Hello""#));
150 }
151
152 #[test]
153 fn test_input_message_tool_result() {
154 let msg = InputMessage::tool_result(
155 "tool-123".to_string(),
156 "Result".to_string(),
157 "session-456".to_string(),
158 false,
159 );
160 assert_eq!(msg.r#type, "user");
161 assert_eq!(msg.parent_tool_use_id, Some("tool-123".to_string()));
162
163 let json = serde_json::to_string(&msg).unwrap();
164 assert!(json.contains(r#""tool_use_id":"tool-123""#));
165 assert!(json.contains(r#""is_error":false"#));
166 }
167}