1use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16pub struct SimilaritySchema {
17 #[serde(default = "default_version")]
19 pub version: u32,
20
21 pub fields: HashMap<String, FieldConfig>,
23}
24
25fn default_version() -> u32 {
26 1
27}
28
29impl SimilaritySchema {
30 pub fn new(fields: HashMap<String, FieldConfig>) -> Self {
32 Self {
33 version: 1,
34 fields,
35 }
36 }
37
38 pub fn validate_and_normalize(&mut self) -> Result<(), SchemaError> {
40 if self.fields.is_empty() {
41 return Err(SchemaError::EmptySchema);
42 }
43
44 for (name, config) in &self.fields {
46 if config.weight < 0.0 {
47 return Err(SchemaError::NegativeWeight(name.clone()));
48 }
49 }
50
51 let weight_sum: f32 = self.fields.values().map(|f| f.weight).sum();
53
54 if weight_sum <= 0.0 {
55 return Err(SchemaError::ZeroTotalWeight);
56 }
57
58 if (weight_sum - 1.0).abs() > 0.001 {
60 for config in self.fields.values_mut() {
61 config.weight /= weight_sum;
62 }
63 }
64
65 Ok(())
66 }
67
68 pub fn sorted_field_names(&self) -> Vec<&String> {
70 let mut names: Vec<_> = self.fields.keys().collect();
71 names.sort();
72 names
73 }
74
75 pub fn get_field(&self, name: &str) -> Option<&FieldConfig> {
77 self.fields.get(name)
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
83pub struct FieldConfig {
84 #[serde(rename = "type")]
86 pub field_type: FieldType,
87
88 #[serde(default)]
90 pub distance: DistanceType,
91
92 #[serde(default = "default_weight")]
94 pub weight: f32,
95}
96
97fn default_weight() -> f32 {
98 1.0
99}
100
101impl FieldConfig {
102 pub fn text(weight: f32) -> Self {
107 Self {
108 field_type: FieldType::Text,
109 distance: DistanceType::Semantic,
110 weight,
111 }
112 }
113
114 pub fn number(weight: f32, distance: DistanceType) -> Self {
119 Self {
120 field_type: FieldType::Number,
121 distance,
122 weight,
123 }
124 }
125
126 pub fn categorical(weight: f32) -> Self {
131 Self {
132 field_type: FieldType::Categorical,
133 distance: DistanceType::Exact,
134 weight,
135 }
136 }
137
138 pub fn boolean(weight: f32) -> Self {
140 Self {
141 field_type: FieldType::Boolean,
142 distance: DistanceType::Exact,
143 weight,
144 }
145 }
146}
147
148#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
150#[serde(rename_all = "lowercase")]
151pub enum FieldType {
152 Text,
154 Number,
156 Categorical,
158 Boolean,
160}
161
162#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
164#[serde(rename_all = "lowercase")]
165pub enum DistanceType {
166 #[default]
168 Semantic,
169 Absolute,
171 Relative,
173 Exact,
175 Overlap,
177}
178
179#[derive(Debug, Clone, thiserror::Error)]
181pub enum SchemaError {
182 #[error("Schema cannot be empty")]
183 EmptySchema,
184
185 #[error("Field '{0}' has negative weight")]
186 NegativeWeight(String),
187
188 #[error("Total weight cannot be zero")]
189 ZeroTotalWeight,
190
191 #[error("Field '{0}' not found in schema")]
192 FieldNotFound(String),
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 #[test]
200 fn test_schema_creation() {
201 let mut fields = HashMap::new();
202 fields.insert("name".to_string(), FieldConfig::text(0.5));
203 fields.insert("price".to_string(), FieldConfig::number(0.3, DistanceType::Relative));
204 fields.insert("category".to_string(), FieldConfig::categorical(0.2));
205
206 let schema = SimilaritySchema::new(fields);
207 assert_eq!(schema.version, 1);
208 assert_eq!(schema.fields.len(), 3);
209 }
210
211 #[test]
212 fn test_schema_normalization() {
213 let mut fields = HashMap::new();
214 fields.insert("a".to_string(), FieldConfig::text(2.0));
215 fields.insert("b".to_string(), FieldConfig::number(2.0, DistanceType::Absolute));
216
217 let mut schema = SimilaritySchema::new(fields);
218 schema.validate_and_normalize().unwrap();
219
220 let weight_sum: f32 = schema.fields.values().map(|f| f.weight).sum();
221 assert!((weight_sum - 1.0).abs() < 0.001);
222 }
223
224 #[test]
225 fn test_empty_schema_error() {
226 let mut schema = SimilaritySchema::new(HashMap::new());
227 assert!(matches!(
228 schema.validate_and_normalize(),
229 Err(SchemaError::EmptySchema)
230 ));
231 }
232
233 #[test]
234 fn test_serde_roundtrip() {
235 let mut fields = HashMap::new();
236 fields.insert("name".to_string(), FieldConfig::text(0.5));
237 fields.insert("price".to_string(), FieldConfig::number(0.5, DistanceType::Relative));
238
239 let schema = SimilaritySchema::new(fields);
240 let json = serde_json::to_string(&schema).unwrap();
241 let parsed: SimilaritySchema = serde_json::from_str(&json).unwrap();
242
243 assert_eq!(schema, parsed);
244 }
245}