1use crate::schema::{SimilaritySchema, FieldType};
7use crate::distance::{hash_text_to_vector, hash_categorical_to_vector};
8use distx_core::Vector;
9use serde_json::Value;
10
11pub const DEFAULT_TEXT_DIM: usize = 64;
13
14pub const DEFAULT_CATEGORICAL_DIM: usize = 64;
16
17#[derive(Debug, Clone)]
19pub struct StructuredEmbedder {
20 schema: SimilaritySchema,
21 text_dim: usize,
22 categorical_dim: usize,
23}
24
25impl StructuredEmbedder {
26 pub fn new(schema: SimilaritySchema) -> Self {
28 Self {
29 schema,
30 text_dim: DEFAULT_TEXT_DIM,
31 categorical_dim: DEFAULT_CATEGORICAL_DIM,
32 }
33 }
34
35 pub fn with_dimensions(schema: SimilaritySchema, text_dim: usize, categorical_dim: usize) -> Self {
37 Self {
38 schema,
39 text_dim,
40 categorical_dim,
41 }
42 }
43
44 pub fn vector_dim(&self) -> usize {
46 use crate::schema::FieldType;
47 self.schema.fields.values().map(|config| {
48 match config.field_type {
49 FieldType::Text => self.text_dim,
50 FieldType::Number => 1,
51 FieldType::Categorical => self.categorical_dim,
52 FieldType::Boolean => 1,
53 }
54 }).sum()
55 }
56
57 pub fn schema(&self) -> &SimilaritySchema {
59 &self.schema
60 }
61
62 pub fn embed(&self, payload: &Value) -> Vector {
71 let mut components: Vec<f32> = Vec::with_capacity(self.vector_dim());
72
73 for field_name in self.schema.sorted_field_names() {
75 let config = self.schema.get_field(field_name).unwrap();
76 let weight_sqrt = config.weight.sqrt(); let field_vector = self.embed_field(payload, field_name, config);
79
80 components.extend(field_vector.iter().map(|v| v * weight_sqrt));
82 }
83
84 let mut vector = Vector::new(components);
86 vector.normalize();
87 vector
88 }
89
90 fn embed_field(&self, payload: &Value, field_name: &str, config: &crate::schema::FieldConfig) -> Vec<f32> {
92 let value = payload.get(field_name);
93
94 match config.field_type {
95 FieldType::Text => self.embed_text(value),
96 FieldType::Number => self.embed_number(value),
97 FieldType::Categorical => self.embed_categorical(value),
98 FieldType::Boolean => self.embed_boolean(value),
99 }
100 }
101
102 fn embed_text(&self, value: Option<&Value>) -> Vec<f32> {
104 match value.and_then(|v| v.as_str()) {
105 Some(text) => hash_text_to_vector(text, self.text_dim),
106 None => vec![0.0; self.text_dim], }
108 }
109
110 fn embed_number(&self, value: Option<&Value>) -> Vec<f32> {
112 match value {
113 Some(v) => {
114 let num = v.as_f64().unwrap_or(0.0);
115 let normalized = num.tanh() as f32;
117 vec![normalized]
118 }
119 None => vec![0.0], }
121 }
122
123 fn embed_categorical(&self, value: Option<&Value>) -> Vec<f32> {
125 match value.and_then(|v| v.as_str()) {
126 Some(category) => hash_categorical_to_vector(category, self.categorical_dim),
127 None => vec![0.0; self.categorical_dim],
128 }
129 }
130
131 fn embed_boolean(&self, value: Option<&Value>) -> Vec<f32> {
133 match value.and_then(|v| v.as_bool()) {
134 Some(true) => vec![1.0],
135 Some(false) => vec![-1.0],
136 None => vec![0.0],
137 }
138 }
139}
140
141#[derive(Debug, Clone)]
143pub struct EmbedderBuilder {
144 schema: SimilaritySchema,
145 text_dim: usize,
146 categorical_dim: usize,
147}
148
149impl EmbedderBuilder {
150 pub fn new(schema: SimilaritySchema) -> Self {
151 Self {
152 schema,
153 text_dim: DEFAULT_TEXT_DIM,
154 categorical_dim: DEFAULT_CATEGORICAL_DIM,
155 }
156 }
157
158 pub fn text_dim(mut self, dim: usize) -> Self {
159 self.text_dim = dim;
160 self
161 }
162
163 pub fn categorical_dim(mut self, dim: usize) -> Self {
164 self.categorical_dim = dim;
165 self
166 }
167
168 pub fn build(self) -> StructuredEmbedder {
169 StructuredEmbedder::with_dimensions(self.schema, self.text_dim, self.categorical_dim)
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use crate::schema::{FieldConfig, DistanceType};
177 use std::collections::HashMap;
178 use serde_json::json;
179
180 fn create_test_schema() -> SimilaritySchema {
181 let mut fields = HashMap::new();
182 fields.insert("name".to_string(), FieldConfig::text(0.5));
183 fields.insert("price".to_string(), FieldConfig::number(0.3, DistanceType::Relative));
184 fields.insert("category".to_string(), FieldConfig::categorical(0.2));
185 SimilaritySchema::new(fields)
186 }
187
188 #[test]
189 fn test_embedder_creation() {
190 let schema = create_test_schema();
191 let embedder = StructuredEmbedder::new(schema);
192
193 assert_eq!(embedder.vector_dim(), 64 + 1 + 64);
195 }
196
197 #[test]
198 fn test_embed_complete_payload() {
199 let schema = create_test_schema();
200 let embedder = StructuredEmbedder::new(schema);
201
202 let payload = json!({
203 "name": "Prosciutto cotto",
204 "price": 1.99,
205 "category": "salumi"
206 });
207
208 let vector = embedder.embed(&payload);
209 assert_eq!(vector.dim(), embedder.vector_dim());
210
211 let magnitude: f32 = vector.as_slice().iter().map(|x| x * x).sum::<f32>().sqrt();
213 assert!((magnitude - 1.0).abs() < 0.01);
214 }
215
216 #[test]
217 fn test_embed_partial_payload() {
218 let schema = create_test_schema();
219 let embedder = StructuredEmbedder::new(schema);
220
221 let payload = json!({
222 "name": "Prosciutto"
223 });
225
226 let vector = embedder.embed(&payload);
227 assert_eq!(vector.dim(), embedder.vector_dim());
228 }
229
230 #[test]
231 fn test_same_payload_same_vector() {
232 let schema = create_test_schema();
233 let embedder = StructuredEmbedder::new(schema);
234
235 let payload = json!({
236 "name": "Product A",
237 "price": 10.0,
238 "category": "electronics"
239 });
240
241 let v1 = embedder.embed(&payload);
242 let v2 = embedder.embed(&payload);
243
244 assert_eq!(v1.as_slice(), v2.as_slice());
246 }
247
248 #[test]
249 fn test_similar_payloads_close_vectors() {
250 let schema = create_test_schema();
251 let embedder = StructuredEmbedder::new(schema);
252
253 let payload1 = json!({
254 "name": "Prosciutto cotto",
255 "price": 1.99,
256 "category": "salumi"
257 });
258
259 let payload2 = json!({
260 "name": "Prosciutto crudo",
261 "price": 2.49,
262 "category": "salumi"
263 });
264
265 let v1 = embedder.embed(&payload1);
266 let v2 = embedder.embed(&payload2);
267
268 let similarity = v1.cosine_similarity(&v2);
270 assert!(similarity > 0.5, "Expected similarity > 0.5, got {}", similarity);
271 }
272
273 #[test]
274 fn test_different_payloads_different_vectors() {
275 let schema = create_test_schema();
276 let embedder = StructuredEmbedder::new(schema);
277
278 let payload1 = json!({
279 "name": "Apple iPhone",
280 "price": 999.0,
281 "category": "electronics"
282 });
283
284 let payload2 = json!({
285 "name": "Organic Bananas",
286 "price": 1.99,
287 "category": "food"
288 });
289
290 let v1 = embedder.embed(&payload1);
291 let v2 = embedder.embed(&payload2);
292
293 let similarity = v1.cosine_similarity(&v2);
295 assert!(similarity < 0.5, "Expected similarity < 0.5, got {}", similarity);
296 }
297
298 #[test]
299 fn test_builder_pattern() {
300 let schema = create_test_schema();
301 let embedder = EmbedderBuilder::new(schema)
302 .text_dim(128)
303 .categorical_dim(32)
304 .build();
305
306 assert_eq!(embedder.vector_dim(), 128 + 1 + 32);
308 }
309}