distx_schema/
schema.rs

1//! Similarity Schema definitions
2//!
3//! Defines the schema for structured similarity reranking.
4//! The schema specifies field types, distance metrics, and weights
5//! used to rerank vector search results.
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Similarity schema for structured reranking
11/// 
12/// Defines how payload fields should be compared when reranking
13/// vector search results. Each field has a type, distance metric,
14/// and weight in the final score.
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16pub struct SimilaritySchema {
17    /// Schema version for future compatibility
18    #[serde(default = "default_version")]
19    pub version: u32,
20    
21    /// Field configurations keyed by field name
22    pub fields: HashMap<String, FieldConfig>,
23}
24
25fn default_version() -> u32 {
26    1
27}
28
29impl SimilaritySchema {
30    /// Create a new similarity schema with the given fields
31    pub fn new(fields: HashMap<String, FieldConfig>) -> Self {
32        Self {
33            version: 1,
34            fields,
35        }
36    }
37
38    /// Validate the schema and normalize weights to sum to 1.0
39    pub fn validate_and_normalize(&mut self) -> Result<(), SchemaError> {
40        if self.fields.is_empty() {
41            return Err(SchemaError::EmptySchema);
42        }
43
44        // Check for negative weights
45        for (name, config) in &self.fields {
46            if config.weight < 0.0 {
47                return Err(SchemaError::NegativeWeight(name.clone()));
48            }
49        }
50
51        // Calculate sum of weights
52        let weight_sum: f32 = self.fields.values().map(|f| f.weight).sum();
53        
54        if weight_sum <= 0.0 {
55            return Err(SchemaError::ZeroTotalWeight);
56        }
57
58        // Normalize weights to sum to 1.0
59        if (weight_sum - 1.0).abs() > 0.001 {
60            for config in self.fields.values_mut() {
61                config.weight /= weight_sum;
62            }
63        }
64
65        Ok(())
66    }
67
68    /// Get field names in a deterministic order (sorted)
69    pub fn sorted_field_names(&self) -> Vec<&String> {
70        let mut names: Vec<_> = self.fields.keys().collect();
71        names.sort();
72        names
73    }
74
75    /// Get a field config by name
76    pub fn get_field(&self, name: &str) -> Option<&FieldConfig> {
77        self.fields.get(name)
78    }
79}
80
81/// Configuration for a single field in the similarity schema
82#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
83pub struct FieldConfig {
84    /// The type of the field (text, number, categorical, boolean)
85    #[serde(rename = "type")]
86    pub field_type: FieldType,
87    
88    /// The distance/similarity metric to use for this field
89    #[serde(default)]
90    pub distance: DistanceType,
91    
92    /// Weight of this field in the overall similarity score (0.0 to 1.0)
93    #[serde(default = "default_weight")]
94    pub weight: f32,
95}
96
97fn default_weight() -> f32 {
98    1.0
99}
100
101impl FieldConfig {
102    /// Create a text field config
103    /// 
104    /// Text fields are compared using trigram similarity during reranking.
105    /// For semantic search, use client-side embeddings + vector search.
106    pub fn text(weight: f32) -> Self {
107        Self {
108            field_type: FieldType::Text,
109            distance: DistanceType::Semantic,
110            weight,
111        }
112    }
113
114    /// Create a number field config
115    /// 
116    /// Use Relative for comparing values of different magnitudes (e.g., prices).
117    /// Use Absolute for values in the same range (e.g., ratings 1-5).
118    pub fn number(weight: f32, distance: DistanceType) -> Self {
119        Self {
120            field_type: FieldType::Number,
121            distance,
122            weight,
123        }
124    }
125
126    /// Create a categorical field config
127    /// 
128    /// Categorical fields use exact match by default.
129    /// Use Overlap for multi-value categories.
130    pub fn categorical(weight: f32) -> Self {
131        Self {
132            field_type: FieldType::Categorical,
133            distance: DistanceType::Exact,
134            weight,
135        }
136    }
137
138    /// Create a boolean field config
139    pub fn boolean(weight: f32) -> Self {
140        Self {
141            field_type: FieldType::Boolean,
142            distance: DistanceType::Exact,
143            weight,
144        }
145    }
146}
147
148/// Field type enumeration
149#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
150#[serde(rename_all = "lowercase")]
151pub enum FieldType {
152    /// Text field - uses trigram or exact matching for reranking
153    Text,
154    /// Numeric field - uses absolute or relative distance
155    Number,
156    /// Categorical field - uses exact or overlap matching
157    Categorical,
158    /// Boolean field - uses exact matching
159    Boolean,
160}
161
162/// Distance/similarity type for fields
163#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
164#[serde(rename_all = "lowercase")]
165pub enum DistanceType {
166    /// Trigram similarity for text fields
167    #[default]
168    Semantic,
169    /// Absolute distance: exp(-|a - b| / scale)
170    Absolute,
171    /// Relative distance: 1 - |a - b| / max(|a|, |b|)
172    Relative,
173    /// Exact match: 1 if equal, 0 otherwise
174    Exact,
175    /// Overlap/Jaccard similarity for token sets
176    Overlap,
177}
178
179/// Errors that can occur during schema validation
180#[derive(Debug, Clone, thiserror::Error)]
181pub enum SchemaError {
182    #[error("Schema cannot be empty")]
183    EmptySchema,
184    
185    #[error("Field '{0}' has negative weight")]
186    NegativeWeight(String),
187    
188    #[error("Total weight cannot be zero")]
189    ZeroTotalWeight,
190    
191    #[error("Field '{0}' not found in schema")]
192    FieldNotFound(String),
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    #[test]
200    fn test_schema_creation() {
201        let mut fields = HashMap::new();
202        fields.insert("name".to_string(), FieldConfig::text(0.5));
203        fields.insert("price".to_string(), FieldConfig::number(0.3, DistanceType::Relative));
204        fields.insert("category".to_string(), FieldConfig::categorical(0.2));
205
206        let schema = SimilaritySchema::new(fields);
207        assert_eq!(schema.version, 1);
208        assert_eq!(schema.fields.len(), 3);
209    }
210
211    #[test]
212    fn test_schema_normalization() {
213        let mut fields = HashMap::new();
214        fields.insert("a".to_string(), FieldConfig::text(2.0));
215        fields.insert("b".to_string(), FieldConfig::number(2.0, DistanceType::Absolute));
216
217        let mut schema = SimilaritySchema::new(fields);
218        schema.validate_and_normalize().unwrap();
219
220        let weight_sum: f32 = schema.fields.values().map(|f| f.weight).sum();
221        assert!((weight_sum - 1.0).abs() < 0.001);
222    }
223
224    #[test]
225    fn test_empty_schema_error() {
226        let mut schema = SimilaritySchema::new(HashMap::new());
227        assert!(matches!(
228            schema.validate_and_normalize(),
229            Err(SchemaError::EmptySchema)
230        ));
231    }
232
233    #[test]
234    fn test_serde_roundtrip() {
235        let mut fields = HashMap::new();
236        fields.insert("name".to_string(), FieldConfig::text(0.5));
237        fields.insert("price".to_string(), FieldConfig::number(0.5, DistanceType::Relative));
238
239        let schema = SimilaritySchema::new(fields);
240        let json = serde_json::to_string(&schema).unwrap();
241        let parsed: SimilaritySchema = serde_json::from_str(&json).unwrap();
242        
243        assert_eq!(schema, parsed);
244    }
245}