geph5-client 0.2.102

Geph5 client
Documentation
use std::io;

use async_trait::async_trait;
use sillad::{Pipe, dialer::Dialer};

pub fn logged<D: Dialer>(
    name: &'static str,
    route_subtree_json: String,
    inner: D,
) -> LoggedDialer<D> {
    LoggedDialer {
        name,
        route_subtree_json,
        inner,
    }
}

pub struct LoggedDialer<D: Dialer> {
    name: &'static str,
    route_subtree_json: String,
    inner: D,
}

#[async_trait]
impl<D: Dialer> Dialer for LoggedDialer<D> {
    type P = D::P;

    async fn dial(&self) -> io::Result<Self::P> {
        match self.inner.dial().await {
            Ok(pipe) => {
                tracing::debug!(
                    stage = self.name,
                    remote_addr = pipe.remote_addr(),
                    "dial stage succeeded"
                );
                Ok(pipe)
            }
            Err(err) => {
                tracing::debug!(
                    stage = self.name,
                    route_subtree = %self.route_subtree_json,
                    err = debug(&err),
                    "dial stage failed"
                );
                Err(err)
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use std::{
        io,
        pin::Pin,
        sync::{Arc, Mutex},
        task::{Context, Poll},
    };

    use async_trait::async_trait;
    use futures_util::{AsyncRead, AsyncWrite};
    use tracing_subscriber::layer::SubscriberExt as _;

    use super::*;

    #[derive(Clone)]
    struct TestPipe {
        remote_addr: Option<&'static str>,
    }

    impl AsyncRead for TestPipe {
        fn poll_read(
            self: Pin<&mut Self>,
            _cx: &mut Context<'_>,
            _buf: &mut [u8],
        ) -> Poll<io::Result<usize>> {
            Poll::Ready(Ok(0))
        }
    }

    impl AsyncWrite for TestPipe {
        fn poll_write(
            self: Pin<&mut Self>,
            _cx: &mut Context<'_>,
            _buf: &[u8],
        ) -> Poll<io::Result<usize>> {
            Poll::Ready(Ok(0))
        }

        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
            Poll::Ready(Ok(()))
        }

        fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
            Poll::Ready(Ok(()))
        }
    }

    impl Pipe for TestPipe {
        fn protocol(&self) -> &str {
            "test"
        }

        fn remote_addr(&self) -> Option<&str> {
            self.remote_addr
        }
    }

    struct OkDialer {
        remote_addr: Option<&'static str>,
    }

    #[async_trait]
    impl Dialer for OkDialer {
        type P = TestPipe;

        async fn dial(&self) -> io::Result<Self::P> {
            Ok(TestPipe {
                remote_addr: self.remote_addr,
            })
        }
    }

    struct ErrDialer;

    #[async_trait]
    impl Dialer for ErrDialer {
        type P = TestPipe;

        async fn dial(&self) -> io::Result<Self::P> {
            Err(io::Error::other("boom"))
        }
    }

    #[derive(Clone, Default)]
    struct TestWriter {
        buf: Arc<Mutex<Vec<u8>>>,
    }

    impl io::Write for TestWriter {
        fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
            self.buf.lock().unwrap().extend_from_slice(buf);
            Ok(buf.len())
        }

        fn flush(&mut self) -> io::Result<()> {
            Ok(())
        }
    }

    fn capture_logs<T>(f: impl FnOnce() -> T) -> (String, T) {
        let writer = TestWriter::default();
        let captured = writer.buf.clone();
        let subscriber = tracing_subscriber::registry().with(
            tracing_subscriber::fmt::layer()
                .without_time()
                .with_ansi(false)
                .compact()
                .with_writer(move || writer.clone()),
        );

        let result = tracing::subscriber::with_default(subscriber, f);
        let logs = String::from_utf8(captured.lock().unwrap().clone()).unwrap();
        (logs, result)
    }

    #[test]
    fn logged_dialer_logs_success() {
        let (logs, result) = capture_logs(|| {
            let dialer = logged(
                "tcp",
                r#"{"tcp":"127.0.0.1:9000"}"#.to_string(),
                OkDialer {
                    remote_addr: Some("127.0.0.1:9000"),
                },
            );
            smolscale::block_on(async move { dialer.dial().await })
        });

        assert!(result.is_ok());
        assert!(logs.contains("dial stage succeeded"));
        assert!(logs.contains("tcp"));
        assert!(logs.contains("127.0.0.1:9000"));
    }

    #[test]
    fn logged_dialer_logs_failure() {
        let (logs, result) = capture_logs(|| {
            let dialer = logged(
                "tls",
                r#"{"plain_tls":{"lower":"..."}}"#.to_string(),
                ErrDialer,
            );
            smolscale::block_on(async move { dialer.dial().await })
        });

        assert!(result.is_err());
        assert!(logs.contains("dial stage failed"));
        assert!(logs.contains("tls"));
        assert!(logs.contains(r#"{"plain_tls":{"lower":"..."}}"#));
        assert!(logs.contains("boom"));
    }

    #[test]
    fn nested_logged_dialers_cascade_failures() {
        let (logs, result) = capture_logs(|| {
            let dialer = logged(
                "overall",
                r#"{"overall":true}"#.to_string(),
                logged(
                    "sosistab3",
                    r#"{"sosistab3":{"lower":"..."}}"#.to_string(),
                    logged("tcp", r#"{"tcp":"127.0.0.1:9000"}"#.to_string(), ErrDialer),
                ),
            );
            smolscale::block_on(async move { dialer.dial().await })
        });

        assert!(result.is_err());
        assert!(logs.contains("tcp"));
        assert!(logs.contains("sosistab3"));
        assert!(logs.contains("overall"));
    }

    #[test]
    fn nested_logged_dialers_log_layered_success() {
        let (logs, result) = capture_logs(|| {
            let dialer = logged(
                "overall",
                r#"{"overall":true}"#.to_string(),
                logged(
                    "sosistab3",
                    r#"{"sosistab3":{"lower":"..."}}"#.to_string(),
                    logged(
                        "tcp",
                        r#"{"tcp":"127.0.0.1:9000"}"#.to_string(),
                        OkDialer {
                            remote_addr: Some("127.0.0.1:9000"),
                        },
                    ),
                ),
            );
            smolscale::block_on(async move { dialer.dial().await })
        });

        assert!(result.is_ok());
        assert!(logs.contains("tcp"));
        assert!(logs.contains("sosistab3"));
        assert!(logs.contains("overall"));
    }
}