lsp_proxy/
proxy.rs

1use crate::hooks::{Direction, Hook, HookError, Message};
2use crate::transport::{read_message, write_message};
3use dashmap::DashMap;
4use std::sync::Arc;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::join;
7use tokio::sync::mpsc::{self, Sender};
8
9pub struct Proxy {
10    hooks: DashMap<String, Arc<dyn Hook>>,
11    pending_requests: DashMap<i64, String>,
12}
13
14impl Proxy {
15    fn new(hooks: DashMap<String, Arc<dyn Hook>>) -> Self {
16        Self {
17            hooks,
18            pending_requests: DashMap::new(),
19        }
20    }
21
22    async fn process_client_message(
23        &self,
24        message: Message,
25    ) -> Result<ProcessedMessage, HookError> {
26        match message {
27            Message::Request(request) => match self.hooks.get(&request.method) {
28                Some(hook) => {
29                    self.pending_requests
30                        .insert(request.id, request.method.clone());
31                    let output = hook.on_request(request).await?;
32                    Ok(ProcessedMessage::WithMessages {
33                        message: output.message,
34                        generated_messages: output.generated_messages,
35                    })
36                }
37                None => Ok(ProcessedMessage::Forward(Message::Request(request))),
38            },
39            Message::Notification(notification) => match self.hooks.get(&notification.method) {
40                Some(hook) => {
41                    let output = hook.on_notification(notification).await?;
42                    Ok(ProcessedMessage::Forward(output.message))
43                }
44                None => Ok(ProcessedMessage::Forward(Message::Notification(
45                    notification,
46                ))),
47            },
48            Message::Response { .. } => Ok(ProcessedMessage::Forward(message)),
49        }
50    }
51
52    async fn process_server_message(
53        &self,
54        message: Message,
55    ) -> Result<ProcessedMessage, HookError> {
56        match message {
57            Message::Response(response) => {
58                let method = self
59                    .pending_requests
60                    .remove(&response.id)
61                    .map(|(_, method)| method);
62
63                if let Some(method) = method
64                    && let Some(hook) = self.hooks.get(&method)
65                {
66                    let output = hook.on_response(response).await?;
67                    return Ok(ProcessedMessage::WithMessages {
68                        message: output.message,
69                        generated_messages: output.generated_messages,
70                    });
71                }
72
73                Ok(ProcessedMessage::Forward(Message::Response(response)))
74            }
75            Message::Notification(notification) => match self.hooks.get(&notification.method) {
76                Some(hook) => {
77                    let output = hook.on_notification(notification).await?;
78                    Ok(ProcessedMessage::Forward(output.message))
79                }
80                None => Ok(ProcessedMessage::Forward(Message::Notification(
81                    notification,
82                ))),
83            },
84            Message::Request { .. } => Ok(ProcessedMessage::Forward(message)),
85        }
86    }
87
88    pub async fn forward<SR, SW, CR, CW>(
89        self,
90        server_reader: SR,
91        server_writer: SW,
92        client_reader: CR,
93        client_writer: CW,
94    ) -> std::io::Result<()>
95    where
96        SR: AsyncReadExt + Unpin + Send + 'static,
97        SW: AsyncWriteExt + Unpin + Send + 'static,
98        CR: AsyncReadExt + Unpin + Send + 'static,
99        CW: AsyncWriteExt + Unpin + Send + 'static,
100    {
101        let proxy = Arc::new(self);
102        let server_to_client_proxy = proxy.clone();
103
104        let (client_sender, mut client_receiver) = mpsc::channel::<Message>(100);
105        let (server_sender, mut server_receiver) = mpsc::channel::<Message>(100);
106
107        let server_message_sender = server_sender.clone();
108        let client_message_sender = client_sender.clone();
109        let server_to_client = tokio::spawn(async move {
110            forward_to_client(
111                &server_to_client_proxy,
112                server_reader,
113                server_message_sender,
114                client_message_sender,
115            )
116            .await
117        });
118
119        let client_to_server = tokio::spawn(async move {
120            forward_to_server(&proxy, client_reader, server_sender, client_sender).await
121        });
122
123        let write_to_server = tokio::spawn(async move {
124            let mut writer = server_writer;
125            while let Some(msg) = server_receiver.recv().await {
126                if write_message(&mut writer, &msg.to_value()).await.is_err() {
127                    break;
128                }
129            }
130            Ok::<(), std::io::Error>(())
131        });
132
133        let write_to_client = tokio::spawn(async move {
134            let mut writer = client_writer;
135            while let Some(msg) = client_receiver.recv().await {
136                if write_message(&mut writer, &msg.to_value()).await.is_err() {
137                    break;
138                }
139            }
140            Ok::<(), std::io::Error>(())
141        });
142
143        _ = join!(client_to_server, server_to_client);
144
145        drop(write_to_server);
146        drop(write_to_client);
147        Ok(())
148    }
149}
150
151impl Default for Proxy {
152    fn default() -> Self {
153        Self::new(DashMap::new())
154    }
155}
156
157#[derive(Debug)]
158pub enum ProcessedMessage {
159    Forward(Message),
160    WithMessages {
161        message: Message,
162        generated_messages: Vec<(Direction, Message)>,
163    },
164}
165
166impl ProcessedMessage {
167    pub fn get_message(&self) -> &Message {
168        match self {
169            ProcessedMessage::Forward(msg) => msg,
170            ProcessedMessage::WithMessages { message, .. } => message,
171        }
172    }
173
174    pub fn get_generated_messages(&self) -> &[(Direction, Message)] {
175        match self {
176            ProcessedMessage::Forward(_) => &[],
177            ProcessedMessage::WithMessages {
178                generated_messages: messages,
179                ..
180            } => messages,
181        }
182    }
183
184    pub fn into_parts(self) -> (Message, Vec<(Direction, Message)>) {
185        match self {
186            ProcessedMessage::Forward(msg) => (msg, Vec::new()),
187            ProcessedMessage::WithMessages {
188                message,
189                generated_messages: messages,
190            } => (message, messages),
191        }
192    }
193}
194
195async fn forward_to_server<R>(
196    proxy: &Proxy,
197    mut client_reader: R,
198    server_message_sender: Sender<Message>,
199    client_message_sender: Sender<Message>,
200) -> std::io::Result<()>
201where
202    R: AsyncReadExt + Unpin,
203{
204    loop {
205        let message = match read_message(&mut client_reader).await {
206            Ok(msg) => Message::from_value(msg),
207            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
208                break;
209            }
210            Err(e) => return Err(e),
211        };
212
213        if let Ok(message) = message {
214            match proxy.process_client_message(message).await {
215                Ok(processed) => {
216                    let (main_message, generated_messages) = processed.into_parts();
217
218                    if server_message_sender.send(main_message).await.is_err() {
219                        return Err(std::io::Error::new(
220                            std::io::ErrorKind::BrokenPipe,
221                            "Message channel closed",
222                        ));
223                    }
224
225                    for (direction, message) in generated_messages {
226                        let result = match direction {
227                            Direction::ToClient => client_message_sender.send(message),
228                            Direction::ToServer => server_message_sender.send(message),
229                        };
230
231                        if result.await.is_err() {
232                            return Err(std::io::Error::new(
233                                std::io::ErrorKind::BrokenPipe,
234                                "Notification channel closed",
235                            ));
236                        }
237                    }
238                }
239                Err(e) => {
240                    eprintln!("Error processing message: {}", e);
241                }
242            }
243        }
244    }
245
246    Ok(())
247}
248
249async fn forward_to_client<R>(
250    proxy: &Proxy,
251    mut server_reader: R,
252    server_message_sender: Sender<Message>,
253    client_message_sender: Sender<Message>,
254) -> std::io::Result<()>
255where
256    R: AsyncReadExt + Unpin,
257{
258    loop {
259        let message = match read_message(&mut server_reader).await {
260            Ok(msg) => Message::from_value(msg),
261            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
262                break;
263            }
264            Err(e) => return Err(e),
265        };
266
267        if let Ok(message) = message {
268            match proxy.process_server_message(message).await {
269                Ok(processed) => {
270                    let (main_message, generated_messages) = processed.into_parts();
271
272                    if client_message_sender.send(main_message).await.is_err() {
273                        return Err(std::io::Error::new(
274                            std::io::ErrorKind::BrokenPipe,
275                            "Message channel closed",
276                        ));
277                    }
278
279                    for (direction, message) in generated_messages {
280                        let result = match direction {
281                            Direction::ToClient => client_message_sender.send(message),
282                            Direction::ToServer => server_message_sender.send(message),
283                        };
284
285                        if result.await.is_err() {
286                            return Err(std::io::Error::new(
287                                std::io::ErrorKind::BrokenPipe,
288                                "Notification channel closed",
289                            ));
290                        }
291                    }
292                }
293                Err(e) => {
294                    eprintln!("Error processing message: {}", e);
295                }
296            }
297        }
298    }
299
300    Ok(())
301}
302
303pub struct ProxyBuilder {
304    hooks: DashMap<String, Arc<dyn Hook>>,
305}
306
307impl ProxyBuilder {
308    pub fn new() -> Self {
309        Self {
310            hooks: DashMap::new(),
311        }
312    }
313
314    pub fn with_hook(self, method: &str, hook: Arc<dyn Hook>) -> Self {
315        self.hooks.insert(method.to_owned(), hook);
316        self
317    }
318
319    pub fn build(self) -> Proxy {
320        Proxy::new(self.hooks)
321    }
322}
323
324impl Default for ProxyBuilder {
325    fn default() -> Self {
326        Self::new()
327    }
328}