ombrac_server/connection/
mod.rs

1#[cfg(feature = "datagram")]
2mod datagram;
3mod stream;
4
5use std::io;
6use std::sync::Arc;
7
8use futures::{SinkExt, StreamExt};
9use tokio::task::JoinHandle;
10use tokio_util::codec::Framed;
11use tokio_util::sync::CancellationToken;
12#[cfg(feature = "tracing")]
13use tracing::Instrument;
14
15use ombrac::codec;
16use ombrac::protocol;
17use ombrac_macros::{debug, warn};
18use ombrac_transport::Connection;
19
20pub struct ConnectionHandle<C> {
21    inner: Arc<C>,
22}
23
24impl<C: Connection> ConnectionHandle<C> {
25    pub fn close(&self, error_code: u32, reason: &[u8]) {
26        self.inner.close(error_code, reason);
27    }
28}
29
30pub trait ConnectionHandler<T>: Send + Sync {
31    type Context: Send;
32
33    fn verify(
34        &self,
35        hello: &protocol::ClientHello,
36    ) -> Result<Self::Context, protocol::HandshakeError>;
37
38    fn accept(&self, output: Self::Context, connection: ConnectionHandle<T>);
39}
40
41impl<T> ConnectionHandler<T> for ombrac::protocol::Secret {
42    type Context = ();
43
44    fn verify(&self, hello: &protocol::ClientHello) -> Result<(), protocol::HandshakeError> {
45        if &hello.secret == self {
46            Ok(())
47        } else {
48            Err(protocol::HandshakeError::InvalidSecret)
49        }
50    }
51
52    fn accept(&self, _output: Self::Context, _connection: ConnectionHandle<T>) {
53        ()
54    }
55}
56
57pub struct ClientConnection<C: Connection> {
58    client_connection: Arc<C>,
59    shutdown_token: CancellationToken,
60}
61
62impl<C: Connection> ClientConnection<C> {
63    pub async fn handle<V>(connection: C, validator: &V) -> io::Result<()>
64    where
65        V: ConnectionHandler<C>,
66    {
67        let (validation_ctx, connection) = Self::perform_handshake(connection, validator).await?;
68
69        let client_connection = Arc::new(connection);
70
71        validator.accept(
72            validation_ctx,
73            ConnectionHandle {
74                inner: client_connection.clone(),
75            },
76        );
77
78        let handler = Self {
79            client_connection,
80            shutdown_token: CancellationToken::new(),
81        };
82
83        handler.run_acceptor_loops().await;
84
85        Ok(())
86    }
87
88    async fn perform_handshake<V>(connection: C, validator: &V) -> io::Result<(V::Context, C)>
89    where
90        V: ConnectionHandler<C>,
91    {
92        let mut control_stream = connection.accept_bidirectional().await?;
93        let mut control_frame = Framed::new(&mut control_stream, codec::length_codec());
94
95        let payload = match control_frame.next().await {
96            Some(Ok(bytes)) => bytes,
97            Some(Err(e)) => return Err(e),
98            None => {
99                return Err(io::Error::new(
100                    io::ErrorKind::UnexpectedEof,
101                    "Stream closed before hello",
102                ));
103            }
104        };
105
106        let message: codec::UpstreamMessage = protocol::decode(&payload)?;
107
108        let hello = match message {
109            codec::UpstreamMessage::Hello(h) => h,
110            _ => {
111                return Err(io::Error::new(
112                    io::ErrorKind::InvalidData,
113                    "Expected Hello message",
114                ));
115            }
116        };
117
118        #[cfg(feature = "tracing")]
119        Self::trace_handshake(&hello);
120
121        let validation_result = if hello.version != protocol::PROTOCOLS_VERSION {
122            Err(protocol::HandshakeError::UnsupportedVersion)
123        } else {
124            validator.verify(&hello)
125        };
126
127        let response = match validation_result {
128            Ok(_) => protocol::ServerHandshakeResponse::Ok,
129            Err(ref e) => protocol::ServerHandshakeResponse::Err(e.clone()),
130        };
131
132        control_frame.send(protocol::encode(&response)?).await?;
133
134        match validation_result {
135            Ok(ctx) => Ok((ctx, connection)),
136            Err(e) => Err(io::Error::new(
137                io::ErrorKind::PermissionDenied,
138                format!("Handshake failed: {:?}", e),
139            )),
140        }
141    }
142
143    async fn run_acceptor_loops(&self) {
144        let connect_acceptor = self.spawn_client_connect_acceptor();
145        #[cfg(feature = "datagram")]
146        let datagram_acceptor = self.spawn_client_datagram_acceptor();
147
148        #[cfg(not(feature = "datagram"))]
149        let result = connect_acceptor.await;
150
151        #[cfg(feature = "datagram")]
152        let result = tokio::select! {
153            res = connect_acceptor => res,
154            res = datagram_acceptor => res,
155        };
156
157        // Signal all related tasks to shut down
158        self.shutdown_token.cancel();
159
160        match result {
161            Ok(Ok(_)) => debug!("Connection closed gracefully."),
162            Ok(Err(e)) => debug!("Connection closed with internal error: {}", e),
163            Err(e) => warn!("Connection handler task panicked or failed: {}", e),
164        }
165    }
166
167    fn spawn_client_connect_acceptor(&self) -> JoinHandle<io::Result<()>> {
168        use crate::connection::stream::StreamTunnel;
169
170        let connection = Arc::clone(&self.client_connection);
171        let shutdown = self.shutdown_token.child_token();
172        let tunnel = StreamTunnel::new(connection, shutdown);
173
174        #[cfg(not(feature = "tracing"))]
175        let handle = tokio::spawn(tunnel.accept_loop());
176        #[cfg(feature = "tracing")]
177        let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
178
179        handle
180    }
181
182    #[cfg(feature = "datagram")]
183    fn spawn_client_datagram_acceptor(&self) -> JoinHandle<io::Result<()>> {
184        use crate::connection::datagram::DatagramTunnel;
185
186        let connection = Arc::clone(&self.client_connection);
187        let shutdown = self.shutdown_token.child_token();
188        let tunnel = DatagramTunnel::new(connection, shutdown);
189
190        #[cfg(not(feature = "tracing"))]
191        let handle = tokio::spawn(tunnel.accept_loop());
192        #[cfg(feature = "tracing")]
193        let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
194
195        handle
196    }
197
198    #[cfg(feature = "tracing")]
199    fn trace_handshake(hello: &protocol::ClientHello) {
200        let secret_hex = hello
201            .secret
202            .iter()
203            .map(|b| format!("{:02x}", b))
204            .collect::<String>();
205        tracing::Span::current().record("secret", &secret_hex);
206    }
207}