use crate::error::Result;
use crate::io::reconnect::ReconnectConfig;
use crate::io::sync_client::SyncTcpClient;
use crate::io::unified_async_client::UnifiedAsyncClient;
use crate::io::unified_client::{AsyncIgtlClient, SyncIgtlClient};
use crate::io::UdpClient;
use std::marker::PhantomData;
use std::sync::Arc;
use tokio_rustls::rustls;
pub struct Unspecified;
#[allow(dead_code)]
pub struct TcpConfigured {
pub(crate) addr: String,
}
#[allow(dead_code)]
pub struct UdpConfigured {
pub(crate) addr: String,
}
pub struct SyncMode;
pub struct AsyncMode;
pub struct ClientBuilder<Protocol = Unspecified, Mode = Unspecified> {
protocol: Protocol,
mode: PhantomData<Mode>,
tls_config: Option<Arc<rustls::ClientConfig>>,
reconnect_config: Option<ReconnectConfig>,
verify_crc: bool,
}
impl ClientBuilder<Unspecified, Unspecified> {
pub fn new() -> Self {
Self {
protocol: Unspecified,
mode: PhantomData,
tls_config: None,
reconnect_config: None,
verify_crc: true,
}
}
}
impl Default for ClientBuilder<Unspecified, Unspecified> {
fn default() -> Self {
Self::new()
}
}
impl ClientBuilder<Unspecified, Unspecified> {
pub fn tcp(self, addr: impl Into<String>) -> ClientBuilder<TcpConfigured, Unspecified> {
ClientBuilder {
protocol: TcpConfigured { addr: addr.into() },
mode: PhantomData,
tls_config: self.tls_config,
reconnect_config: self.reconnect_config,
verify_crc: self.verify_crc,
}
}
pub fn udp(self, addr: impl Into<String>) -> ClientBuilder<UdpConfigured, SyncMode> {
ClientBuilder {
protocol: UdpConfigured { addr: addr.into() },
mode: PhantomData,
tls_config: self.tls_config,
reconnect_config: self.reconnect_config,
verify_crc: self.verify_crc,
}
}
}
impl ClientBuilder<TcpConfigured, Unspecified> {
pub fn sync(self) -> ClientBuilder<TcpConfigured, SyncMode> {
ClientBuilder {
protocol: self.protocol,
mode: PhantomData,
tls_config: self.tls_config,
reconnect_config: self.reconnect_config,
verify_crc: self.verify_crc,
}
}
pub fn async_mode(self) -> ClientBuilder<TcpConfigured, AsyncMode> {
ClientBuilder {
protocol: self.protocol,
mode: PhantomData,
tls_config: self.tls_config,
reconnect_config: self.reconnect_config,
verify_crc: self.verify_crc,
}
}
}
impl ClientBuilder<TcpConfigured, SyncMode> {
pub fn build(self) -> Result<SyncIgtlClient> {
let mut client = SyncTcpClient::connect(&self.protocol.addr)?;
client.set_verify_crc(self.verify_crc);
Ok(SyncIgtlClient::TcpSync(client))
}
}
impl ClientBuilder<TcpConfigured, AsyncMode> {
pub fn with_tls(mut self, config: Arc<rustls::ClientConfig>) -> Self {
self.tls_config = Some(config);
self
}
pub fn with_reconnect(mut self, config: ReconnectConfig) -> Self {
self.reconnect_config = Some(config);
self
}
pub async fn build(self) -> Result<AsyncIgtlClient> {
let addr = self.protocol.addr;
let mut client = if let Some(tls_config) = self.tls_config {
let (hostname, port) = parse_addr(&addr)?;
UnifiedAsyncClient::connect_with_tls(&hostname, port, tls_config).await?
} else {
UnifiedAsyncClient::connect(&addr).await?
};
if let Some(reconnect_config) = self.reconnect_config {
client = client.with_reconnect(reconnect_config);
}
client.set_verify_crc(self.verify_crc);
Ok(AsyncIgtlClient::Unified(client))
}
}
impl<Protocol, Mode> ClientBuilder<Protocol, Mode> {
pub fn verify_crc(mut self, verify: bool) -> Self {
self.verify_crc = verify;
self
}
}
fn parse_addr(addr: &str) -> Result<(String, u16)> {
let parts: Vec<&str> = addr.rsplitn(2, ':').collect();
if parts.len() != 2 {
return Err(crate::error::IgtlError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Invalid address format: {}", addr),
)));
}
let port = parts[0].parse::<u16>().map_err(|e| {
crate::error::IgtlError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Invalid port number: {}", e),
))
})?;
let hostname = parts[1].to_string();
if hostname.is_empty() {
return Err(crate::error::IgtlError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Hostname cannot be empty",
)));
}
Ok((hostname, port))
}
impl ClientBuilder<UdpConfigured, SyncMode> {
pub fn build(self) -> Result<UdpClient> {
UdpClient::bind(&self.protocol.addr)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_phantom_data_is_zero_size() {
use std::mem::size_of;
assert_eq!(size_of::<PhantomData<SyncMode>>(), 0);
assert_eq!(size_of::<PhantomData<AsyncMode>>(), 0);
let base_size = size_of::<ClientBuilder<Unspecified, Unspecified>>();
let tcp_unspecified = size_of::<ClientBuilder<TcpConfigured, Unspecified>>();
let tcp_sync = size_of::<ClientBuilder<TcpConfigured, SyncMode>>();
let tcp_async = size_of::<ClientBuilder<TcpConfigured, AsyncMode>>();
assert_eq!(tcp_unspecified, tcp_sync);
assert_eq!(tcp_unspecified, tcp_async);
assert!(tcp_unspecified > base_size);
}
#[test]
fn test_builder_state_transitions() {
let builder = ClientBuilder::new();
let builder = builder.tcp("127.0.0.1:18944");
let _sync_builder = builder.sync();
let builder = ClientBuilder::new().tcp("127.0.0.1:18944");
let _async_builder = builder.async_mode();
let _udp_builder = ClientBuilder::new().udp("127.0.0.1:18944");
}
#[test]
fn test_parse_addr() {
assert_eq!(
parse_addr("localhost:18944").unwrap(),
("localhost".to_string(), 18944)
);
assert_eq!(
parse_addr("127.0.0.1:8080").unwrap(),
("127.0.0.1".to_string(), 8080)
);
assert_eq!(
parse_addr("example.com:443").unwrap(),
("example.com".to_string(), 443)
);
assert!(parse_addr("invalid").is_err());
assert!(parse_addr("localhost:").is_err());
assert!(parse_addr(":18944").is_err());
assert!(parse_addr("localhost:abc").is_err());
}
#[test]
fn test_builder_options() {
let builder = ClientBuilder::new()
.tcp("127.0.0.1:18944")
.sync()
.verify_crc(false);
assert!(!builder.verify_crc);
let tls_config = Arc::new(
rustls::ClientConfig::builder()
.with_root_certificates(rustls::RootCertStore::empty())
.with_no_client_auth(),
);
let builder = ClientBuilder::new()
.tcp("127.0.0.1:18944")
.async_mode()
.with_tls(tls_config.clone());
assert!(builder.tls_config.is_some());
let builder = ClientBuilder::new()
.tcp("127.0.0.1:18944")
.async_mode()
.with_reconnect(ReconnectConfig::default());
assert!(builder.reconnect_config.is_some());
}
}