kizzasi_model/
lib.rs

1//! # kizzasi-model
2//!
3//! Model architectures for Kizzasi AGSP (Autoregressive General-Purpose Signal Predictor).
4//!
5//! This crate implements various State Space Model architectures optimized for
6//! continuous signal prediction with O(1) inference step complexity:
7//!
8//! - **Mamba/Mamba2**: Selective State Space Models with input-dependent dynamics
9//! - **RWKV**: Linear attention with time-mixing and channel-mixing
10//! - **S4/S4D**: Structured State Space Models with diagonal state matrices
11//! - **Transformer**: Standard attention for comparison (O(N) per step)
12//!
13//! ## COOLJAPAN Ecosystem
14//!
15//! This crate follows KIZZASI_POLICY.md and uses `scirs2-core` for all
16//! array and numerical operations.
17//!
18//! ## Architecture Philosophy
19//!
20//! As described in the AGSP concept, these models treat all signals
21//! (audio, video, sensors, actions) as equivalent tokenized sequences,
22//! enabling cross-modal prediction and world model construction.
23
24pub mod batch;
25pub mod blas_ops;
26pub mod cache_friendly;
27pub mod checkpoint;
28pub mod compression;
29pub mod dynamic_quantization;
30mod error;
31pub mod factory;
32pub mod huggingface;
33pub mod huggingface_loader;
34pub mod loader;
35pub mod mixed_precision;
36pub mod moe;
37pub mod parallel_multihead;
38pub mod profiling;
39pub mod pytorch_compat;
40pub mod quantization;
41pub mod simd_ops;
42pub mod training;
43
44#[cfg(feature = "mamba")]
45pub mod mamba;
46
47#[cfg(feature = "mamba")]
48pub mod mamba2;
49
50pub mod rwkv;
51
52pub mod rwkv7;
53
54pub mod s4;
55
56pub mod s5;
57
58pub mod h3;
59
60pub mod hybrid;
61
62pub mod transformer;
63
64pub use error::{ModelError, ModelResult};
65pub use loader::{ModelLoader, TensorInfo, WeightLoader};
66
67// Re-export BLAS operations for convenience
68pub use blas_ops::{
69    axpy, batch_matmul_vec, dot, matmul_mat, matmul_vec, norm_frobenius, norm_l2, transpose,
70    BlasConfig,
71};
72
73// Re-export profiling utilities
74pub use profiling::{
75    BottleneckInfo, BottleneckSeverity, ComprehensiveComparison, ComprehensiveProfiler,
76    ModelBottleneckAnalysis,
77};
78
79// Re-export core types
80pub use kizzasi_core::{CoreResult, HiddenState, SignalPredictor};
81pub use scirs2_core::ndarray::{Array1, Array2};
82
83/// Trait for model architectures that support autoregressive prediction
84pub trait AutoregressiveModel: SignalPredictor + Send {
85    /// Get the model's hidden dimension
86    fn hidden_dim(&self) -> usize;
87
88    /// Get the model's state dimension (for SSMs)
89    fn state_dim(&self) -> usize;
90
91    /// Get number of layers
92    fn num_layers(&self) -> usize;
93
94    /// Get model type identifier
95    fn model_type(&self) -> ModelType;
96
97    /// Get current hidden states for all layers
98    fn get_states(&self) -> Vec<HiddenState>;
99
100    /// Set hidden states for all layers
101    fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()>;
102}
103
104/// Enumeration of supported model architectures
105#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
106pub enum ModelType {
107    /// Mamba: Selective State Space Model
108    Mamba,
109    /// Mamba2: Enhanced selective SSM with SSD
110    Mamba2,
111    /// RWKV: Linear attention with time-mixing
112    Rwkv,
113    /// S4: Structured State Space Model
114    S4,
115    /// S4D: S4 with diagonal state matrix
116    S4D,
117    /// Standard Transformer (for comparison)
118    Transformer,
119}
120
121impl std::fmt::Display for ModelType {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        match self {
124            ModelType::Mamba => write!(f, "Mamba"),
125            ModelType::Mamba2 => write!(f, "Mamba2"),
126            ModelType::Rwkv => write!(f, "RWKV"),
127            ModelType::S4 => write!(f, "S4"),
128            ModelType::S4D => write!(f, "S4D"),
129            ModelType::Transformer => write!(f, "Transformer"),
130        }
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    #[test]
139    fn test_model_type_display() {
140        assert_eq!(format!("{}", ModelType::Mamba2), "Mamba2");
141        assert_eq!(format!("{}", ModelType::Rwkv), "RWKV");
142    }
143}