use crate::error::{CoreError, CoreResult};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::thread_rng;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContinuousEmbedding {
weights: Array2<f32>,
bias: Array1<f32>,
input_dim: usize,
embed_dim: usize,
}
impl ContinuousEmbedding {
pub fn new(input_dim: usize, embed_dim: usize) -> Self {
let mut rng = thread_rng();
let scale = (2.0 / (input_dim + embed_dim) as f32).sqrt();
let weights = Array2::from_shape_fn((input_dim, embed_dim), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * scale
});
let bias = Array1::zeros(embed_dim);
Self {
weights,
bias,
input_dim,
embed_dim,
}
}
pub fn embed(&self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
if input.len() != self.input_dim {
return Err(CoreError::DimensionMismatch {
expected: self.input_dim,
got: input.len(),
});
}
let output = input.dot(&self.weights) + &self.bias;
Ok(output)
}
pub fn embed_dim(&self) -> usize {
self.embed_dim
}
pub fn input_dim(&self) -> usize {
self.input_dim
}
pub fn layer_norm(input: &Array1<f32>, eps: f32) -> Array1<f32> {
let n = input.len() as f32;
let mean = if n > 0.0 { input.sum() / n } else { 0.0 };
let var = if n > 0.0 {
input.mapv(|x| (x - mean).powi(2)).sum() / n
} else {
1.0
};
let std = (var + eps).sqrt();
input.mapv(|x| (x - mean) / std)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_continuous_embedding() {
let embed = ContinuousEmbedding::new(3, 64);
let input = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let output = embed.embed(&input).expect("embedding should succeed");
assert_eq!(output.len(), 64);
}
#[test]
fn test_layer_norm() {
let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let normed = ContinuousEmbedding::layer_norm(&input, 1e-5);
let n = normed.len() as f32;
let mean = if n > 0.0 { normed.sum() / n } else { 0.0 };
assert!(mean.abs() < 1e-5);
}
}