ombrac_server/connection/
mod.rs1#[cfg(feature = "datagram")]
2mod datagram;
3mod stream;
4
5use std::future::Future;
6use std::io;
7use std::sync::Arc;
8use std::sync::Weak;
9
10use futures::{SinkExt, StreamExt};
11use tokio::task::JoinHandle;
12use tokio_util::codec::Framed;
13use tokio_util::sync::CancellationToken;
14#[cfg(feature = "tracing")]
15use tracing::Instrument;
16
17use ombrac::codec;
18use ombrac::protocol;
19use ombrac_macros::{debug, warn};
20use ombrac_transport::Connection;
21
22pub struct ConnectionHandle<C> {
23 inner: Arc<C>,
24}
25
26impl<C: Connection> ConnectionHandle<C> {
27 pub fn downgrade_inner(&self) -> Weak<C> {
28 Arc::downgrade(&self.inner)
29 }
30
31 pub fn close(&self, error_code: u32, reason: &[u8]) {
32 self.inner.close(error_code, reason);
33 }
34}
35
36pub trait ConnectionHandler<T>: Send + Sync {
37 type Context: Send;
38
39 fn verify(
40 &self,
41 hello: &protocol::ClientHello,
42 ) -> impl Future<Output = Result<Self::Context, protocol::HandshakeError>> + Send;
43
44 fn accept(
45 &self,
46 output: Self::Context,
47 connection: ConnectionHandle<T>,
48 ) -> impl Future<Output = ()> + Send;
49}
50
51impl<T: Send + Sync> ConnectionHandler<T> for ombrac::protocol::Secret {
52 type Context = ();
53
54 async fn verify(&self, hello: &protocol::ClientHello) -> Result<(), protocol::HandshakeError> {
55 if &hello.secret == self {
56 Ok(())
57 } else {
58 Err(protocol::HandshakeError::InvalidSecret)
59 }
60 }
61
62 async fn accept(&self, _output: Self::Context, _connection: ConnectionHandle<T>) {
63 ()
64 }
65}
66
67pub struct ConnectionDriver<C: Connection> {
68 client_connection: Arc<C>,
69 shutdown_token: CancellationToken,
70}
71
72impl<C: Connection> ConnectionDriver<C> {
73 pub async fn handle<V>(connection: C, validator: &V) -> io::Result<()>
74 where
75 V: ConnectionHandler<C>,
76 {
77 let (validation_ctx, connection) = Self::perform_handshake(connection, validator).await?;
78
79 let client_connection = Arc::new(connection);
80
81 validator
82 .accept(
83 validation_ctx,
84 ConnectionHandle {
85 inner: client_connection.clone(),
86 },
87 )
88 .await;
89
90 let handler = Self {
91 client_connection,
92 shutdown_token: CancellationToken::new(),
93 };
94
95 handler.run_acceptor_loops().await;
96
97 Ok(())
98 }
99
100 async fn perform_handshake<V>(connection: C, validator: &V) -> io::Result<(V::Context, C)>
101 where
102 V: ConnectionHandler<C>,
103 {
104 let mut control_stream = connection.accept_bidirectional().await?;
105 let mut control_frame = Framed::new(&mut control_stream, codec::length_codec());
106
107 let payload = match control_frame.next().await {
108 Some(Ok(bytes)) => bytes,
109 Some(Err(e)) => return Err(e),
110 None => {
111 return Err(io::Error::new(
112 io::ErrorKind::UnexpectedEof,
113 "Stream closed before hello",
114 ));
115 }
116 };
117
118 let message: codec::UpstreamMessage = protocol::decode(&payload)?;
119
120 let hello = match message {
121 codec::UpstreamMessage::Hello(h) => h,
122 _ => {
123 return Err(io::Error::new(
124 io::ErrorKind::InvalidData,
125 "Expected Hello message",
126 ));
127 }
128 };
129
130 #[cfg(feature = "tracing")]
131 Self::trace_handshake(&hello);
132
133 let validation_result = if hello.version != protocol::PROTOCOLS_VERSION {
134 Err(protocol::HandshakeError::UnsupportedVersion)
135 } else {
136 validator.verify(&hello).await
137 };
138
139 let response = match validation_result {
140 Ok(_) => protocol::ServerHandshakeResponse::Ok,
141 Err(ref e) => protocol::ServerHandshakeResponse::Err(e.clone()),
142 };
143
144 control_frame.send(protocol::encode(&response)?).await?;
145
146 match validation_result {
147 Ok(ctx) => Ok((ctx, connection)),
148 Err(e) => Err(io::Error::new(
149 io::ErrorKind::PermissionDenied,
150 format!("Handshake failed: {:?}", e),
151 )),
152 }
153 }
154
155 async fn run_acceptor_loops(&self) {
156 let connect_acceptor = self.spawn_client_connect_acceptor();
157 #[cfg(feature = "datagram")]
158 let datagram_acceptor = self.spawn_client_datagram_acceptor();
159
160 #[cfg(not(feature = "datagram"))]
161 let result = connect_acceptor.await;
162
163 #[cfg(feature = "datagram")]
164 let result = tokio::select! {
165 res = connect_acceptor => res,
166 res = datagram_acceptor => res,
167 };
168
169 self.shutdown_token.cancel();
171
172 match result {
173 Ok(Ok(_)) => debug!("Connection closed gracefully."),
174 Ok(Err(e)) => debug!("Connection closed with internal error: {}", e),
175 Err(e) => warn!("Connection handler task panicked or failed: {}", e),
176 }
177 }
178
179 fn spawn_client_connect_acceptor(&self) -> JoinHandle<io::Result<()>> {
180 use crate::connection::stream::StreamTunnel;
181
182 let connection = Arc::clone(&self.client_connection);
183 let shutdown = self.shutdown_token.child_token();
184 let tunnel = StreamTunnel::new(connection, shutdown);
185
186 #[cfg(not(feature = "tracing"))]
187 let handle = tokio::spawn(tunnel.accept_loop());
188 #[cfg(feature = "tracing")]
189 let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
190
191 handle
192 }
193
194 #[cfg(feature = "datagram")]
195 fn spawn_client_datagram_acceptor(&self) -> JoinHandle<io::Result<()>> {
196 use crate::connection::datagram::DatagramTunnel;
197
198 let connection = Arc::clone(&self.client_connection);
199 let shutdown = self.shutdown_token.child_token();
200 let tunnel = DatagramTunnel::new(connection, shutdown);
201
202 #[cfg(not(feature = "tracing"))]
203 let handle = tokio::spawn(tunnel.accept_loop());
204 #[cfg(feature = "tracing")]
205 let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
206
207 handle
208 }
209
210 #[cfg(feature = "tracing")]
211 fn trace_handshake(hello: &protocol::ClientHello) {
212 let secret_hex = hello
213 .secret
214 .iter()
215 .map(|b| format!("{:02x}", b))
216 .collect::<String>();
217 tracing::Span::current().record("secret", &secret_hex);
218 }
219}