kizzasi_core/
embedding.rs1use crate::error::{CoreError, CoreResult};
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::random::thread_rng;
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ContinuousEmbedding {
14 weights: Array2<f32>,
16 bias: Array1<f32>,
18 input_dim: usize,
20 embed_dim: usize,
22}
23
24impl ContinuousEmbedding {
25 pub fn new(input_dim: usize, embed_dim: usize) -> Self {
27 let mut rng = thread_rng();
28 let scale = (2.0 / (input_dim + embed_dim) as f32).sqrt();
30 let weights = Array2::from_shape_fn((input_dim, embed_dim), |_| {
31 (rng.random::<f32>() - 0.5) * 2.0 * scale
32 });
33 let bias = Array1::zeros(embed_dim);
34
35 Self {
36 weights,
37 bias,
38 input_dim,
39 embed_dim,
40 }
41 }
42
43 pub fn embed(&self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
45 if input.len() != self.input_dim {
46 return Err(CoreError::DimensionMismatch {
47 expected: self.input_dim,
48 got: input.len(),
49 });
50 }
51
52 let output = input.dot(&self.weights) + &self.bias;
54 Ok(output)
55 }
56
57 pub fn embed_dim(&self) -> usize {
59 self.embed_dim
60 }
61
62 pub fn input_dim(&self) -> usize {
64 self.input_dim
65 }
66
67 pub fn layer_norm(input: &Array1<f32>, eps: f32) -> Array1<f32> {
69 let n = input.len() as f32;
70 let mean = if n > 0.0 { input.sum() / n } else { 0.0 };
71 let var = if n > 0.0 {
72 input.mapv(|x| (x - mean).powi(2)).sum() / n
73 } else {
74 1.0
75 };
76 let std = (var + eps).sqrt();
77 input.mapv(|x| (x - mean) / std)
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84
85 #[test]
86 fn test_continuous_embedding() {
87 let embed = ContinuousEmbedding::new(3, 64);
88 let input = Array1::from_vec(vec![0.1, 0.2, 0.3]);
89 let output = embed.embed(&input).expect("embedding should succeed");
90 assert_eq!(output.len(), 64);
91 }
92
93 #[test]
94 fn test_layer_norm() {
95 let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
96 let normed = ContinuousEmbedding::layer_norm(&input, 1e-5);
97 let n = normed.len() as f32;
99 let mean = if n > 0.0 { normed.sum() / n } else { 0.0 };
100 assert!(mean.abs() < 1e-5);
101 }
102}