pub mod types;
use crate::types::{InstanceStatus, LLMInstance, ModelCapability, RouterError, RoutingStrategy};
use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc,
},
time::{Duration, Instant},
};
use tokio::sync::{oneshot, RwLock}; use tracing::{debug, error, info, instrument, warn}; use futures::{stream, StreamExt};
const DEFAULT_HEALTH_CHECK_INTERVAL: Duration = Duration::from_secs(5);
const DEFAULT_HEALTH_CHECK_TIMEOUT: Duration = Duration::from_secs(2);
const DEFAULT_HEALTH_CHECK_PATH: &str = "/health"; const DEFAULT_INSTANCE_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug)]
struct RouterInternalState {
instances: RwLock<Vec<LLMInstance>>,
strategy: RoutingStrategy,
next_instance_index: AtomicUsize, _http_client: reqwest::Client, instance_timeout_duration: Duration, }
#[derive(Debug, Clone)] pub struct Router {
internal: Arc<RouterInternalState>,
stop_health_check_tx: Arc<RwLock<Option<oneshot::Sender<()>>>>,
}
#[derive(Debug, Clone)]
pub struct ModelInstanceConfig {
pub model_name: String,
pub capabilities: Vec<ModelCapability>,
}
#[derive(Debug)]
pub struct RouterBuilder {
initial_instances: Vec<LLMInstance>,
strategy: RoutingStrategy,
health_check_interval: Duration,
health_check_timeout: Duration,
health_check_path: String,
instance_timeout_duration: Duration,
reqwest_client: Option<reqwest::Client>,
}
impl Default for RouterBuilder {
fn default() -> Self {
RouterBuilder {
initial_instances: Vec::new(),
strategy: RoutingStrategy::LoadBased,
health_check_interval: DEFAULT_HEALTH_CHECK_INTERVAL,
health_check_timeout: DEFAULT_HEALTH_CHECK_TIMEOUT,
health_check_path: DEFAULT_HEALTH_CHECK_PATH.to_string(),
instance_timeout_duration: DEFAULT_INSTANCE_TIMEOUT,
reqwest_client: None,
}
}
}
impl RouterBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn initial_instances(mut self, instances: Vec<LLMInstance>) -> Self {
self.initial_instances = instances;
self
}
pub fn strategy(mut self, strategy: RoutingStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn health_check_interval(mut self, interval: Duration) -> Self {
self.health_check_interval = interval;
self
}
pub fn health_check_timeout(mut self, timeout: Duration) -> Self {
self.health_check_timeout = timeout;
self
}
pub fn health_check_path(mut self, path: impl Into<String>) -> Self {
self.health_check_path = path.into();
if !self.health_check_path.starts_with('/') {
self.health_check_path.insert(0, '/');
}
self
}
pub fn instance_timeout_duration(mut self, duration: Duration) -> Self {
self.instance_timeout_duration = duration;
self
}
pub fn instance_with_models(
mut self,
id: impl Into<String>,
base_url: impl Into<String>,
models: Vec<ModelInstanceConfig>,
) -> Self {
let mut supported_models_map = HashMap::new();
for config in models {
supported_models_map.insert(config.model_name.clone(), config.capabilities.clone());
}
let instance = LLMInstance::new(
id.into(),
base_url.into(),
InstanceStatus::Unknown, supported_models_map
);
self.initial_instances.push(instance);
self
}
pub fn instance(
self,
id: impl Into<String>,
base_url: impl Into<String>,
) -> Self {
self.instance_with_models(id, base_url, Vec::new())
}
pub fn instances(mut self, instances: Vec<(String, String)>) -> Self {
for (id, base_url) in instances {
self = self.instance(id, base_url);
}
self
}
pub fn http_client(mut self, client: reqwest::Client) -> Self {
self.reqwest_client = Some(client);
self
}
pub fn build(self) -> Router {
let main_client = self.reqwest_client.unwrap_or_else(|| {
reqwest::Client::builder()
.timeout(Duration::from_secs(150)) .build()
.expect("Failed to build default reqwest client")
});
let health_check_client = reqwest::Client::builder()
.timeout(self.health_check_timeout) .build()
.expect("Failed to build health check reqwest client");
let instances = self.initial_instances;
let internal_state = Arc::new(RouterInternalState {
instances: RwLock::new(instances),
strategy: self.strategy,
next_instance_index: AtomicUsize::new(0),
_http_client: main_client, instance_timeout_duration: self.instance_timeout_duration,
});
let (stop_tx, stop_rx) = oneshot::channel::<()>();
let router = Router {
internal: internal_state.clone(),
stop_health_check_tx: Arc::new(RwLock::new(Some(stop_tx))),
};
HealthCheck::spawn(
internal_state,
stop_rx,
health_check_client, self.health_check_interval,
self.health_check_timeout,
self.health_check_path,
);
router
}
}
impl Router {
pub fn builder() -> RouterBuilder {
RouterBuilder::new()
}
pub async fn get_instances(&self) -> Vec<LLMInstance> {
self.internal.instances.read().await.clone()
}
#[instrument(skip(self, id, base_url))]
pub async fn add_instance(
&self,
id: impl Into<String>,
base_url: impl Into<String>,
) -> Result<(), RouterError> {
self.add_instance_with_models(id, base_url, Vec::new()).await
}
#[instrument(skip(self, id, base_url, models))]
pub async fn add_instance_with_models(
&self,
id: impl Into<String>,
base_url: impl Into<String>,
models: Vec<ModelInstanceConfig>,
) -> Result<(), RouterError> {
let instance_id = id.into();
let base_url_str = base_url.into();
let _ = url::Url::parse(&base_url_str).map_err(RouterError::InvalidUrl)?;
let mut instances = self.internal.instances.write().await;
if instances.iter().any(|inst| inst.id == instance_id) {
warn!(instance_id = %instance_id, "Attempted to add duplicate instance");
return Err(RouterError::InstanceExists(instance_id));
}
let mut instance = LLMInstance {
id: instance_id.clone(),
base_url: base_url_str,
active_requests: Arc::new(AtomicUsize::new(0)),
status: Arc::new(RwLock::new(InstanceStatus::Unknown)), is_in_timeout: Arc::new(AtomicBool::new(false)),
timeout_until: Arc::new(RwLock::new(None)),
supported_models: Arc::new(RwLock::new(HashMap::new())), };
let mut supported_models_map = HashMap::new();
for config in models {
supported_models_map.insert(config.model_name.clone(), config.capabilities.clone());
info!(
instance_id = %instance.id,
model = %config.model_name,
capabilities = ?config.capabilities,
"Adding model support for new instance"
);
}
let supported_models_guard = Arc::get_mut(&mut instance.supported_models)
.expect("Failed to get mutable reference to supported_models");
*supported_models_guard = RwLock::new(supported_models_map);
instances.push(instance);
info!(instance_id = %instance_id, "Added new instance");
Ok(())
}
#[instrument(skip(self, instance_id, model_name, capabilities))]
pub async fn add_model_to_instance(
&self,
instance_id: &str,
model_name: String,
capabilities: Vec<ModelCapability>,
) -> Result<(), RouterError> {
let instances = self.internal.instances.read().await;
if let Some(instance) = instances.iter().find(|inst| inst.id == instance_id) {
let mut supported_models = instance.supported_models.write().await;
info!(
instance_id = %instance_id,
model = %model_name,
capabilities = ?capabilities,
"Adding/Updating model support"
);
supported_models.insert(model_name, capabilities);
Ok(())
} else {
error!(instance_id = %instance_id, "Instance not found when trying to add model");
Err(RouterError::InstanceNotFound(instance_id.to_string()))
}
}
#[instrument(skip(self, id))]
pub async fn remove_instance(&self, id: &str) -> Result<(), RouterError> {
let mut instances = self.internal.instances.write().await;
let initial_len = instances.len();
instances.retain(|inst| inst.id != id);
if instances.len() < initial_len {
info!(instance_id = %id, "Removed instance");
Ok(())
} else {
error!(instance_id = %id, "Instance not found for removal");
Err(RouterError::InstanceNotFound(id.to_string()))
}
}
#[instrument(skip(self, model_name, capability))]
pub async fn select_instance_for_model(
&self,
model_name: &str,
capability: ModelCapability,
) -> Result<LLMInstance, RouterError> {
let instance_refs: Vec<LLMInstance> = {
self.internal.instances.read().await.clone()
};
let mut available_refs = Vec::new();
for instance in instance_refs.iter() {
let status = instance.status.read().await.clone();
let is_timed_out = instance.is_in_timeout.load(Ordering::SeqCst);
let timeout_expiry = *instance.timeout_until.read().await;
let now = Instant::now();
if is_timed_out {
if let Some(expiry) = timeout_expiry {
if now < expiry {
continue;
}
}
}
if status != InstanceStatus::Healthy {
continue;
}
let supports = self
.instance_supports_model(&instance.id, model_name, &capability)
.await?;
if supports {
available_refs.push(instance); } else {
debug!(instance_id = %instance.id, model = %model_name, capability = ?capability, "Instance does not support model/capability.");
}
}
if available_refs.is_empty() {
warn!(model = %model_name, capability = ?capability, "No healthy instances available for the required model and capability.");
return Err(RouterError::NoHealthyInstancesForModel(
model_name.to_string(),
capability,
));
}
match self.internal.strategy {
RoutingStrategy::RoundRobin => {
let index = self
.internal
.next_instance_index
.fetch_add(1, Ordering::SeqCst);
let selected_instance = available_refs[index % available_refs.len()].clone(); debug!(instance_id = %selected_instance.id, strategy = "RoundRobin", "Selected instance");
Ok(selected_instance)
}
RoutingStrategy::LoadBased => {
available_refs
.get(0) .map(|instance_ref_ref| (*instance_ref_ref).clone()) .ok_or_else(|| {
error!("LoadBased selection failed: available_refs was non-empty but get(0) yielded None.");
RouterError::NoHealthyInstancesForModel(model_name.to_string(), capability)
})
}
}
}
#[instrument(skip(self))]
pub async fn select_next_instance(&self) -> Result<LLMInstance, RouterError> {
let instance_refs: Vec<LLMInstance> = {
self.internal.instances.read().await.clone()
};
let mut available_refs = Vec::new();
for instance in instance_refs.iter() {
let status = instance.status.read().await.clone();
let is_timed_out = instance.is_in_timeout.load(Ordering::SeqCst);
let timeout_expiry = *instance.timeout_until.read().await;
let now = Instant::now();
if is_timed_out {
if let Some(expiry) = timeout_expiry {
if now < expiry {
continue;
}
}
}
if status == InstanceStatus::Healthy {
available_refs.push(instance); } else {
debug!(instance_id = %instance.id, status = ?status, "Instance is not healthy.");
}
}
if available_refs.is_empty() {
warn!("No healthy instances available for selection.");
return Err(RouterError::NoHealthyInstances);
}
match self.internal.strategy {
RoutingStrategy::RoundRobin => {
let index = self
.internal
.next_instance_index
.fetch_add(1, Ordering::SeqCst);
let selected_instance = available_refs[index % available_refs.len()].clone(); debug!(instance_id = %selected_instance.id, strategy = "RoundRobin", "Selected instance");
Ok(selected_instance)
}
RoutingStrategy::LoadBased => {
available_refs
.get(0)
.map(|instance_ref_ref| (*instance_ref_ref).clone())
.ok_or_else(|| {
error!("LoadBased selection failed: available_refs was non-empty but get(0) yielded None.");
RouterError::NoHealthyInstances
})
}
}
}
pub async fn increment_request_count(&self, instance_id: &str) -> Result<(), RouterError> {
let instances = self.internal.instances.read().await;
if let Some(instance) = instances.iter().find(|inst| inst.id == instance_id) {
instance.active_requests.fetch_add(1, Ordering::SeqCst);
Ok(())
} else {
Err(RouterError::InstanceNotFound(instance_id.to_string()))
}
}
pub async fn decrement_request_count(&self, instance_id: &str) -> Result<(), RouterError> {
let instances = self.internal.instances.read().await;
if let Some(instance) = instances.iter().find(|inst| inst.id == instance_id) {
instance.active_requests.fetch_sub(1, Ordering::SeqCst);
Ok(())
} else {
Err(RouterError::InstanceNotFound(instance_id.to_string()))
}
}
pub async fn timeout_instance(&self, instance_id: &str) -> Result<(), RouterError> {
let instances = self.internal.instances.read().await;
if let Some(instance) = instances.iter().find(|inst| inst.id == instance_id) {
let timeout_duration = self.internal.instance_timeout_duration;
let expiry_time = Instant::now() + timeout_duration;
instance.is_in_timeout.store(true, Ordering::SeqCst);
*instance.timeout_until.write().await = Some(expiry_time);
*instance.status.write().await = InstanceStatus::TimedOut;
warn!(instance_id = %instance_id, duration = ?timeout_duration, "Instance timed out.");
Ok(())
} else {
error!(instance_id = %instance_id, "Instance not found for timeout.");
Err(RouterError::InstanceNotFound(instance_id.to_string()))
}
}
async fn instance_supports_model(
&self,
instance_id: &str,
model_name: &str,
capability: &ModelCapability,
) -> Result<bool, RouterError> {
let instances = self.internal.instances.read().await;
if let Some(instance) = instances.iter().find(|inst| inst.id == instance_id) {
let supported_models = instance.supported_models.read().await;
if let Some(capabilities) = supported_models.get(model_name) {
Ok(capabilities.contains(capability))
} else {
Ok(false) }
} else {
Err(RouterError::InstanceNotFound(instance_id.to_string()))
}
}
}
#[derive(Debug)]
pub struct RequestTracker {
router: Router, instance_id: String,
incremented: bool,
}
impl RequestTracker {
pub async fn new(router: Router, instance_id: String) -> Self {
let increment_result = router.increment_request_count(&instance_id).await;
let incremented = match increment_result {
Ok(_) => true,
Err(e) => {
error!(instance_id = %instance_id, error = ?e, "Failed to increment request count in RequestTracker");
false
}
};
RequestTracker { router, instance_id, incremented }
}
}
impl Drop for RequestTracker {
fn drop(&mut self) {
if self.incremented {
let router = self.router.clone();
let instance_id = self.instance_id.clone();
tokio::spawn(async move {
if let Err(e) = router.decrement_request_count(&instance_id).await {
error!(instance_id = %instance_id, error = ?e, "Failed to decrement request count in RequestTracker drop");
}
});
}
}
}
struct HealthCheck;
impl HealthCheck {
fn spawn(
state: Arc<RouterInternalState>,
mut stop_rx: oneshot::Receiver<()>,
health_check_client: reqwest::Client, health_check_interval_duration: Duration,
_health_check_timeout: Duration, health_check_path: String,
) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(health_check_interval_duration);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
loop {
tokio::select! {
_ = interval.tick() => {
debug!("Running health check cycle...");
Self::perform_checks(&state, &health_check_client, &health_check_path, _health_check_timeout).await;
}
_ = &mut stop_rx => {
info!("Stopping health check task.");
break;
}
}
}
});
}
async fn perform_checks(
state: &Arc<RouterInternalState>,
health_check_client: &reqwest::Client, health_check_path: &str,
_health_check_timeout: Duration, ) {
let instances = state.instances.read().await.clone();
const CONCURRENCY_LIMIT: usize = 10;
stream::iter(instances.into_iter())
.map(|instance| {
let client = health_check_client.clone();
let path = health_check_path.to_string();
async move {
if instance.is_in_timeout.load(Ordering::SeqCst) {
let expiry = *instance.timeout_until.read().await;
if let Some(exp) = expiry {
if Instant::now() < exp {
debug!(instance_id=%instance.id, "Skipping health check, instance in timeout.");
return; } else {
instance.is_in_timeout.store(false, Ordering::SeqCst);
*instance.timeout_until.write().await = None;
info!(instance_id=%instance.id, "Instance timeout expired during health check cycle.");
}
} else {
instance.is_in_timeout.store(false, Ordering::SeqCst);
}
}
let url = format!("{}/{}", instance.base_url.trim_end_matches('/'), path.trim_start_matches('/'));
let result = client.get(&url).send().await;
let new_status = match result {
Ok(response) => {
if response.status().is_success() {
InstanceStatus::Healthy
} else {
warn!(instance_id=%instance.id, status=%response.status(), url=%url, "Health check failed with status");
InstanceStatus::Unhealthy
}
}
Err(e) => {
warn!(instance_id=%instance.id, error=?e, url=%url, "Health check request failed");
InstanceStatus::Unhealthy
}
};
let mut status_guard = instance.status.write().await;
if *status_guard != new_status {
info!(instance_id=%instance.id, old_status=?*status_guard, new_status=?new_status, "Instance health status changed");
*status_guard = new_status;
}
}
})
.buffer_unordered(CONCURRENCY_LIMIT)
.collect::<()>()
.await;
debug!("Finished health check cycle.");
}
}
impl Drop for Router {
fn drop(&mut self) {
if let Ok(mut guard) = self.stop_health_check_tx.try_write() {
if let Some(tx) = guard.take() {
let _ = tx.send(()); info!("Sent stop signal to health check task via try_write.");
}
} else {
}
}
}
#[cfg(test)]
mod tests {
use super::*; use crate::types::{ModelCapability, InstanceStatus}; use std::time::Duration;
fn create_test_router() -> Router {
Router::builder()
.strategy(RoutingStrategy::RoundRobin)
.build()
}
#[tokio::test]
async fn test_builder_defaults() {
let builder = RouterBuilder::new();
assert_eq!(builder.strategy, RoutingStrategy::LoadBased); assert_eq!(builder.health_check_interval, DEFAULT_HEALTH_CHECK_INTERVAL);
assert_eq!(builder.health_check_timeout, DEFAULT_HEALTH_CHECK_TIMEOUT);
assert_eq!(builder.health_check_path, DEFAULT_HEALTH_CHECK_PATH);
assert_eq!(builder.instance_timeout_duration, DEFAULT_INSTANCE_TIMEOUT);
assert!(builder.initial_instances.is_empty());
assert!(builder.reqwest_client.is_none());
}
#[tokio::test]
async fn test_builder_custom_config() {
let client = reqwest::Client::new();
let builder = RouterBuilder::new()
.strategy(RoutingStrategy::RoundRobin)
.health_check_interval(Duration::from_secs(10))
.health_check_timeout(Duration::from_secs(5))
.health_check_path("/status")
.instance_timeout_duration(Duration::from_secs(60))
.http_client(client.clone())
.instance("test1", "http://localhost:8080");
assert_eq!(builder.strategy, RoutingStrategy::RoundRobin);
assert_eq!(builder.health_check_interval, Duration::from_secs(10));
assert_eq!(builder.health_check_timeout, Duration::from_secs(5));
assert_eq!(builder.health_check_path, "/status");
assert_eq!(builder.instance_timeout_duration, Duration::from_secs(60));
assert_eq!(builder.initial_instances.len(), 1);
assert!(builder.reqwest_client.is_some());
}
#[tokio::test]
async fn test_add_remove_instance() {
let router = create_test_router();
assert!(router.get_instances().await.is_empty());
let add_result = router.add_instance("inst1", "http://127.0.0.1:1111").await;
assert!(add_result.is_ok());
let instances = router.get_instances().await;
assert_eq!(instances.len(), 1);
assert_eq!(instances[0].id, "inst1");
assert_eq!(*instances[0].status.read().await, InstanceStatus::Unknown);
let add_duplicate_result = router.add_instance("inst1", "http://127.0.0.1:2222").await;
assert!(add_duplicate_result.is_err());
assert!(matches!(add_duplicate_result.unwrap_err(), RouterError::InstanceExists(_)));
assert_eq!(router.get_instances().await.len(), 1);
let remove_result = router.remove_instance("inst1").await;
assert!(remove_result.is_ok());
assert!(router.get_instances().await.is_empty());
let remove_nonexistent_result = router.remove_instance("inst2").await;
assert!(remove_nonexistent_result.is_err());
assert!(matches!(remove_nonexistent_result.unwrap_err(), RouterError::InstanceNotFound(_)));
}
#[tokio::test]
async fn test_add_instance_with_models() {
let router = create_test_router();
let models = vec![
ModelInstanceConfig { model_name: "gpt-4".to_string(), capabilities: vec![ModelCapability::Chat] },
ModelInstanceConfig { model_name: "text-embed".to_string(), capabilities: vec![ModelCapability::Embedding] },
];
let add_result = router.add_instance_with_models("inst_models", "http://127.0.0.1:3333", models).await;
assert!(add_result.is_ok());
let instances = router.get_instances().await;
assert_eq!(instances.len(), 1);
let instance = &instances[0];
assert_eq!(instance.id, "inst_models");
let supported_models = instance.supported_models.read().await;
assert_eq!(supported_models.len(), 2);
assert!(supported_models.contains_key("gpt-4"));
assert!(supported_models.contains_key("text-embed"));
assert_eq!(supported_models.get("gpt-4").unwrap(), &vec![ModelCapability::Chat]);
}
#[tokio::test]
async fn test_add_model_to_instance() {
let router = create_test_router();
router.add_instance("inst_add_model", "http://127.0.0.1:4444").await.unwrap();
let add_model_result = router.add_model_to_instance(
"inst_add_model",
"new-model".to_string(),
vec![ModelCapability::Completion]
).await;
assert!(add_model_result.is_ok());
let instances = router.get_instances().await;
let instance = instances.iter().find(|i| i.id == "inst_add_model").unwrap();
let supported_models = instance.supported_models.read().await;
assert_eq!(supported_models.len(), 1);
assert!(supported_models.contains_key("new-model"));
assert_eq!(supported_models.get("new-model").unwrap(), &vec![ModelCapability::Completion]);
let add_to_nonexistent = router.add_model_to_instance(
"nonexistent",
"test".to_string(),
vec![]
).await;
assert!(add_to_nonexistent.is_err());
assert!(matches!(add_to_nonexistent.unwrap_err(), RouterError::InstanceNotFound(_)));
}
}