ombrac_client/connection/
mod.rs1#[cfg(feature = "datagram")]
2mod datagram;
3mod stream;
4
5use std::future::Future;
6use std::io;
7use std::sync::Arc;
8use std::time::Duration;
9
10use arc_swap::{ArcSwap, Guard};
11use bytes::Bytes;
12use futures::{SinkExt, StreamExt};
13use tokio::io::AsyncWriteExt;
14use tokio::sync::Mutex;
15use tokio::time::Instant;
16use tokio_util::codec::Framed;
17
18use ombrac::codec::{ClientMessage, ServerMessage, length_codec};
19use ombrac::protocol::{
20 self, Address, ClientConnect, ClientHello, ConnectErrorKind, PROTOCOL_VERSION, Secret,
21 ServerAuthResponse, ServerConnectResponse,
22};
23use ombrac_macros::{error, warn};
24use ombrac_transport::{Connection, Initiator};
25
26pub use stream::BufferedStream;
27
28#[cfg(feature = "datagram")]
29pub use datagram::{UdpDispatcher, UdpSession};
30
31const AUTH_TIMEOUT: Duration = Duration::from_secs(10);
34
35const INITIAL_RECONNECT_BACKOFF: Duration = Duration::from_secs(1);
38
39const MAX_RECONNECT_BACKOFF: Duration = Duration::from_secs(60);
41
42struct ReconnectState {
43 last_attempt: Option<Instant>,
44 backoff: Duration,
45}
46
47impl Default for ReconnectState {
48 fn default() -> Self {
49 Self {
50 last_attempt: None,
51 backoff: INITIAL_RECONNECT_BACKOFF,
52 }
53 }
54}
55
56pub struct ClientConnection<T, C>
61where
62 T: Initiator<Connection = C>,
63 C: Connection,
64{
65 transport: T,
66 connection: ArcSwap<C>,
67 reconnect_lock: Mutex<ReconnectState>,
68 secret: Secret,
69 options: Bytes,
70}
71
72impl<T, C> ClientConnection<T, C>
73where
74 T: Initiator<Connection = C>,
75 C: Connection,
76{
77 pub async fn new(transport: T, secret: Secret, options: Option<Bytes>) -> io::Result<Self> {
81 let options = options.unwrap_or_default();
82 let connection = match authenticate(&transport, secret, options.clone()).await {
83 Ok(conn) => conn,
84 Err(err) => {
85 error!(
86 error = %err,
87 error_kind = ?err.kind(),
88 "failed to initialize connection"
89 );
90 return Err(err);
91 }
92 };
93
94 Ok(Self {
95 transport,
96 connection: ArcSwap::new(Arc::new(connection)),
97 reconnect_lock: Mutex::new(ReconnectState::default()),
98 secret,
99 options,
100 })
101 }
102
103 pub async fn open_bidirectional(
113 &self,
114 dest_addr: Address,
115 ) -> io::Result<BufferedStream<C::Stream>> {
116 let mut stream = self
117 .with_retry(|conn| async move { conn.open_bidirectional().await })
118 .await?;
119
120 let mut framed = Framed::new(&mut stream, length_codec());
122
123 let connect_message = ClientMessage::Connect(ClientConnect {
125 address: dest_addr.clone(),
126 });
127 framed.send(protocol::encode(&connect_message)?).await?;
128
129 let payload = match framed.next().await {
132 Some(Ok(payload)) => payload,
133 Some(Err(e)) => {
134 return Err(io::Error::new(
135 io::ErrorKind::InvalidData,
136 format!("failed to read server response: {}", e),
137 ));
138 }
139 None => {
140 return Err(io::Error::new(
141 io::ErrorKind::UnexpectedEof,
142 "stream closed before receiving server response",
143 ));
144 }
145 };
146
147 let message: ServerMessage = protocol::decode(&payload)?;
148 let response = match message {
149 ServerMessage::ConnectResponse(response) => response,
150 #[allow(unreachable_patterns)]
151 _ => {
152 return Err(io::Error::new(
153 io::ErrorKind::InvalidData,
154 "expected connect response message",
155 ));
156 }
157 };
158
159 let parts = framed.into_parts();
165
166 if !parts.write_buf.is_empty() {
168 return Err(io::Error::other(format!(
170 "write buffer not empty after send: {} bytes remaining - data may be lost",
171 parts.write_buf.len()
172 )));
173 }
174
175 let buffered_data = if !parts.read_buf.is_empty() {
179 Bytes::copy_from_slice(&parts.read_buf)
180 } else {
181 Bytes::new()
182 };
183
184 match response {
185 ServerConnectResponse::Ok => {
186 Ok(BufferedStream::new(stream, buffered_data))
189 }
190 ServerConnectResponse::Err { kind, message } => {
191 let error_kind = match kind {
193 ConnectErrorKind::ConnectionRefused => io::ErrorKind::ConnectionRefused,
194 ConnectErrorKind::NetworkUnreachable => io::ErrorKind::NetworkUnreachable,
195 ConnectErrorKind::HostUnreachable => io::ErrorKind::HostUnreachable,
196 ConnectErrorKind::TimedOut => io::ErrorKind::TimedOut,
197 ConnectErrorKind::Other => io::ErrorKind::Other,
198 };
199 Err(io::Error::new(error_kind, message))
200 }
201 }
202 }
203
204 pub fn connection(&self) -> Guard<Arc<C>> {
206 self.connection.load()
207 }
208
209 pub async fn rebind(&self) -> io::Result<()> {
211 self.transport.rebind().await
212 }
213
214 pub(crate) async fn with_retry<F, Fut, R>(&self, operation: F) -> io::Result<R>
225 where
226 F: Fn(Guard<Arc<C>>) -> Fut,
227 Fut: Future<Output = io::Result<R>>,
228 {
229 let connection = self.connection.load();
230 let old_conn_id = Arc::as_ptr(&connection) as usize;
232
233 match operation(connection).await {
234 Ok(result) => Ok(result),
235 Err(e) if is_connection_error(&e) => {
236 log_connection_error(
239 ErrorContext::new("with_retry")
240 .with_details("attempting to reconnect".to_string()),
241 &e,
242 );
243 self.reconnect(old_conn_id).await?;
245 let new_connection = self.connection.load();
247 operation(new_connection).await
248 }
249 Err(e) => Err(e),
250 }
251 }
252
253 async fn reconnect(&self, old_conn_id: usize) -> io::Result<()> {
265 let mut state = self.reconnect_lock.lock().await;
266
267 let current_conn = self.connection.load();
269 let current_conn_id = Arc::as_ptr(¤t_conn) as usize;
270 if current_conn_id != old_conn_id {
271 return Ok(());
273 }
274
275 if let Some(last) = state.last_attempt {
277 let elapsed = last.elapsed();
278 if elapsed < state.backoff {
279 let wait_time = state.backoff - elapsed;
280 let backoff_secs = state.backoff.as_secs();
281 drop(state);
282 tokio::time::sleep(wait_time).await;
283 let err = io::Error::other("reconnect throttled");
284 log_reconnect_error(ErrorContext::new("reconnect"), &err, Some(backoff_secs));
285 return Err(err);
286 }
287 }
288
289 state.last_attempt = Some(Instant::now());
290
291 if let Err(e) = self.transport.rebind().await {
292 state.backoff = (state.backoff * 2).min(MAX_RECONNECT_BACKOFF);
293 log_reconnect_error(
294 ErrorContext::new("reconnect").with_details("transport rebind failed".to_string()),
295 &e,
296 Some(state.backoff.as_secs()),
297 );
298 return Err(e);
299 }
300
301 match authenticate(&self.transport, self.secret, self.options.clone()).await {
302 Ok(new_connection) => {
303 state.backoff = INITIAL_RECONNECT_BACKOFF;
304 state.last_attempt = None;
305
306 self.connection.store(Arc::new(new_connection));
307 Ok(())
308 }
309 Err(e) => {
310 state.backoff = (state.backoff * 2).min(MAX_RECONNECT_BACKOFF);
311 log_reconnect_error(
312 ErrorContext::new("reconnect")
313 .with_details("authentication failed".to_string()),
314 &e,
315 Some(state.backoff.as_secs()),
316 );
317 Err(e)
318 }
319 }
320 }
321}
322
323async fn authenticate<T, C>(transport: &T, secret: Secret, options: Bytes) -> io::Result<C>
325where
326 T: Initiator<Connection = C>,
327 C: Connection,
328{
329 let do_auth = async {
330 let connection = transport.connect().await?;
331 let mut stream = connection.open_bidirectional().await?;
332
333 let hello_message = ClientMessage::Hello(ClientHello {
334 version: PROTOCOL_VERSION,
335 secret,
336 options,
337 });
338
339 let encoded_bytes = protocol::encode(&hello_message)?;
340 let mut framed = Framed::new(&mut stream, length_codec());
341
342 framed.send(encoded_bytes).await?;
343
344 match framed.next().await {
345 Some(Ok(payload)) => {
346 let response: ServerAuthResponse = protocol::decode(&payload)?;
347 match response {
348 ServerAuthResponse::Ok => {
349 stream.shutdown().await?;
350 Ok(connection)
351 }
352 ServerAuthResponse::Err => Err(io::Error::other("authentication failed")),
353 }
354 }
355 Some(Err(e)) => Err(e),
356 None => Err(io::Error::new(
357 io::ErrorKind::UnexpectedEof,
358 "connection closed by server during authentication",
359 )),
360 }
361 };
362
363 match tokio::time::timeout(AUTH_TIMEOUT, do_auth).await {
364 Ok(result) => result,
365 Err(_) => Err(io::Error::new(
366 io::ErrorKind::TimedOut,
367 format!(
368 "client authentication timed out after {}s",
369 AUTH_TIMEOUT.as_secs()
370 ),
371 )),
372 }
373}
374
375struct ErrorContext {
378 operation: &'static str,
379 details: Option<String>,
380}
381
382impl ErrorContext {
383 fn new(operation: &'static str) -> Self {
384 Self {
385 operation,
386 details: None,
387 }
388 }
389
390 fn with_details(mut self, details: String) -> Self {
391 self.details = Some(details);
392 self
393 }
394}
395
396fn log_connection_error(ctx: ErrorContext, err: &io::Error) {
401 if is_connection_error(err) {
402 warn!(
403 error = %err,
404 error_kind = ?err.kind(),
405 operation = ctx.operation,
406 details = ctx.details.as_deref(),
407 "connection error detected"
408 );
409 }
410}
411
412fn log_reconnect_error(ctx: ErrorContext, err: &io::Error, backoff_secs: Option<u64>) {
417 if err.kind() == io::ErrorKind::Other && err.to_string() == "reconnect throttled" {
418 warn!(
419 operation = ctx.operation,
420 backoff_secs = backoff_secs,
421 "reconnect throttled, too many attempts"
422 );
423 } else {
424 error!(
425 error = %err,
426 error_kind = ?err.kind(),
427 operation = ctx.operation,
428 backoff_secs = backoff_secs,
429 details = ctx.details.as_deref(),
430 "reconnection failed"
431 );
432 }
433}
434
435fn is_connection_error(e: &io::Error) -> bool {
437 matches!(
438 e.kind(),
439 io::ErrorKind::ConnectionReset
440 | io::ErrorKind::BrokenPipe
441 | io::ErrorKind::NotConnected
442 | io::ErrorKind::TimedOut
443 | io::ErrorKind::UnexpectedEof
444 | io::ErrorKind::NetworkUnreachable
445 )
446}