distx_schema/
schema.rs

1//! Similarity Schema definitions
2//!
3//! Defines the declarative schema for structured similarity queries.
4//! The schema specifies which fields matter for similarity, what type of
5//! similarity to use per field, and the weight of each field.
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Similarity schema version 1
11/// 
12/// A declarative schema that defines how to compute similarity for tabular rows.
13/// Stored at collection level and used for both embedding and reranking.
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15pub struct SimilaritySchema {
16    /// Schema version for future compatibility
17    #[serde(default = "default_version")]
18    pub version: u32,
19    
20    /// Field configurations keyed by field name
21    pub fields: HashMap<String, FieldConfig>,
22}
23
24fn default_version() -> u32 {
25    1
26}
27
28impl SimilaritySchema {
29    /// Create a new similarity schema with the given fields
30    pub fn new(fields: HashMap<String, FieldConfig>) -> Self {
31        Self {
32            version: 1,
33            fields,
34        }
35    }
36
37    /// Validate the schema
38    /// - Checks that weights are positive
39    /// - Normalizes weights to sum to 1.0 if they don't
40    pub fn validate_and_normalize(&mut self) -> Result<(), SchemaError> {
41        if self.fields.is_empty() {
42            return Err(SchemaError::EmptySchema);
43        }
44
45        // Check for negative weights
46        for (name, config) in &self.fields {
47            if config.weight < 0.0 {
48                return Err(SchemaError::NegativeWeight(name.clone()));
49            }
50        }
51
52        // Calculate sum of weights
53        let weight_sum: f32 = self.fields.values().map(|f| f.weight).sum();
54        
55        if weight_sum <= 0.0 {
56            return Err(SchemaError::ZeroTotalWeight);
57        }
58
59        // Normalize weights to sum to 1.0
60        if (weight_sum - 1.0).abs() > 0.001 {
61            for config in self.fields.values_mut() {
62                config.weight /= weight_sum;
63            }
64        }
65
66        Ok(())
67    }
68
69    /// Get the total number of dimensions needed for the composite vector
70    /// Each field type contributes a fixed number of dimensions
71    pub fn compute_vector_dim(&self, text_dim: usize) -> usize {
72        self.fields.values().map(|config| {
73            match config.field_type {
74                FieldType::Text => text_dim,
75                FieldType::Number => 1,
76                FieldType::Categorical => 64, // Hash-based encoding
77                FieldType::Boolean => 1,
78            }
79        }).sum()
80    }
81
82    /// Get field names in a deterministic order (sorted)
83    pub fn sorted_field_names(&self) -> Vec<&String> {
84        let mut names: Vec<_> = self.fields.keys().collect();
85        names.sort();
86        names
87    }
88
89    /// Get a field config by name
90    pub fn get_field(&self, name: &str) -> Option<&FieldConfig> {
91        self.fields.get(name)
92    }
93}
94
95/// Configuration for a single field in the similarity schema
96#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
97pub struct FieldConfig {
98    /// The type of the field (text, number, categorical, boolean)
99    #[serde(rename = "type")]
100    pub field_type: FieldType,
101    
102    /// The distance/similarity metric to use for this field
103    #[serde(default)]
104    pub distance: DistanceType,
105    
106    /// Weight of this field in the overall similarity score (0.0 to 1.0)
107    #[serde(default = "default_weight")]
108    pub weight: f32,
109    
110    /// Embedding type for text fields (optional)
111    #[serde(skip_serializing_if = "Option::is_none")]
112    pub embedding: Option<EmbeddingType>,
113}
114
115fn default_weight() -> f32 {
116    1.0
117}
118
119impl FieldConfig {
120    /// Create a new text field configuration
121    pub fn text(weight: f32) -> Self {
122        Self {
123            field_type: FieldType::Text,
124            distance: DistanceType::Semantic,
125            weight,
126            embedding: Some(EmbeddingType::Semantic),
127        }
128    }
129
130    /// Create a new number field configuration
131    pub fn number(weight: f32, distance: DistanceType) -> Self {
132        Self {
133            field_type: FieldType::Number,
134            distance,
135            weight,
136            embedding: None,
137        }
138    }
139
140    /// Create a new categorical field configuration
141    pub fn categorical(weight: f32) -> Self {
142        Self {
143            field_type: FieldType::Categorical,
144            distance: DistanceType::Exact,
145            weight,
146            embedding: None,
147        }
148    }
149
150    /// Create a new boolean field configuration
151    pub fn boolean(weight: f32) -> Self {
152        Self {
153            field_type: FieldType::Boolean,
154            distance: DistanceType::Exact,
155            weight,
156            embedding: None,
157        }
158    }
159}
160
161/// Field type enumeration
162#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
163#[serde(rename_all = "lowercase")]
164pub enum FieldType {
165    /// Text field - uses semantic or exact matching
166    Text,
167    /// Numeric field - uses absolute or relative distance
168    Number,
169    /// Categorical field - uses exact or overlap matching
170    Categorical,
171    /// Boolean field - uses exact matching
172    Boolean,
173}
174
175/// Distance/similarity type for fields
176#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
177#[serde(rename_all = "lowercase")]
178pub enum DistanceType {
179    /// Semantic similarity (for text) - uses embeddings
180    #[default]
181    Semantic,
182    /// Absolute distance: 1 - |a - b| / max_range
183    Absolute,
184    /// Relative distance: 1 - |a - b| / max(|a|, |b|)
185    Relative,
186    /// Exact match: 1 if equal, 0 otherwise
187    Exact,
188    /// Overlap/Jaccard similarity for sets or tokens
189    Overlap,
190}
191
192/// Embedding type for text fields
193#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
194#[serde(rename_all = "lowercase")]
195pub enum EmbeddingType {
196    /// Semantic embeddings (hash-based in v1, can be extended to ML)
197    Semantic,
198    /// Exact string matching only
199    Exact,
200}
201
202/// Errors that can occur during schema validation
203#[derive(Debug, Clone, thiserror::Error)]
204pub enum SchemaError {
205    #[error("Schema cannot be empty")]
206    EmptySchema,
207    
208    #[error("Field '{0}' has negative weight")]
209    NegativeWeight(String),
210    
211    #[error("Total weight cannot be zero")]
212    ZeroTotalWeight,
213    
214    #[error("Field '{0}' not found in schema")]
215    FieldNotFound(String),
216    
217    #[error("Invalid field type for distance metric")]
218    InvalidDistanceForType,
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_schema_creation() {
227        let mut fields = HashMap::new();
228        fields.insert("name".to_string(), FieldConfig::text(0.5));
229        fields.insert("price".to_string(), FieldConfig::number(0.3, DistanceType::Relative));
230        fields.insert("category".to_string(), FieldConfig::categorical(0.2));
231
232        let schema = SimilaritySchema::new(fields);
233        assert_eq!(schema.version, 1);
234        assert_eq!(schema.fields.len(), 3);
235    }
236
237    #[test]
238    fn test_schema_normalization() {
239        let mut fields = HashMap::new();
240        fields.insert("a".to_string(), FieldConfig::text(2.0));
241        fields.insert("b".to_string(), FieldConfig::number(2.0, DistanceType::Absolute));
242
243        let mut schema = SimilaritySchema::new(fields);
244        schema.validate_and_normalize().unwrap();
245
246        let weight_sum: f32 = schema.fields.values().map(|f| f.weight).sum();
247        assert!((weight_sum - 1.0).abs() < 0.001);
248    }
249
250    #[test]
251    fn test_empty_schema_error() {
252        let mut schema = SimilaritySchema::new(HashMap::new());
253        assert!(matches!(
254            schema.validate_and_normalize(),
255            Err(SchemaError::EmptySchema)
256        ));
257    }
258
259    #[test]
260    fn test_negative_weight_error() {
261        let mut fields = HashMap::new();
262        fields.insert("a".to_string(), FieldConfig {
263            field_type: FieldType::Text,
264            distance: DistanceType::Semantic,
265            weight: -0.5,
266            embedding: None,
267        });
268
269        let mut schema = SimilaritySchema::new(fields);
270        assert!(matches!(
271            schema.validate_and_normalize(),
272            Err(SchemaError::NegativeWeight(_))
273        ));
274    }
275
276    #[test]
277    fn test_compute_vector_dim() {
278        let mut fields = HashMap::new();
279        fields.insert("name".to_string(), FieldConfig::text(0.5));
280        fields.insert("price".to_string(), FieldConfig::number(0.3, DistanceType::Relative));
281        fields.insert("active".to_string(), FieldConfig::boolean(0.2));
282
283        let schema = SimilaritySchema::new(fields);
284        let dim = schema.compute_vector_dim(64); // 64-dim text embedding
285        assert_eq!(dim, 64 + 1 + 1); // text + number + boolean
286    }
287
288    #[test]
289    fn test_serde_roundtrip() {
290        let mut fields = HashMap::new();
291        fields.insert("name".to_string(), FieldConfig::text(0.5));
292        fields.insert("price".to_string(), FieldConfig::number(0.5, DistanceType::Relative));
293
294        let schema = SimilaritySchema::new(fields);
295        let json = serde_json::to_string(&schema).unwrap();
296        let parsed: SimilaritySchema = serde_json::from_str(&json).unwrap();
297        
298        assert_eq!(schema, parsed);
299    }
300}