ombrac_server/connection/
mod.rs

1#[cfg(feature = "datagram")]
2mod datagram;
3mod stream;
4
5use std::future::Future;
6use std::io;
7use std::sync::Arc;
8use std::sync::Weak;
9
10use futures::{SinkExt, StreamExt};
11use tokio::task::JoinHandle;
12use tokio_util::codec::Framed;
13use tokio_util::sync::CancellationToken;
14#[cfg(feature = "tracing")]
15use tracing::Instrument;
16
17use ombrac::codec;
18use ombrac::protocol;
19use ombrac_macros::{debug, warn};
20use ombrac_transport::Connection;
21
22pub struct ConnectionHandle<C> {
23    inner: Arc<C>,
24}
25
26impl<C: Connection> ConnectionHandle<C> {
27    pub fn downgrade_inner(&self) -> Weak<C> {
28        Arc::downgrade(&self.inner)
29    }
30
31    pub fn close(&self, error_code: u32, reason: &[u8]) {
32        self.inner.close(error_code, reason);
33    }
34}
35
36pub trait ConnectionHandler<T>: Send + Sync {
37    type Context: Send;
38
39    fn verify(
40        &self,
41        hello: &protocol::ClientHello,
42    ) -> impl Future<Output = Result<Self::Context, protocol::HandshakeError>> + Send;
43
44    fn accept(
45        &self,
46        output: Self::Context,
47        connection: ConnectionHandle<T>,
48    ) -> impl Future<Output = ()> + Send;
49}
50
51impl<T: Send + Sync> ConnectionHandler<T> for ombrac::protocol::Secret {
52    type Context = ();
53
54    async fn verify(&self, hello: &protocol::ClientHello) -> Result<(), protocol::HandshakeError> {
55        if &hello.secret == self {
56            Ok(())
57        } else {
58            Err(protocol::HandshakeError::InvalidSecret)
59        }
60    }
61
62    async fn accept(&self, _output: Self::Context, _connection: ConnectionHandle<T>) {
63        ()
64    }
65}
66
67pub struct ConnectionDriver<C: Connection> {
68    client_connection: Arc<C>,
69    shutdown_token: CancellationToken,
70}
71
72impl<C: Connection> ConnectionDriver<C> {
73    pub async fn handle<V>(connection: C, validator: &V) -> io::Result<()>
74    where
75        V: ConnectionHandler<C>,
76    {
77        let (validation_ctx, connection) = Self::perform_handshake(connection, validator).await?;
78
79        let client_connection = Arc::new(connection);
80
81        validator
82            .accept(
83                validation_ctx,
84                ConnectionHandle {
85                    inner: client_connection.clone(),
86                },
87            )
88            .await;
89
90        let handler = Self {
91            client_connection,
92            shutdown_token: CancellationToken::new(),
93        };
94
95        handler.run_acceptor_loops().await;
96
97        Ok(())
98    }
99
100    async fn perform_handshake<V>(connection: C, validator: &V) -> io::Result<(V::Context, C)>
101    where
102        V: ConnectionHandler<C>,
103    {
104        let mut control_stream = connection.accept_bidirectional().await?;
105        let mut control_frame = Framed::new(&mut control_stream, codec::length_codec());
106
107        let payload = match control_frame.next().await {
108            Some(Ok(bytes)) => bytes,
109            Some(Err(e)) => return Err(e),
110            None => {
111                return Err(io::Error::new(
112                    io::ErrorKind::UnexpectedEof,
113                    "Stream closed before hello",
114                ));
115            }
116        };
117
118        let message: codec::UpstreamMessage = protocol::decode(&payload)?;
119
120        let hello = match message {
121            codec::UpstreamMessage::Hello(h) => h,
122            _ => {
123                return Err(io::Error::new(
124                    io::ErrorKind::InvalidData,
125                    "Expected Hello message",
126                ));
127            }
128        };
129
130        #[cfg(feature = "tracing")]
131        Self::trace_handshake(&hello);
132
133        let validation_result = if hello.version != protocol::PROTOCOLS_VERSION {
134            Err(protocol::HandshakeError::UnsupportedVersion)
135        } else {
136            validator.verify(&hello).await
137        };
138
139        let response = match validation_result {
140            Ok(_) => protocol::ServerHandshakeResponse::Ok,
141            Err(ref e) => protocol::ServerHandshakeResponse::Err(e.clone()),
142        };
143
144        control_frame.send(protocol::encode(&response)?).await?;
145
146        match validation_result {
147            Ok(ctx) => Ok((ctx, connection)),
148            Err(e) => Err(io::Error::new(
149                io::ErrorKind::PermissionDenied,
150                format!("Handshake failed: {:?}", e),
151            )),
152        }
153    }
154
155    async fn run_acceptor_loops(&self) {
156        let connect_acceptor = self.spawn_client_connect_acceptor();
157        #[cfg(feature = "datagram")]
158        let datagram_acceptor = self.spawn_client_datagram_acceptor();
159
160        #[cfg(not(feature = "datagram"))]
161        let result = connect_acceptor.await;
162
163        #[cfg(feature = "datagram")]
164        let result = tokio::select! {
165            res = connect_acceptor => res,
166            res = datagram_acceptor => res,
167        };
168
169        // Signal all related tasks to shut down
170        self.shutdown_token.cancel();
171
172        match result {
173            Ok(Ok(_)) => debug!("Connection closed gracefully."),
174            Ok(Err(e)) => debug!("Connection closed with internal error: {}", e),
175            Err(e) => warn!("Connection handler task panicked or failed: {}", e),
176        }
177    }
178
179    fn spawn_client_connect_acceptor(&self) -> JoinHandle<io::Result<()>> {
180        use crate::connection::stream::StreamTunnel;
181
182        let connection = Arc::clone(&self.client_connection);
183        let shutdown = self.shutdown_token.child_token();
184        let tunnel = StreamTunnel::new(connection, shutdown);
185
186        #[cfg(not(feature = "tracing"))]
187        let handle = tokio::spawn(tunnel.accept_loop());
188        #[cfg(feature = "tracing")]
189        let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
190
191        handle
192    }
193
194    #[cfg(feature = "datagram")]
195    fn spawn_client_datagram_acceptor(&self) -> JoinHandle<io::Result<()>> {
196        use crate::connection::datagram::DatagramTunnel;
197
198        let connection = Arc::clone(&self.client_connection);
199        let shutdown = self.shutdown_token.child_token();
200        let tunnel = DatagramTunnel::new(connection, shutdown);
201
202        #[cfg(not(feature = "tracing"))]
203        let handle = tokio::spawn(tunnel.accept_loop());
204        #[cfg(feature = "tracing")]
205        let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
206
207        handle
208    }
209
210    #[cfg(feature = "tracing")]
211    fn trace_handshake(hello: &protocol::ClientHello) {
212        let secret_hex = hello
213            .secret
214            .iter()
215            .map(|b| format!("{:02x}", b))
216            .collect::<String>();
217        tracing::Span::current().record("secret", &secret_hex);
218    }
219}