1use async_trait::async_trait;
2use serde_json::Value;
3use std::fmt::Display;
4
5#[derive(Debug, Clone)]
6pub enum Direction {
7 ToClient,
8 ToServer,
9}
10
11#[derive(Debug, Clone)]
12pub struct Request {
13 pub id: i64,
14 pub method: String,
15 pub params: Option<Value>,
16}
17
18#[derive(Debug, Clone)]
19pub struct Response {
20 pub id: i64,
21 pub result: Option<Value>,
22 pub error: Option<Value>,
23}
24
25#[derive(Debug, Clone)]
26pub struct Notification {
27 pub method: String,
28 pub params: Option<Value>,
29}
30
31#[derive(Debug, Clone)]
32pub enum Message {
33 Request(Request),
34 Response(Response),
35 Notification(Notification),
36}
37
38impl Message {
39 pub fn from_value(value: Value) -> Result<Self, String> {
40 let obj = value.as_object().ok_or("Message must be an object")?;
41
42 let id = obj.get("id").and_then(|id| id.as_i64());
43 let method = obj.get("method").and_then(|m| m.as_str()).map(String::from);
44 let params = obj.get("params").cloned();
45 let result = obj.get("result").cloned();
46 let error = obj.get("error").cloned();
47
48 match (id, method, result.is_some() || error.is_some()) {
49 (Some(id), Some(method), false) => Ok(Message::Request(Request { id, method, params })),
50 (Some(id), None, true) => Ok(Message::Response(Response { id, result, error })),
51 (None, Some(method), false) => {
52 Ok(Message::Notification(Notification { method, params }))
53 }
54 _ => Err("Invalid message format".to_string()),
55 }
56 }
57
58 pub fn to_value(&self) -> Value {
59 match self {
60 Message::Request(Request { id, method, params }) => {
61 let mut obj = serde_json::json!({
62 "jsonrpc": "2.0",
63 "id": id,
64 "method": method,
65 });
66 if let Some(params) = params {
67 obj["params"] = params.clone();
68 }
69 obj
70 }
71 Message::Response(Response { id, result, error }) => {
72 let mut obj = serde_json::json!({
73 "jsonrpc": "2.0",
74 "id": id,
75 });
76 if let Some(result) = result {
77 obj["result"] = result.clone();
78 }
79 if let Some(error) = error {
80 obj["error"] = error.clone();
81 }
82 obj
83 }
84 Message::Notification(Notification { method, params }) => {
85 let mut obj = serde_json::json!({
86 "jsonrpc": "2.0",
87 "method": method,
88 });
89 if let Some(params) = params {
90 obj["params"] = params.clone();
91 }
92 obj
93 }
94 }
95 }
96
97 pub fn get_method(&self) -> Option<&str> {
98 match self {
99 Message::Request(Request { method, .. }) => Some(method),
100 Message::Response(Response { .. }) => None,
101 Message::Notification(Notification { method, .. }) => Some(method),
102 }
103 }
104
105 pub fn get_id(&self) -> Option<&i64> {
106 match self {
107 Message::Request(Request { id, .. }) => Some(id),
108 Message::Response(Response { id, .. }) => Some(id),
109 Message::Notification(Notification { .. }) => None,
110 }
111 }
112
113 pub fn notification(method: &str, params: Option<Value>) -> Self {
114 Message::Notification(Notification {
115 method: method.to_owned(),
116 params,
117 })
118 }
119}
120
121#[derive(Debug)]
122pub enum HookError {
123 ProcessingFailed(String),
124}
125
126impl Display for HookError {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 match self {
129 HookError::ProcessingFailed(msg) => write!(f, "Hook processing failed: {}", msg),
130 }
131 }
132}
133
134impl std::error::Error for HookError {}
135
136#[derive(Debug)]
137pub struct HookOutput {
138 pub message: Message,
139 pub generated_messages: Vec<(Direction, Message)>,
140}
141
142impl HookOutput {
143 pub fn new(message: Message) -> Self {
144 Self {
145 message,
146 generated_messages: Vec::new(),
147 }
148 }
149
150 pub fn with_message(mut self, direction: Direction, message: Message) -> Self {
151 self.generated_messages.push((direction, message));
152 self
153 }
154
155 pub fn with_messages(mut self, messages: Vec<(Direction, Message)>) -> Self {
156 self.generated_messages.extend(messages);
157 self
158 }
159}
160
161pub type HookResult = Result<HookOutput, HookError>;
162
163#[async_trait]
164pub trait Hook: Send + Sync {
165 async fn on_request(&self, request: Request) -> HookResult {
166 Ok(HookOutput::new(Message::Request(request)))
167 }
168
169 async fn on_response(&self, response: Response) -> HookResult {
170 Ok(HookOutput::new(Message::Response(response)))
171 }
172
173 async fn on_notification(&self, notification: Notification) -> HookResult {
174 Ok(HookOutput::new(Message::Notification(notification)))
175 }
176}