ironsbe_client/
builder.rs1use crate::error::ClientError;
4use crate::reconnect::{ReconnectConfig, ReconnectState};
5use crate::session::ClientSession;
6use ironsbe_channel::spsc;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::Notify;
11
12pub struct ClientBuilder {
14 server_addr: SocketAddr,
15 connect_timeout: Duration,
16 reconnect_config: ReconnectConfig,
17 channel_capacity: usize,
18}
19
20impl ClientBuilder {
21 #[must_use]
23 pub fn new(server_addr: SocketAddr) -> Self {
24 Self {
25 server_addr,
26 connect_timeout: Duration::from_secs(5),
27 reconnect_config: ReconnectConfig::default(),
28 channel_capacity: 4096,
29 }
30 }
31
32 #[must_use]
34 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
35 self.connect_timeout = timeout;
36 self
37 }
38
39 #[must_use]
41 pub fn reconnect(mut self, enabled: bool) -> Self {
42 self.reconnect_config.enabled = enabled;
43 self
44 }
45
46 #[must_use]
48 pub fn reconnect_delay(mut self, delay: Duration) -> Self {
49 self.reconnect_config.initial_delay = delay;
50 self
51 }
52
53 #[must_use]
55 pub fn max_reconnect_attempts(mut self, max: usize) -> Self {
56 self.reconnect_config.max_attempts = max;
57 self
58 }
59
60 #[must_use]
62 pub fn channel_capacity(mut self, capacity: usize) -> Self {
63 self.channel_capacity = capacity;
64 self
65 }
66
67 #[must_use]
69 pub fn build(self) -> (Client, ClientHandle) {
70 let (cmd_tx, cmd_rx) = spsc::channel(self.channel_capacity);
71 let (event_tx, event_rx) = spsc::channel(self.channel_capacity);
72
73 let cmd_notify = Arc::new(Notify::new());
74 let event_notify = Arc::new(Notify::new());
75
76 let client = Client {
77 server_addr: self.server_addr,
78 connect_timeout: self.connect_timeout,
79 reconnect_state: ReconnectState::new(self.reconnect_config),
80 cmd_rx,
81 event_tx,
82 cmd_notify: Arc::clone(&cmd_notify),
83 event_notify: Arc::clone(&event_notify),
84 };
85
86 let handle = ClientHandle {
87 cmd_tx,
88 event_rx,
89 cmd_notify,
90 event_notify,
91 };
92
93 (client, handle)
94 }
95}
96
97pub struct Client {
99 server_addr: SocketAddr,
100 connect_timeout: Duration,
101 reconnect_state: ReconnectState,
102 cmd_rx: spsc::SpscReceiver<ClientCommand>,
103 event_tx: spsc::SpscSender<ClientEvent>,
104 cmd_notify: Arc<Notify>,
105 event_notify: Arc<Notify>,
106}
107
108impl Client {
109 pub async fn run(&mut self) -> Result<(), ClientError> {
114 loop {
115 match self.connect_and_run().await {
116 Ok(()) => {
117 return Ok(());
119 }
120 Err(e) => {
121 tracing::error!("Connection error: {:?}", e);
122
123 if let Some(delay) = self.reconnect_state.on_failure() {
124 let _ = self.event_tx.send(ClientEvent::Disconnected);
125 self.event_notify.notify_one();
126 tracing::info!("Reconnecting in {:?}...", delay);
127 tokio::time::sleep(delay).await;
128 } else {
129 tracing::error!("Max reconnect attempts reached");
130 return Err(ClientError::MaxReconnectAttempts);
131 }
132 }
133 }
134 }
135 }
136
137 async fn connect_and_run(&mut self) -> Result<(), ClientError> {
138 let stream = tokio::time::timeout(
139 self.connect_timeout,
140 tokio::net::TcpStream::connect(self.server_addr),
141 )
142 .await
143 .map_err(|_| ClientError::ConnectTimeout)?
144 .map_err(ClientError::Io)?;
145
146 stream.set_nodelay(true)?;
147 self.reconnect_state.on_success();
148
149 let _ = self.event_tx.send(ClientEvent::Connected);
150 self.event_notify.notify_one();
151 tracing::info!("Connected to {}", self.server_addr);
152
153 let mut session = ClientSession::new(stream);
154
155 loop {
156 tokio::select! {
157 _ = self.cmd_notify.notified() => {
158 while let Some(cmd) = self.cmd_rx.recv() {
160 match cmd {
161 ClientCommand::Send(msg) => {
162 session.send(&msg).await?;
163 }
164 ClientCommand::Disconnect => {
165 return Ok(());
166 }
167 }
168 }
169 }
170
171 result = session.recv() => {
172 match result {
173 Ok(Some(msg)) => {
174 let _ = self.event_tx.send(ClientEvent::Message(msg.to_vec()));
175 self.event_notify.notify_one();
176 }
177 Ok(None) => {
178 return Err(ClientError::ConnectionClosed);
179 }
180 Err(e) => {
181 return Err(ClientError::Io(e));
182 }
183 }
184 }
185 }
186 }
187 }
188}
189
190pub struct ClientHandle {
192 cmd_tx: spsc::SpscSender<ClientCommand>,
193 event_rx: spsc::SpscReceiver<ClientEvent>,
194 cmd_notify: Arc<Notify>,
195 event_notify: Arc<Notify>,
196}
197
198impl ClientHandle {
199 #[inline]
204 pub fn send(&mut self, message: Vec<u8>) -> Result<(), ClientError> {
205 self.cmd_tx
206 .send(ClientCommand::Send(message))
207 .map_err(|_| ClientError::Channel)?;
208 self.cmd_notify.notify_one();
209 Ok(())
210 }
211
212 pub fn disconnect(&mut self) {
214 let _ = self.cmd_tx.send(ClientCommand::Disconnect);
215 self.cmd_notify.notify_one();
216 }
217
218 #[inline]
220 pub fn poll(&mut self) -> Option<ClientEvent> {
221 self.event_rx.recv()
222 }
223
224 #[inline]
226 pub fn poll_spin(&mut self) -> ClientEvent {
227 self.event_rx.recv_spin()
228 }
229
230 pub fn drain(&mut self) -> impl Iterator<Item = ClientEvent> + '_ {
232 self.event_rx.drain()
233 }
234
235 pub async fn wait_event(&mut self) -> Option<ClientEvent> {
240 loop {
241 if let Some(event) = self.event_rx.recv() {
242 return Some(event);
243 }
244 if !self.event_rx.is_connected() {
245 return None;
246 }
247 self.event_notify.notified().await;
248 }
249 }
250
251 #[must_use]
257 pub fn event_notifier(&self) -> Arc<Notify> {
258 Arc::clone(&self.event_notify)
259 }
260}
261
262#[derive(Debug)]
264pub enum ClientCommand {
265 Send(Vec<u8>),
267 Disconnect,
269}
270
271#[derive(Debug, Clone)]
273pub enum ClientEvent {
274 Connected,
276 Disconnected,
278 Message(Vec<u8>),
280 Error(String),
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use std::time::Duration;
288
289 #[test]
290 fn test_client_builder_new() {
291 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
292 let builder = ClientBuilder::new(addr);
293 let _ = builder;
294 }
295
296 #[test]
297 fn test_client_builder_connect_timeout() {
298 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
299 let builder = ClientBuilder::new(addr).connect_timeout(Duration::from_secs(10));
300 let _ = builder;
301 }
302
303 #[test]
304 fn test_client_builder_reconnect() {
305 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
306 let builder = ClientBuilder::new(addr).reconnect(true);
307 let _ = builder;
308 }
309
310 #[test]
311 fn test_client_builder_reconnect_delay() {
312 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
313 let builder = ClientBuilder::new(addr).reconnect_delay(Duration::from_millis(500));
314 let _ = builder;
315 }
316
317 #[test]
318 fn test_client_builder_max_reconnect_attempts() {
319 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
320 let builder = ClientBuilder::new(addr).max_reconnect_attempts(5);
321 let _ = builder;
322 }
323
324 #[test]
325 fn test_client_builder_channel_capacity() {
326 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
327 let builder = ClientBuilder::new(addr).channel_capacity(8192);
328 let _ = builder;
329 }
330
331 #[test]
332 fn test_client_builder_build() {
333 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
334 let (_client, _handle) = ClientBuilder::new(addr).build();
335 }
336
337 #[test]
338 fn test_client_command_debug() {
339 let cmd = ClientCommand::Send(vec![1, 2, 3]);
340 let debug_str = format!("{:?}", cmd);
341 assert!(debug_str.contains("Send"));
342
343 let cmd2 = ClientCommand::Disconnect;
344 let debug_str2 = format!("{:?}", cmd2);
345 assert!(debug_str2.contains("Disconnect"));
346 }
347
348 #[test]
349 fn test_client_event_clone_debug() {
350 let event = ClientEvent::Connected;
351 let cloned = event.clone();
352 let _ = cloned;
353
354 let debug_str = format!("{:?}", event);
355 assert!(debug_str.contains("Connected"));
356
357 let event2 = ClientEvent::Message(vec![1, 2, 3]);
358 let debug_str2 = format!("{:?}", event2);
359 assert!(debug_str2.contains("Message"));
360
361 let event3 = ClientEvent::Error("test error".to_string());
362 let debug_str3 = format!("{:?}", event3);
363 assert!(debug_str3.contains("Error"));
364 }
365
366 #[test]
367 fn test_client_handle_disconnect() {
368 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
369 let (_client, mut handle) = ClientBuilder::new(addr).build();
370 handle.disconnect();
371 }
372
373 #[test]
374 fn test_client_handle_poll() {
375 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
376 let (_client, mut handle) = ClientBuilder::new(addr).build();
377 assert!(handle.poll().is_none());
378 }
379}