Skip to main content

ironsbe_client/
builder.rs

1//! Client builder and main client implementation.
2
3use crate::error::ClientError;
4use crate::reconnect::{ReconnectConfig, ReconnectState};
5use crate::session::ClientSession;
6use ironsbe_channel::spsc;
7use ironsbe_transport::traits::Transport;
8use std::marker::PhantomData;
9use std::net::SocketAddr;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::Notify;
13
14/// Builder for configuring and creating a client.
15///
16/// The type parameter `T` selects the transport backend.  When the
17/// `tcp-tokio` feature is enabled (the default), `T` defaults to
18/// [`ironsbe_transport::DefaultTransport`] so existing call-sites compile
19/// without changes.  With the feature disabled, `T` must be specified
20/// explicitly so downstream crates can plug in a custom backend.
21#[cfg(feature = "tcp-tokio")]
22pub struct ClientBuilder<T: Transport = ironsbe_transport::DefaultTransport> {
23    server_addr: SocketAddr,
24    connect_config: Option<T::ConnectConfig>,
25    connect_timeout: Duration,
26    reconnect_config: ReconnectConfig,
27    channel_capacity: usize,
28    _transport: PhantomData<T>,
29}
30
31/// Builder for configuring and creating a client.
32///
33/// With the `tcp-tokio` feature disabled, the transport backend must be
34/// specified explicitly via the `T` type parameter.
35#[cfg(not(feature = "tcp-tokio"))]
36pub struct ClientBuilder<T: Transport> {
37    server_addr: SocketAddr,
38    connect_config: Option<T::ConnectConfig>,
39    connect_timeout: Duration,
40    reconnect_config: ReconnectConfig,
41    channel_capacity: usize,
42    _transport: PhantomData<T>,
43}
44
45impl<T: Transport> ClientBuilder<T> {
46    /// Creates a new client builder for the specified server address.
47    #[must_use]
48    pub fn new(server_addr: SocketAddr) -> Self {
49        Self {
50            server_addr,
51            connect_config: None,
52            connect_timeout: Duration::from_secs(5),
53            reconnect_config: ReconnectConfig::default(),
54            channel_capacity: 4096,
55            _transport: PhantomData,
56        }
57    }
58
59    /// Supplies a backend-specific connect configuration.
60    ///
61    /// Use this to override transport tunables (frame size, NODELAY, socket
62    /// buffer sizes, …).  When unset, the backend builds a default config
63    /// from the server address.
64    #[must_use]
65    pub fn connect_config(mut self, config: T::ConnectConfig) -> Self {
66        self.connect_config = Some(config);
67        self
68    }
69
70    /// Sets the outer connection timeout used by the reconnect loop.
71    ///
72    /// This bounds how long [`Client::run`] waits for a single connect
73    /// attempt before mapping the failure to [`ClientError::ConnectTimeout`].
74    /// It does **not** mutate any [`connect_config`](Self::connect_config)
75    /// the user may have supplied — to also override the backend's internal
76    /// connect timeout for the Tokio TCP backend, use
77    /// [`tcp_connect_timeout`](Self::tcp_connect_timeout).
78    #[must_use]
79    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
80        self.connect_timeout = timeout;
81        self
82    }
83
84    /// Enables or disables automatic reconnection.
85    #[must_use]
86    pub fn reconnect(mut self, enabled: bool) -> Self {
87        self.reconnect_config.enabled = enabled;
88        self
89    }
90
91    /// Sets the reconnection delay.
92    #[must_use]
93    pub fn reconnect_delay(mut self, delay: Duration) -> Self {
94        self.reconnect_config.initial_delay = delay;
95        self
96    }
97
98    /// Sets the maximum reconnection attempts.
99    #[must_use]
100    pub fn max_reconnect_attempts(mut self, max: usize) -> Self {
101        self.reconnect_config.max_attempts = max;
102        self
103    }
104
105    /// Sets the channel capacity.
106    #[must_use]
107    pub fn channel_capacity(mut self, capacity: usize) -> Self {
108        self.channel_capacity = capacity;
109        self
110    }
111
112    /// Builds the client and handle.
113    #[must_use]
114    pub fn build(self) -> (Client<T>, ClientHandle) {
115        let (cmd_tx, cmd_rx) = spsc::channel(self.channel_capacity);
116        let (event_tx, event_rx) = spsc::channel(self.channel_capacity);
117
118        let cmd_notify = Arc::new(Notify::new());
119        let event_notify = Arc::new(Notify::new());
120
121        let client = Client {
122            server_addr: self.server_addr,
123            connect_config: Some(
124                self.connect_config
125                    .unwrap_or_else(|| T::ConnectConfig::from(self.server_addr)),
126            ),
127            connect_timeout: self.connect_timeout,
128            reconnect_state: ReconnectState::new(self.reconnect_config),
129            cmd_rx,
130            event_tx,
131            cmd_notify: Arc::clone(&cmd_notify),
132            event_notify: Arc::clone(&event_notify),
133            _transport: PhantomData,
134        };
135
136        let handle = ClientHandle {
137            cmd_tx,
138            event_rx,
139            cmd_notify,
140            event_notify,
141        };
142
143        (client, handle)
144    }
145}
146
147#[cfg(feature = "tcp-tokio")]
148impl ClientBuilder {
149    /// Creates a new client builder using the default transport backend.
150    ///
151    /// This is a convenience constructor that resolves the transport type
152    /// parameter to [`ironsbe_transport::DefaultTransport`], keeping existing
153    /// call-sites like `ClientBuilder::with_default_transport(addr).build()`
154    /// working without turbofish syntax.
155    #[must_use]
156    pub fn with_default_transport(server_addr: SocketAddr) -> Self {
157        Self::new(server_addr)
158    }
159
160    /// Sets the maximum SBE frame size in bytes (Tokio TCP backend only).
161    ///
162    /// Convenience shortcut that mutates the underlying
163    /// [`ironsbe_transport::tcp::TcpClientConfig`] without requiring callers
164    /// to construct it manually.
165    #[must_use]
166    pub fn max_frame_size(mut self, size: usize) -> Self {
167        let cfg = self
168            .connect_config
169            .take()
170            .unwrap_or_else(|| ironsbe_transport::tcp::TcpClientConfig::new(self.server_addr));
171        self.connect_config = Some(cfg.max_frame_size(size));
172        self
173    }
174
175    /// Forwards a connect timeout into the underlying
176    /// [`TcpClientConfig`](ironsbe_transport::tcp::TcpClientConfig) so the
177    /// backend's internal timeout matches the outer reconnect loop.
178    ///
179    /// Convenience shortcut equivalent to calling
180    /// [`connect_timeout`](Self::connect_timeout) and then mutating
181    /// `TcpClientConfig::connect_timeout` on a custom `connect_config`.
182    #[must_use]
183    pub fn tcp_connect_timeout(mut self, timeout: Duration) -> Self {
184        self.connect_timeout = timeout;
185        let cfg = self
186            .connect_config
187            .take()
188            .unwrap_or_else(|| ironsbe_transport::tcp::TcpClientConfig::new(self.server_addr));
189        self.connect_config = Some(cfg.connect_timeout(timeout));
190        self
191    }
192}
193
194/// The main client instance.
195///
196/// Generic over transport backend `T`.
197#[cfg(feature = "tcp-tokio")]
198pub struct Client<T: Transport = ironsbe_transport::DefaultTransport> {
199    server_addr: SocketAddr,
200    connect_config: Option<T::ConnectConfig>,
201    connect_timeout: Duration,
202    reconnect_state: ReconnectState,
203    cmd_rx: spsc::SpscReceiver<ClientCommand>,
204    event_tx: spsc::SpscSender<ClientEvent>,
205    cmd_notify: Arc<Notify>,
206    event_notify: Arc<Notify>,
207    _transport: PhantomData<T>,
208}
209
210/// The main client instance.
211///
212/// Generic over transport backend `T`.
213#[cfg(not(feature = "tcp-tokio"))]
214pub struct Client<T: Transport> {
215    server_addr: SocketAddr,
216    connect_config: Option<T::ConnectConfig>,
217    connect_timeout: Duration,
218    reconnect_state: ReconnectState,
219    cmd_rx: spsc::SpscReceiver<ClientCommand>,
220    event_tx: spsc::SpscSender<ClientEvent>,
221    cmd_notify: Arc<Notify>,
222    event_notify: Arc<Notify>,
223    _transport: PhantomData<T>,
224}
225
226impl<T: Transport> Client<T> {
227    /// Runs the client, connecting to the server and processing messages.
228    ///
229    /// # Errors
230    /// Returns `ClientError` if the client fails to connect or encounters an error.
231    pub async fn run(&mut self) -> Result<(), ClientError> {
232        loop {
233            match self.connect_and_run().await {
234                Ok(()) => {
235                    // Normal shutdown
236                    return Ok(());
237                }
238                Err(e) => {
239                    tracing::error!("Connection error: {:?}", e);
240
241                    if let Some(delay) = self.reconnect_state.on_failure() {
242                        let _ = self.event_tx.send(ClientEvent::Disconnected);
243                        self.event_notify.notify_one();
244                        tracing::info!("Reconnecting in {:?}...", delay);
245                        tokio::time::sleep(delay).await;
246                    } else {
247                        tracing::error!("Max reconnect attempts reached");
248                        return Err(ClientError::MaxReconnectAttempts);
249                    }
250                }
251            }
252        }
253    }
254
255    async fn connect_and_run(&mut self) -> Result<(), ClientError> {
256        // Reconnect attempts share the same connect_config; clone on each attempt.
257        let connect_config = self
258            .connect_config
259            .clone()
260            .unwrap_or_else(|| T::ConnectConfig::from(self.server_addr));
261        let conn = tokio::time::timeout(self.connect_timeout, T::connect_with(connect_config))
262            .await
263            .map_err(|_| ClientError::ConnectTimeout)?
264            .map_err(|e| {
265                let io_err = std::io::Error::other(e);
266                // Normalise backend-internal timeouts to ConnectTimeout so the
267                // outer error is consistent regardless of which timer fired
268                // first (the outer wrapper here or the backend's own).
269                if io_err.kind() == std::io::ErrorKind::TimedOut {
270                    ClientError::ConnectTimeout
271                } else {
272                    ClientError::Io(io_err)
273                }
274            })?;
275
276        self.reconnect_state.on_success();
277
278        let _ = self.event_tx.send(ClientEvent::Connected);
279        self.event_notify.notify_one();
280        tracing::info!("Connected to {}", self.server_addr);
281
282        let mut session = ClientSession::new(conn);
283
284        loop {
285            tokio::select! {
286                _ = self.cmd_notify.notified() => {
287                    // Drain all available commands after notification.
288                    while let Some(cmd) = self.cmd_rx.recv() {
289                        match cmd {
290                            ClientCommand::Send(msg) => {
291                                session.send(&msg).await?;
292                            }
293                            ClientCommand::Disconnect => {
294                                return Ok(());
295                            }
296                        }
297                    }
298                }
299
300                result = session.recv() => {
301                    match result {
302                        Ok(Some(msg)) => {
303                            let _ = self.event_tx.send(ClientEvent::Message(msg.to_vec()));
304                            self.event_notify.notify_one();
305                        }
306                        Ok(None) => {
307                            return Err(ClientError::ConnectionClosed);
308                        }
309                        Err(e) => {
310                            return Err(ClientError::Io(e));
311                        }
312                    }
313                }
314            }
315        }
316    }
317}
318
319/// Handle for sending messages and receiving events.
320pub struct ClientHandle {
321    cmd_tx: spsc::SpscSender<ClientCommand>,
322    event_rx: spsc::SpscReceiver<ClientEvent>,
323    cmd_notify: Arc<Notify>,
324    event_notify: Arc<Notify>,
325}
326
327impl ClientHandle {
328    /// Constructs a [`ClientHandle`] from its raw plumbing.
329    ///
330    /// Used internally by both the multi-threaded [`Client`] builder and
331    /// the single-threaded `LocalClient` builder so both client flavours
332    /// can hand back the same handle type.
333    pub(crate) fn new(
334        cmd_tx: spsc::SpscSender<ClientCommand>,
335        event_rx: spsc::SpscReceiver<ClientEvent>,
336        cmd_notify: Arc<Notify>,
337        event_notify: Arc<Notify>,
338    ) -> Self {
339        Self {
340            cmd_tx,
341            event_rx,
342            cmd_notify,
343            event_notify,
344        }
345    }
346
347    /// Sends an SBE message to the server (non-blocking).
348    ///
349    /// # Errors
350    /// Returns error if the channel is disconnected.
351    #[inline]
352    pub fn send(&mut self, message: Vec<u8>) -> Result<(), ClientError> {
353        self.cmd_tx
354            .send(ClientCommand::Send(message))
355            .map_err(|_| ClientError::Channel)?;
356        self.cmd_notify.notify_one();
357        Ok(())
358    }
359
360    /// Disconnects from the server.
361    pub fn disconnect(&mut self) {
362        let _ = self.cmd_tx.send(ClientCommand::Disconnect);
363        self.cmd_notify.notify_one();
364    }
365
366    /// Polls for events (non-blocking).
367    #[inline]
368    pub fn poll(&mut self) -> Option<ClientEvent> {
369        self.event_rx.recv()
370    }
371
372    /// Busy-poll for next event (for hot path).
373    #[inline]
374    pub fn poll_spin(&mut self) -> ClientEvent {
375        self.event_rx.recv_spin()
376    }
377
378    /// Drains all available events.
379    pub fn drain(&mut self) -> impl Iterator<Item = ClientEvent> + '_ {
380        self.event_rx.drain()
381    }
382
383    /// Asynchronously waits for the next event.
384    ///
385    /// Returns `Some(event)` when an event is available, or keeps waiting.
386    /// Returns `None` only if the sender (client) has been dropped.
387    pub async fn wait_event(&mut self) -> Option<ClientEvent> {
388        loop {
389            if let Some(event) = self.event_rx.recv() {
390                return Some(event);
391            }
392            if !self.event_rx.is_connected() {
393                return None;
394            }
395            self.event_notify.notified().await;
396        }
397    }
398
399    /// Returns a clone of the event notification handle.
400    ///
401    /// Use this to await event availability when holding the handle behind
402    /// a `Mutex` — await the notifier *outside* the lock, then lock and
403    /// drain with \[`poll`\].
404    #[must_use]
405    pub fn event_notifier(&self) -> Arc<Notify> {
406        Arc::clone(&self.event_notify)
407    }
408}
409
410/// Commands that can be sent to the client.
411#[derive(Debug)]
412pub enum ClientCommand {
413    /// Send a message to the server.
414    Send(Vec<u8>),
415    /// Disconnect from the server.
416    Disconnect,
417}
418
419/// Events emitted by the client.
420#[derive(Debug, Clone)]
421pub enum ClientEvent {
422    /// Connected to the server.
423    Connected,
424    /// Disconnected from the server.
425    Disconnected,
426    /// Received a message from the server.
427    Message(Vec<u8>),
428    /// An error occurred.
429    Error(String),
430}
431
432#[cfg(all(test, feature = "tcp-tokio"))]
433mod tests {
434    use super::*;
435    use std::time::Duration;
436
437    type DefaultClientBuilder = ClientBuilder<ironsbe_transport::DefaultTransport>;
438
439    #[test]
440    fn test_client_builder_new() {
441        let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
442        let builder = DefaultClientBuilder::new(addr);
443        let _ = builder;
444    }
445
446    #[test]
447    fn test_client_builder_connect_timeout() {
448        let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
449        let builder = DefaultClientBuilder::new(addr).connect_timeout(Duration::from_secs(10));
450        let _ = builder;
451    }
452
453    #[test]
454    fn test_client_builder_reconnect() {
455        let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
456        let builder = DefaultClientBuilder::new(addr).reconnect(true);
457        let _ = builder;
458    }
459
460    #[test]
461    fn test_client_builder_reconnect_delay() {
462        let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
463        let builder = DefaultClientBuilder::new(addr).reconnect_delay(Duration::from_millis(500));
464        let _ = builder;
465    }
466
467    #[test]
468    fn test_client_builder_max_reconnect_attempts() {
469        let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
470        let builder = DefaultClientBuilder::new(addr).max_reconnect_attempts(5);
471        let _ = builder;
472    }
473
474    #[test]
475    fn test_client_builder_channel_capacity() {
476        let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
477        let builder = DefaultClientBuilder::new(addr).channel_capacity(8192);
478        let _ = builder;
479    }
480
481    #[test]
482    fn test_client_builder_build() {
483        let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
484        let (_client, _handle) = DefaultClientBuilder::new(addr).build();
485    }
486
487    #[test]
488    fn test_client_command_debug() {
489        let cmd = ClientCommand::Send(vec![1, 2, 3]);
490        let debug_str = format!("{:?}", cmd);
491        assert!(debug_str.contains("Send"));
492
493        let cmd2 = ClientCommand::Disconnect;
494        let debug_str2 = format!("{:?}", cmd2);
495        assert!(debug_str2.contains("Disconnect"));
496    }
497
498    #[test]
499    fn test_client_event_clone_debug() {
500        let event = ClientEvent::Connected;
501        let cloned = event.clone();
502        let _ = cloned;
503
504        let debug_str = format!("{:?}", event);
505        assert!(debug_str.contains("Connected"));
506
507        let event2 = ClientEvent::Message(vec![1, 2, 3]);
508        let debug_str2 = format!("{:?}", event2);
509        assert!(debug_str2.contains("Message"));
510
511        let event3 = ClientEvent::Error("test error".to_string());
512        let debug_str3 = format!("{:?}", event3);
513        assert!(debug_str3.contains("Error"));
514    }
515
516    #[test]
517    fn test_client_handle_disconnect() {
518        let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
519        let (_client, mut handle) = DefaultClientBuilder::new(addr).build();
520        handle.disconnect();
521    }
522
523    #[test]
524    fn test_client_handle_poll() {
525        let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
526        let (_client, mut handle) = DefaultClientBuilder::new(addr).build();
527        assert!(handle.poll().is_none());
528    }
529}