Skip to main content

neuronbox_runtime/
model_loader.rs

1//! Tracks logical active model for swap / display (weights are loaded by Python / user code, not in Rust).
2
3use std::sync::Arc;
4
5use tokio::sync::RwLock;
6
7#[derive(Clone, Default)]
8pub struct ActiveModel {
9    pub model_ref: String,
10    pub quantization: Option<String>,
11}
12
13#[derive(Clone)]
14pub struct ModelLoader {
15    active: Arc<RwLock<ActiveModel>>,
16}
17
18impl Default for ModelLoader {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl ModelLoader {
25    pub fn new() -> Self {
26        Self {
27            active: Arc::new(RwLock::new(ActiveModel::default())),
28        }
29    }
30
31    pub async fn swap(&self, model_ref: String, quantization: Option<String>) {
32        let mut a = self.active.write().await;
33        a.model_ref = model_ref;
34        a.quantization = quantization;
35    }
36
37    pub async fn get(&self) -> ActiveModel {
38        self.active.read().await.clone()
39    }
40}