cdns_rs/a_sync/tokio_exc/
mod.rs1
2pub mod async_intrf;
3
4use std::io::ErrorKind;
5use std::{sync::Arc, time::Duration};
6
7
8use async_trait::async_trait;
9
10use tokio::net::{TcpStream, UdpSocket};
11use tokio::io::{AsyncWriteExt, AsyncReadExt};
12use tokio::time::timeout;
13use tokio::net::{TcpSocket};
14
15
16
17use crate::a_sync::network::SocketTap;
18use crate::network_common::SocketTapCommon;
19use crate::{internal_error, internal_error_map, CDnsErrorType};
20use crate::{a_sync::{network::{NetworkTap, NetworkTapType}, SocketTaps}, cfg_resolv_parser::ResolveConfEntry, CDnsResult};
21
22
23#[derive(Clone, Debug)]
24pub struct TokioSocketBase;
25
26impl SocketTaps<TokioSocketBase> for TokioSocketBase
27{
28 type TcpSock = TcpStream;
29
30 type UdpSock = UdpSocket;
31
32 #[cfg(feature = "use_async_tokio_tls")]
33 type TlsSock = self::with_tls::TcpTlsConnection;
34
35 #[inline]
36 fn new_tcp_socket(resolver: Arc<ResolveConfEntry>, timeout: Duration) -> CDnsResult<Box<NetworkTapType<TokioSocketBase>>>
37 {
38 return NetworkTap::<Self::TcpSock, TokioSocketBase>::new(resolver, timeout)
39 }
40
41 #[inline]
42 fn new_udp_socket(resolver: Arc<ResolveConfEntry>, timeout: Duration) -> CDnsResult<Box<NetworkTapType<TokioSocketBase>>>
43 {
44 return NetworkTap::<Self::UdpSock, TokioSocketBase>::new(resolver, timeout)
45 }
46
47 #[cfg(feature = "use_async_tokio_tls")]
48 #[inline]
49 fn new_tls_socket(resolver: Arc<ResolveConfEntry>, timeout: Duration) -> CDnsResult<Box<NetworkTapType<TokioSocketBase>>>
50 {
51 return NetworkTap::<Self::TlsSock, TokioSocketBase>::new(resolver, timeout)
52 }
53}
54
55#[cfg(feature = "use_async_tokio_tls")]
56pub mod with_tls
57{
58 use std::io::ErrorKind;
59 use std::os::fd::{AsFd, BorrowedFd};
60 use std::sync::Arc;
61 use std::time::Duration;
62
63 use async_trait::async_trait;
64 use rustls::pki_types::ServerName;
65 use rustls::RootCertStore;
66 use tokio::io::{AsyncReadExt, AsyncWriteExt};
67 use tokio::net::TcpStream;
68 use tokio::time::timeout;
69 use tokio_rustls::client::TlsStream;
70
71 use crate::a_sync::network::{NetworkTap, SocketTap};
72 use crate::a_sync::tokio_exc::new_tcp_stream;
73 use crate::a_sync::TokioSocketBase;
74 use crate::cfg_resolv_parser::ResolveConfEntry;
75 use crate::network_common::SocketTapCommon;
76 use crate::{internal_error, internal_error_map, CDnsErrorType, CDnsResult};
77
78 #[derive(Debug)]
79 pub struct TcpTlsConnection
80 {
81 stream: TlsStream<TcpStream>,
82 }
83
84 impl AsFd for TcpTlsConnection
85 {
86 fn as_fd(&self) -> BorrowedFd<'_>
87 {
88 return self.stream.get_ref().0.as_fd();
89 }
90 }
91
92 impl TcpTlsConnection
93 { async
94 fn connect(cfg: &ResolveConfEntry, conn_timeout: Option<Duration>) -> CDnsResult<Self>
95 {
96 let domain_name =
100 if let Some(domainname) = cfg.get_tls_domain()
101 {
102 ServerName::try_from(domainname.clone())
104 .map_err(|e|
105 internal_error_map!(CDnsErrorType::InternalError, "{}", e)
106 )?
107 }
108 else
109 {
110 internal_error!(CDnsErrorType::InternalError, "no domain is set for TLS conncection");
111 };
112
113 let config =
116 rustls
117 ::ClientConfig
118 ::builder_with_protocol_versions(&[&rustls::version::TLS12])
119 .with_root_certificates(RootCertStore{roots: webpki_roots::TLS_SERVER_ROOTS.into()})
120 .with_no_client_auth();
121
122 let conn =
123 tokio_rustls::TlsConnector::from(Arc::new(config));
124
125 let socket = new_tcp_stream(&cfg, conn_timeout).await?;
126
127
128 let mut stream_tls =
129 conn
130 .connect(domain_name, socket)
131 .await
132 .map_err(|e|
133 internal_error_map!(CDnsErrorType::IoError, "{}", e)
134 )?;
135
136 stream_tls.flush().await.map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
137
138 return Ok( Self{ stream: stream_tls } );
139 }
140
141 async
142 fn internal_poll_read(&self, timeout_dur: Duration) -> CDnsResult<()>
143 {
144 timeout(timeout_dur, self.stream.get_ref().0.readable())
145 .await
146 .map_err(|e|
147 internal_error_map!(CDnsErrorType::IoError, "Timeout {}", e)
148 )?
149 .map_err(|e|
150 internal_error_map!(CDnsErrorType::IoError, "socket poll error {}", e)
151 )
152 }
153
154 }
155
156
157 #[async_trait]
158 impl SocketTap<TokioSocketBase> for NetworkTap<TcpTlsConnection, TokioSocketBase>
159 {
160 async
161 fn connect(&mut self, conn_timeout: Option<Duration>) -> CDnsResult<()>
162 {
163 if self.sock.is_some() == true
164 {
165 return Ok(());
167 }
168
169 let socket=
170 TcpTlsConnection::connect(self.cfg.as_ref(), conn_timeout).await?;
171
172 self.sock = Some(socket);
173
174 return Ok(());
175 }
176
177 fn is_encrypted(&self) -> bool
178 {
179 return true;
180 }
181
182 fn is_tcp(&self) -> bool
183 {
184 return true;
185 }
186
187 fn should_append_len(&self) -> bool
188 {
189 return true;
190 }
191
192 async
193 fn poll_read(&self) -> CDnsResult<()>
194 {
195 return self.sock.as_ref().unwrap().internal_poll_read(self.timeout).await;
196 }
197
198 async
199 fn send(&mut self, sndbuf: &[u8]) -> CDnsResult<usize>
200 {
201 return
202 self
203 .sock
204 .as_mut()
205 .unwrap()
206 .stream
207 .write_all(sndbuf)
208 .await
209 .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))
210 .map(|_| sndbuf.len());
211 }
212
213 async
214 fn recv(&mut self, rcvbuf: &mut [u8]) -> CDnsResult<usize>
215 {
216 loop
217 {
218 match self.sock.as_mut().unwrap().stream.read(rcvbuf).await
219 {
220 Ok(n) =>
221 {
222 return Ok(n);
223 },
224 Err(ref e) if e.kind() == ErrorKind::WouldBlock =>
225 {
226 internal_error!(CDnsErrorType::RequestTimeout, "request timeout from: '{}'", self.get_remote_addr());
227 },
228 Err(ref e) if e.kind() == ErrorKind::Interrupted =>
229 {
230 continue;
231 },
232 Err(e) =>
233 {
234 internal_error!(CDnsErrorType::IoError, "{}", e);
235 }
236 }
237 }
238 }
239 }
240} #[async_trait]
243impl SocketTap<TokioSocketBase> for NetworkTap<UdpSocket, TokioSocketBase>
244{
245 async
246 fn connect(&mut self, _conn_timeout: Option<Duration>) -> CDnsResult<()>
247 {
248 if self.sock.is_some() == true
249 {
250 return Ok(());
252 }
253
254 let socket =
255 UdpSocket::bind(self.cfg.get_adapter_ip())
256 .await
257 .map_err(|e| internal_error_map!(CDnsErrorType::InternalError, "{}", e))?;
258
259 socket.connect(self.cfg.get_resolver_sa())
260 .await
261 .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
262
263 self.sock = Some(socket);
264
265 return Ok(());
266 }
267
268 fn is_encrypted(&self) -> bool
269 {
270 return false;
271 }
272
273 fn is_tcp(&self) -> bool
274 {
275 return false;
276 }
277
278 fn should_append_len(&self) -> bool
279 {
280 return false;
281 }
282
283 async
284 fn poll_read(&self) -> CDnsResult<()>
285 {
286 timeout(self.timeout, self.sock.as_ref().unwrap().readable())
287 .await
288 .map_err(|e|
289 internal_error_map!(CDnsErrorType::IoError, "Timeout {}", e)
290 )?
291 .map_err(|e|
292 internal_error_map!(CDnsErrorType::IoError, "socket poll error {}", e)
293 )
294 }
295
296 async
297 fn send(&mut self, sndbuf: &[u8]) -> CDnsResult<usize>
298 {
299 return
300 self.sock.as_mut()
301 .unwrap()
302 .send(sndbuf)
303 .await
304 .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e));
305 }
306
307 async
308 fn recv(&mut self, rcvbuf: &mut [u8]) -> CDnsResult<usize>
309 {
310 async
311 fn sub_recv(this: &mut NetworkTap<UdpSocket, TokioSocketBase>, rcvbuf: &mut [u8]) -> CDnsResult<usize>
312 {
313 loop
314 {
315 match this.sock.as_mut().unwrap().recv_from(rcvbuf).await
316 {
317 Ok((rcv_len, rcv_src)) =>
318 {
319 if &rcv_src != this.get_remote_addr()
321 {
322 internal_error!(
323 CDnsErrorType::DnsResponse,
324 "received answer from unknown host: '{}' exp: '{}'",
325 this.get_remote_addr(),
326 rcv_src
327 );
328 }
329
330 return Ok(rcv_len);
331 },
332 Err(ref e) if e.kind() == ErrorKind::WouldBlock =>
333 {
334 continue;
335 },
336 Err(ref e) if e.kind() == ErrorKind::Interrupted =>
337 {
338 continue;
339 },
340 Err(e) =>
341 {
342 internal_error!(CDnsErrorType::IoError, "{}", e);
343 }
344 } } }
348
349 match timeout(self.timeout, sub_recv(self, rcvbuf)).await
351 {
352 Ok(r) =>
353 return r,
354 Err(e) =>
355 internal_error!(CDnsErrorType::RequestTimeout, "{}", e)
356 }
357 }
358}
359
360async
361fn new_tcp_stream(cfg: &ResolveConfEntry, conn_timeout: Option<Duration>) -> CDnsResult<TcpStream>
362{
363 let socket =
365 if cfg.get_resolver_ip().is_ipv4() == true
366 {
367 TcpSocket::new_v4()
368 }
369 else
370 {
371 TcpSocket::new_v6()
372 }
373 .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
374
375 socket.bind(*cfg.get_adapter_ip()).map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
377
378 socket.set_keepalive(false).map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
379
380 socket.set_nodelay(true).map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?;
381
382 let tcpstream =
384 if let Some(c_timeout) = conn_timeout
385 {
386 timeout(c_timeout, socket.connect(*cfg.get_resolver_sa()))
387 .await
388 .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?
389 .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?
390 }
391 else
392 {
393 socket
394 .connect(*cfg.get_resolver_sa())
395 .await
396 .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e))?
397 };
398
399 return Ok(tcpstream);
400}
401
402#[async_trait]
403impl SocketTap<TokioSocketBase> for NetworkTap<TcpStream, TokioSocketBase>
404{
405 async
406 fn connect(&mut self, conn_timeout: Option<Duration>) -> CDnsResult<()>
407 {
408 if self.sock.is_some() == true
409 {
410 return Ok(());
412 }
413
414 let tcpstream = new_tcp_stream(&self.cfg, conn_timeout).await?;
416
417 self.sock = Some(tcpstream);
418
419 return Ok(());
420 }
421
422 fn is_encrypted(&self) -> bool
423 {
424 return false;
425 }
426
427 fn is_tcp(&self) -> bool
428 {
429 return true;
430 }
431
432 fn should_append_len(&self) -> bool
433 {
434 return true;
435 }
436
437 async
438 fn poll_read(&self) -> CDnsResult<()>
439 {
440 timeout(self.timeout, self.sock.as_ref().unwrap().readable())
441 .await
442 .map_err(|e|
443 internal_error_map!(CDnsErrorType::IoError, "Timeout {}", e)
444 )?
445 .map_err(|e|
446 internal_error_map!(CDnsErrorType::IoError, "socket poll error {}", e)
447 )
448 }
449
450 async
451 fn send(&mut self, sndbuf: &[u8]) -> CDnsResult<usize>
452 {
453 return
454 self
455 .sock
456 .as_mut()
457 .unwrap()
458 .write(sndbuf)
459 .await
460 .map_err(|e| internal_error_map!(CDnsErrorType::IoError, "{}", e));
461 }
462
463 async
464 fn recv(&mut self, rcvbuf: &mut [u8]) -> CDnsResult<usize>
465 {
466 async
467 fn sub_recv(this: &mut NetworkTap<TcpStream, TokioSocketBase>, rcvbuf: &mut [u8]) -> CDnsResult<usize>
468 {
469 loop
470 {
471 match this.sock.as_mut().unwrap().read(rcvbuf).await
472 {
473 Ok(n) =>
474 {
475 return Ok(n);
476 },
477 Err(ref e) if e.kind() == ErrorKind::WouldBlock =>
478 {
479 continue;
480 },
481 Err(ref e) if e.kind() == ErrorKind::Interrupted =>
482 {
483 continue;
484 },
485 Err(e) =>
486 {
487 internal_error!(CDnsErrorType::IoError, "{}", e);
488 }
489 } } }
492
493 match timeout(self.timeout, sub_recv(self, rcvbuf)).await
495 {
496 Ok(r) => return r,
497 Err(e) => internal_error!(CDnsErrorType::RequestTimeout, "{}", e)
498 }
499 }
500}
501
502
503
504#[cfg(test)]
505mod tests
506{
507 use std::{net::{IpAddr, SocketAddr}, sync::Arc, time::Duration};
508
509 use tokio::net::UdpSocket;
510
511 use crate::{a_sync::{network::NetworkTap, TokioSocketBase}, cfg_resolv_parser::ResolveConfEntry, common::IPV4_BIND_ALL};
512
513 #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
514 async fn test_struct()
515 {
516
517 let ip0: IpAddr = "127.0.0.1".parse().unwrap();
518 let bind = SocketAddr::from((IPV4_BIND_ALL, 0));
519 let v = Arc::new(ResolveConfEntry::new(SocketAddr::new(ip0, 53), None, bind).unwrap());
520 let res = NetworkTap::<UdpSocket, TokioSocketBase>::new(v, Duration::from_secs(5));
521
522 assert_eq!(res.is_ok(), true, "{}", res.err().unwrap());
523
524 let _res = res.unwrap();
525 }
526}
527