use std::marker::PhantomData;
use candle_core::{Device, Tensor};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub struct UnknownVectorSpace;
impl VectorSpace for UnknownVectorSpace {}
pub trait VectorSpace {}
pub struct Embedding<S: VectorSpace> {
embedding: Tensor,
model: PhantomData<S>,
}
impl<S: VectorSpace> Embedding<S> {
pub fn cosine_similarity(&self, other: &Self) -> f32 {
let sum_ij = (&other.embedding * &self.embedding)
.unwrap()
.sum_all()
.unwrap()
.to_scalar::<f32>()
.unwrap();
let sum_i2 = (&other.embedding * &other.embedding)
.unwrap()
.sum_all()
.unwrap()
.to_scalar::<f32>()
.unwrap();
let sum_j2 = (&self.embedding * &self.embedding)
.unwrap()
.sum_all()
.unwrap()
.to_scalar::<f32>()
.unwrap();
sum_ij / (sum_i2 * sum_j2).sqrt()
}
}
impl<S: VectorSpace> std::fmt::Debug for Embedding<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Embedding")
.field("embedding", &self.embedding)
.field("model", &std::any::type_name::<S>())
.finish()
}
}
impl<S: VectorSpace> Clone for Embedding<S> {
fn clone(&self) -> Self {
Embedding {
embedding: self.embedding.clone(),
model: PhantomData,
}
}
}
impl<S: VectorSpace> Serialize for Embedding<S> {
fn serialize<Ser: Serializer>(&self, _serializer: Ser) -> Result<Ser::Ok, Ser::Error> {
todo!()
}
}
impl<'de, S: VectorSpace> Deserialize<'de> for Embedding<S> {
fn deserialize<Des: Deserializer<'de>>(_deserializer: Des) -> Result<Self, Des::Error> {
todo!()
}
}
impl<S: VectorSpace, I: IntoIterator<Item = f32>> From<I> for Embedding<S> {
fn from(iter: I) -> Self {
let data: Vec<f32> = iter.into_iter().collect();
let shape = [data.len()];
Embedding {
embedding: Tensor::from_vec(data, &shape, &Device::Cpu).unwrap(),
model: PhantomData,
}
}
}
impl<S1: VectorSpace> Embedding<S1> {
pub fn cast<S2: VectorSpace>(self) -> Embedding<S2> {
Embedding {
embedding: self.embedding,
model: PhantomData,
}
}
}
impl<S: VectorSpace> Embedding<S> {
pub fn new(embedding: Tensor) -> Self {
Embedding {
embedding,
model: PhantomData,
}
}
pub fn vector(&self) -> &Tensor {
&self.embedding
}
pub fn to_vec(&self) -> Vec<f32> {
self.embedding
.flatten_to(1)
.unwrap()
.to_vec1::<f32>()
.unwrap()
}
}