kizzasi_core/
embedding.rs

1//! Continuous signal embeddings
2//!
3//! Unlike discrete token embeddings in LLMs, this module provides
4//! direct mapping of continuous float values to latent space.
5
6use crate::error::{CoreError, CoreResult};
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::random::thread_rng;
9use serde::{Deserialize, Serialize};
10
11/// Continuous embedding layer for signal values
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ContinuousEmbedding {
14    /// Linear projection weights
15    weights: Array2<f32>,
16    /// Bias terms
17    bias: Array1<f32>,
18    /// Input dimension
19    input_dim: usize,
20    /// Output (embedding) dimension
21    embed_dim: usize,
22}
23
24impl ContinuousEmbedding {
25    /// Create a new continuous embedding layer
26    pub fn new(input_dim: usize, embed_dim: usize) -> Self {
27        let mut rng = thread_rng();
28        // Xavier initialization
29        let scale = (2.0 / (input_dim + embed_dim) as f32).sqrt();
30        let weights = Array2::from_shape_fn((input_dim, embed_dim), |_| {
31            (rng.random::<f32>() - 0.5) * 2.0 * scale
32        });
33        let bias = Array1::zeros(embed_dim);
34
35        Self {
36            weights,
37            bias,
38            input_dim,
39            embed_dim,
40        }
41    }
42
43    /// Embed a continuous signal vector
44    pub fn embed(&self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
45        if input.len() != self.input_dim {
46            return Err(CoreError::DimensionMismatch {
47                expected: self.input_dim,
48                got: input.len(),
49            });
50        }
51
52        // Linear projection: output = input @ weights + bias
53        let output = input.dot(&self.weights) + &self.bias;
54        Ok(output)
55    }
56
57    /// Get the embedding dimension
58    pub fn embed_dim(&self) -> usize {
59        self.embed_dim
60    }
61
62    /// Get the input dimension
63    pub fn input_dim(&self) -> usize {
64        self.input_dim
65    }
66
67    /// Apply layer normalization
68    pub fn layer_norm(input: &Array1<f32>, eps: f32) -> Array1<f32> {
69        let n = input.len() as f32;
70        let mean = if n > 0.0 { input.sum() / n } else { 0.0 };
71        let var = if n > 0.0 {
72            input.mapv(|x| (x - mean).powi(2)).sum() / n
73        } else {
74            1.0
75        };
76        let std = (var + eps).sqrt();
77        input.mapv(|x| (x - mean) / std)
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84
85    #[test]
86    fn test_continuous_embedding() {
87        let embed = ContinuousEmbedding::new(3, 64);
88        let input = Array1::from_vec(vec![0.1, 0.2, 0.3]);
89        let output = embed.embed(&input).expect("embedding should succeed");
90        assert_eq!(output.len(), 64);
91    }
92
93    #[test]
94    fn test_layer_norm() {
95        let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
96        let normed = ContinuousEmbedding::layer_norm(&input, 1e-5);
97        // Check that mean is approximately 0
98        let n = normed.len() as f32;
99        let mean = if n > 0.0 { normed.sum() / n } else { 0.0 };
100        assert!(mean.abs() < 1e-5);
101    }
102}