Skip to main content

kizzasi_model/
lib.rs

1//! # kizzasi-model
2//!
3//! Model architectures for Kizzasi AGSP (Autoregressive General-Purpose Signal Predictor).
4//!
5//! This crate implements various State Space Model architectures optimized for
6//! continuous signal prediction with O(1) inference step complexity:
7//!
8//! - **Mamba/Mamba2**: Selective State Space Models with input-dependent dynamics
9//! - **RWKV**: Linear attention with time-mixing and channel-mixing
10//! - **S4/S4D**: Structured State Space Models with diagonal state matrices
11//! - **Transformer**: Standard attention for comparison (O(N) per step)
12//!
13//! ## COOLJAPAN Ecosystem
14//!
15//! This crate follows KIZZASI_POLICY.md and uses `scirs2-core` for all
16//! array and numerical operations.
17//!
18//! ## Architecture Philosophy
19//!
20//! As described in the AGSP concept, these models treat all signals
21//! (audio, video, sensors, actions) as equivalent tokenized sequences,
22//! enabling cross-modal prediction and world model construction.
23
24#[cfg(feature = "hf-hub")]
25pub mod hf_hub;
26#[cfg(feature = "hf-hub")]
27pub use hf_hub::{load_from_hub, HfHubClient, HfHubConfig, HfModelInfo};
28
29pub mod arch_search;
30pub use arch_search::{
31    search_best_arch, ArchCandidate, ArchSearchConfig, ArchSearchResult, ArchSearchSpace,
32    EvolutionarySearcher, GridSearcher, RandomArchSearcher,
33};
34
35pub mod backprop;
36pub mod backprop_ssm;
37pub mod batch;
38pub mod blas_ops;
39pub mod cache_friendly;
40pub mod checkpoint;
41pub mod compression;
42pub mod curriculum;
43pub mod distributed;
44pub mod dynamic_quantization;
45pub mod early_exit;
46mod error;
47pub mod factory;
48pub mod flash_linear_attn;
49pub mod gguf;
50pub(crate) mod gguf_dequant;
51pub mod gradient_checkpoint;
52pub mod huggingface;
53pub mod huggingface_loader;
54pub mod incremental_loader;
55pub mod loader;
56pub mod lora;
57pub mod mixed_precision;
58pub mod moe;
59pub mod onnx_export;
60pub mod parallel_multihead;
61pub mod profiling;
62pub mod prune;
63pub mod pytorch_compat;
64pub mod quantization;
65pub mod quantize;
66pub mod simd_ops;
67pub mod speculative;
68pub mod state_io;
69pub mod training;
70pub mod training_loop;
71pub mod visualization;
72
73#[cfg(feature = "mamba")]
74pub mod mamba;
75
76#[cfg(feature = "mamba")]
77pub mod mamba2;
78
79pub mod interpretability;
80
81pub mod rwkv;
82
83pub mod rwkv5;
84
85pub mod rwkv7;
86
87pub mod s4;
88
89pub mod s5;
90
91pub mod h3;
92
93pub mod hybrid;
94
95pub mod multimodal;
96
97pub mod neural_ode;
98
99pub mod spiking;
100
101pub mod temporal_multiscale;
102
103pub mod transformer;
104
105pub use error::{ModelError, ModelResult};
106
107// Re-export backward pass / autograd types
108pub use backprop::{
109    layer_norm_backward, linear_backward, silu_backward, softmax_backward, GradAccumulator,
110    GradientTape, SsmBackward, SsmGradients, Tensor,
111};
112pub use gguf::{GgufFile, GgufInspection, GgufMetaValue, GgufQuantType, GgufTensorInfo};
113pub use incremental_loader::{
114    GgufFileSource, IncrementalModelLoader, SafeTensorsSource, WeightSource,
115};
116pub use loader::{ModelLoader, TensorInfo, WeightLoader};
117pub use lora::{LoraAdapter, LoraAdapterSummary, LoraConfig, LoraLinear, QLoraLinear};
118pub use multimodal::{
119    FusionStrategy, Modality, ModalityAligner, MultiModalConfig, MultiModalModel,
120};
121pub use neural_ode::{
122    AugmentedNeuralOde, NeuralOdeConfig, NeuralOdeModel, OdeIntegrator, OdeSolver,
123};
124pub use spiking::{
125    LifLayer, MembranePotential, ResetMode, SpikingConfig, SpikingNeuralNetwork, StdpConfig,
126};
127pub use temporal_multiscale::{MultiScaleConfig, MultiScaleModel, ScaleFusion};
128
129// Re-export training loop types
130pub use training_loop::{
131    AdamOptimizer, ArrayDataProvider, ConstantScheduler, DataProvider, ExponentialScheduler,
132    LrScheduler, Optimizer, SgdOptimizer, StepDecayScheduler, TrainingCallback, TrainingConfig,
133    TrainingLoop, TrainingResult,
134};
135
136// Re-export distributed training types
137pub use distributed::{
138    average_gradients, partition_indices, run_parallel_workers, sgd_step, CommBackend,
139    DataParallelModel, DistributedConfig, GradientBuffer, GradientStrategy, GradientSync,
140    LocalGradientSync, SharedGradientStore, ThreadedGradientSync,
141};
142
143// Re-export curriculum learning types
144pub use curriculum::{CurriculumDataProvider, CurriculumScheduler, CurriculumStrategy};
145
146// Re-export gradient checkpointing types
147pub use gradient_checkpoint::{ActivationCheckpointer, CheckpointConfig};
148
149// Re-export speculative decoding types
150pub use speculative::{SpeculativeDecoder, SpeculativeResult};
151
152// Re-export adaptive early exit types
153pub use early_exit::{AdaptiveComputation, EarlyExitConfig, ExitCriterion, ExitStats};
154
155// Re-export BLAS operations for convenience
156pub use blas_ops::{
157    axpy, batch_matmul_vec, dot, matmul_mat, matmul_vec, norm_frobenius, norm_l2, transpose,
158    BlasConfig,
159};
160
161// Re-export profiling utilities
162pub use profiling::{
163    BottleneckInfo, BottleneckSeverity, ComprehensiveComparison, ComprehensiveProfiler,
164    ModelBottleneckAnalysis,
165};
166
167// Re-export core types
168pub use interpretability::{
169    ActivationStats, CompressionAnalysis, GatingAnalysis, InterpretabilityReport, LayerProbe,
170    SensitivityAnalyzer, StateTrajectory,
171};
172// Re-export visualization types
173pub use visualization::{
174    matrix_to_csv, signal_to_svg_sparkline, ActivationHistogram, GatingPatternRecorder,
175    PhasePortrait,
176};
177
178// Re-export new compression types
179pub use compression::{CompressionReport, LowRankApprox, MagnitudePruner, StructuredPruner};
180
181// Re-export full state I/O types
182pub use state_io::{decode_f32_slice, encode_f32_slice, ModelSnapshot};
183
184pub use kizzasi_core::{CoreResult, HiddenState, SignalPredictor};
185pub use rwkv5::{Rwkv5Config, Rwkv5Model, Rwkv5State};
186pub use rwkv7::{Rwkv7Config, Rwkv7Model, Rwkv7State, Rwkv7TimeMixing};
187pub use scirs2_core::ndarray::{Array1, Array2};
188
189/// Trait for model architectures that support autoregressive prediction
190pub trait AutoregressiveModel: SignalPredictor + Send {
191    /// Get the model's hidden dimension
192    fn hidden_dim(&self) -> usize;
193
194    /// Get the model's state dimension (for SSMs)
195    fn state_dim(&self) -> usize;
196
197    /// Get number of layers
198    fn num_layers(&self) -> usize;
199
200    /// Get model type identifier
201    fn model_type(&self) -> ModelType;
202
203    /// Get current hidden states for all layers
204    fn get_states(&self) -> Vec<HiddenState>;
205
206    /// Set hidden states for all layers
207    fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()>;
208
209    /// Load weights from a JSON file (`HashMap<String, Vec<f32>>` format).
210    ///
211    /// Override this method in model implementations that support weight loading.
212    /// The default implementation returns an error indicating the model does not
213    /// support JSON weight loading.
214    fn load_weights_json(&mut self, _path: &std::path::Path) -> ModelResult<()> {
215        Err(ModelError::unsupported_operation(
216            "load_weights_json",
217            format!(
218                "{} (model does not implement JSON weight loading)",
219                std::any::type_name::<Self>()
220            ),
221        ))
222    }
223
224    /// Save weights to a JSON file (`HashMap<String, Vec<f32>>` format).
225    ///
226    /// Override this method in model implementations that support weight saving.
227    /// The default implementation returns an error indicating the model does not
228    /// support JSON weight saving.
229    fn save_weights_json(&self, _path: &std::path::Path) -> ModelResult<()> {
230        Err(ModelError::unsupported_operation(
231            "save_weights_json",
232            format!(
233                "{} (model does not implement JSON weight saving)",
234                std::any::type_name::<Self>()
235            ),
236        ))
237    }
238}
239
240/// Enumeration of supported model architectures
241#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
242pub enum ModelType {
243    /// Mamba: Selective State Space Model
244    Mamba,
245    /// Mamba2: Enhanced selective SSM with SSD
246    Mamba2,
247    /// RWKV: Linear attention with time-mixing (v6)
248    Rwkv,
249    /// RWKV v5: Multi-head WKV with static time decay
250    Rwkv5,
251    /// S4: Structured State Space Model
252    S4,
253    /// S4D: S4 with diagonal state matrix
254    S4D,
255    /// Standard Transformer (for comparison)
256    Transformer,
257    /// Neural ODE: Continuous-time dynamics via ODE solver
258    NeuralOde,
259    /// Multi-modal fusion model
260    MultiModal,
261    /// Spiking Neural Network: biologically-inspired LIF neurons
262    Snn,
263    /// Multi-Scale Temporal Model: multiple temporal resolutions
264    MultiScale,
265}
266
267impl std::fmt::Display for ModelType {
268    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269        match self {
270            ModelType::Mamba => write!(f, "Mamba"),
271            ModelType::Mamba2 => write!(f, "Mamba2"),
272            ModelType::Rwkv => write!(f, "RWKV"),
273            ModelType::Rwkv5 => write!(f, "RWKV5"),
274            ModelType::S4 => write!(f, "S4"),
275            ModelType::S4D => write!(f, "S4D"),
276            ModelType::Transformer => write!(f, "Transformer"),
277            ModelType::NeuralOde => write!(f, "NeuralODE"),
278            ModelType::MultiModal => write!(f, "MultiModal"),
279            ModelType::Snn => write!(f, "SNN"),
280            ModelType::MultiScale => write!(f, "MultiScale"),
281        }
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn test_model_type_display() {
291        assert_eq!(format!("{}", ModelType::Mamba2), "Mamba2");
292        assert_eq!(format!("{}", ModelType::Rwkv), "RWKV");
293    }
294}