use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Tensor {
dimension: usize,
data: Vec<f32>,
}
impl Tensor {
pub fn new(values: Vec<f32>) -> Self {
let dimension = values.len();
Self { dimension, data: values }
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn as_f32(&self) -> &[f32] {
&self.data
}
pub fn to_f32(&self) -> Vec<f32> {
self.data.clone()
}
pub fn cosine_similarity(&self, other: &Tensor) -> f32 {
assert_eq!(self.dimension, other.dimension, "Dimension mismatch");
let a = &self.data;
let b = &other.data;
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
pub fn l2_distance(&self, other: &Tensor) -> f32 {
assert_eq!(self.dimension, other.dimension, "Dimension mismatch");
let a = &self.data;
let b = &other.data;
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
pub fn memory_size(&self) -> usize {
self.dimension * std::mem::size_of::<f32>()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_creation() {
let values = vec![1.0, 2.0, 3.0];
let tensor = Tensor::new(values.clone());
assert_eq!(tensor.dimension(), 3);
let reconstructed = tensor.to_f32();
for (a, b) in values.iter().zip(reconstructed.iter()) {
assert!((a - b).abs() < 0.001);
}
}
#[test]
fn test_cosine_similarity() {
let t1 = Tensor::new(vec![1.0, 0.0, 0.0]);
let t2 = Tensor::new(vec![1.0, 0.0, 0.0]);
let t3 = Tensor::new(vec![0.0, 1.0, 0.0]);
assert!((t1.cosine_similarity(&t2) - 1.0).abs() < 0.01);
assert!((t1.cosine_similarity(&t3) - 0.0).abs() < 0.01);
}
#[test]
fn test_l2_distance() {
let t1 = Tensor::new(vec![0.0, 0.0, 0.0]);
let t2 = Tensor::new(vec![3.0, 4.0, 0.0]);
assert!((t1.l2_distance(&t2) - 5.0).abs() < 0.001);
}
}