Skip to main content

oxibonsai_runtime/
multi_model.rs

1//! Multi-model serving: manage base models + LoRA adapters with smart routing.
2//!
3//! Supports:
4//! - Multiple base model configurations
5//! - Hot-swappable LoRA adapter registry
6//! - Request routing by model ID
7//! - Model alias resolution
8//! - Adapter composition (stacking multiple LoRAs)
9
10use std::collections::HashMap;
11
12// ─────────────────────────────────────────────────────────────────────────────
13// ModelId
14// ─────────────────────────────────────────────────────────────────────────────
15
16/// A model endpoint identifier.
17///
18/// Uses a convention where `"base_name"` denotes a base model and
19/// `"base_name:adapter_name"` denotes a base model with a LoRA adapter applied.
20#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21pub struct ModelId(pub String);
22
23impl ModelId {
24    /// Create a new model identifier from any string-like value.
25    pub fn new(id: impl Into<String>) -> Self {
26        Self(id.into())
27    }
28
29    /// Return the identifier as a string slice.
30    pub fn as_str(&self) -> &str {
31        &self.0
32    }
33
34    /// Returns `true` if this is a base model (no `":"` separator).
35    pub fn is_base(&self) -> bool {
36        !self.0.contains(':')
37    }
38
39    /// If the identifier has the form `"base:adapter"`, return `Some("adapter")`.
40    /// Otherwise return `None`.
41    pub fn adapter_name(&self) -> Option<&str> {
42        self.0.split_once(':').map(|(_, adapter)| adapter)
43    }
44}
45
46impl std::fmt::Display for ModelId {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        write!(f, "{}", self.0)
49    }
50}
51
52// ─────────────────────────────────────────────────────────────────────────────
53// EndpointStatus
54// ─────────────────────────────────────────────────────────────────────────────
55
56/// Status of a model endpoint.
57#[derive(Debug, Clone, Copy, PartialEq)]
58pub enum EndpointStatus {
59    /// The model is loaded and ready to serve requests.
60    Ready,
61    /// The model is currently being loaded.
62    Loading,
63    /// The model encountered an error and is unavailable.
64    Error,
65    /// The model has been explicitly disabled by an administrator.
66    Disabled,
67}
68
69impl EndpointStatus {
70    /// Returns `true` if the endpoint is available for serving requests.
71    pub fn is_available(&self) -> bool {
72        *self == Self::Ready
73    }
74
75    /// Human-readable name for this status.
76    pub fn name(&self) -> &'static str {
77        match self {
78            Self::Ready => "ready",
79            Self::Loading => "loading",
80            Self::Error => "error",
81            Self::Disabled => "disabled",
82        }
83    }
84}
85
86// ─────────────────────────────────────────────────────────────────────────────
87// ModelEndpoint
88// ─────────────────────────────────────────────────────────────────────────────
89
90/// Metadata for a served model variant.
91///
92/// Each endpoint represents a unique model configuration that can receive
93/// inference requests. A base model may have multiple endpoints, each with
94/// a different LoRA adapter applied.
95#[derive(Debug, Clone)]
96pub struct ModelEndpoint {
97    /// Unique identifier for this endpoint.
98    pub id: ModelId,
99    /// Human-readable display name.
100    pub display_name: String,
101    /// Longer description of what this endpoint provides.
102    pub description: String,
103    /// Name of the underlying base model.
104    pub base_model: String,
105    /// Optional LoRA adapter name applied on top of the base model.
106    pub adapter: Option<String>,
107    /// Maximum context length (in tokens) this endpoint supports.
108    pub max_context_length: usize,
109    /// Whether this endpoint is the default when no model is specified.
110    pub is_default: bool,
111    /// Current operational status.
112    pub status: EndpointStatus,
113}
114
115impl ModelEndpoint {
116    /// Create a new endpoint with sensible defaults.
117    ///
118    /// Status is set to `Ready`, no adapter, default context length of 4096.
119    pub fn new(id: impl Into<String>, base_model: impl Into<String>) -> Self {
120        let id_str: String = id.into();
121        let base: String = base_model.into();
122        Self {
123            display_name: id_str.clone(),
124            id: ModelId::new(id_str),
125            description: String::new(),
126            base_model: base,
127            adapter: None,
128            max_context_length: 4096,
129            is_default: false,
130            status: EndpointStatus::Ready,
131        }
132    }
133
134    /// Attach a LoRA adapter to this endpoint.
135    pub fn with_adapter(mut self, adapter: impl Into<String>) -> Self {
136        self.adapter = Some(adapter.into());
137        self
138    }
139
140    /// Set a human-readable description.
141    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
142        self.description = desc.into();
143        self
144    }
145
146    /// Set the maximum context length.
147    pub fn with_context_length(mut self, ctx: usize) -> Self {
148        self.max_context_length = ctx;
149        self
150    }
151
152    /// Mark this endpoint as the default.
153    pub fn set_default(mut self) -> Self {
154        self.is_default = true;
155        self
156    }
157}
158
159// ─────────────────────────────────────────────────────────────────────────────
160// ModelRegistry
161// ─────────────────────────────────────────────────────────────────────────────
162
163/// The multi-model registry.
164///
165/// Manages a collection of [`ModelEndpoint`] instances and supports alias
166/// resolution so that clients can refer to models by friendly names.
167pub struct ModelRegistry {
168    endpoints: HashMap<ModelId, ModelEndpoint>,
169    aliases: HashMap<String, ModelId>,
170    default_model: Option<ModelId>,
171}
172
173impl ModelRegistry {
174    /// Create an empty registry.
175    pub fn new() -> Self {
176        Self {
177            endpoints: HashMap::new(),
178            aliases: HashMap::new(),
179            default_model: None,
180        }
181    }
182
183    /// Register a model endpoint.
184    ///
185    /// If the endpoint has `is_default` set, it becomes the default model.
186    /// Replaces any existing endpoint with the same ID.
187    pub fn register(&mut self, endpoint: ModelEndpoint) {
188        if endpoint.is_default {
189            self.default_model = Some(endpoint.id.clone());
190        }
191        self.endpoints.insert(endpoint.id.clone(), endpoint);
192    }
193
194    /// Remove an endpoint from the registry.
195    ///
196    /// Also clears the default-model pointer if it pointed to the removed
197    /// endpoint, and removes any aliases that targeted this ID.
198    pub fn unregister(&mut self, id: &ModelId) -> Option<ModelEndpoint> {
199        let removed = self.endpoints.remove(id);
200        if removed.is_some() {
201            // Clear default if it was this model.
202            if self.default_model.as_ref() == Some(id) {
203                self.default_model = None;
204            }
205            // Remove aliases pointing to this model.
206            self.aliases.retain(|_, target| target != id);
207        }
208        removed
209    }
210
211    /// Add an alias: e.g. `"gpt-4"` maps to `ModelId("bonsai-8b")`.
212    pub fn add_alias(&mut self, alias: impl Into<String>, target: ModelId) {
213        self.aliases.insert(alias.into(), target);
214    }
215
216    /// Resolve a model identifier (checks ID first, then aliases).
217    ///
218    /// Returns `None` if neither a direct ID nor an alias matches.
219    pub fn resolve(&self, id_or_alias: &str) -> Option<&ModelEndpoint> {
220        let model_id = ModelId::new(id_or_alias);
221        if let Some(ep) = self.endpoints.get(&model_id) {
222            return Some(ep);
223        }
224        // Try alias resolution.
225        if let Some(target_id) = self.aliases.get(id_or_alias) {
226            return self.endpoints.get(target_id);
227        }
228        None
229    }
230
231    /// Get the default model endpoint.
232    pub fn default_endpoint(&self) -> Option<&ModelEndpoint> {
233        self.default_model
234            .as_ref()
235            .and_then(|id| self.endpoints.get(id))
236    }
237
238    /// List all available (Ready) endpoints.
239    pub fn available_endpoints(&self) -> Vec<&ModelEndpoint> {
240        self.endpoints
241            .values()
242            .filter(|ep| ep.status.is_available())
243            .collect()
244    }
245
246    /// List all registered endpoints (including non-ready ones).
247    pub fn all_endpoints(&self) -> Vec<&ModelEndpoint> {
248        self.endpoints.values().collect()
249    }
250
251    /// Update an endpoint's status.
252    ///
253    /// Returns `true` if the endpoint was found and updated.
254    pub fn set_status(&mut self, id: &ModelId, status: EndpointStatus) -> bool {
255        if let Some(ep) = self.endpoints.get_mut(id) {
256            ep.status = status;
257            true
258        } else {
259            false
260        }
261    }
262
263    /// Number of registered endpoints.
264    pub fn len(&self) -> usize {
265        self.endpoints.len()
266    }
267
268    /// Is the registry empty?
269    pub fn is_empty(&self) -> bool {
270        self.endpoints.is_empty()
271    }
272}
273
274impl Default for ModelRegistry {
275    fn default() -> Self {
276        Self::new()
277    }
278}
279
280// ─────────────────────────────────────────────────────────────────────────────
281// RoutingError
282// ─────────────────────────────────────────────────────────────────────────────
283
284/// Errors that can occur when routing a request to a model endpoint.
285#[derive(Debug, thiserror::Error)]
286pub enum RoutingError {
287    /// The requested model was not found in the registry.
288    #[error("model '{0}' not found")]
289    ModelNotFound(String),
290
291    /// The requested model cannot accommodate the required context length.
292    #[error("model '{model}' cannot handle context length {required} (max: {available})")]
293    ContextTooLong {
294        model: String,
295        required: usize,
296        available: usize,
297    },
298
299    /// No models are currently available in the registry.
300    #[error("no models are currently available")]
301    NoModelsAvailable,
302
303    /// The model was found but is not in a ready state.
304    #[error("model '{0}' is not ready (status: {1})")]
305    ModelNotReady(String, String),
306}
307
308// ─────────────────────────────────────────────────────────────────────────────
309// ModelRouter
310// ─────────────────────────────────────────────────────────────────────────────
311
312/// Smart request router: selects the best model endpoint for a request.
313///
314/// Wraps a [`ModelRegistry`] and adds routing logic including fallback to
315/// the default model, context-length awareness, and OpenAI-compatible
316/// model listing.
317pub struct ModelRouter {
318    registry: ModelRegistry,
319}
320
321impl ModelRouter {
322    /// Create a new router backed by the given registry.
323    pub fn new(registry: ModelRegistry) -> Self {
324        Self { registry }
325    }
326
327    /// Route a request: resolve `model_id` from the request.
328    ///
329    /// Falls back to the default model if `requested_model` is `None`.
330    /// Returns an error if the resolved endpoint is not in a `Ready` state.
331    pub fn route(&self, requested_model: Option<&str>) -> Result<&ModelEndpoint, RoutingError> {
332        let endpoint = match requested_model {
333            Some(model_name) => self
334                .registry
335                .resolve(model_name)
336                .ok_or_else(|| RoutingError::ModelNotFound(model_name.to_string()))?,
337            None => self
338                .registry
339                .default_endpoint()
340                .ok_or(RoutingError::NoModelsAvailable)?,
341        };
342
343        if !endpoint.status.is_available() {
344            return Err(RoutingError::ModelNotReady(
345                endpoint.id.to_string(),
346                endpoint.status.name().to_string(),
347            ));
348        }
349
350        Ok(endpoint)
351    }
352
353    /// Route with context-length awareness: pick a model that can accommodate
354    /// the required context length.
355    ///
356    /// If a specific model is requested, validates it has sufficient context.
357    /// If no model is specified, finds the default model that fits, or falls
358    /// back to any available model with sufficient context capacity.
359    pub fn route_for_context(
360        &self,
361        requested_model: Option<&str>,
362        required_context: usize,
363    ) -> Result<&ModelEndpoint, RoutingError> {
364        let endpoint = self.route(requested_model)?;
365
366        if endpoint.max_context_length < required_context {
367            // If a specific model was requested but is too small, error out.
368            if requested_model.is_some() {
369                return Err(RoutingError::ContextTooLong {
370                    model: endpoint.id.to_string(),
371                    required: required_context,
372                    available: endpoint.max_context_length,
373                });
374            }
375
376            // No specific model requested — try to find any available endpoint
377            // with sufficient context capacity.
378            let fallback = self
379                .registry
380                .available_endpoints()
381                .into_iter()
382                .filter(|ep| ep.max_context_length >= required_context)
383                .max_by_key(|ep| ep.max_context_length);
384
385            return fallback.ok_or(RoutingError::ContextTooLong {
386                model: endpoint.id.to_string(),
387                required: required_context,
388                available: endpoint.max_context_length,
389            });
390        }
391
392        Ok(endpoint)
393    }
394
395    /// OpenAI-compatible `/v1/models` list.
396    ///
397    /// Returns an entry for every available endpoint in the registry.
398    pub fn models_list(&self) -> Vec<ModelListEntry> {
399        let created = std::time::SystemTime::now()
400            .duration_since(std::time::UNIX_EPOCH)
401            .unwrap_or_default()
402            .as_secs();
403
404        self.registry
405            .available_endpoints()
406            .into_iter()
407            .map(|ep| ModelListEntry {
408                id: ep.id.to_string(),
409                object: "model".to_string(),
410                owned_by: "oxibonsai".to_string(),
411                created,
412            })
413            .collect()
414    }
415
416    /// Immutable access to the underlying registry.
417    pub fn registry(&self) -> &ModelRegistry {
418        &self.registry
419    }
420
421    /// Mutable access to the underlying registry.
422    pub fn registry_mut(&mut self) -> &mut ModelRegistry {
423        &mut self.registry
424    }
425}
426
427// ─────────────────────────────────────────────────────────────────────────────
428// ModelListEntry
429// ─────────────────────────────────────────────────────────────────────────────
430
431/// Entry for an OpenAI-compatible `/v1/models` response.
432#[derive(Debug, Clone)]
433pub struct ModelListEntry {
434    /// Model identifier string.
435    pub id: String,
436    /// Object type — always `"model"`.
437    pub object: String,
438    /// Organisation that owns the model.
439    pub owned_by: String,
440    /// Unix timestamp when the model was created/registered.
441    pub created: u64,
442}
443
444// ─────────────────────────────────────────────────────────────────────────────
445// AdapterRef / AdapterStack
446// ─────────────────────────────────────────────────────────────────────────────
447
448/// A reference to a single LoRA adapter with a blending weight.
449#[derive(Debug, Clone)]
450pub struct AdapterRef {
451    /// Name of the LoRA adapter.
452    pub name: String,
453    /// Blending weight in the range `[0.0, 1.0]`.
454    pub weight: f32,
455}
456
457/// Adapter composition: apply multiple LoRA adapters in sequence.
458///
459/// Allows stacking several adapters with independent blending weights.
460/// Weights can be normalized so they sum to 1.0, which is useful for
461/// even blending across adapters.
462#[derive(Debug, Clone)]
463pub struct AdapterStack {
464    /// The ordered list of adapters to apply.
465    pub adapters: Vec<AdapterRef>,
466}
467
468impl AdapterStack {
469    /// Create an empty adapter stack.
470    pub fn new() -> Self {
471        Self {
472            adapters: Vec::new(),
473        }
474    }
475
476    /// Add an adapter with the given blending weight.
477    pub fn add(mut self, name: impl Into<String>, weight: f32) -> Self {
478        self.adapters.push(AdapterRef {
479            name: name.into(),
480            weight,
481        });
482        self
483    }
484
485    /// Number of adapters in the stack.
486    pub fn len(&self) -> usize {
487        self.adapters.len()
488    }
489
490    /// Whether the stack is empty.
491    pub fn is_empty(&self) -> bool {
492        self.adapters.is_empty()
493    }
494
495    /// Sum of all adapter weights.
496    pub fn total_weight(&self) -> f32 {
497        self.adapters.iter().map(|a| a.weight).sum()
498    }
499
500    /// Normalize weights so they sum to 1.0.
501    ///
502    /// If the total weight is zero (or very close to it), weights are left
503    /// unchanged to avoid division by zero.
504    pub fn normalize_weights(&mut self) {
505        let total = self.total_weight();
506        if total.abs() < f32::EPSILON {
507            return;
508        }
509        for adapter in &mut self.adapters {
510            adapter.weight /= total;
511        }
512    }
513}
514
515impl Default for AdapterStack {
516    fn default() -> Self {
517        Self::new()
518    }
519}
520
521// ─────────────────────────────────────────────────────────────────────────────
522// Unit tests
523// ─────────────────────────────────────────────────────────────────────────────
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528
529    #[test]
530    fn model_id_display() {
531        let id = ModelId::new("bonsai-8b");
532        assert_eq!(format!("{id}"), "bonsai-8b");
533    }
534
535    #[test]
536    fn endpoint_status_name() {
537        assert_eq!(EndpointStatus::Ready.name(), "ready");
538        assert_eq!(EndpointStatus::Loading.name(), "loading");
539        assert_eq!(EndpointStatus::Error.name(), "error");
540        assert_eq!(EndpointStatus::Disabled.name(), "disabled");
541    }
542
543    #[test]
544    fn endpoint_display_name_defaults_to_id() {
545        let ep = ModelEndpoint::new("bonsai-8b", "qwen3-8b");
546        assert_eq!(ep.display_name, "bonsai-8b");
547    }
548}