use crate::core::error::Result;
use crate::core::types::{EmbeddingResult, Sector};
use crate::memory::embed::EmbeddingProvider;
use crate::utils::text::{add_synonym_tokens, canonical_tokens_from_text};
use async_trait::async_trait;
use std::collections::HashMap;
pub struct SyntheticProvider {
dim: usize,
}
impl SyntheticProvider {
pub fn new(dim: usize) -> Self {
Self { dim }
}
pub fn generate(&self, text: &str, sector: &Sector) -> Vec<f32> {
gen_synthetic_embedding(text, sector, self.dim)
}
}
#[async_trait]
impl EmbeddingProvider for SyntheticProvider {
async fn embed(&self, text: &str, sector: &Sector) -> Result<EmbeddingResult> {
let vector = self.generate(text, sector);
Ok(EmbeddingResult {
sector: *sector,
vector: vector.clone(),
dim: vector.len(),
})
}
fn dimensions(&self) -> usize {
self.dim
}
fn name(&self) -> &'static str {
"synthetic"
}
fn supports_batch(&self) -> bool {
true }
async fn embed_batch(&self, texts: &[(&str, &Sector)]) -> Result<Vec<EmbeddingResult>> {
let results = texts
.iter()
.map(|(text, sector)| {
let vector = self.generate(text, sector);
EmbeddingResult {
sector: **sector,
vector: vector.clone(),
dim: vector.len(),
}
})
.collect();
Ok(results)
}
}
fn sector_weight(sector: &Sector) -> f32 {
match sector {
Sector::Episodic => 1.3,
Sector::Semantic => 1.0,
Sector::Procedural => 1.2,
Sector::Emotional => 1.4,
Sector::Reflective => 0.9,
}
}
fn hash_fnv1a(s: &str) -> u32 {
let mut h: u32 = 0x811c9dc5;
for byte in s.bytes() {
h ^= byte as u32;
h = h.wrapping_mul(16777619);
}
h
}
fn hash_secondary(s: &str, seed: u32) -> u32 {
let mut h = seed;
for byte in s.bytes() {
h ^= byte as u32;
h = h.wrapping_mul(0x5bd1e995);
h = (h >> 13) ^ h;
}
h
}
fn add_feature(vec: &mut [f32], key: &str, weight: f32) {
let dim = vec.len();
if dim == 0 {
return;
}
let h1 = hash_fnv1a(key);
let h2 = hash_secondary(key, 0xdeadbeef);
let sign = if h1 & 1 == 0 { 1.0 } else { -1.0 };
let val = weight * sign;
if dim > 0 && (dim & (dim - 1)) == 0 {
vec[(h1 as usize) & (dim - 1)] += val;
vec[(h2 as usize) & (dim - 1)] += val * 0.5;
} else {
vec[(h1 as usize) % dim] += val;
vec[(h2 as usize) % dim] += val * 0.5;
}
}
fn add_positional_feature(vec: &mut [f32], pos: usize, weight: f32) {
let dim = vec.len();
if dim == 0 {
return;
}
let idx = pos % dim;
let angle = pos as f32 / 10000.0_f32.powf((2 * idx) as f32 / dim as f32);
vec[idx] += weight * angle.sin();
vec[(idx + 1) % dim] += weight * angle.cos();
}
fn normalize(vec: &mut [f32]) {
let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
let inv_norm = 1.0 / norm;
for x in vec.iter_mut() {
*x *= inv_norm;
}
}
}
pub fn gen_synthetic_embedding(text: &str, sector: &Sector, dim: usize) -> Vec<f32> {
let mut vec = vec![0.0f32; dim];
let canonical_tokens = canonical_tokens_from_text(text);
if canonical_tokens.is_empty() {
let default_val = 1.0 / (dim as f32).sqrt();
return vec![default_val; dim];
}
let expanded_tokens: Vec<String> = add_synonym_tokens(
canonical_tokens.iter().map(|s| s.as_str()),
)
.into_iter()
.collect();
let mut term_freq: HashMap<&str, usize> = HashMap::new();
for tok in &expanded_tokens {
*term_freq.entry(tok.as_str()).or_insert(0) += 1;
}
let sector_w = sector_weight(sector);
let doc_length = expanded_tokens.len() as f32;
let doc_log = (1.0 + doc_length).ln();
let sector_str = sector.as_str();
for (tok, &count) in &term_freq {
let tf = count as f32 / doc_length;
let idf = (1.0 + doc_length / count as f32).ln();
let weight = (tf * idf + 1.0) * sector_w;
add_feature(&mut vec, &format!("{}|tok|{}", sector_str, tok), weight);
if tok.len() >= 3 {
let chars: Vec<char> = tok.chars().collect();
for i in 0..chars.len().saturating_sub(2) {
let trigram: String = chars[i..i + 3].iter().collect();
add_feature(
&mut vec,
&format!("{}|c3|{}", sector_str, trigram),
weight * 0.4,
);
}
}
if tok.len() >= 4 {
let chars: Vec<char> = tok.chars().collect();
for i in 0..chars.len().saturating_sub(3) {
let fourgram: String = chars[i..i + 4].iter().collect();
add_feature(
&mut vec,
&format!("{}|c4|{}", sector_str, fourgram),
weight * 0.3,
);
}
}
}
for i in 0..canonical_tokens.len().saturating_sub(1) {
let a = &canonical_tokens[i];
let b = &canonical_tokens[i + 1];
let position_weight = 1.0 / (1.0 + i as f32 * 0.1);
add_feature(
&mut vec,
&format!("{}|bi|{}_{}", sector_str, a, b),
1.4 * sector_w * position_weight,
);
}
for i in 0..canonical_tokens.len().saturating_sub(2) {
let a = &canonical_tokens[i];
let b = &canonical_tokens[i + 1];
let c = &canonical_tokens[i + 2];
add_feature(
&mut vec,
&format!("{}|tri|{}_{}_{}", sector_str, a, b, c),
1.0 * sector_w,
);
}
for i in 0..canonical_tokens.len().saturating_sub(2).min(20) {
let a = &canonical_tokens[i];
let c = &canonical_tokens[i + 2];
add_feature(
&mut vec,
&format!("{}|skip|{}_{}", sector_str, a, c),
0.7 * sector_w,
);
}
for i in 0..canonical_tokens.len().min(50) {
add_positional_feature(&mut vec, i, 0.5 * sector_w / doc_log);
}
let length_bucket = ((doc_length + 1.0).log2() as usize).min(10);
add_feature(
&mut vec,
&format!("{}|len|{}", sector_str, length_bucket),
0.6 * sector_w,
);
let density = term_freq.len() as f32 / doc_length;
let density_bucket = (density * 10.0) as usize;
add_feature(
&mut vec,
&format!("{}|dens|{}", sector_str, density_bucket),
0.5 * sector_w,
);
normalize(&mut vec);
vec
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_synthetic_provider() {
let provider = SyntheticProvider::new(256);
assert_eq!(provider.dimensions(), 256);
assert_eq!(provider.name(), "synthetic");
}
#[test]
fn test_gen_embedding() {
let embedding = gen_synthetic_embedding("Hello world", &Sector::Semantic, 256);
assert_eq!(embedding.len(), 256);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5);
}
#[test]
fn test_empty_text() {
let embedding = gen_synthetic_embedding("", &Sector::Semantic, 256);
assert_eq!(embedding.len(), 256);
let first = embedding[0];
for &val in &embedding {
assert!((val - first).abs() < 1e-6);
}
}
#[test]
fn test_different_sectors() {
let text = "This is a test sentence.";
let e1 = gen_synthetic_embedding(text, &Sector::Episodic, 256);
let e2 = gen_synthetic_embedding(text, &Sector::Semantic, 256);
let dot: f32 = e1.iter().zip(e2.iter()).map(|(a, b)| a * b).sum();
assert!(dot < 0.99); }
#[test]
fn test_similar_texts() {
let e1 = gen_synthetic_embedding("I love programming", &Sector::Semantic, 256);
let e2 = gen_synthetic_embedding("I enjoy coding", &Sector::Semantic, 256);
let e3 = gen_synthetic_embedding("The weather is nice", &Sector::Semantic, 256);
let sim_12: f32 = e1.iter().zip(e2.iter()).map(|(a, b)| a * b).sum();
let sim_13: f32 = e1.iter().zip(e3.iter()).map(|(a, b)| a * b).sum();
assert!(sim_12.abs() > 0.0 || sim_13.abs() > 0.0);
}
#[test]
fn test_hash_functions() {
let h1 = hash_fnv1a("test");
let h2 = hash_fnv1a("test");
assert_eq!(h1, h2);
let h3 = hash_fnv1a("different");
assert_ne!(h1, h3);
}
}