1use std::{collections::HashSet, fs, path::Path};
25
26use serde::{Deserialize, Serialize};
27
28use crate::error::{Error, Result};
29
30#[derive(Debug, Clone, Serialize)]
33#[serde(untagged)]
34pub enum Prefix {
35 Single(String),
37 Multi(Vec<String>),
39}
40
41impl<'de> Deserialize<'de> for Prefix {
42 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
43 where
44 D: serde::Deserializer<'de>,
45 {
46 use serde::de;
47
48 struct PrefixVisitor;
49
50 impl<'de> de::Visitor<'de> for PrefixVisitor {
51 type Value = Prefix;
52
53 fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 formatter.write_str("a string or a list of strings")
55 }
56
57 fn visit_str<E: de::Error>(self, v: &str) -> std::result::Result<Prefix, E> {
58 Ok(Prefix::Single(v.to_owned()))
59 }
60
61 fn visit_string<E: de::Error>(self, v: String) -> std::result::Result<Prefix, E> {
62 Ok(Prefix::Single(v))
63 }
64
65 fn visit_seq<A: de::SeqAccess<'de>>(
66 self,
67 mut seq: A,
68 ) -> std::result::Result<Prefix, A::Error> {
69 let mut items = Vec::new();
70 while let Some(item) = seq.next_element::<String>()? {
71 items.push(item);
72 }
73 Ok(Prefix::Multi(items))
74 }
75 }
76
77 deserializer.deserialize_any(PrefixVisitor)
78 }
79}
80
81impl Default for Prefix {
82 fn default() -> Self {
83 Prefix::Single("rvl".to_owned())
84 }
85}
86
87impl Prefix {
88 pub fn first(&self) -> &str {
90 match self {
91 Prefix::Single(s) => s,
92 Prefix::Multi(v) => v.first().map(String::as_str).unwrap_or(""),
93 }
94 }
95
96 pub fn all(&self) -> Vec<&str> {
98 match self {
99 Prefix::Single(s) => vec![s.as_str()],
100 Prefix::Multi(v) => v.iter().map(String::as_str).collect(),
101 }
102 }
103
104 pub fn len(&self) -> usize {
106 match self {
107 Prefix::Single(_) => 1,
108 Prefix::Multi(v) => v.len(),
109 }
110 }
111
112 pub fn is_empty(&self) -> bool {
114 match self {
115 Prefix::Single(s) => s.is_empty(),
116 Prefix::Multi(v) => v.is_empty(),
117 }
118 }
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct IndexSchema {
124 pub index: IndexDefinition,
126 #[serde(default)]
128 pub fields: Vec<Field>,
129}
130
131impl IndexSchema {
132 pub fn from_yaml_str(input: &str) -> Result<Self> {
134 let schema: Self = serde_yaml::from_str(input)?;
135 schema.validate()?;
136 Ok(schema)
137 }
138
139 pub fn from_yaml_file(path: impl AsRef<Path>) -> Result<Self> {
141 let contents = fs::read_to_string(path)?;
142 Self::from_yaml_str(&contents)
143 }
144
145 pub fn from_json_value(value: serde_json::Value) -> Result<Self> {
147 let schema: Self = serde_json::from_value(value)?;
148 schema.validate()?;
149 Ok(schema)
150 }
151
152 pub fn to_json_value(&self) -> Result<serde_json::Value> {
154 Ok(serde_json::to_value(self)?)
155 }
156
157 pub fn to_yaml_string(&self) -> Result<String> {
159 Ok(serde_yaml::to_string(self)?)
160 }
161
162 pub fn to_yaml_file(&self, path: impl AsRef<Path>) -> Result<()> {
164 fs::write(path, self.to_yaml_string()?)?;
165 Ok(())
166 }
167
168 pub fn validate(&self) -> Result<()> {
170 if self.index.name.trim().is_empty() {
171 return Err(Error::SchemaValidation(
172 "index name cannot be empty".to_owned(),
173 ));
174 }
175 let mut seen = HashSet::new();
176 for field in &self.fields {
177 if !seen.insert(field.name.clone()) {
178 return Err(Error::SchemaValidation(format!(
179 "duplicate field name '{}'",
180 field.name
181 )));
182 }
183
184 if field.name.trim().is_empty() {
185 return Err(Error::SchemaValidation(
186 "field names cannot be empty".to_owned(),
187 ));
188 }
189
190 if let FieldKind::Vector { attrs } = &field.kind {
191 if attrs.dims == 0 {
192 return Err(Error::SchemaValidation(format!(
193 "vector field '{}' must use dims > 0",
194 field.name
195 )));
196 }
197 attrs.validate_svs()?;
198 }
199 }
200
201 Ok(())
202 }
203
204 pub fn field(&self, name: &str) -> Option<&Field> {
206 self.fields.iter().find(|field| field.name == name)
207 }
208
209 pub fn add_field(&mut self, field: Field) -> Result<()> {
219 if self.fields.iter().any(|f| f.name == field.name) {
220 return Err(Error::SchemaValidation(format!(
221 "duplicate field name '{}'",
222 field.name
223 )));
224 }
225 if field.name.trim().is_empty() {
226 return Err(Error::SchemaValidation(
227 "field names cannot be empty".to_owned(),
228 ));
229 }
230 if let FieldKind::Vector { attrs } = &field.kind {
231 if attrs.dims == 0 {
232 return Err(Error::SchemaValidation(format!(
233 "vector field '{}' must use dims > 0",
234 field.name
235 )));
236 }
237 }
238 self.fields.push(field);
239 Ok(())
240 }
241
242 pub fn add_fields(&mut self, fields: Vec<Field>) -> Result<()> {
247 for field in fields {
248 self.add_field(field)?;
249 }
250 Ok(())
251 }
252
253 pub fn remove_field(&mut self, name: &str) -> bool {
255 let before = self.fields.len();
256 self.fields.retain(|f| f.name != name);
257 self.fields.len() != before
258 }
259
260 pub(crate) fn redis_schema_args(&self) -> Vec<String> {
262 self.fields
263 .iter()
264 .flat_map(|field| field.redis_args(self.index.storage_type))
265 .collect()
266 }
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct IndexDefinition {
272 pub name: String,
274 #[serde(default)]
279 pub prefix: Prefix,
280 #[serde(default = "default_key_separator")]
282 pub key_separator: String,
283 #[serde(default = "default_storage_type")]
285 pub storage_type: StorageType,
286 #[serde(default)]
288 pub stopwords: Vec<String>,
289}
290
291#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
293#[serde(rename_all = "lowercase")]
294pub enum StorageType {
295 Hash,
297 Json,
299}
300
301impl StorageType {
302 pub(crate) fn redis_name(self) -> &'static str {
303 match self {
304 Self::Hash => "HASH",
305 Self::Json => "JSON",
306 }
307 }
308}
309
310fn default_key_separator() -> String {
311 ":".to_owned()
312}
313
314fn default_storage_type() -> StorageType {
315 StorageType::Hash
316}
317
318#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct Field {
321 pub name: String,
323 #[serde(default)]
325 pub path: Option<String>,
326 #[serde(flatten)]
328 pub kind: FieldKind,
329}
330
331impl Field {
332 pub(crate) fn redis_args(&self, storage_type: StorageType) -> Vec<String> {
333 let mut args = Vec::new();
334 match (storage_type, self.path.as_deref()) {
335 (StorageType::Json, Some(path)) => {
336 args.push(path.to_owned());
337 args.push("AS".to_owned());
338 args.push(self.name.clone());
339 }
340 _ => args.push(self.name.clone()),
341 }
342 self.kind.push_redis_args(&mut args);
343 args
344 }
345}
346
347#[derive(Debug, Clone, Serialize, Deserialize)]
349#[serde(tag = "type", rename_all = "snake_case")]
350pub enum FieldKind {
351 Tag {
353 #[serde(default)]
355 attrs: TagFieldAttributes,
356 },
357 Text {
359 #[serde(default)]
361 attrs: TextFieldAttributes,
362 },
363 Numeric {
365 #[serde(default)]
367 attrs: NumericFieldAttributes,
368 },
369 Geo {
371 #[serde(default)]
373 attrs: GeoFieldAttributes,
374 },
375 Timestamp {
377 #[serde(default)]
379 attrs: TimestampFieldAttributes,
380 },
381 Vector {
383 attrs: VectorFieldAttributes,
385 },
386}
387
388impl FieldKind {
389 fn push_redis_args(&self, args: &mut Vec<String>) {
390 match self {
391 Self::Tag { attrs } => {
392 args.push("TAG".to_owned());
393 attrs.push_redis_args(args);
394 }
395 Self::Text { attrs } => {
396 args.push("TEXT".to_owned());
397 attrs.push_redis_args(args);
398 }
399 Self::Numeric { attrs } => {
400 args.push("NUMERIC".to_owned());
401 attrs.push_redis_args(args);
402 }
403 Self::Geo { attrs } => {
404 args.push("GEO".to_owned());
405 attrs.push_redis_args(args);
406 }
407 Self::Timestamp { attrs } => {
408 args.push("NUMERIC".to_owned());
409 attrs.push_redis_args(args);
410 }
411 Self::Vector { attrs } => {
412 args.push("VECTOR".to_owned());
413 args.push(attrs.algorithm.redis_name().to_owned());
414 let vector_args = attrs.redis_attribute_pairs();
415 args.push(vector_args.len().to_string());
416 args.extend(vector_args);
417 }
418 }
419 }
420}
421
422#[derive(Debug, Clone, Default, Serialize, Deserialize)]
424pub struct TagFieldAttributes {
425 pub separator: Option<String>,
427 #[serde(default)]
429 pub case_sensitive: bool,
430 #[serde(default)]
432 pub sortable: bool,
433 #[serde(default)]
435 pub no_index: bool,
436 #[serde(default)]
438 pub index_missing: bool,
439 #[serde(default)]
441 pub index_empty: bool,
442}
443
444impl TagFieldAttributes {
445 fn push_redis_args(&self, args: &mut Vec<String>) {
446 if let Some(separator) = &self.separator {
447 args.push("SEPARATOR".to_owned());
448 args.push(separator.clone());
449 }
450 if self.case_sensitive {
451 args.push("CASESENSITIVE".to_owned());
452 }
453 if self.sortable {
454 args.push("SORTABLE".to_owned());
455 }
456 if self.no_index {
457 args.push("NOINDEX".to_owned());
458 }
459 if self.index_missing {
460 args.push("INDEXMISSING".to_owned());
461 }
462 if self.index_empty {
463 args.push("INDEXEMPTY".to_owned());
464 }
465 }
466}
467
468#[derive(Debug, Clone, Default, Serialize, Deserialize)]
470pub struct TextFieldAttributes {
471 pub weight: Option<f32>,
473 #[serde(default)]
475 pub sortable: bool,
476 #[serde(default)]
478 pub no_stem: bool,
479 #[serde(default)]
481 pub no_index: bool,
482 pub phonetic: Option<String>,
484 #[serde(default)]
486 pub with_suffix_trie: bool,
487 #[serde(default)]
489 pub index_missing: bool,
490 #[serde(default)]
492 pub index_empty: bool,
493}
494
495impl TextFieldAttributes {
496 fn push_redis_args(&self, args: &mut Vec<String>) {
497 if let Some(weight) = self.weight {
498 args.push("WEIGHT".to_owned());
499 args.push(weight.to_string());
500 }
501 if self.sortable {
502 args.push("SORTABLE".to_owned());
503 }
504 if self.no_stem {
505 args.push("NOSTEM".to_owned());
506 }
507 if self.no_index {
508 args.push("NOINDEX".to_owned());
509 }
510 if let Some(phonetic) = &self.phonetic {
511 args.push("PHONETIC".to_owned());
512 args.push(phonetic.clone());
513 }
514 if self.with_suffix_trie {
515 args.push("WITHSUFFIXTRIE".to_owned());
516 }
517 if self.index_missing {
518 args.push("INDEXMISSING".to_owned());
519 }
520 if self.index_empty {
521 args.push("INDEXEMPTY".to_owned());
522 }
523 }
524}
525
526#[derive(Debug, Clone, Default, Serialize, Deserialize)]
528pub struct NumericFieldAttributes {
529 #[serde(default)]
531 pub sortable: bool,
532 #[serde(default)]
534 pub no_index: bool,
535 #[serde(default)]
537 pub index_missing: bool,
538 #[serde(default)]
540 pub index_empty: bool,
541}
542
543impl NumericFieldAttributes {
544 fn push_redis_args(&self, args: &mut Vec<String>) {
545 if self.sortable {
546 args.push("SORTABLE".to_owned());
547 }
548 if self.no_index {
549 args.push("NOINDEX".to_owned());
550 }
551 if self.index_missing {
552 args.push("INDEXMISSING".to_owned());
553 }
554 if self.index_empty {
555 args.push("INDEXEMPTY".to_owned());
556 }
557 }
558}
559
560#[derive(Debug, Clone, Default, Serialize, Deserialize)]
562pub struct GeoFieldAttributes {
563 #[serde(default)]
565 pub sortable: bool,
566 #[serde(default)]
568 pub no_index: bool,
569 #[serde(default)]
571 pub index_missing: bool,
572 #[serde(default)]
574 pub index_empty: bool,
575}
576
577impl GeoFieldAttributes {
578 fn push_redis_args(&self, args: &mut Vec<String>) {
579 if self.sortable {
580 args.push("SORTABLE".to_owned());
581 }
582 if self.no_index {
583 args.push("NOINDEX".to_owned());
584 }
585 if self.index_missing {
586 args.push("INDEXMISSING".to_owned());
587 }
588 if self.index_empty {
589 args.push("INDEXEMPTY".to_owned());
590 }
591 }
592}
593
594#[derive(Debug, Clone, Default, Serialize, Deserialize)]
596pub struct TimestampFieldAttributes {
597 #[serde(default)]
599 pub sortable: bool,
600 #[serde(default)]
602 pub no_index: bool,
603 #[serde(default)]
605 pub index_missing: bool,
606 #[serde(default)]
608 pub index_empty: bool,
609}
610
611impl TimestampFieldAttributes {
612 fn push_redis_args(&self, args: &mut Vec<String>) {
613 if self.sortable {
614 args.push("SORTABLE".to_owned());
615 }
616 if self.no_index {
617 args.push("NOINDEX".to_owned());
618 }
619 if self.index_missing {
620 args.push("INDEXMISSING".to_owned());
621 }
622 if self.index_empty {
623 args.push("INDEXEMPTY".to_owned());
624 }
625 }
626}
627
628#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
630pub enum VectorAlgorithm {
631 #[serde(alias = "flat", alias = "FLAT")]
633 Flat,
634 #[serde(alias = "hnsw", alias = "HNSW")]
636 Hnsw,
637 #[serde(
639 alias = "svs-vamana",
640 alias = "SVS-VAMANA",
641 alias = "svs_vamana",
642 alias = "SVS_VAMANA"
643 )]
644 SvsVamana,
645}
646
647impl VectorAlgorithm {
648 fn redis_name(self) -> &'static str {
649 match self {
650 Self::Flat => "FLAT",
651 Self::Hnsw => "HNSW",
652 Self::SvsVamana => "SVS-VAMANA",
653 }
654 }
655}
656
657#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
659pub enum SvsCompressionType {
660 #[serde(alias = "lvq4", alias = "LVQ4")]
662 Lvq4,
663 #[serde(alias = "lvq4x4", alias = "LVQ4x4")]
665 Lvq4x4,
666 #[serde(alias = "lvq4x8", alias = "LVQ4x8")]
668 Lvq4x8,
669 #[serde(alias = "lvq8", alias = "LVQ8")]
671 Lvq8,
672 #[serde(alias = "leanvec4x8", alias = "LeanVec4x8")]
674 LeanVec4x8,
675 #[serde(alias = "leanvec8x8", alias = "LeanVec8x8")]
677 LeanVec8x8,
678}
679
680impl SvsCompressionType {
681 fn redis_name(self) -> &'static str {
682 match self {
683 Self::Lvq4 => "LVQ4",
684 Self::Lvq4x4 => "LVQ4x4",
685 Self::Lvq4x8 => "LVQ4x8",
686 Self::Lvq8 => "LVQ8",
687 Self::LeanVec4x8 => "LeanVec4x8",
688 Self::LeanVec8x8 => "LeanVec8x8",
689 }
690 }
691
692 fn is_lean_vec(self) -> bool {
693 matches!(self, Self::LeanVec4x8 | Self::LeanVec8x8)
694 }
695}
696
697#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
699#[serde(rename_all = "UPPERCASE")]
700pub enum VectorDataType {
701 #[serde(alias = "bfloat16", alias = "Bfloat16")]
703 Bfloat16,
704 #[serde(alias = "float16", alias = "Float16")]
706 Float16,
707 #[serde(alias = "float32", alias = "Float32")]
709 Float32,
710 #[serde(alias = "float64", alias = "Float64")]
712 Float64,
713}
714
715impl VectorDataType {
716 fn redis_name(self) -> &'static str {
717 match self {
718 Self::Bfloat16 => "BFLOAT16",
719 Self::Float16 => "FLOAT16",
720 Self::Float32 => "FLOAT32",
721 Self::Float64 => "FLOAT64",
722 }
723 }
724
725 pub fn as_str(self) -> &'static str {
730 match self {
731 Self::Bfloat16 => "bfloat16",
732 Self::Float16 => "float16",
733 Self::Float32 => "float32",
734 Self::Float64 => "float64",
735 }
736 }
737}
738
739impl std::fmt::Display for VectorDataType {
740 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
741 f.write_str(self.as_str())
742 }
743}
744
745impl std::str::FromStr for VectorDataType {
746 type Err = crate::Error;
747
748 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
749 match s.to_lowercase().as_str() {
750 "bfloat16" => Ok(Self::Bfloat16),
751 "float16" => Ok(Self::Float16),
752 "float32" => Ok(Self::Float32),
753 "float64" => Ok(Self::Float64),
754 other => Err(crate::Error::InvalidInput(format!(
755 "unknown vector data type '{other}'; expected bfloat16, float16, float32, or float64"
756 ))),
757 }
758 }
759}
760
761impl Default for VectorDataType {
762 fn default() -> Self {
763 Self::Float32
764 }
765}
766
767#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
769#[serde(rename_all = "UPPERCASE")]
770pub enum VectorDistanceMetric {
771 #[serde(alias = "cosine", alias = "Cosine")]
773 Cosine,
774 #[serde(alias = "l2", alias = "L2")]
776 L2,
777 #[serde(alias = "ip", alias = "Ip")]
779 Ip,
780}
781
782impl VectorDistanceMetric {
783 fn redis_name(self) -> &'static str {
784 match self {
785 Self::Cosine => "COSINE",
786 Self::L2 => "L2",
787 Self::Ip => "IP",
788 }
789 }
790}
791
792#[derive(Debug, Clone, Serialize, Deserialize)]
794pub struct VectorFieldAttributes {
795 pub algorithm: VectorAlgorithm,
797 pub dims: usize,
799 pub distance_metric: VectorDistanceMetric,
801 pub datatype: VectorDataType,
803 pub initial_cap: Option<usize>,
805 pub block_size: Option<usize>,
807 pub m: Option<usize>,
809 pub ef_construction: Option<usize>,
811 pub ef_runtime: Option<usize>,
813 pub epsilon: Option<f32>,
815 pub graph_max_degree: Option<usize>,
818 pub construction_window_size: Option<usize>,
820 pub search_window_size: Option<usize>,
822 pub compression: Option<SvsCompressionType>,
824 pub reduce: Option<usize>,
826 pub training_threshold: Option<usize>,
828}
829
830impl VectorFieldAttributes {
831 fn redis_attribute_pairs(&self) -> Vec<String> {
832 let mut args = vec![
833 "TYPE".to_owned(),
834 self.datatype.redis_name().to_owned(),
835 "DIM".to_owned(),
836 self.dims.to_string(),
837 "DISTANCE_METRIC".to_owned(),
838 self.distance_metric.redis_name().to_owned(),
839 ];
840
841 if let Some(initial_cap) = self.initial_cap {
842 args.push("INITIAL_CAP".to_owned());
843 args.push(initial_cap.to_string());
844 }
845 if let Some(block_size) = self.block_size {
847 args.push("BLOCK_SIZE".to_owned());
848 args.push(block_size.to_string());
849 }
850 if let Some(m) = self.m {
852 args.push("M".to_owned());
853 args.push(m.to_string());
854 }
855 if let Some(ef_construction) = self.ef_construction {
856 args.push("EF_CONSTRUCTION".to_owned());
857 args.push(ef_construction.to_string());
858 }
859 if let Some(ef_runtime) = self.ef_runtime {
860 args.push("EF_RUNTIME".to_owned());
861 args.push(ef_runtime.to_string());
862 }
863 if let Some(epsilon) = self.epsilon {
864 args.push("EPSILON".to_owned());
865 args.push(epsilon.to_string());
866 }
867 if let Some(graph_max_degree) = self.graph_max_degree {
869 args.push("GRAPH_MAX_DEGREE".to_owned());
870 args.push(graph_max_degree.to_string());
871 }
872 if let Some(construction_window_size) = self.construction_window_size {
873 args.push("CONSTRUCTION_WINDOW_SIZE".to_owned());
874 args.push(construction_window_size.to_string());
875 }
876 if let Some(search_window_size) = self.search_window_size {
877 args.push("SEARCH_WINDOW_SIZE".to_owned());
878 args.push(search_window_size.to_string());
879 }
880 if let Some(compression) = self.compression {
881 args.push("COMPRESSION".to_owned());
882 args.push(compression.redis_name().to_owned());
883 }
884 if let Some(reduce) = self.reduce {
885 args.push("REDUCE".to_owned());
886 args.push(reduce.to_string());
887 }
888 if let Some(training_threshold) = self.training_threshold {
889 args.push("TRAINING_THRESHOLD".to_owned());
890 args.push(training_threshold.to_string());
891 }
892
893 args
894 }
895
896 pub fn validate_svs(&self) -> Result<()> {
902 if self.algorithm != VectorAlgorithm::SvsVamana {
903 return Ok(());
904 }
905 if !matches!(
907 self.datatype,
908 VectorDataType::Float16 | VectorDataType::Float32
909 ) {
910 return Err(Error::SchemaValidation(format!(
911 "SVS-VAMANA only supports FLOAT16 and FLOAT32 datatypes, got {}",
912 self.datatype
913 )));
914 }
915 if let Some(reduce) = self.reduce {
917 match self.compression {
918 None => {
919 return Err(Error::SchemaValidation(
920 "reduce parameter requires compression to be set".to_owned(),
921 ));
922 }
923 Some(c) if !c.is_lean_vec() => {
924 return Err(Error::SchemaValidation(format!(
925 "reduce parameter is only supported with LeanVec compression types, got {:?}",
926 c
927 )));
928 }
929 _ => {}
930 }
931 if reduce >= self.dims {
932 return Err(Error::SchemaValidation(format!(
933 "reduce ({reduce}) must be less than dims ({})",
934 self.dims
935 )));
936 }
937 }
938 Ok(())
939 }
940}
941
942#[cfg(test)]
943mod tests {
944 use super::{IndexSchema, Prefix, StorageType};
945
946 #[test]
947 fn schema_from_yaml_should_parse_json_storage() {
948 let schema = IndexSchema::from_yaml_str(
949 r"
950index:
951 name: docs
952 prefix: doc
953 storage_type: json
954fields:
955 - name: title
956 path: $.title
957 type: text
958 - name: embedding
959 path: $.embedding
960 type: vector
961 attrs:
962 algorithm: HNSW
963 dims: 3
964 datatype: FLOAT32
965 distance_metric: COSINE
966",
967 )
968 .expect("schema should parse");
969
970 assert!(matches!(schema.index.storage_type, StorageType::Json));
971 assert_eq!(schema.fields.len(), 2);
972 }
973
974 #[test]
975 fn schema_should_apply_defaults_like_python_unit_tests() {
976 let schema = IndexSchema::from_json_value(serde_json::json!({
977 "index": { "name": "test" }
978 }))
979 .expect("schema should parse");
980
981 assert_eq!(schema.index.prefix.first(), "rvl");
982 assert_eq!(schema.index.key_separator, ":");
983 assert!(matches!(schema.index.storage_type, StorageType::Hash));
984 assert!(schema.fields.is_empty());
985 }
986
987 #[test]
988 fn schema_should_accept_multi_prefix_list_like_python_multi_prefix_tests() {
989 let schema = IndexSchema::from_json_value(serde_json::json!({
990 "index": {
991 "name": "test",
992 "prefix": ["pfx_a", "pfx_b"]
993 }
994 }))
995 .expect("schema should parse");
996
997 assert_eq!(schema.index.prefix.len(), 2);
998 assert_eq!(schema.index.prefix.first(), "pfx_a");
999 assert_eq!(schema.index.prefix.all(), vec!["pfx_a", "pfx_b"]);
1000 assert!(matches!(schema.index.prefix, Prefix::Multi(_)));
1001 }
1002
1003 #[test]
1004 fn schema_should_accept_single_string_prefix_like_python_tests() {
1005 let schema = IndexSchema::from_json_value(serde_json::json!({
1006 "index": {
1007 "name": "test",
1008 "prefix": "my_prefix"
1009 }
1010 }))
1011 .expect("schema should parse");
1012
1013 assert_eq!(schema.index.prefix.first(), "my_prefix");
1014 assert_eq!(schema.index.prefix.len(), 1);
1015 assert_eq!(schema.index.prefix.all(), vec!["my_prefix"]);
1016 assert!(matches!(schema.index.prefix, Prefix::Single(_)));
1017 }
1018
1019 #[test]
1020 fn schema_multi_prefix_yaml_should_parse() {
1021 let schema = IndexSchema::from_yaml_str(
1022 r"
1023index:
1024 name: multi
1025 prefix:
1026 - alpha
1027 - beta
1028fields:
1029 - name: tag
1030 type: tag
1031",
1032 )
1033 .expect("schema should parse");
1034
1035 assert_eq!(schema.index.prefix.len(), 2);
1036 assert_eq!(schema.index.prefix.all(), vec!["alpha", "beta"]);
1037 }
1038
1039 #[test]
1042 fn tag_field_index_missing_should_render_indexmissing_arg() {
1043 let schema = IndexSchema::from_json_value(serde_json::json!({
1044 "index": { "name": "test_missing" },
1045 "fields": [
1046 { "name": "brand", "type": "tag", "attrs": { "index_missing": true } }
1047 ]
1048 }))
1049 .expect("schema should parse");
1050
1051 let args = schema.fields[0].redis_args(StorageType::Hash);
1052 assert!(args.contains(&"INDEXMISSING".to_owned()));
1053 }
1054
1055 #[test]
1056 fn numeric_field_index_empty_should_render_indexempty_arg() {
1057 let schema = IndexSchema::from_json_value(serde_json::json!({
1058 "index": { "name": "test_empty" },
1059 "fields": [
1060 { "name": "price", "type": "numeric", "attrs": { "index_empty": true } }
1061 ]
1062 }))
1063 .expect("schema should parse");
1064
1065 let args = schema.fields[0].redis_args(StorageType::Hash);
1066 assert!(args.contains(&"INDEXEMPTY".to_owned()));
1067 }
1068
1069 #[test]
1070 fn text_field_both_index_missing_and_index_empty() {
1071 let schema = IndexSchema::from_json_value(serde_json::json!({
1072 "index": { "name": "test_both" },
1073 "fields": [
1074 { "name": "description", "type": "text", "attrs": { "index_missing": true, "index_empty": true } }
1075 ]
1076 }))
1077 .expect("schema should parse");
1078
1079 let args = schema.fields[0].redis_args(StorageType::Hash);
1080 assert!(args.contains(&"INDEXMISSING".to_owned()));
1081 assert!(args.contains(&"INDEXEMPTY".to_owned()));
1082 }
1083
1084 #[test]
1085 fn fields_default_to_no_index_missing_or_empty() {
1086 let schema = IndexSchema::from_yaml_str(
1087 r"
1088index:
1089 name: test_defaults
1090fields:
1091 - name: brand
1092 type: tag
1093",
1094 )
1095 .expect("schema should parse");
1096
1097 let args = schema.fields[0].redis_args(StorageType::Hash);
1098 assert!(!args.contains(&"INDEXMISSING".to_owned()));
1099 assert!(!args.contains(&"INDEXEMPTY".to_owned()));
1100 }
1101
1102 #[test]
1103 fn vector_data_type_from_str_roundtrip() {
1104 use super::VectorDataType;
1105 use std::str::FromStr;
1106
1107 for (input, expected) in [
1108 ("bfloat16", VectorDataType::Bfloat16),
1109 ("float16", VectorDataType::Float16),
1110 ("float32", VectorDataType::Float32),
1111 ("float64", VectorDataType::Float64),
1112 ("BFLOAT16", VectorDataType::Bfloat16),
1113 ("FLOAT16", VectorDataType::Float16),
1114 ("FLOAT32", VectorDataType::Float32),
1115 ("FLOAT64", VectorDataType::Float64),
1116 ("Float32", VectorDataType::Float32),
1117 ] {
1118 let parsed = VectorDataType::from_str(input)
1119 .unwrap_or_else(|_| panic!("should parse '{input}'"));
1120 assert_eq!(parsed, expected, "mismatch for input '{input}'");
1121 }
1122
1123 assert!(VectorDataType::from_str("int8").is_err());
1124 assert!(VectorDataType::from_str("").is_err());
1125 }
1126
1127 #[test]
1128 fn vector_data_type_as_str_and_display() {
1129 use super::VectorDataType;
1130
1131 assert_eq!(VectorDataType::Bfloat16.as_str(), "bfloat16");
1132 assert_eq!(VectorDataType::Float16.as_str(), "float16");
1133 assert_eq!(VectorDataType::Float32.as_str(), "float32");
1134 assert_eq!(VectorDataType::Float64.as_str(), "float64");
1135
1136 assert_eq!(VectorDataType::Float32.to_string(), "float32");
1137 assert_eq!(VectorDataType::Bfloat16.to_string(), "bfloat16");
1138 }
1139
1140 #[test]
1141 fn vector_data_type_default_is_float32() {
1142 use super::VectorDataType;
1143 assert_eq!(VectorDataType::default(), VectorDataType::Float32);
1144 }
1145
1146 #[test]
1147 fn vector_data_type_serde_uppercase() {
1148 use super::VectorDataType;
1149
1150 let json = serde_json::to_string(&VectorDataType::Bfloat16).unwrap();
1151 assert_eq!(json, "\"BFLOAT16\"");
1152
1153 let json = serde_json::to_string(&VectorDataType::Float16).unwrap();
1154 assert_eq!(json, "\"FLOAT16\"");
1155
1156 let deserialized: VectorDataType = serde_json::from_str("\"FLOAT64\"").unwrap();
1157 assert_eq!(deserialized, VectorDataType::Float64);
1158 }
1159
1160 #[test]
1161 fn vector_data_type_serde_lowercase_aliases() {
1162 use super::VectorDataType;
1163
1164 for (input, expected) in [
1167 ("\"float32\"", VectorDataType::Float32),
1168 ("\"float64\"", VectorDataType::Float64),
1169 ("\"float16\"", VectorDataType::Float16),
1170 ("\"bfloat16\"", VectorDataType::Bfloat16),
1171 ("\"Float32\"", VectorDataType::Float32),
1172 ("\"Bfloat16\"", VectorDataType::Bfloat16),
1173 ] {
1174 let deserialized: VectorDataType = serde_json::from_str(input)
1175 .unwrap_or_else(|e| panic!("should deserialize {input}: {e}"));
1176 assert_eq!(deserialized, expected, "mismatch for input {input}");
1177 }
1178 }
1179
1180 #[test]
1181 fn vector_distance_metric_serde_lowercase_aliases() {
1182 use super::VectorDistanceMetric;
1183
1184 for (input, expected_name) in [
1186 ("\"COSINE\"", "COSINE"),
1187 ("\"cosine\"", "COSINE"),
1188 ("\"Cosine\"", "COSINE"),
1189 ("\"L2\"", "L2"),
1190 ("\"l2\"", "L2"),
1191 ("\"IP\"", "IP"),
1192 ("\"ip\"", "IP"),
1193 ] {
1194 let deserialized: VectorDistanceMetric = serde_json::from_str(input)
1195 .unwrap_or_else(|e| panic!("should deserialize {input}: {e}"));
1196 assert_eq!(
1197 deserialized.redis_name(),
1198 expected_name,
1199 "mismatch for input {input}"
1200 );
1201 }
1202 }
1203
1204 #[test]
1205 fn schema_from_json_with_lowercase_dtype() {
1206 use super::{FieldKind, VectorDataType};
1207
1208 let schema = IndexSchema::from_json_value(serde_json::json!({
1211 "index": { "name": "lc_test", "prefix": "lc" },
1212 "fields": [{
1213 "name": "vec",
1214 "type": "vector",
1215 "attrs": {
1216 "algorithm": "flat",
1217 "dims": 3,
1218 "datatype": "float32",
1219 "distance_metric": "cosine"
1220 }
1221 }]
1222 }))
1223 .expect("schema with lowercase dtype/distance_metric should parse");
1224
1225 if let FieldKind::Vector { ref attrs } = schema.fields[0].kind {
1226 assert_eq!(attrs.datatype, VectorDataType::Float32);
1227 } else {
1228 panic!("expected vector field");
1229 }
1230 }
1231
1232 #[test]
1233 fn schema_from_yaml_bfloat16_vector() {
1234 use super::{FieldKind, VectorDataType};
1235 let schema = IndexSchema::from_yaml_str(
1236 r"
1237index:
1238 name: bf16test
1239 prefix: bf16
1240fields:
1241 - name: vec
1242 type: vector
1243 attrs:
1244 algorithm: FLAT
1245 dims: 4
1246 datatype: BFLOAT16
1247 distance_metric: COSINE
1248",
1249 )
1250 .expect("schema with BFLOAT16 should parse");
1251
1252 assert_eq!(schema.index.name, "bf16test");
1253 let vec_field = &schema.fields[0];
1254 if let FieldKind::Vector { ref attrs } = vec_field.kind {
1255 assert_eq!(attrs.datatype, VectorDataType::Bfloat16);
1256 } else {
1257 panic!("expected vector field");
1258 }
1259 }
1260
1261 #[test]
1262 fn schema_from_yaml_float16_vector() {
1263 use super::{FieldKind, VectorDataType};
1264 let schema = IndexSchema::from_yaml_str(
1265 r"
1266index:
1267 name: f16test
1268 prefix: f16
1269fields:
1270 - name: vec
1271 type: vector
1272 attrs:
1273 algorithm: HNSW
1274 dims: 8
1275 datatype: FLOAT16
1276 distance_metric: L2
1277",
1278 )
1279 .expect("schema with FLOAT16 should parse");
1280
1281 let vec_field = &schema.fields[0];
1282 if let FieldKind::Vector { ref attrs } = vec_field.kind {
1283 assert_eq!(attrs.datatype, VectorDataType::Float16);
1284 } else {
1285 panic!("expected vector field");
1286 }
1287 }
1288
1289 #[test]
1292 fn add_field_should_append_and_validate() {
1293 use super::{Field, FieldKind, TagFieldAttributes};
1294
1295 let mut schema = IndexSchema::from_json_value(serde_json::json!({
1296 "index": { "name": "test" },
1297 "fields": [
1298 { "name": "title", "type": "text" }
1299 ]
1300 }))
1301 .expect("schema should parse");
1302
1303 assert_eq!(schema.fields.len(), 1);
1304
1305 let field = Field {
1306 name: "brand".to_owned(),
1307 path: None,
1308 kind: FieldKind::Tag {
1309 attrs: TagFieldAttributes::default(),
1310 },
1311 };
1312 schema.add_field(field).expect("add_field should succeed");
1313 assert_eq!(schema.fields.len(), 2);
1314 assert!(schema.field("brand").is_some());
1315 }
1316
1317 #[test]
1318 fn add_field_duplicate_should_error() {
1319 let mut schema = IndexSchema::from_json_value(serde_json::json!({
1320 "index": { "name": "test" },
1321 "fields": [
1322 { "name": "title", "type": "text" }
1323 ]
1324 }))
1325 .expect("schema should parse");
1326
1327 let field = super::Field {
1328 name: "title".to_owned(),
1329 path: None,
1330 kind: super::FieldKind::Text {
1331 attrs: super::TextFieldAttributes::default(),
1332 },
1333 };
1334 assert!(schema.add_field(field).is_err());
1335 }
1336
1337 #[test]
1338 fn remove_field_should_drop_by_name() {
1339 let mut schema = IndexSchema::from_json_value(serde_json::json!({
1340 "index": { "name": "test" },
1341 "fields": [
1342 { "name": "title", "type": "text" },
1343 { "name": "brand", "type": "tag" }
1344 ]
1345 }))
1346 .expect("schema should parse");
1347
1348 assert_eq!(schema.fields.len(), 2);
1349 assert!(schema.remove_field("title"));
1350 assert_eq!(schema.fields.len(), 1);
1351 assert!(schema.field("title").is_none());
1352 assert!(!schema.remove_field("title"));
1354 }
1355
1356 #[test]
1359 fn svs_vamana_schema_with_float32_should_parse() {
1360 use super::{FieldKind, VectorAlgorithm};
1361
1362 let schema = IndexSchema::from_json_value(serde_json::json!({
1363 "index": { "name": "test-svs-index" },
1364 "fields": [{
1365 "name": "vec",
1366 "type": "vector",
1367 "attrs": {
1368 "algorithm": "SvsVamana",
1369 "dims": 128,
1370 "distance_metric": "COSINE",
1371 "datatype": "FLOAT32"
1372 }
1373 }]
1374 }))
1375 .expect("SVS-VAMANA with float32 should parse");
1376
1377 if let FieldKind::Vector { ref attrs } = schema.fields[0].kind {
1378 assert_eq!(attrs.algorithm, VectorAlgorithm::SvsVamana);
1379 } else {
1380 panic!("expected vector field");
1381 }
1382 }
1383
1384 #[test]
1385 fn svs_vamana_with_float64_should_fail_validation() {
1386 let result = IndexSchema::from_json_value(serde_json::json!({
1387 "index": { "name": "test-svs-index" },
1388 "fields": [{
1389 "name": "vec",
1390 "type": "vector",
1391 "attrs": {
1392 "algorithm": "SvsVamana",
1393 "dims": 128,
1394 "distance_metric": "COSINE",
1395 "datatype": "FLOAT64"
1396 }
1397 }]
1398 }));
1399 assert!(result.is_err(), "SVS-VAMANA should reject FLOAT64");
1400 }
1401
1402 #[test]
1403 fn svs_vamana_with_compression_and_reduce() {
1404 use super::{FieldKind, SvsCompressionType};
1405
1406 let schema = IndexSchema::from_json_value(serde_json::json!({
1407 "index": { "name": "test-svs-index" },
1408 "fields": [{
1409 "name": "vec",
1410 "type": "vector",
1411 "attrs": {
1412 "algorithm": "SvsVamana",
1413 "dims": 128,
1414 "distance_metric": "COSINE",
1415 "datatype": "FLOAT32",
1416 "compression": "LeanVec4x8",
1417 "reduce": 64
1418 }
1419 }]
1420 }))
1421 .expect("SVS-VAMANA with LeanVec + reduce should parse");
1422
1423 if let FieldKind::Vector { ref attrs } = schema.fields[0].kind {
1424 assert_eq!(attrs.compression, Some(SvsCompressionType::LeanVec4x8));
1425 assert_eq!(attrs.reduce, Some(64));
1426 } else {
1427 panic!("expected vector field");
1428 }
1429 }
1430
1431 #[test]
1432 fn svs_vamana_reduce_without_compression_should_fail() {
1433 let result = IndexSchema::from_json_value(serde_json::json!({
1434 "index": { "name": "test-svs-index" },
1435 "fields": [{
1436 "name": "vec",
1437 "type": "vector",
1438 "attrs": {
1439 "algorithm": "SvsVamana",
1440 "dims": 128,
1441 "distance_metric": "COSINE",
1442 "datatype": "FLOAT32",
1443 "reduce": 64
1444 }
1445 }]
1446 }));
1447 assert!(
1448 result.is_err(),
1449 "SVS-VAMANA reduce without compression should fail"
1450 );
1451 }
1452
1453 #[test]
1454 fn svs_vamana_reduce_with_lvq4_should_fail() {
1455 let result = IndexSchema::from_json_value(serde_json::json!({
1456 "index": { "name": "test-svs-index" },
1457 "fields": [{
1458 "name": "vec",
1459 "type": "vector",
1460 "attrs": {
1461 "algorithm": "SvsVamana",
1462 "dims": 128,
1463 "distance_metric": "COSINE",
1464 "datatype": "FLOAT32",
1465 "compression": "Lvq4",
1466 "reduce": 64
1467 }
1468 }]
1469 }));
1470 assert!(result.is_err(), "SVS-VAMANA reduce with LVQ4 should fail");
1471 }
1472
1473 #[test]
1474 fn svs_vamana_reduce_gte_dims_should_fail() {
1475 let result = IndexSchema::from_json_value(serde_json::json!({
1476 "index": { "name": "test-svs-index" },
1477 "fields": [{
1478 "name": "vec",
1479 "type": "vector",
1480 "attrs": {
1481 "algorithm": "SvsVamana",
1482 "dims": 128,
1483 "distance_metric": "COSINE",
1484 "datatype": "FLOAT32",
1485 "compression": "LeanVec4x8",
1486 "reduce": 128
1487 }
1488 }]
1489 }));
1490 assert!(result.is_err(), "SVS-VAMANA reduce >= dims should fail");
1491 }
1492
1493 #[test]
1494 fn svs_vamana_redis_args_include_svs_params() {
1495 let schema = IndexSchema::from_json_value(serde_json::json!({
1496 "index": { "name": "test-svs-index" },
1497 "fields": [{
1498 "name": "vec",
1499 "type": "vector",
1500 "attrs": {
1501 "algorithm": "SvsVamana",
1502 "dims": 128,
1503 "distance_metric": "COSINE",
1504 "datatype": "FLOAT32",
1505 "graph_max_degree": 40,
1506 "construction_window_size": 250,
1507 "search_window_size": 20,
1508 "compression": "Lvq8",
1509 "training_threshold": 10000
1510 }
1511 }]
1512 }))
1513 .expect("SVS schema should parse");
1514
1515 let args = schema.fields[0].redis_args(StorageType::Hash);
1516 assert!(args.contains(&"VECTOR".to_owned()));
1517 assert!(args.contains(&"SVS-VAMANA".to_owned()));
1518 assert!(args.contains(&"GRAPH_MAX_DEGREE".to_owned()));
1519 assert!(args.contains(&"40".to_owned()));
1520 assert!(args.contains(&"CONSTRUCTION_WINDOW_SIZE".to_owned()));
1521 assert!(args.contains(&"SEARCH_WINDOW_SIZE".to_owned()));
1522 assert!(args.contains(&"COMPRESSION".to_owned()));
1523 assert!(args.contains(&"LVQ8".to_owned()));
1524 assert!(args.contains(&"TRAINING_THRESHOLD".to_owned()));
1525 }
1526}