1use std::{
2 io,
3 pin::{Pin, pin},
4 sync::{Arc, atomic::Ordering},
5 task::{Context, Poll},
6 time::Instant,
7};
8
9use ic_bn_lib_common::types::http::{Stats, TlsInfo};
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tokio_rustls::{TlsAcceptor, server::TlsStream};
12
13pub mod listener;
14
15pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin {}
17impl<T: AsyncRead + AsyncWrite + Send + Sync + Unpin> AsyncReadWrite for T {}
18
19pub async fn tls_handshake<T: AsyncReadWrite>(
21 rustls_cfg: Arc<rustls::ServerConfig>,
22 stream: T,
23) -> io::Result<(TlsStream<T>, TlsInfo)> {
24 let tls_acceptor = TlsAcceptor::from(rustls_cfg);
25
26 let start = Instant::now();
28 let stream = tls_acceptor.accept(stream).await?;
29 let duration = start.elapsed();
30
31 let conn = stream.get_ref().1;
33 let mut tls_info = TlsInfo::try_from(conn).map_err(io::Error::other)?;
34 tls_info.handshake_dur = duration;
35
36 Ok((stream, tls_info))
37}
38
39pub struct AsyncCounter<T: AsyncReadWrite> {
41 inner: T,
42 stats: Arc<Stats>,
43}
44
45impl<T: AsyncReadWrite> AsyncCounter<T> {
46 pub fn new(inner: T) -> (Self, Arc<Stats>) {
48 let stats = Arc::new(Stats::new());
49
50 (
51 Self {
52 inner,
53 stats: stats.clone(),
54 },
55 stats,
56 )
57 }
58}
59
60impl<T: AsyncReadWrite> AsyncRead for AsyncCounter<T> {
61 fn poll_read(
62 mut self: Pin<&mut Self>,
63 cx: &mut Context<'_>,
64 buf: &mut ReadBuf<'_>,
65 ) -> Poll<io::Result<()>> {
66 let size_before = buf.filled().len();
67 let poll = pin!(&mut self.inner).poll_read(cx, buf);
68 if matches!(&poll, Poll::Ready(Ok(()))) {
69 let rcvd = buf.filled().len() - size_before;
70 self.stats.rcvd.fetch_add(rcvd as u64, Ordering::SeqCst);
71 }
72
73 poll
74 }
75}
76
77impl<T: AsyncReadWrite> AsyncWrite for AsyncCounter<T> {
78 fn poll_write(
79 mut self: Pin<&mut Self>,
80 cx: &mut Context<'_>,
81 buf: &[u8],
82 ) -> Poll<io::Result<usize>> {
83 let poll = pin!(&mut self.inner).poll_write(cx, buf);
84 if let Poll::Ready(Ok(v)) = &poll {
85 self.stats.sent.fetch_add(*v as u64, Ordering::SeqCst);
86 }
87
88 poll
89 }
90
91 fn poll_shutdown(
92 mut self: Pin<&mut Self>,
93 cx: &mut Context<'_>,
94 ) -> Poll<Result<(), io::Error>> {
95 pin!(&mut self.inner).poll_shutdown(cx)
96 }
97
98 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
99 pin!(&mut self.inner).poll_flush(cx)
100 }
101}