use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ModelId(pub String);
impl ModelId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
pub fn is_base(&self) -> bool {
!self.0.contains(':')
}
pub fn adapter_name(&self) -> Option<&str> {
self.0.split_once(':').map(|(_, adapter)| adapter)
}
}
impl std::fmt::Display for ModelId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum EndpointStatus {
Ready,
Loading,
Error,
Disabled,
}
impl EndpointStatus {
pub fn is_available(&self) -> bool {
*self == Self::Ready
}
pub fn name(&self) -> &'static str {
match self {
Self::Ready => "ready",
Self::Loading => "loading",
Self::Error => "error",
Self::Disabled => "disabled",
}
}
}
#[derive(Debug, Clone)]
pub struct ModelEndpoint {
pub id: ModelId,
pub display_name: String,
pub description: String,
pub base_model: String,
pub adapter: Option<String>,
pub max_context_length: usize,
pub is_default: bool,
pub status: EndpointStatus,
}
impl ModelEndpoint {
pub fn new(id: impl Into<String>, base_model: impl Into<String>) -> Self {
let id_str: String = id.into();
let base: String = base_model.into();
Self {
display_name: id_str.clone(),
id: ModelId::new(id_str),
description: String::new(),
base_model: base,
adapter: None,
max_context_length: 4096,
is_default: false,
status: EndpointStatus::Ready,
}
}
pub fn with_adapter(mut self, adapter: impl Into<String>) -> Self {
self.adapter = Some(adapter.into());
self
}
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.description = desc.into();
self
}
pub fn with_context_length(mut self, ctx: usize) -> Self {
self.max_context_length = ctx;
self
}
pub fn set_default(mut self) -> Self {
self.is_default = true;
self
}
}
pub struct ModelRegistry {
endpoints: HashMap<ModelId, ModelEndpoint>,
aliases: HashMap<String, ModelId>,
default_model: Option<ModelId>,
}
impl ModelRegistry {
pub fn new() -> Self {
Self {
endpoints: HashMap::new(),
aliases: HashMap::new(),
default_model: None,
}
}
pub fn register(&mut self, endpoint: ModelEndpoint) {
if endpoint.is_default {
self.default_model = Some(endpoint.id.clone());
}
self.endpoints.insert(endpoint.id.clone(), endpoint);
}
pub fn unregister(&mut self, id: &ModelId) -> Option<ModelEndpoint> {
let removed = self.endpoints.remove(id);
if removed.is_some() {
if self.default_model.as_ref() == Some(id) {
self.default_model = None;
}
self.aliases.retain(|_, target| target != id);
}
removed
}
pub fn add_alias(&mut self, alias: impl Into<String>, target: ModelId) {
self.aliases.insert(alias.into(), target);
}
pub fn resolve(&self, id_or_alias: &str) -> Option<&ModelEndpoint> {
let model_id = ModelId::new(id_or_alias);
if let Some(ep) = self.endpoints.get(&model_id) {
return Some(ep);
}
if let Some(target_id) = self.aliases.get(id_or_alias) {
return self.endpoints.get(target_id);
}
None
}
pub fn default_endpoint(&self) -> Option<&ModelEndpoint> {
self.default_model
.as_ref()
.and_then(|id| self.endpoints.get(id))
}
pub fn available_endpoints(&self) -> Vec<&ModelEndpoint> {
self.endpoints
.values()
.filter(|ep| ep.status.is_available())
.collect()
}
pub fn all_endpoints(&self) -> Vec<&ModelEndpoint> {
self.endpoints.values().collect()
}
pub fn set_status(&mut self, id: &ModelId, status: EndpointStatus) -> bool {
if let Some(ep) = self.endpoints.get_mut(id) {
ep.status = status;
true
} else {
false
}
}
pub fn len(&self) -> usize {
self.endpoints.len()
}
pub fn is_empty(&self) -> bool {
self.endpoints.is_empty()
}
}
impl Default for ModelRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, thiserror::Error)]
pub enum RoutingError {
#[error("model '{0}' not found")]
ModelNotFound(String),
#[error("model '{model}' cannot handle context length {required} (max: {available})")]
ContextTooLong {
model: String,
required: usize,
available: usize,
},
#[error("no models are currently available")]
NoModelsAvailable,
#[error("model '{0}' is not ready (status: {1})")]
ModelNotReady(String, String),
}
pub struct ModelRouter {
registry: ModelRegistry,
}
impl ModelRouter {
pub fn new(registry: ModelRegistry) -> Self {
Self { registry }
}
pub fn route(&self, requested_model: Option<&str>) -> Result<&ModelEndpoint, RoutingError> {
let endpoint = match requested_model {
Some(model_name) => self
.registry
.resolve(model_name)
.ok_or_else(|| RoutingError::ModelNotFound(model_name.to_string()))?,
None => self
.registry
.default_endpoint()
.ok_or(RoutingError::NoModelsAvailable)?,
};
if !endpoint.status.is_available() {
return Err(RoutingError::ModelNotReady(
endpoint.id.to_string(),
endpoint.status.name().to_string(),
));
}
Ok(endpoint)
}
pub fn route_for_context(
&self,
requested_model: Option<&str>,
required_context: usize,
) -> Result<&ModelEndpoint, RoutingError> {
let endpoint = self.route(requested_model)?;
if endpoint.max_context_length < required_context {
if requested_model.is_some() {
return Err(RoutingError::ContextTooLong {
model: endpoint.id.to_string(),
required: required_context,
available: endpoint.max_context_length,
});
}
let fallback = self
.registry
.available_endpoints()
.into_iter()
.filter(|ep| ep.max_context_length >= required_context)
.max_by_key(|ep| ep.max_context_length);
return fallback.ok_or(RoutingError::ContextTooLong {
model: endpoint.id.to_string(),
required: required_context,
available: endpoint.max_context_length,
});
}
Ok(endpoint)
}
pub fn models_list(&self) -> Vec<ModelListEntry> {
let created = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
self.registry
.available_endpoints()
.into_iter()
.map(|ep| ModelListEntry {
id: ep.id.to_string(),
object: "model".to_string(),
owned_by: "oxibonsai".to_string(),
created,
})
.collect()
}
pub fn registry(&self) -> &ModelRegistry {
&self.registry
}
pub fn registry_mut(&mut self) -> &mut ModelRegistry {
&mut self.registry
}
}
#[derive(Debug, Clone)]
pub struct ModelListEntry {
pub id: String,
pub object: String,
pub owned_by: String,
pub created: u64,
}
#[derive(Debug, Clone)]
pub struct AdapterRef {
pub name: String,
pub weight: f32,
}
#[derive(Debug, Clone)]
pub struct AdapterStack {
pub adapters: Vec<AdapterRef>,
}
impl AdapterStack {
pub fn new() -> Self {
Self {
adapters: Vec::new(),
}
}
pub fn add(mut self, name: impl Into<String>, weight: f32) -> Self {
self.adapters.push(AdapterRef {
name: name.into(),
weight,
});
self
}
pub fn len(&self) -> usize {
self.adapters.len()
}
pub fn is_empty(&self) -> bool {
self.adapters.is_empty()
}
pub fn total_weight(&self) -> f32 {
self.adapters.iter().map(|a| a.weight).sum()
}
pub fn normalize_weights(&mut self) {
let total = self.total_weight();
if total.abs() < f32::EPSILON {
return;
}
for adapter in &mut self.adapters {
adapter.weight /= total;
}
}
}
impl Default for AdapterStack {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn model_id_display() {
let id = ModelId::new("bonsai-8b");
assert_eq!(format!("{id}"), "bonsai-8b");
}
#[test]
fn endpoint_status_name() {
assert_eq!(EndpointStatus::Ready.name(), "ready");
assert_eq!(EndpointStatus::Loading.name(), "loading");
assert_eq!(EndpointStatus::Error.name(), "error");
assert_eq!(EndpointStatus::Disabled.name(), "disabled");
}
#[test]
fn endpoint_display_name_defaults_to_id() {
let ep = ModelEndpoint::new("bonsai-8b", "qwen3-8b");
assert_eq!(ep.display_name, "bonsai-8b");
}
}