ombrac_server/connection/
mod.rs

1#[cfg(feature = "datagram")]
2mod datagram;
3mod stream;
4
5use std::io;
6use std::sync::Arc;
7
8use futures::{SinkExt, StreamExt};
9use tokio::task::JoinHandle;
10use tokio_util::codec::Framed;
11use tokio_util::sync::CancellationToken;
12#[cfg(feature = "tracing")]
13use tracing::Instrument;
14
15use ombrac::codec;
16use ombrac::protocol;
17use ombrac_macros::{debug, warn};
18use ombrac_transport::Connection;
19
20pub trait HandshakeValidator: Send + Sync {
21    fn validate_hello(&self, hello: &protocol::ClientHello)
22    -> Result<(), protocol::HandshakeError>;
23}
24
25impl HandshakeValidator for ombrac::protocol::Secret {
26    fn validate_hello(
27        &self,
28        hello: &protocol::ClientHello,
29    ) -> Result<(), protocol::HandshakeError> {
30        if &hello.secret == self {
31            Ok(())
32        } else {
33            Err(protocol::HandshakeError::InvalidSecret)
34        }
35    }
36}
37
38pub struct ClientConnection<C: Connection> {
39    client_connection: Arc<C>,
40    shutdown_token: CancellationToken,
41}
42
43impl<C: Connection> ClientConnection<C> {
44    pub async fn handle<V: HandshakeValidator>(connection: C, validator: &V) -> io::Result<()> {
45        let mut control_stream = connection.accept_bidirectional().await?;
46        let mut control_frame = Framed::new(&mut control_stream, codec::length_codec());
47
48        match control_frame.next().await {
49            Some(Ok(payload)) => {
50                let hello_message: codec::UpstreamMessage = protocol::decode(&payload)?;
51
52                if let codec::UpstreamMessage::Hello(hello) = &hello_message {
53                    #[cfg(feature = "tracing")]
54                    {
55                        let secret_hex = hello
56                            .secret
57                            .iter()
58                            .map(|b| format!("{:02x}", b))
59                            .collect::<String>();
60                        tracing::span::Span::current().record("secret", &secret_hex);
61                    }
62
63                    let response = if hello.version != protocol::PROTOCOLS_VERSION {
64                        protocol::ServerHandshakeResponse::Err(
65                            protocol::HandshakeError::UnsupportedVersion,
66                        )
67                    } else {
68                        match validator.validate_hello(hello) {
69                            Ok(_) => protocol::ServerHandshakeResponse::Ok,
70                            Err(e) => protocol::ServerHandshakeResponse::Err(e),
71                        }
72                    };
73
74                    let response_payload = protocol::encode(&response)?;
75                    control_frame.send(response_payload).await?;
76
77                    if let protocol::ServerHandshakeResponse::Err(e) = response {
78                        return Err(io::Error::new(
79                            io::ErrorKind::PermissionDenied,
80                            format!("handshake validation failed: {:?}", e),
81                        ));
82                    }
83                }
84            }
85            _ => {
86                return Err(io::Error::new(
87                    io::ErrorKind::InvalidData,
88                    "failed to read hello message",
89                ));
90            }
91        }
92
93        let handler = Self {
94            client_connection: Arc::new(connection),
95            shutdown_token: CancellationToken::new(),
96        };
97
98        handler.manage_acceptor_loops().await;
99
100        Ok(())
101    }
102
103    async fn manage_acceptor_loops(&self) {
104        let connect_acceptor = self.spawn_client_connect_acceptor();
105        #[cfg(feature = "datagram")]
106        let datagram_acceptor = self.spawn_client_datagram_acceptor();
107
108        #[cfg(not(feature = "datagram"))]
109        let result = connect_acceptor.await;
110
111        #[cfg(feature = "datagram")]
112        let result = tokio::select! {
113            res = connect_acceptor => res,
114            res = datagram_acceptor => res,
115        };
116
117        // Signal all related tasks to shut down
118        self.shutdown_token.cancel();
119
120        match result {
121            Ok(Ok(_)) => {
122                debug!("connection closed gracefully.");
123            }
124            Ok(Err(_err)) => {
125                debug!("connection closed with an error: {_err}");
126            }
127            Err(_err) => {
128                warn!("connection handler task failed: {_err}");
129            }
130        }
131    }
132
133    fn spawn_client_connect_acceptor(&self) -> JoinHandle<io::Result<()>> {
134        use crate::connection::stream::StreamTunnel;
135
136        let connection = Arc::clone(&self.client_connection);
137        let shutdown = self.shutdown_token.child_token();
138        let tunnel = StreamTunnel::new(connection, shutdown);
139
140        #[cfg(not(feature = "tracing"))]
141        let handle = tokio::spawn(tunnel.accept_loop());
142        #[cfg(feature = "tracing")]
143        let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
144
145        handle
146    }
147
148    #[cfg(feature = "datagram")]
149    fn spawn_client_datagram_acceptor(&self) -> JoinHandle<io::Result<()>> {
150        use crate::connection::datagram::DatagramTunnel;
151
152        let connection = Arc::clone(&self.client_connection);
153        let shutdown = self.shutdown_token.child_token();
154        let tunnel = DatagramTunnel::new(connection, shutdown);
155
156        #[cfg(not(feature = "tracing"))]
157        let handle = tokio::spawn(tunnel.accept_loop());
158        #[cfg(feature = "tracing")]
159        let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
160
161        handle
162    }
163}