1use crate::hooks::{Direction, Hook, HookError, Message};
2use crate::transport::{read_message, write_message};
3use dashmap::DashMap;
4use std::sync::Arc;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::join;
7use tokio::sync::mpsc::{self, Sender};
8
9pub struct Proxy {
10 hooks: DashMap<String, Arc<dyn Hook>>,
11 pending_requests: DashMap<i64, String>,
12}
13
14impl Proxy {
15 fn new(hooks: DashMap<String, Arc<dyn Hook>>) -> Self {
16 Self {
17 hooks,
18 pending_requests: DashMap::new(),
19 }
20 }
21
22 async fn process_client_message(
23 &self,
24 message: Message,
25 ) -> Result<ProcessedMessage, HookError> {
26 match message {
27 Message::Request(request) => match self.hooks.get(&request.method) {
28 Some(hook) => {
29 self.pending_requests
30 .insert(request.id, request.method.clone());
31 let output = hook.on_request(request).await?;
32 Ok(ProcessedMessage::WithMessages {
33 message: output.message,
34 generated_messages: output.generated_messages,
35 })
36 }
37 None => Ok(ProcessedMessage::Forward(Message::Request(request))),
38 },
39 Message::Notification(notification) => match self.hooks.get(¬ification.method) {
40 Some(hook) => {
41 let output = hook.on_notification(notification).await?;
42 Ok(ProcessedMessage::Forward(output.message))
43 }
44 None => Ok(ProcessedMessage::Forward(Message::Notification(
45 notification,
46 ))),
47 },
48 Message::Response { .. } => Ok(ProcessedMessage::Forward(message)),
49 }
50 }
51
52 async fn process_server_message(
53 &self,
54 message: Message,
55 ) -> Result<ProcessedMessage, HookError> {
56 match message {
57 Message::Response(response) => {
58 let method = self
59 .pending_requests
60 .remove(&response.id)
61 .map(|(_, method)| method);
62
63 if let Some(method) = method
64 && let Some(hook) = self.hooks.get(&method)
65 {
66 let output = hook.on_response(response).await?;
67 return Ok(ProcessedMessage::WithMessages {
68 message: output.message,
69 generated_messages: output.generated_messages,
70 });
71 }
72
73 Ok(ProcessedMessage::Forward(Message::Response(response)))
74 }
75 Message::Notification(notification) => match self.hooks.get(¬ification.method) {
76 Some(hook) => {
77 let output = hook.on_notification(notification).await?;
78 Ok(ProcessedMessage::Forward(output.message))
79 }
80 None => Ok(ProcessedMessage::Forward(Message::Notification(
81 notification,
82 ))),
83 },
84 Message::Request { .. } => Ok(ProcessedMessage::Forward(message)),
85 }
86 }
87
88 pub async fn forward<SR, SW, CR, CW>(
89 self,
90 server_reader: SR,
91 server_writer: SW,
92 client_reader: CR,
93 client_writer: CW,
94 ) -> std::io::Result<()>
95 where
96 SR: AsyncReadExt + Unpin + Send + 'static,
97 SW: AsyncWriteExt + Unpin + Send + 'static,
98 CR: AsyncReadExt + Unpin + Send + 'static,
99 CW: AsyncWriteExt + Unpin + Send + 'static,
100 {
101 let proxy = Arc::new(self);
102 let server_to_client_proxy = proxy.clone();
103
104 let (client_sender, mut client_receiver) = mpsc::channel::<Message>(100);
105 let (server_sender, mut server_receiver) = mpsc::channel::<Message>(100);
106
107 let server_message_sender = server_sender.clone();
108 let client_message_sender = client_sender.clone();
109 let server_to_client = tokio::spawn(async move {
110 forward_to_client(
111 &server_to_client_proxy,
112 server_reader,
113 server_message_sender,
114 client_message_sender,
115 )
116 .await
117 });
118
119 let client_to_server = tokio::spawn(async move {
120 forward_to_server(&proxy, client_reader, server_sender, client_sender).await
121 });
122
123 let write_to_server = tokio::spawn(async move {
124 let mut writer = server_writer;
125 while let Some(msg) = server_receiver.recv().await {
126 if write_message(&mut writer, &msg.to_value()).await.is_err() {
127 break;
128 }
129 }
130 Ok::<(), std::io::Error>(())
131 });
132
133 let write_to_client = tokio::spawn(async move {
134 let mut writer = client_writer;
135 while let Some(msg) = client_receiver.recv().await {
136 if write_message(&mut writer, &msg.to_value()).await.is_err() {
137 break;
138 }
139 }
140 Ok::<(), std::io::Error>(())
141 });
142
143 _ = join!(client_to_server, server_to_client);
144
145 drop(write_to_server);
146 drop(write_to_client);
147 Ok(())
148 }
149}
150
151impl Default for Proxy {
152 fn default() -> Self {
153 Self::new(DashMap::new())
154 }
155}
156
157#[derive(Debug)]
158pub enum ProcessedMessage {
159 Forward(Message),
160 WithMessages {
161 message: Message,
162 generated_messages: Vec<(Direction, Message)>,
163 },
164}
165
166impl ProcessedMessage {
167 pub fn get_message(&self) -> &Message {
168 match self {
169 ProcessedMessage::Forward(msg) => msg,
170 ProcessedMessage::WithMessages { message, .. } => message,
171 }
172 }
173
174 pub fn get_generated_messages(&self) -> &[(Direction, Message)] {
175 match self {
176 ProcessedMessage::Forward(_) => &[],
177 ProcessedMessage::WithMessages {
178 generated_messages: messages,
179 ..
180 } => messages,
181 }
182 }
183
184 pub fn into_parts(self) -> (Message, Vec<(Direction, Message)>) {
185 match self {
186 ProcessedMessage::Forward(msg) => (msg, Vec::new()),
187 ProcessedMessage::WithMessages {
188 message,
189 generated_messages: messages,
190 } => (message, messages),
191 }
192 }
193}
194
195async fn forward_to_server<R>(
196 proxy: &Proxy,
197 mut client_reader: R,
198 server_message_sender: Sender<Message>,
199 client_message_sender: Sender<Message>,
200) -> std::io::Result<()>
201where
202 R: AsyncReadExt + Unpin,
203{
204 loop {
205 let message = match read_message(&mut client_reader).await {
206 Ok(msg) => Message::from_value(msg),
207 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
208 break;
209 }
210 Err(e) => return Err(e),
211 };
212
213 if let Ok(message) = message {
214 match proxy.process_client_message(message).await {
215 Ok(processed) => {
216 let (main_message, generated_messages) = processed.into_parts();
217
218 if server_message_sender.send(main_message).await.is_err() {
219 return Err(std::io::Error::new(
220 std::io::ErrorKind::BrokenPipe,
221 "Message channel closed",
222 ));
223 }
224
225 for (direction, message) in generated_messages {
226 let result = match direction {
227 Direction::ToClient => client_message_sender.send(message),
228 Direction::ToServer => server_message_sender.send(message),
229 };
230
231 if result.await.is_err() {
232 return Err(std::io::Error::new(
233 std::io::ErrorKind::BrokenPipe,
234 "Notification channel closed",
235 ));
236 }
237 }
238 }
239 Err(e) => {
240 eprintln!("Error processing message: {}", e);
241 }
242 }
243 }
244 }
245
246 Ok(())
247}
248
249async fn forward_to_client<R>(
250 proxy: &Proxy,
251 mut server_reader: R,
252 server_message_sender: Sender<Message>,
253 client_message_sender: Sender<Message>,
254) -> std::io::Result<()>
255where
256 R: AsyncReadExt + Unpin,
257{
258 loop {
259 let message = match read_message(&mut server_reader).await {
260 Ok(msg) => Message::from_value(msg),
261 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
262 break;
263 }
264 Err(e) => return Err(e),
265 };
266
267 if let Ok(message) = message {
268 match proxy.process_server_message(message).await {
269 Ok(processed) => {
270 let (main_message, generated_messages) = processed.into_parts();
271
272 if client_message_sender.send(main_message).await.is_err() {
273 return Err(std::io::Error::new(
274 std::io::ErrorKind::BrokenPipe,
275 "Message channel closed",
276 ));
277 }
278
279 for (direction, message) in generated_messages {
280 let result = match direction {
281 Direction::ToClient => client_message_sender.send(message),
282 Direction::ToServer => server_message_sender.send(message),
283 };
284
285 if result.await.is_err() {
286 return Err(std::io::Error::new(
287 std::io::ErrorKind::BrokenPipe,
288 "Notification channel closed",
289 ));
290 }
291 }
292 }
293 Err(e) => {
294 eprintln!("Error processing message: {}", e);
295 }
296 }
297 }
298 }
299
300 Ok(())
301}
302
303pub struct ProxyBuilder {
304 hooks: DashMap<String, Arc<dyn Hook>>,
305}
306
307impl ProxyBuilder {
308 pub fn new() -> Self {
309 Self {
310 hooks: DashMap::new(),
311 }
312 }
313
314 pub fn with_hook(self, method: &str, hook: Arc<dyn Hook>) -> Self {
315 self.hooks.insert(method.to_owned(), hook);
316 self
317 }
318
319 pub fn build(self) -> Proxy {
320 Proxy::new(self.hooks)
321 }
322}
323
324impl Default for ProxyBuilder {
325 fn default() -> Self {
326 Self::new()
327 }
328}