diskann_tools/utils/
multi_label.rs1use std::collections::HashSet;
7
8#[derive(Debug, Default)]
9pub struct MultiLabel {
10 base_clause: HashSet<String>, query_clause: Vec<Vec<String>>, }
13
14impl MultiLabel {
15 pub fn new() -> Self {
17 MultiLabel {
18 base_clause: HashSet::new(),
19 query_clause: Vec::new(),
20 }
21 }
22
23 pub fn from_base(base_label: &str) -> Self {
25 let mut ml = MultiLabel::new();
26 for token in base_label.split(',') {
27 let token = token.trim().to_string();
28 ml.base_clause.insert(token);
29 }
30 ml
31 }
32
33 pub fn from_query(query_label: &str) -> Self {
35 let mut ml = MultiLabel::new();
36 for token in query_label.split('&') {
37 let mut or_clause = Vec::new();
38 for inner_token in token.split('|') {
39 let inner_token = inner_token.trim().to_string();
40 or_clause.push(inner_token);
41 }
42 ml.query_clause.push(or_clause);
43 }
44 ml
45 }
46
47 pub fn print_query(&self) {
49 for (i, and_clause) in self.query_clause.iter().enumerate() {
50 for (j, or_clause) in and_clause.iter().enumerate() {
51 print!("{}", or_clause);
52 if j < and_clause.len() - 1 {
53 print!("|");
54 }
55 }
56 if i < self.query_clause.len() - 1 {
57 print!("&");
58 }
59 }
60 println!();
61 }
62
63 pub fn is_subset_of(&self, base_label: &MultiLabel) -> bool {
65 for and_clause in &self.query_clause {
66 let mut or_pass = false;
67 for or_clause in and_clause {
68 if base_label.base_clause.contains(or_clause) {
69 or_pass = true;
70 break;
71 }
72 }
73 if !or_pass {
74 return false;
75 }
76 }
77 true
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::MultiLabel;
84 const BASE_LABEL:&str = "BRAND=Caltric,CAT=Automotive,CAT=MotorcyclePowersports,CAT=Parts,CAT=Filters,CAT=OilFilters,RATING=5";
85 #[test]
86 fn test_subset1() {
87 let mut query_label = "CAT=ExteriorAccessories&RATING=4|RATING=5";
88
89 let base_ml = MultiLabel::from_base(BASE_LABEL);
90 let mut query_ml = MultiLabel::from_query(query_label);
91
92 query_ml.print_query();
93
94 assert!(!query_ml.is_subset_of(&base_ml));
95
96 query_label = "CAT=Automotive&RATING=4|RATING=5";
97
98 query_ml = MultiLabel::from_query(query_label);
99
100 query_ml.print_query();
101
102 assert!(query_ml.is_subset_of(&base_ml));
103
104 query_label = "CAT=ExteriorAccessories&RATING=4|RATING=5";
105
106 query_ml = MultiLabel::from_query(query_label);
107
108 query_ml.print_query();
109
110 assert!(!query_ml.is_subset_of(&base_ml));
111 }
112}