oxibonsai_runtime/multi_model.rs
1//! Multi-model serving: manage base models + LoRA adapters with smart routing.
2//!
3//! Supports:
4//! - Multiple base model configurations
5//! - Hot-swappable LoRA adapter registry
6//! - Request routing by model ID
7//! - Model alias resolution
8//! - Adapter composition (stacking multiple LoRAs)
9
10use std::collections::HashMap;
11
12// ─────────────────────────────────────────────────────────────────────────────
13// ModelId
14// ─────────────────────────────────────────────────────────────────────────────
15
16/// A model endpoint identifier.
17///
18/// Uses a convention where `"base_name"` denotes a base model and
19/// `"base_name:adapter_name"` denotes a base model with a LoRA adapter applied.
20#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21pub struct ModelId(pub String);
22
23impl ModelId {
24 /// Create a new model identifier from any string-like value.
25 pub fn new(id: impl Into<String>) -> Self {
26 Self(id.into())
27 }
28
29 /// Return the identifier as a string slice.
30 pub fn as_str(&self) -> &str {
31 &self.0
32 }
33
34 /// Returns `true` if this is a base model (no `":"` separator).
35 pub fn is_base(&self) -> bool {
36 !self.0.contains(':')
37 }
38
39 /// If the identifier has the form `"base:adapter"`, return `Some("adapter")`.
40 /// Otherwise return `None`.
41 pub fn adapter_name(&self) -> Option<&str> {
42 self.0.split_once(':').map(|(_, adapter)| adapter)
43 }
44}
45
46impl std::fmt::Display for ModelId {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 write!(f, "{}", self.0)
49 }
50}
51
52// ─────────────────────────────────────────────────────────────────────────────
53// EndpointStatus
54// ─────────────────────────────────────────────────────────────────────────────
55
56/// Status of a model endpoint.
57#[derive(Debug, Clone, Copy, PartialEq)]
58pub enum EndpointStatus {
59 /// The model is loaded and ready to serve requests.
60 Ready,
61 /// The model is currently being loaded.
62 Loading,
63 /// The model encountered an error and is unavailable.
64 Error,
65 /// The model has been explicitly disabled by an administrator.
66 Disabled,
67}
68
69impl EndpointStatus {
70 /// Returns `true` if the endpoint is available for serving requests.
71 pub fn is_available(&self) -> bool {
72 *self == Self::Ready
73 }
74
75 /// Human-readable name for this status.
76 pub fn name(&self) -> &'static str {
77 match self {
78 Self::Ready => "ready",
79 Self::Loading => "loading",
80 Self::Error => "error",
81 Self::Disabled => "disabled",
82 }
83 }
84}
85
86// ─────────────────────────────────────────────────────────────────────────────
87// ModelEndpoint
88// ─────────────────────────────────────────────────────────────────────────────
89
90/// Metadata for a served model variant.
91///
92/// Each endpoint represents a unique model configuration that can receive
93/// inference requests. A base model may have multiple endpoints, each with
94/// a different LoRA adapter applied.
95#[derive(Debug, Clone)]
96pub struct ModelEndpoint {
97 /// Unique identifier for this endpoint.
98 pub id: ModelId,
99 /// Human-readable display name.
100 pub display_name: String,
101 /// Longer description of what this endpoint provides.
102 pub description: String,
103 /// Name of the underlying base model.
104 pub base_model: String,
105 /// Optional LoRA adapter name applied on top of the base model.
106 pub adapter: Option<String>,
107 /// Maximum context length (in tokens) this endpoint supports.
108 pub max_context_length: usize,
109 /// Whether this endpoint is the default when no model is specified.
110 pub is_default: bool,
111 /// Current operational status.
112 pub status: EndpointStatus,
113}
114
115impl ModelEndpoint {
116 /// Create a new endpoint with sensible defaults.
117 ///
118 /// Status is set to `Ready`, no adapter, default context length of 4096.
119 pub fn new(id: impl Into<String>, base_model: impl Into<String>) -> Self {
120 let id_str: String = id.into();
121 let base: String = base_model.into();
122 Self {
123 display_name: id_str.clone(),
124 id: ModelId::new(id_str),
125 description: String::new(),
126 base_model: base,
127 adapter: None,
128 max_context_length: 4096,
129 is_default: false,
130 status: EndpointStatus::Ready,
131 }
132 }
133
134 /// Attach a LoRA adapter to this endpoint.
135 pub fn with_adapter(mut self, adapter: impl Into<String>) -> Self {
136 self.adapter = Some(adapter.into());
137 self
138 }
139
140 /// Set a human-readable description.
141 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
142 self.description = desc.into();
143 self
144 }
145
146 /// Set the maximum context length.
147 pub fn with_context_length(mut self, ctx: usize) -> Self {
148 self.max_context_length = ctx;
149 self
150 }
151
152 /// Mark this endpoint as the default.
153 pub fn set_default(mut self) -> Self {
154 self.is_default = true;
155 self
156 }
157}
158
159// ─────────────────────────────────────────────────────────────────────────────
160// ModelRegistry
161// ─────────────────────────────────────────────────────────────────────────────
162
163/// The multi-model registry.
164///
165/// Manages a collection of [`ModelEndpoint`] instances and supports alias
166/// resolution so that clients can refer to models by friendly names.
167pub struct ModelRegistry {
168 endpoints: HashMap<ModelId, ModelEndpoint>,
169 aliases: HashMap<String, ModelId>,
170 default_model: Option<ModelId>,
171}
172
173impl ModelRegistry {
174 /// Create an empty registry.
175 pub fn new() -> Self {
176 Self {
177 endpoints: HashMap::new(),
178 aliases: HashMap::new(),
179 default_model: None,
180 }
181 }
182
183 /// Register a model endpoint.
184 ///
185 /// If the endpoint has `is_default` set, it becomes the default model.
186 /// Replaces any existing endpoint with the same ID.
187 pub fn register(&mut self, endpoint: ModelEndpoint) {
188 if endpoint.is_default {
189 self.default_model = Some(endpoint.id.clone());
190 }
191 self.endpoints.insert(endpoint.id.clone(), endpoint);
192 }
193
194 /// Remove an endpoint from the registry.
195 ///
196 /// Also clears the default-model pointer if it pointed to the removed
197 /// endpoint, and removes any aliases that targeted this ID.
198 pub fn unregister(&mut self, id: &ModelId) -> Option<ModelEndpoint> {
199 let removed = self.endpoints.remove(id);
200 if removed.is_some() {
201 // Clear default if it was this model.
202 if self.default_model.as_ref() == Some(id) {
203 self.default_model = None;
204 }
205 // Remove aliases pointing to this model.
206 self.aliases.retain(|_, target| target != id);
207 }
208 removed
209 }
210
211 /// Add an alias: e.g. `"gpt-4"` maps to `ModelId("bonsai-8b")`.
212 pub fn add_alias(&mut self, alias: impl Into<String>, target: ModelId) {
213 self.aliases.insert(alias.into(), target);
214 }
215
216 /// Resolve a model identifier (checks ID first, then aliases).
217 ///
218 /// Returns `None` if neither a direct ID nor an alias matches.
219 pub fn resolve(&self, id_or_alias: &str) -> Option<&ModelEndpoint> {
220 let model_id = ModelId::new(id_or_alias);
221 if let Some(ep) = self.endpoints.get(&model_id) {
222 return Some(ep);
223 }
224 // Try alias resolution.
225 if let Some(target_id) = self.aliases.get(id_or_alias) {
226 return self.endpoints.get(target_id);
227 }
228 None
229 }
230
231 /// Get the default model endpoint.
232 pub fn default_endpoint(&self) -> Option<&ModelEndpoint> {
233 self.default_model
234 .as_ref()
235 .and_then(|id| self.endpoints.get(id))
236 }
237
238 /// List all available (Ready) endpoints.
239 pub fn available_endpoints(&self) -> Vec<&ModelEndpoint> {
240 self.endpoints
241 .values()
242 .filter(|ep| ep.status.is_available())
243 .collect()
244 }
245
246 /// List all registered endpoints (including non-ready ones).
247 pub fn all_endpoints(&self) -> Vec<&ModelEndpoint> {
248 self.endpoints.values().collect()
249 }
250
251 /// Update an endpoint's status.
252 ///
253 /// Returns `true` if the endpoint was found and updated.
254 pub fn set_status(&mut self, id: &ModelId, status: EndpointStatus) -> bool {
255 if let Some(ep) = self.endpoints.get_mut(id) {
256 ep.status = status;
257 true
258 } else {
259 false
260 }
261 }
262
263 /// Number of registered endpoints.
264 pub fn len(&self) -> usize {
265 self.endpoints.len()
266 }
267
268 /// Is the registry empty?
269 pub fn is_empty(&self) -> bool {
270 self.endpoints.is_empty()
271 }
272}
273
274impl Default for ModelRegistry {
275 fn default() -> Self {
276 Self::new()
277 }
278}
279
280// ─────────────────────────────────────────────────────────────────────────────
281// RoutingError
282// ─────────────────────────────────────────────────────────────────────────────
283
284/// Errors that can occur when routing a request to a model endpoint.
285#[derive(Debug, thiserror::Error)]
286pub enum RoutingError {
287 /// The requested model was not found in the registry.
288 #[error("model '{0}' not found")]
289 ModelNotFound(String),
290
291 /// The requested model cannot accommodate the required context length.
292 #[error("model '{model}' cannot handle context length {required} (max: {available})")]
293 ContextTooLong {
294 model: String,
295 required: usize,
296 available: usize,
297 },
298
299 /// No models are currently available in the registry.
300 #[error("no models are currently available")]
301 NoModelsAvailable,
302
303 /// The model was found but is not in a ready state.
304 #[error("model '{0}' is not ready (status: {1})")]
305 ModelNotReady(String, String),
306}
307
308// ─────────────────────────────────────────────────────────────────────────────
309// ModelRouter
310// ─────────────────────────────────────────────────────────────────────────────
311
312/// Smart request router: selects the best model endpoint for a request.
313///
314/// Wraps a [`ModelRegistry`] and adds routing logic including fallback to
315/// the default model, context-length awareness, and OpenAI-compatible
316/// model listing.
317pub struct ModelRouter {
318 registry: ModelRegistry,
319}
320
321impl ModelRouter {
322 /// Create a new router backed by the given registry.
323 pub fn new(registry: ModelRegistry) -> Self {
324 Self { registry }
325 }
326
327 /// Route a request: resolve `model_id` from the request.
328 ///
329 /// Falls back to the default model if `requested_model` is `None`.
330 /// Returns an error if the resolved endpoint is not in a `Ready` state.
331 pub fn route(&self, requested_model: Option<&str>) -> Result<&ModelEndpoint, RoutingError> {
332 let endpoint = match requested_model {
333 Some(model_name) => self
334 .registry
335 .resolve(model_name)
336 .ok_or_else(|| RoutingError::ModelNotFound(model_name.to_string()))?,
337 None => self
338 .registry
339 .default_endpoint()
340 .ok_or(RoutingError::NoModelsAvailable)?,
341 };
342
343 if !endpoint.status.is_available() {
344 return Err(RoutingError::ModelNotReady(
345 endpoint.id.to_string(),
346 endpoint.status.name().to_string(),
347 ));
348 }
349
350 Ok(endpoint)
351 }
352
353 /// Route with context-length awareness: pick a model that can accommodate
354 /// the required context length.
355 ///
356 /// If a specific model is requested, validates it has sufficient context.
357 /// If no model is specified, finds the default model that fits, or falls
358 /// back to any available model with sufficient context capacity.
359 pub fn route_for_context(
360 &self,
361 requested_model: Option<&str>,
362 required_context: usize,
363 ) -> Result<&ModelEndpoint, RoutingError> {
364 let endpoint = self.route(requested_model)?;
365
366 if endpoint.max_context_length < required_context {
367 // If a specific model was requested but is too small, error out.
368 if requested_model.is_some() {
369 return Err(RoutingError::ContextTooLong {
370 model: endpoint.id.to_string(),
371 required: required_context,
372 available: endpoint.max_context_length,
373 });
374 }
375
376 // No specific model requested — try to find any available endpoint
377 // with sufficient context capacity.
378 let fallback = self
379 .registry
380 .available_endpoints()
381 .into_iter()
382 .filter(|ep| ep.max_context_length >= required_context)
383 .max_by_key(|ep| ep.max_context_length);
384
385 return fallback.ok_or(RoutingError::ContextTooLong {
386 model: endpoint.id.to_string(),
387 required: required_context,
388 available: endpoint.max_context_length,
389 });
390 }
391
392 Ok(endpoint)
393 }
394
395 /// OpenAI-compatible `/v1/models` list.
396 ///
397 /// Returns an entry for every available endpoint in the registry.
398 pub fn models_list(&self) -> Vec<ModelListEntry> {
399 let created = std::time::SystemTime::now()
400 .duration_since(std::time::UNIX_EPOCH)
401 .unwrap_or_default()
402 .as_secs();
403
404 self.registry
405 .available_endpoints()
406 .into_iter()
407 .map(|ep| ModelListEntry {
408 id: ep.id.to_string(),
409 object: "model".to_string(),
410 owned_by: "oxibonsai".to_string(),
411 created,
412 })
413 .collect()
414 }
415
416 /// Immutable access to the underlying registry.
417 pub fn registry(&self) -> &ModelRegistry {
418 &self.registry
419 }
420
421 /// Mutable access to the underlying registry.
422 pub fn registry_mut(&mut self) -> &mut ModelRegistry {
423 &mut self.registry
424 }
425}
426
427// ─────────────────────────────────────────────────────────────────────────────
428// ModelListEntry
429// ─────────────────────────────────────────────────────────────────────────────
430
431/// Entry for an OpenAI-compatible `/v1/models` response.
432#[derive(Debug, Clone)]
433pub struct ModelListEntry {
434 /// Model identifier string.
435 pub id: String,
436 /// Object type — always `"model"`.
437 pub object: String,
438 /// Organisation that owns the model.
439 pub owned_by: String,
440 /// Unix timestamp when the model was created/registered.
441 pub created: u64,
442}
443
444// ─────────────────────────────────────────────────────────────────────────────
445// AdapterRef / AdapterStack
446// ─────────────────────────────────────────────────────────────────────────────
447
448/// A reference to a single LoRA adapter with a blending weight.
449#[derive(Debug, Clone)]
450pub struct AdapterRef {
451 /// Name of the LoRA adapter.
452 pub name: String,
453 /// Blending weight in the range `[0.0, 1.0]`.
454 pub weight: f32,
455}
456
457/// Adapter composition: apply multiple LoRA adapters in sequence.
458///
459/// Allows stacking several adapters with independent blending weights.
460/// Weights can be normalized so they sum to 1.0, which is useful for
461/// even blending across adapters.
462#[derive(Debug, Clone)]
463pub struct AdapterStack {
464 /// The ordered list of adapters to apply.
465 pub adapters: Vec<AdapterRef>,
466}
467
468impl AdapterStack {
469 /// Create an empty adapter stack.
470 pub fn new() -> Self {
471 Self {
472 adapters: Vec::new(),
473 }
474 }
475
476 /// Add an adapter with the given blending weight.
477 pub fn add(mut self, name: impl Into<String>, weight: f32) -> Self {
478 self.adapters.push(AdapterRef {
479 name: name.into(),
480 weight,
481 });
482 self
483 }
484
485 /// Number of adapters in the stack.
486 pub fn len(&self) -> usize {
487 self.adapters.len()
488 }
489
490 /// Whether the stack is empty.
491 pub fn is_empty(&self) -> bool {
492 self.adapters.is_empty()
493 }
494
495 /// Sum of all adapter weights.
496 pub fn total_weight(&self) -> f32 {
497 self.adapters.iter().map(|a| a.weight).sum()
498 }
499
500 /// Normalize weights so they sum to 1.0.
501 ///
502 /// If the total weight is zero (or very close to it), weights are left
503 /// unchanged to avoid division by zero.
504 pub fn normalize_weights(&mut self) {
505 let total = self.total_weight();
506 if total.abs() < f32::EPSILON {
507 return;
508 }
509 for adapter in &mut self.adapters {
510 adapter.weight /= total;
511 }
512 }
513}
514
515impl Default for AdapterStack {
516 fn default() -> Self {
517 Self::new()
518 }
519}
520
521// ─────────────────────────────────────────────────────────────────────────────
522// Unit tests
523// ─────────────────────────────────────────────────────────────────────────────
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528
529 #[test]
530 fn model_id_display() {
531 let id = ModelId::new("bonsai-8b");
532 assert_eq!(format!("{id}"), "bonsai-8b");
533 }
534
535 #[test]
536 fn endpoint_status_name() {
537 assert_eq!(EndpointStatus::Ready.name(), "ready");
538 assert_eq!(EndpointStatus::Loading.name(), "loading");
539 assert_eq!(EndpointStatus::Error.name(), "error");
540 assert_eq!(EndpointStatus::Disabled.name(), "disabled");
541 }
542
543 #[test]
544 fn endpoint_display_name_defaults_to_id() {
545 let ep = ModelEndpoint::new("bonsai-8b", "qwen3-8b");
546 assert_eq!(ep.display_name, "bonsai-8b");
547 }
548}