1pub mod batch;
25pub mod blas_ops;
26pub mod cache_friendly;
27pub mod checkpoint;
28pub mod compression;
29pub mod dynamic_quantization;
30mod error;
31pub mod factory;
32pub mod huggingface;
33pub mod huggingface_loader;
34pub mod loader;
35pub mod mixed_precision;
36pub mod moe;
37pub mod parallel_multihead;
38pub mod profiling;
39pub mod pytorch_compat;
40pub mod quantization;
41pub mod simd_ops;
42pub mod training;
43
44#[cfg(feature = "mamba")]
45pub mod mamba;
46
47#[cfg(feature = "mamba")]
48pub mod mamba2;
49
50pub mod rwkv;
51
52pub mod rwkv7;
53
54pub mod s4;
55
56pub mod s5;
57
58pub mod h3;
59
60pub mod hybrid;
61
62pub mod transformer;
63
64pub use error::{ModelError, ModelResult};
65pub use loader::{ModelLoader, TensorInfo, WeightLoader};
66
67pub use blas_ops::{
69 axpy, batch_matmul_vec, dot, matmul_mat, matmul_vec, norm_frobenius, norm_l2, transpose,
70 BlasConfig,
71};
72
73pub use profiling::{
75 BottleneckInfo, BottleneckSeverity, ComprehensiveComparison, ComprehensiveProfiler,
76 ModelBottleneckAnalysis,
77};
78
79pub use kizzasi_core::{CoreResult, HiddenState, SignalPredictor};
81pub use scirs2_core::ndarray::{Array1, Array2};
82
83pub trait AutoregressiveModel: SignalPredictor + Send {
85 fn hidden_dim(&self) -> usize;
87
88 fn state_dim(&self) -> usize;
90
91 fn num_layers(&self) -> usize;
93
94 fn model_type(&self) -> ModelType;
96
97 fn get_states(&self) -> Vec<HiddenState>;
99
100 fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()>;
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
106pub enum ModelType {
107 Mamba,
109 Mamba2,
111 Rwkv,
113 S4,
115 S4D,
117 Transformer,
119}
120
121impl std::fmt::Display for ModelType {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 match self {
124 ModelType::Mamba => write!(f, "Mamba"),
125 ModelType::Mamba2 => write!(f, "Mamba2"),
126 ModelType::Rwkv => write!(f, "RWKV"),
127 ModelType::S4 => write!(f, "S4"),
128 ModelType::S4D => write!(f, "S4D"),
129 ModelType::Transformer => write!(f, "Transformer"),
130 }
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 #[test]
139 fn test_model_type_display() {
140 assert_eq!(format!("{}", ModelType::Mamba2), "Mamba2");
141 assert_eq!(format!("{}", ModelType::Rwkv), "RWKV");
142 }
143}