1use crate::{
7 errors::Result,
8 types::{ControlRequest, ControlResponse, Message},
9};
10use async_trait::async_trait;
11use futures::stream::Stream;
12use serde_json::Value as JsonValue;
13use std::pin::Pin;
14use tokio::sync::mpsc::Receiver;
15
16pub mod subprocess;
17pub mod mock;
18#[cfg(feature = "websocket")]
19pub mod websocket;
20
21pub use subprocess::SubprocessTransport;
22
23#[derive(Debug, Clone, serde::Serialize)]
25pub struct InputMessage {
26 #[serde(rename = "type")]
28 pub r#type: String,
29 pub message: serde_json::Value,
31 pub parent_tool_use_id: Option<String>,
33 pub session_id: String,
35}
36
37impl InputMessage {
38 pub fn user(content: String, session_id: String) -> Self {
40 Self {
41 r#type: "user".to_string(),
42 message: serde_json::json!({
43 "role": "user",
44 "content": content
45 }),
46 parent_tool_use_id: None,
47 session_id,
48 }
49 }
50
51 pub fn tool_result(
53 tool_use_id: String,
54 content: String,
55 session_id: String,
56 is_error: bool,
57 ) -> Self {
58 Self {
59 r#type: "user".to_string(),
60 message: serde_json::json!({
61 "role": "user",
62 "content": [{
63 "type": "tool_result",
64 "tool_use_id": tool_use_id,
65 "content": content,
66 "is_error": is_error
67 }]
68 }),
69 parent_tool_use_id: Some(tool_use_id),
70 session_id,
71 }
72 }
73}
74
75#[async_trait]
77pub trait Transport: Send + Sync {
78 fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
80
81 async fn connect(&mut self) -> Result<()>;
83
84 async fn send_message(&mut self, message: InputMessage) -> Result<()>;
86
87 fn receive_messages(&mut self) -> Pin<Box<dyn Stream<Item = Result<Message>> + Send + 'static>>;
89
90 async fn send_control_request(&mut self, request: ControlRequest) -> Result<()>;
92
93 async fn receive_control_response(&mut self) -> Result<Option<ControlResponse>>;
95
96 async fn send_sdk_control_request(&mut self, request: JsonValue) -> Result<()>;
98
99 async fn send_sdk_control_response(&mut self, response: JsonValue) -> Result<()>;
101
102 fn take_sdk_control_receiver(&mut self) -> Option<Receiver<JsonValue>> {
106 None
107 }
108
109 #[allow(dead_code)]
111 fn is_connected(&self) -> bool;
112
113 async fn disconnect(&mut self) -> Result<()>;
115
116 async fn end_input(&mut self) -> Result<()> {
118 Ok(())
119 }
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124pub enum TransportState {
125 Disconnected,
127 Connecting,
129 Connected,
131 Disconnecting,
133 #[allow(dead_code)]
135 Error,
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn test_input_message_user() {
144 let msg = InputMessage::user("Hello".to_string(), "session-123".to_string());
145 assert_eq!(msg.r#type, "user");
146 assert_eq!(msg.session_id, "session-123");
147 assert!(msg.parent_tool_use_id.is_none());
148
149 let json = serde_json::to_string(&msg).unwrap();
150 assert!(json.contains(r#""type":"user""#));
151 assert!(json.contains(r#""content":"Hello""#));
152 }
153
154 #[test]
155 fn test_input_message_tool_result() {
156 let msg = InputMessage::tool_result(
157 "tool-123".to_string(),
158 "Result".to_string(),
159 "session-456".to_string(),
160 false,
161 );
162 assert_eq!(msg.r#type, "user");
163 assert_eq!(msg.parent_tool_use_id, Some("tool-123".to_string()));
164
165 let json = serde_json::to_string(&msg).unwrap();
166 assert!(json.contains(r#""tool_use_id":"tool-123""#));
167 assert!(json.contains(r#""is_error":false"#));
168 }
169}