1#![cfg_attr(not(feature = "std"), no_std)]
14
15#[cfg(not(feature = "std"))]
16extern crate alloc;
17
18#[cfg(not(feature = "std"))]
19use alloc::vec::Vec;
20
21pub mod attention;
22mod config;
23pub mod conv;
24pub mod dataloader;
25pub mod device;
26pub mod efficient_attention;
27pub mod embedded_alloc;
28mod embedding;
29mod error;
30pub mod fixed_point;
31pub mod flash_attention;
32pub mod gpu_utils;
33pub mod h3;
34pub mod kernel_fusion;
35pub mod lora;
36pub mod mamba2;
37pub mod metrics;
38pub mod nn;
39pub mod numerics;
40pub mod optimizations;
41pub mod parallel;
42pub mod pool;
43pub mod profiling;
44pub mod pruning;
45pub mod pytorch_compat;
46pub mod quantization;
47pub mod retnet;
48pub mod rwkv7;
49pub mod s4d;
50pub mod s5;
51pub mod scan;
52pub mod scheduler;
53pub mod sequences;
54pub mod simd;
55pub mod simd_avx512;
56pub mod simd_neon;
57mod ssm;
58mod state;
59pub mod training;
60pub mod weights;
61
62pub use attention::{GatedLinearAttention, MultiHeadSSMAttention, MultiHeadSSMConfig};
63pub use config::{KizzasiConfig, ModelType};
64pub use conv::{CausalConv1d, DepthwiseCausalConv1d, DilatedCausalConv1d, DilatedStack, ShortConv};
65pub use dataloader::{
66 BatchIterator, DataLoaderConfig, TimeSeriesAugmentation, TimeSeriesDataLoader,
67};
68pub use device::{
69 get_best_device, is_cuda_available, is_metal_available, list_devices, DeviceConfig, DeviceInfo,
70 DeviceType,
71};
72pub use efficient_attention::{
73 EfficientAttentionConfig, EfficientMultiHeadAttention, FusedAttentionKernel,
74};
75pub use embedded_alloc::{BumpAllocator, EmbeddedAllocator, FixedPool, StackAllocator, StackGuard};
76pub use embedding::ContinuousEmbedding;
77pub use error::{CoreError, CoreResult};
78pub use flash_attention::{flash_attention_fused, FlashAttention, FlashAttentionConfig};
79pub use gpu_utils::{GPUMemoryPool, MemoryStats, TensorPrefetch, TensorTransfer, TransferBatch};
80pub use h3::{DiagonalSSM, H3Config, H3Layer, H3Model, ShiftSSM};
81pub use kernel_fusion::{
82 fused_ffn_gelu, fused_layernorm_gelu, fused_layernorm_silu, fused_linear_activation,
83 fused_multihead_output, fused_qkv_projection, fused_quantize_dequantize, fused_softmax_attend,
84 fused_ssm_step,
85};
86pub use lora::{LoRAAdapter, LoRAConfig, LoRALayer};
87pub use mamba2::{Mamba2Config, Mamba2Layer, Mamba2Model};
88pub use metrics::{MetricsLogger, MetricsSummary, TrainingMetrics};
89pub use nn::{
90 gelu, gelu_fast, layer_norm, leaky_relu, log_softmax, relu, rms_norm, sigmoid, silu, softmax,
91 tanh, Activation, ActivationType, GatedLinearUnit, LayerNorm, NormType,
92};
93pub use optimizations::{
94 acquire_workspace, ilp, prefetch, release_workspace, CacheAligned, DiscretizationCache,
95 SSMWorkspace, WorkspaceGuard,
96};
97pub use parallel::{BatchProcessor, ParallelConfig};
98pub use pool::{ArrayPool, MultiArrayPool, PoolStats, PooledArray};
99pub use profiling::{
100 CounterStats, MemoryProfiler, PerfCounter, ProfilerMemoryStats, ProfilingSession, ScopeTimer,
101 Timer,
102};
103pub use pruning::{
104 GradientPruner, PruningConfig, PruningGranularity, PruningMask, PruningStrategy,
105 StructuredPruner,
106};
107pub use pytorch_compat::{
108 detect_checkpoint_architecture, load_pytorch_checkpoint, PyTorchCheckpoint, PyTorchConverter,
109 WeightMapping,
110};
111pub use quantization::{
112 DynamicQuantizer, QuantizationParams, QuantizationScheme, QuantizationType, QuantizedTensor,
113};
114pub use retnet::{MultiScaleRetention, RetNetConfig, RetNetLayer, RetNetModel};
115pub use rwkv7::{ChannelMixing, RWKV7Config, RWKV7Layer, RWKV7Model, TimeMixing};
116pub use s4d::{S4DConfig, S4DLayer, S4DModel};
117pub use s5::{S5Config, S5Layer, S5Model};
118pub use scan::{
119 parallel_scan, parallel_ssm_batch, parallel_ssm_scan, segmented_scan, AssociativeOp,
120 SSMElement, SSMScanOp,
121};
122pub use scheduler::{
123 ConstantScheduler, CosineScheduler, ExponentialScheduler, LRScheduler, LinearScheduler,
124 OneCycleScheduler, PolynomialScheduler, StepScheduler,
125};
126pub use sequences::{
127 apply_mask, masked_mean, masked_sum, pad_sequences, PackedSequence, PaddingStrategy,
128 SequenceMask,
129};
130pub use ssm::{SelectiveSSM, StateSpaceModel};
131pub use state::HiddenState;
132pub use training::{
133 CheckpointMetadata, ConstraintLoss, Loss, MixedPrecision, SchedulerType, TrainableSSM, Trainer,
134 TrainingConfig,
135};
136pub use weights::{WeightFormat, WeightLoadConfig, WeightLoader, WeightPruner};
137
138pub use scirs2_core::ndarray::Array1;
140
141pub trait SignalPredictor {
143 fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>>;
145
146 fn reset(&mut self);
148
149 fn context_window(&self) -> usize;
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn test_config_builder() {
159 let config = KizzasiConfig::new()
160 .model_type(ModelType::Mamba2)
161 .context_window(8192)
162 .hidden_dim(256);
163
164 assert_eq!(config.get_context_window(), 8192);
165 assert_eq!(config.get_hidden_dim(), 256);
166 }
167}