ice_rs/
proxy.rs

1use std::{collections::HashMap, hash::Hash};
2use std::sync::Arc;
3
4use task::JoinHandle;
5use tokio::{io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}, sync::Mutex, task};
6
7use crate::{errors::{ProtocolError, UserError}, protocol::{Header, MessageType}, proxy_factory::ProxyFactory, proxy_parser::{ProxyStringType, parse_proxy_string}, transport::Transport};
8use crate::protocol::{ReplyData, RequestData, Identity, Encapsulation};
9use crate::encoding::{ToBytes, FromBytes};
10
11#[derive(Parser)]
12#[grammar = "proxystring.pest"]
13pub struct ProxyParser;
14
15pub struct Proxy {
16    pub write: WriteHalf<Box<dyn Transport + Send + Sync + Unpin>>,
17    pub request_id: i32,
18    pub ident: String,
19    pub host: String,
20    pub port: i32,
21    pub context: Option<HashMap<String, String>>,
22    pub handle: Option<JoinHandle<Result<(), Box<dyn std::error::Error + Sync + Send>>>>,
23    pub message_queue: Arc<Mutex<Vec<MessageType>>>,
24    pub stream_type: String
25}
26
27
28impl Drop for Proxy {
29    fn drop(&mut self) {
30        tokio::task::block_in_place(|| {
31            futures::executor::block_on(async {
32                self.close_connection().await
33            })
34        }).expect("Could not close connection");
35        match &self.handle {
36            Some(handle) => handle.abort(),
37            None => {}
38        };
39    }
40}
41
42impl Proxy {
43    async fn read_thread(mut rx: ReadHalf<Box<dyn Transport + Send + Sync + Unpin>>, message_queue: Arc<Mutex<Vec<MessageType>>>) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
44        let mut buffer = vec![0; 2048];
45        loop {
46            let bytes = rx.read(&mut buffer).await?;
47            let mut read: i32 = 0;
48            let header = Header::from_bytes(&buffer[read as usize..bytes], &mut read)?;
49
50            let message = match header.message_type {
51                2 => {
52                    let reply = ReplyData::from_bytes(&buffer[read as usize..bytes as usize], &mut read)?;                    
53                    MessageType::Reply(header, reply)
54                }
55                3 => {
56                    MessageType::ValidateConnection(header)
57                },
58                _ => return Err(Box::new(ProtocolError::new(&format!("TCP: Unsuppored reply message type: {}", header.message_type))))
59            };
60
61            {
62                let mut lock = message_queue.lock().await;
63                lock.push(message);
64            }
65
66            std::thread::sleep(std::time::Duration::from_millis(1));
67        }
68    }
69
70    pub fn new(stream: Box<dyn Transport + Send + Sync + Unpin>, ident: &str, host: &str, port: i32, context: Option<HashMap<String, String>>) -> Proxy {
71        let stream_type = stream.transport_type();
72        let (rx, tx) = tokio::io::split(stream);
73        let mut proxy = Proxy {
74            write: tx,
75            request_id: 0,
76            ident: String::from(ident),
77            host: String::from(host),
78            port,
79            context: context,
80            handle: None,
81            message_queue: Arc::new(Mutex::new(Vec::new())),
82            stream_type
83        };
84        let message_queue = proxy.message_queue.clone();
85        proxy.handle = Some(task::spawn(async move {
86            Proxy::read_thread(rx, message_queue).await
87        }));
88
89        proxy
90    }
91
92    async fn close_connection(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
93    {
94        let header = Header::new(4, 14);
95        let mut bytes = header.to_bytes()?;
96        let written = self.write.write(&mut bytes).await?;
97        if written != header.message_size as usize {
98            return Err(Box::new(ProtocolError::new("TCP: Could not validate connection")))
99        }
100
101        Ok(())
102    }
103
104    pub async fn ice_context(&mut self, context: HashMap<String, String>) -> Result<Proxy, Box<dyn std::error::Error + Send + Sync>> {
105        let init_data = crate::communicator::INITDATA.lock().unwrap();
106        let proxy_string = format!("{}:{} -h {} -p {}", self.ident, self.stream_type, self.host, self.port);
107        match parse_proxy_string(&proxy_string)? {
108            ProxyStringType::DirectProxy(data) => {
109                ProxyFactory::create_proxy(data, init_data.properties(), Some(context)).await
110            }
111            _ => {
112                Err(Box::new(ProtocolError::new("ice_context() - could not create proxy")))
113            }
114        }
115    }
116
117    pub async fn dispatch<
118        T: 'static + std::fmt::Debug + std::fmt::Display + FromBytes + Send + Sync,
119    >(
120        &mut self,
121        op: &str,
122        mode: u8,
123        params: &Encapsulation,
124        context: Option<HashMap<String, String>>,
125    ) -> Result<ReplyData, Box<dyn std::error::Error + Send + Sync>> {
126        let id = String::from(self.ident.clone());
127        let req = self.create_request(&id, op, mode, params, context);
128        self.make_request::<T>(&req).await
129    }
130
131    pub fn create_request(&mut self, identity_name: &str, operation: &str, mode: u8, params: &Encapsulation, context: Option<HashMap<String, String>>) -> RequestData {
132        let context = match context {
133            Some(context) => context,
134            None => {
135                match self.context.as_ref() {
136                    Some(context) => context.clone(),
137                    None => HashMap::new()
138                }
139            }
140        };
141        self.request_id = self.request_id + 1;
142        RequestData {
143            request_id: self.request_id,
144            id: Identity::new(identity_name),
145            facet: Vec::new(),
146            operation: String::from(operation),
147            mode: mode,
148            context: context,
149            params: params.clone()
150        }
151    }
152
153    async fn send_request(&mut self, request: &RequestData) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
154        let req_bytes = request.to_bytes()?;
155        let header = Header::new(0, 14 + req_bytes.len() as i32);
156        let mut bytes = header.to_bytes()?;
157        bytes.extend(req_bytes);
158
159        let written = self.write.write(&mut bytes).await?;
160        if written != header.message_size as usize {
161            return Err(Box::new(ProtocolError::new(&format!("TCP: Error writing request {}", request.request_id))))
162        }
163        Ok(())
164    }
165
166    pub async fn await_validate_connection_message(&mut self) -> Result<(), Box<dyn std::error::Error + Sync + Send>> {
167        let timeout = std::time::Duration::from_secs(30); // TODO: read from ice config
168        let now = std::time::Instant::now();
169
170        loop {
171            {
172                let mut lock = self.message_queue.lock().await;
173                let index = lock.iter().position(|i| {
174                    match i {
175                        MessageType::ValidateConnection(_) => true,
176                        _ => false
177                    }
178                });
179                match index {
180                    Some(index) => {
181                        lock.swap_remove(index);
182                        break;
183                    },
184                    None => {}
185                }
186            }
187
188            if now.elapsed() >= timeout {
189                return Err(Box::new(ProtocolError::new("Timeout waiting for response")));
190            }
191
192            std::thread::sleep(std::time::Duration::from_millis(1));
193        }
194        Ok(())
195    }
196
197    pub async fn await_reply_message(&mut self, request_id: i32) -> Result<MessageType, Box<dyn std::error::Error + Sync + Send>> {
198        let timeout = std::time::Duration::from_secs(30); // TODO: read from ice config
199        let now = std::time::Instant::now();
200
201        loop {
202            {
203                let mut lock = self.message_queue.lock().await;
204                let index = lock.iter().position(|i| {
205                    match i {
206                        MessageType::Reply(_, data) => {
207                            if data.request_id == request_id {
208                                true
209                            } else {
210                                false
211                            }
212                        },
213                        _ => false
214                    }
215                });
216                match index {
217                    Some(index) => {
218                        let result = lock.swap_remove(index);
219                        return Ok(result)
220                    },
221                    None => {}
222                }
223            }
224
225            if now.elapsed() >= timeout {
226                return Err(Box::new(ProtocolError::new("Timeout waiting for response")));
227            }
228
229            std::thread::sleep(std::time::Duration::from_millis(1));
230        }
231    }
232
233    async fn read_response<T: 'static + std::fmt::Debug + std::fmt::Display + FromBytes + Send + Sync>(&mut self, request_id: i32) -> Result<ReplyData, Box<dyn std::error::Error + Sync + Send>> {
234        let message = self.await_reply_message(request_id).await?;
235        match message {
236            MessageType::Reply(_header, reply) => {
237                match reply.status {
238                    1 => {
239                        let mut read = 0;
240                        Err(Box::new(UserError {
241                            exception: T::from_bytes(&reply.body.data, &mut read)?
242                        }))
243                    }
244                    _ => Ok(reply)
245                }
246            },
247            _ => Err(Box::new(ProtocolError::new(&format!("Unsupported message type: {:?}", message))))
248        }
249    }
250
251    pub async fn make_request<T: 'static + std::fmt::Debug + std::fmt::Display + FromBytes + Send + Sync>(&mut self, request: &RequestData) -> Result<ReplyData, Box<dyn std::error::Error + Sync + Send>>
252    {
253        self.send_request(request).await?;
254        self.read_response::<T>(request.request_id).await
255    }
256}