#[ cfg( any( feature = "failover", feature = "health_checks" ) ) ]
mod private
{
use core::time::Duration;
use std::sync::atomic::AtomicUsize;
use error_tools::untyped::{ format_err, Result as OllamaResult };
#[ cfg( feature = "failover" ) ]
#[ derive( Debug, Clone, PartialEq ) ]
pub enum FailoverPolicy
{
RoundRobin,
Priority,
}
#[ derive( Debug, Clone, PartialEq ) ]
#[ allow( dead_code ) ]
pub enum EndpointHealth
{
Healthy,
Degraded,
Unhealthy,
Unknown,
}
#[ cfg( feature = "failover" ) ]
#[ derive( Debug, Clone ) ]
#[ allow( dead_code ) ]
pub struct EndpointInfo
{
pub url : String,
pub health : EndpointHealth,
pub last_success : Option< std::time::Instant >,
pub last_failure : Option< std::time::Instant >,
pub total_requests : u64,
pub total_failures : u64,
pub avg_response_time : Duration,
}
#[ cfg( feature = "failover" ) ]
#[ allow( dead_code ) ]
impl EndpointInfo
{
#[ inline ]
#[ must_use ]
pub fn new( url : String ) -> Self
{
Self
{
url,
health : EndpointHealth::Healthy,
last_success : None,
last_failure : None,
total_requests : 0,
total_failures : 0,
avg_response_time : Duration::from_millis( 0 ),
}
}
#[ inline ]
pub fn mark_healthy( &mut self )
{
self.health = EndpointHealth::Healthy;
self.last_success = Some( std::time::Instant::now() );
}
#[ inline ]
pub fn mark_unhealthy( &mut self )
{
self.health = EndpointHealth::Unhealthy;
self.last_failure = Some( std::time::Instant::now() );
self.total_failures += 1;
}
#[ inline ]
pub fn update_request_stats( &mut self, response_time : Duration )
{
self.total_requests += 1;
if self.total_requests == 1
{
self.avg_response_time = response_time;
}
else
{
let total_ms = self.avg_response_time.as_millis() as u64 * ( self.total_requests - 1 );
let new_avg_ms = ( total_ms + response_time.as_millis() as u64 ) / self.total_requests;
self.avg_response_time = Duration::from_millis( new_avg_ms );
}
}
#[ inline ]
#[ must_use ]
pub fn is_healthy( &self ) -> bool
{
self.health == EndpointHealth::Healthy
}
}
#[ cfg( feature = "failover" ) ]
#[ derive( Debug, Clone ) ]
pub struct FailoverStats
{
pub total_failovers : u64,
pub total_requests : u64,
pub active_endpoint_index : usize,
pub total_endpoints : usize,
}
#[ cfg( feature = "failover" ) ]
impl FailoverStats
{
#[ inline ]
#[ must_use ]
pub fn new( total_endpoints : usize ) -> Self
{
Self
{
total_failovers : 0,
total_requests : 0,
active_endpoint_index : 0,
total_endpoints,
}
}
}
#[ cfg( feature = "failover" ) ]
#[ derive( Debug ) ]
#[ allow( dead_code ) ]
pub struct FailoverManager
{
endpoints : Vec< EndpointInfo >,
current_index : AtomicUsize,
policy : FailoverPolicy,
stats : std::sync::Mutex< FailoverStats >,
timeout : Duration,
}
#[ cfg( feature = "failover" ) ]
impl FailoverManager
{
#[ inline ]
#[ must_use ]
pub fn new( endpoints : Vec< String >, policy : FailoverPolicy, timeout : Duration ) -> OllamaResult< Self >
{
if endpoints.is_empty()
{
return Err( format_err!( "At least one endpoint must be provided" ) );
}
for endpoint in &endpoints
{
url ::Url::parse( endpoint ).map_err( |e| format_err!( "Invalid URL {}: {}", endpoint, e ) )?;
}
let endpoint_infos : Vec< EndpointInfo > = endpoints.into_iter()
.map( EndpointInfo::new )
.collect();
let stats = FailoverStats::new( endpoint_infos.len() );
Ok( Self
{
endpoints : endpoint_infos,
current_index : AtomicUsize::new( 0 ),
policy,
stats : std::sync::Mutex::new( stats ),
timeout,
})
}
#[ inline ]
#[ must_use ]
pub fn get_active_endpoint( &self ) -> String
{
let index = self.current_index.load( std::sync::atomic::Ordering::Acquire );
if index < self.endpoints.len()
{
self.endpoints[ index ].url.clone()
}
else
{
self.endpoints[ 0 ].url.clone()
}
}
#[ inline ]
#[ must_use ]
pub fn get_endpoint_count( &self ) -> usize
{
self.endpoints.len()
}
#[ inline ]
#[ must_use ]
pub fn is_endpoint_healthy( &self, url : &str ) -> bool
{
self.endpoints.iter()
.find( |endpoint| endpoint.url == url )
.map_or( false, |endpoint| endpoint.is_healthy() )
}
#[ inline ]
pub fn mark_endpoint_healthy( &mut self, url : &str )
{
if let Some( endpoint ) = self.endpoints.iter_mut().find( |e| e.url == url )
{
endpoint.mark_healthy();
}
if self.policy == FailoverPolicy::Priority
{
self.select_next_healthy_endpoint();
}
}
#[ inline ]
pub fn mark_endpoint_unhealthy( &mut self, url : &str )
{
if let Some( endpoint ) = self.endpoints.iter_mut().find( |e| e.url == url )
{
endpoint.mark_unhealthy();
}
self.select_next_healthy_endpoint();
}
#[ inline ]
pub fn rotate_endpoint( &mut self )
{
let current = self.current_index.load( std::sync::atomic::Ordering::Acquire );
let next = ( current + 1 ) % self.endpoints.len();
self.current_index.store( next, std::sync::atomic::Ordering::Release );
}
pub fn select_next_healthy_endpoint( &mut self )
{
match self.policy
{
FailoverPolicy::RoundRobin =>
{
let current = self.current_index.load( std::sync::atomic::Ordering::Acquire );
for i in 1..=self.endpoints.len()
{
let index = ( current + i ) % self.endpoints.len();
if self.endpoints[ index ].is_healthy()
{
self.current_index.store( index, std::sync::atomic::Ordering::Release );
break;
}
}
}
FailoverPolicy::Priority =>
{
for ( index, endpoint ) in self.endpoints.iter().enumerate()
{
if endpoint.is_healthy()
{
self.current_index.store( index, std::sync::atomic::Ordering::Release );
break;
}
}
}
}
if let Ok( mut stats ) = self.stats.lock()
{
stats.total_failovers += 1;
stats.active_endpoint_index = self.current_index.load( std::sync::atomic::Ordering::Acquire );
}
}
#[ inline ]
#[ must_use ]
pub fn get_failover_stats( &self ) -> FailoverStats
{
self.stats.lock().map( |stats| stats.clone() ).unwrap_or_else( |_|
{
FailoverStats::new( self.endpoints.len() )
})
}
}
}
#[ cfg( any( feature = "failover", feature = "health_checks" ) ) ]
crate ::mod_interface!
{
exposed use private::EndpointHealth;
#[ cfg( feature = "failover" ) ]
exposed use private::FailoverPolicy;
#[ cfg( feature = "failover" ) ]
exposed use private::FailoverStats;
#[ cfg( feature = "failover" ) ]
exposed use private::EndpointInfo;
#[ cfg( feature = "failover" ) ]
exposed use private::FailoverManager;
}