1use crate::hooks::{Hook, HookError, Message};
2use crate::transport::{read_message, write_message};
3use serde_json::Value;
4use smol::io::{AsyncReadExt, AsyncWriteExt};
5use smol::process::{Command, Stdio};
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8
9pub struct Proxy {
10 hooks: HashMap<String, Arc<dyn Hook>>,
11 pending_requests: Arc<Mutex<HashMap<Value, String>>>,
12}
13
14impl Proxy {
15 fn new(hooks: HashMap<String, Arc<dyn Hook>>) -> Self {
16 Self {
17 hooks,
18 pending_requests: Arc::new(Mutex::new(HashMap::new())),
19 }
20 }
21
22 async fn process_client_message(
23 &self,
24 message: Message,
25 ) -> Result<ProcessedMessage, HookError> {
26 match &message {
27 Message::Request { id, method, .. } => {
28 if let Some(hook) = self.hooks.get(method) {
29 {
30 let mut pending = self.pending_requests.lock().unwrap();
31 pending.insert(id.clone(), method.clone());
32 }
33
34 let output = hook.on_request(message).await?;
35 Ok(ProcessedMessage::WithNotifications {
36 message: output.message,
37 notifications: output.notifications,
38 })
39 } else {
40 {
41 let mut pending = self.pending_requests.lock().unwrap();
42 pending.insert(id.clone(), method.clone());
43 }
44 Ok(ProcessedMessage::Forward(message))
45 }
46 }
47 Message::Notification { .. } => Ok(ProcessedMessage::Forward(message)),
48 Message::Response { .. } => Ok(ProcessedMessage::Forward(message)),
49 }
50 }
51
52 async fn process_server_message(
53 &self,
54 message: Message,
55 ) -> Result<ProcessedMessage, HookError> {
56 match &message {
57 Message::Response { id, .. } => {
58 let method = {
59 let mut pending = self.pending_requests.lock().unwrap();
60 pending.remove(id)
61 };
62
63 if let Some(method) = method
64 && let Some(hook) = self.hooks.get(&method)
65 {
66 let output = hook.on_response(message).await?;
67 return Ok(ProcessedMessage::WithNotifications {
68 message: output.message,
69 notifications: output.notifications,
70 });
71 }
72
73 Ok(ProcessedMessage::Forward(message))
74 }
75 Message::Notification { .. } => Ok(ProcessedMessage::Forward(message)),
76 Message::Request { .. } => Ok(ProcessedMessage::Forward(message)),
77 }
78 }
79
80 async fn process_message(
81 &self,
82 message: Value,
83 from_client: bool,
84 ) -> Result<ProcessedMessage, String> {
85 let parsed_message = Message::from_value(message)?;
86
87 let result = if from_client {
88 self.process_client_message(parsed_message).await
89 } else {
90 self.process_server_message(parsed_message).await
91 };
92
93 result.map_err(|e| e.to_string())
94 }
95
96 pub async fn spawn<R, W>(
97 self,
98 command: &str,
99 args: &[&str],
100 client_reader: R,
101 client_writer: W,
102 ) -> std::io::Result<()>
103 where
104 R: AsyncReadExt + Unpin + Send + 'static,
105 W: AsyncWriteExt + Unpin + Send + 'static,
106 {
107 let proxy = Arc::new(self);
108 let mut child = Command::new(command)
109 .args(args)
110 .stdin(Stdio::piped())
111 .stdout(Stdio::piped())
112 .stderr(Stdio::inherit())
113 .spawn()?;
114
115 let server_writer = child.stdin.take().unwrap();
116 let server_reader = child.stdout.take().unwrap();
117
118 let client_to_server_proxy = proxy.clone();
119 let server_to_client_proxy = proxy.clone();
120
121 let (client_sender, client_receiver) = smol::channel::unbounded::<Message>();
122 let (server_sender, server_receiver) = smol::channel::unbounded::<Message>();
123
124 let server_message_sender = server_sender.clone();
125 let server_notification_sender = server_sender.clone();
126 let client_message_sender = client_sender.clone();
127 let client_notification_sender = client_sender.clone();
128
129 let client_to_server = smol::spawn(async move {
130 forward_messages(
131 &client_to_server_proxy,
132 client_reader,
133 server_message_sender,
134 server_notification_sender,
135 true,
136 )
137 .await
138 });
139
140 let server_to_client = smol::spawn(async move {
141 forward_messages(
142 &server_to_client_proxy,
143 server_reader,
144 client_message_sender,
145 client_notification_sender,
146 false,
147 )
148 .await
149 });
150
151 let write_to_server = smol::spawn(async move {
152 let mut writer = server_writer;
153 while let Ok(msg) = server_receiver.recv().await {
154 if write_message(&mut writer, &msg.to_value()).await.is_err() {
155 break;
156 }
157 }
158 Ok::<(), std::io::Error>(())
159 });
160
161 let write_to_client = smol::spawn(async move {
162 let mut writer = client_writer;
163 while let Ok(msg) = client_receiver.recv().await {
164 if write_message(&mut writer, &msg.to_value()).await.is_err() {
165 break;
166 }
167 }
168 Ok::<(), std::io::Error>(())
169 });
170
171 _ = smol::future::zip(client_to_server, server_to_client).await;
172
173 child.kill()?;
174
175 drop(write_to_server);
176 drop(write_to_client);
177
178 Ok(())
180 }
181}
182
183impl Default for Proxy {
184 fn default() -> Self {
185 Self::new(HashMap::new())
186 }
187}
188
189#[derive(Debug)]
190pub enum ProcessedMessage {
191 Forward(Message),
192 WithNotifications {
193 message: Message,
194 notifications: Vec<Message>,
195 },
196}
197
198impl ProcessedMessage {
199 pub fn get_message(&self) -> &Message {
200 match self {
201 ProcessedMessage::Forward(msg) => msg,
202 ProcessedMessage::WithNotifications { message, .. } => message,
203 }
204 }
205
206 pub fn get_notifications(&self) -> &[Message] {
207 match self {
208 ProcessedMessage::Forward(_) => &[],
209 ProcessedMessage::WithNotifications { notifications, .. } => notifications,
210 }
211 }
212
213 pub fn into_parts(self) -> (Message, Vec<Message>) {
214 match self {
215 ProcessedMessage::Forward(msg) => (msg, Vec::new()),
216 ProcessedMessage::WithNotifications {
217 message,
218 notifications,
219 } => (message, notifications),
220 }
221 }
222}
223
224async fn forward_messages<R>(
225 proxy: &Proxy,
226 mut reader: R,
227 message_sender: smol::channel::Sender<Message>,
228 notification_sender: smol::channel::Sender<Message>,
229 from_client: bool,
230) -> std::io::Result<()>
231where
232 R: AsyncReadExt + Unpin,
233{
234 loop {
235 let message = match read_message(&mut reader).await {
236 Ok(msg) => msg,
237 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
238 break;
239 }
240 Err(e) => return Err(e),
241 };
242
243 match proxy.process_message(message, from_client).await {
244 Ok(processed) => {
245 let (main_message, notifications) = processed.into_parts();
246
247 if message_sender.send(main_message).await.is_err() {
248 return Err(std::io::Error::new(
249 std::io::ErrorKind::BrokenPipe,
250 "Message channel closed",
251 ));
252 }
253
254 for notification in notifications {
255 if notification_sender.send(notification).await.is_err() {
256 return Err(std::io::Error::new(
257 std::io::ErrorKind::BrokenPipe,
258 "Notification channel closed",
259 ));
260 }
261 }
262 }
263 Err(e) => {
264 eprintln!("Error processing message: {}", e);
265 }
266 }
267 }
268
269 Ok(())
270}
271
272pub struct ProxyBuilder {
273 hooks: HashMap<String, Arc<dyn Hook>>,
274}
275
276impl ProxyBuilder {
277 pub fn new() -> Self {
278 Self {
279 hooks: HashMap::new(),
280 }
281 }
282
283 pub fn with_hook(mut self, method: &str, hook: Arc<dyn Hook>) -> Self {
284 self.hooks.insert(method.to_owned(), hook);
285 self
286 }
287
288 pub fn build(self) -> Proxy {
289 Proxy::new(self.hooks)
290 }
291}
292
293impl Default for ProxyBuilder {
294 fn default() -> Self {
295 Self::new()
296 }
297}