use crate::{
errors::Result,
types::{ControlRequest, ControlResponse, Message},
};
use async_trait::async_trait;
use futures::stream::Stream;
use serde_json::Value as JsonValue;
use std::pin::Pin;
use tokio::sync::mpsc::Receiver;
pub mod subprocess;
pub mod mock;
#[cfg(feature = "websocket")]
pub mod websocket;
pub use subprocess::SubprocessTransport;
#[derive(Debug, Clone, serde::Serialize)]
pub struct InputMessage {
#[serde(rename = "type")]
pub r#type: String,
pub message: serde_json::Value,
pub parent_tool_use_id: Option<String>,
pub session_id: String,
}
impl InputMessage {
pub fn user(content: String, session_id: String) -> Self {
Self {
r#type: "user".to_string(),
message: serde_json::json!({
"role": "user",
"content": content
}),
parent_tool_use_id: None,
session_id,
}
}
pub fn tool_result(
tool_use_id: String,
content: String,
session_id: String,
is_error: bool,
) -> Self {
Self {
r#type: "user".to_string(),
message: serde_json::json!({
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": tool_use_id,
"content": content,
"is_error": is_error
}]
}),
parent_tool_use_id: Some(tool_use_id),
session_id,
}
}
}
#[async_trait]
pub trait Transport: Send + Sync {
fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
async fn connect(&mut self) -> Result<()>;
async fn send_message(&mut self, message: InputMessage) -> Result<()>;
fn receive_messages(&mut self) -> Pin<Box<dyn Stream<Item = Result<Message>> + Send + 'static>>;
async fn send_control_request(&mut self, request: ControlRequest) -> Result<()>;
async fn receive_control_response(&mut self) -> Result<Option<ControlResponse>>;
async fn send_sdk_control_request(&mut self, request: JsonValue) -> Result<()>;
async fn send_sdk_control_response(&mut self, response: JsonValue) -> Result<()>;
fn take_sdk_control_receiver(&mut self) -> Option<Receiver<JsonValue>> {
None
}
#[allow(dead_code)]
fn is_connected(&self) -> bool;
async fn disconnect(&mut self) -> Result<()>;
async fn end_input(&mut self) -> Result<()> {
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransportState {
Disconnected,
Connecting,
Connected,
Disconnecting,
#[allow(dead_code)]
Error,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_input_message_user() {
let msg = InputMessage::user("Hello".to_string(), "session-123".to_string());
assert_eq!(msg.r#type, "user");
assert_eq!(msg.session_id, "session-123");
assert!(msg.parent_tool_use_id.is_none());
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains(r#""type":"user""#));
assert!(json.contains(r#""content":"Hello""#));
}
#[test]
fn test_input_message_tool_result() {
let msg = InputMessage::tool_result(
"tool-123".to_string(),
"Result".to_string(),
"session-456".to_string(),
false,
);
assert_eq!(msg.r#type, "user");
assert_eq!(msg.parent_tool_use_id, Some("tool-123".to_string()));
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains(r#""tool_use_id":"tool-123""#));
assert!(json.contains(r#""is_error":false"#));
}
}