rs_puff/
rank_by.rs

1use serde::ser::{SerializeSeq, Serializer};
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
5#[serde(rename_all = "lowercase")]
6pub enum Order {
7    Asc,
8    Desc,
9}
10
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12pub struct Bm25Params {
13    #[serde(skip_serializing_if = "Option::is_none")]
14    pub last_as_prefix: Option<bool>,
15}
16
17#[derive(Debug, Clone, PartialEq)]
18pub enum RankBy {
19    // Vector search: ["attr", "ANN", [vector]]
20    Vector { attr: String, query: Vec<f32> },
21    // Exact kNN: ["attr", "kNN", [vector]]
22    VectorKnn { attr: String, query: Vec<f32> },
23    // BM25 text search: ["attr", "BM25", "query"]
24    Bm25 { attr: String, query: String, params: Option<Bm25Params> },
25    // Attribute ordering: ["attr", "asc"|"desc"]
26    Attribute { attr: String, order: Order },
27    // Combinators
28    Sum(Vec<RankBy>),
29    Max(Vec<RankBy>),
30    Product { weight: f64, subquery: Box<RankBy> },
31}
32
33impl RankBy {
34    pub fn vector(attr: impl Into<String>, query: Vec<f32>) -> Self {
35        RankBy::Vector { attr: attr.into(), query }
36    }
37
38    pub fn vector_knn(attr: impl Into<String>, query: Vec<f32>) -> Self {
39        RankBy::VectorKnn { attr: attr.into(), query }
40    }
41
42    pub fn bm25(attr: impl Into<String>, query: impl Into<String>) -> Self {
43        RankBy::Bm25 { attr: attr.into(), query: query.into(), params: None }
44    }
45
46    pub fn bm25_with_params(attr: impl Into<String>, query: impl Into<String>, params: Bm25Params) -> Self {
47        RankBy::Bm25 { attr: attr.into(), query: query.into(), params: Some(params) }
48    }
49
50    pub fn attribute(attr: impl Into<String>, order: Order) -> Self {
51        RankBy::Attribute { attr: attr.into(), order }
52    }
53
54    pub fn asc(attr: impl Into<String>) -> Self {
55        RankBy::Attribute { attr: attr.into(), order: Order::Asc }
56    }
57
58    pub fn desc(attr: impl Into<String>) -> Self {
59        RankBy::Attribute { attr: attr.into(), order: Order::Desc }
60    }
61
62    pub fn sum(subqueries: Vec<RankBy>) -> Self {
63        RankBy::Sum(subqueries)
64    }
65
66    pub fn max(subqueries: Vec<RankBy>) -> Self {
67        RankBy::Max(subqueries)
68    }
69
70    pub fn product(weight: f64, subquery: RankBy) -> Self {
71        RankBy::Product { weight, subquery: Box::new(subquery) }
72    }
73}
74
75impl Serialize for RankBy {
76    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
77    where
78        S: Serializer,
79    {
80        match self {
81            RankBy::Vector { attr, query } => {
82                let mut seq = serializer.serialize_seq(Some(3))?;
83                seq.serialize_element(attr)?;
84                seq.serialize_element("ANN")?;
85                seq.serialize_element(query)?;
86                seq.end()
87            }
88            RankBy::VectorKnn { attr, query } => {
89                let mut seq = serializer.serialize_seq(Some(3))?;
90                seq.serialize_element(attr)?;
91                seq.serialize_element("kNN")?;
92                seq.serialize_element(query)?;
93                seq.end()
94            }
95            RankBy::Bm25 { attr, query, params } => {
96                if let Some(p) = params {
97                    let mut seq = serializer.serialize_seq(Some(4))?;
98                    seq.serialize_element(attr)?;
99                    seq.serialize_element("BM25")?;
100                    seq.serialize_element(query)?;
101                    seq.serialize_element(p)?;
102                    seq.end()
103                } else {
104                    let mut seq = serializer.serialize_seq(Some(3))?;
105                    seq.serialize_element(attr)?;
106                    seq.serialize_element("BM25")?;
107                    seq.serialize_element(query)?;
108                    seq.end()
109                }
110            }
111            RankBy::Attribute { attr, order } => {
112                let mut seq = serializer.serialize_seq(Some(2))?;
113                seq.serialize_element(attr)?;
114                seq.serialize_element(order)?;
115                seq.end()
116            }
117            RankBy::Sum(subqueries) => {
118                let mut seq = serializer.serialize_seq(Some(2))?;
119                seq.serialize_element("Sum")?;
120                seq.serialize_element(subqueries)?;
121                seq.end()
122            }
123            RankBy::Max(subqueries) => {
124                let mut seq = serializer.serialize_seq(Some(2))?;
125                seq.serialize_element("Max")?;
126                seq.serialize_element(subqueries)?;
127                seq.end()
128            }
129            RankBy::Product { weight, subquery } => {
130                let mut seq = serializer.serialize_seq(Some(3))?;
131                seq.serialize_element("Product")?;
132                seq.serialize_element(weight)?;
133                seq.serialize_element(subquery)?;
134                seq.end()
135            }
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn test_vector_serialization() {
146        let r = RankBy::vector("vector", vec![0.1, 0.2, 0.3]);
147        let json = serde_json::to_string(&r).unwrap();
148        assert_eq!(json, r#"["vector","ANN",[0.1,0.2,0.3]]"#);
149    }
150
151    #[test]
152    fn test_vector_knn_serialization() {
153        let r = RankBy::vector_knn("embedding", vec![1.0, 2.0, 3.0]);
154        let json = serde_json::to_string(&r).unwrap();
155        assert_eq!(json, r#"["embedding","kNN",[1.0,2.0,3.0]]"#);
156    }
157
158    #[test]
159    fn test_bm25_serialization() {
160        let r = RankBy::bm25("content", "quick fox");
161        let json = serde_json::to_string(&r).unwrap();
162        assert_eq!(json, r#"["content","BM25","quick fox"]"#);
163    }
164
165    #[test]
166    fn test_bm25_with_params_serialization() {
167        let r = RankBy::bm25_with_params("content", "quick", Bm25Params {
168            last_as_prefix: Some(true),
169        });
170        let json = serde_json::to_string(&r).unwrap();
171        assert_eq!(json, r#"["content","BM25","quick",{"last_as_prefix":true}]"#);
172    }
173
174    #[test]
175    fn test_attribute_asc_serialization() {
176        let r = RankBy::asc("timestamp");
177        let json = serde_json::to_string(&r).unwrap();
178        assert_eq!(json, r#"["timestamp","asc"]"#);
179    }
180
181    #[test]
182    fn test_attribute_desc_serialization() {
183        let r = RankBy::desc("timestamp");
184        let json = serde_json::to_string(&r).unwrap();
185        assert_eq!(json, r#"["timestamp","desc"]"#);
186    }
187
188    #[test]
189    fn test_attribute_with_order_serialization() {
190        let r = RankBy::attribute("score", Order::Desc);
191        let json = serde_json::to_string(&r).unwrap();
192        assert_eq!(json, r#"["score","desc"]"#);
193    }
194
195    #[test]
196    fn test_sum_serialization() {
197        let r = RankBy::sum(vec![
198            RankBy::bm25("title", "fox"),
199            RankBy::bm25("content", "fox"),
200        ]);
201        let json = serde_json::to_string(&r).unwrap();
202        assert_eq!(json, r#"["Sum",[["title","BM25","fox"],["content","BM25","fox"]]]"#);
203    }
204
205    #[test]
206    fn test_max_serialization() {
207        let r = RankBy::max(vec![
208            RankBy::bm25("title", "query"),
209            RankBy::vector("vec", vec![0.1, 0.2]),
210        ]);
211        let json = serde_json::to_string(&r).unwrap();
212        assert_eq!(json, r#"["Max",[["title","BM25","query"],["vec","ANN",[0.1,0.2]]]]"#);
213    }
214
215    #[test]
216    fn test_product_serialization() {
217        let r = RankBy::product(2.0, RankBy::bm25("title", "fox"));
218        let json = serde_json::to_string(&r).unwrap();
219        assert_eq!(json, r#"["Product",2.0,["title","BM25","fox"]]"#);
220    }
221
222    #[test]
223    fn test_nested_combinators() {
224        let r = RankBy::sum(vec![
225            RankBy::product(2.0, RankBy::bm25("title", "query")),
226            RankBy::product(1.0, RankBy::bm25("content", "query")),
227        ]);
228        let json = serde_json::to_string(&r).unwrap();
229        assert_eq!(
230            json,
231            r#"["Sum",[["Product",2.0,["title","BM25","query"]],["Product",1.0,["content","BM25","query"]]]]"#
232        );
233    }
234
235    #[test]
236    fn test_empty_vector() {
237        let r = RankBy::vector("vec", vec![]);
238        let json = serde_json::to_string(&r).unwrap();
239        assert_eq!(json, r#"["vec","ANN",[]]"#);
240    }
241}