1use serde::ser::{SerializeSeq, Serializer};
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
5#[serde(rename_all = "lowercase")]
6pub enum Order {
7 Asc,
8 Desc,
9}
10
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12pub struct Bm25Params {
13 #[serde(skip_serializing_if = "Option::is_none")]
14 pub last_as_prefix: Option<bool>,
15}
16
17#[derive(Debug, Clone, PartialEq)]
18pub enum RankBy {
19 Vector { attr: String, query: Vec<f32> },
21 VectorKnn { attr: String, query: Vec<f32> },
23 Bm25 { attr: String, query: String, params: Option<Bm25Params> },
25 Attribute { attr: String, order: Order },
27 Sum(Vec<RankBy>),
29 Max(Vec<RankBy>),
30 Product { weight: f64, subquery: Box<RankBy> },
31}
32
33impl RankBy {
34 pub fn vector(attr: impl Into<String>, query: Vec<f32>) -> Self {
35 RankBy::Vector { attr: attr.into(), query }
36 }
37
38 pub fn vector_knn(attr: impl Into<String>, query: Vec<f32>) -> Self {
39 RankBy::VectorKnn { attr: attr.into(), query }
40 }
41
42 pub fn bm25(attr: impl Into<String>, query: impl Into<String>) -> Self {
43 RankBy::Bm25 { attr: attr.into(), query: query.into(), params: None }
44 }
45
46 pub fn bm25_with_params(attr: impl Into<String>, query: impl Into<String>, params: Bm25Params) -> Self {
47 RankBy::Bm25 { attr: attr.into(), query: query.into(), params: Some(params) }
48 }
49
50 pub fn attribute(attr: impl Into<String>, order: Order) -> Self {
51 RankBy::Attribute { attr: attr.into(), order }
52 }
53
54 pub fn asc(attr: impl Into<String>) -> Self {
55 RankBy::Attribute { attr: attr.into(), order: Order::Asc }
56 }
57
58 pub fn desc(attr: impl Into<String>) -> Self {
59 RankBy::Attribute { attr: attr.into(), order: Order::Desc }
60 }
61
62 pub fn sum(subqueries: Vec<RankBy>) -> Self {
63 RankBy::Sum(subqueries)
64 }
65
66 pub fn max(subqueries: Vec<RankBy>) -> Self {
67 RankBy::Max(subqueries)
68 }
69
70 pub fn product(weight: f64, subquery: RankBy) -> Self {
71 RankBy::Product { weight, subquery: Box::new(subquery) }
72 }
73}
74
75impl Serialize for RankBy {
76 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
77 where
78 S: Serializer,
79 {
80 match self {
81 RankBy::Vector { attr, query } => {
82 let mut seq = serializer.serialize_seq(Some(3))?;
83 seq.serialize_element(attr)?;
84 seq.serialize_element("ANN")?;
85 seq.serialize_element(query)?;
86 seq.end()
87 }
88 RankBy::VectorKnn { attr, query } => {
89 let mut seq = serializer.serialize_seq(Some(3))?;
90 seq.serialize_element(attr)?;
91 seq.serialize_element("kNN")?;
92 seq.serialize_element(query)?;
93 seq.end()
94 }
95 RankBy::Bm25 { attr, query, params } => {
96 if let Some(p) = params {
97 let mut seq = serializer.serialize_seq(Some(4))?;
98 seq.serialize_element(attr)?;
99 seq.serialize_element("BM25")?;
100 seq.serialize_element(query)?;
101 seq.serialize_element(p)?;
102 seq.end()
103 } else {
104 let mut seq = serializer.serialize_seq(Some(3))?;
105 seq.serialize_element(attr)?;
106 seq.serialize_element("BM25")?;
107 seq.serialize_element(query)?;
108 seq.end()
109 }
110 }
111 RankBy::Attribute { attr, order } => {
112 let mut seq = serializer.serialize_seq(Some(2))?;
113 seq.serialize_element(attr)?;
114 seq.serialize_element(order)?;
115 seq.end()
116 }
117 RankBy::Sum(subqueries) => {
118 let mut seq = serializer.serialize_seq(Some(2))?;
119 seq.serialize_element("Sum")?;
120 seq.serialize_element(subqueries)?;
121 seq.end()
122 }
123 RankBy::Max(subqueries) => {
124 let mut seq = serializer.serialize_seq(Some(2))?;
125 seq.serialize_element("Max")?;
126 seq.serialize_element(subqueries)?;
127 seq.end()
128 }
129 RankBy::Product { weight, subquery } => {
130 let mut seq = serializer.serialize_seq(Some(3))?;
131 seq.serialize_element("Product")?;
132 seq.serialize_element(weight)?;
133 seq.serialize_element(subquery)?;
134 seq.end()
135 }
136 }
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
145 fn test_vector_serialization() {
146 let r = RankBy::vector("vector", vec![0.1, 0.2, 0.3]);
147 let json = serde_json::to_string(&r).unwrap();
148 assert_eq!(json, r#"["vector","ANN",[0.1,0.2,0.3]]"#);
149 }
150
151 #[test]
152 fn test_vector_knn_serialization() {
153 let r = RankBy::vector_knn("embedding", vec![1.0, 2.0, 3.0]);
154 let json = serde_json::to_string(&r).unwrap();
155 assert_eq!(json, r#"["embedding","kNN",[1.0,2.0,3.0]]"#);
156 }
157
158 #[test]
159 fn test_bm25_serialization() {
160 let r = RankBy::bm25("content", "quick fox");
161 let json = serde_json::to_string(&r).unwrap();
162 assert_eq!(json, r#"["content","BM25","quick fox"]"#);
163 }
164
165 #[test]
166 fn test_bm25_with_params_serialization() {
167 let r = RankBy::bm25_with_params("content", "quick", Bm25Params {
168 last_as_prefix: Some(true),
169 });
170 let json = serde_json::to_string(&r).unwrap();
171 assert_eq!(json, r#"["content","BM25","quick",{"last_as_prefix":true}]"#);
172 }
173
174 #[test]
175 fn test_attribute_asc_serialization() {
176 let r = RankBy::asc("timestamp");
177 let json = serde_json::to_string(&r).unwrap();
178 assert_eq!(json, r#"["timestamp","asc"]"#);
179 }
180
181 #[test]
182 fn test_attribute_desc_serialization() {
183 let r = RankBy::desc("timestamp");
184 let json = serde_json::to_string(&r).unwrap();
185 assert_eq!(json, r#"["timestamp","desc"]"#);
186 }
187
188 #[test]
189 fn test_attribute_with_order_serialization() {
190 let r = RankBy::attribute("score", Order::Desc);
191 let json = serde_json::to_string(&r).unwrap();
192 assert_eq!(json, r#"["score","desc"]"#);
193 }
194
195 #[test]
196 fn test_sum_serialization() {
197 let r = RankBy::sum(vec![
198 RankBy::bm25("title", "fox"),
199 RankBy::bm25("content", "fox"),
200 ]);
201 let json = serde_json::to_string(&r).unwrap();
202 assert_eq!(json, r#"["Sum",[["title","BM25","fox"],["content","BM25","fox"]]]"#);
203 }
204
205 #[test]
206 fn test_max_serialization() {
207 let r = RankBy::max(vec![
208 RankBy::bm25("title", "query"),
209 RankBy::vector("vec", vec![0.1, 0.2]),
210 ]);
211 let json = serde_json::to_string(&r).unwrap();
212 assert_eq!(json, r#"["Max",[["title","BM25","query"],["vec","ANN",[0.1,0.2]]]]"#);
213 }
214
215 #[test]
216 fn test_product_serialization() {
217 let r = RankBy::product(2.0, RankBy::bm25("title", "fox"));
218 let json = serde_json::to_string(&r).unwrap();
219 assert_eq!(json, r#"["Product",2.0,["title","BM25","fox"]]"#);
220 }
221
222 #[test]
223 fn test_nested_combinators() {
224 let r = RankBy::sum(vec![
225 RankBy::product(2.0, RankBy::bm25("title", "query")),
226 RankBy::product(1.0, RankBy::bm25("content", "query")),
227 ]);
228 let json = serde_json::to_string(&r).unwrap();
229 assert_eq!(
230 json,
231 r#"["Sum",[["Product",2.0,["title","BM25","query"]],["Product",1.0,["content","BM25","query"]]]]"#
232 );
233 }
234
235 #[test]
236 fn test_empty_vector() {
237 let r = RankBy::vector("vec", vec![]);
238 let json = serde_json::to_string(&r).unwrap();
239 assert_eq!(json, r#"["vec","ANN",[]]"#);
240 }
241}