kizzasi_inference/
registry.rs

1//! Model registry for loading and managing different architectures
2//!
3//! This module provides a centralized registry for loading various model
4//! architectures supported by kizzasi-model:
5//! - Mamba/Mamba2: Selective State Space Models
6//! - RWKV: Linear attention models
7//! - S4/S4D: Structured State Space Models
8//! - Transformer: Standard attention models
9
10use crate::error::{InferenceError, InferenceResult};
11use kizzasi_model::{AutoregressiveModel, ModelType};
12use std::path::Path;
13
14/// Configuration for model loading
15#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
16pub struct ModelConfig {
17    /// Type of model architecture
18    pub model_type: ModelType,
19    /// Input dimension
20    pub input_dim: usize,
21    /// Hidden dimension
22    pub hidden_dim: usize,
23    /// Number of layers
24    pub num_layers: usize,
25    /// State dimension (for SSMs)
26    pub state_dim: usize,
27    /// Output dimension
28    pub output_dim: usize,
29    /// Optional path to pretrained weights
30    pub weights_path: Option<String>,
31}
32
33impl ModelConfig {
34    /// Create a new model configuration
35    pub fn new(model_type: ModelType) -> Self {
36        Self {
37            model_type,
38            input_dim: 1,
39            hidden_dim: 256,
40            num_layers: 4,
41            state_dim: 16,
42            output_dim: 1,
43            weights_path: None,
44        }
45    }
46
47    /// Set input dimension
48    pub fn input_dim(mut self, dim: usize) -> Self {
49        self.input_dim = dim;
50        self
51    }
52
53    /// Set hidden dimension
54    pub fn hidden_dim(mut self, dim: usize) -> Self {
55        self.hidden_dim = dim;
56        self
57    }
58
59    /// Set number of layers
60    pub fn num_layers(mut self, n: usize) -> Self {
61        self.num_layers = n;
62        self
63    }
64
65    /// Set state dimension
66    pub fn state_dim(mut self, dim: usize) -> Self {
67        self.state_dim = dim;
68        self
69    }
70
71    /// Set output dimension
72    pub fn output_dim(mut self, dim: usize) -> Self {
73        self.output_dim = dim;
74        self
75    }
76
77    /// Set weights path
78    pub fn weights_path(mut self, path: impl Into<String>) -> Self {
79        self.weights_path = Some(path.into());
80        self
81    }
82}
83
84/// Model registry for creating and managing model instances
85pub struct ModelRegistry {
86    /// Available model configurations
87    configs: std::collections::HashMap<String, ModelConfig>,
88}
89
90impl ModelRegistry {
91    /// Create a new model registry
92    pub fn new() -> Self {
93        Self {
94            configs: std::collections::HashMap::new(),
95        }
96    }
97
98    /// Register a model configuration with a name
99    pub fn register(&mut self, name: impl Into<String>, config: ModelConfig) {
100        self.configs.insert(name.into(), config);
101    }
102
103    /// Get a registered configuration
104    pub fn get_config(&self, name: &str) -> Option<&ModelConfig> {
105        self.configs.get(name)
106    }
107
108    /// List all registered model names
109    pub fn list_models(&self) -> Vec<String> {
110        self.configs.keys().cloned().collect()
111    }
112
113    /// Create a model wrapper from a configuration
114    pub fn create_model(&self, name: &str) -> InferenceResult<Box<dyn AutoregressiveModel>> {
115        let config = self
116            .get_config(name)
117            .ok_or_else(|| InferenceError::PipelineConfig(format!("Model '{}' not found", name)))?;
118
119        self.create_from_config(config)
120    }
121
122    /// Create a model from configuration
123    fn create_from_config(
124        &self,
125        config: &ModelConfig,
126    ) -> InferenceResult<Box<dyn AutoregressiveModel>> {
127        match config.model_type {
128            ModelType::Mamba2 => {
129                // TODO: Fix Mamba2 model exports in kizzasi-model
130                Err(InferenceError::PipelineConfig(
131                    "Mamba2 not yet supported - models not exported".into(),
132                ))
133            }
134            ModelType::Rwkv => {
135                use kizzasi_model::rwkv::{Rwkv, RwkvConfig};
136                let num_heads = (config.hidden_dim / 64).max(1);
137                let model_config = RwkvConfig {
138                    input_dim: config.input_dim,
139                    hidden_dim: config.hidden_dim,
140                    intermediate_dim: config.hidden_dim * 4,
141                    num_layers: config.num_layers,
142                    num_heads,
143                    head_dim: config.hidden_dim / num_heads,
144                    dropout: 0.0,
145                    time_decay_init: -5.0,
146                    use_rms_norm: true,
147                };
148                let model = Rwkv::new(model_config).map_err(InferenceError::ModelError)?;
149                Ok(Box::new(model))
150            }
151            ModelType::S4 | ModelType::S4D => {
152                use kizzasi_model::s4::{S4Config, S4D};
153                let model_config = S4Config {
154                    input_dim: config.input_dim,
155                    hidden_dim: config.hidden_dim,
156                    state_dim: config.state_dim,
157                    num_layers: config.num_layers,
158                    dropout: 0.0,
159                    dt_min: 0.001,
160                    dt_max: 0.1,
161                    use_diagonal: config.model_type == ModelType::S4D,
162                    use_rms_norm: true,
163                };
164                let model = S4D::new(model_config).map_err(InferenceError::ModelError)?;
165                Ok(Box::new(model))
166            }
167            ModelType::Transformer => {
168                use kizzasi_model::transformer::{Transformer, TransformerConfig};
169                let num_heads = (config.hidden_dim / 64).max(1);
170                let model_config = TransformerConfig {
171                    input_dim: config.input_dim,
172                    hidden_dim: config.hidden_dim,
173                    num_heads,
174                    head_dim: config.hidden_dim / num_heads,
175                    ff_dim: config.hidden_dim * 4,
176                    num_layers: config.num_layers,
177                    max_seq_len: 8192,
178                    dropout: 0.1,
179                    use_rms_norm: true,
180                    causal: true,
181                };
182                let model = Transformer::new(model_config).map_err(InferenceError::ModelError)?;
183                Ok(Box::new(model))
184            }
185            ModelType::Mamba => {
186                // TODO: Fix Mamba model exports in kizzasi-model
187                Err(InferenceError::PipelineConfig(
188                    "Mamba not yet supported - models not exported".into(),
189                ))
190            }
191        }
192    }
193
194    /// Load weights from a file
195    pub fn load_weights(
196        &self,
197        _model: &mut dyn AutoregressiveModel,
198        path: impl AsRef<Path>,
199    ) -> InferenceResult<()> {
200        // Placeholder for weight loading
201        // Will integrate with kizzasi_model::loader once implemented
202        let _path = path.as_ref();
203        tracing::info!("Weight loading not yet implemented");
204        Ok(())
205    }
206}
207
208impl Default for ModelRegistry {
209    fn default() -> Self {
210        Self::new()
211    }
212}
213
214/// Builder for common model configurations
215pub struct ModelBuilder {
216    config: ModelConfig,
217}
218
219impl ModelBuilder {
220    /// Start building a Mamba2 model
221    pub fn mamba2() -> Self {
222        Self {
223            config: ModelConfig::new(ModelType::Mamba2),
224        }
225    }
226
227    /// Start building an RWKV model
228    pub fn rwkv() -> Self {
229        Self {
230            config: ModelConfig::new(ModelType::Rwkv),
231        }
232    }
233
234    /// Start building an S4 model
235    pub fn s4() -> Self {
236        Self {
237            config: ModelConfig::new(ModelType::S4),
238        }
239    }
240
241    /// Start building an S4D model
242    pub fn s4d() -> Self {
243        Self {
244            config: ModelConfig::new(ModelType::S4D),
245        }
246    }
247
248    /// Start building a Transformer model
249    pub fn transformer() -> Self {
250        Self {
251            config: ModelConfig::new(ModelType::Transformer),
252        }
253    }
254
255    /// Set dimensions
256    pub fn dims(mut self, input: usize, hidden: usize, output: usize) -> Self {
257        self.config.input_dim = input;
258        self.config.hidden_dim = hidden;
259        self.config.output_dim = output;
260        self
261    }
262
263    /// Set number of layers
264    pub fn layers(mut self, n: usize) -> Self {
265        self.config.num_layers = n;
266        self
267    }
268
269    /// Set state dimension
270    pub fn state_dim(mut self, dim: usize) -> Self {
271        self.config.state_dim = dim;
272        self
273    }
274
275    /// Build the configuration
276    pub fn build(self) -> ModelConfig {
277        self.config
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_model_config_builder() {
287        let config = ModelConfig::new(ModelType::Mamba2)
288            .input_dim(3)
289            .hidden_dim(128)
290            .num_layers(2)
291            .output_dim(3);
292
293        assert_eq!(config.model_type, ModelType::Mamba2);
294        assert_eq!(config.input_dim, 3);
295        assert_eq!(config.hidden_dim, 128);
296        assert_eq!(config.num_layers, 2);
297    }
298
299    #[test]
300    fn test_registry_register() {
301        let mut registry = ModelRegistry::new();
302        let config = ModelConfig::new(ModelType::Rwkv);
303
304        registry.register("test_model", config);
305
306        assert!(registry.get_config("test_model").is_some());
307        assert_eq!(registry.list_models().len(), 1);
308    }
309
310    #[test]
311    fn test_model_builder() {
312        let config = ModelBuilder::s4d()
313            .dims(1, 256, 1)
314            .layers(4)
315            .state_dim(16)
316            .build();
317
318        assert_eq!(config.model_type, ModelType::S4D);
319        assert_eq!(config.hidden_dim, 256);
320        assert_eq!(config.num_layers, 4);
321    }
322
323    #[test]
324    fn test_create_rwkv_model() {
325        let mut registry = ModelRegistry::new();
326        let config = ModelBuilder::rwkv().dims(1, 64, 10).layers(2).build();
327
328        registry.register("rwkv_test", config);
329
330        let result = registry.create_model("rwkv_test");
331        assert!(result.is_ok());
332
333        let model = result.unwrap();
334        assert_eq!(model.model_type(), ModelType::Rwkv);
335        assert_eq!(model.hidden_dim(), 64);
336    }
337
338    #[test]
339    fn test_create_s4_model() {
340        let mut registry = ModelRegistry::new();
341        let config = ModelBuilder::s4d().dims(1, 128, 10).layers(3).build();
342
343        registry.register("s4_test", config);
344
345        let result = registry.create_model("s4_test");
346        assert!(result.is_ok());
347
348        let model = result.unwrap();
349        assert_eq!(model.model_type(), ModelType::S4D);
350    }
351
352    #[test]
353    fn test_create_transformer_model() {
354        let mut registry = ModelRegistry::new();
355        let config = ModelBuilder::transformer()
356            .dims(1, 128, 10)
357            .layers(2)
358            .build();
359
360        registry.register("transformer_test", config);
361
362        let result = registry.create_model("transformer_test");
363        assert!(result.is_ok());
364
365        let model = result.unwrap();
366        assert_eq!(model.model_type(), ModelType::Transformer);
367    }
368}