#[cfg(feature = "hf-hub")]
pub mod hf_hub;
#[cfg(feature = "hf-hub")]
pub use hf_hub::{load_from_hub, HfHubClient, HfHubConfig, HfModelInfo};
pub mod arch_search;
pub use arch_search::{
search_best_arch, ArchCandidate, ArchSearchConfig, ArchSearchResult, ArchSearchSpace,
EvolutionarySearcher, GridSearcher, RandomArchSearcher,
};
pub mod backprop;
pub mod backprop_ssm;
pub mod batch;
pub mod blas_ops;
pub mod cache_friendly;
pub mod checkpoint;
pub mod compression;
pub mod curriculum;
pub mod distributed;
pub mod dynamic_quantization;
pub mod early_exit;
mod error;
pub mod factory;
pub mod flash_linear_attn;
pub mod gguf;
pub(crate) mod gguf_dequant;
pub mod gradient_checkpoint;
pub mod huggingface;
pub mod huggingface_loader;
pub mod incremental_loader;
pub mod loader;
pub mod lora;
pub mod mixed_precision;
pub mod moe;
pub mod onnx_export;
pub mod parallel_multihead;
pub mod profiling;
pub mod prune;
pub mod pytorch_compat;
pub mod quantization;
pub mod quantize;
pub mod simd_ops;
pub mod speculative;
pub mod state_io;
pub mod training;
pub mod training_loop;
pub mod visualization;
#[cfg(feature = "mamba")]
pub mod mamba;
#[cfg(feature = "mamba")]
pub mod mamba2;
pub mod interpretability;
pub mod rwkv;
pub mod rwkv5;
pub mod rwkv7;
pub mod s4;
pub mod s5;
pub mod h3;
pub mod hybrid;
pub mod multimodal;
pub mod neural_ode;
pub mod spiking;
pub mod temporal_multiscale;
pub mod transformer;
pub use error::{ModelError, ModelResult};
pub use backprop::{
layer_norm_backward, linear_backward, silu_backward, softmax_backward, GradAccumulator,
GradientTape, SsmBackward, SsmGradients, Tensor,
};
pub use gguf::{GgufFile, GgufInspection, GgufMetaValue, GgufQuantType, GgufTensorInfo};
pub use incremental_loader::{
GgufFileSource, IncrementalModelLoader, SafeTensorsSource, WeightSource,
};
pub use loader::{ModelLoader, TensorInfo, WeightLoader};
pub use lora::{LoraAdapter, LoraAdapterSummary, LoraConfig, LoraLinear, QLoraLinear};
pub use multimodal::{
FusionStrategy, Modality, ModalityAligner, MultiModalConfig, MultiModalModel,
};
pub use neural_ode::{
AugmentedNeuralOde, NeuralOdeConfig, NeuralOdeModel, OdeIntegrator, OdeSolver,
};
pub use spiking::{
LifLayer, MembranePotential, ResetMode, SpikingConfig, SpikingNeuralNetwork, StdpConfig,
};
pub use temporal_multiscale::{MultiScaleConfig, MultiScaleModel, ScaleFusion};
pub use training_loop::{
AdamOptimizer, ArrayDataProvider, ConstantScheduler, DataProvider, ExponentialScheduler,
LrScheduler, Optimizer, SgdOptimizer, StepDecayScheduler, TrainingCallback, TrainingConfig,
TrainingLoop, TrainingResult,
};
pub use distributed::{
average_gradients, partition_indices, run_parallel_workers, sgd_step, CommBackend,
DataParallelModel, DistributedConfig, GradientBuffer, GradientStrategy, GradientSync,
LocalGradientSync, SharedGradientStore, ThreadedGradientSync,
};
pub use curriculum::{CurriculumDataProvider, CurriculumScheduler, CurriculumStrategy};
pub use gradient_checkpoint::{ActivationCheckpointer, CheckpointConfig};
pub use speculative::{SpeculativeDecoder, SpeculativeResult};
pub use early_exit::{AdaptiveComputation, EarlyExitConfig, ExitCriterion, ExitStats};
pub use blas_ops::{
axpy, batch_matmul_vec, dot, matmul_mat, matmul_vec, norm_frobenius, norm_l2, transpose,
BlasConfig,
};
pub use profiling::{
BottleneckInfo, BottleneckSeverity, ComprehensiveComparison, ComprehensiveProfiler,
ModelBottleneckAnalysis,
};
pub use interpretability::{
ActivationStats, CompressionAnalysis, GatingAnalysis, InterpretabilityReport, LayerProbe,
SensitivityAnalyzer, StateTrajectory,
};
pub use visualization::{
matrix_to_csv, signal_to_svg_sparkline, ActivationHistogram, GatingPatternRecorder,
PhasePortrait,
};
pub use compression::{CompressionReport, LowRankApprox, MagnitudePruner, StructuredPruner};
pub use state_io::{decode_f32_slice, encode_f32_slice, ModelSnapshot};
pub use kizzasi_core::{CoreResult, HiddenState, SignalPredictor};
pub use rwkv5::{Rwkv5Config, Rwkv5Model, Rwkv5State};
pub use rwkv7::{Rwkv7Config, Rwkv7Model, Rwkv7State, Rwkv7TimeMixing};
pub use scirs2_core::ndarray::{Array1, Array2};
pub trait AutoregressiveModel: SignalPredictor + Send {
fn hidden_dim(&self) -> usize;
fn state_dim(&self) -> usize;
fn num_layers(&self) -> usize;
fn model_type(&self) -> ModelType;
fn get_states(&self) -> Vec<HiddenState>;
fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()>;
fn load_weights_json(&mut self, _path: &std::path::Path) -> ModelResult<()> {
Err(ModelError::unsupported_operation(
"load_weights_json",
format!(
"{} (model does not implement JSON weight loading)",
std::any::type_name::<Self>()
),
))
}
fn save_weights_json(&self, _path: &std::path::Path) -> ModelResult<()> {
Err(ModelError::unsupported_operation(
"save_weights_json",
format!(
"{} (model does not implement JSON weight saving)",
std::any::type_name::<Self>()
),
))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum ModelType {
Mamba,
Mamba2,
Rwkv,
Rwkv5,
S4,
S4D,
Transformer,
NeuralOde,
MultiModal,
Snn,
MultiScale,
}
impl std::fmt::Display for ModelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ModelType::Mamba => write!(f, "Mamba"),
ModelType::Mamba2 => write!(f, "Mamba2"),
ModelType::Rwkv => write!(f, "RWKV"),
ModelType::Rwkv5 => write!(f, "RWKV5"),
ModelType::S4 => write!(f, "S4"),
ModelType::S4D => write!(f, "S4D"),
ModelType::Transformer => write!(f, "Transformer"),
ModelType::NeuralOde => write!(f, "NeuralODE"),
ModelType::MultiModal => write!(f, "MultiModal"),
ModelType::Snn => write!(f, "SNN"),
ModelType::MultiScale => write!(f, "MultiScale"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_type_display() {
assert_eq!(format!("{}", ModelType::Mamba2), "Mamba2");
assert_eq!(format!("{}", ModelType::Rwkv), "RWKV");
}
}