1use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15pub struct SimilaritySchema {
16 #[serde(default = "default_version")]
18 pub version: u32,
19
20 pub fields: HashMap<String, FieldConfig>,
22}
23
24fn default_version() -> u32 {
25 1
26}
27
28impl SimilaritySchema {
29 pub fn new(fields: HashMap<String, FieldConfig>) -> Self {
31 Self {
32 version: 1,
33 fields,
34 }
35 }
36
37 pub fn validate_and_normalize(&mut self) -> Result<(), SchemaError> {
41 if self.fields.is_empty() {
42 return Err(SchemaError::EmptySchema);
43 }
44
45 for (name, config) in &self.fields {
47 if config.weight < 0.0 {
48 return Err(SchemaError::NegativeWeight(name.clone()));
49 }
50 }
51
52 let weight_sum: f32 = self.fields.values().map(|f| f.weight).sum();
54
55 if weight_sum <= 0.0 {
56 return Err(SchemaError::ZeroTotalWeight);
57 }
58
59 if (weight_sum - 1.0).abs() > 0.001 {
61 for config in self.fields.values_mut() {
62 config.weight /= weight_sum;
63 }
64 }
65
66 Ok(())
67 }
68
69 pub fn compute_vector_dim(&self, text_dim: usize) -> usize {
72 self.fields.values().map(|config| {
73 match config.field_type {
74 FieldType::Text => text_dim,
75 FieldType::Number => 1,
76 FieldType::Categorical => 64, FieldType::Boolean => 1,
78 }
79 }).sum()
80 }
81
82 pub fn sorted_field_names(&self) -> Vec<&String> {
84 let mut names: Vec<_> = self.fields.keys().collect();
85 names.sort();
86 names
87 }
88
89 pub fn get_field(&self, name: &str) -> Option<&FieldConfig> {
91 self.fields.get(name)
92 }
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
97pub struct FieldConfig {
98 #[serde(rename = "type")]
100 pub field_type: FieldType,
101
102 #[serde(default)]
104 pub distance: DistanceType,
105
106 #[serde(default = "default_weight")]
108 pub weight: f32,
109
110 #[serde(skip_serializing_if = "Option::is_none")]
112 pub embedding: Option<EmbeddingType>,
113}
114
115fn default_weight() -> f32 {
116 1.0
117}
118
119impl FieldConfig {
120 pub fn text(weight: f32) -> Self {
122 Self {
123 field_type: FieldType::Text,
124 distance: DistanceType::Semantic,
125 weight,
126 embedding: Some(EmbeddingType::Semantic),
127 }
128 }
129
130 pub fn number(weight: f32, distance: DistanceType) -> Self {
132 Self {
133 field_type: FieldType::Number,
134 distance,
135 weight,
136 embedding: None,
137 }
138 }
139
140 pub fn categorical(weight: f32) -> Self {
142 Self {
143 field_type: FieldType::Categorical,
144 distance: DistanceType::Exact,
145 weight,
146 embedding: None,
147 }
148 }
149
150 pub fn boolean(weight: f32) -> Self {
152 Self {
153 field_type: FieldType::Boolean,
154 distance: DistanceType::Exact,
155 weight,
156 embedding: None,
157 }
158 }
159}
160
161#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
163#[serde(rename_all = "lowercase")]
164pub enum FieldType {
165 Text,
167 Number,
169 Categorical,
171 Boolean,
173}
174
175#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
177#[serde(rename_all = "lowercase")]
178pub enum DistanceType {
179 #[default]
181 Semantic,
182 Absolute,
184 Relative,
186 Exact,
188 Overlap,
190}
191
192#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
194#[serde(rename_all = "lowercase")]
195pub enum EmbeddingType {
196 Semantic,
198 Exact,
200}
201
202#[derive(Debug, Clone, thiserror::Error)]
204pub enum SchemaError {
205 #[error("Schema cannot be empty")]
206 EmptySchema,
207
208 #[error("Field '{0}' has negative weight")]
209 NegativeWeight(String),
210
211 #[error("Total weight cannot be zero")]
212 ZeroTotalWeight,
213
214 #[error("Field '{0}' not found in schema")]
215 FieldNotFound(String),
216
217 #[error("Invalid field type for distance metric")]
218 InvalidDistanceForType,
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn test_schema_creation() {
227 let mut fields = HashMap::new();
228 fields.insert("name".to_string(), FieldConfig::text(0.5));
229 fields.insert("price".to_string(), FieldConfig::number(0.3, DistanceType::Relative));
230 fields.insert("category".to_string(), FieldConfig::categorical(0.2));
231
232 let schema = SimilaritySchema::new(fields);
233 assert_eq!(schema.version, 1);
234 assert_eq!(schema.fields.len(), 3);
235 }
236
237 #[test]
238 fn test_schema_normalization() {
239 let mut fields = HashMap::new();
240 fields.insert("a".to_string(), FieldConfig::text(2.0));
241 fields.insert("b".to_string(), FieldConfig::number(2.0, DistanceType::Absolute));
242
243 let mut schema = SimilaritySchema::new(fields);
244 schema.validate_and_normalize().unwrap();
245
246 let weight_sum: f32 = schema.fields.values().map(|f| f.weight).sum();
247 assert!((weight_sum - 1.0).abs() < 0.001);
248 }
249
250 #[test]
251 fn test_empty_schema_error() {
252 let mut schema = SimilaritySchema::new(HashMap::new());
253 assert!(matches!(
254 schema.validate_and_normalize(),
255 Err(SchemaError::EmptySchema)
256 ));
257 }
258
259 #[test]
260 fn test_negative_weight_error() {
261 let mut fields = HashMap::new();
262 fields.insert("a".to_string(), FieldConfig {
263 field_type: FieldType::Text,
264 distance: DistanceType::Semantic,
265 weight: -0.5,
266 embedding: None,
267 });
268
269 let mut schema = SimilaritySchema::new(fields);
270 assert!(matches!(
271 schema.validate_and_normalize(),
272 Err(SchemaError::NegativeWeight(_))
273 ));
274 }
275
276 #[test]
277 fn test_compute_vector_dim() {
278 let mut fields = HashMap::new();
279 fields.insert("name".to_string(), FieldConfig::text(0.5));
280 fields.insert("price".to_string(), FieldConfig::number(0.3, DistanceType::Relative));
281 fields.insert("active".to_string(), FieldConfig::boolean(0.2));
282
283 let schema = SimilaritySchema::new(fields);
284 let dim = schema.compute_vector_dim(64); assert_eq!(dim, 64 + 1 + 1); }
287
288 #[test]
289 fn test_serde_roundtrip() {
290 let mut fields = HashMap::new();
291 fields.insert("name".to_string(), FieldConfig::text(0.5));
292 fields.insert("price".to_string(), FieldConfig::number(0.5, DistanceType::Relative));
293
294 let schema = SimilaritySchema::new(fields);
295 let json = serde_json::to_string(&schema).unwrap();
296 let parsed: SimilaritySchema = serde_json::from_str(&json).unwrap();
297
298 assert_eq!(schema, parsed);
299 }
300}