1use crate::error::{InferenceError, InferenceResult};
11use kizzasi_model::{AutoregressiveModel, ModelType};
12use std::path::Path;
13
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
16pub struct ModelConfig {
17 pub model_type: ModelType,
19 pub input_dim: usize,
21 pub hidden_dim: usize,
23 pub num_layers: usize,
25 pub state_dim: usize,
27 pub output_dim: usize,
29 pub weights_path: Option<String>,
31}
32
33impl ModelConfig {
34 pub fn new(model_type: ModelType) -> Self {
36 Self {
37 model_type,
38 input_dim: 1,
39 hidden_dim: 256,
40 num_layers: 4,
41 state_dim: 16,
42 output_dim: 1,
43 weights_path: None,
44 }
45 }
46
47 pub fn input_dim(mut self, dim: usize) -> Self {
49 self.input_dim = dim;
50 self
51 }
52
53 pub fn hidden_dim(mut self, dim: usize) -> Self {
55 self.hidden_dim = dim;
56 self
57 }
58
59 pub fn num_layers(mut self, n: usize) -> Self {
61 self.num_layers = n;
62 self
63 }
64
65 pub fn state_dim(mut self, dim: usize) -> Self {
67 self.state_dim = dim;
68 self
69 }
70
71 pub fn output_dim(mut self, dim: usize) -> Self {
73 self.output_dim = dim;
74 self
75 }
76
77 pub fn weights_path(mut self, path: impl Into<String>) -> Self {
79 self.weights_path = Some(path.into());
80 self
81 }
82}
83
84pub struct ModelRegistry {
86 configs: std::collections::HashMap<String, ModelConfig>,
88}
89
90impl ModelRegistry {
91 pub fn new() -> Self {
93 Self {
94 configs: std::collections::HashMap::new(),
95 }
96 }
97
98 pub fn register(&mut self, name: impl Into<String>, config: ModelConfig) {
100 self.configs.insert(name.into(), config);
101 }
102
103 pub fn get_config(&self, name: &str) -> Option<&ModelConfig> {
105 self.configs.get(name)
106 }
107
108 pub fn list_models(&self) -> Vec<String> {
110 self.configs.keys().cloned().collect()
111 }
112
113 pub fn create_model(&self, name: &str) -> InferenceResult<Box<dyn AutoregressiveModel>> {
115 let config = self
116 .get_config(name)
117 .ok_or_else(|| InferenceError::PipelineConfig(format!("Model '{}' not found", name)))?;
118
119 self.create_from_config(config)
120 }
121
122 fn create_from_config(
124 &self,
125 config: &ModelConfig,
126 ) -> InferenceResult<Box<dyn AutoregressiveModel>> {
127 match config.model_type {
128 ModelType::Mamba2 => {
129 Err(InferenceError::PipelineConfig(
131 "Mamba2 not yet supported - models not exported".into(),
132 ))
133 }
134 ModelType::Rwkv => {
135 use kizzasi_model::rwkv::{Rwkv, RwkvConfig};
136 let num_heads = (config.hidden_dim / 64).max(1);
137 let model_config = RwkvConfig {
138 input_dim: config.input_dim,
139 hidden_dim: config.hidden_dim,
140 intermediate_dim: config.hidden_dim * 4,
141 num_layers: config.num_layers,
142 num_heads,
143 head_dim: config.hidden_dim / num_heads,
144 dropout: 0.0,
145 time_decay_init: -5.0,
146 use_rms_norm: true,
147 };
148 let model = Rwkv::new(model_config).map_err(InferenceError::ModelError)?;
149 Ok(Box::new(model))
150 }
151 ModelType::S4 | ModelType::S4D => {
152 use kizzasi_model::s4::{S4Config, S4D};
153 let model_config = S4Config {
154 input_dim: config.input_dim,
155 hidden_dim: config.hidden_dim,
156 state_dim: config.state_dim,
157 num_layers: config.num_layers,
158 dropout: 0.0,
159 dt_min: 0.001,
160 dt_max: 0.1,
161 use_diagonal: config.model_type == ModelType::S4D,
162 use_rms_norm: true,
163 };
164 let model = S4D::new(model_config).map_err(InferenceError::ModelError)?;
165 Ok(Box::new(model))
166 }
167 ModelType::Transformer => {
168 use kizzasi_model::transformer::{Transformer, TransformerConfig};
169 let num_heads = (config.hidden_dim / 64).max(1);
170 let model_config = TransformerConfig {
171 input_dim: config.input_dim,
172 hidden_dim: config.hidden_dim,
173 num_heads,
174 head_dim: config.hidden_dim / num_heads,
175 ff_dim: config.hidden_dim * 4,
176 num_layers: config.num_layers,
177 max_seq_len: 8192,
178 dropout: 0.1,
179 use_rms_norm: true,
180 causal: true,
181 };
182 let model = Transformer::new(model_config).map_err(InferenceError::ModelError)?;
183 Ok(Box::new(model))
184 }
185 ModelType::Mamba => {
186 Err(InferenceError::PipelineConfig(
188 "Mamba not yet supported - models not exported".into(),
189 ))
190 }
191 }
192 }
193
194 pub fn load_weights(
196 &self,
197 _model: &mut dyn AutoregressiveModel,
198 path: impl AsRef<Path>,
199 ) -> InferenceResult<()> {
200 let _path = path.as_ref();
203 tracing::info!("Weight loading not yet implemented");
204 Ok(())
205 }
206}
207
208impl Default for ModelRegistry {
209 fn default() -> Self {
210 Self::new()
211 }
212}
213
214pub struct ModelBuilder {
216 config: ModelConfig,
217}
218
219impl ModelBuilder {
220 pub fn mamba2() -> Self {
222 Self {
223 config: ModelConfig::new(ModelType::Mamba2),
224 }
225 }
226
227 pub fn rwkv() -> Self {
229 Self {
230 config: ModelConfig::new(ModelType::Rwkv),
231 }
232 }
233
234 pub fn s4() -> Self {
236 Self {
237 config: ModelConfig::new(ModelType::S4),
238 }
239 }
240
241 pub fn s4d() -> Self {
243 Self {
244 config: ModelConfig::new(ModelType::S4D),
245 }
246 }
247
248 pub fn transformer() -> Self {
250 Self {
251 config: ModelConfig::new(ModelType::Transformer),
252 }
253 }
254
255 pub fn dims(mut self, input: usize, hidden: usize, output: usize) -> Self {
257 self.config.input_dim = input;
258 self.config.hidden_dim = hidden;
259 self.config.output_dim = output;
260 self
261 }
262
263 pub fn layers(mut self, n: usize) -> Self {
265 self.config.num_layers = n;
266 self
267 }
268
269 pub fn state_dim(mut self, dim: usize) -> Self {
271 self.config.state_dim = dim;
272 self
273 }
274
275 pub fn build(self) -> ModelConfig {
277 self.config
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn test_model_config_builder() {
287 let config = ModelConfig::new(ModelType::Mamba2)
288 .input_dim(3)
289 .hidden_dim(128)
290 .num_layers(2)
291 .output_dim(3);
292
293 assert_eq!(config.model_type, ModelType::Mamba2);
294 assert_eq!(config.input_dim, 3);
295 assert_eq!(config.hidden_dim, 128);
296 assert_eq!(config.num_layers, 2);
297 }
298
299 #[test]
300 fn test_registry_register() {
301 let mut registry = ModelRegistry::new();
302 let config = ModelConfig::new(ModelType::Rwkv);
303
304 registry.register("test_model", config);
305
306 assert!(registry.get_config("test_model").is_some());
307 assert_eq!(registry.list_models().len(), 1);
308 }
309
310 #[test]
311 fn test_model_builder() {
312 let config = ModelBuilder::s4d()
313 .dims(1, 256, 1)
314 .layers(4)
315 .state_dim(16)
316 .build();
317
318 assert_eq!(config.model_type, ModelType::S4D);
319 assert_eq!(config.hidden_dim, 256);
320 assert_eq!(config.num_layers, 4);
321 }
322
323 #[test]
324 fn test_create_rwkv_model() {
325 let mut registry = ModelRegistry::new();
326 let config = ModelBuilder::rwkv().dims(1, 64, 10).layers(2).build();
327
328 registry.register("rwkv_test", config);
329
330 let result = registry.create_model("rwkv_test");
331 assert!(result.is_ok());
332
333 let model = result.unwrap();
334 assert_eq!(model.model_type(), ModelType::Rwkv);
335 assert_eq!(model.hidden_dim(), 64);
336 }
337
338 #[test]
339 fn test_create_s4_model() {
340 let mut registry = ModelRegistry::new();
341 let config = ModelBuilder::s4d().dims(1, 128, 10).layers(3).build();
342
343 registry.register("s4_test", config);
344
345 let result = registry.create_model("s4_test");
346 assert!(result.is_ok());
347
348 let model = result.unwrap();
349 assert_eq!(model.model_type(), ModelType::S4D);
350 }
351
352 #[test]
353 fn test_create_transformer_model() {
354 let mut registry = ModelRegistry::new();
355 let config = ModelBuilder::transformer()
356 .dims(1, 128, 10)
357 .layers(2)
358 .build();
359
360 registry.register("transformer_test", config);
361
362 let result = registry.create_model("transformer_test");
363 assert!(result.is_ok());
364
365 let model = result.unwrap();
366 assert_eq!(model.model_type(), ModelType::Transformer);
367 }
368}