Skip to main content

ic_bn_lib/network/
mod.rs

1use std::{
2    io,
3    pin::{Pin, pin},
4    sync::{Arc, atomic::Ordering},
5    task::{Context, Poll},
6    time::Instant,
7};
8
9use ic_bn_lib_common::types::http::{Stats, TlsInfo};
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tokio_rustls::{TlsAcceptor, server::TlsStream};
12
13pub mod listener;
14
15/// Blanket async read+write trait for streams `Box`-ing
16pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin {}
17impl<T: AsyncRead + AsyncWrite + Send + Sync + Unpin> AsyncReadWrite for T {}
18
19/// Performs TLS handshake on the given stream
20pub async fn tls_handshake<T: AsyncReadWrite>(
21    rustls_cfg: Arc<rustls::ServerConfig>,
22    stream: T,
23) -> io::Result<(TlsStream<T>, TlsInfo)> {
24    let tls_acceptor = TlsAcceptor::from(rustls_cfg);
25
26    // Perform the TLS handshake
27    let start = Instant::now();
28    let stream = tls_acceptor.accept(stream).await?;
29    let duration = start.elapsed();
30
31    // Obtain TLS info
32    let conn = stream.get_ref().1;
33    let mut tls_info = TlsInfo::try_from(conn).map_err(io::Error::other)?;
34    tls_info.handshake_dur = duration;
35
36    Ok((stream, tls_info))
37}
38
39/// Async read+write wrapper that counts bytes read/written
40pub struct AsyncCounter<T: AsyncReadWrite> {
41    inner: T,
42    stats: Arc<Stats>,
43}
44
45impl<T: AsyncReadWrite> AsyncCounter<T> {
46    /// Create new `AsyncCounter`
47    pub fn new(inner: T) -> (Self, Arc<Stats>) {
48        let stats = Arc::new(Stats::new());
49
50        (
51            Self {
52                inner,
53                stats: stats.clone(),
54            },
55            stats,
56        )
57    }
58}
59
60impl<T: AsyncReadWrite> AsyncRead for AsyncCounter<T> {
61    fn poll_read(
62        mut self: Pin<&mut Self>,
63        cx: &mut Context<'_>,
64        buf: &mut ReadBuf<'_>,
65    ) -> Poll<io::Result<()>> {
66        let size_before = buf.filled().len();
67        let poll = pin!(&mut self.inner).poll_read(cx, buf);
68        if matches!(&poll, Poll::Ready(Ok(()))) {
69            let rcvd = buf.filled().len() - size_before;
70            self.stats.rcvd.fetch_add(rcvd as u64, Ordering::SeqCst);
71        }
72
73        poll
74    }
75}
76
77impl<T: AsyncReadWrite> AsyncWrite for AsyncCounter<T> {
78    fn poll_write(
79        mut self: Pin<&mut Self>,
80        cx: &mut Context<'_>,
81        buf: &[u8],
82    ) -> Poll<io::Result<usize>> {
83        let poll = pin!(&mut self.inner).poll_write(cx, buf);
84        if let Poll::Ready(Ok(v)) = &poll {
85            self.stats.sent.fetch_add(*v as u64, Ordering::SeqCst);
86        }
87
88        poll
89    }
90
91    fn poll_shutdown(
92        mut self: Pin<&mut Self>,
93        cx: &mut Context<'_>,
94    ) -> Poll<Result<(), io::Error>> {
95        pin!(&mut self.inner).poll_shutdown(cx)
96    }
97
98    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
99        pin!(&mut self.inner).poll_flush(cx)
100    }
101}