1#[cfg(feature = "hf-hub")]
25pub mod hf_hub;
26#[cfg(feature = "hf-hub")]
27pub use hf_hub::{load_from_hub, HfHubClient, HfHubConfig, HfModelInfo};
28
29pub mod arch_search;
30pub use arch_search::{
31 search_best_arch, ArchCandidate, ArchSearchConfig, ArchSearchResult, ArchSearchSpace,
32 EvolutionarySearcher, GridSearcher, RandomArchSearcher,
33};
34
35pub mod backprop;
36pub mod backprop_ssm;
37pub mod batch;
38pub mod blas_ops;
39pub mod cache_friendly;
40pub mod checkpoint;
41pub mod compression;
42pub mod curriculum;
43pub mod distributed;
44pub mod dynamic_quantization;
45pub mod early_exit;
46mod error;
47pub mod factory;
48pub mod flash_linear_attn;
49pub mod gguf;
50pub(crate) mod gguf_dequant;
51pub mod gradient_checkpoint;
52pub mod huggingface;
53pub mod huggingface_loader;
54pub mod incremental_loader;
55pub mod loader;
56pub mod lora;
57pub mod mixed_precision;
58pub mod moe;
59pub mod onnx_export;
60pub mod parallel_multihead;
61pub mod profiling;
62pub mod prune;
63pub mod pytorch_compat;
64pub mod quantization;
65pub mod quantize;
66pub mod simd_ops;
67pub mod speculative;
68pub mod state_io;
69pub mod training;
70pub mod training_loop;
71pub mod visualization;
72
73#[cfg(feature = "mamba")]
74pub mod mamba;
75
76#[cfg(feature = "mamba")]
77pub mod mamba2;
78
79pub mod interpretability;
80
81pub mod rwkv;
82
83pub mod rwkv5;
84
85pub mod rwkv7;
86
87pub mod s4;
88
89pub mod s5;
90
91pub mod h3;
92
93pub mod hybrid;
94
95pub mod multimodal;
96
97pub mod neural_ode;
98
99pub mod spiking;
100
101pub mod temporal_multiscale;
102
103pub mod transformer;
104
105pub use error::{ModelError, ModelResult};
106
107pub use backprop::{
109 layer_norm_backward, linear_backward, silu_backward, softmax_backward, GradAccumulator,
110 GradientTape, SsmBackward, SsmGradients, Tensor,
111};
112pub use gguf::{GgufFile, GgufInspection, GgufMetaValue, GgufQuantType, GgufTensorInfo};
113pub use incremental_loader::{
114 GgufFileSource, IncrementalModelLoader, SafeTensorsSource, WeightSource,
115};
116pub use loader::{ModelLoader, TensorInfo, WeightLoader};
117pub use lora::{LoraAdapter, LoraAdapterSummary, LoraConfig, LoraLinear, QLoraLinear};
118pub use multimodal::{
119 FusionStrategy, Modality, ModalityAligner, MultiModalConfig, MultiModalModel,
120};
121pub use neural_ode::{
122 AugmentedNeuralOde, NeuralOdeConfig, NeuralOdeModel, OdeIntegrator, OdeSolver,
123};
124pub use spiking::{
125 LifLayer, MembranePotential, ResetMode, SpikingConfig, SpikingNeuralNetwork, StdpConfig,
126};
127pub use temporal_multiscale::{MultiScaleConfig, MultiScaleModel, ScaleFusion};
128
129pub use training_loop::{
131 AdamOptimizer, ArrayDataProvider, ConstantScheduler, DataProvider, ExponentialScheduler,
132 LrScheduler, Optimizer, SgdOptimizer, StepDecayScheduler, TrainingCallback, TrainingConfig,
133 TrainingLoop, TrainingResult,
134};
135
136pub use distributed::{
138 average_gradients, partition_indices, run_parallel_workers, sgd_step, CommBackend,
139 DataParallelModel, DistributedConfig, GradientBuffer, GradientStrategy, GradientSync,
140 LocalGradientSync, SharedGradientStore, ThreadedGradientSync,
141};
142
143pub use curriculum::{CurriculumDataProvider, CurriculumScheduler, CurriculumStrategy};
145
146pub use gradient_checkpoint::{ActivationCheckpointer, CheckpointConfig};
148
149pub use speculative::{SpeculativeDecoder, SpeculativeResult};
151
152pub use early_exit::{AdaptiveComputation, EarlyExitConfig, ExitCriterion, ExitStats};
154
155pub use blas_ops::{
157 axpy, batch_matmul_vec, dot, matmul_mat, matmul_vec, norm_frobenius, norm_l2, transpose,
158 BlasConfig,
159};
160
161pub use profiling::{
163 BottleneckInfo, BottleneckSeverity, ComprehensiveComparison, ComprehensiveProfiler,
164 ModelBottleneckAnalysis,
165};
166
167pub use interpretability::{
169 ActivationStats, CompressionAnalysis, GatingAnalysis, InterpretabilityReport, LayerProbe,
170 SensitivityAnalyzer, StateTrajectory,
171};
172pub use visualization::{
174 matrix_to_csv, signal_to_svg_sparkline, ActivationHistogram, GatingPatternRecorder,
175 PhasePortrait,
176};
177
178pub use compression::{CompressionReport, LowRankApprox, MagnitudePruner, StructuredPruner};
180
181pub use state_io::{decode_f32_slice, encode_f32_slice, ModelSnapshot};
183
184pub use kizzasi_core::{CoreResult, HiddenState, SignalPredictor};
185pub use rwkv5::{Rwkv5Config, Rwkv5Model, Rwkv5State};
186pub use rwkv7::{Rwkv7Config, Rwkv7Model, Rwkv7State, Rwkv7TimeMixing};
187pub use scirs2_core::ndarray::{Array1, Array2};
188
189pub trait AutoregressiveModel: SignalPredictor + Send {
191 fn hidden_dim(&self) -> usize;
193
194 fn state_dim(&self) -> usize;
196
197 fn num_layers(&self) -> usize;
199
200 fn model_type(&self) -> ModelType;
202
203 fn get_states(&self) -> Vec<HiddenState>;
205
206 fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()>;
208
209 fn load_weights_json(&mut self, _path: &std::path::Path) -> ModelResult<()> {
215 Err(ModelError::unsupported_operation(
216 "load_weights_json",
217 format!(
218 "{} (model does not implement JSON weight loading)",
219 std::any::type_name::<Self>()
220 ),
221 ))
222 }
223
224 fn save_weights_json(&self, _path: &std::path::Path) -> ModelResult<()> {
230 Err(ModelError::unsupported_operation(
231 "save_weights_json",
232 format!(
233 "{} (model does not implement JSON weight saving)",
234 std::any::type_name::<Self>()
235 ),
236 ))
237 }
238}
239
240#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
242pub enum ModelType {
243 Mamba,
245 Mamba2,
247 Rwkv,
249 Rwkv5,
251 S4,
253 S4D,
255 Transformer,
257 NeuralOde,
259 MultiModal,
261 Snn,
263 MultiScale,
265}
266
267impl std::fmt::Display for ModelType {
268 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269 match self {
270 ModelType::Mamba => write!(f, "Mamba"),
271 ModelType::Mamba2 => write!(f, "Mamba2"),
272 ModelType::Rwkv => write!(f, "RWKV"),
273 ModelType::Rwkv5 => write!(f, "RWKV5"),
274 ModelType::S4 => write!(f, "S4"),
275 ModelType::S4D => write!(f, "S4D"),
276 ModelType::Transformer => write!(f, "Transformer"),
277 ModelType::NeuralOde => write!(f, "NeuralODE"),
278 ModelType::MultiModal => write!(f, "MultiModal"),
279 ModelType::Snn => write!(f, "SNN"),
280 ModelType::MultiScale => write!(f, "MultiScale"),
281 }
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_model_type_display() {
291 assert_eq!(format!("{}", ModelType::Mamba2), "Mamba2");
292 assert_eq!(format!("{}", ModelType::Rwkv), "RWKV");
293 }
294}