use crate::errors::Result;
use crate::utils::{content_hash, pack_embedding};
pub trait EmbeddingProvider: Send + Sync {
fn model_name(&self) -> &'static str {
"custom"
}
fn content_dim(&self) -> usize;
fn trigger_dim(&self) -> usize;
fn embed_content(&self, text: &str) -> Result<Vec<f32>>;
fn embed_trigger(&self, text: &str) -> Result<Vec<f32>>;
}
pub struct DummyEmbeddingProvider {
content_dim: usize,
trigger_dim: usize,
}
impl DummyEmbeddingProvider {
pub fn new(content_dim: usize, trigger_dim: usize) -> Self {
Self {
content_dim,
trigger_dim,
}
}
}
impl Default for DummyEmbeddingProvider {
fn default() -> Self {
Self::new(1024, 256)
}
}
impl EmbeddingProvider for DummyEmbeddingProvider {
fn model_name(&self) -> &'static str {
"DummyEmbeddingProvider"
}
fn content_dim(&self) -> usize {
self.content_dim
}
fn trigger_dim(&self) -> usize {
self.trigger_dim
}
fn embed_content(&self, text: &str) -> Result<Vec<f32>> {
Ok(hash_to_vec(text, self.content_dim))
}
fn embed_trigger(&self, text: &str) -> Result<Vec<f32>> {
Ok(hash_to_vec(text, self.trigger_dim))
}
}
fn hash_to_vec(text: &str, dim: usize) -> Vec<f32> {
let h = content_hash(text);
let bytes = h.as_bytes();
let mut v: Vec<f32> = (0..dim)
.map(|i| {
let b = bytes[i % bytes.len()] as f32;
(b / 255.0) * 2.0 - 1.0
})
.collect();
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut v {
*x /= norm;
}
}
v
}
pub fn embed_to_bytes(
provider: &dyn EmbeddingProvider,
text: &str,
trigger: bool,
) -> Result<Vec<u8>> {
let vec = if trigger {
provider.embed_trigger(text)?
} else {
provider.embed_content(text)?
};
Ok(pack_embedding(&vec))
}