use crate::error::ClientError;
use crate::reconnect::{ReconnectConfig, ReconnectState};
use crate::session::ClientSession;
use ironsbe_channel::spsc;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Notify;
pub struct ClientBuilder {
server_addr: SocketAddr,
connect_timeout: Duration,
reconnect_config: ReconnectConfig,
channel_capacity: usize,
}
impl ClientBuilder {
#[must_use]
pub fn new(server_addr: SocketAddr) -> Self {
Self {
server_addr,
connect_timeout: Duration::from_secs(5),
reconnect_config: ReconnectConfig::default(),
channel_capacity: 4096,
}
}
#[must_use]
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
#[must_use]
pub fn reconnect(mut self, enabled: bool) -> Self {
self.reconnect_config.enabled = enabled;
self
}
#[must_use]
pub fn reconnect_delay(mut self, delay: Duration) -> Self {
self.reconnect_config.initial_delay = delay;
self
}
#[must_use]
pub fn max_reconnect_attempts(mut self, max: usize) -> Self {
self.reconnect_config.max_attempts = max;
self
}
#[must_use]
pub fn channel_capacity(mut self, capacity: usize) -> Self {
self.channel_capacity = capacity;
self
}
#[must_use]
pub fn build(self) -> (Client, ClientHandle) {
let (cmd_tx, cmd_rx) = spsc::channel(self.channel_capacity);
let (event_tx, event_rx) = spsc::channel(self.channel_capacity);
let cmd_notify = Arc::new(Notify::new());
let event_notify = Arc::new(Notify::new());
let client = Client {
server_addr: self.server_addr,
connect_timeout: self.connect_timeout,
reconnect_state: ReconnectState::new(self.reconnect_config),
cmd_rx,
event_tx,
cmd_notify: Arc::clone(&cmd_notify),
event_notify: Arc::clone(&event_notify),
};
let handle = ClientHandle {
cmd_tx,
event_rx,
cmd_notify,
event_notify,
};
(client, handle)
}
}
pub struct Client {
server_addr: SocketAddr,
connect_timeout: Duration,
reconnect_state: ReconnectState,
cmd_rx: spsc::SpscReceiver<ClientCommand>,
event_tx: spsc::SpscSender<ClientEvent>,
cmd_notify: Arc<Notify>,
event_notify: Arc<Notify>,
}
impl Client {
pub async fn run(&mut self) -> Result<(), ClientError> {
loop {
match self.connect_and_run().await {
Ok(()) => {
return Ok(());
}
Err(e) => {
tracing::error!("Connection error: {:?}", e);
if let Some(delay) = self.reconnect_state.on_failure() {
let _ = self.event_tx.send(ClientEvent::Disconnected);
self.event_notify.notify_one();
tracing::info!("Reconnecting in {:?}...", delay);
tokio::time::sleep(delay).await;
} else {
tracing::error!("Max reconnect attempts reached");
return Err(ClientError::MaxReconnectAttempts);
}
}
}
}
}
async fn connect_and_run(&mut self) -> Result<(), ClientError> {
let stream = tokio::time::timeout(
self.connect_timeout,
tokio::net::TcpStream::connect(self.server_addr),
)
.await
.map_err(|_| ClientError::ConnectTimeout)?
.map_err(ClientError::Io)?;
stream.set_nodelay(true)?;
self.reconnect_state.on_success();
let _ = self.event_tx.send(ClientEvent::Connected);
self.event_notify.notify_one();
tracing::info!("Connected to {}", self.server_addr);
let mut session = ClientSession::new(stream);
loop {
tokio::select! {
_ = self.cmd_notify.notified() => {
while let Some(cmd) = self.cmd_rx.recv() {
match cmd {
ClientCommand::Send(msg) => {
session.send(&msg).await?;
}
ClientCommand::Disconnect => {
return Ok(());
}
}
}
}
result = session.recv() => {
match result {
Ok(Some(msg)) => {
let _ = self.event_tx.send(ClientEvent::Message(msg.to_vec()));
self.event_notify.notify_one();
}
Ok(None) => {
return Err(ClientError::ConnectionClosed);
}
Err(e) => {
return Err(ClientError::Io(e));
}
}
}
}
}
}
}
pub struct ClientHandle {
cmd_tx: spsc::SpscSender<ClientCommand>,
event_rx: spsc::SpscReceiver<ClientEvent>,
cmd_notify: Arc<Notify>,
event_notify: Arc<Notify>,
}
impl ClientHandle {
#[inline]
pub fn send(&mut self, message: Vec<u8>) -> Result<(), ClientError> {
self.cmd_tx
.send(ClientCommand::Send(message))
.map_err(|_| ClientError::Channel)?;
self.cmd_notify.notify_one();
Ok(())
}
pub fn disconnect(&mut self) {
let _ = self.cmd_tx.send(ClientCommand::Disconnect);
self.cmd_notify.notify_one();
}
#[inline]
pub fn poll(&mut self) -> Option<ClientEvent> {
self.event_rx.recv()
}
#[inline]
pub fn poll_spin(&mut self) -> ClientEvent {
self.event_rx.recv_spin()
}
pub fn drain(&mut self) -> impl Iterator<Item = ClientEvent> + '_ {
self.event_rx.drain()
}
pub async fn wait_event(&mut self) -> Option<ClientEvent> {
loop {
if let Some(event) = self.event_rx.recv() {
return Some(event);
}
if !self.event_rx.is_connected() {
return None;
}
self.event_notify.notified().await;
}
}
#[must_use]
pub fn event_notifier(&self) -> Arc<Notify> {
Arc::clone(&self.event_notify)
}
}
#[derive(Debug)]
pub enum ClientCommand {
Send(Vec<u8>),
Disconnect,
}
#[derive(Debug, Clone)]
pub enum ClientEvent {
Connected,
Disconnected,
Message(Vec<u8>),
Error(String),
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_client_builder_new() {
let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
let builder = ClientBuilder::new(addr);
let _ = builder;
}
#[test]
fn test_client_builder_connect_timeout() {
let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
let builder = ClientBuilder::new(addr).connect_timeout(Duration::from_secs(10));
let _ = builder;
}
#[test]
fn test_client_builder_reconnect() {
let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
let builder = ClientBuilder::new(addr).reconnect(true);
let _ = builder;
}
#[test]
fn test_client_builder_reconnect_delay() {
let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
let builder = ClientBuilder::new(addr).reconnect_delay(Duration::from_millis(500));
let _ = builder;
}
#[test]
fn test_client_builder_max_reconnect_attempts() {
let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
let builder = ClientBuilder::new(addr).max_reconnect_attempts(5);
let _ = builder;
}
#[test]
fn test_client_builder_channel_capacity() {
let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
let builder = ClientBuilder::new(addr).channel_capacity(8192);
let _ = builder;
}
#[test]
fn test_client_builder_build() {
let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
let (_client, _handle) = ClientBuilder::new(addr).build();
}
#[test]
fn test_client_command_debug() {
let cmd = ClientCommand::Send(vec![1, 2, 3]);
let debug_str = format!("{:?}", cmd);
assert!(debug_str.contains("Send"));
let cmd2 = ClientCommand::Disconnect;
let debug_str2 = format!("{:?}", cmd2);
assert!(debug_str2.contains("Disconnect"));
}
#[test]
fn test_client_event_clone_debug() {
let event = ClientEvent::Connected;
let cloned = event.clone();
let _ = cloned;
let debug_str = format!("{:?}", event);
assert!(debug_str.contains("Connected"));
let event2 = ClientEvent::Message(vec![1, 2, 3]);
let debug_str2 = format!("{:?}", event2);
assert!(debug_str2.contains("Message"));
let event3 = ClientEvent::Error("test error".to_string());
let debug_str3 = format!("{:?}", event3);
assert!(debug_str3.contains("Error"));
}
#[test]
fn test_client_handle_disconnect() {
let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
let (_client, mut handle) = ClientBuilder::new(addr).build();
handle.disconnect();
}
#[test]
fn test_client_handle_poll() {
let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
let (_client, mut handle) = ClientBuilder::new(addr).build();
assert!(handle.poll().is_none());
}
}