use crate::error::Result;
use crate::message::Message;
use crate::types::ToolDefinition;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct RequestId(String);
impl RequestId {
pub fn new() -> Self {
Self(Uuid::new_v4().to_string())
}
pub fn from_string(id: impl Into<String>) -> Self {
Self(id.into())
}
pub fn base(&self) -> String {
self.0.split('.').next().unwrap_or(&self.0).to_string()
}
pub fn with_sequence(&self, sequence: usize) -> RequestId {
RequestId(format!("{}.{}", self.0, sequence))
}
pub fn as_str(&self) -> &str {
&self.0
}
pub fn matches_base(&self, other: &RequestId) -> bool {
self.base() == other.base()
}
}
impl Default for RequestId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for RequestId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryRequest {
pub query: String,
pub system_prompt: Option<String>,
pub model: String,
pub max_tokens: u32,
pub tools: Vec<ToolDefinition>,
pub messages: Vec<Message>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryResponse {
pub message: Message,
pub is_complete: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HookRequest {
pub event_type: String,
pub data: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct HookResponse {
#[serde(rename = "continue")]
pub continue_: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub modified_inputs: Option<ModifiedInputs>,
#[serde(skip_serializing_if = "Option::is_none")]
pub context: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub permission_decision: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub permission_decision_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub additional_context: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub continue_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_message: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub suppress_output: Option<bool>,
}
impl HookResponse {
pub fn continue_exec() -> Self {
Self {
continue_: true,
modified_inputs: None,
context: None,
permission_decision: None,
permission_decision_reason: None,
additional_context: None,
continue_reason: None,
stop_reason: None,
system_message: None,
reason: None,
suppress_output: None,
}
}
pub fn stop() -> Self {
Self {
continue_: false,
modified_inputs: None,
context: None,
permission_decision: None,
permission_decision_reason: None,
additional_context: None,
continue_reason: None,
stop_reason: None,
system_message: None,
reason: None,
suppress_output: None,
}
}
pub fn with_permission_decision(mut self, decision: impl Into<String>) -> Self {
self.permission_decision = Some(decision.into());
self
}
pub fn with_permission_reason(mut self, reason: impl Into<String>) -> Self {
self.permission_decision_reason = Some(reason.into());
self
}
pub fn with_additional_context(mut self, context: serde_json::Value) -> Self {
self.additional_context = Some(context);
self
}
pub fn with_continue_reason(mut self, reason: impl Into<String>) -> Self {
self.continue_reason = Some(reason.into());
self
}
pub fn with_stop_reason(mut self, reason: impl Into<String>) -> Self {
self.stop_reason = Some(reason.into());
self
}
pub fn with_system_message(mut self, message: impl Into<String>) -> Self {
self.system_message = Some(message.into());
self
}
pub fn with_reason(mut self, reason: impl Into<String>) -> Self {
self.reason = Some(reason.into());
self
}
pub fn with_suppress_output(mut self, suppress: bool) -> Self {
self.suppress_output = Some(suppress);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ModifiedInputs {
pub tool_name: Option<String>,
pub input: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PermissionCheckRequest {
pub tool: String,
pub input: serde_json::Value,
pub suggestion: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct PermissionResponse {
pub allow: bool,
pub modified_input: Option<serde_json::Value>,
pub reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "command", content = "payload")]
pub enum ControlCommand {
#[serde(rename = "interrupt")]
Interrupt,
#[serde(rename = "set_model")]
SetModel(String),
#[serde(rename = "set_permission_mode")]
SetPermissionMode(String),
#[serde(rename = "get_state")]
GetState,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ControlRequest {
#[serde(flatten)]
pub command: ControlCommand,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ControlResponse {
pub success: bool,
pub message: Option<String>,
pub data: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProtocolErrorMessage {
pub code: String,
pub message: String,
pub details: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "payload")]
pub enum ProtocolMessage {
#[serde(rename = "query")]
Query(QueryRequest),
#[serde(rename = "response")]
Response(QueryResponse),
#[serde(rename = "hook_request")]
HookRequest(HookRequest),
#[serde(rename = "hook_response")]
HookResponse(Box<HookResponse>),
#[serde(rename = "permission_check")]
PermissionCheck(PermissionCheckRequest),
#[serde(rename = "permission_response")]
PermissionResponse(PermissionResponse),
#[serde(rename = "control_request")]
ControlRequest(ControlRequest),
#[serde(rename = "control_response")]
ControlResponse(ControlResponse),
#[serde(rename = "error")]
Error(ProtocolErrorMessage),
}
impl ProtocolMessage {
pub fn to_json(&self) -> Result<String> {
serde_json::to_string(self)
.map_err(|e| crate::error::ProtocolError::SerializationError(e.to_string()))
}
pub fn from_json(json: &str) -> Result<Self> {
serde_json::from_str(json)
.map_err(|e| crate::error::ProtocolError::SerializationError(e.to_string()))
}
pub fn request_id(&self) -> Option<RequestId> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_request_id_generation() {
let id = RequestId::new();
assert!(!id.as_str().is_empty());
assert_eq!(id.base(), id.as_str());
}
#[test]
fn test_request_id_sequence() {
let id = RequestId::from_string("550e8400");
let seq = id.with_sequence(1);
assert_eq!(seq.as_str(), "550e8400.1");
assert_eq!(seq.base(), "550e8400");
}
#[test]
fn test_request_id_matches_base() {
let id1 = RequestId::from_string("550e8400");
let id2 = RequestId::from_string("550e8400.1");
let id3 = RequestId::from_string("other");
assert!(id1.matches_base(&id2));
assert!(id2.matches_base(&id1));
assert!(!id1.matches_base(&id3));
}
#[test]
fn test_hook_request_serialization() {
let hook = HookRequest {
event_type: "PreToolUse".to_string(),
data: serde_json::json!({ "tool": "search" }),
};
let json = serde_json::to_string(&hook).unwrap();
let deserialized: HookRequest = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.event_type, "PreToolUse");
assert_eq!(deserialized.data["tool"], "search");
}
#[test]
fn test_permission_check_serialization() {
let check = PermissionCheckRequest {
tool: "web_search".to_string(),
input: serde_json::json!({ "query": "test" }),
suggestion: "Use web_search? (yes/no)".to_string(),
};
let json = serde_json::to_string(&check).unwrap();
let deserialized: PermissionCheckRequest = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.tool, "web_search");
}
#[test]
fn test_hook_response_serialization() {
let response = Box::new(HookResponse {
continue_: true,
modified_inputs: None,
context: None,
permission_decision: None,
permission_decision_reason: None,
additional_context: None,
continue_reason: None,
stop_reason: None,
system_message: None,
reason: None,
suppress_output: None,
});
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains(r#""continue":true"#));
let deserialized: HookResponse = serde_json::from_str(&json).unwrap();
assert!(deserialized.continue_);
}
#[test]
fn test_permission_response_serialization() {
let response = PermissionResponse {
allow: true,
modified_input: None,
reason: Some("User approved".to_string()),
};
let json = serde_json::to_string(&response).unwrap();
let deserialized: PermissionResponse = serde_json::from_str(&json).unwrap();
assert!(deserialized.allow);
}
#[test]
fn test_protocol_message_query_roundtrip() {
let request = QueryRequest {
query: "What is the capital of France?".to_string(),
system_prompt: None,
model: "claude-3-5-sonnet-20241022".to_string(),
max_tokens: 1024,
tools: vec![],
messages: vec![],
};
let msg = ProtocolMessage::Query(request.clone());
let json = msg.to_json().unwrap();
let deserialized = ProtocolMessage::from_json(&json).unwrap();
match deserialized {
ProtocolMessage::Query(q) => {
assert_eq!(q.query, request.query);
assert_eq!(q.model, request.model);
}
_ => panic!("Expected Query message"),
}
}
#[test]
fn test_protocol_message_hook_request_roundtrip() {
let hook = HookRequest {
event_type: "PreToolUse".to_string(),
data: serde_json::json!({ "tool": "search", "step": 1 }),
};
let msg = ProtocolMessage::HookRequest(hook.clone());
let json = msg.to_json().unwrap();
let deserialized = ProtocolMessage::from_json(&json).unwrap();
match deserialized {
ProtocolMessage::HookRequest(h) => {
assert_eq!(h.event_type, "PreToolUse");
}
_ => panic!("Expected HookRequest message"),
}
}
#[test]
fn test_control_command_interrupt() {
let cmd = ControlCommand::Interrupt;
let json = serde_json::to_string(&cmd).unwrap();
assert!(json.contains("interrupt"));
}
#[test]
fn test_control_command_set_model() {
let cmd = ControlCommand::SetModel("claude-3-5-haiku-20241022".to_string());
let json = serde_json::to_string(&cmd).unwrap();
assert!(json.contains("set_model"));
assert!(json.contains("claude-3-5-haiku-20241022"));
}
#[test]
fn test_protocol_error_message_serialization() {
let error = ProtocolErrorMessage {
code: "parse_error".to_string(),
message: "Invalid JSON".to_string(),
details: Some(serde_json::json!({ "line": 5 })),
};
let json = serde_json::to_string(&error).unwrap();
let deserialized: ProtocolErrorMessage = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.code, "parse_error");
}
}