lsp_proxy/
proxy.rs

1use crate::hooks::{Hook, HookError, Message};
2use crate::transport::{read_message, write_message};
3use serde_json::Value;
4use smol::io::{AsyncReadExt, AsyncWriteExt};
5use smol::process::{Command, Stdio};
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8
9pub struct Proxy {
10    hooks: HashMap<String, Arc<dyn Hook>>,
11    pending_requests: Arc<Mutex<HashMap<Value, String>>>,
12}
13
14impl Proxy {
15    fn new(hooks: HashMap<String, Arc<dyn Hook>>) -> Self {
16        Self {
17            hooks,
18            pending_requests: Arc::new(Mutex::new(HashMap::new())),
19        }
20    }
21
22    async fn process_client_message(
23        &self,
24        message: Message,
25    ) -> Result<ProcessedMessage, HookError> {
26        match &message {
27            Message::Request { id, method, .. } => {
28                if let Some(hook) = self.hooks.get(method) {
29                    {
30                        let mut pending = self.pending_requests.lock().unwrap();
31                        pending.insert(id.clone(), method.clone());
32                    }
33
34                    let output = hook.on_request(message).await?;
35                    Ok(ProcessedMessage::WithNotifications {
36                        message: output.message,
37                        notifications: output.notifications,
38                    })
39                } else {
40                    {
41                        let mut pending = self.pending_requests.lock().unwrap();
42                        pending.insert(id.clone(), method.clone());
43                    }
44                    Ok(ProcessedMessage::Forward(message))
45                }
46            }
47            Message::Notification { .. } => Ok(ProcessedMessage::Forward(message)),
48            Message::Response { .. } => Ok(ProcessedMessage::Forward(message)),
49        }
50    }
51
52    async fn process_server_message(
53        &self,
54        message: Message,
55    ) -> Result<ProcessedMessage, HookError> {
56        match &message {
57            Message::Response { id, .. } => {
58                let method = {
59                    let mut pending = self.pending_requests.lock().unwrap();
60                    pending.remove(id)
61                };
62
63                if let Some(method) = method
64                    && let Some(hook) = self.hooks.get(&method)
65                {
66                    let output = hook.on_response(message).await?;
67                    return Ok(ProcessedMessage::WithNotifications {
68                        message: output.message,
69                        notifications: output.notifications,
70                    });
71                }
72
73                Ok(ProcessedMessage::Forward(message))
74            }
75            Message::Notification { .. } => Ok(ProcessedMessage::Forward(message)),
76            Message::Request { .. } => Ok(ProcessedMessage::Forward(message)),
77        }
78    }
79
80    async fn process_message(
81        &self,
82        message: Value,
83        from_client: bool,
84    ) -> Result<ProcessedMessage, String> {
85        let parsed_message = Message::from_value(message)?;
86
87        let result = if from_client {
88            self.process_client_message(parsed_message).await
89        } else {
90            self.process_server_message(parsed_message).await
91        };
92
93        result.map_err(|e| e.to_string())
94    }
95
96    pub async fn spawn<R, W>(
97        self,
98        command: &str,
99        args: &[&str],
100        client_reader: R,
101        client_writer: W,
102    ) -> std::io::Result<()>
103    where
104        R: AsyncReadExt + Unpin + Send + 'static,
105        W: AsyncWriteExt + Unpin + Send + 'static,
106    {
107        let proxy = Arc::new(self);
108        let mut child = Command::new(command)
109            .args(args)
110            .stdin(Stdio::piped())
111            .stdout(Stdio::piped())
112            .stderr(Stdio::inherit())
113            .spawn()?;
114
115        let server_writer = child.stdin.take().unwrap();
116        let server_reader = child.stdout.take().unwrap();
117
118        let client_to_server_proxy = proxy.clone();
119        let server_to_client_proxy = proxy.clone();
120
121        let (client_sender, client_receiver) = smol::channel::unbounded::<Message>();
122        let (server_sender, server_receiver) = smol::channel::unbounded::<Message>();
123
124        let server_message_sender = server_sender.clone();
125        let server_notification_sender = server_sender.clone();
126        let client_message_sender = client_sender.clone();
127        let client_notification_sender = client_sender.clone();
128
129        let client_to_server = smol::spawn(async move {
130            forward_messages(
131                &client_to_server_proxy,
132                client_reader,
133                server_message_sender,
134                server_notification_sender,
135                true,
136            )
137            .await
138        });
139
140        let server_to_client = smol::spawn(async move {
141            forward_messages(
142                &server_to_client_proxy,
143                server_reader,
144                client_message_sender,
145                client_notification_sender,
146                false,
147            )
148            .await
149        });
150
151        let write_to_server = smol::spawn(async move {
152            let mut writer = server_writer;
153            while let Ok(msg) = server_receiver.recv().await {
154                if write_message(&mut writer, &msg.to_value()).await.is_err() {
155                    break;
156                }
157            }
158            Ok::<(), std::io::Error>(())
159        });
160
161        let write_to_client = smol::spawn(async move {
162            let mut writer = client_writer;
163            while let Ok(msg) = client_receiver.recv().await {
164                if write_message(&mut writer, &msg.to_value()).await.is_err() {
165                    break;
166                }
167            }
168            Ok::<(), std::io::Error>(())
169        });
170
171        _ = smol::future::zip(client_to_server, server_to_client).await;
172
173        child.kill()?;
174
175        drop(write_to_server);
176        drop(write_to_client);
177
178        // forward_result.map_err(std::io::Error::other)
179        Ok(())
180    }
181}
182
183impl Default for Proxy {
184    fn default() -> Self {
185        Self::new(HashMap::new())
186    }
187}
188
189#[derive(Debug)]
190pub enum ProcessedMessage {
191    Forward(Message),
192    WithNotifications {
193        message: Message,
194        notifications: Vec<Message>,
195    },
196}
197
198impl ProcessedMessage {
199    pub fn get_message(&self) -> &Message {
200        match self {
201            ProcessedMessage::Forward(msg) => msg,
202            ProcessedMessage::WithNotifications { message, .. } => message,
203        }
204    }
205
206    pub fn get_notifications(&self) -> &[Message] {
207        match self {
208            ProcessedMessage::Forward(_) => &[],
209            ProcessedMessage::WithNotifications { notifications, .. } => notifications,
210        }
211    }
212
213    pub fn into_parts(self) -> (Message, Vec<Message>) {
214        match self {
215            ProcessedMessage::Forward(msg) => (msg, Vec::new()),
216            ProcessedMessage::WithNotifications {
217                message,
218                notifications,
219            } => (message, notifications),
220        }
221    }
222}
223
224async fn forward_messages<R>(
225    proxy: &Proxy,
226    mut reader: R,
227    message_sender: smol::channel::Sender<Message>,
228    notification_sender: smol::channel::Sender<Message>,
229    from_client: bool,
230) -> std::io::Result<()>
231where
232    R: AsyncReadExt + Unpin,
233{
234    loop {
235        let message = match read_message(&mut reader).await {
236            Ok(msg) => msg,
237            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
238                break;
239            }
240            Err(e) => return Err(e),
241        };
242
243        match proxy.process_message(message, from_client).await {
244            Ok(processed) => {
245                let (main_message, notifications) = processed.into_parts();
246
247                if message_sender.send(main_message).await.is_err() {
248                    return Err(std::io::Error::new(
249                        std::io::ErrorKind::BrokenPipe,
250                        "Message channel closed",
251                    ));
252                }
253
254                for notification in notifications {
255                    if notification_sender.send(notification).await.is_err() {
256                        return Err(std::io::Error::new(
257                            std::io::ErrorKind::BrokenPipe,
258                            "Notification channel closed",
259                        ));
260                    }
261                }
262            }
263            Err(e) => {
264                eprintln!("Error processing message: {}", e);
265            }
266        }
267    }
268
269    Ok(())
270}
271
272pub struct ProxyBuilder {
273    hooks: HashMap<String, Arc<dyn Hook>>,
274}
275
276impl ProxyBuilder {
277    pub fn new() -> Self {
278        Self {
279            hooks: HashMap::new(),
280        }
281    }
282
283    pub fn with_hook(mut self, method: &str, hook: Arc<dyn Hook>) -> Self {
284        self.hooks.insert(method.to_owned(), hook);
285        self
286    }
287
288    pub fn build(self) -> Proxy {
289        Proxy::new(self.hooks)
290    }
291}
292
293impl Default for ProxyBuilder {
294    fn default() -> Self {
295        Self::new()
296    }
297}