mod private
{
use serde::{ Deserialize, Serialize };
use core::time::Duration;
use std::sync::{ Arc, Mutex };
use tokio::sync::{ mpsc, broadcast };
#[ derive( Debug, Clone, PartialEq, Eq, Serialize, Deserialize ) ]
pub enum WebSocketConnectionState
{
Connecting,
Connected,
Disconnected,
Closing,
Closed,
Error,
}
#[ derive( Debug, Clone, Serialize, Deserialize, PartialEq ) ]
pub struct WebSocketConfig
{
pub heartbeat_interval : Duration,
pub max_message_size : usize,
pub enable_compression : bool,
pub reconnect_attempts : u32,
pub connection_timeout : Duration,
pub fallback_to_http : bool,
}
impl Default for WebSocketConfig
{
#[ inline ]
fn default() -> Self
{
Self {
heartbeat_interval : Duration::from_secs( 30 ),
max_message_size : 1024 * 1024, enable_compression : true,
reconnect_attempts : 3,
connection_timeout : Duration::from_secs( 10 ),
fallback_to_http : true,
}
}
}
#[ derive( Debug, Clone ) ]
pub struct WebSocketConfigBuilder
{
config : WebSocketConfig,
}
impl Default for WebSocketConfigBuilder
{
#[ inline ]
fn default() -> Self
{
Self {
config : WebSocketConfig::default(),
}
}
}
impl WebSocketConfigBuilder
{
#[ inline ]
#[ must_use ]
pub fn new() -> Self
{
Self {
config : WebSocketConfig::default(),
}
}
#[ inline ]
#[ must_use ]
pub fn heartbeat_interval( mut self, interval : Duration ) -> Self
{
self.config.heartbeat_interval = interval;
self
}
#[ inline ]
#[ must_use ]
pub fn max_message_size( mut self, size : usize ) -> Self
{
self.config.max_message_size = size;
self
}
#[ inline ]
#[ must_use ]
pub fn enable_compression( mut self, enable : bool ) -> Self
{
self.config.enable_compression = enable;
self
}
#[ inline ]
#[ must_use ]
pub fn reconnect_attempts( mut self, attempts : u32 ) -> Self
{
self.config.reconnect_attempts = attempts;
self
}
#[ inline ]
#[ must_use ]
pub fn connection_timeout( mut self, timeout : Duration ) -> Self
{
self.config.connection_timeout = timeout;
self
}
#[ inline ]
#[ must_use ]
pub fn fallback_to_http( mut self, fallback : bool ) -> Self
{
self.config.fallback_to_http = fallback;
self
}
#[ inline ]
pub fn build( self ) -> Result< WebSocketConfig, crate::error::Error >
{
if self.config.heartbeat_interval.is_zero()
{
return Err( crate::error::Error::ConfigurationError(
"Heartbeat interval must be greater than 0".to_string()
) );
}
if self.config.max_message_size == 0
{
return Err( crate::error::Error::ConfigurationError(
"Max message size must be greater than 0".to_string()
) );
}
if self.config.max_message_size > 5 * 1024 * 1024 {
return Err( crate::error::Error::ConfigurationError(
"Max message size cannot exceed 5MB".to_string()
) );
}
if self.config.reconnect_attempts == 0
{
return Err( crate::error::Error::ConfigurationError(
"Reconnect attempts must be greater than 0".to_string()
) );
}
if self.config.connection_timeout.is_zero()
{
return Err( crate::error::Error::ConfigurationError(
"Connection timeout must be greater than 0".to_string()
) );
}
Ok( self.config )
}
}
impl WebSocketConfig
{
#[ inline ]
#[ must_use ]
pub fn builder() -> WebSocketConfigBuilder
{
WebSocketConfigBuilder::new()
}
}
#[ derive( Debug, Clone ) ]
pub struct WebSocketPoolConfig
{
pub max_connections : usize,
pub connection_timeout : Duration,
pub idle_timeout : Duration,
}
impl Default for WebSocketPoolConfig
{
#[ inline ]
fn default() -> Self
{
Self {
max_connections : 10,
connection_timeout : Duration::from_secs( 30 ),
idle_timeout : Duration::from_secs( 300 ), }
}
}
#[ derive( Debug, Clone ) ]
pub struct WebSocketPoolConfigBuilder
{
config : WebSocketPoolConfig,
}
impl WebSocketPoolConfigBuilder
{
#[ inline ]
#[ must_use ]
pub fn new() -> Self
{
Self {
config : WebSocketPoolConfig::default(),
}
}
pub fn max_connections( mut self, max : usize ) -> Self
{
self.config.max_connections = max;
self
}
pub fn connection_timeout( mut self, timeout : Duration ) -> Self
{
self.config.connection_timeout = timeout;
self
}
pub fn idle_timeout( mut self, timeout : Duration ) -> Self
{
self.config.idle_timeout = timeout;
self
}
pub fn build( self ) -> Result< WebSocketPoolConfig, crate::error::Error >
{
if self.config.max_connections == 0
{
return Err( crate::error::Error::ConfigurationError(
"Max connections must be greater than 0".to_string()
) );
}
Ok( self.config )
}
}
impl WebSocketPoolConfig
{
pub fn builder() -> WebSocketPoolConfigBuilder
{
WebSocketPoolConfigBuilder::new()
}
}
#[ derive( Debug, Clone ) ]
pub enum WebSocketMessage
{
Text( String ),
Binary( Vec< u8 > ),
Ping( Vec< u8 > ),
Pong( Vec< u8 > ),
Close( Option< String > ),
}
#[ derive( Debug, Clone, Serialize, Deserialize, PartialEq ) ]
pub struct WebSocketMetrics
{
pub messages_sent : u64,
pub messages_received : u64,
pub bytes_sent : u64,
pub bytes_received : u64,
pub connection_count : u32,
pub reconnection_count : u32,
pub error_count : u32,
}
impl Default for WebSocketMetrics
{
fn default() -> Self
{
Self {
messages_sent : 0,
messages_received : 0,
bytes_sent : 0,
bytes_received : 0,
connection_count : 0,
reconnection_count : 0,
error_count : 0,
}
}
}
pub struct WebSocketConnection
{
state : Arc< Mutex< WebSocketConnectionState > >,
config : WebSocketConfig,
metrics : Arc< Mutex< WebSocketMetrics > >,
message_tx : mpsc::UnboundedSender< WebSocketMessage >,
message_rx : mpsc::UnboundedReceiver< WebSocketMessage >,
state_tx : broadcast::Sender< WebSocketConnectionState >,
}
impl WebSocketConnection
{
pub fn new( config : WebSocketConfig ) -> Self
{
let ( message_tx, message_rx ) = mpsc::unbounded_channel();
let ( state_tx, _state_rx ) = broadcast::channel( 16 );
Self {
state : Arc::new( Mutex::new( WebSocketConnectionState::Connecting ) ),
config,
metrics : Arc::new( Mutex::new( WebSocketMetrics::default() ) ),
message_tx,
message_rx,
state_tx,
}
}
pub fn state( &self ) -> WebSocketConnectionState
{
self.state.lock().unwrap().clone()
}
pub fn is_connected( &self ) -> bool
{
matches!( self.state(), WebSocketConnectionState::Connected )
}
pub fn get_metrics( &self ) -> WebSocketMetrics
{
self.metrics.lock().unwrap().clone()
}
pub async fn send_message( &self, message : WebSocketMessage ) -> Result< (), crate::error::Error >
{
if !self.is_connected()
{
return Err( crate::error::Error::ApiError(
"WebSocket is not connected".to_string()
) );
}
self.message_tx.send( message )
.map_err( |_| crate::error::Error::ApiError( "Failed to send message".to_string() ) )?;
let mut metrics = self.metrics.lock().unwrap();
metrics.messages_sent += 1;
Ok( () )
}
pub async fn receive_message( &mut self ) -> Option< WebSocketMessage >
{
let message = self.message_rx.recv().await;
if message.is_some()
{
let mut metrics = self.metrics.lock().unwrap();
metrics.messages_received += 1;
}
message
}
pub fn subscribe_state_changes( &self ) -> broadcast::Receiver< WebSocketConnectionState >
{
self.state_tx.subscribe()
}
pub async fn connect( _endpoint : &str, config : WebSocketConfig ) -> Result< Self, crate::error::Error >
{
let connection = Self::new( config );
*connection.state.lock().unwrap() = WebSocketConnectionState::Connecting;
connection.state_tx.send( WebSocketConnectionState::Connecting ).ok();
*connection.state.lock().unwrap() = WebSocketConnectionState::Connected;
connection.state_tx.send( WebSocketConnectionState::Connected ).ok();
{
let mut metrics = connection.metrics.lock().unwrap();
metrics.connection_count += 1;
}
Ok( connection )
}
pub async fn close( &self ) -> Result< (), crate::error::Error >
{
*self.state.lock().unwrap() = WebSocketConnectionState::Closing;
self.state_tx.send( WebSocketConnectionState::Closing ).ok();
self.send_message( WebSocketMessage::Close( Some( "Normal closure".to_string() ) ) ).await?;
*self.state.lock().unwrap() = WebSocketConnectionState::Closed;
self.state_tx.send( WebSocketConnectionState::Closed ).ok();
Ok( () )
}
}
impl std::fmt::Debug for WebSocketConnection
{
fn fmt( &self, f : &mut std::fmt::Formatter< '_ > ) -> std::fmt::Result
{
f.debug_struct( "WebSocketConnection" )
.field( "state", &self.state() )
.field( "config", &self.config )
.field( "metrics", &self.get_metrics() )
.finish_non_exhaustive()
}
}
#[ derive( Debug ) ]
pub struct WebSocketStreamBuilder< 'a >
{
#[ allow( dead_code ) ]
model : &'a crate::models::api::ModelApi< 'a >,
message : Option< String >,
config : WebSocketConfig,
keepalive : Option< Duration >,
reconnect : bool,
fallback_to_http : bool,
}
impl< 'a > WebSocketStreamBuilder< 'a >
{
pub fn new( model : &'a crate::models::api::ModelApi< 'a > ) -> Self
{
Self {
model,
message : None,
config : WebSocketConfig::default(),
keepalive : None,
reconnect : true,
fallback_to_http : true,
}
}
pub fn with_message( mut self, message : &str ) -> Self
{
self.message = Some( message.to_string() );
self
}
pub fn with_keepalive( mut self, interval : Duration ) -> Self
{
self.keepalive = Some( interval );
self.config.heartbeat_interval = interval;
self
}
pub fn with_reconnect( mut self, reconnect : bool ) -> Self
{
self.reconnect = reconnect;
self
}
pub fn with_fallback_to_http( mut self, fallback : bool ) -> Self
{
self.fallback_to_http = fallback;
self.config.fallback_to_http = fallback;
self
}
pub fn with_config( mut self, config : WebSocketConfig ) -> Self
{
self.config = config;
self
}
pub async fn connect( self ) -> Result< WebSocketConnection, crate::error::Error >
{
let connection = WebSocketConnection::new( self.config );
*connection.state.lock().unwrap() = WebSocketConnectionState::Connected;
connection.state_tx.send( WebSocketConnectionState::Connected ).ok();
{
let mut metrics = connection.metrics.lock().unwrap();
metrics.connection_count += 1;
}
Ok( connection )
}
}
}
::mod_interface::mod_interface!
{
exposed use private::WebSocketConnectionState;
exposed use private::WebSocketConfig;
exposed use private::WebSocketConfigBuilder;
exposed use private::WebSocketPoolConfig;
exposed use private::WebSocketPoolConfigBuilder;
exposed use private::WebSocketMessage;
exposed use private::WebSocketMetrics;
exposed use private::WebSocketConnection;
exposed use private::WebSocketStreamBuilder;
}