Skip to main content

ironsbe_client/
builder.rs

1//! Client builder and main client implementation.
2
3use crate::error::ClientError;
4use crate::reconnect::{ReconnectConfig, ReconnectState};
5use crate::session::ClientSession;
6use ironsbe_channel::spsc;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::Notify;
11
12/// Builder for configuring and creating a client.
13pub struct ClientBuilder {
14    server_addr: SocketAddr,
15    connect_timeout: Duration,
16    reconnect_config: ReconnectConfig,
17    channel_capacity: usize,
18}
19
20impl ClientBuilder {
21    /// Creates a new client builder for the specified server address.
22    #[must_use]
23    pub fn new(server_addr: SocketAddr) -> Self {
24        Self {
25            server_addr,
26            connect_timeout: Duration::from_secs(5),
27            reconnect_config: ReconnectConfig::default(),
28            channel_capacity: 4096,
29        }
30    }
31
32    /// Sets the connection timeout.
33    #[must_use]
34    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
35        self.connect_timeout = timeout;
36        self
37    }
38
39    /// Enables or disables automatic reconnection.
40    #[must_use]
41    pub fn reconnect(mut self, enabled: bool) -> Self {
42        self.reconnect_config.enabled = enabled;
43        self
44    }
45
46    /// Sets the reconnection delay.
47    #[must_use]
48    pub fn reconnect_delay(mut self, delay: Duration) -> Self {
49        self.reconnect_config.initial_delay = delay;
50        self
51    }
52
53    /// Sets the maximum reconnection attempts.
54    #[must_use]
55    pub fn max_reconnect_attempts(mut self, max: usize) -> Self {
56        self.reconnect_config.max_attempts = max;
57        self
58    }
59
60    /// Sets the channel capacity.
61    #[must_use]
62    pub fn channel_capacity(mut self, capacity: usize) -> Self {
63        self.channel_capacity = capacity;
64        self
65    }
66
67    /// Builds the client and handle.
68    #[must_use]
69    pub fn build(self) -> (Client, ClientHandle) {
70        let (cmd_tx, cmd_rx) = spsc::channel(self.channel_capacity);
71        let (event_tx, event_rx) = spsc::channel(self.channel_capacity);
72
73        let cmd_notify = Arc::new(Notify::new());
74        let event_notify = Arc::new(Notify::new());
75
76        let client = Client {
77            server_addr: self.server_addr,
78            connect_timeout: self.connect_timeout,
79            reconnect_state: ReconnectState::new(self.reconnect_config),
80            cmd_rx,
81            event_tx,
82            cmd_notify: Arc::clone(&cmd_notify),
83            event_notify: Arc::clone(&event_notify),
84        };
85
86        let handle = ClientHandle {
87            cmd_tx,
88            event_rx,
89            cmd_notify,
90            event_notify,
91        };
92
93        (client, handle)
94    }
95}
96
97/// The main client instance.
98pub struct Client {
99    server_addr: SocketAddr,
100    connect_timeout: Duration,
101    reconnect_state: ReconnectState,
102    cmd_rx: spsc::SpscReceiver<ClientCommand>,
103    event_tx: spsc::SpscSender<ClientEvent>,
104    cmd_notify: Arc<Notify>,
105    event_notify: Arc<Notify>,
106}
107
108impl Client {
109    /// Runs the client, connecting to the server and processing messages.
110    ///
111    /// # Errors
112    /// Returns `ClientError` if the client fails to connect or encounters an error.
113    pub async fn run(&mut self) -> Result<(), ClientError> {
114        loop {
115            match self.connect_and_run().await {
116                Ok(()) => {
117                    // Normal shutdown
118                    return Ok(());
119                }
120                Err(e) => {
121                    tracing::error!("Connection error: {:?}", e);
122
123                    if let Some(delay) = self.reconnect_state.on_failure() {
124                        let _ = self.event_tx.send(ClientEvent::Disconnected);
125                        self.event_notify.notify_one();
126                        tracing::info!("Reconnecting in {:?}...", delay);
127                        tokio::time::sleep(delay).await;
128                    } else {
129                        tracing::error!("Max reconnect attempts reached");
130                        return Err(ClientError::MaxReconnectAttempts);
131                    }
132                }
133            }
134        }
135    }
136
137    async fn connect_and_run(&mut self) -> Result<(), ClientError> {
138        let stream = tokio::time::timeout(
139            self.connect_timeout,
140            tokio::net::TcpStream::connect(self.server_addr),
141        )
142        .await
143        .map_err(|_| ClientError::ConnectTimeout)?
144        .map_err(ClientError::Io)?;
145
146        stream.set_nodelay(true)?;
147        self.reconnect_state.on_success();
148
149        let _ = self.event_tx.send(ClientEvent::Connected);
150        self.event_notify.notify_one();
151        tracing::info!("Connected to {}", self.server_addr);
152
153        let mut session = ClientSession::new(stream);
154
155        loop {
156            tokio::select! {
157                _ = self.cmd_notify.notified() => {
158                    // Drain all available commands after notification.
159                    while let Some(cmd) = self.cmd_rx.recv() {
160                        match cmd {
161                            ClientCommand::Send(msg) => {
162                                session.send(&msg).await?;
163                            }
164                            ClientCommand::Disconnect => {
165                                return Ok(());
166                            }
167                        }
168                    }
169                }
170
171                result = session.recv() => {
172                    match result {
173                        Ok(Some(msg)) => {
174                            let _ = self.event_tx.send(ClientEvent::Message(msg.to_vec()));
175                            self.event_notify.notify_one();
176                        }
177                        Ok(None) => {
178                            return Err(ClientError::ConnectionClosed);
179                        }
180                        Err(e) => {
181                            return Err(ClientError::Io(e));
182                        }
183                    }
184                }
185            }
186        }
187    }
188}
189
190/// Handle for sending messages and receiving events.
191pub struct ClientHandle {
192    cmd_tx: spsc::SpscSender<ClientCommand>,
193    event_rx: spsc::SpscReceiver<ClientEvent>,
194    cmd_notify: Arc<Notify>,
195    event_notify: Arc<Notify>,
196}
197
198impl ClientHandle {
199    /// Sends an SBE message to the server (non-blocking).
200    ///
201    /// # Errors
202    /// Returns error if the channel is disconnected.
203    #[inline]
204    pub fn send(&mut self, message: Vec<u8>) -> Result<(), ClientError> {
205        self.cmd_tx
206            .send(ClientCommand::Send(message))
207            .map_err(|_| ClientError::Channel)?;
208        self.cmd_notify.notify_one();
209        Ok(())
210    }
211
212    /// Disconnects from the server.
213    pub fn disconnect(&mut self) {
214        let _ = self.cmd_tx.send(ClientCommand::Disconnect);
215        self.cmd_notify.notify_one();
216    }
217
218    /// Polls for events (non-blocking).
219    #[inline]
220    pub fn poll(&mut self) -> Option<ClientEvent> {
221        self.event_rx.recv()
222    }
223
224    /// Busy-poll for next event (for hot path).
225    #[inline]
226    pub fn poll_spin(&mut self) -> ClientEvent {
227        self.event_rx.recv_spin()
228    }
229
230    /// Drains all available events.
231    pub fn drain(&mut self) -> impl Iterator<Item = ClientEvent> + '_ {
232        self.event_rx.drain()
233    }
234
235    /// Asynchronously waits for the next event.
236    ///
237    /// Returns `Some(event)` when an event is available, or keeps waiting.
238    /// Returns `None` only if the sender (client) has been dropped.
239    pub async fn wait_event(&mut self) -> Option<ClientEvent> {
240        loop {
241            if let Some(event) = self.event_rx.recv() {
242                return Some(event);
243            }
244            if !self.event_rx.is_connected() {
245                return None;
246            }
247            self.event_notify.notified().await;
248        }
249    }
250
251    /// Returns a clone of the event notification handle.
252    ///
253    /// Use this to await event availability when holding the handle behind
254    /// a `Mutex` — await the notifier *outside* the lock, then lock and
255    /// drain with \[`poll`\].
256    #[must_use]
257    pub fn event_notifier(&self) -> Arc<Notify> {
258        Arc::clone(&self.event_notify)
259    }
260}
261
262/// Commands that can be sent to the client.
263#[derive(Debug)]
264pub enum ClientCommand {
265    /// Send a message to the server.
266    Send(Vec<u8>),
267    /// Disconnect from the server.
268    Disconnect,
269}
270
271/// Events emitted by the client.
272#[derive(Debug, Clone)]
273pub enum ClientEvent {
274    /// Connected to the server.
275    Connected,
276    /// Disconnected from the server.
277    Disconnected,
278    /// Received a message from the server.
279    Message(Vec<u8>),
280    /// An error occurred.
281    Error(String),
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use std::time::Duration;
288
289    #[test]
290    fn test_client_builder_new() {
291        let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
292        let builder = ClientBuilder::new(addr);
293        let _ = builder;
294    }
295
296    #[test]
297    fn test_client_builder_connect_timeout() {
298        let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
299        let builder = ClientBuilder::new(addr).connect_timeout(Duration::from_secs(10));
300        let _ = builder;
301    }
302
303    #[test]
304    fn test_client_builder_reconnect() {
305        let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
306        let builder = ClientBuilder::new(addr).reconnect(true);
307        let _ = builder;
308    }
309
310    #[test]
311    fn test_client_builder_reconnect_delay() {
312        let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
313        let builder = ClientBuilder::new(addr).reconnect_delay(Duration::from_millis(500));
314        let _ = builder;
315    }
316
317    #[test]
318    fn test_client_builder_max_reconnect_attempts() {
319        let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
320        let builder = ClientBuilder::new(addr).max_reconnect_attempts(5);
321        let _ = builder;
322    }
323
324    #[test]
325    fn test_client_builder_channel_capacity() {
326        let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
327        let builder = ClientBuilder::new(addr).channel_capacity(8192);
328        let _ = builder;
329    }
330
331    #[test]
332    fn test_client_builder_build() {
333        let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
334        let (_client, _handle) = ClientBuilder::new(addr).build();
335    }
336
337    #[test]
338    fn test_client_command_debug() {
339        let cmd = ClientCommand::Send(vec![1, 2, 3]);
340        let debug_str = format!("{:?}", cmd);
341        assert!(debug_str.contains("Send"));
342
343        let cmd2 = ClientCommand::Disconnect;
344        let debug_str2 = format!("{:?}", cmd2);
345        assert!(debug_str2.contains("Disconnect"));
346    }
347
348    #[test]
349    fn test_client_event_clone_debug() {
350        let event = ClientEvent::Connected;
351        let cloned = event.clone();
352        let _ = cloned;
353
354        let debug_str = format!("{:?}", event);
355        assert!(debug_str.contains("Connected"));
356
357        let event2 = ClientEvent::Message(vec![1, 2, 3]);
358        let debug_str2 = format!("{:?}", event2);
359        assert!(debug_str2.contains("Message"));
360
361        let event3 = ClientEvent::Error("test error".to_string());
362        let debug_str3 = format!("{:?}", event3);
363        assert!(debug_str3.contains("Error"));
364    }
365
366    #[test]
367    fn test_client_handle_disconnect() {
368        let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
369        let (_client, mut handle) = ClientBuilder::new(addr).build();
370        handle.disconnect();
371    }
372
373    #[test]
374    fn test_client_handle_poll() {
375        let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
376        let (_client, mut handle) = ClientBuilder::new(addr).build();
377        assert!(handle.poll().is_none());
378    }
379}