1use crate::errors::Result;
2use crate::utils::{content_hash, pack_embedding};
3
4pub trait EmbeddingProvider: Send + Sync {
6 fn model_name(&self) -> &'static str {
7 "custom"
8 }
9 fn content_dim(&self) -> usize;
10 fn trigger_dim(&self) -> usize;
11 fn embed_content(&self, text: &str) -> Result<Vec<f32>>;
12 fn embed_trigger(&self, text: &str) -> Result<Vec<f32>>;
13
14 fn embed_both(&self, text: &str) -> Result<(Vec<f32>, Vec<f32>)> {
19 Ok((self.embed_content(text)?, self.embed_trigger(text)?))
20 }
21}
22
23pub struct DummyEmbeddingProvider {
25 content_dim: usize,
26 trigger_dim: usize,
27}
28
29impl DummyEmbeddingProvider {
30 pub fn new(content_dim: usize, trigger_dim: usize) -> Self {
31 Self {
32 content_dim,
33 trigger_dim,
34 }
35 }
36}
37
38impl Default for DummyEmbeddingProvider {
39 fn default() -> Self {
40 Self::new(1024, 256)
41 }
42}
43
44impl EmbeddingProvider for DummyEmbeddingProvider {
45 fn model_name(&self) -> &'static str {
46 "DummyEmbeddingProvider"
47 }
48
49 fn content_dim(&self) -> usize {
50 self.content_dim
51 }
52 fn trigger_dim(&self) -> usize {
53 self.trigger_dim
54 }
55
56 fn embed_content(&self, text: &str) -> Result<Vec<f32>> {
57 Ok(hash_to_vec(text, self.content_dim))
58 }
59
60 fn embed_trigger(&self, text: &str) -> Result<Vec<f32>> {
61 Ok(hash_to_vec(text, self.trigger_dim))
62 }
63}
64
65fn hash_to_vec(text: &str, dim: usize) -> Vec<f32> {
66 let h = content_hash(text);
67 let bytes = h.as_bytes();
68 let mut v: Vec<f32> = (0..dim)
69 .map(|i| {
70 let b = bytes[i % bytes.len()] as f32;
71 (b / 255.0) * 2.0 - 1.0
72 })
73 .collect();
74 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
76 if norm > 0.0 {
77 for x in &mut v {
78 *x /= norm;
79 }
80 }
81 v
82}
83
84pub fn embed_to_bytes(
86 provider: &dyn EmbeddingProvider,
87 text: &str,
88 trigger: bool,
89) -> Result<Vec<u8>> {
90 let vec = if trigger {
91 provider.embed_trigger(text)?
92 } else {
93 provider.embed_content(text)?
94 };
95 Ok(pack_embedding(&vec))
96}