ombrac_server/connection/
mod.rs1#[cfg(feature = "datagram")]
2mod datagram;
3mod stream;
4
5use std::io;
6use std::sync::Arc;
7
8use futures::{SinkExt, StreamExt};
9use tokio::task::JoinHandle;
10use tokio_util::codec::Framed;
11use tokio_util::sync::CancellationToken;
12#[cfg(feature = "tracing")]
13use tracing::Instrument;
14
15use ombrac::codec;
16use ombrac::protocol;
17use ombrac_macros::{debug, warn};
18use ombrac_transport::Connection;
19
20pub trait HandshakeValidator: Send + Sync {
21 fn validate_hello(&self, hello: &protocol::ClientHello)
22 -> Result<(), protocol::HandshakeError>;
23}
24
25impl HandshakeValidator for ombrac::protocol::Secret {
26 fn validate_hello(
27 &self,
28 hello: &protocol::ClientHello,
29 ) -> Result<(), protocol::HandshakeError> {
30 if &hello.secret == self {
31 Ok(())
32 } else {
33 Err(protocol::HandshakeError::InvalidSecret)
34 }
35 }
36}
37
38pub struct ClientConnection<C: Connection> {
39 client_connection: Arc<C>,
40 shutdown_token: CancellationToken,
41}
42
43impl<C: Connection> ClientConnection<C> {
44 pub async fn handle<V: HandshakeValidator>(connection: C, validator: &V) -> io::Result<()> {
45 let mut control_stream = connection.accept_bidirectional().await?;
46 let mut control_frame = Framed::new(&mut control_stream, codec::length_codec());
47
48 match control_frame.next().await {
49 Some(Ok(payload)) => {
50 let hello_message: codec::UpstreamMessage = protocol::decode(&payload)?;
51
52 if let codec::UpstreamMessage::Hello(hello) = &hello_message {
53 #[cfg(feature = "tracing")]
54 {
55 let secret_hex = hello
56 .secret
57 .iter()
58 .map(|b| format!("{:02x}", b))
59 .collect::<String>();
60 tracing::span::Span::current().record("secret", &secret_hex);
61 }
62
63 let response = if hello.version != protocol::PROTOCOLS_VERSION {
64 protocol::ServerHandshakeResponse::Err(
65 protocol::HandshakeError::UnsupportedVersion,
66 )
67 } else {
68 match validator.validate_hello(hello) {
69 Ok(_) => protocol::ServerHandshakeResponse::Ok,
70 Err(e) => protocol::ServerHandshakeResponse::Err(e),
71 }
72 };
73
74 let response_payload = protocol::encode(&response)?;
75 control_frame.send(response_payload.into()).await?;
76
77 if let protocol::ServerHandshakeResponse::Err(e) = response {
78 return Err(io::Error::new(
79 io::ErrorKind::PermissionDenied,
80 format!("handshake validation failed: {:?}", e),
81 ));
82 }
83 }
84 }
85 _ => {
86 return Err(io::Error::new(
87 io::ErrorKind::InvalidData,
88 "failed to read hello message",
89 ));
90 }
91 }
92
93 let handler = Self {
94 client_connection: Arc::new(connection),
95 shutdown_token: CancellationToken::new(),
96 };
97
98 handler.manage_acceptor_loops().await;
99
100 Ok(())
101 }
102
103 async fn manage_acceptor_loops(&self) {
104 let connect_acceptor = self.spawn_client_connect_acceptor();
105 #[cfg(feature = "datagram")]
106 let datagram_acceptor = self.spawn_client_datagram_acceptor();
107
108 #[cfg(not(feature = "datagram"))]
109 let result = connect_acceptor.await;
110
111 #[cfg(feature = "datagram")]
112 let result = tokio::select! {
113 res = connect_acceptor => res,
114 res = datagram_acceptor => res,
115 };
116
117 self.shutdown_token.cancel();
119
120 match result {
121 Ok(Ok(_)) => {
122 debug!("connection closed gracefully.");
123 }
124 Ok(Err(_err)) => {
125 debug!("connection closed with an error: {_err}");
126 }
127 Err(_err) => {
128 warn!("connection handler task failed: {_err}");
129 }
130 }
131 }
132
133 fn spawn_client_connect_acceptor(&self) -> JoinHandle<io::Result<()>> {
134 use crate::connection::stream::StreamTunnel;
135
136 let connection = Arc::clone(&self.client_connection);
137 let shutdown = self.shutdown_token.child_token();
138 let tunnel = StreamTunnel::new(connection, shutdown);
139
140 #[cfg(not(feature = "tracing"))]
141 let handle = tokio::spawn(tunnel.accept_loop());
142 #[cfg(feature = "tracing")]
143 let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
144
145 handle
146 }
147
148 #[cfg(feature = "datagram")]
149 fn spawn_client_datagram_acceptor(&self) -> JoinHandle<io::Result<()>> {
150 use crate::connection::datagram::DatagramTunnel;
151
152 let connection = Arc::clone(&self.client_connection);
153 let shutdown = self.shutdown_token.child_token();
154 let tunnel = DatagramTunnel::new(connection, shutdown);
155
156 #[cfg(not(feature = "tracing"))]
157 let handle = tokio::spawn(tunnel.accept_loop());
158 #[cfg(feature = "tracing")]
159 let handle = tokio::spawn(tunnel.accept_loop().in_current_span());
160
161 handle
162 }
163}