Skip to main content

kizzasi_model/
temporal_multiscale.rs

1//! Multi-Scale Temporal Modeling
2//!
3//! This module implements temporal modeling at multiple resolutions simultaneously.
4//! Each "scale" processes the input at a different decimation rate, allowing the
5//! model to capture both fine-grained and coarse temporal dynamics.
6//!
7//! # Architecture
8//!
9//! ```text
10//! Input ──→ [Scale 1, dt=1]  ──→ state_1 ──┐
11//!        ├─→ [Scale 2, dt=4]  ──→ state_2 ──┤ → [Fusion] → [Output Proj] → Prediction
12//!        └─→ [Scale 3, dt=16] ──→ state_3 ──┘
13//! ```
14//!
15//! Each `TemporalScale` updates its state only every `decimation` steps,
16//! emulating processing at different temporal resolutions.
17//!
18//! # Fusion Strategies
19//!
20//! - **Concatenate**: Concatenate all scale outputs, then project linearly.
21//! - **Weighted**: Learned scalar weights (softmax-normalized) per scale.
22//! - **Attention**: Cross-scale attention — query from mean, keys/values from scales.
23//!
24//! # GRU-style Update
25//!
26//! Each scale uses a simplified GRU recurrence:
27//! ```text
28//! state = tanh(W_proj @ input + W_rec @ state + bias)
29//! ```
30//!
31//! # References
32//!
33//! - "Temporal Convolutional Networks" (Bai et al., 2018)
34//! - "Multi-Scale RNNs" (El Hihi & Bengio, 1996)
35
36use crate::error::{ModelError, ModelResult};
37use crate::{AutoregressiveModel, ModelType};
38use kizzasi_core::{CoreResult, HiddenState, SignalPredictor};
39use scirs2_core::ndarray::{Array1, Array2};
40use serde::{Deserialize, Serialize};
41
42#[allow(unused_imports)]
43use tracing::{debug, instrument, trace};
44
45// ---------------------------------------------------------------------------
46// Seeded deterministic RNG
47// ---------------------------------------------------------------------------
48
49struct SeededRng {
50    state: u64,
51}
52
53impl SeededRng {
54    fn new(seed: u64) -> Self {
55        Self { state: seed.max(1) }
56    }
57
58    fn next_f32(&mut self) -> f32 {
59        self.state ^= self.state << 13;
60        self.state ^= self.state >> 7;
61        self.state ^= self.state << 17;
62        (self.state as f64 / u64::MAX as f64 * 2.0 - 1.0) as f32
63    }
64}
65
66// ---------------------------------------------------------------------------
67// Scale Fusion Strategy
68// ---------------------------------------------------------------------------
69
70/// Strategy for combining outputs from multiple temporal scales
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub enum ScaleFusion {
73    /// Concatenate all scale hidden states → linear projection to output_dim
74    Concatenate,
75    /// Learned softmax-normalized scalar weight per scale
76    Weighted,
77    /// Cross-scale attention: query = mean of states, keys/values = per-scale states
78    Attention,
79}
80
81// ---------------------------------------------------------------------------
82// Configuration
83// ---------------------------------------------------------------------------
84
85/// Configuration for Multi-Scale Temporal Model
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct MultiScaleConfig {
88    /// Input signal dimension
89    pub input_dim: usize,
90    /// Hidden dimension per temporal scale
91    pub hidden_dim: usize,
92    /// Output dimension
93    pub output_dim: usize,
94    /// Number of temporal scales
95    pub num_scales: usize,
96    /// Decimation factor per scale (e.g., [1, 4, 16] means scale 0 updates every step,
97    /// scale 1 every 4 steps, scale 2 every 16 steps)
98    pub scale_factors: Vec<usize>,
99    /// Fusion strategy for combining scale outputs
100    pub fusion: ScaleFusion,
101    /// Nominal context window (informational only; recurrence is unbounded)
102    pub context_length: usize,
103}
104
105impl MultiScaleConfig {
106    /// Validate configuration
107    pub fn validate(&self) -> ModelResult<()> {
108        if self.input_dim == 0 {
109            return Err(ModelError::invalid_config("input_dim must be > 0"));
110        }
111        if self.hidden_dim == 0 {
112            return Err(ModelError::invalid_config("hidden_dim must be > 0"));
113        }
114        if self.output_dim == 0 {
115            return Err(ModelError::invalid_config("output_dim must be > 0"));
116        }
117        if self.num_scales == 0 {
118            return Err(ModelError::invalid_config("num_scales must be > 0"));
119        }
120        if self.scale_factors.len() != self.num_scales {
121            return Err(ModelError::invalid_config(
122                "scale_factors.len() must equal num_scales",
123            ));
124        }
125        for &sf in &self.scale_factors {
126            if sf == 0 {
127                return Err(ModelError::invalid_config("all scale_factors must be > 0"));
128            }
129        }
130        Ok(())
131    }
132}
133
134// ---------------------------------------------------------------------------
135// TemporalScale
136// ---------------------------------------------------------------------------
137
138/// A single temporal scale that processes input at a given decimation rate.
139///
140/// The scale runs a GRU-style recurrent update every `decimation` steps.
141/// Between updates, it holds its previous state constant.
142pub struct TemporalScale {
143    hidden_dim: usize,
144    /// Number of steps between state updates
145    decimation: usize,
146    /// Input projection: (hidden_dim, input_dim)
147    projection: Array2<f32>,
148    /// Recurrent weight: (hidden_dim, hidden_dim)
149    recurrent: Array2<f32>,
150    /// Bias
151    bias: Array1<f32>,
152    /// Internal tick counter (0..decimation-1)
153    tick_counter: usize,
154    /// Current hidden state
155    state: Array1<f32>,
156}
157
158impl TemporalScale {
159    /// Create a new temporal scale
160    pub fn new(input_dim: usize, hidden_dim: usize, decimation: usize) -> ModelResult<Self> {
161        if input_dim == 0 || hidden_dim == 0 || decimation == 0 {
162            return Err(ModelError::invalid_config(
163                "TemporalScale dimensions and decimation must be > 0",
164            ));
165        }
166
167        let scale_input = (2.0 / (input_dim + hidden_dim) as f32).sqrt();
168        let scale_rec = (2.0 / (hidden_dim + hidden_dim) as f32).sqrt();
169        let seed = ((input_dim + hidden_dim * 37 + decimation * 997) as u64)
170            .wrapping_mul(6364136223846793005);
171        let mut rng = SeededRng::new(seed);
172
173        let projection =
174            Array2::from_shape_fn((hidden_dim, input_dim), |_| rng.next_f32() * scale_input);
175        let recurrent =
176            Array2::from_shape_fn((hidden_dim, hidden_dim), |_| rng.next_f32() * scale_rec);
177        let bias = Array1::from_shape_fn(hidden_dim, |_| rng.next_f32() * 0.01);
178
179        Ok(Self {
180            hidden_dim,
181            decimation,
182            projection,
183            recurrent,
184            bias,
185            tick_counter: 0,
186            state: Array1::zeros(hidden_dim),
187        })
188    }
189
190    /// Process one time step.
191    ///
192    /// Returns `Some(state)` when the scale updates (every `decimation` steps),
193    /// or `None` if this step is a no-op for this scale.
194    #[instrument(skip(self, input), fields(decimation = self.decimation, tick = self.tick_counter))]
195    pub fn step(&mut self, input: &Array1<f32>) -> ModelResult<Option<Array1<f32>>> {
196        self.tick_counter += 1;
197
198        if !self.tick_counter.is_multiple_of(self.decimation) {
199            return Ok(None);
200        }
201
202        // GRU-style update: state = tanh(W_proj @ input + W_rec @ state + bias)
203        let proj_out = self.projection.dot(input);
204        let rec_out = self.recurrent.dot(&self.state);
205        let pre_act = proj_out + rec_out + &self.bias;
206        let new_state = pre_act.mapv(f32::tanh);
207
208        if new_state.iter().any(|v| !v.is_finite()) {
209            return Err(ModelError::numerical_instability(
210                "TemporalScale::step",
211                "NaN or Inf in state update",
212            ));
213        }
214
215        self.state = new_state.clone();
216        Ok(Some(new_state))
217    }
218
219    /// Get the current hidden state (does not advance the tick counter)
220    pub fn current_state(&self) -> &Array1<f32> {
221        &self.state
222    }
223
224    /// Reset tick counter and hidden state to zero
225    pub fn reset(&mut self) {
226        self.tick_counter = 0;
227        self.state.fill(0.0);
228    }
229
230    /// Get hidden dimension
231    pub fn hidden_dim(&self) -> usize {
232        self.hidden_dim
233    }
234
235    /// Get decimation factor
236    pub fn decimation(&self) -> usize {
237        self.decimation
238    }
239}
240
241// ---------------------------------------------------------------------------
242// ScaleFusionLayer
243// ---------------------------------------------------------------------------
244
245/// Internal fusion module that combines outputs from multiple temporal scales.
246struct ScaleFusionLayer {
247    fusion: ScaleFusion,
248    /// For Concatenate: projects [num_scales * hidden_dim] → hidden_dim
249    concat_proj: Option<Array2<f32>>,
250    /// For Weighted: log-space unnormalized weights (one per scale)
251    scale_weights: Option<Array1<f32>>,
252    /// For Attention: query projection (hidden_dim, hidden_dim)
253    attn_q: Option<Array2<f32>>,
254    /// For Attention: key projection (hidden_dim, hidden_dim)
255    attn_k: Option<Array2<f32>>,
256    /// For Attention: value projection (hidden_dim, hidden_dim)
257    attn_v: Option<Array2<f32>>,
258    num_scales: usize,
259    hidden_dim: usize,
260}
261
262impl ScaleFusionLayer {
263    fn new(
264        fusion: ScaleFusion,
265        num_scales: usize,
266        hidden_dim: usize,
267        seed: u64,
268    ) -> ModelResult<Self> {
269        if num_scales == 0 || hidden_dim == 0 {
270            return Err(ModelError::invalid_config(
271                "ScaleFusionLayer: num_scales and hidden_dim must be > 0",
272            ));
273        }
274
275        let mut rng = SeededRng::new(seed);
276        let scale = (2.0 / (hidden_dim * 2) as f32).sqrt();
277
278        let (concat_proj, scale_weights, attn_q, attn_k, attn_v) = match &fusion {
279            ScaleFusion::Concatenate => {
280                let in_dim = num_scales * hidden_dim;
281                let proj_scale = (2.0 / (in_dim + hidden_dim) as f32).sqrt();
282                let proj =
283                    Array2::from_shape_fn((hidden_dim, in_dim), |_| rng.next_f32() * proj_scale);
284                (Some(proj), None, None, None, None)
285            }
286            ScaleFusion::Weighted => {
287                // Initialize log-weights to zero (uniform after softmax)
288                let weights = Array1::zeros(num_scales);
289                (None, Some(weights), None, None, None)
290            }
291            ScaleFusion::Attention => {
292                let q = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| rng.next_f32() * scale);
293                let k = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| rng.next_f32() * scale);
294                let v = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| rng.next_f32() * scale);
295                (None, None, Some(q), Some(k), Some(v))
296            }
297        };
298
299        Ok(Self {
300            fusion,
301            concat_proj,
302            scale_weights,
303            attn_q,
304            attn_k,
305            attn_v,
306            num_scales,
307            hidden_dim,
308        })
309    }
310
311    /// Fuse scale states into a single hidden_dim vector.
312    fn fuse(&self, scale_states: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
313        if scale_states.len() != self.num_scales {
314            return Err(ModelError::dimension_mismatch(
315                "ScaleFusionLayer::fuse",
316                self.num_scales,
317                scale_states.len(),
318            ));
319        }
320
321        match &self.fusion {
322            ScaleFusion::Concatenate => self.fuse_concatenate(scale_states),
323            ScaleFusion::Weighted => self.fuse_weighted(scale_states),
324            ScaleFusion::Attention => self.fuse_attention(scale_states),
325        }
326    }
327
328    fn fuse_concatenate(&self, scale_states: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
329        let proj = self.concat_proj.as_ref().ok_or_else(|| {
330            ModelError::not_initialized("concat_proj missing for Concatenate fusion")
331        })?;
332
333        // Concatenate all states into a single vector
334        let total_dim = self.num_scales * self.hidden_dim;
335        let mut concat = Array1::<f32>::zeros(total_dim);
336        for (i, state) in scale_states.iter().enumerate() {
337            let start = i * self.hidden_dim;
338            let end = start + self.hidden_dim;
339            if state.len() != self.hidden_dim {
340                return Err(ModelError::dimension_mismatch(
341                    format!("scale {i} state"),
342                    self.hidden_dim,
343                    state.len(),
344                ));
345            }
346            concat
347                .slice_mut(scirs2_core::ndarray::s![start..end])
348                .assign(state);
349        }
350
351        Ok(proj.dot(&concat))
352    }
353
354    fn fuse_weighted(&self, scale_states: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
355        let log_weights = self.scale_weights.as_ref().ok_or_else(|| {
356            ModelError::not_initialized("scale_weights missing for Weighted fusion")
357        })?;
358
359        // Softmax normalization for numerical stability
360        let max_w = log_weights
361            .iter()
362            .cloned()
363            .fold(f32::NEG_INFINITY, f32::max);
364        let exp_w: Vec<f32> = log_weights.iter().map(|&w| (w - max_w).exp()).collect();
365        let sum_exp: f32 = exp_w.iter().sum();
366        let norm_weights: Vec<f32> = exp_w.iter().map(|&e| e / sum_exp).collect();
367
368        let mut result = Array1::<f32>::zeros(self.hidden_dim);
369        for (state, &w) in scale_states.iter().zip(norm_weights.iter()) {
370            if state.len() != self.hidden_dim {
371                return Err(ModelError::dimension_mismatch(
372                    "weighted scale state",
373                    self.hidden_dim,
374                    state.len(),
375                ));
376            }
377            result = result + state * w;
378        }
379        Ok(result)
380    }
381
382    fn fuse_attention(&self, scale_states: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
383        let q_proj = self
384            .attn_q
385            .as_ref()
386            .ok_or_else(|| ModelError::not_initialized("attn_q missing for Attention fusion"))?;
387        let k_proj = self
388            .attn_k
389            .as_ref()
390            .ok_or_else(|| ModelError::not_initialized("attn_k missing for Attention fusion"))?;
391        let v_proj = self
392            .attn_v
393            .as_ref()
394            .ok_or_else(|| ModelError::not_initialized("attn_v missing for Attention fusion"))?;
395
396        // Query: mean of all scale states
397        let mut mean_state = Array1::<f32>::zeros(self.hidden_dim);
398        for state in scale_states {
399            if state.len() != self.hidden_dim {
400                return Err(ModelError::dimension_mismatch(
401                    "attention scale state",
402                    self.hidden_dim,
403                    state.len(),
404                ));
405            }
406            mean_state += state;
407        }
408        mean_state.mapv_inplace(|v| v / self.num_scales as f32);
409
410        let query = q_proj.dot(&mean_state); // (hidden_dim,)
411        let scale_factor = (self.hidden_dim as f32).sqrt();
412
413        // Compute attention scores: q · k_i / sqrt(d)
414        let mut scores = Vec::with_capacity(self.num_scales);
415        for state in scale_states {
416            let key_i = k_proj.dot(state);
417            let score = query.dot(&key_i) / scale_factor;
418            scores.push(score);
419        }
420
421        // Softmax over scores
422        let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
423        let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
424        let sum_exp: f32 = exp_scores.iter().sum();
425        let attn_weights: Vec<f32> = exp_scores.iter().map(|&e| e / sum_exp).collect();
426
427        // Weighted sum of values
428        let mut result = Array1::<f32>::zeros(self.hidden_dim);
429        for (state, &w) in scale_states.iter().zip(attn_weights.iter()) {
430            let value_i = v_proj.dot(state);
431            result = result + value_i * w;
432        }
433        Ok(result)
434    }
435}
436
437// ---------------------------------------------------------------------------
438// MultiScaleModel
439// ---------------------------------------------------------------------------
440
441/// Multi-Scale Temporal Model that processes signals at multiple resolutions.
442pub struct MultiScaleModel {
443    /// Model configuration
444    pub config: MultiScaleConfig,
445    /// Per-scale temporal processors
446    scales: Vec<TemporalScale>,
447    /// Fusion layer for combining scale outputs
448    fusion_layer: ScaleFusionLayer,
449    /// Output projection: (output_dim, hidden_dim)
450    output_proj: Array2<f32>,
451    /// Output bias
452    output_bias: Array1<f32>,
453    /// Most recent output from each scale (initialized to zeros)
454    last_scale_outputs: Vec<Array1<f32>>,
455}
456
457impl MultiScaleModel {
458    /// Create a new multi-scale model
459    #[instrument(skip(config), fields(scales = config.num_scales, hidden = config.hidden_dim))]
460    pub fn new(config: MultiScaleConfig) -> ModelResult<Self> {
461        config.validate()?;
462        debug!(
463            "Building MultiScaleModel: {} scales at {:?}",
464            config.num_scales, config.scale_factors
465        );
466
467        let mut scales = Vec::with_capacity(config.num_scales);
468        for (i, &decimation) in config.scale_factors.iter().enumerate() {
469            let seed = ((i + 1) as u64).wrapping_mul(6364136223846793005);
470            let _ = seed; // seed baked into TemporalScale::new via dimension-based seeding
471            scales.push(TemporalScale::new(
472                config.input_dim,
473                config.hidden_dim,
474                decimation,
475            )?);
476        }
477
478        let fusion_seed = (config.num_scales as u64 * 1000 + config.hidden_dim as u64)
479            .wrapping_mul(2862933555777941757);
480        let fusion_layer = ScaleFusionLayer::new(
481            config.fusion.clone(),
482            config.num_scales,
483            config.hidden_dim,
484            fusion_seed,
485        )?;
486
487        let out_scale = (2.0 / (config.hidden_dim + config.output_dim) as f32).sqrt();
488        let mut rng = SeededRng::new(
489            ((config.hidden_dim * 7919 + config.output_dim) as u64)
490                .wrapping_mul(6364136223846793005),
491        );
492        let output_proj = Array2::from_shape_fn((config.output_dim, config.hidden_dim), |_| {
493            rng.next_f32() * out_scale
494        });
495        let output_bias = Array1::from_shape_fn(config.output_dim, |_| rng.next_f32() * 0.01);
496
497        let last_scale_outputs = vec![Array1::zeros(config.hidden_dim); config.num_scales];
498
499        debug!("MultiScaleModel built successfully");
500        Ok(Self {
501            config,
502            scales,
503            fusion_layer,
504            output_proj,
505            output_bias,
506            last_scale_outputs,
507        })
508    }
509
510    /// Small preset: 3 scales at [1, 4, 16], input/output dim 1
511    pub fn small() -> ModelResult<Self> {
512        let config = MultiScaleConfig {
513            input_dim: 1,
514            hidden_dim: 32,
515            output_dim: 1,
516            num_scales: 3,
517            scale_factors: vec![1, 4, 16],
518            fusion: ScaleFusion::Concatenate,
519            context_length: 512,
520        };
521        Self::new(config)
522    }
523
524    /// Base preset: 4 scales at [1, 2, 8, 32], input/output dim 1
525    pub fn base() -> ModelResult<Self> {
526        let config = MultiScaleConfig {
527            input_dim: 1,
528            hidden_dim: 64,
529            output_dim: 1,
530            num_scales: 4,
531            scale_factors: vec![1, 2, 8, 32],
532            fusion: ScaleFusion::Weighted,
533            context_length: 2048,
534        };
535        Self::new(config)
536    }
537
538    /// Internal forward step
539    fn forward_step(&mut self, input: &Array1<f32>) -> ModelResult<Array1<f32>> {
540        if input.len() != self.config.input_dim {
541            return Err(ModelError::dimension_mismatch(
542                "MultiScaleModel input",
543                self.config.input_dim,
544                input.len(),
545            ));
546        }
547
548        // Step each scale; update last_scale_outputs when scale fires
549        for (i, scale) in self.scales.iter_mut().enumerate() {
550            if let Some(new_state) = scale.step(input)? {
551                self.last_scale_outputs[i] = new_state;
552            }
553        }
554
555        // Fuse all (possibly stale) scale outputs
556        let fused = self.fusion_layer.fuse(&self.last_scale_outputs)?;
557
558        // Output projection
559        let output = self.output_proj.dot(&fused) + &self.output_bias;
560
561        if output.iter().any(|v| !v.is_finite()) {
562            return Err(ModelError::numerical_instability(
563                "MultiScaleModel output",
564                "NaN or Inf detected",
565            ));
566        }
567
568        Ok(output)
569    }
570}
571
572impl SignalPredictor for MultiScaleModel {
573    #[instrument(skip(self, input))]
574    fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
575        self.forward_step(input)
576            .map_err(|e| kizzasi_core::CoreError::Generic(e.to_string()))
577    }
578
579    #[instrument(skip(self))]
580    fn reset(&mut self) {
581        debug!("Resetting MultiScaleModel state");
582        for scale in &mut self.scales {
583            scale.reset();
584        }
585        for output in &mut self.last_scale_outputs {
586            output.fill(0.0);
587        }
588    }
589
590    fn context_window(&self) -> usize {
591        self.config.context_length
592    }
593}
594
595impl AutoregressiveModel for MultiScaleModel {
596    fn hidden_dim(&self) -> usize {
597        self.config.hidden_dim
598    }
599
600    fn state_dim(&self) -> usize {
601        // Total state = hidden_dim per scale
602        self.config.hidden_dim * self.config.num_scales
603    }
604
605    fn num_layers(&self) -> usize {
606        self.config.num_scales
607    }
608
609    fn model_type(&self) -> ModelType {
610        ModelType::MultiScale
611    }
612
613    fn get_states(&self) -> Vec<HiddenState> {
614        self.scales
615            .iter()
616            .map(|scale| {
617                let state = scale.current_state().clone();
618                let dim = state.len();
619                let state_2d = state.insert_axis(scirs2_core::ndarray::Axis(0));
620                let mut hidden = HiddenState::new(dim, 1);
621                hidden.update(state_2d);
622                hidden
623            })
624            .collect()
625    }
626
627    fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
628        if states.len() != self.config.num_scales {
629            return Err(ModelError::state_count_mismatch(
630                "MultiScale",
631                self.config.num_scales,
632                states.len(),
633            ));
634        }
635        for (scale, hidden) in self.scales.iter_mut().zip(states.iter()) {
636            let state_2d = hidden.state();
637            if state_2d.nrows() > 0 && state_2d.ncols() > 0 {
638                scale.state = state_2d.row(0).to_owned();
639            }
640        }
641        Ok(())
642    }
643}
644
645// ---------------------------------------------------------------------------
646// Tests
647// ---------------------------------------------------------------------------
648
649#[cfg(test)]
650mod tests {
651    use super::*;
652
653    fn make_concat_config() -> MultiScaleConfig {
654        MultiScaleConfig {
655            input_dim: 4,
656            hidden_dim: 8,
657            output_dim: 4,
658            num_scales: 3,
659            scale_factors: vec![1, 2, 4],
660            fusion: ScaleFusion::Concatenate,
661            context_length: 64,
662        }
663    }
664
665    fn make_weighted_config() -> MultiScaleConfig {
666        MultiScaleConfig {
667            input_dim: 4,
668            hidden_dim: 8,
669            output_dim: 4,
670            num_scales: 3,
671            scale_factors: vec![1, 2, 4],
672            fusion: ScaleFusion::Weighted,
673            context_length: 64,
674        }
675    }
676
677    fn make_attention_config() -> MultiScaleConfig {
678        MultiScaleConfig {
679            input_dim: 4,
680            hidden_dim: 8,
681            output_dim: 4,
682            num_scales: 3,
683            scale_factors: vec![1, 2, 4],
684            fusion: ScaleFusion::Attention,
685            context_length: 64,
686        }
687    }
688
689    // 9. test_temporal_scale_decimation
690    #[test]
691    fn test_temporal_scale_decimation() {
692        let decimation = 4;
693        let mut scale =
694            TemporalScale::new(4, 8, decimation).expect("TemporalScale creation failed");
695
696        let input = Array1::from_vec(vec![1.0_f32; 4]);
697
698        // Steps 1, 2, 3 should return None (not a multiple of 4)
699        let r1 = scale.step(&input).expect("step 1 failed");
700        let r2 = scale.step(&input).expect("step 2 failed");
701        let r3 = scale.step(&input).expect("step 3 failed");
702        assert!(r1.is_none(), "step 1 should be None");
703        assert!(r2.is_none(), "step 2 should be None");
704        assert!(r3.is_none(), "step 3 should be None");
705
706        // Step 4 should return Some(state)
707        let r4 = scale.step(&input).expect("step 4 failed");
708        assert!(r4.is_some(), "step 4 should return Some(state)");
709        assert_eq!(r4.as_ref().map(|s| s.len()), Some(8));
710    }
711
712    // 10. test_temporal_scale_continuous_state
713    #[test]
714    fn test_temporal_scale_continuous_state() {
715        let mut scale = TemporalScale::new(4, 8, 1).expect("TemporalScale creation failed");
716
717        let input = Array1::from_vec(vec![0.5_f32; 4]);
718
719        let r1 = scale.step(&input).expect("step 1 failed");
720        assert!(r1.is_some(), "decimation=1 should always return Some");
721
722        let state_after_step1 = scale.current_state().clone();
723
724        let r2 = scale.step(&input).expect("step 2 failed");
725        assert!(r2.is_some());
726
727        let state_after_step2 = scale.current_state().clone();
728
729        // With non-zero input, state should change between steps
730        let diff: f32 = (&state_after_step2 - &state_after_step1)
731            .iter()
732            .map(|v| v.abs())
733            .sum();
734        // State may or may not change (could converge), but state persists
735        assert!(state_after_step1.len() == 8 && state_after_step2.len() == 8);
736        let _ = diff; // diff may be 0 if converged, that's OK
737    }
738
739    // 11. test_multiscale_small
740    #[test]
741    fn test_multiscale_small() {
742        let mut model = MultiScaleModel::small().expect("small model creation failed");
743
744        let input = Array1::from_vec(vec![0.3_f32; 1]);
745        let output = model.forward_step(&input).expect("forward failed");
746
747        assert_eq!(output.len(), 1);
748        assert!(output.iter().all(|v| v.is_finite()));
749    }
750
751    // 12. test_multiscale_base
752    #[test]
753    fn test_multiscale_base() {
754        let mut model = MultiScaleModel::base().expect("base model creation failed");
755
756        let input = Array1::from_vec(vec![0.1_f32; 1]);
757        for _ in 0..10 {
758            let output = model.forward_step(&input).expect("forward failed");
759            assert_eq!(output.len(), 1);
760            assert!(output.iter().all(|v| v.is_finite()));
761        }
762    }
763
764    // 13. test_multiscale_fusion_concat
765    #[test]
766    fn test_multiscale_fusion_concat() {
767        let config = make_concat_config();
768        let output_dim = config.output_dim;
769        let mut model = MultiScaleModel::new(config).expect("model creation failed");
770
771        let input = Array1::from_vec(vec![0.5_f32; 4]);
772        let output = model.forward_step(&input).expect("forward failed");
773
774        assert_eq!(output.len(), output_dim);
775        assert!(output.iter().all(|v| v.is_finite()));
776    }
777
778    // 14. test_multiscale_fusion_weighted
779    #[test]
780    fn test_multiscale_fusion_weighted() {
781        let config = make_weighted_config();
782        let output_dim = config.output_dim;
783        let mut model = MultiScaleModel::new(config).expect("model creation failed");
784
785        let input = Array1::from_vec(vec![0.5_f32; 4]);
786        let output = model.forward_step(&input).expect("forward failed");
787
788        assert_eq!(output.len(), output_dim);
789        assert!(output.iter().all(|v| v.is_finite()));
790    }
791
792    // 14b. test_multiscale_fusion_attention
793    #[test]
794    fn test_multiscale_fusion_attention() {
795        let config = make_attention_config();
796        let output_dim = config.output_dim;
797        let mut model = MultiScaleModel::new(config).expect("model creation failed");
798
799        let input = Array1::from_vec(vec![0.5_f32; 4]);
800        let output = model.forward_step(&input).expect("forward failed");
801
802        assert_eq!(output.len(), output_dim);
803        assert!(output.iter().all(|v| v.is_finite()));
804    }
805
806    // 15. test_multiscale_signal_predictor
807    #[test]
808    fn test_multiscale_signal_predictor() {
809        let config = make_concat_config();
810        let output_dim = config.output_dim;
811        let mut model = MultiScaleModel::new(config).expect("model creation failed");
812
813        let input = Array1::from_vec(vec![0.2_f32; 4]);
814        let output = model.step(&input).expect("SignalPredictor::step failed");
815
816        assert_eq!(output.len(), output_dim);
817        assert!(output.iter().all(|v| v.is_finite()));
818    }
819
820    // 16. test_multiscale_numerical_stability
821    #[test]
822    fn test_multiscale_numerical_stability() {
823        let config = make_weighted_config();
824        let mut model = MultiScaleModel::new(config).expect("model creation failed");
825
826        // Test with zero input
827        let zero_input = Array1::zeros(4);
828        let out_zero = model.forward_step(&zero_input).expect("zero input failed");
829        assert!(
830            out_zero.iter().all(|v| v.is_finite()),
831            "zero input should produce finite output"
832        );
833
834        // Test with moderate large input
835        let large_input = Array1::from_vec(vec![100.0_f32; 4]);
836        let out_large = model.forward_step(&large_input);
837        match out_large {
838            Ok(o) => assert!(
839                o.iter().all(|v| v.is_finite()),
840                "large input should produce finite output"
841            ),
842            Err(ModelError::NumericalInstability { .. }) => {
843                // acceptable for extreme inputs
844            }
845            Err(e) => panic!("unexpected error: {e}"),
846        }
847
848        // Test with very small input
849        let tiny_input = Array1::from_vec(vec![1e-30_f32; 4]);
850        let out_tiny = model.forward_step(&tiny_input).expect("tiny input failed");
851        assert!(
852            out_tiny.iter().all(|v| v.is_finite()),
853            "tiny input should produce finite output"
854        );
855    }
856
857    // AutoregressiveModel trait test
858    #[test]
859    fn test_multiscale_autoregressive_model() {
860        let config = make_concat_config();
861        let model = MultiScaleModel::new(config).expect("model creation failed");
862
863        assert_eq!(model.model_type(), ModelType::MultiScale);
864        assert_eq!(model.num_layers(), 3);
865        assert_eq!(model.hidden_dim(), 8);
866        assert_eq!(model.state_dim(), 24); // 8 * 3
867
868        let states = model.get_states();
869        assert_eq!(states.len(), 3);
870    }
871}