use std::convert::Infallible;
use anybytes::View;
use triblespace_core::blob::{Blob, BlobSchema, ToBlob, TryFromBlob};
use triblespace_core::id::Id;
use triblespace_core::id_hex;
use triblespace_core::metadata::ConstId;
use triblespace_core::value::{ToValue, TryFromValue, Value, ValueSchema};
pub enum F32LE {}
impl ConstId for F32LE {
const ID: Id = id_hex!("816B4751EA8C12644CCB572F36188EBA");
}
impl ValueSchema for F32LE {
type ValidationError = Infallible;
}
impl ToValue<F32LE> for f32 {
fn to_value(self) -> Value<F32LE> {
let mut raw = [0u8; 32];
raw[0..4].copy_from_slice(&self.to_le_bytes());
Value::new(raw)
}
}
impl ToValue<F32LE> for &f32 {
fn to_value(self) -> Value<F32LE> {
(*self).to_value()
}
}
impl TryFromValue<'_, F32LE> for f32 {
type Error = Infallible;
fn try_from_value(value: &Value<F32LE>) -> Result<Self, Self::Error> {
Ok(f32::from_le_bytes(value.raw[0..4].try_into().unwrap()))
}
}
pub struct Embedding {}
impl ConstId for Embedding {
const ID: Id = id_hex!("EEC5DFDEA2FFCED70850DF83B03CB62B");
}
impl BlobSchema for Embedding {}
impl triblespace_core::metadata::ConstDescribe for Embedding {}
pub type EmbHandle =
triblespace_core::value::schemas::hash::Handle<
triblespace_core::value::schemas::hash::Blake3,
Embedding,
>;
impl TryFromBlob<Embedding> for View<[f32]> {
type Error = anybytes::view::ViewError;
fn try_from_blob(b: Blob<Embedding>) -> Result<Self, Self::Error> {
b.bytes.view()
}
}
impl ToBlob<Embedding> for View<[f32]> {
fn to_blob(self) -> Blob<Embedding> {
Blob::new(self.bytes())
}
}
impl ToBlob<Embedding> for Vec<f32> {
fn to_blob(self) -> Blob<Embedding> {
let mut bytes = Vec::with_capacity(self.len() * 4);
for v in &self {
bytes.extend_from_slice(&v.to_le_bytes());
}
Blob::new(bytes.into())
}
}
impl ToBlob<Embedding> for &[f32] {
fn to_blob(self) -> Blob<Embedding> {
let mut bytes = Vec::with_capacity(self.len() * 4);
for v in self {
bytes.extend_from_slice(&v.to_le_bytes());
}
Blob::new(bytes.into())
}
}
pub fn l2_normalize(vec: &mut [f32]) {
let norm: f32 = vec.iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
let inv = 1.0 / norm;
for v in vec.iter_mut() {
*v *= inv;
}
}
}
pub fn put_embedding<B, H>(
store: &mut B,
mut vec: Vec<f32>,
) -> Result<triblespace_core::value::Value<triblespace_core::value::schemas::hash::Handle<H, Embedding>>, B::PutError>
where
H: triblespace_core::value::schemas::hash::HashProtocol,
B: triblespace_core::repo::BlobStorePut<H>,
triblespace_core::value::schemas::hash::Handle<H, Embedding>:
triblespace_core::value::ValueSchema,
{
l2_normalize(&mut vec);
store.put::<Embedding, _>(vec)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip_positive() {
let original: f32 = 0.123;
let v: Value<F32LE> = original.to_value();
let back: f32 = f32::try_from_value(&v).unwrap();
assert_eq!(original, back);
}
#[test]
fn round_trip_negative() {
let original: f32 = -42.75;
let v: Value<F32LE> = original.to_value();
let back: f32 = f32::try_from_value(&v).unwrap();
assert_eq!(original, back);
}
#[test]
fn round_trip_zero() {
let original: f32 = 0.0;
let v: Value<F32LE> = original.to_value();
let back: f32 = f32::try_from_value(&v).unwrap();
assert_eq!(original.to_bits(), back.to_bits());
}
#[test]
fn round_trip_nan() {
let original: f32 = f32::NAN;
let v: Value<F32LE> = original.to_value();
let back: f32 = f32::try_from_value(&v).unwrap();
assert!(back.is_nan());
}
#[test]
fn padding_is_zero() {
let v: Value<F32LE> = 2.5f32.to_value();
assert_eq!(&v.raw[4..32], &[0u8; 28]);
}
#[test]
fn deterministic_same_input_same_value() {
let a: Value<F32LE> = 1.5f32.to_value();
let b: Value<F32LE> = 1.5f32.to_value();
assert_eq!(a.raw, b.raw);
}
#[test]
fn embedding_blob_round_trip() {
let original: Vec<f32> = vec![0.1, -0.5, 3.25, f32::consts::PI];
let blob: Blob<Embedding> = original.clone().to_blob();
let view: View<[f32]> = TryFromBlob::try_from_blob(blob).unwrap();
assert_eq!(view.as_ref(), original.as_slice());
}
#[test]
fn put_embedding_roundtrips_through_memory_store() {
use triblespace_core::blob::MemoryBlobStore;
use triblespace_core::repo::{BlobStore, BlobStoreGet};
use triblespace_core::value::schemas::hash::Blake3;
let mut store = MemoryBlobStore::<Blake3>::new();
let vec = vec![1.0_f32, 0.0, 0.0];
let handle = put_embedding::<_, Blake3>(&mut store, vec.clone()).unwrap();
let reader = store.reader().unwrap();
let view: View<[f32]> = reader.get::<View<[f32]>, Embedding>(handle).unwrap();
assert_eq!(view.as_ref(), &[1.0_f32, 0.0, 0.0]);
}
#[test]
fn embedding_handle_is_content_addressed() {
use triblespace_core::value::schemas::hash::{Blake3, Handle};
let v1: Vec<f32> = vec![1.0, 2.0, 3.0];
let v2: Vec<f32> = vec![1.0, 2.0, 3.0];
let v3: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let h1: Value<Handle<Blake3, Embedding>> = v1.to_blob().get_handle();
let h2: Value<Handle<Blake3, Embedding>> = v2.to_blob().get_handle();
let h3: Value<Handle<Blake3, Embedding>> = v3.to_blob().get_handle();
assert_eq!(h1, h2, "identical vectors must dedup by handle");
assert_ne!(h1, h3, "different vectors must have different handles");
}
use std::f32;
}