1use crate::error::ClientError;
4use crate::reconnect::{ReconnectConfig, ReconnectState};
5use crate::session::ClientSession;
6use ironsbe_channel::spsc;
7use ironsbe_transport::traits::Transport;
8use std::marker::PhantomData;
9use std::net::SocketAddr;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::Notify;
13
14#[cfg(feature = "tcp-tokio")]
22pub struct ClientBuilder<T: Transport = ironsbe_transport::DefaultTransport> {
23 server_addr: SocketAddr,
24 connect_config: Option<T::ConnectConfig>,
25 connect_timeout: Duration,
26 reconnect_config: ReconnectConfig,
27 channel_capacity: usize,
28 _transport: PhantomData<T>,
29}
30
31#[cfg(not(feature = "tcp-tokio"))]
36pub struct ClientBuilder<T: Transport> {
37 server_addr: SocketAddr,
38 connect_config: Option<T::ConnectConfig>,
39 connect_timeout: Duration,
40 reconnect_config: ReconnectConfig,
41 channel_capacity: usize,
42 _transport: PhantomData<T>,
43}
44
45impl<T: Transport> ClientBuilder<T> {
46 #[must_use]
48 pub fn new(server_addr: SocketAddr) -> Self {
49 Self {
50 server_addr,
51 connect_config: None,
52 connect_timeout: Duration::from_secs(5),
53 reconnect_config: ReconnectConfig::default(),
54 channel_capacity: 4096,
55 _transport: PhantomData,
56 }
57 }
58
59 #[must_use]
65 pub fn connect_config(mut self, config: T::ConnectConfig) -> Self {
66 self.connect_config = Some(config);
67 self
68 }
69
70 #[must_use]
79 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
80 self.connect_timeout = timeout;
81 self
82 }
83
84 #[must_use]
86 pub fn reconnect(mut self, enabled: bool) -> Self {
87 self.reconnect_config.enabled = enabled;
88 self
89 }
90
91 #[must_use]
93 pub fn reconnect_delay(mut self, delay: Duration) -> Self {
94 self.reconnect_config.initial_delay = delay;
95 self
96 }
97
98 #[must_use]
100 pub fn max_reconnect_attempts(mut self, max: usize) -> Self {
101 self.reconnect_config.max_attempts = max;
102 self
103 }
104
105 #[must_use]
107 pub fn channel_capacity(mut self, capacity: usize) -> Self {
108 self.channel_capacity = capacity;
109 self
110 }
111
112 #[must_use]
114 pub fn build(self) -> (Client<T>, ClientHandle) {
115 let (cmd_tx, cmd_rx) = spsc::channel(self.channel_capacity);
116 let (event_tx, event_rx) = spsc::channel(self.channel_capacity);
117
118 let cmd_notify = Arc::new(Notify::new());
119 let event_notify = Arc::new(Notify::new());
120
121 let client = Client {
122 server_addr: self.server_addr,
123 connect_config: Some(
124 self.connect_config
125 .unwrap_or_else(|| T::ConnectConfig::from(self.server_addr)),
126 ),
127 connect_timeout: self.connect_timeout,
128 reconnect_state: ReconnectState::new(self.reconnect_config),
129 cmd_rx,
130 event_tx,
131 cmd_notify: Arc::clone(&cmd_notify),
132 event_notify: Arc::clone(&event_notify),
133 _transport: PhantomData,
134 };
135
136 let handle = ClientHandle {
137 cmd_tx,
138 event_rx,
139 cmd_notify,
140 event_notify,
141 };
142
143 (client, handle)
144 }
145}
146
147#[cfg(feature = "tcp-tokio")]
148impl ClientBuilder {
149 #[must_use]
156 pub fn with_default_transport(server_addr: SocketAddr) -> Self {
157 Self::new(server_addr)
158 }
159
160 #[must_use]
166 pub fn max_frame_size(mut self, size: usize) -> Self {
167 let cfg = self
168 .connect_config
169 .take()
170 .unwrap_or_else(|| ironsbe_transport::tcp::TcpClientConfig::new(self.server_addr));
171 self.connect_config = Some(cfg.max_frame_size(size));
172 self
173 }
174
175 #[must_use]
183 pub fn tcp_connect_timeout(mut self, timeout: Duration) -> Self {
184 self.connect_timeout = timeout;
185 let cfg = self
186 .connect_config
187 .take()
188 .unwrap_or_else(|| ironsbe_transport::tcp::TcpClientConfig::new(self.server_addr));
189 self.connect_config = Some(cfg.connect_timeout(timeout));
190 self
191 }
192}
193
194#[cfg(feature = "tcp-tokio")]
198pub struct Client<T: Transport = ironsbe_transport::DefaultTransport> {
199 server_addr: SocketAddr,
200 connect_config: Option<T::ConnectConfig>,
201 connect_timeout: Duration,
202 reconnect_state: ReconnectState,
203 cmd_rx: spsc::SpscReceiver<ClientCommand>,
204 event_tx: spsc::SpscSender<ClientEvent>,
205 cmd_notify: Arc<Notify>,
206 event_notify: Arc<Notify>,
207 _transport: PhantomData<T>,
208}
209
210#[cfg(not(feature = "tcp-tokio"))]
214pub struct Client<T: Transport> {
215 server_addr: SocketAddr,
216 connect_config: Option<T::ConnectConfig>,
217 connect_timeout: Duration,
218 reconnect_state: ReconnectState,
219 cmd_rx: spsc::SpscReceiver<ClientCommand>,
220 event_tx: spsc::SpscSender<ClientEvent>,
221 cmd_notify: Arc<Notify>,
222 event_notify: Arc<Notify>,
223 _transport: PhantomData<T>,
224}
225
226impl<T: Transport> Client<T> {
227 pub async fn run(&mut self) -> Result<(), ClientError> {
232 loop {
233 match self.connect_and_run().await {
234 Ok(()) => {
235 return Ok(());
237 }
238 Err(e) => {
239 tracing::error!("Connection error: {:?}", e);
240
241 if let Some(delay) = self.reconnect_state.on_failure() {
242 let _ = self.event_tx.send(ClientEvent::Disconnected);
243 self.event_notify.notify_one();
244 tracing::info!("Reconnecting in {:?}...", delay);
245 tokio::time::sleep(delay).await;
246 } else {
247 tracing::error!("Max reconnect attempts reached");
248 return Err(ClientError::MaxReconnectAttempts);
249 }
250 }
251 }
252 }
253 }
254
255 async fn connect_and_run(&mut self) -> Result<(), ClientError> {
256 let connect_config = self
258 .connect_config
259 .clone()
260 .unwrap_or_else(|| T::ConnectConfig::from(self.server_addr));
261 let conn = tokio::time::timeout(self.connect_timeout, T::connect_with(connect_config))
262 .await
263 .map_err(|_| ClientError::ConnectTimeout)?
264 .map_err(|e| {
265 let io_err = std::io::Error::other(e);
266 if io_err.kind() == std::io::ErrorKind::TimedOut {
270 ClientError::ConnectTimeout
271 } else {
272 ClientError::Io(io_err)
273 }
274 })?;
275
276 self.reconnect_state.on_success();
277
278 let _ = self.event_tx.send(ClientEvent::Connected);
279 self.event_notify.notify_one();
280 tracing::info!("Connected to {}", self.server_addr);
281
282 let mut session = ClientSession::new(conn);
283
284 loop {
285 tokio::select! {
286 _ = self.cmd_notify.notified() => {
287 while let Some(cmd) = self.cmd_rx.recv() {
289 match cmd {
290 ClientCommand::Send(msg) => {
291 session.send(&msg).await?;
292 }
293 ClientCommand::Disconnect => {
294 return Ok(());
295 }
296 }
297 }
298 }
299
300 result = session.recv() => {
301 match result {
302 Ok(Some(msg)) => {
303 let _ = self.event_tx.send(ClientEvent::Message(msg.to_vec()));
304 self.event_notify.notify_one();
305 }
306 Ok(None) => {
307 return Err(ClientError::ConnectionClosed);
308 }
309 Err(e) => {
310 return Err(ClientError::Io(e));
311 }
312 }
313 }
314 }
315 }
316 }
317}
318
319pub struct ClientHandle {
321 cmd_tx: spsc::SpscSender<ClientCommand>,
322 event_rx: spsc::SpscReceiver<ClientEvent>,
323 cmd_notify: Arc<Notify>,
324 event_notify: Arc<Notify>,
325}
326
327impl ClientHandle {
328 pub(crate) fn new(
334 cmd_tx: spsc::SpscSender<ClientCommand>,
335 event_rx: spsc::SpscReceiver<ClientEvent>,
336 cmd_notify: Arc<Notify>,
337 event_notify: Arc<Notify>,
338 ) -> Self {
339 Self {
340 cmd_tx,
341 event_rx,
342 cmd_notify,
343 event_notify,
344 }
345 }
346
347 #[inline]
352 pub fn send(&mut self, message: Vec<u8>) -> Result<(), ClientError> {
353 self.cmd_tx
354 .send(ClientCommand::Send(message))
355 .map_err(|_| ClientError::Channel)?;
356 self.cmd_notify.notify_one();
357 Ok(())
358 }
359
360 pub fn disconnect(&mut self) {
362 let _ = self.cmd_tx.send(ClientCommand::Disconnect);
363 self.cmd_notify.notify_one();
364 }
365
366 #[inline]
368 pub fn poll(&mut self) -> Option<ClientEvent> {
369 self.event_rx.recv()
370 }
371
372 #[inline]
374 pub fn poll_spin(&mut self) -> ClientEvent {
375 self.event_rx.recv_spin()
376 }
377
378 pub fn drain(&mut self) -> impl Iterator<Item = ClientEvent> + '_ {
380 self.event_rx.drain()
381 }
382
383 pub async fn wait_event(&mut self) -> Option<ClientEvent> {
388 loop {
389 if let Some(event) = self.event_rx.recv() {
390 return Some(event);
391 }
392 if !self.event_rx.is_connected() {
393 return None;
394 }
395 self.event_notify.notified().await;
396 }
397 }
398
399 #[must_use]
405 pub fn event_notifier(&self) -> Arc<Notify> {
406 Arc::clone(&self.event_notify)
407 }
408}
409
410#[derive(Debug)]
412pub enum ClientCommand {
413 Send(Vec<u8>),
415 Disconnect,
417}
418
419#[derive(Debug, Clone)]
421pub enum ClientEvent {
422 Connected,
424 Disconnected,
426 Message(Vec<u8>),
428 Error(String),
430}
431
432#[cfg(all(test, feature = "tcp-tokio"))]
433mod tests {
434 use super::*;
435 use std::time::Duration;
436
437 type DefaultClientBuilder = ClientBuilder<ironsbe_transport::DefaultTransport>;
438
439 #[test]
440 fn test_client_builder_new() {
441 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
442 let builder = DefaultClientBuilder::new(addr);
443 let _ = builder;
444 }
445
446 #[test]
447 fn test_client_builder_connect_timeout() {
448 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
449 let builder = DefaultClientBuilder::new(addr).connect_timeout(Duration::from_secs(10));
450 let _ = builder;
451 }
452
453 #[test]
454 fn test_client_builder_reconnect() {
455 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
456 let builder = DefaultClientBuilder::new(addr).reconnect(true);
457 let _ = builder;
458 }
459
460 #[test]
461 fn test_client_builder_reconnect_delay() {
462 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
463 let builder = DefaultClientBuilder::new(addr).reconnect_delay(Duration::from_millis(500));
464 let _ = builder;
465 }
466
467 #[test]
468 fn test_client_builder_max_reconnect_attempts() {
469 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
470 let builder = DefaultClientBuilder::new(addr).max_reconnect_attempts(5);
471 let _ = builder;
472 }
473
474 #[test]
475 fn test_client_builder_channel_capacity() {
476 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
477 let builder = DefaultClientBuilder::new(addr).channel_capacity(8192);
478 let _ = builder;
479 }
480
481 #[test]
482 fn test_client_builder_build() {
483 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
484 let (_client, _handle) = DefaultClientBuilder::new(addr).build();
485 }
486
487 #[test]
488 fn test_client_command_debug() {
489 let cmd = ClientCommand::Send(vec![1, 2, 3]);
490 let debug_str = format!("{:?}", cmd);
491 assert!(debug_str.contains("Send"));
492
493 let cmd2 = ClientCommand::Disconnect;
494 let debug_str2 = format!("{:?}", cmd2);
495 assert!(debug_str2.contains("Disconnect"));
496 }
497
498 #[test]
499 fn test_client_event_clone_debug() {
500 let event = ClientEvent::Connected;
501 let cloned = event.clone();
502 let _ = cloned;
503
504 let debug_str = format!("{:?}", event);
505 assert!(debug_str.contains("Connected"));
506
507 let event2 = ClientEvent::Message(vec![1, 2, 3]);
508 let debug_str2 = format!("{:?}", event2);
509 assert!(debug_str2.contains("Message"));
510
511 let event3 = ClientEvent::Error("test error".to_string());
512 let debug_str3 = format!("{:?}", event3);
513 assert!(debug_str3.contains("Error"));
514 }
515
516 #[test]
517 fn test_client_handle_disconnect() {
518 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
519 let (_client, mut handle) = DefaultClientBuilder::new(addr).build();
520 handle.disconnect();
521 }
522
523 #[test]
524 fn test_client_handle_poll() {
525 let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
526 let (_client, mut handle) = DefaultClientBuilder::new(addr).build();
527 assert!(handle.poll().is_none());
528 }
529}