1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
8pub struct Field(pub u32);
9
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
12pub enum FieldType {
13 #[serde(rename = "text")]
15 Text,
16 #[serde(rename = "u64")]
18 U64,
19 #[serde(rename = "i64")]
21 I64,
22 #[serde(rename = "f64")]
24 F64,
25 #[serde(rename = "bytes")]
27 Bytes,
28 #[serde(rename = "sparse_vector")]
30 SparseVector,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct FieldEntry {
36 pub name: String,
37 pub field_type: FieldType,
38 pub indexed: bool,
39 pub stored: bool,
40 pub tokenizer: Option<String>,
42 #[serde(default)]
44 pub multi: bool,
45 #[serde(default, skip_serializing_if = "Option::is_none")]
47 pub sparse_vector_config: Option<crate::structures::SparseVectorConfig>,
48}
49
50use super::query_field_router::QueryRouterRule;
51
52#[derive(Debug, Clone, Default, Serialize, Deserialize)]
54pub struct Schema {
55 fields: Vec<FieldEntry>,
56 name_to_field: HashMap<String, Field>,
57 #[serde(default)]
59 default_fields: Vec<Field>,
60 #[serde(default)]
62 query_routers: Vec<QueryRouterRule>,
63}
64
65impl Schema {
66 pub fn builder() -> SchemaBuilder {
67 SchemaBuilder::default()
68 }
69
70 pub fn get_field(&self, name: &str) -> Option<Field> {
71 self.name_to_field.get(name).copied()
72 }
73
74 pub fn get_field_entry(&self, field: Field) -> Option<&FieldEntry> {
75 self.fields.get(field.0 as usize)
76 }
77
78 pub fn get_field_name(&self, field: Field) -> Option<&str> {
79 self.fields.get(field.0 as usize).map(|e| e.name.as_str())
80 }
81
82 pub fn fields(&self) -> impl Iterator<Item = (Field, &FieldEntry)> {
83 self.fields
84 .iter()
85 .enumerate()
86 .map(|(i, e)| (Field(i as u32), e))
87 }
88
89 pub fn num_fields(&self) -> usize {
90 self.fields.len()
91 }
92
93 pub fn default_fields(&self) -> &[Field] {
95 &self.default_fields
96 }
97
98 pub fn set_default_fields(&mut self, fields: Vec<Field>) {
100 self.default_fields = fields;
101 }
102
103 pub fn query_routers(&self) -> &[QueryRouterRule] {
105 &self.query_routers
106 }
107
108 pub fn set_query_routers(&mut self, rules: Vec<QueryRouterRule>) {
110 self.query_routers = rules;
111 }
112}
113
114#[derive(Debug, Default)]
116pub struct SchemaBuilder {
117 fields: Vec<FieldEntry>,
118 default_fields: Vec<String>,
119 query_routers: Vec<QueryRouterRule>,
120}
121
122impl SchemaBuilder {
123 pub fn add_text_field(&mut self, name: &str, indexed: bool, stored: bool) -> Field {
124 self.add_field_with_tokenizer(
125 name,
126 FieldType::Text,
127 indexed,
128 stored,
129 Some("default".to_string()),
130 )
131 }
132
133 pub fn add_text_field_with_tokenizer(
134 &mut self,
135 name: &str,
136 indexed: bool,
137 stored: bool,
138 tokenizer: &str,
139 ) -> Field {
140 self.add_field_with_tokenizer(
141 name,
142 FieldType::Text,
143 indexed,
144 stored,
145 Some(tokenizer.to_string()),
146 )
147 }
148
149 pub fn add_u64_field(&mut self, name: &str, indexed: bool, stored: bool) -> Field {
150 self.add_field(name, FieldType::U64, indexed, stored)
151 }
152
153 pub fn add_i64_field(&mut self, name: &str, indexed: bool, stored: bool) -> Field {
154 self.add_field(name, FieldType::I64, indexed, stored)
155 }
156
157 pub fn add_f64_field(&mut self, name: &str, indexed: bool, stored: bool) -> Field {
158 self.add_field(name, FieldType::F64, indexed, stored)
159 }
160
161 pub fn add_bytes_field(&mut self, name: &str, stored: bool) -> Field {
162 self.add_field(name, FieldType::Bytes, false, stored)
163 }
164
165 pub fn add_sparse_vector_field(&mut self, name: &str, indexed: bool, stored: bool) -> Field {
170 self.add_sparse_vector_field_with_config(
171 name,
172 indexed,
173 stored,
174 crate::structures::SparseVectorConfig::default(),
175 )
176 }
177
178 pub fn add_sparse_vector_field_with_config(
183 &mut self,
184 name: &str,
185 indexed: bool,
186 stored: bool,
187 config: crate::structures::SparseVectorConfig,
188 ) -> Field {
189 let field = Field(self.fields.len() as u32);
190 self.fields.push(FieldEntry {
191 name: name.to_string(),
192 field_type: FieldType::SparseVector,
193 indexed,
194 stored,
195 tokenizer: None,
196 multi: false,
197 sparse_vector_config: Some(config),
198 });
199 field
200 }
201
202 pub fn set_sparse_vector_config(
204 &mut self,
205 field: Field,
206 config: crate::structures::SparseVectorConfig,
207 ) {
208 if let Some(entry) = self.fields.get_mut(field.0 as usize) {
209 entry.sparse_vector_config = Some(config);
210 }
211 }
212
213 fn add_field(
214 &mut self,
215 name: &str,
216 field_type: FieldType,
217 indexed: bool,
218 stored: bool,
219 ) -> Field {
220 self.add_field_with_tokenizer(name, field_type, indexed, stored, None)
221 }
222
223 fn add_field_with_tokenizer(
224 &mut self,
225 name: &str,
226 field_type: FieldType,
227 indexed: bool,
228 stored: bool,
229 tokenizer: Option<String>,
230 ) -> Field {
231 self.add_field_full(name, field_type, indexed, stored, tokenizer, false)
232 }
233
234 fn add_field_full(
235 &mut self,
236 name: &str,
237 field_type: FieldType,
238 indexed: bool,
239 stored: bool,
240 tokenizer: Option<String>,
241 multi: bool,
242 ) -> Field {
243 let field = Field(self.fields.len() as u32);
244 self.fields.push(FieldEntry {
245 name: name.to_string(),
246 field_type,
247 indexed,
248 stored,
249 tokenizer,
250 multi,
251 sparse_vector_config: None,
252 });
253 field
254 }
255
256 pub fn set_multi(&mut self, field: Field, multi: bool) {
258 if let Some(entry) = self.fields.get_mut(field.0 as usize) {
259 entry.multi = multi;
260 }
261 }
262
263 pub fn set_default_fields(&mut self, field_names: Vec<String>) {
265 self.default_fields = field_names;
266 }
267
268 pub fn set_query_routers(&mut self, rules: Vec<QueryRouterRule>) {
270 self.query_routers = rules;
271 }
272
273 pub fn build(self) -> Schema {
274 let mut name_to_field = HashMap::new();
275 for (i, entry) in self.fields.iter().enumerate() {
276 name_to_field.insert(entry.name.clone(), Field(i as u32));
277 }
278
279 let default_fields: Vec<Field> = self
281 .default_fields
282 .iter()
283 .filter_map(|name| name_to_field.get(name).copied())
284 .collect();
285
286 Schema {
287 fields: self.fields,
288 name_to_field,
289 default_fields,
290 query_routers: self.query_routers,
291 }
292 }
293}
294
295#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
297pub enum FieldValue {
298 #[serde(rename = "text")]
299 Text(String),
300 #[serde(rename = "u64")]
301 U64(u64),
302 #[serde(rename = "i64")]
303 I64(i64),
304 #[serde(rename = "f64")]
305 F64(f64),
306 #[serde(rename = "bytes")]
307 Bytes(Vec<u8>),
308 #[serde(rename = "sparse_vector")]
310 SparseVector { indices: Vec<u32>, values: Vec<f32> },
311}
312
313impl FieldValue {
314 pub fn as_text(&self) -> Option<&str> {
315 match self {
316 FieldValue::Text(s) => Some(s),
317 _ => None,
318 }
319 }
320
321 pub fn as_u64(&self) -> Option<u64> {
322 match self {
323 FieldValue::U64(v) => Some(*v),
324 _ => None,
325 }
326 }
327
328 pub fn as_i64(&self) -> Option<i64> {
329 match self {
330 FieldValue::I64(v) => Some(*v),
331 _ => None,
332 }
333 }
334
335 pub fn as_f64(&self) -> Option<f64> {
336 match self {
337 FieldValue::F64(v) => Some(*v),
338 _ => None,
339 }
340 }
341
342 pub fn as_bytes(&self) -> Option<&[u8]> {
343 match self {
344 FieldValue::Bytes(b) => Some(b),
345 _ => None,
346 }
347 }
348
349 pub fn as_sparse_vector(&self) -> Option<(&[u32], &[f32])> {
350 match self {
351 FieldValue::SparseVector { indices, values } => Some((indices, values)),
352 _ => None,
353 }
354 }
355}
356
357#[derive(Debug, Clone, Default, Serialize, Deserialize)]
359pub struct Document {
360 field_values: Vec<(Field, FieldValue)>,
361}
362
363impl Document {
364 pub fn new() -> Self {
365 Self::default()
366 }
367
368 pub fn add_text(&mut self, field: Field, value: impl Into<String>) {
369 self.field_values
370 .push((field, FieldValue::Text(value.into())));
371 }
372
373 pub fn add_u64(&mut self, field: Field, value: u64) {
374 self.field_values.push((field, FieldValue::U64(value)));
375 }
376
377 pub fn add_i64(&mut self, field: Field, value: i64) {
378 self.field_values.push((field, FieldValue::I64(value)));
379 }
380
381 pub fn add_f64(&mut self, field: Field, value: f64) {
382 self.field_values.push((field, FieldValue::F64(value)));
383 }
384
385 pub fn add_bytes(&mut self, field: Field, value: Vec<u8>) {
386 self.field_values.push((field, FieldValue::Bytes(value)));
387 }
388
389 pub fn add_sparse_vector(&mut self, field: Field, indices: Vec<u32>, values: Vec<f32>) {
390 debug_assert_eq!(
391 indices.len(),
392 values.len(),
393 "Sparse vector indices and values must have same length"
394 );
395 self.field_values
396 .push((field, FieldValue::SparseVector { indices, values }));
397 }
398
399 pub fn get_first(&self, field: Field) -> Option<&FieldValue> {
400 self.field_values
401 .iter()
402 .find(|(f, _)| *f == field)
403 .map(|(_, v)| v)
404 }
405
406 pub fn get_all(&self, field: Field) -> impl Iterator<Item = &FieldValue> {
407 self.field_values
408 .iter()
409 .filter(move |(f, _)| *f == field)
410 .map(|(_, v)| v)
411 }
412
413 pub fn field_values(&self) -> &[(Field, FieldValue)] {
414 &self.field_values
415 }
416
417 pub fn to_json(&self, schema: &Schema) -> serde_json::Value {
423 use std::collections::HashMap;
424
425 let mut field_values_map: HashMap<Field, (String, bool, Vec<serde_json::Value>)> =
427 HashMap::new();
428
429 for (field, value) in &self.field_values {
430 if let Some(entry) = schema.get_field_entry(*field) {
431 let json_value = match value {
432 FieldValue::Text(s) => serde_json::Value::String(s.clone()),
433 FieldValue::U64(n) => serde_json::Value::Number((*n).into()),
434 FieldValue::I64(n) => serde_json::Value::Number((*n).into()),
435 FieldValue::F64(n) => serde_json::json!(n),
436 FieldValue::Bytes(b) => {
437 use base64::Engine;
438 serde_json::Value::String(
439 base64::engine::general_purpose::STANDARD.encode(b),
440 )
441 }
442 FieldValue::SparseVector { indices, values } => {
443 serde_json::json!({
444 "indices": indices,
445 "values": values
446 })
447 }
448 };
449 field_values_map
450 .entry(*field)
451 .or_insert_with(|| (entry.name.clone(), entry.multi, Vec::new()))
452 .2
453 .push(json_value);
454 }
455 }
456
457 let mut map = serde_json::Map::new();
459 for (_field, (name, is_multi, values)) in field_values_map {
460 let json_value = if is_multi || values.len() > 1 {
461 serde_json::Value::Array(values)
462 } else {
463 values.into_iter().next().unwrap()
464 };
465 map.insert(name, json_value);
466 }
467
468 serde_json::Value::Object(map)
469 }
470
471 pub fn from_json(json: &serde_json::Value, schema: &Schema) -> Option<Self> {
480 let obj = json.as_object()?;
481 let mut doc = Document::new();
482
483 for (key, value) in obj {
484 if let Some(field) = schema.get_field(key) {
485 let field_entry = schema.get_field_entry(field)?;
486 Self::add_json_value(&mut doc, field, &field_entry.field_type, value);
487 }
488 }
489
490 Some(doc)
491 }
492
493 fn add_json_value(
495 doc: &mut Document,
496 field: Field,
497 field_type: &FieldType,
498 value: &serde_json::Value,
499 ) {
500 match value {
501 serde_json::Value::String(s) => {
502 if matches!(field_type, FieldType::Text) {
503 doc.add_text(field, s.clone());
504 }
505 }
506 serde_json::Value::Number(n) => {
507 match field_type {
508 FieldType::I64 => {
509 if let Some(i) = n.as_i64() {
510 doc.add_i64(field, i);
511 }
512 }
513 FieldType::U64 => {
514 if let Some(u) = n.as_u64() {
515 doc.add_u64(field, u);
516 } else if let Some(i) = n.as_i64() {
517 if i >= 0 {
519 doc.add_u64(field, i as u64);
520 }
521 }
522 }
523 FieldType::F64 => {
524 if let Some(f) = n.as_f64() {
525 doc.add_f64(field, f);
526 }
527 }
528 _ => {}
529 }
530 }
531 serde_json::Value::Array(arr) => {
533 for item in arr {
534 Self::add_json_value(doc, field, field_type, item);
535 }
536 }
537 serde_json::Value::Object(obj) if matches!(field_type, FieldType::SparseVector) => {
539 if let (Some(indices_val), Some(values_val)) =
540 (obj.get("indices"), obj.get("values"))
541 {
542 let indices: Vec<u32> = indices_val
543 .as_array()
544 .map(|arr| {
545 arr.iter()
546 .filter_map(|v| v.as_u64().map(|n| n as u32))
547 .collect()
548 })
549 .unwrap_or_default();
550 let values: Vec<f32> = values_val
551 .as_array()
552 .map(|arr| {
553 arr.iter()
554 .filter_map(|v| v.as_f64().map(|n| n as f32))
555 .collect()
556 })
557 .unwrap_or_default();
558 if indices.len() == values.len() {
559 doc.add_sparse_vector(field, indices, values);
560 }
561 }
562 }
563 serde_json::Value::Object(_) => {}
564 _ => {}
565 }
566 }
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572
573 #[test]
574 fn test_schema_builder() {
575 let mut builder = Schema::builder();
576 let title = builder.add_text_field("title", true, true);
577 let body = builder.add_text_field("body", true, false);
578 let count = builder.add_u64_field("count", true, true);
579 let schema = builder.build();
580
581 assert_eq!(schema.get_field("title"), Some(title));
582 assert_eq!(schema.get_field("body"), Some(body));
583 assert_eq!(schema.get_field("count"), Some(count));
584 assert_eq!(schema.get_field("nonexistent"), None);
585 }
586
587 #[test]
588 fn test_document() {
589 let mut builder = Schema::builder();
590 let title = builder.add_text_field("title", true, true);
591 let count = builder.add_u64_field("count", true, true);
592 let _schema = builder.build();
593
594 let mut doc = Document::new();
595 doc.add_text(title, "Hello World");
596 doc.add_u64(count, 42);
597
598 assert_eq!(doc.get_first(title).unwrap().as_text(), Some("Hello World"));
599 assert_eq!(doc.get_first(count).unwrap().as_u64(), Some(42));
600 }
601
602 #[test]
603 fn test_document_serialization() {
604 let mut builder = Schema::builder();
605 let title = builder.add_text_field("title", true, true);
606 let count = builder.add_u64_field("count", true, true);
607 let _schema = builder.build();
608
609 let mut doc = Document::new();
610 doc.add_text(title, "Hello World");
611 doc.add_u64(count, 42);
612
613 let json = serde_json::to_string(&doc).unwrap();
615 println!("Serialized doc: {}", json);
616
617 let doc2: Document = serde_json::from_str(&json).unwrap();
619 assert_eq!(
620 doc2.field_values().len(),
621 2,
622 "Should have 2 field values after deserialization"
623 );
624 assert_eq!(
625 doc2.get_first(title).unwrap().as_text(),
626 Some("Hello World")
627 );
628 assert_eq!(doc2.get_first(count).unwrap().as_u64(), Some(42));
629 }
630
631 #[test]
632 fn test_multivalue_field() {
633 let mut builder = Schema::builder();
634 let uris = builder.add_text_field("uris", true, true);
635 let title = builder.add_text_field("title", true, true);
636 let schema = builder.build();
637
638 let mut doc = Document::new();
640 doc.add_text(uris, "one");
641 doc.add_text(uris, "two");
642 doc.add_text(title, "Test Document");
643
644 assert_eq!(doc.get_first(uris).unwrap().as_text(), Some("one"));
646
647 let all_uris: Vec<_> = doc.get_all(uris).collect();
649 assert_eq!(all_uris.len(), 2);
650 assert_eq!(all_uris[0].as_text(), Some("one"));
651 assert_eq!(all_uris[1].as_text(), Some("two"));
652
653 let json = doc.to_json(&schema);
655 let uris_json = json.get("uris").unwrap();
656 assert!(uris_json.is_array(), "Multi-value field should be an array");
657 let uris_arr = uris_json.as_array().unwrap();
658 assert_eq!(uris_arr.len(), 2);
659 assert_eq!(uris_arr[0].as_str(), Some("one"));
660 assert_eq!(uris_arr[1].as_str(), Some("two"));
661
662 let title_json = json.get("title").unwrap();
664 assert!(
665 title_json.is_string(),
666 "Single-value field should be a string"
667 );
668 assert_eq!(title_json.as_str(), Some("Test Document"));
669 }
670
671 #[test]
672 fn test_multivalue_from_json() {
673 let mut builder = Schema::builder();
674 let uris = builder.add_text_field("uris", true, true);
675 let title = builder.add_text_field("title", true, true);
676 let schema = builder.build();
677
678 let json = serde_json::json!({
680 "uris": ["one", "two"],
681 "title": "Test Document"
682 });
683
684 let doc = Document::from_json(&json, &schema).unwrap();
686
687 let all_uris: Vec<_> = doc.get_all(uris).collect();
689 assert_eq!(all_uris.len(), 2);
690 assert_eq!(all_uris[0].as_text(), Some("one"));
691 assert_eq!(all_uris[1].as_text(), Some("two"));
692
693 assert_eq!(
695 doc.get_first(title).unwrap().as_text(),
696 Some("Test Document")
697 );
698
699 let json_out = doc.to_json(&schema);
701 let uris_out = json_out.get("uris").unwrap().as_array().unwrap();
702 assert_eq!(uris_out.len(), 2);
703 assert_eq!(uris_out[0].as_str(), Some("one"));
704 assert_eq!(uris_out[1].as_str(), Some("two"));
705 }
706
707 #[test]
708 fn test_multi_attribute_forces_array() {
709 let mut builder = Schema::builder();
712 let uris = builder.add_text_field("uris", true, true);
713 builder.set_multi(uris, true); let title = builder.add_text_field("title", true, true);
715 let schema = builder.build();
716
717 assert!(schema.get_field_entry(uris).unwrap().multi);
719 assert!(!schema.get_field_entry(title).unwrap().multi);
720
721 let mut doc = Document::new();
723 doc.add_text(uris, "only_one");
724 doc.add_text(title, "Test Document");
725
726 let json = doc.to_json(&schema);
728
729 let uris_json = json.get("uris").unwrap();
730 assert!(
731 uris_json.is_array(),
732 "Multi field should be array even with single value"
733 );
734 let uris_arr = uris_json.as_array().unwrap();
735 assert_eq!(uris_arr.len(), 1);
736 assert_eq!(uris_arr[0].as_str(), Some("only_one"));
737
738 let title_json = json.get("title").unwrap();
740 assert!(
741 title_json.is_string(),
742 "Non-multi single-value field should be a string"
743 );
744 assert_eq!(title_json.as_str(), Some("Test Document"));
745 }
746
747 #[test]
748 fn test_sparse_vector_field() {
749 let mut builder = Schema::builder();
750 let embedding = builder.add_sparse_vector_field("embedding", true, true);
751 let title = builder.add_text_field("title", true, true);
752 let schema = builder.build();
753
754 assert_eq!(schema.get_field("embedding"), Some(embedding));
755 assert_eq!(
756 schema.get_field_entry(embedding).unwrap().field_type,
757 FieldType::SparseVector
758 );
759
760 let mut doc = Document::new();
762 doc.add_sparse_vector(embedding, vec![0, 5, 10], vec![1.0, 2.5, 0.5]);
763 doc.add_text(title, "Test Document");
764
765 let (indices, values) = doc
767 .get_first(embedding)
768 .unwrap()
769 .as_sparse_vector()
770 .unwrap();
771 assert_eq!(indices, &[0, 5, 10]);
772 assert_eq!(values, &[1.0, 2.5, 0.5]);
773
774 let json = doc.to_json(&schema);
776 let embedding_json = json.get("embedding").unwrap();
777 assert!(embedding_json.is_object());
778 assert_eq!(
779 embedding_json
780 .get("indices")
781 .unwrap()
782 .as_array()
783 .unwrap()
784 .len(),
785 3
786 );
787
788 let doc2 = Document::from_json(&json, &schema).unwrap();
790 let (indices2, values2) = doc2
791 .get_first(embedding)
792 .unwrap()
793 .as_sparse_vector()
794 .unwrap();
795 assert_eq!(indices2, &[0, 5, 10]);
796 assert!((values2[0] - 1.0).abs() < 1e-6);
797 assert!((values2[1] - 2.5).abs() < 1e-6);
798 assert!((values2[2] - 0.5).abs() < 1e-6);
799 }
800}