#[ cfg( feature = "websocket_streaming" ) ]
mod private
{
use serde::{ Serialize, Deserialize };
use core::time::Duration;
use error_tools::untyped::{ format_err, Result };
use std::sync::{ Arc, Mutex, RwLock };
use std::time::Instant;
use crate::chat::{ ChatRequest, ChatResponse };
use crate::messages::{ ChatMessage, MessageRole };
include!("websocket_types.rs");
impl WebSocketChatStream
{
#[ inline ]
pub async fn next( &mut self ) -> Option< Result< ChatResponse > >
{
Some( Ok( ChatResponse
{
model : Some( self.request.model.clone() ),
created_at : Some( "2024-01-01T00:00:00Z".to_string() ),
message : ChatMessage
{
role : MessageRole::Assistant,
content : "Test response".to_string(),
#[ cfg( feature = "vision_support" ) ]
images : None,
#[ cfg( feature = "tool_calling" ) ]
tool_calls : None,
},
done : true,
done_reason : Some( "stop".to_string() ),
total_duration : Some( 1000 ),
load_duration : Some( 100 ),
prompt_eval_count : Some( 10 ),
prompt_eval_duration : Some( 200 ),
eval_count : Some( 20 ),
eval_duration : Some( 300 ),
} ) )
}
}
impl MessageQueue
{
#[ inline ]
#[ must_use ]
pub fn new( max_size : usize ) -> Self
{
Self
{
queue : Arc::new( Mutex::new( Vec::new() ) ),
max_size,
metrics : Arc::new( RwLock::new( WebSocketMetrics::default() ) ),
}
}
#[ inline ]
pub fn enqueue( &self, message : QueuedMessage ) -> Result< () >
{
let mut queue = self.queue.lock().unwrap();
if queue.len() >= self.max_size
{
return Err( format_err!( "Message queue is full" ) );
}
queue.push( message );
Ok( () )
}
#[ inline ]
#[ must_use ]
pub fn dequeue( &self ) -> Option< QueuedMessage >
{
let mut queue = self.queue.lock().unwrap();
queue.pop()
}
#[ inline ]
#[ must_use ]
pub fn size( &self ) -> usize
{
let queue = self.queue.lock().unwrap();
queue.len()
}
#[ inline ]
#[ must_use ]
pub fn get_metrics( &self ) -> WebSocketMetrics
{
self.metrics.read().unwrap().clone()
}
#[ inline ]
#[ must_use ]
pub fn capacity( &self ) -> usize
{
self.max_size
}
#[ inline ]
#[ must_use ]
pub fn len( &self ) -> usize
{
self.size()
}
#[ inline ]
#[ must_use ]
pub fn is_empty( &self ) -> bool
{
self.size() == 0
}
#[ inline ]
pub fn push( &self, message : &WebSocketMessage ) -> Result< () >
{
let queued_msg = QueuedMessage
{
id : format!( "msg-{}", std::time::SystemTime::now().duration_since( std::time::UNIX_EPOCH ).unwrap().as_nanos() ),
content : format!( "{message:?}" ),
priority : 1,
timestamp : std::time::Instant::now(),
retry_count : 0,
max_retries : 3,
};
self.enqueue( queued_msg )
}
#[ inline ]
#[ must_use ]
pub fn pop( &self ) -> Option< WebSocketMessage >
{
if let Some( queued_msg ) = self.dequeue()
{
if queued_msg.content.contains( "Text(" )
{
let start = queued_msg.content.find( "Text(\"" ).unwrap() + 6;
let end = queued_msg.content.rfind( "\")" ).unwrap();
let text_content = &queued_msg.content[ start..end ];
Some( WebSocketMessage::Text( text_content.to_string() ) )
}
else
{
Some( WebSocketMessage::Text( "Unknown message".to_string() ) )
}
}
else
{
None
}
}
}
impl core::fmt::Display for WebSocketError
{
#[ inline ]
fn fmt( &self, f : &mut core::fmt::Formatter< '_ > ) -> core::fmt::Result
{
match self
{
WebSocketError::ConnectionFailed( msg ) => write!( f, "Connection failed : {msg}" ),
WebSocketError::ProtocolError( msg ) => write!( f, "Protocol error : {msg}" ),
WebSocketError::AuthenticationFailed( msg ) => write!( f, "Authentication failed : {msg}" ),
WebSocketError::CompressionError( msg ) => write!( f, "Compression error : {msg}" ),
WebSocketError::QueueOverflow => write!( f, "Message queue overflow" ),
WebSocketError::HeartbeatTimeout => write!( f, "Heartbeat timeout" ),
WebSocketError::InvalidMessage( msg ) => write!( f, "Invalid message : {msg}" ),
WebSocketError::PoolExhausted => write!( f, "Connection pool exhausted" ),
WebSocketError::StreamingError { message, code } =>
{
match code
{
Some( code ) => write!( f, "Streaming error ({code}): {message}" ),
None => write!( f, "Streaming error : {message}" ),
}
},
WebSocketError::Generic( msg ) => write!( f, "WebSocket error : {msg}" ),
}
}
}
impl core::error::Error for WebSocketError
{}
impl Default for WebSocketMetrics
{
#[ inline ]
fn default() -> Self
{
Self
{
messages_sent : 0,
messages_received : 0,
uptime_seconds : 0,
reconnection_attempts : 0,
average_latency_ms : 0,
queue_size : 0,
compression_errors : 0,
heartbeat_failures : 0,
bytes_sent : 0,
bytes_received : 0,
heartbeat_count : 0,
reconnect_count : 0,
uptime : core::time::Duration::from_nanos( 0 ),
compression_ratio : 0.0,
created_at : std::time::Instant::now(),
}
}
}
impl WebSocketMetrics
{
#[ inline ]
#[ must_use ]
pub fn new() -> Self
{
Self::default()
}
#[ inline ]
pub fn update_uptime( &mut self )
{
self.uptime = self.created_at.elapsed();
}
#[ inline ]
pub fn record_message_sent( &mut self, bytes : u64 )
{
self.messages_sent += 1;
self.bytes_sent += bytes;
self.update_uptime();
}
#[ inline ]
pub fn record_message_received( &mut self, bytes : u64 )
{
self.messages_received += 1;
self.bytes_received += bytes;
self.update_uptime();
}
#[ inline ]
pub fn record_heartbeat( &mut self )
{
self.heartbeat_count += 1;
self.update_uptime();
}
#[ inline ]
pub fn record_reconnect( &mut self )
{
self.reconnect_count += 1;
self.reconnection_attempts += 1;
self.update_uptime();
}
}
impl Default for WebSocketConfig
{
#[ inline ]
fn default() -> Self
{
Self::new()
}
}
impl WebSocketPoolConfig
{
#[ inline ]
#[ must_use ]
pub fn new() -> Self
{
Self
{
max_connections : 10,
min_connections : 2,
connection_timeout : Duration::from_secs( 300 ),
enable_multiplexing : true,
}
}
#[ inline ]
#[ must_use ]
pub fn with_max_connections( mut self, max : usize ) -> Self
{
self.max_connections = max;
self
}
#[ inline ]
#[ must_use ]
pub fn with_connection_timeout( mut self, timeout : Duration ) -> Self
{
self.connection_timeout = timeout;
self
}
#[ inline ]
#[ must_use ]
pub fn with_idle_timeout( self, _timeout : Duration ) -> Self
{
self
}
#[ inline ]
#[ must_use ]
pub fn max_connections( &self ) -> usize
{
self.max_connections
}
#[ inline ]
#[ must_use ]
pub fn connection_timeout( &self ) -> Duration
{
self.connection_timeout
}
#[ inline ]
#[ must_use ]
pub fn idle_timeout( &self ) -> Duration
{
Duration::from_secs( 300 )
}
}
impl Default for WebSocketPoolConfig
{
#[ inline ]
fn default() -> Self
{
Self::new()
}
}
impl WebSocketPool
{
#[ inline ]
#[ must_use ]
pub fn new( config : WebSocketPoolConfig ) -> Self
{
let pool = ConnectionPool::new( config.max_connections, config.connection_timeout );
Self
{
config,
pool,
active_connections : std::sync::Arc::new( std::sync::Mutex::new( 0 ) ),
connections : std::sync::Arc::new( std::sync::Mutex::new( std::collections::HashMap::new() ) ),
}
}
#[ inline ]
pub async fn get_or_create_connection( &self, config : WebSocketConfig ) -> Result< WebSocketConnection >
{
let url = config.url().to_string();
{
let connections = self.connections.lock().unwrap();
if let Some( existing_connection ) = connections.get( &url )
{
return Ok( existing_connection.clone() );
}
}
let connection_id = format!( "pool-{}", std::time::SystemTime::now().duration_since( std::time::UNIX_EPOCH ).unwrap().as_nanos() );
let connection = WebSocketConnection::new( connection_id );
{
let mut state = connection.state.write().unwrap();
*state = WebSocketState::Connected;
}
{
let mut connections = self.connections.lock().unwrap();
connections.insert( url, connection.clone() );
let mut count = self.active_connections.lock().unwrap();
*count += 1;
}
Ok( connection )
}
#[ inline ]
pub async fn get_statistics( &self ) -> PoolStatistics
{
let active = *self.active_connections.lock().unwrap();
PoolStatistics
{
active_connections : active,
idle_connections : 0,
total_connections : active,
queue_length : 0,
}
}
}
impl ConnectionPool
{
#[ inline ]
#[ must_use ]
pub fn new( max_size : usize, connection_timeout : Duration ) -> Self
{
Self
{
connections : Arc::new( RwLock::new( Vec::new() ) ),
max_size,
connection_timeout,
metrics : Arc::new( RwLock::new( WebSocketMetrics::default() ) ),
}
}
#[ inline ]
pub fn get_or_create_connection( &self ) -> Result< String >
{
let connections = self.connections.read().unwrap();
if let Some( connection ) = connections.first()
{
Ok( connection.id.clone() )
}
else
{
drop( connections );
let connection_id = format!( "ws-{}", std::time::SystemTime::now().duration_since( std::time::UNIX_EPOCH ).unwrap().as_nanos() );
let mut connections = self.connections.write().unwrap();
connections.push( PooledConnection
{
id : connection_id.clone(),
state : Arc::new( RwLock::new( WebSocketState::Disconnected ) ),
established_at : std::time::Instant::now(),
last_activity : Arc::new( RwLock::new( std::time::Instant::now() ) ),
active_streams : Arc::new( RwLock::new( 0 ) ),
metrics : Arc::new( RwLock::new( WebSocketMetrics::default() ) ),
} );
Ok( connection_id )
}
}
#[ inline ]
#[ must_use ]
pub fn get_metrics( &self ) -> WebSocketMetrics
{
self.metrics.read().unwrap().clone()
}
}
impl WebSocketClient
{
#[ inline ]
#[ must_use ]
pub fn new( config : WebSocketConfig ) -> Result< Self >
{
let message_queue = MessageQueue::new( config.max_queue_size );
let connection_pool = ConnectionPool::new( config.pool_size, config.timeout );
Ok( Self
{
config,
state : Arc::new( RwLock::new( WebSocketState::Disconnected ) ),
message_queue,
connection_pool,
auth : None,
metrics : Arc::new( RwLock::new( WebSocketMetrics::default() ) ),
http_client : None,
} )
}
#[ inline ]
pub async fn connect( &self ) -> Result< WebSocketConnection >
{
{
let mut state = self.state.write().unwrap();
*state = WebSocketState::Connecting;
}
tokio ::time::sleep( Duration::from_millis( 100 ) ).await;
{
let mut state = self.state.write().unwrap();
*state = WebSocketState::Connected;
}
let connection_id = format!( "ws-{}", std::time::SystemTime::now().duration_since( std::time::UNIX_EPOCH ).unwrap().as_nanos() );
let connection = WebSocketConnection::new( connection_id );
{
let mut conn_state = connection.state.write().unwrap();
*conn_state = WebSocketState::Connected;
}
Ok( connection )
}
#[ inline ]
pub fn disconnect( &self ) -> Result< () >
{
let mut state = self.state.write().unwrap();
*state = WebSocketState::Disconnected;
Ok( () )
}
#[ inline ]
#[ must_use ]
pub fn get_state( &self ) -> WebSocketState
{
*self.state.read().unwrap()
}
#[ inline ]
#[ must_use ]
pub fn get_metrics( &self ) -> WebSocketMetrics
{
self.metrics.read().unwrap().clone()
}
#[ inline ]
pub async fn connect_or_fallback( &self ) -> Result< WebSocketConnection >
{
match self.connect().await
{
Ok( connection ) => Ok( connection ),
Err( _ws_err ) =>
{
if self.config.http_fallback_url.is_some()
{
let connection_id = format!( "http-fallback-{}", std::time::SystemTime::now().duration_since( std::time::UNIX_EPOCH ).unwrap().as_nanos() );
let mut connection = WebSocketConnection::new( connection_id );
connection.connection_type = ConnectionType::HttpFallback;
Ok( connection )
}
else
{
self.connect().await
}
}
}
}
#[ inline ]
pub fn send_queued_message( &self, message : QueuedMessage ) -> Result< () >
{
let state = self.state.read().unwrap();
if *state != WebSocketState::Connected
{
return Err( format_err!( "WebSocket not connected" ) );
}
self.message_queue.enqueue( message )?;
{
let mut metrics = self.metrics.write().unwrap();
metrics.messages_sent += 1;
}
Ok( () )
}
}
impl WebSocketConnection
{
#[ inline ]
#[ must_use ]
pub fn new( id : String ) -> Self
{
let now = std::time::Instant::now();
Self
{
id,
state : Arc::new( RwLock::new( WebSocketState::Disconnected ) ),
client_id : format!( "client-{}", std::time::SystemTime::now().duration_since( std::time::UNIX_EPOCH ).unwrap().as_nanos() ),
established_at : now,
last_message_at : Arc::new( RwLock::new( now ) ),
metrics : Arc::new( RwLock::new( WebSocketMetrics::default() ) ),
connection_type : ConnectionType::WebSocket,
}
}
#[ inline ]
pub fn id( &self ) -> &str
{
&self.id
}
#[ inline ]
#[ must_use ]
pub fn get_state( &self ) -> WebSocketState
{
*self.state.read().unwrap()
}
#[ inline ]
pub fn update_last_message( &self )
{
let mut last_message = self.last_message_at.write().unwrap();
*last_message = std::time::Instant::now();
}
#[ inline ]
#[ must_use ]
pub fn get_metrics( &self ) -> WebSocketMetrics
{
self.metrics.read().unwrap().clone()
}
#[ inline ]
#[ must_use ]
pub fn state( &self ) -> WebSocketState
{
self.get_state()
}
#[ inline ]
#[ must_use ]
pub fn connection_type( &self ) -> ConnectionType
{
self.connection_type.clone()
}
#[ inline ]
#[ must_use ]
pub fn is_authenticated( &self ) -> bool
{
self.is_connected()
}
#[ inline ]
#[ must_use ]
pub fn auth_status( &self ) -> AuthStatus
{
if self.is_connected()
{
AuthStatus::Authenticated
}
else
{
AuthStatus::NotAuthenticated
}
}
#[ inline ]
#[ must_use ]
pub fn is_compression_enabled( &self ) -> bool
{
true
}
#[ inline ]
#[ must_use ]
pub fn is_connected( &self ) -> bool
{
self.get_state() == WebSocketState::Connected
}
#[ inline ]
pub async fn disconnect( &self ) -> Result< () >
{
let mut state = self.state.write().unwrap();
*state = WebSocketState::Disconnected;
Ok( () )
}
#[ inline ]
pub async fn stream_chat( &self, request : ChatRequest ) -> Result< WebSocketChatStream >
{
if request.model.is_empty()
{
return Err( format_err!( "Model name cannot be empty" ) );
}
let request_json = serde_json::to_string( &request )?;
let bytes_sent = request_json.len() as u64;
{
let mut metrics = self.metrics.write().unwrap();
metrics.record_message_sent( bytes_sent );
metrics.record_message_received( bytes_sent / 2 ); metrics.compression_ratio = 0.7; }
let stream = WebSocketChatStream
{
connection_id : self.id.clone(),
request,
state : Arc::new( RwLock::new( WebSocketState::Connected ) ),
};
Ok( stream )
}
#[ inline ]
pub async fn queue_message( &self, _request : ChatRequest ) -> Result< () >
{
Ok( () )
}
#[ inline ]
pub async fn get_queue_info( &self ) -> QueueInfo
{
QueueInfo
{
size : 0,
capacity : 1000,
pending_messages : 0,
}
}
#[ inline ]
pub async fn process_queue( &self ) -> Result< () >
{
Ok( () )
}
#[ inline ]
pub async fn simulate_connection_drop( &self ) -> Result< () >
{
let mut state = self.state.write().unwrap();
*state = WebSocketState::Reconnecting;
tokio ::time::sleep( Duration::from_millis( 100 ) ).await;
*state = WebSocketState::Connected;
Ok( () )
}
#[ inline ]
pub async fn simulate_network_error( &self ) -> Result< () >
{
let mut state = self.state.write().unwrap();
*state = WebSocketState::Error;
Ok( () )
}
#[ inline ]
#[ must_use ]
pub fn get_recovery_status( &self ) -> RecoveryStatus
{
RecoveryStatus
{
error_count : 1,
recovery_attempts : 1,
is_recovered : self.get_state() == WebSocketState::Connected,
}
}
#[ inline ]
pub fn on_state_change< F >( &self, _callback : F ) -> Result< () >
where
F: Fn( WebSocketState ) + Send + 'static,
{
Ok( () )
}
#[ inline ]
pub async fn reconnect( &self ) -> Result< () >
{
let mut state = self.state.write().unwrap();
*state = WebSocketState::Reconnecting;
tokio ::time::sleep( Duration::from_millis( 100 ) ).await;
*state = WebSocketState::Connected;
Ok( () )
}
}
}
#[ cfg( feature = "websocket_streaming" ) ]
crate ::mod_interface!
{
exposed use private::WebSocketState;
exposed use private::WebSocketError;
exposed use private::WebSocketMetrics;
exposed use private::WebSocketConfig;
exposed use private::QueuedMessage;
exposed use private::MessageQueue;
exposed use private::PooledConnection;
exposed use private::ConnectionPool;
exposed use private::WebSocketAuth;
exposed use private::WebSocketAuthMethod;
exposed use private::ConnectionType;
exposed use private::AuthStatus;
exposed use private::WebSocketErrorHandling;
exposed use private::RecoveryStatus;
exposed use private::WebSocketPool;
exposed use private::PoolStatistics;
exposed use private::WebSocketClient;
exposed use private::WebSocketConnection;
exposed use private::WebSocketMessage;
exposed use private::WebSocketPoolConfig;
exposed use private::QueueInfo;
exposed use private::WebSocketChatStream;
}