1use crate::Message;
2use crate::hooks::{Hook, HookError};
3use crate::message::Direction;
4use crate::processed_message::ProcessedMessage;
5use crate::transport::{read_message, write_message};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::select;
10use tokio::sync::Mutex;
11use tokio::sync::mpsc::{self, UnboundedSender};
12
13pub struct Proxy {
14 hooks: HashMap<String, Arc<dyn Hook>>,
15 pending_requests: HashMap<i64, String>,
16}
17
18impl Proxy {
19 fn new(hooks: HashMap<String, Arc<dyn Hook>>) -> Self {
20 Self {
21 hooks,
22 pending_requests: HashMap::new(),
23 }
24 }
25
26 pub async fn forward<SR, SW, CR, CW>(
27 self,
28 server_reader: SR,
29 server_writer: SW,
30 client_reader: CR,
31 client_writer: CW,
32 ) -> std::io::Result<()>
33 where
34 SR: AsyncReadExt + Unpin + Send + 'static,
35 SW: AsyncWriteExt + Unpin + Send + 'static,
36 CR: AsyncReadExt + Unpin + Send + 'static,
37 CW: AsyncWriteExt + Unpin + Send + 'static,
38 {
39 let hooks = Arc::new(self.hooks);
40 let pending_requests = Arc::new(Mutex::new(self.pending_requests));
41
42 let (client_sender, mut client_receiver) = mpsc::unbounded_channel::<Message>();
43 let (server_sender, mut server_receiver) = mpsc::unbounded_channel::<Message>();
44
45 let server_message_sender = server_sender.clone();
46 let client_message_sender = client_sender.clone();
47 let hooks_client = Arc::clone(&hooks);
48 let pending_requests_client = Arc::clone(&pending_requests);
49 let server_to_client_task = tokio::spawn(async move {
50 forward_to_client(
51 hooks_client,
52 pending_requests_client,
53 server_reader,
54 server_message_sender,
55 client_message_sender,
56 )
57 .await
58 });
59
60 let hooks_server = Arc::clone(&hooks);
61 let pending_requests_server = Arc::clone(&pending_requests);
62 let client_to_server_task = tokio::spawn(async move {
63 forward_to_server(
64 hooks_server,
65 pending_requests_server,
66 client_reader,
67 server_sender,
68 client_sender,
69 )
70 .await
71 });
72
73 let write_to_server = tokio::spawn(async move {
74 let mut writer = server_writer;
75 while let Some(msg) = server_receiver.recv().await {
76 if write_message(&mut writer, &msg.to_value()).await.is_err() {
77 break;
78 }
79 }
80 Ok::<(), std::io::Error>(())
81 });
82
83 let write_to_client = tokio::spawn(async move {
84 let mut writer = client_writer;
85 while let Some(msg) = client_receiver.recv().await {
86 if write_message(&mut writer, &msg.to_value()).await.is_err() {
87 break;
88 }
89 }
90 Ok::<(), std::io::Error>(())
91 });
92
93 select! {
94 client_to_server = client_to_server_task => {
95 client_to_server?
96 },
97 server_to_client = server_to_client_task => {
98 server_to_client?
99 },
100 write_server = write_to_server => {
101 write_server?
102 },
103 write_client = write_to_client => {
104 write_client?
105 }
106 }
107 }
108}
109
110async fn process_server_message(
111 hooks: &HashMap<String, Arc<dyn Hook>>,
112 pending_requests: &Mutex<HashMap<i64, String>>,
113 message: Message,
114) -> Result<ProcessedMessage, HookError> {
115 match message {
116 Message::Response(response) => {
117 let method = { pending_requests.lock().await.remove(&response.id) };
118
119 if let Some(method) = method
120 && let Some(hook) = hooks.get(&method)
121 {
122 return Ok(hook.on_response(response).await?.as_processed());
123 }
124
125 Ok(ProcessedMessage::Forward(Message::Response(response)))
126 }
127 Message::Notification(notification) => match hooks.get(¬ification.method) {
128 Some(hook) => Ok(hook.on_notification(notification).await?.as_processed()),
129 None => Ok(ProcessedMessage::Forward(Message::Notification(
130 notification,
131 ))),
132 },
133 Message::Request { .. } => Ok(ProcessedMessage::Forward(message)),
134 }
135}
136
137async fn process_client_message(
138 hooks: &HashMap<String, Arc<dyn Hook>>,
139 pending_requests: &Mutex<HashMap<i64, String>>,
140 message: Message,
141) -> Result<ProcessedMessage, HookError> {
142 match message {
143 Message::Request(request) => match hooks.get(&request.method) {
144 Some(hook) => {
145 pending_requests
146 .lock()
147 .await
148 .insert(request.id, request.method.clone());
149
150 Ok(hook.on_request(request).await?.as_processed())
151 }
152 None => Ok(ProcessedMessage::Forward(Message::Request(request))),
153 },
154 Message::Notification(notification) => match hooks.get(¬ification.method) {
155 Some(hook) => Ok(hook.on_notification(notification).await?.as_processed()),
156 None => Ok(ProcessedMessage::Forward(Message::Notification(
157 notification,
158 ))),
159 },
160 Message::Response { .. } => Ok(ProcessedMessage::Forward(message)),
161 }
162}
163
164impl Default for Proxy {
165 fn default() -> Self {
166 Self::new(HashMap::new())
167 }
168}
169
170async fn forward_to_server<R>(
171 hooks: Arc<HashMap<String, Arc<dyn Hook>>>,
172 pending_requests: Arc<Mutex<HashMap<i64, String>>>,
173 mut client_reader: R,
174 server_message_sender: UnboundedSender<Message>,
175 client_message_sender: UnboundedSender<Message>,
176) -> std::io::Result<()>
177where
178 R: AsyncReadExt + Unpin,
179{
180 loop {
181 let message = match read_message(&mut client_reader).await {
182 Ok(msg) => Message::from_value(msg),
183 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
184 break;
185 }
186 Err(e) => return Err(e),
187 };
188
189 if let Ok(message) = message {
190 match process_client_message(&hooks, &pending_requests, message).await {
191 Ok(processed) => {
192 let (main_message, generated_messages) = processed.into_parts();
193
194 if let Some(main_message) = main_message
195 && server_message_sender.send(main_message).is_err()
196 {
197 return Err(std::io::Error::new(
198 std::io::ErrorKind::BrokenPipe,
199 "Message channel closed",
200 ));
201 }
202
203 for (direction, message) in generated_messages {
204 let result = match direction {
205 Direction::ToClient => client_message_sender.send(message),
206 Direction::ToServer => server_message_sender.send(message),
207 };
208
209 if result.is_err() {
210 return Err(std::io::Error::new(
211 std::io::ErrorKind::BrokenPipe,
212 "Notification channel closed",
213 ));
214 }
215 }
216 }
217 Err(e) => {
218 eprintln!("Error processing message: {}", e);
219 }
220 }
221 }
222 }
223
224 Ok(())
225}
226
227async fn forward_to_client<R>(
228 hooks: Arc<HashMap<String, Arc<dyn Hook>>>,
229 pending_requests: Arc<Mutex<HashMap<i64, String>>>,
230 mut server_reader: R,
231 server_message_sender: UnboundedSender<Message>,
232 client_message_sender: UnboundedSender<Message>,
233) -> std::io::Result<()>
234where
235 R: AsyncReadExt + Unpin,
236{
237 loop {
238 let message = match read_message(&mut server_reader).await {
239 Ok(msg) => Message::from_value(msg),
240 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
241 break;
242 }
243 Err(e) => return Err(e),
244 };
245
246 if let Ok(message) = message {
247 match process_server_message(&hooks, &pending_requests, message).await {
248 Ok(processed) => {
249 let (main_message, generated_messages) = processed.into_parts();
250
251 if let Some(main_message) = main_message
252 && client_message_sender.send(main_message).is_err()
253 {
254 return Err(std::io::Error::new(
255 std::io::ErrorKind::BrokenPipe,
256 "Message channel closed",
257 ));
258 }
259
260 for (direction, message) in generated_messages {
261 let result = match direction {
262 Direction::ToClient => client_message_sender.send(message),
263 Direction::ToServer => server_message_sender.send(message),
264 };
265
266 if result.is_err() {
267 return Err(std::io::Error::new(
268 std::io::ErrorKind::BrokenPipe,
269 "Notification channel closed",
270 ));
271 }
272 }
273 }
274 Err(e) => {
275 eprintln!("Error processing message: {}", e);
276 }
277 }
278 }
279 }
280
281 Ok(())
282}
283
284pub struct ProxyBuilder {
285 hooks: HashMap<String, Arc<dyn Hook>>,
286}
287
288impl ProxyBuilder {
289 pub fn new() -> Self {
290 Self {
291 hooks: HashMap::new(),
292 }
293 }
294
295 pub fn with_hook(mut self, method: &str, hook: Arc<dyn Hook>) -> Self {
296 self.hooks.insert(method.to_owned(), hook);
297 self
298 }
299
300 pub fn build(self) -> Proxy {
301 Proxy::new(self.hooks)
302 }
303}
304
305impl Default for ProxyBuilder {
306 fn default() -> Self {
307 Self::new()
308 }
309}