1use async_trait::async_trait;
2use serde::Serialize;
3use serde_json::Value;
4use std::fmt::Display;
5
6#[derive(Debug, Clone, Serialize)]
7pub enum Message {
8 Request {
9 id: Value,
10 method: String,
11 params: Option<Value>,
12 },
13 Response {
14 id: Value,
15 result: Option<Value>,
16 error: Option<Value>,
17 },
18 Notification {
19 method: String,
20 params: Option<Value>,
21 },
22}
23
24impl Message {
25 pub fn from_value(value: Value) -> Result<Self, String> {
26 let obj = value.as_object().ok_or("Message must be an object")?;
27
28 let id = obj.get("id").cloned();
29 let method = obj.get("method").and_then(|m| m.as_str()).map(String::from);
30 let params = obj.get("params").cloned();
31 let result = obj.get("result").cloned();
32 let error = obj.get("error").cloned();
33
34 match (id, method, result.is_some() || error.is_some()) {
35 (Some(id), Some(method), false) => Ok(Message::Request { id, method, params }),
36 (Some(id), None, true) => Ok(Message::Response { id, result, error }),
37 (None, Some(method), false) => Ok(Message::Notification { method, params }),
38 _ => Err("Invalid message format".to_string()),
39 }
40 }
41
42 pub fn to_value(&self) -> Value {
43 match self {
44 Message::Request { id, method, params } => {
45 let mut obj = serde_json::json!({
46 "jsonrpc": "2.0",
47 "id": id,
48 "method": method,
49 });
50 if let Some(params) = params {
51 obj["params"] = params.clone();
52 }
53 obj
54 }
55 Message::Response { id, result, error } => {
56 let mut obj = serde_json::json!({
57 "jsonrpc": "2.0",
58 "id": id,
59 });
60 if let Some(result) = result {
61 obj["result"] = result.clone();
62 }
63 if let Some(error) = error {
64 obj["error"] = error.clone();
65 }
66 obj
67 }
68 Message::Notification { method, params } => {
69 let mut obj = serde_json::json!({
70 "jsonrpc": "2.0",
71 "method": method,
72 });
73 if let Some(params) = params {
74 obj["params"] = params.clone();
75 }
76 obj
77 }
78 }
79 }
80
81 pub fn get_method(&self) -> Option<&str> {
82 match self {
83 Message::Request { method, .. } => Some(method),
84 Message::Response { .. } => None,
85 Message::Notification { method, .. } => Some(method),
86 }
87 }
88
89 pub fn get_id(&self) -> Option<&Value> {
90 match self {
91 Message::Request { id, .. } => Some(id),
92 Message::Response { id, .. } => Some(id),
93 Message::Notification { .. } => None,
94 }
95 }
96
97 pub fn notification(method: &str, params: Option<Value>) -> Self {
98 Message::Notification {
99 method: method.to_owned(),
100 params,
101 }
102 }
103}
104
105#[derive(Debug)]
106pub enum HookError {
107 ProcessingFailed(String),
108}
109
110impl Display for HookError {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 match self {
113 HookError::ProcessingFailed(msg) => write!(f, "Hook processing failed: {}", msg),
114 }
115 }
116}
117
118impl std::error::Error for HookError {}
119
120#[derive(Debug)]
121pub struct HookOutput {
122 pub message: Message,
123 pub notifications: Vec<Message>,
124}
125
126impl HookOutput {
127 pub fn new(message: Message) -> Self {
128 Self {
129 message,
130 notifications: Vec::new(),
131 }
132 }
133
134 pub fn with_notification(mut self, notification: Message) -> Self {
135 self.notifications.push(notification);
136 self
137 }
138
139 pub fn with_notifications(mut self, notifications: Vec<Message>) -> Self {
140 self.notifications.extend(notifications);
141 self
142 }
143
144 pub fn add_notification(&mut self, notification: Message) {
145 self.notifications.push(notification);
146 }
147}
148
149pub type HookResult = Result<HookOutput, HookError>;
150
151#[async_trait]
152pub trait Hook: Send + Sync {
153 async fn on_request(&self, message: Message) -> HookResult {
154 Ok(HookOutput::new(message))
155 }
156
157 async fn on_response(&self, message: Message) -> HookResult {
158 Ok(HookOutput::new(message))
159 }
160}