1use std::{collections::HashMap, hash::Hash};
2use std::sync::Arc;
3
4use task::JoinHandle;
5use tokio::{io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, sync::Mutex, task};
6
7use crate::{errors::{ProtocolError, UserError}, protocol::{Header, MessageType}, proxy_factory::ProxyFactory, proxy_parser::{ProxyStringType, parse_proxy_string}, transport::Transport};
8use crate::protocol::{ReplyData, RequestData, Identity, Encapsulation};
9use crate::encoding::{ToBytes, FromBytes};
10
11#[derive(Parser)]
12#[grammar = "proxystring.pest"]
13pub struct ProxyParser;
14
15pub struct Proxy {
16 pub write: WriteHalf<Box<dyn Transport + Send + Sync + Unpin>>,
17 pub request_id: i32,
18 pub ident: String,
19 pub host: String,
20 pub port: i32,
21 pub context: Option<HashMap<String, String>>,
22 pub handle: Option<JoinHandle<Result<(), Box<dyn std::error::Error + Sync + Send>>>>,
23 pub message_queue: Arc<Mutex<Vec<MessageType>>>,
24 pub stream_type: String
25}
26
27
28impl Drop for Proxy {
29 fn drop(&mut self) {
30 tokio::task::block_in_place(|| {
31 futures::executor::block_on(async {
32 self.close_connection().await
33 })
34 }).expect("Could not close connection");
35 match &self.handle {
36 Some(handle) => handle.abort(),
37 None => {}
38 };
39 }
40}
41
42impl Proxy {
43 async fn read_thread(mut rx: ReadHalf<Box<dyn Transport + Send + Sync + Unpin>>, message_queue: Arc<Mutex<Vec<MessageType>>>) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
44 let mut buffer = vec![0; 2048];
45 loop {
46 let bytes = rx.read(&mut buffer).await?;
47 let mut read: i32 = 0;
48 let header = Header::from_bytes(&buffer[read as usize..bytes], &mut read)?;
49
50 let message = match header.message_type {
51 2 => {
52 let reply = ReplyData::from_bytes(&buffer[read as usize..bytes as usize], &mut read)?;
53 MessageType::Reply(header, reply)
54 }
55 3 => {
56 MessageType::ValidateConnection(header)
57 },
58 _ => return Err(Box::new(ProtocolError::new(&format!("TCP: Unsuppored reply message type: {}", header.message_type))))
59 };
60
61 {
62 let mut lock = message_queue.lock().await;
63 lock.push(message);
64 }
65
66 std::thread::sleep(std::time::Duration::from_millis(1));
67 }
68 }
69
70 pub fn new(stream: Box<dyn Transport + Send + Sync + Unpin>, ident: &str, host: &str, port: i32, context: Option<HashMap<String, String>>) -> Proxy {
71 let stream_type = stream.transport_type();
72 let (rx, tx) = tokio::io::split(stream);
73 let mut proxy = Proxy {
74 write: tx,
75 request_id: 0,
76 ident: String::from(ident),
77 host: String::from(host),
78 port,
79 context: context,
80 handle: None,
81 message_queue: Arc::new(Mutex::new(Vec::new())),
82 stream_type
83 };
84 let message_queue = proxy.message_queue.clone();
85 proxy.handle = Some(task::spawn(async move {
86 Proxy::read_thread(rx, message_queue).await
87 }));
88
89 proxy
90 }
91
92 async fn close_connection(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
93 {
94 let header = Header::new(4, 14);
95 let mut bytes = header.to_bytes()?;
96 let written = self.write.write(&mut bytes).await?;
97 if written != header.message_size as usize {
98 return Err(Box::new(ProtocolError::new("TCP: Could not validate connection")))
99 }
100
101 Ok(())
102 }
103
104 pub async fn ice_context(&mut self, context: HashMap<String, String>) -> Result<Proxy, Box<dyn std::error::Error + Send + Sync>> {
105 let init_data = crate::communicator::INITDATA.lock().unwrap();
106 let proxy_string = format!("{}:{} -h {} -p {}", self.ident, self.stream_type, self.host, self.port);
107 match parse_proxy_string(&proxy_string)? {
108 ProxyStringType::DirectProxy(data) => {
109 ProxyFactory::create_proxy(data, init_data.properties(), Some(context)).await
110 }
111 _ => {
112 Err(Box::new(ProtocolError::new("ice_context() - could not create proxy")))
113 }
114 }
115 }
116
117 pub async fn dispatch<
118 T: 'static + std::fmt::Debug + std::fmt::Display + FromBytes + Send + Sync,
119 >(
120 &mut self,
121 op: &str,
122 mode: u8,
123 params: &Encapsulation,
124 context: Option<HashMap<String, String>>,
125 ) -> Result<ReplyData, Box<dyn std::error::Error + Send + Sync>> {
126 let id = String::from(self.ident.clone());
127 let req = self.create_request(&id, op, mode, params, context);
128 self.make_request::<T>(&req).await
129 }
130
131 pub fn create_request(&mut self, identity_name: &str, operation: &str, mode: u8, params: &Encapsulation, context: Option<HashMap<String, String>>) -> RequestData {
132 let context = match context {
133 Some(context) => context,
134 None => {
135 match self.context.as_ref() {
136 Some(context) => context.clone(),
137 None => HashMap::new()
138 }
139 }
140 };
141 self.request_id = self.request_id + 1;
142 RequestData {
143 request_id: self.request_id,
144 id: Identity::new(identity_name),
145 facet: Vec::new(),
146 operation: String::from(operation),
147 mode: mode,
148 context: context,
149 params: params.clone()
150 }
151 }
152
153 async fn send_request(&mut self, request: &RequestData) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
154 let req_bytes = request.to_bytes()?;
155 let header = Header::new(0, 14 + req_bytes.len() as i32);
156 let mut bytes = header.to_bytes()?;
157 bytes.extend(req_bytes);
158
159 let written = self.write.write(&mut bytes).await?;
160 if written != header.message_size as usize {
161 return Err(Box::new(ProtocolError::new(&format!("TCP: Error writing request {}", request.request_id))))
162 }
163 Ok(())
164 }
165
166 pub async fn await_validate_connection_message(&mut self) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
167 let timeout = std::time::Duration::from_secs(30); let now = std::time::Instant::now();
169
170 loop {
171 {
172 let mut lock = self.message_queue.lock().await;
173 let index = lock.iter().position(|i| {
174 match i {
175 MessageType::ValidateConnection(_) => true,
176 _ => false
177 }
178 });
179 match index {
180 Some(index) => {
181 lock.swap_remove(index);
182 break;
183 },
184 None => {}
185 }
186 }
187
188 if now.elapsed() >= timeout {
189 return Err(Box::new(ProtocolError::new("Timeout waiting for response")));
190 }
191
192 std::thread::sleep(std::time::Duration::from_millis(1));
193 }
194 Ok(())
195 }
196
197 pub async fn await_reply_message(&mut self, request_id: i32) -> Result<MessageType, Box<dyn std::error::Error + Sync + Send>> {
198 let timeout = std::time::Duration::from_secs(30); let now = std::time::Instant::now();
200
201 loop {
202 {
203 let mut lock = self.message_queue.lock().await;
204 let index = lock.iter().position(|i| {
205 match i {
206 MessageType::Reply(_, data) => {
207 if data.request_id == request_id {
208 true
209 } else {
210 false
211 }
212 },
213 _ => false
214 }
215 });
216 match index {
217 Some(index) => {
218 let result = lock.swap_remove(index);
219 return Ok(result)
220 },
221 None => {}
222 }
223 }
224
225 if now.elapsed() >= timeout {
226 return Err(Box::new(ProtocolError::new("Timeout waiting for response")));
227 }
228
229 std::thread::sleep(std::time::Duration::from_millis(1));
230 }
231 }
232
233 async fn read_response<T: 'static + std::fmt::Debug + std::fmt::Display + FromBytes + Send + Sync>(&mut self, request_id: i32) -> Result<ReplyData, Box<dyn std::error::Error + Sync + Send>> {
234 let message = self.await_reply_message(request_id).await?;
235 match message {
236 MessageType::Reply(_header, reply) => {
237 match reply.status {
238 1 => {
239 let mut read = 0;
240 Err(Box::new(UserError {
241 exception: T::from_bytes(&reply.body.data, &mut read)?
242 }))
243 }
244 _ => Ok(reply)
245 }
246 },
247 _ => Err(Box::new(ProtocolError::new(&format!("Unsupported message type: {:?}", message))))
248 }
249 }
250
251 pub async fn make_request<T: 'static + std::fmt::Debug + std::fmt::Display + FromBytes + Send + Sync>(&mut self, request: &RequestData) -> Result<ReplyData, Box<dyn std::error::Error + Sync + Send>>
252 {
253 self.send_request(request).await?;
254 self.read_response::<T>(request.request_id).await
255 }
256}