1use crate::error::Result;
7use crate::message::Message;
8use crate::types::ToolDefinition;
9use serde::{Deserialize, Serialize};
10use uuid::Uuid;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
16pub struct RequestId(String);
17
18impl RequestId {
19 pub fn new() -> Self {
21 Self(Uuid::new_v4().to_string())
22 }
23
24 pub fn from_string(id: impl Into<String>) -> Self {
26 Self(id.into())
27 }
28
29 pub fn base(&self) -> String {
31 self.0.split('.').next().unwrap_or(&self.0).to_string()
32 }
33
34 pub fn with_sequence(&self, sequence: usize) -> RequestId {
36 RequestId(format!("{}.{}", self.0, sequence))
37 }
38
39 pub fn as_str(&self) -> &str {
41 &self.0
42 }
43
44 pub fn matches_base(&self, other: &RequestId) -> bool {
46 self.base() == other.base()
47 }
48}
49
50impl Default for RequestId {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56impl std::fmt::Display for RequestId {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 write!(f, "{}", self.0)
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct QueryRequest {
67 pub query: String,
69
70 pub system_prompt: Option<String>,
72
73 pub model: String,
75
76 pub max_tokens: u32,
78
79 pub tools: Vec<ToolDefinition>,
81
82 pub messages: Vec<Message>,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct QueryResponse {
91 pub message: Message,
93
94 pub is_complete: bool,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct HookRequest {
103 pub event_type: String,
105
106 pub data: serde_json::Value,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
114pub struct HookResponse {
115 #[serde(rename = "continue")]
117 pub continue_: bool,
118
119 #[serde(skip_serializing_if = "Option::is_none")]
121 pub modified_inputs: Option<ModifiedInputs>,
122
123 #[serde(skip_serializing_if = "Option::is_none")]
125 pub context: Option<serde_json::Value>,
126
127 #[serde(skip_serializing_if = "Option::is_none")]
131 pub permission_decision: Option<String>,
132
133 #[serde(skip_serializing_if = "Option::is_none")]
135 pub permission_decision_reason: Option<String>,
136
137 #[serde(skip_serializing_if = "Option::is_none")]
142 pub additional_context: Option<serde_json::Value>,
143
144 #[serde(skip_serializing_if = "Option::is_none")]
146 pub continue_reason: Option<String>,
147
148 #[serde(skip_serializing_if = "Option::is_none")]
150 pub stop_reason: Option<String>,
151
152 #[serde(skip_serializing_if = "Option::is_none")]
154 pub system_message: Option<String>,
155
156 #[serde(skip_serializing_if = "Option::is_none")]
158 pub reason: Option<String>,
159
160 #[serde(skip_serializing_if = "Option::is_none")]
162 pub suppress_output: Option<bool>,
163}
164
165impl HookResponse {
166 pub fn continue_exec() -> Self {
168 Self {
169 continue_: true,
170 modified_inputs: None,
171 context: None,
172 permission_decision: None,
173 permission_decision_reason: None,
174 additional_context: None,
175 continue_reason: None,
176 stop_reason: None,
177 system_message: None,
178 reason: None,
179 suppress_output: None,
180 }
181 }
182
183 pub fn stop() -> Self {
185 Self {
186 continue_: false,
187 modified_inputs: None,
188 context: None,
189 permission_decision: None,
190 permission_decision_reason: None,
191 additional_context: None,
192 continue_reason: None,
193 stop_reason: None,
194 system_message: None,
195 reason: None,
196 suppress_output: None,
197 }
198 }
199
200 pub fn with_permission_decision(mut self, decision: impl Into<String>) -> Self {
202 self.permission_decision = Some(decision.into());
203 self
204 }
205
206 pub fn with_permission_reason(mut self, reason: impl Into<String>) -> Self {
208 self.permission_decision_reason = Some(reason.into());
209 self
210 }
211
212 pub fn with_additional_context(mut self, context: serde_json::Value) -> Self {
214 self.additional_context = Some(context);
215 self
216 }
217
218 pub fn with_continue_reason(mut self, reason: impl Into<String>) -> Self {
220 self.continue_reason = Some(reason.into());
221 self
222 }
223
224 pub fn with_stop_reason(mut self, reason: impl Into<String>) -> Self {
226 self.stop_reason = Some(reason.into());
227 self
228 }
229
230 pub fn with_system_message(mut self, message: impl Into<String>) -> Self {
232 self.system_message = Some(message.into());
233 self
234 }
235
236 pub fn with_reason(mut self, reason: impl Into<String>) -> Self {
238 self.reason = Some(reason.into());
239 self
240 }
241
242 pub fn with_suppress_output(mut self, suppress: bool) -> Self {
244 self.suppress_output = Some(suppress);
245 self
246 }
247}
248
249#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
251pub struct ModifiedInputs {
252 pub tool_name: Option<String>,
254
255 pub input: Option<serde_json::Value>,
257}
258
259#[derive(Debug, Clone, Serialize, Deserialize)]
263pub struct PermissionCheckRequest {
264 pub tool: String,
266
267 pub input: serde_json::Value,
269
270 pub suggestion: String,
272}
273
274#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
278pub struct PermissionResponse {
279 pub allow: bool,
281
282 pub modified_input: Option<serde_json::Value>,
284
285 pub reason: Option<String>,
287}
288
289#[derive(Debug, Clone, Serialize, Deserialize)]
293#[serde(tag = "command", content = "payload")]
294pub enum ControlCommand {
295 #[serde(rename = "interrupt")]
297 Interrupt,
298
299 #[serde(rename = "set_model")]
301 SetModel(String),
302
303 #[serde(rename = "set_permission_mode")]
305 SetPermissionMode(String),
306
307 #[serde(rename = "get_state")]
309 GetState,
310}
311
312#[derive(Debug, Clone, Serialize, Deserialize)]
314pub struct ControlRequest {
315 #[serde(flatten)]
317 pub command: ControlCommand,
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct ControlResponse {
325 pub success: bool,
327
328 pub message: Option<String>,
330
331 pub data: Option<serde_json::Value>,
333}
334
335#[derive(Debug, Clone, Serialize, Deserialize)]
339pub struct ProtocolErrorMessage {
340 pub code: String,
342
343 pub message: String,
345
346 pub details: Option<serde_json::Value>,
348}
349
350#[derive(Debug, Clone, Serialize, Deserialize)]
354#[serde(tag = "type", content = "payload")]
355pub enum ProtocolMessage {
356 #[serde(rename = "query")]
358 Query(QueryRequest),
359
360 #[serde(rename = "response")]
362 Response(QueryResponse),
363
364 #[serde(rename = "hook_request")]
366 HookRequest(HookRequest),
367
368 #[serde(rename = "hook_response")]
370 HookResponse(Box<HookResponse>),
371
372 #[serde(rename = "permission_check")]
374 PermissionCheck(PermissionCheckRequest),
375
376 #[serde(rename = "permission_response")]
378 PermissionResponse(PermissionResponse),
379
380 #[serde(rename = "control_request")]
382 ControlRequest(ControlRequest),
383
384 #[serde(rename = "control_response")]
386 ControlResponse(ControlResponse),
387
388 #[serde(rename = "error")]
390 Error(ProtocolErrorMessage),
391}
392
393impl ProtocolMessage {
394 pub fn to_json(&self) -> Result<String> {
396 serde_json::to_string(self)
397 .map_err(|e| crate::error::ProtocolError::SerializationError(e.to_string()))
398 }
399
400 pub fn from_json(json: &str) -> Result<Self> {
402 serde_json::from_str(json)
403 .map_err(|e| crate::error::ProtocolError::SerializationError(e.to_string()))
404 }
405
406 pub fn request_id(&self) -> Option<RequestId> {
408 None
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417
418 #[test]
419 fn test_request_id_generation() {
420 let id = RequestId::new();
421 assert!(!id.as_str().is_empty());
422 assert_eq!(id.base(), id.as_str());
423 }
424
425 #[test]
426 fn test_request_id_sequence() {
427 let id = RequestId::from_string("550e8400");
428 let seq = id.with_sequence(1);
429 assert_eq!(seq.as_str(), "550e8400.1");
430 assert_eq!(seq.base(), "550e8400");
431 }
432
433 #[test]
434 fn test_request_id_matches_base() {
435 let id1 = RequestId::from_string("550e8400");
436 let id2 = RequestId::from_string("550e8400.1");
437 let id3 = RequestId::from_string("other");
438
439 assert!(id1.matches_base(&id2));
440 assert!(id2.matches_base(&id1));
441 assert!(!id1.matches_base(&id3));
442 }
443
444 #[test]
445 fn test_hook_request_serialization() {
446 let hook = HookRequest {
447 event_type: "PreToolUse".to_string(),
448 data: serde_json::json!({ "tool": "search" }),
449 };
450
451 let json = serde_json::to_string(&hook).unwrap();
452 let deserialized: HookRequest = serde_json::from_str(&json).unwrap();
453
454 assert_eq!(deserialized.event_type, "PreToolUse");
455 assert_eq!(deserialized.data["tool"], "search");
456 }
457
458 #[test]
459 fn test_permission_check_serialization() {
460 let check = PermissionCheckRequest {
461 tool: "web_search".to_string(),
462 input: serde_json::json!({ "query": "test" }),
463 suggestion: "Use web_search? (yes/no)".to_string(),
464 };
465
466 let json = serde_json::to_string(&check).unwrap();
467 let deserialized: PermissionCheckRequest = serde_json::from_str(&json).unwrap();
468
469 assert_eq!(deserialized.tool, "web_search");
470 }
471
472 #[test]
473 fn test_hook_response_serialization() {
474 let response = Box::new(HookResponse {
475 continue_: true,
476 modified_inputs: None,
477 context: None,
478 permission_decision: None,
479 permission_decision_reason: None,
480 additional_context: None,
481 continue_reason: None,
482 stop_reason: None,
483 system_message: None,
484 reason: None,
485 suppress_output: None,
486 });
487
488 let json = serde_json::to_string(&response).unwrap();
489 assert!(json.contains(r#""continue":true"#));
490
491 let deserialized: HookResponse = serde_json::from_str(&json).unwrap();
492 assert!(deserialized.continue_);
493 }
494
495 #[test]
496 fn test_permission_response_serialization() {
497 let response = PermissionResponse {
498 allow: true,
499 modified_input: None,
500 reason: Some("User approved".to_string()),
501 };
502
503 let json = serde_json::to_string(&response).unwrap();
504 let deserialized: PermissionResponse = serde_json::from_str(&json).unwrap();
505
506 assert!(deserialized.allow);
507 }
508
509 #[test]
510 fn test_protocol_message_query_roundtrip() {
511 let request = QueryRequest {
512 query: "What is the capital of France?".to_string(),
513 system_prompt: None,
514 model: "claude-3-5-sonnet-20241022".to_string(),
515 max_tokens: 1024,
516 tools: vec![],
517 messages: vec![],
518 };
519
520 let msg = ProtocolMessage::Query(request.clone());
521 let json = msg.to_json().unwrap();
522 let deserialized = ProtocolMessage::from_json(&json).unwrap();
523
524 match deserialized {
525 ProtocolMessage::Query(q) => {
526 assert_eq!(q.query, request.query);
527 assert_eq!(q.model, request.model);
528 }
529 _ => panic!("Expected Query message"),
530 }
531 }
532
533 #[test]
534 fn test_protocol_message_hook_request_roundtrip() {
535 let hook = HookRequest {
536 event_type: "PreToolUse".to_string(),
537 data: serde_json::json!({ "tool": "search", "step": 1 }),
538 };
539
540 let msg = ProtocolMessage::HookRequest(hook.clone());
541 let json = msg.to_json().unwrap();
542 let deserialized = ProtocolMessage::from_json(&json).unwrap();
543
544 match deserialized {
545 ProtocolMessage::HookRequest(h) => {
546 assert_eq!(h.event_type, "PreToolUse");
547 }
548 _ => panic!("Expected HookRequest message"),
549 }
550 }
551
552 #[test]
553 fn test_control_command_interrupt() {
554 let cmd = ControlCommand::Interrupt;
555 let json = serde_json::to_string(&cmd).unwrap();
556 assert!(json.contains("interrupt"));
557 }
558
559 #[test]
560 fn test_control_command_set_model() {
561 let cmd = ControlCommand::SetModel("claude-3-5-haiku-20241022".to_string());
562 let json = serde_json::to_string(&cmd).unwrap();
563 assert!(json.contains("set_model"));
564 assert!(json.contains("claude-3-5-haiku-20241022"));
565 }
566
567 #[test]
568 fn test_protocol_error_message_serialization() {
569 let error = ProtocolErrorMessage {
570 code: "parse_error".to_string(),
571 message: "Invalid JSON".to_string(),
572 details: Some(serde_json::json!({ "line": 5 })),
573 };
574
575 let json = serde_json::to_string(&error).unwrap();
576 let deserialized: ProtocolErrorMessage = serde_json::from_str(&json).unwrap();
577
578 assert_eq!(deserialized.code, "parse_error");
579 }
580}