1use crate::hooks::{Direction, Hook, HookError, Message};
2use crate::transport::{read_message, write_message};
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::join;
7use tokio::sync::Mutex;
8use tokio::sync::mpsc::{self, Sender};
9
10pub struct Proxy {
11 hooks: HashMap<String, Arc<dyn Hook>>,
12 pending_requests: Mutex<HashMap<i64, String>>,
13}
14
15impl Proxy {
16 fn new(hooks: HashMap<String, Arc<dyn Hook>>) -> Self {
17 Self {
18 hooks,
19 pending_requests: Mutex::new(HashMap::new()),
20 }
21 }
22
23 async fn process_client_message(
24 &self,
25 message: Message,
26 ) -> Result<ProcessedMessage, HookError> {
27 match message {
28 Message::Request(request) => match self.hooks.get(&request.method) {
29 Some(hook) => {
30 {
31 let mut pending = self.pending_requests.lock().await;
32 pending.insert(request.id, request.method.clone());
33 }
34 let output = hook.on_request(request).await?;
35 Ok(ProcessedMessage::WithMessages {
36 message: output.message,
37 generated_messages: output.generated_messages,
38 })
39 }
40 None => Ok(ProcessedMessage::Forward(Message::Request(request))),
41 },
42 Message::Notification(notification) => match self.hooks.get(¬ification.method) {
43 Some(hook) => {
44 let output = hook.on_notification(notification).await?;
45 Ok(ProcessedMessage::Forward(output.message))
46 }
47 None => Ok(ProcessedMessage::Forward(Message::Notification(
48 notification,
49 ))),
50 },
51 Message::Response { .. } => Ok(ProcessedMessage::Forward(message)),
52 }
53 }
54
55 async fn process_server_message(
56 &self,
57 message: Message,
58 ) -> Result<ProcessedMessage, HookError> {
59 match message {
60 Message::Response(response) => {
61 let method = {
62 let mut pending = self.pending_requests.lock().await;
63 pending.remove(&response.id)
64 };
65
66 if let Some(method) = method
67 && let Some(hook) = self.hooks.get(&method)
68 {
69 let output = hook.on_response(response).await?;
70 return Ok(ProcessedMessage::WithMessages {
71 message: output.message,
72 generated_messages: output.generated_messages,
73 });
74 }
75
76 Ok(ProcessedMessage::Forward(Message::Response(response)))
77 }
78 Message::Notification(notification) => match self.hooks.get(¬ification.method) {
79 Some(hook) => {
80 let output = hook.on_notification(notification).await?;
81 Ok(ProcessedMessage::Forward(output.message))
82 }
83 None => Ok(ProcessedMessage::Forward(Message::Notification(
84 notification,
85 ))),
86 },
87 Message::Request { .. } => Ok(ProcessedMessage::Forward(message)),
88 }
89 }
90
91 pub async fn forward<SR, SW, CR, CW>(
92 self,
93 server_reader: SR,
94 server_writer: SW,
95 client_reader: CR,
96 client_writer: CW,
97 ) -> std::io::Result<()>
98 where
99 SR: AsyncReadExt + Unpin + Send + 'static,
100 SW: AsyncWriteExt + Unpin + Send + 'static,
101 CR: AsyncReadExt + Unpin + Send + 'static,
102 CW: AsyncWriteExt + Unpin + Send + 'static,
103 {
104 let proxy = Arc::new(self);
105 let server_to_client_proxy = proxy.clone();
106
107 let (client_sender, mut client_receiver) = mpsc::channel::<Message>(10);
108 let (server_sender, mut server_receiver) = mpsc::channel::<Message>(10);
109
110 let server_message_sender = server_sender.clone();
111 let client_message_sender = client_sender.clone();
112 let server_to_client = tokio::spawn(async move {
113 forward_to_client(
114 &server_to_client_proxy,
115 server_reader,
116 server_message_sender,
117 client_message_sender,
118 )
119 .await
120 });
121
122 let client_to_server = tokio::spawn(async move {
123 forward_to_server(&proxy, client_reader, server_sender, client_sender).await
124 });
125
126 let write_to_server = tokio::spawn(async move {
127 let mut writer = server_writer;
128 while let Some(msg) = server_receiver.recv().await {
129 if write_message(&mut writer, &msg.to_value()).await.is_err() {
130 break;
131 }
132 }
133 Ok::<(), std::io::Error>(())
134 });
135
136 let write_to_client = tokio::spawn(async move {
137 let mut writer = client_writer;
138 while let Some(msg) = client_receiver.recv().await {
139 if write_message(&mut writer, &msg.to_value()).await.is_err() {
140 break;
141 }
142 }
143 Ok::<(), std::io::Error>(())
144 });
145
146 _ = join!(client_to_server, server_to_client);
147
148 drop(write_to_server);
149 drop(write_to_client);
150 Ok(())
151 }
152}
153
154impl Default for Proxy {
155 fn default() -> Self {
156 Self::new(HashMap::new())
157 }
158}
159
160#[derive(Debug)]
161pub enum ProcessedMessage {
162 Forward(Message),
163 WithMessages {
164 message: Message,
165 generated_messages: Vec<(Direction, Message)>,
166 },
167}
168
169impl ProcessedMessage {
170 pub fn get_message(&self) -> &Message {
171 match self {
172 ProcessedMessage::Forward(msg) => msg,
173 ProcessedMessage::WithMessages { message, .. } => message,
174 }
175 }
176
177 pub fn get_generated_messages(&self) -> &[(Direction, Message)] {
178 match self {
179 ProcessedMessage::Forward(_) => &[],
180 ProcessedMessage::WithMessages {
181 generated_messages: messages,
182 ..
183 } => messages,
184 }
185 }
186
187 pub fn into_parts(self) -> (Message, Vec<(Direction, Message)>) {
188 match self {
189 ProcessedMessage::Forward(msg) => (msg, Vec::new()),
190 ProcessedMessage::WithMessages {
191 message,
192 generated_messages: messages,
193 } => (message, messages),
194 }
195 }
196}
197
198async fn forward_to_server<R>(
199 proxy: &Proxy,
200 mut client_reader: R,
201 server_message_sender: Sender<Message>,
202 client_message_sender: Sender<Message>,
203) -> std::io::Result<()>
204where
205 R: AsyncReadExt + Unpin,
206{
207 loop {
208 let message = match read_message(&mut client_reader).await {
209 Ok(msg) => Message::from_value(msg),
210 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
211 break;
212 }
213 Err(e) => return Err(e),
214 };
215
216 if let Ok(message) = message {
217 match proxy.process_client_message(message).await {
218 Ok(processed) => {
219 let (main_message, generated_messages) = processed.into_parts();
220
221 if server_message_sender.send(main_message).await.is_err() {
222 return Err(std::io::Error::new(
223 std::io::ErrorKind::BrokenPipe,
224 "Message channel closed",
225 ));
226 }
227
228 for (direction, message) in generated_messages {
229 let result = match direction {
230 Direction::ToClient => client_message_sender.send(message),
231 Direction::ToServer => server_message_sender.send(message),
232 };
233
234 if result.await.is_err() {
235 return Err(std::io::Error::new(
236 std::io::ErrorKind::BrokenPipe,
237 "Notification channel closed",
238 ));
239 }
240 }
241 }
242 Err(e) => {
243 eprintln!("Error processing message: {}", e);
244 }
245 }
246 }
247 }
248
249 Ok(())
250}
251
252async fn forward_to_client<R>(
253 proxy: &Proxy,
254 mut server_reader: R,
255 server_message_sender: Sender<Message>,
256 client_message_sender: Sender<Message>,
257) -> std::io::Result<()>
258where
259 R: AsyncReadExt + Unpin,
260{
261 loop {
262 let message = match read_message(&mut server_reader).await {
263 Ok(msg) => Message::from_value(msg),
264 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
265 break;
266 }
267 Err(e) => return Err(e),
268 };
269
270 if let Ok(message) = message {
271 match proxy.process_server_message(message).await {
272 Ok(processed) => {
273 let (main_message, generated_messages) = processed.into_parts();
274
275 if client_message_sender.send(main_message).await.is_err() {
276 return Err(std::io::Error::new(
277 std::io::ErrorKind::BrokenPipe,
278 "Message channel closed",
279 ));
280 }
281
282 for (direction, message) in generated_messages {
283 let result = match direction {
284 Direction::ToClient => client_message_sender.send(message),
285 Direction::ToServer => server_message_sender.send(message),
286 };
287
288 if result.await.is_err() {
289 return Err(std::io::Error::new(
290 std::io::ErrorKind::BrokenPipe,
291 "Notification channel closed",
292 ));
293 }
294 }
295 }
296 Err(e) => {
297 eprintln!("Error processing message: {}", e);
298 }
299 }
300 }
301 }
302
303 Ok(())
304}
305
306pub struct ProxyBuilder {
307 hooks: HashMap<String, Arc<dyn Hook>>,
308}
309
310impl ProxyBuilder {
311 pub fn new() -> Self {
312 Self {
313 hooks: HashMap::new(),
314 }
315 }
316
317 pub fn with_hook(mut self, method: &str, hook: Arc<dyn Hook>) -> Self {
318 self.hooks.insert(method.to_owned(), hook);
319 self
320 }
321
322 pub fn build(self) -> Proxy {
323 Proxy::new(self.hooks)
324 }
325}
326
327impl Default for ProxyBuilder {
328 fn default() -> Self {
329 Self::new()
330 }
331}