lsp_proxy/
proxy.rs

1use crate::hooks::{Direction, Hook, HookError, Message};
2use crate::transport::{read_message, write_message};
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::join;
7use tokio::sync::Mutex;
8use tokio::sync::mpsc::{self, Sender};
9
10pub struct Proxy {
11    hooks: HashMap<String, Arc<dyn Hook>>,
12    pending_requests: Mutex<HashMap<i64, String>>,
13}
14
15impl Proxy {
16    fn new(hooks: HashMap<String, Arc<dyn Hook>>) -> Self {
17        Self {
18            hooks,
19            pending_requests: Mutex::new(HashMap::new()),
20        }
21    }
22
23    async fn process_client_message(
24        &self,
25        message: Message,
26    ) -> Result<ProcessedMessage, HookError> {
27        match message {
28            Message::Request(request) => match self.hooks.get(&request.method) {
29                Some(hook) => {
30                    {
31                        let mut pending = self.pending_requests.lock().await;
32                        pending.insert(request.id, request.method.clone());
33                    }
34                    let output = hook.on_request(request).await?;
35                    Ok(ProcessedMessage::WithMessages {
36                        message: output.message,
37                        generated_messages: output.generated_messages,
38                    })
39                }
40                None => Ok(ProcessedMessage::Forward(Message::Request(request))),
41            },
42            Message::Notification(notification) => match self.hooks.get(&notification.method) {
43                Some(hook) => {
44                    let output = hook.on_notification(notification).await?;
45                    Ok(ProcessedMessage::Forward(output.message))
46                }
47                None => Ok(ProcessedMessage::Forward(Message::Notification(
48                    notification,
49                ))),
50            },
51            Message::Response { .. } => Ok(ProcessedMessage::Forward(message)),
52        }
53    }
54
55    async fn process_server_message(
56        &self,
57        message: Message,
58    ) -> Result<ProcessedMessage, HookError> {
59        match message {
60            Message::Response(response) => {
61                let method = {
62                    let mut pending = self.pending_requests.lock().await;
63                    pending.remove(&response.id)
64                };
65
66                if let Some(method) = method
67                    && let Some(hook) = self.hooks.get(&method)
68                {
69                    let output = hook.on_response(response).await?;
70                    return Ok(ProcessedMessage::WithMessages {
71                        message: output.message,
72                        generated_messages: output.generated_messages,
73                    });
74                }
75
76                Ok(ProcessedMessage::Forward(Message::Response(response)))
77            }
78            Message::Notification(notification) => match self.hooks.get(&notification.method) {
79                Some(hook) => {
80                    let output = hook.on_notification(notification).await?;
81                    Ok(ProcessedMessage::Forward(output.message))
82                }
83                None => Ok(ProcessedMessage::Forward(Message::Notification(
84                    notification,
85                ))),
86            },
87            Message::Request { .. } => Ok(ProcessedMessage::Forward(message)),
88        }
89    }
90
91    pub async fn forward<SR, SW, CR, CW>(
92        self,
93        server_reader: SR,
94        server_writer: SW,
95        client_reader: CR,
96        client_writer: CW,
97    ) -> std::io::Result<()>
98    where
99        SR: AsyncReadExt + Unpin + Send + 'static,
100        SW: AsyncWriteExt + Unpin + Send + 'static,
101        CR: AsyncReadExt + Unpin + Send + 'static,
102        CW: AsyncWriteExt + Unpin + Send + 'static,
103    {
104        let proxy = Arc::new(self);
105        let server_to_client_proxy = proxy.clone();
106
107        let (client_sender, mut client_receiver) = mpsc::channel::<Message>(10);
108        let (server_sender, mut server_receiver) = mpsc::channel::<Message>(10);
109
110        let server_message_sender = server_sender.clone();
111        let client_message_sender = client_sender.clone();
112        let server_to_client = tokio::spawn(async move {
113            forward_to_client(
114                &server_to_client_proxy,
115                server_reader,
116                server_message_sender,
117                client_message_sender,
118            )
119            .await
120        });
121
122        let client_to_server = tokio::spawn(async move {
123            forward_to_server(&proxy, client_reader, server_sender, client_sender).await
124        });
125
126        let write_to_server = tokio::spawn(async move {
127            let mut writer = server_writer;
128            while let Some(msg) = server_receiver.recv().await {
129                if write_message(&mut writer, &msg.to_value()).await.is_err() {
130                    break;
131                }
132            }
133            Ok::<(), std::io::Error>(())
134        });
135
136        let write_to_client = tokio::spawn(async move {
137            let mut writer = client_writer;
138            while let Some(msg) = client_receiver.recv().await {
139                if write_message(&mut writer, &msg.to_value()).await.is_err() {
140                    break;
141                }
142            }
143            Ok::<(), std::io::Error>(())
144        });
145
146        _ = join!(client_to_server, server_to_client);
147
148        drop(write_to_server);
149        drop(write_to_client);
150        Ok(())
151    }
152}
153
154impl Default for Proxy {
155    fn default() -> Self {
156        Self::new(HashMap::new())
157    }
158}
159
160#[derive(Debug)]
161pub enum ProcessedMessage {
162    Forward(Message),
163    WithMessages {
164        message: Message,
165        generated_messages: Vec<(Direction, Message)>,
166    },
167}
168
169impl ProcessedMessage {
170    pub fn get_message(&self) -> &Message {
171        match self {
172            ProcessedMessage::Forward(msg) => msg,
173            ProcessedMessage::WithMessages { message, .. } => message,
174        }
175    }
176
177    pub fn get_generated_messages(&self) -> &[(Direction, Message)] {
178        match self {
179            ProcessedMessage::Forward(_) => &[],
180            ProcessedMessage::WithMessages {
181                generated_messages: messages,
182                ..
183            } => messages,
184        }
185    }
186
187    pub fn into_parts(self) -> (Message, Vec<(Direction, Message)>) {
188        match self {
189            ProcessedMessage::Forward(msg) => (msg, Vec::new()),
190            ProcessedMessage::WithMessages {
191                message,
192                generated_messages: messages,
193            } => (message, messages),
194        }
195    }
196}
197
198async fn forward_to_server<R>(
199    proxy: &Proxy,
200    mut client_reader: R,
201    server_message_sender: Sender<Message>,
202    client_message_sender: Sender<Message>,
203) -> std::io::Result<()>
204where
205    R: AsyncReadExt + Unpin,
206{
207    loop {
208        let message = match read_message(&mut client_reader).await {
209            Ok(msg) => Message::from_value(msg),
210            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
211                break;
212            }
213            Err(e) => return Err(e),
214        };
215
216        if let Ok(message) = message {
217            match proxy.process_client_message(message).await {
218                Ok(processed) => {
219                    let (main_message, generated_messages) = processed.into_parts();
220
221                    if server_message_sender.send(main_message).await.is_err() {
222                        return Err(std::io::Error::new(
223                            std::io::ErrorKind::BrokenPipe,
224                            "Message channel closed",
225                        ));
226                    }
227
228                    for (direction, message) in generated_messages {
229                        let result = match direction {
230                            Direction::ToClient => client_message_sender.send(message),
231                            Direction::ToServer => server_message_sender.send(message),
232                        };
233
234                        if result.await.is_err() {
235                            return Err(std::io::Error::new(
236                                std::io::ErrorKind::BrokenPipe,
237                                "Notification channel closed",
238                            ));
239                        }
240                    }
241                }
242                Err(e) => {
243                    eprintln!("Error processing message: {}", e);
244                }
245            }
246        }
247    }
248
249    Ok(())
250}
251
252async fn forward_to_client<R>(
253    proxy: &Proxy,
254    mut server_reader: R,
255    server_message_sender: Sender<Message>,
256    client_message_sender: Sender<Message>,
257) -> std::io::Result<()>
258where
259    R: AsyncReadExt + Unpin,
260{
261    loop {
262        let message = match read_message(&mut server_reader).await {
263            Ok(msg) => Message::from_value(msg),
264            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
265                break;
266            }
267            Err(e) => return Err(e),
268        };
269
270        if let Ok(message) = message {
271            match proxy.process_server_message(message).await {
272                Ok(processed) => {
273                    let (main_message, generated_messages) = processed.into_parts();
274
275                    if client_message_sender.send(main_message).await.is_err() {
276                        return Err(std::io::Error::new(
277                            std::io::ErrorKind::BrokenPipe,
278                            "Message channel closed",
279                        ));
280                    }
281
282                    for (direction, message) in generated_messages {
283                        let result = match direction {
284                            Direction::ToClient => client_message_sender.send(message),
285                            Direction::ToServer => server_message_sender.send(message),
286                        };
287
288                        if result.await.is_err() {
289                            return Err(std::io::Error::new(
290                                std::io::ErrorKind::BrokenPipe,
291                                "Notification channel closed",
292                            ));
293                        }
294                    }
295                }
296                Err(e) => {
297                    eprintln!("Error processing message: {}", e);
298                }
299            }
300        }
301    }
302
303    Ok(())
304}
305
306pub struct ProxyBuilder {
307    hooks: HashMap<String, Arc<dyn Hook>>,
308}
309
310impl ProxyBuilder {
311    pub fn new() -> Self {
312        Self {
313            hooks: HashMap::new(),
314        }
315    }
316
317    pub fn with_hook(mut self, method: &str, hook: Arc<dyn Hook>) -> Self {
318        self.hooks.insert(method.to_owned(), hook);
319        self
320    }
321
322    pub fn build(self) -> Proxy {
323        Proxy::new(self.hooks)
324    }
325}
326
327impl Default for ProxyBuilder {
328    fn default() -> Self {
329        Self::new()
330    }
331}