Skip to main content

diskann_tools/utils/
multi_label.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::collections::HashSet;
7
8#[derive(Debug, Default)]
9pub struct MultiLabel {
10    base_clause: HashSet<String>,   // For base labels
11    query_clause: Vec<Vec<String>>, // For query labels, first level is AND, second level is OR
12}
13
14impl MultiLabel {
15    // Constructor
16    pub fn new() -> Self {
17        MultiLabel {
18            base_clause: HashSet::new(),
19            query_clause: Vec::new(),
20        }
21    }
22
23    // Static method to create MultiLabel from base label
24    pub fn from_base(base_label: &str) -> Self {
25        let mut ml = MultiLabel::new();
26        for token in base_label.split(',') {
27            let token = token.trim().to_string();
28            ml.base_clause.insert(token);
29        }
30        ml
31    }
32
33    // Static method to create MultiLabel from query label
34    pub fn from_query(query_label: &str) -> Self {
35        let mut ml = MultiLabel::new();
36        for token in query_label.split('&') {
37            let mut or_clause = Vec::new();
38            for inner_token in token.split('|') {
39                let inner_token = inner_token.trim().to_string();
40                or_clause.push(inner_token);
41            }
42            ml.query_clause.push(or_clause);
43        }
44        ml
45    }
46
47    // Method to print the query clause
48    pub fn print_query(&self) {
49        for (i, and_clause) in self.query_clause.iter().enumerate() {
50            for (j, or_clause) in and_clause.iter().enumerate() {
51                print!("{}", or_clause);
52                if j < and_clause.len() - 1 {
53                    print!("|");
54                }
55            }
56            if i < self.query_clause.len() - 1 {
57                print!("&");
58            }
59        }
60        println!();
61    }
62
63    // Method to check if the query label is a subset of the base label
64    pub fn is_subset_of(&self, base_label: &MultiLabel) -> bool {
65        for and_clause in &self.query_clause {
66            let mut or_pass = false;
67            for or_clause in and_clause {
68                if base_label.base_clause.contains(or_clause) {
69                    or_pass = true;
70                    break;
71                }
72            }
73            if !or_pass {
74                return false;
75            }
76        }
77        true
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::MultiLabel;
84    const BASE_LABEL:&str = "BRAND=Caltric,CAT=Automotive,CAT=MotorcyclePowersports,CAT=Parts,CAT=Filters,CAT=OilFilters,RATING=5";
85    #[test]
86    fn test_subset1() {
87        let mut query_label = "CAT=ExteriorAccessories&RATING=4|RATING=5";
88
89        let base_ml = MultiLabel::from_base(BASE_LABEL);
90        let mut query_ml = MultiLabel::from_query(query_label);
91
92        query_ml.print_query();
93
94        assert!(!query_ml.is_subset_of(&base_ml));
95
96        query_label = "CAT=Automotive&RATING=4|RATING=5";
97
98        query_ml = MultiLabel::from_query(query_label);
99
100        query_ml.print_query();
101
102        assert!(query_ml.is_subset_of(&base_ml));
103
104        query_label = "CAT=ExteriorAccessories&RATING=4|RATING=5";
105
106        query_ml = MultiLabel::from_query(query_label);
107
108        query_ml.print_query();
109
110        assert!(!query_ml.is_subset_of(&base_ml));
111    }
112}