ombrac_client/connection/
mod.rs

1#[cfg(feature = "datagram")]
2mod datagram;
3mod stream;
4
5use std::future::Future;
6use std::io;
7use std::sync::Arc;
8use std::time::Duration;
9
10use arc_swap::{ArcSwap, Guard};
11use bytes::Bytes;
12use futures::{SinkExt, StreamExt};
13use tokio::io::AsyncWriteExt;
14use tokio::sync::Mutex;
15use tokio::time::Instant;
16use tokio_util::codec::Framed;
17
18use ombrac::codec::{ClientMessage, ServerMessage, length_codec};
19use ombrac::protocol::{
20    self, Address, ClientConnect, ClientHello, ConnectErrorKind, PROTOCOL_VERSION, Secret,
21    ServerAuthResponse, ServerConnectResponse,
22};
23use ombrac_macros::{error, warn};
24use ombrac_transport::{Connection, Initiator};
25
26pub use stream::BufferedStream;
27
28#[cfg(feature = "datagram")]
29pub use datagram::{UdpDispatcher, UdpSession};
30
31// --- Authentication & Connection ---
32/// Timeout for the initial authentication with the server [default: 10 seconds]
33const AUTH_TIMEOUT: Duration = Duration::from_secs(10);
34
35// --- Reconnection Strategy ---
36/// Initial backoff duration for reconnection attempts [default: 1 second]
37const INITIAL_RECONNECT_BACKOFF: Duration = Duration::from_secs(1);
38
39/// Maximum backoff duration for reconnection attempts [default: 60 seconds]
40const MAX_RECONNECT_BACKOFF: Duration = Duration::from_secs(60);
41
42struct ReconnectState {
43    last_attempt: Option<Instant>,
44    backoff: Duration,
45}
46
47impl Default for ReconnectState {
48    fn default() -> Self {
49        Self {
50            last_attempt: None,
51            backoff: INITIAL_RECONNECT_BACKOFF,
52        }
53    }
54}
55
56/// Manages the connection to the server, including authentication and reconnection logic.
57///
58/// This struct handles the lifecycle of a connection to the server, including
59/// initial authentication, automatic reconnection on failures, and connection state management.
60pub struct ClientConnection<T, C>
61where
62    T: Initiator<Connection = C>,
63    C: Connection,
64{
65    transport: T,
66    connection: ArcSwap<C>,
67    reconnect_lock: Mutex<ReconnectState>,
68    secret: Secret,
69    options: Bytes,
70}
71
72impl<T, C> ClientConnection<T, C>
73where
74    T: Initiator<Connection = C>,
75    C: Connection,
76{
77    /// Creates a new `ClientConnection` and establishes a connection to the server.
78    ///
79    /// This involves performing authentication with the server.
80    pub async fn new(transport: T, secret: Secret, options: Option<Bytes>) -> io::Result<Self> {
81        let options = options.unwrap_or_default();
82        let connection = match authenticate(&transport, secret, options.clone()).await {
83            Ok(conn) => conn,
84            Err(err) => {
85                error!(
86                    error = %err,
87                    error_kind = ?err.kind(),
88                    "failed to initialize connection"
89                );
90                return Err(err);
91            }
92        };
93
94        Ok(Self {
95            transport,
96            connection: ArcSwap::new(Arc::new(connection)),
97            reconnect_lock: Mutex::new(ReconnectState::default()),
98            secret,
99            options,
100        })
101    }
102
103    /// Opens a new bidirectional stream for TCP-like communication.
104    ///
105    /// This method negotiates a new stream with the server, which will then
106    /// connect to the specified destination address. It waits for the server's
107    /// connection response before returning, ensuring proper TCP state handling.
108    ///
109    /// The returned stream is wrapped in a `BufferedStream` to ensure that any
110    /// data remaining in the protocol framing buffer is read first, preventing
111    /// data loss when transitioning from message-based to raw stream communication.
112    pub async fn open_bidirectional(
113        &self,
114        dest_addr: Address,
115    ) -> io::Result<BufferedStream<C::Stream>> {
116        let mut stream = self
117            .with_retry(|conn| async move { conn.open_bidirectional().await })
118            .await?;
119
120        // Use Framed codec for consistent message framing
121        let mut framed = Framed::new(&mut stream, length_codec());
122
123        // Send connection request
124        let connect_message = ClientMessage::Connect(ClientConnect {
125            address: dest_addr.clone(),
126        });
127        framed.send(protocol::encode(&connect_message)?).await?;
128
129        // Wait for the server's connection response
130        // Framed automatically reads the length prefix and validates frame size
131        let payload = match framed.next().await {
132            Some(Ok(payload)) => payload,
133            Some(Err(e)) => {
134                return Err(io::Error::new(
135                    io::ErrorKind::InvalidData,
136                    format!("failed to read server response: {}", e),
137                ));
138            }
139            None => {
140                return Err(io::Error::new(
141                    io::ErrorKind::UnexpectedEof,
142                    "stream closed before receiving server response",
143                ));
144            }
145        };
146
147        let message: ServerMessage = protocol::decode(&payload)?;
148        let response = match message {
149            ServerMessage::ConnectResponse(response) => response,
150            #[allow(unreachable_patterns)]
151            _ => {
152                return Err(io::Error::new(
153                    io::ErrorKind::InvalidData,
154                    "expected connect response message",
155                ));
156            }
157        };
158
159        // Extract any buffered data from Framed before dropping it
160        // This ensures we don't lose any data that might be in the read/write buffers.
161        // In this request-response protocol, we expect:
162        // - read_buf to be empty (we've read the complete response)
163        // - write_buf to be empty (send() should have flushed the request)
164        let parts = framed.into_parts();
165
166        // Verify write buffer is empty (send() should have flushed, but verify for safety)
167        if !parts.write_buf.is_empty() {
168            // This indicates send() didn't complete properly - this is a serious error
169            return Err(io::Error::other(format!(
170                "write buffer not empty after send: {} bytes remaining - data may be lost",
171                parts.write_buf.len()
172            )));
173        }
174
175        // Extract any remaining buffered read data
176        // This data may be present if the server sent additional data immediately after
177        // the connection response. We preserve it by wrapping the stream in BufferedStream.
178        let buffered_data = if !parts.read_buf.is_empty() {
179            Bytes::copy_from_slice(&parts.read_buf)
180        } else {
181            Bytes::new()
182        };
183
184        match response {
185            ServerConnectResponse::Ok => {
186                // Connection successful - return the stream wrapped in BufferedStream
187                // to ensure any buffered data is read first
188                Ok(BufferedStream::new(stream, buffered_data))
189            }
190            ServerConnectResponse::Err { kind, message } => {
191                // Connection failed - return appropriate error
192                let error_kind = match kind {
193                    ConnectErrorKind::ConnectionRefused => io::ErrorKind::ConnectionRefused,
194                    ConnectErrorKind::NetworkUnreachable => io::ErrorKind::NetworkUnreachable,
195                    ConnectErrorKind::HostUnreachable => io::ErrorKind::HostUnreachable,
196                    ConnectErrorKind::TimedOut => io::ErrorKind::TimedOut,
197                    ConnectErrorKind::Other => io::ErrorKind::Other,
198                };
199                Err(io::Error::new(error_kind, message))
200            }
201        }
202    }
203
204    /// Gets a reference to the current connection.
205    pub fn connection(&self) -> Guard<Arc<C>> {
206        self.connection.load()
207    }
208
209    /// Rebind the transport to a new socket to ensure a clean state for reconnection.
210    pub async fn rebind(&self) -> io::Result<()> {
211        self.transport.rebind().await
212    }
213
214    /// A wrapper function that adds retry/reconnect logic to a connection operation.
215    ///
216    /// It executes the provided `operation`. If the operation fails with a
217    /// connection-related error, it attempts to reconnect and retries the
218    /// operation once.
219    ///
220    /// # Errors
221    ///
222    /// Returns the original error if it's not a connection error, or the error
223    /// from the retry attempt if reconnection fails.
224    pub(crate) async fn with_retry<F, Fut, R>(&self, operation: F) -> io::Result<R>
225    where
226        F: Fn(Guard<Arc<C>>) -> Fut,
227        Fut: Future<Output = io::Result<R>>,
228    {
229        let connection = self.connection.load();
230        // Use the pointer address as a unique ID for the connection instance.
231        let old_conn_id = Arc::as_ptr(&connection) as usize;
232
233        match operation(connection).await {
234            Ok(result) => Ok(result),
235            Err(e) if is_connection_error(&e) => {
236                // Log the connection error before attempting reconnection
237                // This is a system-level error that should be logged
238                log_connection_error(
239                    ErrorContext::new("with_retry")
240                        .with_details("attempting to reconnect".to_string()),
241                    &e,
242                );
243                // Attempt reconnection - if it fails, return the reconnection error
244                self.reconnect(old_conn_id).await?;
245                // Retry the operation with the new connection
246                let new_connection = self.connection.load();
247                operation(new_connection).await
248            }
249            Err(e) => Err(e),
250        }
251    }
252
253    /// Handles the reconnection logic with exponential backoff.
254    ///
255    /// It uses a mutex to prevent multiple tasks from trying to reconnect simultaneously.
256    /// If another task has already reconnected, this function returns immediately.
257    ///
258    /// # Errors
259    ///
260    /// Returns an error if:
261    /// - Reconnection is throttled (too many attempts)
262    /// - Transport rebind fails
263    /// - Authentication fails
264    async fn reconnect(&self, old_conn_id: usize) -> io::Result<()> {
265        let mut state = self.reconnect_lock.lock().await;
266
267        // Check if another task has already reconnected
268        let current_conn = self.connection.load();
269        let current_conn_id = Arc::as_ptr(&current_conn) as usize;
270        if current_conn_id != old_conn_id {
271            // Another task already reconnected, we're done
272            return Ok(());
273        }
274
275        // Apply exponential backoff if we've attempted recently
276        if let Some(last) = state.last_attempt {
277            let elapsed = last.elapsed();
278            if elapsed < state.backoff {
279                let wait_time = state.backoff - elapsed;
280                let backoff_secs = state.backoff.as_secs();
281                drop(state);
282                tokio::time::sleep(wait_time).await;
283                let err = io::Error::other("reconnect throttled");
284                log_reconnect_error(ErrorContext::new("reconnect"), &err, Some(backoff_secs));
285                return Err(err);
286            }
287        }
288
289        state.last_attempt = Some(Instant::now());
290
291        if let Err(e) = self.transport.rebind().await {
292            state.backoff = (state.backoff * 2).min(MAX_RECONNECT_BACKOFF);
293            log_reconnect_error(
294                ErrorContext::new("reconnect").with_details("transport rebind failed".to_string()),
295                &e,
296                Some(state.backoff.as_secs()),
297            );
298            return Err(e);
299        }
300
301        match authenticate(&self.transport, self.secret, self.options.clone()).await {
302            Ok(new_connection) => {
303                state.backoff = INITIAL_RECONNECT_BACKOFF;
304                state.last_attempt = None;
305
306                self.connection.store(Arc::new(new_connection));
307                Ok(())
308            }
309            Err(e) => {
310                state.backoff = (state.backoff * 2).min(MAX_RECONNECT_BACKOFF);
311                log_reconnect_error(
312                    ErrorContext::new("reconnect")
313                        .with_details("authentication failed".to_string()),
314                    &e,
315                    Some(state.backoff.as_secs()),
316                );
317                Err(e)
318            }
319        }
320    }
321}
322
323/// Performs the initial authentication with the server.
324async fn authenticate<T, C>(transport: &T, secret: Secret, options: Bytes) -> io::Result<C>
325where
326    T: Initiator<Connection = C>,
327    C: Connection,
328{
329    let do_auth = async {
330        let connection = transport.connect().await?;
331        let mut stream = connection.open_bidirectional().await?;
332
333        let hello_message = ClientMessage::Hello(ClientHello {
334            version: PROTOCOL_VERSION,
335            secret,
336            options,
337        });
338
339        let encoded_bytes = protocol::encode(&hello_message)?;
340        let mut framed = Framed::new(&mut stream, length_codec());
341
342        framed.send(encoded_bytes).await?;
343
344        match framed.next().await {
345            Some(Ok(payload)) => {
346                let response: ServerAuthResponse = protocol::decode(&payload)?;
347                match response {
348                    ServerAuthResponse::Ok => {
349                        stream.shutdown().await?;
350                        Ok(connection)
351                    }
352                    ServerAuthResponse::Err => Err(io::Error::other("authentication failed")),
353                }
354            }
355            Some(Err(e)) => Err(e),
356            None => Err(io::Error::new(
357                io::ErrorKind::UnexpectedEof,
358                "connection closed by server during authentication",
359            )),
360        }
361    };
362
363    match tokio::time::timeout(AUTH_TIMEOUT, do_auth).await {
364        Ok(result) => result,
365        Err(_) => Err(io::Error::new(
366            io::ErrorKind::TimedOut,
367            format!(
368                "client authentication timed out after {}s",
369                AUTH_TIMEOUT.as_secs()
370            ),
371        )),
372    }
373}
374
375// --- Error Handling ---
376/// Error context for connection-related operations
377struct ErrorContext {
378    operation: &'static str,
379    details: Option<String>,
380}
381
382impl ErrorContext {
383    fn new(operation: &'static str) -> Self {
384        Self {
385            operation,
386            details: None,
387        }
388    }
389
390    fn with_details(mut self, details: String) -> Self {
391        self.details = Some(details);
392        self
393    }
394}
395
396/// Logs and returns a connection error.
397///
398/// This is used for system-level connection errors that should be logged
399/// at the point of occurrence (e.g., during automatic reconnection).
400fn log_connection_error(ctx: ErrorContext, err: &io::Error) {
401    if is_connection_error(err) {
402        warn!(
403            error = %err,
404            error_kind = ?err.kind(),
405            operation = ctx.operation,
406            details = ctx.details.as_deref(),
407            "connection error detected"
408        );
409    }
410}
411
412/// Logs and returns a reconnection error.
413///
414/// This is used for system-level reconnection errors that should be logged
415/// at the point of occurrence.
416fn log_reconnect_error(ctx: ErrorContext, err: &io::Error, backoff_secs: Option<u64>) {
417    if err.kind() == io::ErrorKind::Other && err.to_string() == "reconnect throttled" {
418        warn!(
419            operation = ctx.operation,
420            backoff_secs = backoff_secs,
421            "reconnect throttled, too many attempts"
422        );
423    } else {
424        error!(
425            error = %err,
426            error_kind = ?err.kind(),
427            operation = ctx.operation,
428            backoff_secs = backoff_secs,
429            details = ctx.details.as_deref(),
430            "reconnection failed"
431        );
432    }
433}
434
435/// Checks if an `io::Error` is related to a lost connection.
436fn is_connection_error(e: &io::Error) -> bool {
437    matches!(
438        e.kind(),
439        io::ErrorKind::ConnectionReset
440            | io::ErrorKind::BrokenPipe
441            | io::ErrorKind::NotConnected
442            | io::ErrorKind::TimedOut
443            | io::ErrorKind::UnexpectedEof
444            | io::ErrorKind::NetworkUnreachable
445    )
446}