distx_schema/
embedder.rs

1//! Structured Embedder
2//!
3//! Converts payload fields into a single composite vector based on the similarity schema.
4//! This enables vector search over structured/tabular data without requiring external embeddings.
5
6use crate::schema::{SimilaritySchema, FieldType};
7use crate::distance::{hash_text_to_vector, hash_categorical_to_vector};
8use distx_core::Vector;
9use serde_json::Value;
10
11/// Default dimension for text embeddings
12pub const DEFAULT_TEXT_DIM: usize = 64;
13
14/// Default dimension for categorical embeddings
15pub const DEFAULT_CATEGORICAL_DIM: usize = 64;
16
17/// Structured embedder that converts payloads to vectors based on schema
18#[derive(Debug, Clone)]
19pub struct StructuredEmbedder {
20    schema: SimilaritySchema,
21    text_dim: usize,
22    categorical_dim: usize,
23}
24
25impl StructuredEmbedder {
26    /// Create a new structured embedder with the given schema
27    pub fn new(schema: SimilaritySchema) -> Self {
28        Self {
29            schema,
30            text_dim: DEFAULT_TEXT_DIM,
31            categorical_dim: DEFAULT_CATEGORICAL_DIM,
32        }
33    }
34
35    /// Create embedder with custom dimensions
36    pub fn with_dimensions(schema: SimilaritySchema, text_dim: usize, categorical_dim: usize) -> Self {
37        Self {
38            schema,
39            text_dim,
40            categorical_dim,
41        }
42    }
43
44    /// Get the total vector dimension for this embedder
45    pub fn vector_dim(&self) -> usize {
46        use crate::schema::FieldType;
47        self.schema.fields.values().map(|config| {
48            match config.field_type {
49                FieldType::Text => self.text_dim,
50                FieldType::Number => 1,
51                FieldType::Categorical => self.categorical_dim,
52                FieldType::Boolean => 1,
53            }
54        }).sum()
55    }
56
57    /// Get a reference to the schema
58    pub fn schema(&self) -> &SimilaritySchema {
59        &self.schema
60    }
61
62    /// Convert a payload to a composite vector
63    /// 
64    /// The vector is constructed by:
65    /// 1. Iterating through schema fields in sorted order
66    /// 2. Extracting values from the payload
67    /// 3. Embedding each field according to its type
68    /// 4. Applying weights
69    /// 5. Concatenating all embeddings
70    pub fn embed(&self, payload: &Value) -> Vector {
71        let mut components: Vec<f32> = Vec::with_capacity(self.vector_dim());
72        
73        // Process fields in sorted order for consistency
74        for field_name in self.schema.sorted_field_names() {
75            let config = self.schema.get_field(field_name).unwrap();
76            let weight_sqrt = config.weight.sqrt(); // Apply sqrt of weight to vector
77            
78            let field_vector = self.embed_field(payload, field_name, config);
79            
80            // Apply weight and extend
81            components.extend(field_vector.iter().map(|v| v * weight_sqrt));
82        }
83        
84        // Create and normalize the vector
85        let mut vector = Vector::new(components);
86        vector.normalize();
87        vector
88    }
89
90    /// Embed a single field from the payload
91    fn embed_field(&self, payload: &Value, field_name: &str, config: &crate::schema::FieldConfig) -> Vec<f32> {
92        let value = payload.get(field_name);
93        
94        match config.field_type {
95            FieldType::Text => self.embed_text(value),
96            FieldType::Number => self.embed_number(value),
97            FieldType::Categorical => self.embed_categorical(value),
98            FieldType::Boolean => self.embed_boolean(value),
99        }
100    }
101
102    /// Embed a text field
103    fn embed_text(&self, value: Option<&Value>) -> Vec<f32> {
104        match value.and_then(|v| v.as_str()) {
105            Some(text) => hash_text_to_vector(text, self.text_dim),
106            None => vec![0.0; self.text_dim], // Missing field gets zero vector
107        }
108    }
109
110    /// Embed a number field
111    fn embed_number(&self, value: Option<&Value>) -> Vec<f32> {
112        match value {
113            Some(v) => {
114                let num = v.as_f64().unwrap_or(0.0);
115                // Normalize using sigmoid-like function for unbounded values
116                let normalized = num.tanh() as f32;
117                vec![normalized]
118            }
119            None => vec![0.0], // Missing field
120        }
121    }
122
123    /// Embed a categorical field
124    fn embed_categorical(&self, value: Option<&Value>) -> Vec<f32> {
125        match value.and_then(|v| v.as_str()) {
126            Some(category) => hash_categorical_to_vector(category, self.categorical_dim),
127            None => vec![0.0; self.categorical_dim],
128        }
129    }
130
131    /// Embed a boolean field
132    fn embed_boolean(&self, value: Option<&Value>) -> Vec<f32> {
133        match value.and_then(|v| v.as_bool()) {
134            Some(true) => vec![1.0],
135            Some(false) => vec![-1.0],
136            None => vec![0.0],
137        }
138    }
139}
140
141/// Builder for creating StructuredEmbedder with custom options
142#[derive(Debug, Clone)]
143pub struct EmbedderBuilder {
144    schema: SimilaritySchema,
145    text_dim: usize,
146    categorical_dim: usize,
147}
148
149impl EmbedderBuilder {
150    pub fn new(schema: SimilaritySchema) -> Self {
151        Self {
152            schema,
153            text_dim: DEFAULT_TEXT_DIM,
154            categorical_dim: DEFAULT_CATEGORICAL_DIM,
155        }
156    }
157
158    pub fn text_dim(mut self, dim: usize) -> Self {
159        self.text_dim = dim;
160        self
161    }
162
163    pub fn categorical_dim(mut self, dim: usize) -> Self {
164        self.categorical_dim = dim;
165        self
166    }
167
168    pub fn build(self) -> StructuredEmbedder {
169        StructuredEmbedder::with_dimensions(self.schema, self.text_dim, self.categorical_dim)
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use crate::schema::{FieldConfig, DistanceType};
177    use std::collections::HashMap;
178    use serde_json::json;
179
180    fn create_test_schema() -> SimilaritySchema {
181        let mut fields = HashMap::new();
182        fields.insert("name".to_string(), FieldConfig::text(0.5));
183        fields.insert("price".to_string(), FieldConfig::number(0.3, DistanceType::Relative));
184        fields.insert("category".to_string(), FieldConfig::categorical(0.2));
185        SimilaritySchema::new(fields)
186    }
187
188    #[test]
189    fn test_embedder_creation() {
190        let schema = create_test_schema();
191        let embedder = StructuredEmbedder::new(schema);
192        
193        // Vector dim = text(64) + number(1) + categorical(64) = 129
194        assert_eq!(embedder.vector_dim(), 64 + 1 + 64);
195    }
196
197    #[test]
198    fn test_embed_complete_payload() {
199        let schema = create_test_schema();
200        let embedder = StructuredEmbedder::new(schema);
201        
202        let payload = json!({
203            "name": "Prosciutto cotto",
204            "price": 1.99,
205            "category": "salumi"
206        });
207        
208        let vector = embedder.embed(&payload);
209        assert_eq!(vector.dim(), embedder.vector_dim());
210        
211        // Vector should be normalized
212        let magnitude: f32 = vector.as_slice().iter().map(|x| x * x).sum::<f32>().sqrt();
213        assert!((magnitude - 1.0).abs() < 0.01);
214    }
215
216    #[test]
217    fn test_embed_partial_payload() {
218        let schema = create_test_schema();
219        let embedder = StructuredEmbedder::new(schema);
220        
221        let payload = json!({
222            "name": "Prosciutto"
223            // Missing price and category
224        });
225        
226        let vector = embedder.embed(&payload);
227        assert_eq!(vector.dim(), embedder.vector_dim());
228    }
229
230    #[test]
231    fn test_same_payload_same_vector() {
232        let schema = create_test_schema();
233        let embedder = StructuredEmbedder::new(schema);
234        
235        let payload = json!({
236            "name": "Product A",
237            "price": 10.0,
238            "category": "electronics"
239        });
240        
241        let v1 = embedder.embed(&payload);
242        let v2 = embedder.embed(&payload);
243        
244        // Same payload should produce identical vectors
245        assert_eq!(v1.as_slice(), v2.as_slice());
246    }
247
248    #[test]
249    fn test_similar_payloads_close_vectors() {
250        let schema = create_test_schema();
251        let embedder = StructuredEmbedder::new(schema);
252        
253        let payload1 = json!({
254            "name": "Prosciutto cotto",
255            "price": 1.99,
256            "category": "salumi"
257        });
258        
259        let payload2 = json!({
260            "name": "Prosciutto crudo",
261            "price": 2.49,
262            "category": "salumi"
263        });
264        
265        let v1 = embedder.embed(&payload1);
266        let v2 = embedder.embed(&payload2);
267        
268        // Similar products should have high cosine similarity
269        let similarity = v1.cosine_similarity(&v2);
270        assert!(similarity > 0.5, "Expected similarity > 0.5, got {}", similarity);
271    }
272
273    #[test]
274    fn test_different_payloads_different_vectors() {
275        let schema = create_test_schema();
276        let embedder = StructuredEmbedder::new(schema);
277        
278        let payload1 = json!({
279            "name": "Apple iPhone",
280            "price": 999.0,
281            "category": "electronics"
282        });
283        
284        let payload2 = json!({
285            "name": "Organic Bananas",
286            "price": 1.99,
287            "category": "food"
288        });
289        
290        let v1 = embedder.embed(&payload1);
291        let v2 = embedder.embed(&payload2);
292        
293        // Very different products should have low similarity
294        let similarity = v1.cosine_similarity(&v2);
295        assert!(similarity < 0.5, "Expected similarity < 0.5, got {}", similarity);
296    }
297
298    #[test]
299    fn test_builder_pattern() {
300        let schema = create_test_schema();
301        let embedder = EmbedderBuilder::new(schema)
302            .text_dim(128)
303            .categorical_dim(32)
304            .build();
305        
306        // Vector dim = text(128) + number(1) + categorical(32) = 161
307        assert_eq!(embedder.vector_dim(), 128 + 1 + 32);
308    }
309}