#![allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
use crate::embedding::{EmbeddingProvider, Projection};
use crate::encoder::TextEncoder;
use crate::error::Result;
use crate::hyperdim::HVec10240;
#[derive(Debug, Clone)]
pub struct HdcTextProvider {
encoder: TextEncoder,
}
impl Default for HdcTextProvider {
fn default() -> Self {
Self {
encoder: TextEncoder::new(),
}
}
}
impl HdcTextProvider {
#[must_use]
pub fn new() -> Self {
Self {
encoder: TextEncoder::new(),
}
}
#[must_use]
pub const fn with_config(config: crate::encoder::TextEncoderConfig) -> Self {
Self {
encoder: TextEncoder::with_config(config),
}
}
}
#[async_trait::async_trait]
impl EmbeddingProvider for HdcTextProvider {
fn name(&self) -> &str {
"hdc-text"
}
fn native_dim(&self) -> usize {
10240 }
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
let hv = self.encoder.encode(text);
let mut result = Vec::with_capacity(10240);
for word in &hv.data {
for i in 0..128 {
if (word >> i) & 1 == 1 {
result.push(1.0);
} else {
result.push(0.0);
}
}
}
Ok(result)
}
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let mut results = Vec::with_capacity(texts.len());
for t in texts {
results.push(self.embed(t).await?);
}
Ok(results)
}
fn project(&self, vec: &[f32], projection: &Projection) -> HVec10240 {
if projection.nnz() == 0 {
let mut hv = HVec10240::zero();
for (i, &v) in vec.iter().take(10240).enumerate() {
if v > 0.5 {
let word = i / 128;
let bit = i % 128;
hv.data[word] |= 1u128 << bit;
}
}
hv
} else {
projection.project(vec)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::embedding::Projection;
#[tokio::test]
async fn test_hdc_provider_roundtrip() {
let provider = HdcTextProvider::new();
let projection = Projection::empty();
let text = "hello world";
let embedding = provider.embed(text).await.unwrap();
assert_eq!(embedding.len(), 10240);
let vector = provider.project(&embedding, &projection);
let zero = HVec10240::zero();
assert_ne!(vector, zero, "HDC vector should not be zero");
let direct = TextEncoder::new().encode(text);
assert_eq!(vector, direct);
}
}