kizzasi_core/
lib.rs

1//! # kizzasi-core
2//!
3//! Core SSM (State Space Model) engine for Kizzasi AGSP.
4//!
5//! Implements linear-time State Space Models (Mamba/S4/RWKV) for efficient
6//! processing of continuous signal streams with O(1) inference step complexity.
7//!
8//! ## COOLJAPAN Ecosystem
9//!
10//! This crate follows the KIZZASI_POLICY.md and uses `scirs2-core` for all
11//! array and numerical operations.
12
13#![cfg_attr(not(feature = "std"), no_std)]
14
15#[cfg(not(feature = "std"))]
16extern crate alloc;
17
18#[cfg(not(feature = "std"))]
19use alloc::vec::Vec;
20
21pub mod attention;
22mod config;
23pub mod conv;
24pub mod dataloader;
25pub mod device;
26pub mod efficient_attention;
27pub mod embedded_alloc;
28mod embedding;
29mod error;
30pub mod fixed_point;
31pub mod flash_attention;
32pub mod gpu_utils;
33pub mod h3;
34pub mod kernel_fusion;
35pub mod lora;
36pub mod mamba2;
37pub mod metrics;
38pub mod nn;
39pub mod numerics;
40pub mod optimizations;
41pub mod parallel;
42pub mod pool;
43pub mod profiling;
44pub mod pruning;
45pub mod pytorch_compat;
46pub mod quantization;
47pub mod retnet;
48pub mod rwkv7;
49pub mod s4d;
50pub mod s5;
51pub mod scan;
52pub mod scheduler;
53pub mod sequences;
54pub mod simd;
55pub mod simd_avx512;
56pub mod simd_neon;
57mod ssm;
58mod state;
59pub mod training;
60pub mod weights;
61
62pub use attention::{GatedLinearAttention, MultiHeadSSMAttention, MultiHeadSSMConfig};
63pub use config::{KizzasiConfig, ModelType};
64pub use conv::{CausalConv1d, DepthwiseCausalConv1d, DilatedCausalConv1d, DilatedStack, ShortConv};
65pub use dataloader::{
66    BatchIterator, DataLoaderConfig, TimeSeriesAugmentation, TimeSeriesDataLoader,
67};
68pub use device::{
69    get_best_device, is_cuda_available, is_metal_available, list_devices, DeviceConfig, DeviceInfo,
70    DeviceType,
71};
72pub use efficient_attention::{
73    EfficientAttentionConfig, EfficientMultiHeadAttention, FusedAttentionKernel,
74};
75pub use embedded_alloc::{BumpAllocator, EmbeddedAllocator, FixedPool, StackAllocator, StackGuard};
76pub use embedding::ContinuousEmbedding;
77pub use error::{CoreError, CoreResult};
78pub use flash_attention::{flash_attention_fused, FlashAttention, FlashAttentionConfig};
79pub use gpu_utils::{GPUMemoryPool, MemoryStats, TensorPrefetch, TensorTransfer, TransferBatch};
80pub use h3::{DiagonalSSM, H3Config, H3Layer, H3Model, ShiftSSM};
81pub use kernel_fusion::{
82    fused_ffn_gelu, fused_layernorm_gelu, fused_layernorm_silu, fused_linear_activation,
83    fused_multihead_output, fused_qkv_projection, fused_quantize_dequantize, fused_softmax_attend,
84    fused_ssm_step,
85};
86pub use lora::{LoRAAdapter, LoRAConfig, LoRALayer};
87pub use mamba2::{Mamba2Config, Mamba2Layer, Mamba2Model};
88pub use metrics::{MetricsLogger, MetricsSummary, TrainingMetrics};
89pub use nn::{
90    gelu, gelu_fast, layer_norm, leaky_relu, log_softmax, relu, rms_norm, sigmoid, silu, softmax,
91    tanh, Activation, ActivationType, GatedLinearUnit, LayerNorm, NormType,
92};
93pub use optimizations::{
94    acquire_workspace, ilp, prefetch, release_workspace, CacheAligned, DiscretizationCache,
95    SSMWorkspace, WorkspaceGuard,
96};
97pub use parallel::{BatchProcessor, ParallelConfig};
98pub use pool::{ArrayPool, MultiArrayPool, PoolStats, PooledArray};
99pub use profiling::{
100    CounterStats, MemoryProfiler, PerfCounter, ProfilerMemoryStats, ProfilingSession, ScopeTimer,
101    Timer,
102};
103pub use pruning::{
104    GradientPruner, PruningConfig, PruningGranularity, PruningMask, PruningStrategy,
105    StructuredPruner,
106};
107pub use pytorch_compat::{
108    detect_checkpoint_architecture, load_pytorch_checkpoint, PyTorchCheckpoint, PyTorchConverter,
109    WeightMapping,
110};
111pub use quantization::{
112    DynamicQuantizer, QuantizationParams, QuantizationScheme, QuantizationType, QuantizedTensor,
113};
114pub use retnet::{MultiScaleRetention, RetNetConfig, RetNetLayer, RetNetModel};
115pub use rwkv7::{ChannelMixing, RWKV7Config, RWKV7Layer, RWKV7Model, TimeMixing};
116pub use s4d::{S4DConfig, S4DLayer, S4DModel};
117pub use s5::{S5Config, S5Layer, S5Model};
118pub use scan::{
119    parallel_scan, parallel_ssm_batch, parallel_ssm_scan, segmented_scan, AssociativeOp,
120    SSMElement, SSMScanOp,
121};
122pub use scheduler::{
123    ConstantScheduler, CosineScheduler, ExponentialScheduler, LRScheduler, LinearScheduler,
124    OneCycleScheduler, PolynomialScheduler, StepScheduler,
125};
126pub use sequences::{
127    apply_mask, masked_mean, masked_sum, pad_sequences, PackedSequence, PaddingStrategy,
128    SequenceMask,
129};
130pub use ssm::{SelectiveSSM, StateSpaceModel};
131pub use state::HiddenState;
132pub use training::{
133    CheckpointMetadata, ConstraintLoss, Loss, MixedPrecision, SchedulerType, TrainableSSM, Trainer,
134    TrainingConfig,
135};
136pub use weights::{WeightFormat, WeightLoadConfig, WeightLoader, WeightPruner};
137
138// Re-export scirs2-core types for convenience
139pub use scirs2_core::ndarray::Array1;
140
141/// Core trait for autoregressive signal prediction
142pub trait SignalPredictor {
143    /// Update hidden state and predict the next signal vector
144    fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>>;
145
146    /// Reset the hidden state to initial values
147    fn reset(&mut self);
148
149    /// Get the current context window size
150    fn context_window(&self) -> usize;
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    #[test]
158    fn test_config_builder() {
159        let config = KizzasiConfig::new()
160            .model_type(ModelType::Mamba2)
161            .context_window(8192)
162            .hidden_dim(256);
163
164        assert_eq!(config.get_context_window(), 8192);
165        assert_eq!(config.get_hidden_dim(), 256);
166    }
167}