1use crate::error::{ModelError, ModelResult};
37use crate::{AutoregressiveModel, ModelType};
38use kizzasi_core::{CoreResult, HiddenState, SignalPredictor};
39use scirs2_core::ndarray::{Array1, Array2};
40use serde::{Deserialize, Serialize};
41
42#[allow(unused_imports)]
43use tracing::{debug, instrument, trace};
44
45struct SeededRng {
50 state: u64,
51}
52
53impl SeededRng {
54 fn new(seed: u64) -> Self {
55 Self { state: seed.max(1) }
56 }
57
58 fn next_f32(&mut self) -> f32 {
59 self.state ^= self.state << 13;
60 self.state ^= self.state >> 7;
61 self.state ^= self.state << 17;
62 (self.state as f64 / u64::MAX as f64 * 2.0 - 1.0) as f32
63 }
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
72pub enum ScaleFusion {
73 Concatenate,
75 Weighted,
77 Attention,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct MultiScaleConfig {
88 pub input_dim: usize,
90 pub hidden_dim: usize,
92 pub output_dim: usize,
94 pub num_scales: usize,
96 pub scale_factors: Vec<usize>,
99 pub fusion: ScaleFusion,
101 pub context_length: usize,
103}
104
105impl MultiScaleConfig {
106 pub fn validate(&self) -> ModelResult<()> {
108 if self.input_dim == 0 {
109 return Err(ModelError::invalid_config("input_dim must be > 0"));
110 }
111 if self.hidden_dim == 0 {
112 return Err(ModelError::invalid_config("hidden_dim must be > 0"));
113 }
114 if self.output_dim == 0 {
115 return Err(ModelError::invalid_config("output_dim must be > 0"));
116 }
117 if self.num_scales == 0 {
118 return Err(ModelError::invalid_config("num_scales must be > 0"));
119 }
120 if self.scale_factors.len() != self.num_scales {
121 return Err(ModelError::invalid_config(
122 "scale_factors.len() must equal num_scales",
123 ));
124 }
125 for &sf in &self.scale_factors {
126 if sf == 0 {
127 return Err(ModelError::invalid_config("all scale_factors must be > 0"));
128 }
129 }
130 Ok(())
131 }
132}
133
134pub struct TemporalScale {
143 hidden_dim: usize,
144 decimation: usize,
146 projection: Array2<f32>,
148 recurrent: Array2<f32>,
150 bias: Array1<f32>,
152 tick_counter: usize,
154 state: Array1<f32>,
156}
157
158impl TemporalScale {
159 pub fn new(input_dim: usize, hidden_dim: usize, decimation: usize) -> ModelResult<Self> {
161 if input_dim == 0 || hidden_dim == 0 || decimation == 0 {
162 return Err(ModelError::invalid_config(
163 "TemporalScale dimensions and decimation must be > 0",
164 ));
165 }
166
167 let scale_input = (2.0 / (input_dim + hidden_dim) as f32).sqrt();
168 let scale_rec = (2.0 / (hidden_dim + hidden_dim) as f32).sqrt();
169 let seed = ((input_dim + hidden_dim * 37 + decimation * 997) as u64)
170 .wrapping_mul(6364136223846793005);
171 let mut rng = SeededRng::new(seed);
172
173 let projection =
174 Array2::from_shape_fn((hidden_dim, input_dim), |_| rng.next_f32() * scale_input);
175 let recurrent =
176 Array2::from_shape_fn((hidden_dim, hidden_dim), |_| rng.next_f32() * scale_rec);
177 let bias = Array1::from_shape_fn(hidden_dim, |_| rng.next_f32() * 0.01);
178
179 Ok(Self {
180 hidden_dim,
181 decimation,
182 projection,
183 recurrent,
184 bias,
185 tick_counter: 0,
186 state: Array1::zeros(hidden_dim),
187 })
188 }
189
190 #[instrument(skip(self, input), fields(decimation = self.decimation, tick = self.tick_counter))]
195 pub fn step(&mut self, input: &Array1<f32>) -> ModelResult<Option<Array1<f32>>> {
196 self.tick_counter += 1;
197
198 if !self.tick_counter.is_multiple_of(self.decimation) {
199 return Ok(None);
200 }
201
202 let proj_out = self.projection.dot(input);
204 let rec_out = self.recurrent.dot(&self.state);
205 let pre_act = proj_out + rec_out + &self.bias;
206 let new_state = pre_act.mapv(f32::tanh);
207
208 if new_state.iter().any(|v| !v.is_finite()) {
209 return Err(ModelError::numerical_instability(
210 "TemporalScale::step",
211 "NaN or Inf in state update",
212 ));
213 }
214
215 self.state = new_state.clone();
216 Ok(Some(new_state))
217 }
218
219 pub fn current_state(&self) -> &Array1<f32> {
221 &self.state
222 }
223
224 pub fn reset(&mut self) {
226 self.tick_counter = 0;
227 self.state.fill(0.0);
228 }
229
230 pub fn hidden_dim(&self) -> usize {
232 self.hidden_dim
233 }
234
235 pub fn decimation(&self) -> usize {
237 self.decimation
238 }
239}
240
241struct ScaleFusionLayer {
247 fusion: ScaleFusion,
248 concat_proj: Option<Array2<f32>>,
250 scale_weights: Option<Array1<f32>>,
252 attn_q: Option<Array2<f32>>,
254 attn_k: Option<Array2<f32>>,
256 attn_v: Option<Array2<f32>>,
258 num_scales: usize,
259 hidden_dim: usize,
260}
261
262impl ScaleFusionLayer {
263 fn new(
264 fusion: ScaleFusion,
265 num_scales: usize,
266 hidden_dim: usize,
267 seed: u64,
268 ) -> ModelResult<Self> {
269 if num_scales == 0 || hidden_dim == 0 {
270 return Err(ModelError::invalid_config(
271 "ScaleFusionLayer: num_scales and hidden_dim must be > 0",
272 ));
273 }
274
275 let mut rng = SeededRng::new(seed);
276 let scale = (2.0 / (hidden_dim * 2) as f32).sqrt();
277
278 let (concat_proj, scale_weights, attn_q, attn_k, attn_v) = match &fusion {
279 ScaleFusion::Concatenate => {
280 let in_dim = num_scales * hidden_dim;
281 let proj_scale = (2.0 / (in_dim + hidden_dim) as f32).sqrt();
282 let proj =
283 Array2::from_shape_fn((hidden_dim, in_dim), |_| rng.next_f32() * proj_scale);
284 (Some(proj), None, None, None, None)
285 }
286 ScaleFusion::Weighted => {
287 let weights = Array1::zeros(num_scales);
289 (None, Some(weights), None, None, None)
290 }
291 ScaleFusion::Attention => {
292 let q = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| rng.next_f32() * scale);
293 let k = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| rng.next_f32() * scale);
294 let v = Array2::from_shape_fn((hidden_dim, hidden_dim), |_| rng.next_f32() * scale);
295 (None, None, Some(q), Some(k), Some(v))
296 }
297 };
298
299 Ok(Self {
300 fusion,
301 concat_proj,
302 scale_weights,
303 attn_q,
304 attn_k,
305 attn_v,
306 num_scales,
307 hidden_dim,
308 })
309 }
310
311 fn fuse(&self, scale_states: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
313 if scale_states.len() != self.num_scales {
314 return Err(ModelError::dimension_mismatch(
315 "ScaleFusionLayer::fuse",
316 self.num_scales,
317 scale_states.len(),
318 ));
319 }
320
321 match &self.fusion {
322 ScaleFusion::Concatenate => self.fuse_concatenate(scale_states),
323 ScaleFusion::Weighted => self.fuse_weighted(scale_states),
324 ScaleFusion::Attention => self.fuse_attention(scale_states),
325 }
326 }
327
328 fn fuse_concatenate(&self, scale_states: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
329 let proj = self.concat_proj.as_ref().ok_or_else(|| {
330 ModelError::not_initialized("concat_proj missing for Concatenate fusion")
331 })?;
332
333 let total_dim = self.num_scales * self.hidden_dim;
335 let mut concat = Array1::<f32>::zeros(total_dim);
336 for (i, state) in scale_states.iter().enumerate() {
337 let start = i * self.hidden_dim;
338 let end = start + self.hidden_dim;
339 if state.len() != self.hidden_dim {
340 return Err(ModelError::dimension_mismatch(
341 format!("scale {i} state"),
342 self.hidden_dim,
343 state.len(),
344 ));
345 }
346 concat
347 .slice_mut(scirs2_core::ndarray::s![start..end])
348 .assign(state);
349 }
350
351 Ok(proj.dot(&concat))
352 }
353
354 fn fuse_weighted(&self, scale_states: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
355 let log_weights = self.scale_weights.as_ref().ok_or_else(|| {
356 ModelError::not_initialized("scale_weights missing for Weighted fusion")
357 })?;
358
359 let max_w = log_weights
361 .iter()
362 .cloned()
363 .fold(f32::NEG_INFINITY, f32::max);
364 let exp_w: Vec<f32> = log_weights.iter().map(|&w| (w - max_w).exp()).collect();
365 let sum_exp: f32 = exp_w.iter().sum();
366 let norm_weights: Vec<f32> = exp_w.iter().map(|&e| e / sum_exp).collect();
367
368 let mut result = Array1::<f32>::zeros(self.hidden_dim);
369 for (state, &w) in scale_states.iter().zip(norm_weights.iter()) {
370 if state.len() != self.hidden_dim {
371 return Err(ModelError::dimension_mismatch(
372 "weighted scale state",
373 self.hidden_dim,
374 state.len(),
375 ));
376 }
377 result = result + state * w;
378 }
379 Ok(result)
380 }
381
382 fn fuse_attention(&self, scale_states: &[Array1<f32>]) -> ModelResult<Array1<f32>> {
383 let q_proj = self
384 .attn_q
385 .as_ref()
386 .ok_or_else(|| ModelError::not_initialized("attn_q missing for Attention fusion"))?;
387 let k_proj = self
388 .attn_k
389 .as_ref()
390 .ok_or_else(|| ModelError::not_initialized("attn_k missing for Attention fusion"))?;
391 let v_proj = self
392 .attn_v
393 .as_ref()
394 .ok_or_else(|| ModelError::not_initialized("attn_v missing for Attention fusion"))?;
395
396 let mut mean_state = Array1::<f32>::zeros(self.hidden_dim);
398 for state in scale_states {
399 if state.len() != self.hidden_dim {
400 return Err(ModelError::dimension_mismatch(
401 "attention scale state",
402 self.hidden_dim,
403 state.len(),
404 ));
405 }
406 mean_state += state;
407 }
408 mean_state.mapv_inplace(|v| v / self.num_scales as f32);
409
410 let query = q_proj.dot(&mean_state); let scale_factor = (self.hidden_dim as f32).sqrt();
412
413 let mut scores = Vec::with_capacity(self.num_scales);
415 for state in scale_states {
416 let key_i = k_proj.dot(state);
417 let score = query.dot(&key_i) / scale_factor;
418 scores.push(score);
419 }
420
421 let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
423 let exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
424 let sum_exp: f32 = exp_scores.iter().sum();
425 let attn_weights: Vec<f32> = exp_scores.iter().map(|&e| e / sum_exp).collect();
426
427 let mut result = Array1::<f32>::zeros(self.hidden_dim);
429 for (state, &w) in scale_states.iter().zip(attn_weights.iter()) {
430 let value_i = v_proj.dot(state);
431 result = result + value_i * w;
432 }
433 Ok(result)
434 }
435}
436
437pub struct MultiScaleModel {
443 pub config: MultiScaleConfig,
445 scales: Vec<TemporalScale>,
447 fusion_layer: ScaleFusionLayer,
449 output_proj: Array2<f32>,
451 output_bias: Array1<f32>,
453 last_scale_outputs: Vec<Array1<f32>>,
455}
456
457impl MultiScaleModel {
458 #[instrument(skip(config), fields(scales = config.num_scales, hidden = config.hidden_dim))]
460 pub fn new(config: MultiScaleConfig) -> ModelResult<Self> {
461 config.validate()?;
462 debug!(
463 "Building MultiScaleModel: {} scales at {:?}",
464 config.num_scales, config.scale_factors
465 );
466
467 let mut scales = Vec::with_capacity(config.num_scales);
468 for (i, &decimation) in config.scale_factors.iter().enumerate() {
469 let seed = ((i + 1) as u64).wrapping_mul(6364136223846793005);
470 let _ = seed; scales.push(TemporalScale::new(
472 config.input_dim,
473 config.hidden_dim,
474 decimation,
475 )?);
476 }
477
478 let fusion_seed = (config.num_scales as u64 * 1000 + config.hidden_dim as u64)
479 .wrapping_mul(2862933555777941757);
480 let fusion_layer = ScaleFusionLayer::new(
481 config.fusion.clone(),
482 config.num_scales,
483 config.hidden_dim,
484 fusion_seed,
485 )?;
486
487 let out_scale = (2.0 / (config.hidden_dim + config.output_dim) as f32).sqrt();
488 let mut rng = SeededRng::new(
489 ((config.hidden_dim * 7919 + config.output_dim) as u64)
490 .wrapping_mul(6364136223846793005),
491 );
492 let output_proj = Array2::from_shape_fn((config.output_dim, config.hidden_dim), |_| {
493 rng.next_f32() * out_scale
494 });
495 let output_bias = Array1::from_shape_fn(config.output_dim, |_| rng.next_f32() * 0.01);
496
497 let last_scale_outputs = vec![Array1::zeros(config.hidden_dim); config.num_scales];
498
499 debug!("MultiScaleModel built successfully");
500 Ok(Self {
501 config,
502 scales,
503 fusion_layer,
504 output_proj,
505 output_bias,
506 last_scale_outputs,
507 })
508 }
509
510 pub fn small() -> ModelResult<Self> {
512 let config = MultiScaleConfig {
513 input_dim: 1,
514 hidden_dim: 32,
515 output_dim: 1,
516 num_scales: 3,
517 scale_factors: vec![1, 4, 16],
518 fusion: ScaleFusion::Concatenate,
519 context_length: 512,
520 };
521 Self::new(config)
522 }
523
524 pub fn base() -> ModelResult<Self> {
526 let config = MultiScaleConfig {
527 input_dim: 1,
528 hidden_dim: 64,
529 output_dim: 1,
530 num_scales: 4,
531 scale_factors: vec![1, 2, 8, 32],
532 fusion: ScaleFusion::Weighted,
533 context_length: 2048,
534 };
535 Self::new(config)
536 }
537
538 fn forward_step(&mut self, input: &Array1<f32>) -> ModelResult<Array1<f32>> {
540 if input.len() != self.config.input_dim {
541 return Err(ModelError::dimension_mismatch(
542 "MultiScaleModel input",
543 self.config.input_dim,
544 input.len(),
545 ));
546 }
547
548 for (i, scale) in self.scales.iter_mut().enumerate() {
550 if let Some(new_state) = scale.step(input)? {
551 self.last_scale_outputs[i] = new_state;
552 }
553 }
554
555 let fused = self.fusion_layer.fuse(&self.last_scale_outputs)?;
557
558 let output = self.output_proj.dot(&fused) + &self.output_bias;
560
561 if output.iter().any(|v| !v.is_finite()) {
562 return Err(ModelError::numerical_instability(
563 "MultiScaleModel output",
564 "NaN or Inf detected",
565 ));
566 }
567
568 Ok(output)
569 }
570}
571
572impl SignalPredictor for MultiScaleModel {
573 #[instrument(skip(self, input))]
574 fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
575 self.forward_step(input)
576 .map_err(|e| kizzasi_core::CoreError::Generic(e.to_string()))
577 }
578
579 #[instrument(skip(self))]
580 fn reset(&mut self) {
581 debug!("Resetting MultiScaleModel state");
582 for scale in &mut self.scales {
583 scale.reset();
584 }
585 for output in &mut self.last_scale_outputs {
586 output.fill(0.0);
587 }
588 }
589
590 fn context_window(&self) -> usize {
591 self.config.context_length
592 }
593}
594
595impl AutoregressiveModel for MultiScaleModel {
596 fn hidden_dim(&self) -> usize {
597 self.config.hidden_dim
598 }
599
600 fn state_dim(&self) -> usize {
601 self.config.hidden_dim * self.config.num_scales
603 }
604
605 fn num_layers(&self) -> usize {
606 self.config.num_scales
607 }
608
609 fn model_type(&self) -> ModelType {
610 ModelType::MultiScale
611 }
612
613 fn get_states(&self) -> Vec<HiddenState> {
614 self.scales
615 .iter()
616 .map(|scale| {
617 let state = scale.current_state().clone();
618 let dim = state.len();
619 let state_2d = state.insert_axis(scirs2_core::ndarray::Axis(0));
620 let mut hidden = HiddenState::new(dim, 1);
621 hidden.update(state_2d);
622 hidden
623 })
624 .collect()
625 }
626
627 fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
628 if states.len() != self.config.num_scales {
629 return Err(ModelError::state_count_mismatch(
630 "MultiScale",
631 self.config.num_scales,
632 states.len(),
633 ));
634 }
635 for (scale, hidden) in self.scales.iter_mut().zip(states.iter()) {
636 let state_2d = hidden.state();
637 if state_2d.nrows() > 0 && state_2d.ncols() > 0 {
638 scale.state = state_2d.row(0).to_owned();
639 }
640 }
641 Ok(())
642 }
643}
644
645#[cfg(test)]
650mod tests {
651 use super::*;
652
653 fn make_concat_config() -> MultiScaleConfig {
654 MultiScaleConfig {
655 input_dim: 4,
656 hidden_dim: 8,
657 output_dim: 4,
658 num_scales: 3,
659 scale_factors: vec![1, 2, 4],
660 fusion: ScaleFusion::Concatenate,
661 context_length: 64,
662 }
663 }
664
665 fn make_weighted_config() -> MultiScaleConfig {
666 MultiScaleConfig {
667 input_dim: 4,
668 hidden_dim: 8,
669 output_dim: 4,
670 num_scales: 3,
671 scale_factors: vec![1, 2, 4],
672 fusion: ScaleFusion::Weighted,
673 context_length: 64,
674 }
675 }
676
677 fn make_attention_config() -> MultiScaleConfig {
678 MultiScaleConfig {
679 input_dim: 4,
680 hidden_dim: 8,
681 output_dim: 4,
682 num_scales: 3,
683 scale_factors: vec![1, 2, 4],
684 fusion: ScaleFusion::Attention,
685 context_length: 64,
686 }
687 }
688
689 #[test]
691 fn test_temporal_scale_decimation() {
692 let decimation = 4;
693 let mut scale =
694 TemporalScale::new(4, 8, decimation).expect("TemporalScale creation failed");
695
696 let input = Array1::from_vec(vec![1.0_f32; 4]);
697
698 let r1 = scale.step(&input).expect("step 1 failed");
700 let r2 = scale.step(&input).expect("step 2 failed");
701 let r3 = scale.step(&input).expect("step 3 failed");
702 assert!(r1.is_none(), "step 1 should be None");
703 assert!(r2.is_none(), "step 2 should be None");
704 assert!(r3.is_none(), "step 3 should be None");
705
706 let r4 = scale.step(&input).expect("step 4 failed");
708 assert!(r4.is_some(), "step 4 should return Some(state)");
709 assert_eq!(r4.as_ref().map(|s| s.len()), Some(8));
710 }
711
712 #[test]
714 fn test_temporal_scale_continuous_state() {
715 let mut scale = TemporalScale::new(4, 8, 1).expect("TemporalScale creation failed");
716
717 let input = Array1::from_vec(vec![0.5_f32; 4]);
718
719 let r1 = scale.step(&input).expect("step 1 failed");
720 assert!(r1.is_some(), "decimation=1 should always return Some");
721
722 let state_after_step1 = scale.current_state().clone();
723
724 let r2 = scale.step(&input).expect("step 2 failed");
725 assert!(r2.is_some());
726
727 let state_after_step2 = scale.current_state().clone();
728
729 let diff: f32 = (&state_after_step2 - &state_after_step1)
731 .iter()
732 .map(|v| v.abs())
733 .sum();
734 assert!(state_after_step1.len() == 8 && state_after_step2.len() == 8);
736 let _ = diff; }
738
739 #[test]
741 fn test_multiscale_small() {
742 let mut model = MultiScaleModel::small().expect("small model creation failed");
743
744 let input = Array1::from_vec(vec![0.3_f32; 1]);
745 let output = model.forward_step(&input).expect("forward failed");
746
747 assert_eq!(output.len(), 1);
748 assert!(output.iter().all(|v| v.is_finite()));
749 }
750
751 #[test]
753 fn test_multiscale_base() {
754 let mut model = MultiScaleModel::base().expect("base model creation failed");
755
756 let input = Array1::from_vec(vec![0.1_f32; 1]);
757 for _ in 0..10 {
758 let output = model.forward_step(&input).expect("forward failed");
759 assert_eq!(output.len(), 1);
760 assert!(output.iter().all(|v| v.is_finite()));
761 }
762 }
763
764 #[test]
766 fn test_multiscale_fusion_concat() {
767 let config = make_concat_config();
768 let output_dim = config.output_dim;
769 let mut model = MultiScaleModel::new(config).expect("model creation failed");
770
771 let input = Array1::from_vec(vec![0.5_f32; 4]);
772 let output = model.forward_step(&input).expect("forward failed");
773
774 assert_eq!(output.len(), output_dim);
775 assert!(output.iter().all(|v| v.is_finite()));
776 }
777
778 #[test]
780 fn test_multiscale_fusion_weighted() {
781 let config = make_weighted_config();
782 let output_dim = config.output_dim;
783 let mut model = MultiScaleModel::new(config).expect("model creation failed");
784
785 let input = Array1::from_vec(vec![0.5_f32; 4]);
786 let output = model.forward_step(&input).expect("forward failed");
787
788 assert_eq!(output.len(), output_dim);
789 assert!(output.iter().all(|v| v.is_finite()));
790 }
791
792 #[test]
794 fn test_multiscale_fusion_attention() {
795 let config = make_attention_config();
796 let output_dim = config.output_dim;
797 let mut model = MultiScaleModel::new(config).expect("model creation failed");
798
799 let input = Array1::from_vec(vec![0.5_f32; 4]);
800 let output = model.forward_step(&input).expect("forward failed");
801
802 assert_eq!(output.len(), output_dim);
803 assert!(output.iter().all(|v| v.is_finite()));
804 }
805
806 #[test]
808 fn test_multiscale_signal_predictor() {
809 let config = make_concat_config();
810 let output_dim = config.output_dim;
811 let mut model = MultiScaleModel::new(config).expect("model creation failed");
812
813 let input = Array1::from_vec(vec![0.2_f32; 4]);
814 let output = model.step(&input).expect("SignalPredictor::step failed");
815
816 assert_eq!(output.len(), output_dim);
817 assert!(output.iter().all(|v| v.is_finite()));
818 }
819
820 #[test]
822 fn test_multiscale_numerical_stability() {
823 let config = make_weighted_config();
824 let mut model = MultiScaleModel::new(config).expect("model creation failed");
825
826 let zero_input = Array1::zeros(4);
828 let out_zero = model.forward_step(&zero_input).expect("zero input failed");
829 assert!(
830 out_zero.iter().all(|v| v.is_finite()),
831 "zero input should produce finite output"
832 );
833
834 let large_input = Array1::from_vec(vec![100.0_f32; 4]);
836 let out_large = model.forward_step(&large_input);
837 match out_large {
838 Ok(o) => assert!(
839 o.iter().all(|v| v.is_finite()),
840 "large input should produce finite output"
841 ),
842 Err(ModelError::NumericalInstability { .. }) => {
843 }
845 Err(e) => panic!("unexpected error: {e}"),
846 }
847
848 let tiny_input = Array1::from_vec(vec![1e-30_f32; 4]);
850 let out_tiny = model.forward_step(&tiny_input).expect("tiny input failed");
851 assert!(
852 out_tiny.iter().all(|v| v.is_finite()),
853 "tiny input should produce finite output"
854 );
855 }
856
857 #[test]
859 fn test_multiscale_autoregressive_model() {
860 let config = make_concat_config();
861 let model = MultiScaleModel::new(config).expect("model creation failed");
862
863 assert_eq!(model.model_type(), ModelType::MultiScale);
864 assert_eq!(model.num_layers(), 3);
865 assert_eq!(model.hidden_dim(), 8);
866 assert_eq!(model.state_dim(), 24); let states = model.get_states();
869 assert_eq!(states.len(), 3);
870 }
871}