cdns_rs/a_sync/tokio_exc/
mod.rs

1
2pub mod async_intrf;
3
4use std::io::ErrorKind;
5use std::{sync::Arc, time::Duration};
6
7
8use async_trait::async_trait;
9
10use tokio::net::{TcpStream, UdpSocket};
11use tokio::io::{AsyncWriteExt, AsyncReadExt};
12use tokio::time::timeout;
13use tokio::net::{TcpSocket};
14
15
16
17use crate::a_sync::network::SocketTap;
18use crate::network_common::SocketTapCommon;
19use crate::{internal_error, internal_error_map, CDnsErrorType};
20use crate::{a_sync::{network::{NetworkTap, NetworkTapType}, SocketTaps}, cfg_resolv_parser::ResolveConfEntry, CDnsResult};
21
22
23#[derive(Clone, Debug)]
24pub struct TokioSocketBase;
25
26impl SocketTaps<TokioSocketBase> for TokioSocketBase
27{
28    type TcpSock = TcpStream;
29
30    type UdpSock = UdpSocket;
31
32    #[cfg(feature = "use_async_tokio_tls")]
33    type TlsSock = self::with_tls::TcpTlsConnection;
34    
35    #[inline]
36    fn new_tcp_socket(resolver: Arc<ResolveConfEntry>, timeout: Duration) -> CDnsResult<Box<NetworkTapType<TokioSocketBase>>>
37    {
38        return NetworkTap::<Self::TcpSock, TokioSocketBase>::new(resolver, timeout)
39    }
40    
41    #[inline]
42    fn new_udp_socket(resolver: Arc<ResolveConfEntry>, timeout: Duration) -> CDnsResult<Box<NetworkTapType<TokioSocketBase>>>
43    {
44        return NetworkTap::<Self::UdpSock, TokioSocketBase>::new(resolver, timeout)
45    }
46    
47    #[cfg(feature = "use_async_tokio_tls")]
48    #[inline]
49    fn new_tls_socket(resolver: Arc<ResolveConfEntry>, timeout: Duration) -> CDnsResult<Box<NetworkTapType<TokioSocketBase>>> 
50    {
51        return NetworkTap::<Self::TlsSock, TokioSocketBase>::new(resolver, timeout)
52    }
53}
54
55#[cfg(feature = "use_async_tokio_tls")]
56pub mod with_tls
57{
58    use std::io::ErrorKind;
59    use std::os::fd::{AsFd, BorrowedFd};
60    use std::sync::Arc;
61    use std::time::Duration;
62
63    use async_trait::async_trait;
64    use rustls::pki_types::ServerName;
65    use rustls::RootCertStore;
66    use tokio::io::{AsyncReadExt, AsyncWriteExt};
67    use tokio::net::TcpStream;
68    use tokio::time::timeout;
69    use tokio_rustls::client::TlsStream;
70
71    use crate::a_sync::network::{NetworkTap, SocketTap};
72    use crate::a_sync::tokio_exc::new_tcp_stream;
73    use crate::a_sync::TokioSocketBase;
74    use crate::cfg_resolv_parser::ResolveConfEntry;
75    use crate::network_common::SocketTapCommon;
76    use crate::{internal_error, internal_error_map, CDnsErrorType, CDnsResult};
77
78    #[derive(Debug)]
79    pub struct TcpTlsConnection
80    {
81        stream: TlsStream<TcpStream>,
82    }
83
84    impl AsFd for TcpTlsConnection
85    {
86        fn as_fd(&self) -> BorrowedFd<'_> 
87        {
88            return self.stream.get_ref().0.as_fd();
89        }
90    }
91
92    impl TcpTlsConnection
93    {   async 
94        fn connect(cfg: &ResolveConfEntry, conn_timeout: Option<Duration>) -> CDnsResult<Self> 
95        {
96            //&self.domain_name, self.cfg.get_resolver_sa(), 
97            //self.cfg.get_adapter_ip()
98
99            let domain_name = 
100                if let Some(domainname) = cfg.get_tls_domain()
101                {
102                    //webpki::DnsNameRef::try_from_ascii_str(domainname).map_err(|e| internal_error_map!(CDnsErrorType::InternalError, "{}", e))?
103                    ServerName::try_from(domainname.clone())
104                        .map_err(|e| 
105                            internal_error_map!(CDnsErrorType::InternalError, "{}", e)
106                        )?
107                }
108                else
109                {
110                    internal_error!(CDnsErrorType::InternalError, "no domain is set for TLS conncection");
111                };
112
113            //let f = webpki::DnsNameRef::try_from_ascii_str(domainname).map_err(|e| internal_error_map!(CDnsErrorType::InternalError, "{}", e))?;
114
115            let config = 
116                rustls
117                    ::ClientConfig
118                    ::builder_with_protocol_versions(&[&rustls::version::TLS12])
119                        .with_root_certificates(RootCertStore{roots: webpki_roots::TLS_SERVER_ROOTS.into()})
120                        .with_no_client_auth();
121
122            let conn = 
123                tokio_rustls::TlsConnector::from(Arc::new(config));
124
125            let socket =  new_tcp_stream(&cfg, conn_timeout).await?;
126
127            
128            let mut stream_tls = 
129                conn
130                    .connect(domain_name, socket)
131                    .await
132                    .map_err(|e|
133                        internal_error_map!(CDnsErrorType::IoError, "{}", e)
134                    )?;
135
136            stream_tls.flush().await.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
137
138            return Ok( Self{ stream: stream_tls } );
139        }
140
141        async 
142        fn internal_poll_read(&self, timeout_dur: Duration) -> CDnsResult<()>
143        {
144            timeout(timeout_dur, self.stream.get_ref().0.readable())
145                .await
146                .map_err(|e|
147                    internal_error_map!(CDnsErrorType::IoError, "Timeout {}", e)
148                )?
149                .map_err(|e|
150                    internal_error_map!(CDnsErrorType::IoError, "socket poll error {}", e)
151                )
152        }
153
154    }
155
156
157    #[async_trait]
158    impl SocketTap<TokioSocketBase> for NetworkTap<TcpTlsConnection, TokioSocketBase>
159    {
160        async
161        fn connect(&mut self, conn_timeout: Option<Duration>) -> CDnsResult<()>
162        {
163            if self.sock.is_some() == true
164            {
165                // ignore
166                return Ok(());
167            }
168
169            let socket= 
170                TcpTlsConnection::connect(self.cfg.as_ref(), conn_timeout).await?;
171
172            self.sock = Some(socket);
173
174            return Ok(());
175        }
176
177        fn is_encrypted(&self) -> bool 
178        {
179            return true;
180        }
181
182        fn is_tcp(&self) -> bool 
183        {
184            return true;
185        }
186
187        fn should_append_len(&self) -> bool
188        {
189            return true;
190        }
191
192        async 
193        fn poll_read(&self) -> CDnsResult<()>
194        {
195            return self.sock.as_ref().unwrap().internal_poll_read(self.timeout).await;
196        }
197
198        async 
199        fn send(&mut self, sndbuf: &[u8]) -> CDnsResult<usize>  
200        {
201            return 
202                self
203                    .sock
204                    .as_mut()
205                    .unwrap()
206                    .stream
207                    .write_all(sndbuf)
208                    .await
209                    .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))
210                    .map(|_| sndbuf.len());
211        }
212
213        async
214        fn recv(&mut self, rcvbuf: &mut [u8]) -> CDnsResult<usize> 
215        {
216            loop
217            {
218                match self.sock.as_mut().unwrap().stream.read(rcvbuf).await
219                {
220                    Ok(n) => 
221                    {
222                        return Ok(n);
223                    },
224                    Err(ref e) if e.kind() == ErrorKind::WouldBlock =>
225                    {
226                        internal_error!(CDnsErrorType::RequestTimeout, "request timeout from: '{}'", self.get_remote_addr()); 
227                    },
228                    Err(ref e) if e.kind() == ErrorKind::Interrupted =>
229                    {
230                        continue;
231                    },
232                    Err(e) =>
233                    {
234                        internal_error!(CDnsErrorType::IoError, "{}", e); 
235                    }
236                }
237            }
238        }
239    }
240} // with_tls
241
242#[async_trait]
243impl SocketTap<TokioSocketBase> for NetworkTap<UdpSocket, TokioSocketBase>
244{
245    async
246    fn connect(&mut self, _conn_timeout: Option<Duration>) -> CDnsResult<()>
247    {
248        if self.sock.is_some() == true
249        {
250            // ignore
251            return Ok(());
252        }
253
254        let socket = 
255            UdpSocket::bind(self.cfg.get_adapter_ip())
256                .await
257                .map_err(|e| internal_error_map!(CDnsErrorType::InternalError, "{}", e))?;
258
259        socket.connect(self.cfg.get_resolver_sa())
260            .await
261            .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
262
263        self.sock = Some(socket);
264
265        return Ok(());
266    }
267
268    fn is_encrypted(&self) -> bool 
269    {
270        return false;
271    }
272
273    fn is_tcp(&self) -> bool 
274    {
275        return false;
276    }
277
278    fn should_append_len(&self) -> bool
279    {
280        return false;
281    }
282
283    async 
284    fn poll_read(&self) -> CDnsResult<()>
285    {
286        timeout(self.timeout, self.sock.as_ref().unwrap().readable())
287            .await
288            .map_err(|e|
289                internal_error_map!(CDnsErrorType::IoError, "Timeout {}", e)
290            )?
291            .map_err(|e|
292                internal_error_map!(CDnsErrorType::IoError, "socket poll error {}", e)
293            )
294    }
295
296    async
297    fn send(&mut self, sndbuf: &[u8]) -> CDnsResult<usize> 
298    {
299        return 
300            self.sock.as_mut()
301                .unwrap()
302                .send(sndbuf)
303                .await
304                .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e));
305    }
306
307    async 
308    fn recv(&mut self, rcvbuf: &mut [u8]) -> CDnsResult<usize> 
309    {
310        async 
311        fn sub_recv(this: &mut NetworkTap<UdpSocket, TokioSocketBase>, rcvbuf: &mut [u8]) -> CDnsResult<usize> 
312        {
313            loop
314            {
315                match this.sock.as_mut().unwrap().recv_from(rcvbuf).await
316                {
317                    Ok((rcv_len, rcv_src)) =>
318                    {
319                        // this should not fail because socket is "connected"
320                        if &rcv_src != this.get_remote_addr()
321                        {
322                            internal_error!(
323                                CDnsErrorType::DnsResponse, 
324                                "received answer from unknown host: '{}' exp: '{}'", 
325                                this.get_remote_addr(), 
326                                rcv_src
327                            );
328                        }
329
330                        return Ok(rcv_len);
331                    },
332                    Err(ref e) if e.kind() == ErrorKind::WouldBlock =>
333                    {
334                        continue;
335                    },
336                    Err(ref e) if e.kind() == ErrorKind::Interrupted =>
337                    {
338                        continue;
339                    },
340                    Err(e) =>
341                    {
342                        internal_error!(CDnsErrorType::IoError, "{}", e); 
343                    }
344                } // match
345            } // loop
346            
347        }
348
349        // wait for timeout
350        match timeout(self.timeout, sub_recv(self, rcvbuf)).await
351        {
352            Ok(r) => 
353                return r,
354            Err(e) => 
355                internal_error!(CDnsErrorType::RequestTimeout, "{}", e)
356        }
357    }
358}
359
360async 
361fn new_tcp_stream(cfg: &ResolveConfEntry, conn_timeout: Option<Duration>) -> CDnsResult<TcpStream> 
362{
363    // create socket
364    let socket = 
365        if cfg.get_resolver_ip().is_ipv4() == true
366        {
367            TcpSocket::new_v4()
368        }
369        else
370        {
371            TcpSocket::new_v6()
372        }
373        .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
374
375    // bind address
376    socket.bind(*cfg.get_adapter_ip()).map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
377
378    socket.set_keepalive(false).map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
379
380    socket.set_nodelay(true).map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
381
382    // connect
383    let tcpstream = 
384        if let Some(c_timeout) = conn_timeout
385        {
386            timeout(c_timeout, socket.connect(*cfg.get_resolver_sa()))
387                .await
388                .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?
389                .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?
390        }
391        else
392        {
393            socket
394                .connect(*cfg.get_resolver_sa())
395                .await
396                .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?
397        };
398
399    return Ok(tcpstream);
400}
401
402#[async_trait]
403impl SocketTap<TokioSocketBase> for NetworkTap<TcpStream, TokioSocketBase>
404{
405    async 
406    fn connect(&mut self, conn_timeout: Option<Duration>) -> CDnsResult<()> 
407    {
408        if self.sock.is_some() == true
409        {
410            // ignore
411            return Ok(());
412        }
413
414        // create socket
415        let tcpstream = new_tcp_stream(&self.cfg, conn_timeout).await?;
416
417        self.sock = Some(tcpstream);
418
419        return Ok(());
420    }
421
422    fn is_encrypted(&self) -> bool 
423    {
424        return false;
425    }
426
427    fn is_tcp(&self) -> bool 
428    {
429        return true;
430    }
431
432    fn should_append_len(&self) -> bool
433    {
434        return true;
435    }
436
437    async 
438    fn poll_read(&self) -> CDnsResult<()>
439    {
440        timeout(self.timeout, self.sock.as_ref().unwrap().readable())
441            .await
442            .map_err(|e|
443                internal_error_map!(CDnsErrorType::IoError, "Timeout {}", e)
444            )?
445            .map_err(|e|
446                internal_error_map!(CDnsErrorType::IoError, "socket poll error {}", e)
447            )
448    }
449
450    async 
451    fn send(&mut self, sndbuf: &[u8]) -> CDnsResult<usize>  
452    {
453        return 
454            self
455                .sock
456                .as_mut()
457                .unwrap()
458                .write(sndbuf)
459                .await
460                .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e));
461    }
462
463    async 
464    fn recv(&mut self, rcvbuf: &mut [u8]) -> CDnsResult<usize> 
465    {
466        async 
467        fn sub_recv(this: &mut NetworkTap<TcpStream, TokioSocketBase>, rcvbuf: &mut [u8]) -> CDnsResult<usize> 
468        {
469            loop
470            {
471                match this.sock.as_mut().unwrap().read(rcvbuf).await
472                {
473                    Ok(n) => 
474                    {
475                        return Ok(n);
476                    },
477                    Err(ref e) if e.kind() == ErrorKind::WouldBlock =>
478                    {
479                        continue;
480                    },
481                    Err(ref e) if e.kind() == ErrorKind::Interrupted =>
482                    {
483                        continue;
484                    },
485                    Err(e) =>
486                    {
487                        internal_error!(CDnsErrorType::IoError, "{}", e); 
488                    }
489                } // match
490            } // loop
491        }
492
493        // wait for timeout
494        match timeout(self.timeout, sub_recv(self, rcvbuf)).await
495        {
496            Ok(r) => return r,
497            Err(e) => internal_error!(CDnsErrorType::RequestTimeout, "{}", e)
498        }
499    }
500}
501
502
503
504#[cfg(test)]
505mod tests
506{
507    use std::{net::{IpAddr, SocketAddr}, sync::Arc, time::Duration};
508
509    use tokio::net::UdpSocket;
510
511    use crate::{a_sync::{network::NetworkTap, TokioSocketBase}, cfg_resolv_parser::ResolveConfEntry, common::IPV4_BIND_ALL};
512
513    #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
514    async fn test_struct()
515    {
516        
517        let ip0: IpAddr = "127.0.0.1".parse().unwrap();
518        let bind =  SocketAddr::from((IPV4_BIND_ALL, 0));
519        let v = Arc::new(ResolveConfEntry::new(SocketAddr::new(ip0, 53), None, bind).unwrap());
520        let res = NetworkTap::<UdpSocket, TokioSocketBase>::new(v, Duration::from_secs(5));
521
522        assert_eq!(res.is_ok(), true, "{}", res.err().unwrap());
523
524        let _res = res.unwrap();
525    }
526}
527