ombrac_server/connection/
mod.rs1#[cfg(feature = "datagram")]
2mod datagram;
3mod stream;
4
5use std::io;
6use std::sync::Arc;
7
8use futures::{SinkExt, StreamExt};
9use tokio::task::JoinHandle;
10use tokio_util::codec::Framed;
11use tokio_util::sync::CancellationToken;
12#[cfg(feature = "tracing")]
13use tracing::Instrument;
14
15use ombrac::codec;
16use ombrac::protocol;
17use ombrac_macros::{debug, warn};
18use ombrac_transport::Connection;
19
20pub struct ConnectionHandle<C> {
21 inner: Arc<C>,
22}
23
24impl<C: Connection> ConnectionHandle<C> {
25 pub fn close(&self, error_code: u32, reason: &[u8]) {
26 self.inner.close(error_code, reason);
27 }
28}
29
30pub trait ConnectionHandler<T>: Send + Sync {
31 type Context: Send;
32
33 fn verify(
34 &self,
35 hello: &protocol::ClientHello,
36 ) -> Result<Self::Context, protocol::HandshakeError>;
37
38 fn accept(&self, output: Self::Context, connection: ConnectionHandle<T>);
39}
40
41impl<T> ConnectionHandler<T> for ombrac::protocol::Secret {
42 type Context = ();
43
44 fn verify(&self, hello: &protocol::ClientHello) -> Result<(), protocol::HandshakeError> {
45 if &hello.secret == self {
46 Ok(())
47 } else {
48 Err(protocol::HandshakeError::InvalidSecret)
49 }
50 }
51
52 fn accept(&self, _output: Self::Context, _connection: ConnectionHandle<T>) {
53 ()
54 }
55}
56
57pub struct ClientConnection<C: Connection> {
58 client_connection: Arc<C>,
59 shutdown_token: CancellationToken,
60}
61
62impl<C: Connection> ClientConnection<C> {
63 pub async fn handle<V>(connection: C, validator: &V) -> io::Result<()>
64 where
65 V: ConnectionHandler<C>,
66 {
67 let (validation_ctx, connection) = Self::perform_handshake(connection, validator).await?;
68
69 let client_connection = Arc::new(connection);
70
71 validator.accept(
72 validation_ctx,
73 ConnectionHandle {
74 inner: client_connection.clone(),
75 },
76 );
77
78 let handler = Self {
79 client_connection,
80 shutdown_token: CancellationToken::new(),
81 };
82
83 handler.run_acceptor_loops().await;
84
85 Ok(())
86 }
87
88 async fn perform_handshake<V>(connection: C, validator: &V) -> io::Result<(V::Context, C)>
89 where
90 V: ConnectionHandler<C>,
91 {
92 let mut control_stream = connection.accept_bidirectional().await?;
93 let mut control_frame = Framed::new(&mut control_stream, codec::length_codec());
94
95 let payload = match control_frame.next().await {
96 Some(Ok(bytes)) => bytes,
97 Some(Err(e)) => return Err(e),
98 None => {
99 return Err(io::Error::new(
100 io::ErrorKind::UnexpectedEof,
101 "Stream closed before hello",
102 ));
103 }
104 };
105
106 let message: codec::UpstreamMessage = protocol::decode(&payload)?;
107
108 let hello = match message {
109 codec::UpstreamMessage::Hello(h) => h,
110 _ => {
111 return Err(io::Error::new(
112 io::ErrorKind::InvalidData,
113 "Expected Hello message",
114 ));
115 }
116 };
117
118 #[cfg(feature = "tracing")]
119 Self::trace_handshake(&hello);
120
121 let validation_result = if hello.version != protocol::PROTOCOLS_VERSION {
122 Err(protocol::HandshakeError::UnsupportedVersion)
123 } else {
124 validator.verify(&hello)
125 };
126
127 let response = match validation_result {
128 Ok(_) => protocol::ServerHandshakeResponse::Ok,
129 Err(ref e) => protocol::ServerHandshakeResponse::Err(e.clone()),
130 };
131
132 control_frame.send(protocol::encode(&response)?).await?;
133
134 match validation_result {
135 Ok(ctx) => Ok((ctx, connection)),
136 Err(e) => Err(io::Error::new(
137 io::ErrorKind::PermissionDenied,
138 format!("Handshake failed: {:?}", e),
139 )),
140 }
141 }
142
143 async fn run_acceptor_loops(&self) {
144 let connect_acceptor = self.spawn_client_connect_acceptor();
145 #[cfg(feature = "datagram")]
146 let datagram_acceptor = self.spawn_client_datagram_acceptor();
147
148 #[cfg(not(feature = "datagram"))]
149 let result = connect_acceptor.await;
150
151 #[cfg(feature = "datagram")]
152 let result = tokio::select! {
153 res = connect_acceptor => res,
154 res = datagram_acceptor => res,
155 };
156
157 self.shutdown_token.cancel();
159
160 match result {
161 Ok(Ok(_)) => debug!("Connection closed gracefully."),
162 Ok(Err(e)) => debug!("Connection closed with internal error: {}", e),
163 Err(e) => warn!("Connection handler task panicked or failed: {}", e),
164 }
165 }
166
167 fn spawn_client_connect_acceptor(&self) -> JoinHandle<io::Result<()>> {
168 use crate::connection::stream::StreamTunnel;
169
170 let connection = Arc::clone(&self.client_connection);
171 let shutdown = self.shutdown_token.child_token();
172 let tunnel = StreamTunnel::new(connection, shutdown);
173
174 #[cfg(not(feature = "tracing"))]
175 let handle = tokio::spawn(tunnel.accept_loop());
176 #[cfg(feature = "tracing")]
177 let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
178
179 handle
180 }
181
182 #[cfg(feature = "datagram")]
183 fn spawn_client_datagram_acceptor(&self) -> JoinHandle<io::Result<()>> {
184 use crate::connection::datagram::DatagramTunnel;
185
186 let connection = Arc::clone(&self.client_connection);
187 let shutdown = self.shutdown_token.child_token();
188 let tunnel = DatagramTunnel::new(connection, shutdown);
189
190 #[cfg(not(feature = "tracing"))]
191 let handle = tokio::spawn(tunnel.accept_loop());
192 #[cfg(feature = "tracing")]
193 let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
194
195 handle
196 }
197
198 #[cfg(feature = "tracing")]
199 fn trace_handshake(hello: &protocol::ClientHello) {
200 let secret_hex = hello
201 .secret
202 .iter()
203 .map(|b| format!("{:02x}", b))
204 .collect::<String>();
205 tracing::Span::current().record("secret", &secret_hex);
206 }
207}