lsp_proxy/
proxy.rs

1use crate::Message;
2use crate::hooks::{Hook, HookError};
3use crate::message::Direction;
4use crate::processed_message::ProcessedMessage;
5use crate::transport::{read_message, write_message};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::select;
10use tokio::sync::Mutex;
11use tokio::sync::mpsc::{self, UnboundedSender};
12
13pub struct Proxy {
14    hooks: HashMap<String, Arc<dyn Hook>>,
15    pending_requests: HashMap<i64, String>,
16}
17
18impl Proxy {
19    fn new(hooks: HashMap<String, Arc<dyn Hook>>) -> Self {
20        Self {
21            hooks,
22            pending_requests: HashMap::new(),
23        }
24    }
25
26    pub async fn forward<SR, SW, CR, CW>(
27        self,
28        server_reader: SR,
29        server_writer: SW,
30        client_reader: CR,
31        client_writer: CW,
32    ) -> std::io::Result<()>
33    where
34        SR: AsyncReadExt + Unpin + Send + 'static,
35        SW: AsyncWriteExt + Unpin + Send + 'static,
36        CR: AsyncReadExt + Unpin + Send + 'static,
37        CW: AsyncWriteExt + Unpin + Send + 'static,
38    {
39        let hooks = Arc::new(self.hooks);
40        let pending_requests = Arc::new(Mutex::new(self.pending_requests));
41
42        let (client_sender, mut client_receiver) = mpsc::unbounded_channel::<Message>();
43        let (server_sender, mut server_receiver) = mpsc::unbounded_channel::<Message>();
44
45        let server_message_sender = server_sender.clone();
46        let client_message_sender = client_sender.clone();
47        let hooks_client = Arc::clone(&hooks);
48        let pending_requests_client = Arc::clone(&pending_requests);
49        let server_to_client_task = tokio::spawn(async move {
50            forward_to_client(
51                hooks_client,
52                pending_requests_client,
53                server_reader,
54                server_message_sender,
55                client_message_sender,
56            )
57            .await
58        });
59
60        let hooks_server = Arc::clone(&hooks);
61        let pending_requests_server = Arc::clone(&pending_requests);
62        let client_to_server_task = tokio::spawn(async move {
63            forward_to_server(
64                hooks_server,
65                pending_requests_server,
66                client_reader,
67                server_sender,
68                client_sender,
69            )
70            .await
71        });
72
73        let write_to_server = tokio::spawn(async move {
74            let mut writer = server_writer;
75            while let Some(msg) = server_receiver.recv().await {
76                if write_message(&mut writer, &msg.to_value()).await.is_err() {
77                    break;
78                }
79            }
80            Ok::<(), std::io::Error>(())
81        });
82
83        let write_to_client = tokio::spawn(async move {
84            let mut writer = client_writer;
85            while let Some(msg) = client_receiver.recv().await {
86                if write_message(&mut writer, &msg.to_value()).await.is_err() {
87                    break;
88                }
89            }
90            Ok::<(), std::io::Error>(())
91        });
92
93        select! {
94            client_to_server = client_to_server_task => {
95                client_to_server?
96            },
97            server_to_client = server_to_client_task => {
98                server_to_client?
99            },
100            write_server = write_to_server => {
101                write_server?
102            },
103            write_client = write_to_client => {
104                write_client?
105            }
106        }
107    }
108}
109
110async fn process_server_message(
111    hooks: &HashMap<String, Arc<dyn Hook>>,
112    pending_requests: &Mutex<HashMap<i64, String>>,
113    message: Message,
114) -> Result<ProcessedMessage, HookError> {
115    match message {
116        Message::Response(response) => {
117            let method = { pending_requests.lock().await.remove(&response.id) };
118
119            if let Some(method) = method
120                && let Some(hook) = hooks.get(&method)
121            {
122                return Ok(hook.on_response(response).await?.as_processed());
123            }
124
125            Ok(ProcessedMessage::Forward(Message::Response(response)))
126        }
127        Message::Notification(notification) => match hooks.get(&notification.method) {
128            Some(hook) => Ok(hook.on_notification(notification).await?.as_processed()),
129            None => Ok(ProcessedMessage::Forward(Message::Notification(
130                notification,
131            ))),
132        },
133        Message::Request { .. } => Ok(ProcessedMessage::Forward(message)),
134    }
135}
136
137async fn process_client_message(
138    hooks: &HashMap<String, Arc<dyn Hook>>,
139    pending_requests: &Mutex<HashMap<i64, String>>,
140    message: Message,
141) -> Result<ProcessedMessage, HookError> {
142    match message {
143        Message::Request(request) => match hooks.get(&request.method) {
144            Some(hook) => {
145                pending_requests
146                    .lock()
147                    .await
148                    .insert(request.id, request.method.clone());
149
150                Ok(hook.on_request(request).await?.as_processed())
151            }
152            None => Ok(ProcessedMessage::Forward(Message::Request(request))),
153        },
154        Message::Notification(notification) => match hooks.get(&notification.method) {
155            Some(hook) => Ok(hook.on_notification(notification).await?.as_processed()),
156            None => Ok(ProcessedMessage::Forward(Message::Notification(
157                notification,
158            ))),
159        },
160        Message::Response { .. } => Ok(ProcessedMessage::Forward(message)),
161    }
162}
163
164impl Default for Proxy {
165    fn default() -> Self {
166        Self::new(HashMap::new())
167    }
168}
169
170async fn forward_to_server<R>(
171    hooks: Arc<HashMap<String, Arc<dyn Hook>>>,
172    pending_requests: Arc<Mutex<HashMap<i64, String>>>,
173    mut client_reader: R,
174    server_message_sender: UnboundedSender<Message>,
175    client_message_sender: UnboundedSender<Message>,
176) -> std::io::Result<()>
177where
178    R: AsyncReadExt + Unpin,
179{
180    loop {
181        let message = match read_message(&mut client_reader).await {
182            Ok(msg) => Message::from_value(msg),
183            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
184                break;
185            }
186            Err(e) => return Err(e),
187        };
188
189        if let Ok(message) = message {
190            match process_client_message(&hooks, &pending_requests, message).await {
191                Ok(processed) => {
192                    let (main_message, generated_messages) = processed.into_parts();
193
194                    if let Some(main_message) = main_message
195                        && server_message_sender.send(main_message).is_err()
196                    {
197                        return Err(std::io::Error::new(
198                            std::io::ErrorKind::BrokenPipe,
199                            "Message channel closed",
200                        ));
201                    }
202
203                    for (direction, message) in generated_messages {
204                        let result = match direction {
205                            Direction::ToClient => client_message_sender.send(message),
206                            Direction::ToServer => server_message_sender.send(message),
207                        };
208
209                        if result.is_err() {
210                            return Err(std::io::Error::new(
211                                std::io::ErrorKind::BrokenPipe,
212                                "Notification channel closed",
213                            ));
214                        }
215                    }
216                }
217                Err(e) => {
218                    eprintln!("Error processing message: {}", e);
219                }
220            }
221        }
222    }
223
224    Ok(())
225}
226
227async fn forward_to_client<R>(
228    hooks: Arc<HashMap<String, Arc<dyn Hook>>>,
229    pending_requests: Arc<Mutex<HashMap<i64, String>>>,
230    mut server_reader: R,
231    server_message_sender: UnboundedSender<Message>,
232    client_message_sender: UnboundedSender<Message>,
233) -> std::io::Result<()>
234where
235    R: AsyncReadExt + Unpin,
236{
237    loop {
238        let message = match read_message(&mut server_reader).await {
239            Ok(msg) => Message::from_value(msg),
240            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
241                break;
242            }
243            Err(e) => return Err(e),
244        };
245
246        if let Ok(message) = message {
247            match process_server_message(&hooks, &pending_requests, message).await {
248                Ok(processed) => {
249                    let (main_message, generated_messages) = processed.into_parts();
250
251                    if let Some(main_message) = main_message
252                        && client_message_sender.send(main_message).is_err()
253                    {
254                        return Err(std::io::Error::new(
255                            std::io::ErrorKind::BrokenPipe,
256                            "Message channel closed",
257                        ));
258                    }
259
260                    for (direction, message) in generated_messages {
261                        let result = match direction {
262                            Direction::ToClient => client_message_sender.send(message),
263                            Direction::ToServer => server_message_sender.send(message),
264                        };
265
266                        if result.is_err() {
267                            return Err(std::io::Error::new(
268                                std::io::ErrorKind::BrokenPipe,
269                                "Notification channel closed",
270                            ));
271                        }
272                    }
273                }
274                Err(e) => {
275                    eprintln!("Error processing message: {}", e);
276                }
277            }
278        }
279    }
280
281    Ok(())
282}
283
284pub struct ProxyBuilder {
285    hooks: HashMap<String, Arc<dyn Hook>>,
286}
287
288impl ProxyBuilder {
289    pub fn new() -> Self {
290        Self {
291            hooks: HashMap::new(),
292        }
293    }
294
295    pub fn with_hook(mut self, method: &str, hook: Arc<dyn Hook>) -> Self {
296        self.hooks.insert(method.to_owned(), hook);
297        self
298    }
299
300    pub fn build(self) -> Proxy {
301        Proxy::new(self.hooks)
302    }
303}
304
305impl Default for ProxyBuilder {
306    fn default() -> Self {
307        Self::new()
308    }
309}