1use crate::Message;
2use crate::hooks::{Hook, HookError};
3use crate::message::Direction;
4use crate::processed_message::ProcessedMessage;
5use crate::transport::{read_message, write_message};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::select;
10use tokio::sync::Mutex;
11use tokio::sync::mpsc::{self, UnboundedSender};
12
13pub struct Proxy {
14 hooks: HashMap<String, Arc<dyn Hook>>,
15 pending_requests: HashMap<i64, String>,
16}
17
18impl Proxy {
19 fn new(hooks: HashMap<String, Arc<dyn Hook>>) -> Self {
20 Self {
21 hooks,
22 pending_requests: HashMap::new(),
23 }
24 }
25
26 pub async fn forward<SR, SW, CR, CW>(
27 self,
28 server_reader: SR,
29 server_writer: SW,
30 client_reader: CR,
31 client_writer: CW,
32 ) -> std::io::Result<()>
33 where
34 SR: AsyncReadExt + Unpin + Send + 'static,
35 SW: AsyncWriteExt + Unpin + Send + 'static,
36 CR: AsyncReadExt + Unpin + Send + 'static,
37 CW: AsyncWriteExt + Unpin + Send + 'static,
38 {
39 let hooks = Arc::new(self.hooks);
40 let pending_requests = Arc::new(Mutex::new(self.pending_requests));
41
42 let (client_sender, mut client_receiver) = mpsc::unbounded_channel::<Message>();
43 let (server_sender, mut server_receiver) = mpsc::unbounded_channel::<Message>();
44
45 let server_message_sender = server_sender.clone();
46 let client_message_sender = client_sender.clone();
47 let hooks_client = Arc::clone(&hooks);
48 let pending_requests_client = Arc::clone(&pending_requests);
49 let server_to_client_task = tokio::spawn(async move {
50 forward_to_client(
51 hooks_client,
52 pending_requests_client,
53 server_reader,
54 server_message_sender,
55 client_message_sender,
56 )
57 .await
58 });
59
60 let hooks_server = Arc::clone(&hooks);
61 let pending_requests_server = Arc::clone(&pending_requests);
62 let client_to_server_task = tokio::spawn(async move {
63 forward_to_server(
64 hooks_server,
65 pending_requests_server,
66 client_reader,
67 server_sender,
68 client_sender,
69 )
70 .await
71 });
72
73 let write_to_server = tokio::spawn(async move {
74 let mut writer = server_writer;
75 while let Some(msg) = server_receiver.recv().await {
76 if write_message(&mut writer, &msg.to_value()).await.is_err() {
77 break;
78 }
79 }
80 Ok::<(), std::io::Error>(())
81 });
82
83 let write_to_client = tokio::spawn(async move {
84 let mut writer = client_writer;
85 while let Some(msg) = client_receiver.recv().await {
86 if write_message(&mut writer, &msg.to_value()).await.is_err() {
87 break;
88 }
89 }
90 Ok::<(), std::io::Error>(())
91 });
92
93 select! {
94 client_to_server = client_to_server_task => {
95 client_to_server?
96 },
97 server_to_client = server_to_client_task => {
98 server_to_client?
99 },
100 write_server = write_to_server => {
101 write_server?
102 },
103 write_client = write_to_client => {
104 write_client?
105 }
106 }
107 }
108}
109
110async fn process_message(
111 hooks: &HashMap<String, Arc<dyn Hook>>,
112 pending_requests: &Mutex<HashMap<i64, String>>,
113 message: Message,
114) -> Result<ProcessedMessage, HookError> {
115 match message {
116 Message::Request(request) => match hooks.get(&request.method) {
117 Some(hook) => {
118 pending_requests
119 .lock()
120 .await
121 .insert(request.id, request.method.clone());
122
123 Ok(hook.on_request(request).await?.as_processed())
124 }
125 None => Ok(ProcessedMessage::Forward(Message::Request(request))),
126 },
127 Message::Notification(notification) => match hooks.get(¬ification.method) {
128 Some(hook) => Ok(hook.on_notification(notification).await?.as_processed()),
129 None => Ok(ProcessedMessage::Forward(Message::Notification(
130 notification,
131 ))),
132 },
133 Message::Response(response) => {
134 let method = { pending_requests.lock().await.remove(&response.id) };
135
136 if let Some(method) = method
137 && let Some(hook) = hooks.get(&method)
138 {
139 return Ok(hook.on_response(response).await?.as_processed());
140 }
141
142 Ok(ProcessedMessage::Forward(Message::Response(response)))
143 }
144 }
145}
146
147impl Default for Proxy {
148 fn default() -> Self {
149 Self::new(HashMap::new())
150 }
151}
152
153async fn forward_to_server<R>(
154 hooks: Arc<HashMap<String, Arc<dyn Hook>>>,
155 pending_requests: Arc<Mutex<HashMap<i64, String>>>,
156 mut client_reader: R,
157 server_message_sender: UnboundedSender<Message>,
158 client_message_sender: UnboundedSender<Message>,
159) -> std::io::Result<()>
160where
161 R: AsyncReadExt + Unpin,
162{
163 loop {
164 let message = match read_message(&mut client_reader).await {
165 Ok(msg) => Message::from_value(msg),
166 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
167 break;
168 }
169 Err(e) => return Err(e),
170 };
171
172 if let Ok(message) = message {
173 match process_message(&hooks, &pending_requests, message).await {
174 Ok(processed) => {
175 let (main_message, generated_messages) = processed.into_parts();
176
177 if let Some(main_message) = main_message
178 && server_message_sender.send(main_message).is_err()
179 {
180 return Err(std::io::Error::new(
181 std::io::ErrorKind::BrokenPipe,
182 "Message channel closed",
183 ));
184 }
185
186 for (direction, message) in generated_messages {
187 let result = match direction {
188 Direction::ToClient => client_message_sender.send(message),
189 Direction::ToServer => server_message_sender.send(message),
190 };
191
192 if result.is_err() {
193 return Err(std::io::Error::new(
194 std::io::ErrorKind::BrokenPipe,
195 "Notification channel closed",
196 ));
197 }
198 }
199 }
200 Err(e) => {
201 eprintln!("Error processing message: {}", e);
202 }
203 }
204 }
205 }
206
207 Ok(())
208}
209
210async fn forward_to_client<R>(
211 hooks: Arc<HashMap<String, Arc<dyn Hook>>>,
212 pending_requests: Arc<Mutex<HashMap<i64, String>>>,
213 mut server_reader: R,
214 server_message_sender: UnboundedSender<Message>,
215 client_message_sender: UnboundedSender<Message>,
216) -> std::io::Result<()>
217where
218 R: AsyncReadExt + Unpin,
219{
220 loop {
221 let message = match read_message(&mut server_reader).await {
222 Ok(msg) => Message::from_value(msg),
223 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
224 break;
225 }
226 Err(e) => return Err(e),
227 };
228
229 if let Ok(message) = message {
230 match process_message(&hooks, &pending_requests, message).await {
231 Ok(processed) => {
232 let (main_message, generated_messages) = processed.into_parts();
233
234 if let Some(main_message) = main_message
235 && client_message_sender.send(main_message).is_err()
236 {
237 return Err(std::io::Error::new(
238 std::io::ErrorKind::BrokenPipe,
239 "Message channel closed",
240 ));
241 }
242
243 for (direction, message) in generated_messages {
244 let result = match direction {
245 Direction::ToClient => client_message_sender.send(message),
246 Direction::ToServer => server_message_sender.send(message),
247 };
248
249 if result.is_err() {
250 return Err(std::io::Error::new(
251 std::io::ErrorKind::BrokenPipe,
252 "Notification channel closed",
253 ));
254 }
255 }
256 }
257 Err(e) => {
258 eprintln!("Error processing message: {}", e);
259 }
260 }
261 }
262 }
263
264 Ok(())
265}
266
267pub struct ProxyBuilder {
268 hooks: HashMap<String, Arc<dyn Hook>>,
269}
270
271impl ProxyBuilder {
272 pub fn new() -> Self {
273 Self {
274 hooks: HashMap::new(),
275 }
276 }
277
278 pub fn with_hook(mut self, method: &str, hook: Arc<dyn Hook>) -> Self {
279 self.hooks.insert(method.to_owned(), hook);
280 self
281 }
282
283 pub fn build(self) -> Proxy {
284 Proxy::new(self.hooks)
285 }
286}
287
288impl Default for ProxyBuilder {
289 fn default() -> Self {
290 Self::new()
291 }
292}