lsp_proxy/
hooks.rs

1use async_trait::async_trait;
2use serde_json::Value;
3use std::fmt::Display;
4
5#[derive(Debug, Clone)]
6pub enum Direction {
7    ToClient,
8    ToServer,
9}
10
11#[derive(Debug, Clone)]
12pub struct Request {
13    pub id: i64,
14    pub method: String,
15    pub params: Option<Value>,
16}
17
18#[derive(Debug, Clone)]
19pub struct Response {
20    pub id: i64,
21    pub result: Option<Value>,
22    pub error: Option<Value>,
23}
24
25#[derive(Debug, Clone)]
26pub struct Notification {
27    pub method: String,
28    pub params: Option<Value>,
29}
30
31#[derive(Debug, Clone)]
32pub enum Message {
33    Request(Request),
34    Response(Response),
35    Notification(Notification),
36}
37
38impl Message {
39    pub fn from_value(value: Value) -> Result<Self, String> {
40        let obj = value.as_object().ok_or("Message must be an object")?;
41
42        let id = obj.get("id").and_then(|id| id.as_i64());
43        let method = obj.get("method").and_then(|m| m.as_str()).map(String::from);
44        let params = obj.get("params").cloned();
45        let result = obj.get("result").cloned();
46        let error = obj.get("error").cloned();
47
48        match (id, method, result.is_some() || error.is_some()) {
49            (Some(id), Some(method), false) => Ok(Message::Request(Request { id, method, params })),
50            (Some(id), None, true) => Ok(Message::Response(Response { id, result, error })),
51            (None, Some(method), false) => {
52                Ok(Message::Notification(Notification { method, params }))
53            }
54            _ => Err("Invalid message format".to_string()),
55        }
56    }
57
58    pub fn to_value(&self) -> Value {
59        match self {
60            Message::Request(Request { id, method, params }) => {
61                let mut obj = serde_json::json!({
62                    "jsonrpc": "2.0",
63                    "id": id,
64                    "method": method,
65                });
66                if let Some(params) = params {
67                    obj["params"] = params.clone();
68                }
69                obj
70            }
71            Message::Response(Response { id, result, error }) => {
72                let mut obj = serde_json::json!({
73                    "jsonrpc": "2.0",
74                    "id": id,
75                });
76                if let Some(result) = result {
77                    obj["result"] = result.clone();
78                }
79                if let Some(error) = error {
80                    obj["error"] = error.clone();
81                }
82                obj
83            }
84            Message::Notification(Notification { method, params }) => {
85                let mut obj = serde_json::json!({
86                    "jsonrpc": "2.0",
87                    "method": method,
88                });
89                if let Some(params) = params {
90                    obj["params"] = params.clone();
91                }
92                obj
93            }
94        }
95    }
96
97    pub fn get_method(&self) -> Option<&str> {
98        match self {
99            Message::Request(Request { method, .. }) => Some(method),
100            Message::Response(Response { .. }) => None,
101            Message::Notification(Notification { method, .. }) => Some(method),
102        }
103    }
104
105    pub fn get_id(&self) -> Option<&i64> {
106        match self {
107            Message::Request(Request { id, .. }) => Some(id),
108            Message::Response(Response { id, .. }) => Some(id),
109            Message::Notification(Notification { .. }) => None,
110        }
111    }
112
113    pub fn notification(method: &str, params: Option<Value>) -> Self {
114        Message::Notification(Notification {
115            method: method.to_owned(),
116            params,
117        })
118    }
119}
120
121#[derive(Debug)]
122pub enum HookError {
123    ProcessingFailed(String),
124}
125
126impl Display for HookError {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        match self {
129            HookError::ProcessingFailed(msg) => write!(f, "Hook processing failed: {}", msg),
130        }
131    }
132}
133
134impl std::error::Error for HookError {}
135
136#[derive(Debug)]
137pub struct HookOutput {
138    pub message: Message,
139    pub generated_messages: Vec<(Direction, Message)>,
140}
141
142impl HookOutput {
143    pub fn new(message: Message) -> Self {
144        Self {
145            message,
146            generated_messages: Vec::new(),
147        }
148    }
149
150    pub fn with_message(mut self, direction: Direction, message: Message) -> Self {
151        self.generated_messages.push((direction, message));
152        self
153    }
154
155    pub fn with_messages(mut self, messages: Vec<(Direction, Message)>) -> Self {
156        self.generated_messages.extend(messages);
157        self
158    }
159}
160
161pub type HookResult = Result<HookOutput, HookError>;
162
163#[async_trait]
164pub trait Hook: Send + Sync {
165    async fn on_request(&self, request: Request) -> HookResult {
166        Ok(HookOutput::new(Message::Request(request)))
167    }
168
169    async fn on_response(&self, response: Response) -> HookResult {
170        Ok(HookOutput::new(Message::Response(response)))
171    }
172
173    async fn on_notification(&self, notification: Notification) -> HookResult {
174        Ok(HookOutput::new(Message::Notification(notification)))
175    }
176}