#![ allow( clippy::unreadable_literal ) ]
#![ allow( clippy::uninlined_format_args ) ]
#![ allow( clippy::std_instead_of_core ) ]
#![ allow( clippy::useless_vec ) ]
#![ allow( clippy::unused_async ) ]
#![ allow( clippy::must_use_candidate ) ]
#![ allow( clippy::missing_panics_doc ) ]
#![ allow( clippy::missing_errors_doc ) ]
#![ allow( clippy::doc_markdown ) ]
#![ allow( clippy::needless_continue ) ]
#![ allow( clippy::redundant_else ) ]
use api_openai::
{
realtime ::WsSession,
error ::{ OpenAIError, Result },
};
use std::
{
collections ::VecDeque,
sync ::{ Arc, Mutex },
time ::{ Duration, Instant },
};
use tokio::
{
sync ::{ RwLock, Semaphore },
time ::{ timeout, sleep, interval },
};
#[ derive( Debug, Clone ) ]
pub struct WsReliabilityConfig
{
pub max_reconnection_attempts : usize,
pub initial_reconnection_delay : Duration,
pub max_reconnection_delay : Duration,
pub connection_timeout : Duration,
pub heartbeat_interval : Duration,
pub message_buffer_size : usize,
pub health_check_interval : Duration,
}
impl Default for WsReliabilityConfig
{
fn default() -> Self
{
Self
{
max_reconnection_attempts : 5,
initial_reconnection_delay : Duration::from_millis( 1000 ),
max_reconnection_delay : Duration::from_secs( 30 ),
connection_timeout : Duration::from_secs( 10 ),
heartbeat_interval : Duration::from_secs( 30 ),
message_buffer_size : 1000,
health_check_interval : Duration::from_secs( 5 ),
}
}
}
#[ derive( Debug, Clone ) ]
pub struct WsConnectionStats
{
pub connection_attempts : u64,
pub successful_connections : u64,
pub failed_connections : u64,
pub reconnections : u64,
pub messages_sent : u64,
pub messages_received : u64,
pub message_send_failures : u64,
pub connection_interruptions : u64,
pub average_connection_duration : Duration,
pub last_connection_time : Option< Instant >,
}
impl Default for WsConnectionStats
{
fn default() -> Self
{
Self
{
connection_attempts : 0,
successful_connections : 0,
failed_connections : 0,
reconnections : 0,
messages_sent : 0,
messages_received : 0,
message_send_failures : 0,
connection_interruptions : 0,
average_connection_duration : Duration::from_secs( 0 ),
last_connection_time : None,
}
}
}
#[ derive( Debug ) ]
pub struct EnhancedWsSession
{
current_session : Arc< RwLock< Option< WsSession > > >,
config : WsReliabilityConfig,
stats : Arc< RwLock< WsConnectionStats > >,
message_buffer : Arc< Mutex< VecDeque< String > > >,
url : String,
reconnection_count : Arc< Mutex< usize > >,
last_heartbeat : Arc< Mutex< Option< Instant > > >,
}
impl EnhancedWsSession
{
pub fn new( url : String, config : WsReliabilityConfig ) -> Self
{
Self
{
current_session : Arc::new( RwLock::new( None ) ),
config,
stats : Arc::new( RwLock::new( WsConnectionStats::default() ) ),
message_buffer : Arc::new( Mutex::new( VecDeque::new() ) ),
url,
reconnection_count : Arc::new( Mutex::new( 0 ) ),
last_heartbeat : Arc::new( Mutex::new( None ) ),
}
}
pub async fn connect_with_reliability( &self ) -> Result< () >
{
let mut attempts = 0;
let mut delay = self.config.initial_reconnection_delay;
loop
{
{
let mut stats = self.stats.write().await;
stats.connection_attempts += 1;
}
match timeout( self.config.connection_timeout, WsSession::connect( &self.url ) ).await
{
Ok( Ok( session ) ) =>
{
{
let mut current = self.current_session.write().await;
*current = Some( session );
}
{
let mut stats = self.stats.write().await;
stats.successful_connections += 1;
stats.last_connection_time = Some( Instant::now() );
}
{
let mut reconnection_count = self.reconnection_count.lock().unwrap();
*reconnection_count = 0;
}
return Ok( () );
},
Ok( Err( error ) ) =>
{
{
let mut stats = self.stats.write().await;
stats.failed_connections += 1;
}
attempts += 1;
if attempts >= self.config.max_reconnection_attempts
{
return Err( error );
}
sleep( delay ).await;
delay = std::cmp::min( delay * 2, self.config.max_reconnection_delay );
},
Err( _ ) =>
{
{
let mut stats = self.stats.write().await;
stats.failed_connections += 1;
}
attempts += 1;
if attempts >= self.config.max_reconnection_attempts
{
return Err( error_tools::Error::from( OpenAIError::Internal(
format!( "Failed to connect after {} attempts", attempts )
) ) );
}
sleep( delay ).await;
delay = std::cmp::min( delay * 2, self.config.max_reconnection_delay );
}
}
}
}
pub async fn send_message_reliable( &self, message : &str ) -> Result< () >
{
{
let mut buffer = self.message_buffer.lock().unwrap();
if buffer.len() >= self.config.message_buffer_size
{
buffer.pop_front(); }
buffer.push_back( message.to_string() );
}
let session_opt = self.current_session.read().await;
if let Some( ref session ) = *session_opt
{
match self.simulate_message_send( session, message ).await
{
Ok( () ) =>
{
let mut stats = self.stats.write().await;
stats.messages_sent += 1;
return Ok( () );
},
Err( _error ) =>
{
let mut stats = self.stats.write().await;
stats.message_send_failures += 1;
}
}
}
self.handle_connection_failure().await?;
let session_opt = self.current_session.read().await;
if let Some( ref session ) = *session_opt
{
self.simulate_message_send( session, message ).await?;
let mut stats = self.stats.write().await;
stats.messages_sent += 1;
}
Ok( () )
}
async fn simulate_message_send( &self, _session : &WsSession, _message : &str ) -> Result< () >
{
Ok( () )
}
async fn handle_connection_failure( &self ) -> Result< () >
{
{
let mut stats = self.stats.write().await;
stats.connection_interruptions += 1;
}
{
let mut current = self.current_session.write().await;
*current = None; }
{
let mut reconnection_count = self.reconnection_count.lock().unwrap();
*reconnection_count += 1;
if *reconnection_count > self.config.max_reconnection_attempts
{
return Err( error_tools::Error::from( OpenAIError::Internal(
"Maximum reconnection attempts exceeded".to_string()
) ) );
}
}
{
let mut stats = self.stats.write().await;
stats.reconnections += 1;
}
self.connect_with_reliability().await
}
pub async fn start_heartbeat_monitoring( &self )
{
let heartbeat_interval = self.config.heartbeat_interval;
let last_heartbeat = Arc::clone( &self.last_heartbeat );
let current_session = Arc::clone( &self.current_session );
tokio ::spawn( async move
{
let mut interval = interval( heartbeat_interval );
loop
{
interval.tick().await;
{
let mut last_hb = last_heartbeat.lock().unwrap();
*last_hb = Some( Instant::now() );
}
let session_opt = current_session.read().await;
if session_opt.is_some()
{
continue;
}
else
{
break;
}
}
});
}
pub async fn check_connection_health( &self ) -> bool
{
let session_opt = self.current_session.read().await;
if session_opt.is_none()
{
return false;
}
let last_hb = self.last_heartbeat.lock().unwrap();
if let Some( last_time ) = *last_hb
{
let elapsed = last_time.elapsed();
return elapsed < self.config.heartbeat_interval * 3; }
true }
pub async fn get_stats( &self ) -> WsConnectionStats
{
self.stats.read().await.clone()
}
pub fn get_buffered_message_count( &self ) -> usize
{
self.message_buffer.lock().unwrap().len()
}
pub async fn flush_message_buffer( &self ) -> Result< () >
{
let messages : Vec< String > = {
let mut buffer = self.message_buffer.lock().unwrap();
buffer.drain( .. ).collect()
};
for message in messages
{
self.send_message_reliable( &message ).await?;
}
Ok( () )
}
}
#[ derive( Debug ) ]
pub struct WsReliabilityTestUtils;
impl WsReliabilityTestUtils
{
pub fn create_test_config() -> WsReliabilityConfig
{
WsReliabilityConfig
{
max_reconnection_attempts : 3,
initial_reconnection_delay : Duration::from_millis( 100 ),
max_reconnection_delay : Duration::from_millis( 500 ),
connection_timeout : Duration::from_millis( 1000 ),
heartbeat_interval : Duration::from_millis( 500 ),
message_buffer_size : 10,
health_check_interval : Duration::from_millis( 200 ),
}
}
pub fn create_mock_ws_url() -> String
{
"wss://api.openai.com/v1/realtime/test".to_string()
}
pub async fn simulate_network_interruption( duration : Duration )
{
sleep( duration ).await;
}
pub async fn measure_connection_time< F, Fut >( connection_fn : F ) -> Duration
where
F : FnOnce() -> Fut,
Fut : std::future::Future< Output = Result< () > >,
{
let start = Instant::now();
let _ = connection_fn().await;
start.elapsed()
}
}
#[ tokio::test ]
async fn test_enhanced_ws_session_creation()
{
let config = WsReliabilityTestUtils::create_test_config();
let url = WsReliabilityTestUtils::create_mock_ws_url();
let session = EnhancedWsSession::new( url.clone(), config.clone() );
assert_eq!( session.url, url );
assert_eq!( session.config.max_reconnection_attempts, 3 );
assert_eq!( session.get_buffered_message_count(), 0 );
let stats = session.get_stats().await;
assert_eq!( stats.connection_attempts, 0 );
assert_eq!( stats.successful_connections, 0 );
}
#[ tokio::test ]
async fn test_ws_reliability_config_defaults()
{
let config = WsReliabilityConfig::default();
assert_eq!( config.max_reconnection_attempts, 5 );
assert_eq!( config.initial_reconnection_delay, Duration::from_millis( 1000 ) );
assert_eq!( config.max_reconnection_delay, Duration::from_secs( 30 ) );
assert_eq!( config.connection_timeout, Duration::from_secs( 10 ) );
assert_eq!( config.heartbeat_interval, Duration::from_secs( 30 ) );
assert_eq!( config.message_buffer_size, 1000 );
}
#[ tokio::test ]
async fn test_ws_connection_stats_initialization()
{
let stats = WsConnectionStats::default();
assert_eq!( stats.connection_attempts, 0 );
assert_eq!( stats.successful_connections, 0 );
assert_eq!( stats.failed_connections, 0 );
assert_eq!( stats.reconnections, 0 );
assert_eq!( stats.messages_sent, 0 );
assert_eq!( stats.messages_received, 0 );
assert_eq!( stats.message_send_failures, 0 );
assert_eq!( stats.connection_interruptions, 0 );
assert!(stats.last_connection_time.is_none());
}
#[ tokio::test ]
async fn test_message_buffering_functionality()
{
let config = WsReliabilityTestUtils::create_test_config();
let url = WsReliabilityTestUtils::create_mock_ws_url();
let session = EnhancedWsSession::new( url, config );
{
let mut buffer = session.message_buffer.lock().unwrap();
buffer.push_back( "test_message_1".to_string() );
buffer.push_back( "test_message_2".to_string() );
}
assert_eq!( session.get_buffered_message_count(), 2 );
for i in 0..20
{
let mut buffer = session.message_buffer.lock().unwrap();
if buffer.len() >= session.config.message_buffer_size
{
buffer.pop_front(); }
buffer.push_back( format!( "message_{}", i ) );
}
assert!( session.get_buffered_message_count() <= 10 );
}
#[ tokio::test ]
async fn test_connection_health_monitoring()
{
let config = WsReliabilityTestUtils::create_test_config();
let url = WsReliabilityTestUtils::create_mock_ws_url();
let session = EnhancedWsSession::new( url, config );
let health = session.check_connection_health().await;
assert!( !health );
{
let mut last_hb = session.last_heartbeat.lock().unwrap();
*last_hb = Some( Instant::now() );
}
let health = session.check_connection_health().await;
assert!( !health );
}
#[ tokio::test ]
async fn test_reliability_test_utils()
{
let config = WsReliabilityTestUtils::create_test_config();
assert_eq!( config.max_reconnection_attempts, 3 );
assert_eq!( config.initial_reconnection_delay, Duration::from_millis( 100 ) );
let url = WsReliabilityTestUtils::create_mock_ws_url();
assert!( url.starts_with( "wss://" ) );
assert!( url.contains( "openai.com" ) );
let start = Instant::now();
WsReliabilityTestUtils::simulate_network_interruption( Duration::from_millis( 50 ) ).await;
let elapsed = start.elapsed();
assert!( elapsed >= Duration::from_millis( 50 ) );
}
#[ tokio::test ]
async fn test_concurrent_reliability_sessions()
{
let config = WsReliabilityTestUtils::create_test_config();
let url = WsReliabilityTestUtils::create_mock_ws_url();
let session1 = Arc::new( EnhancedWsSession::new( url.clone(), config.clone() ) );
let session2 = Arc::new( EnhancedWsSession::new( url, config ) );
let semaphore = Arc::new( Semaphore::new( 2 ) );
let mut handles = Vec::new();
for session in vec![ session1, session2 ]
{
let semaphore_clone = Arc::clone( &semaphore );
let handle = tokio::spawn( async move
{
let _permit = semaphore_clone.acquire().await.unwrap();
let stats = session.get_stats().await;
assert_eq!( stats.connection_attempts, 0 );
let health = session.check_connection_health().await;
assert!( !health );
session.start_heartbeat_monitoring().await;
sleep( Duration::from_millis( 100 ) ).await;
});
handles.push( handle );
}
for handle in handles
{
handle.await.expect( "Task should complete successfully" );
}
}
#[ tokio::test ]
async fn test_message_reliability_under_failure()
{
let config = WsReliabilityTestUtils::create_test_config();
let url = WsReliabilityTestUtils::create_mock_ws_url();
let session = EnhancedWsSession::new( url, config );
{
let mut buffer = session.message_buffer.lock().unwrap();
buffer.push_back( "queued_message_1".to_string() );
buffer.push_back( "queued_message_2".to_string() );
}
assert_eq!( session.get_buffered_message_count(), 2 );
{
let mut stats = session.stats.write().await;
stats.message_send_failures += 1;
stats.connection_interruptions += 1;
}
let final_stats = session.get_stats().await;
assert_eq!( final_stats.message_send_failures, 1 );
assert_eq!( final_stats.connection_interruptions, 1 );
}
#[ tokio::test ]
async fn test_exponential_backoff_timing()
{
let config = WsReliabilityTestUtils::create_test_config();
let mut delay = config.initial_reconnection_delay;
assert_eq!( delay, Duration::from_millis( 100 ) );
delay = std::cmp::min( delay * 2, config.max_reconnection_delay );
assert_eq!( delay, Duration::from_millis( 200 ) );
delay = std::cmp::min( delay * 2, config.max_reconnection_delay );
assert_eq!( delay, Duration::from_millis( 400 ) );
delay = std::cmp::min( delay * 2, config.max_reconnection_delay );
assert_eq!( delay, Duration::from_millis( 500 ) );
delay = std::cmp::min( delay * 2, config.max_reconnection_delay );
assert_eq!( delay, Duration::from_millis( 500 ) ); }
#[ tokio::test ]
async fn test_connection_timeout_handling()
{
let config = WsReliabilityTestUtils::create_test_config();
assert_eq!( config.connection_timeout, Duration::from_millis( 1000 ) );
let start = Instant::now();
let result = timeout( config.connection_timeout, async move
{
sleep( Duration::from_millis( 1500 ) ).await;
Ok::<(), OpenAIError >( () )
}).await;
let elapsed = start.elapsed();
assert!( result.is_err() ); assert!( elapsed >= config.connection_timeout );
assert!( elapsed < Duration::from_millis( 1200 ) ); }
#[ tokio::test ]
async fn test_heartbeat_interval_configuration()
{
let config = WsReliabilityTestUtils::create_test_config();
let url = WsReliabilityTestUtils::create_mock_ws_url();
let session = EnhancedWsSession::new( url, config.clone() );
assert_eq!( config.heartbeat_interval, Duration::from_millis( 500 ) );
session.start_heartbeat_monitoring().await;
sleep( Duration::from_millis( 600 ) ).await;
let _last_hb = session.last_heartbeat.lock().unwrap();
}
#[ tokio::test ]
async fn test_statistics_accuracy_tracking()
{
let config = WsReliabilityTestUtils::create_test_config();
let url = WsReliabilityTestUtils::create_mock_ws_url();
let session = EnhancedWsSession::new( url, config );
{
let mut stats = session.stats.write().await;
stats.connection_attempts = 5;
stats.successful_connections = 3;
stats.failed_connections = 2;
stats.reconnections = 2;
stats.messages_sent = 100;
stats.messages_received = 95;
stats.message_send_failures = 5;
stats.connection_interruptions = 1;
stats.last_connection_time = Some( Instant::now() );
}
let final_stats = session.get_stats().await;
assert_eq!( final_stats.connection_attempts, 5 );
assert_eq!( final_stats.successful_connections, 3 );
assert_eq!( final_stats.failed_connections, 2 );
assert_eq!( final_stats.reconnections, 2 );
assert_eq!( final_stats.messages_sent, 100 );
assert_eq!( final_stats.messages_received, 95 );
assert_eq!( final_stats.message_send_failures, 5 );
assert_eq!( final_stats.connection_interruptions, 1 );
assert!( final_stats.last_connection_time.is_some() );
}