lsp_proxy/
proxy.rs

1use crate::Message;
2use crate::hooks::{Hook, HookError};
3use crate::message::Direction;
4use crate::processed_message::ProcessedMessage;
5use crate::transport::{read_message, write_message};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::select;
10use tokio::sync::Mutex;
11use tokio::sync::mpsc::{self, UnboundedSender};
12
13pub struct Proxy {
14    hooks: HashMap<String, Arc<dyn Hook>>,
15    pending_requests: HashMap<i64, String>,
16}
17
18impl Proxy {
19    fn new(hooks: HashMap<String, Arc<dyn Hook>>) -> Self {
20        Self {
21            hooks,
22            pending_requests: HashMap::new(),
23        }
24    }
25
26    pub async fn forward<SR, SW, CR, CW>(
27        self,
28        server_reader: SR,
29        server_writer: SW,
30        client_reader: CR,
31        client_writer: CW,
32    ) -> std::io::Result<()>
33    where
34        SR: AsyncReadExt + Unpin + Send + 'static,
35        SW: AsyncWriteExt + Unpin + Send + 'static,
36        CR: AsyncReadExt + Unpin + Send + 'static,
37        CW: AsyncWriteExt + Unpin + Send + 'static,
38    {
39        let hooks = Arc::new(self.hooks);
40        let pending_requests = Arc::new(Mutex::new(self.pending_requests));
41
42        let (client_sender, mut client_receiver) = mpsc::unbounded_channel::<Message>();
43        let (server_sender, mut server_receiver) = mpsc::unbounded_channel::<Message>();
44
45        let server_message_sender = server_sender.clone();
46        let client_message_sender = client_sender.clone();
47        let hooks_client = Arc::clone(&hooks);
48        let pending_requests_client = Arc::clone(&pending_requests);
49        let server_to_client_task = tokio::spawn(async move {
50            forward_to_client(
51                hooks_client,
52                pending_requests_client,
53                server_reader,
54                server_message_sender,
55                client_message_sender,
56            )
57            .await
58        });
59
60        let hooks_server = Arc::clone(&hooks);
61        let pending_requests_server = Arc::clone(&pending_requests);
62        let client_to_server_task = tokio::spawn(async move {
63            forward_to_server(
64                hooks_server,
65                pending_requests_server,
66                client_reader,
67                server_sender,
68                client_sender,
69            )
70            .await
71        });
72
73        let write_to_server = tokio::spawn(async move {
74            let mut writer = server_writer;
75            while let Some(msg) = server_receiver.recv().await {
76                if write_message(&mut writer, &msg.to_value()).await.is_err() {
77                    break;
78                }
79            }
80            Ok::<(), std::io::Error>(())
81        });
82
83        let write_to_client = tokio::spawn(async move {
84            let mut writer = client_writer;
85            while let Some(msg) = client_receiver.recv().await {
86                if write_message(&mut writer, &msg.to_value()).await.is_err() {
87                    break;
88                }
89            }
90            Ok::<(), std::io::Error>(())
91        });
92
93        select! {
94            client_to_server = client_to_server_task => {
95                client_to_server?
96            },
97            server_to_client = server_to_client_task => {
98                server_to_client?
99            },
100            write_server = write_to_server => {
101                write_server?
102            },
103            write_client = write_to_client => {
104                write_client?
105            }
106        }
107    }
108}
109
110async fn process_message(
111    hooks: &HashMap<String, Arc<dyn Hook>>,
112    pending_requests: &Mutex<HashMap<i64, String>>,
113    message: Message,
114) -> Result<ProcessedMessage, HookError> {
115    match message {
116        Message::Request(request) => match hooks.get(&request.method) {
117            Some(hook) => {
118                pending_requests
119                    .lock()
120                    .await
121                    .insert(request.id, request.method.clone());
122
123                Ok(hook.on_request(request).await?.as_processed())
124            }
125            None => Ok(ProcessedMessage::Forward(Message::Request(request))),
126        },
127        Message::Notification(notification) => match hooks.get(&notification.method) {
128            Some(hook) => Ok(hook.on_notification(notification).await?.as_processed()),
129            None => Ok(ProcessedMessage::Forward(Message::Notification(
130                notification,
131            ))),
132        },
133        Message::Response(response) => {
134            let method = { pending_requests.lock().await.remove(&response.id) };
135
136            if let Some(method) = method
137                && let Some(hook) = hooks.get(&method)
138            {
139                return Ok(hook.on_response(response).await?.as_processed());
140            }
141
142            Ok(ProcessedMessage::Forward(Message::Response(response)))
143        }
144    }
145}
146
147impl Default for Proxy {
148    fn default() -> Self {
149        Self::new(HashMap::new())
150    }
151}
152
153async fn forward_to_server<R>(
154    hooks: Arc<HashMap<String, Arc<dyn Hook>>>,
155    pending_requests: Arc<Mutex<HashMap<i64, String>>>,
156    mut client_reader: R,
157    server_message_sender: UnboundedSender<Message>,
158    client_message_sender: UnboundedSender<Message>,
159) -> std::io::Result<()>
160where
161    R: AsyncReadExt + Unpin,
162{
163    loop {
164        let message = match read_message(&mut client_reader).await {
165            Ok(msg) => Message::from_value(msg),
166            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
167                break;
168            }
169            Err(e) => return Err(e),
170        };
171
172        if let Ok(message) = message {
173            match process_message(&hooks, &pending_requests, message).await {
174                Ok(processed) => {
175                    let (main_message, generated_messages) = processed.into_parts();
176
177                    if let Some(main_message) = main_message
178                        && server_message_sender.send(main_message).is_err()
179                    {
180                        return Err(std::io::Error::new(
181                            std::io::ErrorKind::BrokenPipe,
182                            "Message channel closed",
183                        ));
184                    }
185
186                    for (direction, message) in generated_messages {
187                        let result = match direction {
188                            Direction::ToClient => client_message_sender.send(message),
189                            Direction::ToServer => server_message_sender.send(message),
190                        };
191
192                        if result.is_err() {
193                            return Err(std::io::Error::new(
194                                std::io::ErrorKind::BrokenPipe,
195                                "Notification channel closed",
196                            ));
197                        }
198                    }
199                }
200                Err(e) => {
201                    eprintln!("Error processing message: {}", e);
202                }
203            }
204        }
205    }
206
207    Ok(())
208}
209
210async fn forward_to_client<R>(
211    hooks: Arc<HashMap<String, Arc<dyn Hook>>>,
212    pending_requests: Arc<Mutex<HashMap<i64, String>>>,
213    mut server_reader: R,
214    server_message_sender: UnboundedSender<Message>,
215    client_message_sender: UnboundedSender<Message>,
216) -> std::io::Result<()>
217where
218    R: AsyncReadExt + Unpin,
219{
220    loop {
221        let message = match read_message(&mut server_reader).await {
222            Ok(msg) => Message::from_value(msg),
223            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
224                break;
225            }
226            Err(e) => return Err(e),
227        };
228
229        if let Ok(message) = message {
230            match process_message(&hooks, &pending_requests, message).await {
231                Ok(processed) => {
232                    let (main_message, generated_messages) = processed.into_parts();
233
234                    if let Some(main_message) = main_message
235                        && client_message_sender.send(main_message).is_err()
236                    {
237                        return Err(std::io::Error::new(
238                            std::io::ErrorKind::BrokenPipe,
239                            "Message channel closed",
240                        ));
241                    }
242
243                    for (direction, message) in generated_messages {
244                        let result = match direction {
245                            Direction::ToClient => client_message_sender.send(message),
246                            Direction::ToServer => server_message_sender.send(message),
247                        };
248
249                        if result.is_err() {
250                            return Err(std::io::Error::new(
251                                std::io::ErrorKind::BrokenPipe,
252                                "Notification channel closed",
253                            ));
254                        }
255                    }
256                }
257                Err(e) => {
258                    eprintln!("Error processing message: {}", e);
259                }
260            }
261        }
262    }
263
264    Ok(())
265}
266
267pub struct ProxyBuilder {
268    hooks: HashMap<String, Arc<dyn Hook>>,
269}
270
271impl ProxyBuilder {
272    pub fn new() -> Self {
273        Self {
274            hooks: HashMap::new(),
275        }
276    }
277
278    pub fn with_hook(mut self, method: &str, hook: Arc<dyn Hook>) -> Self {
279        self.hooks.insert(method.to_owned(), hook);
280        self
281    }
282
283    pub fn build(self) -> Proxy {
284        Proxy::new(self.hooks)
285    }
286}
287
288impl Default for ProxyBuilder {
289    fn default() -> Self {
290        Self::new()
291    }
292}