1use crate::errors::Result;
2use crate::utils::{content_hash, pack_embedding};
3
4pub trait EmbeddingProvider: Send + Sync {
6 fn model_name(&self) -> &'static str {
7 "custom"
8 }
9 fn content_dim(&self) -> usize;
10 fn trigger_dim(&self) -> usize;
11 fn embed_content(&self, text: &str) -> Result<Vec<f32>>;
12 fn embed_trigger(&self, text: &str) -> Result<Vec<f32>>;
13}
14
15pub struct DummyEmbeddingProvider {
17 content_dim: usize,
18 trigger_dim: usize,
19}
20
21impl DummyEmbeddingProvider {
22 pub fn new(content_dim: usize, trigger_dim: usize) -> Self {
23 Self {
24 content_dim,
25 trigger_dim,
26 }
27 }
28}
29
30impl Default for DummyEmbeddingProvider {
31 fn default() -> Self {
32 Self::new(1024, 256)
33 }
34}
35
36impl EmbeddingProvider for DummyEmbeddingProvider {
37 fn model_name(&self) -> &'static str {
38 "DummyEmbeddingProvider"
39 }
40
41 fn content_dim(&self) -> usize {
42 self.content_dim
43 }
44 fn trigger_dim(&self) -> usize {
45 self.trigger_dim
46 }
47
48 fn embed_content(&self, text: &str) -> Result<Vec<f32>> {
49 Ok(hash_to_vec(text, self.content_dim))
50 }
51
52 fn embed_trigger(&self, text: &str) -> Result<Vec<f32>> {
53 Ok(hash_to_vec(text, self.trigger_dim))
54 }
55}
56
57fn hash_to_vec(text: &str, dim: usize) -> Vec<f32> {
58 let h = content_hash(text);
59 let bytes = h.as_bytes();
60 let mut v: Vec<f32> = (0..dim)
61 .map(|i| {
62 let b = bytes[i % bytes.len()] as f32;
63 (b / 255.0) * 2.0 - 1.0
64 })
65 .collect();
66 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
68 if norm > 0.0 {
69 for x in &mut v {
70 *x /= norm;
71 }
72 }
73 v
74}
75
76pub fn embed_to_bytes(
78 provider: &dyn EmbeddingProvider,
79 text: &str,
80 trigger: bool,
81) -> Result<Vec<u8>> {
82 let vec = if trigger {
83 provider.embed_trigger(text)?
84 } else {
85 provider.embed_content(text)?
86 };
87 Ok(pack_embedding(&vec))
88}