Skip to main content

opensearch_dsl/search/queries/compound/
function_score_query.rs

1use crate::{search::*, util::*};
2
3/// The `function_score` allows you to modify the score of documents that are
4/// retrieved by a query.
5///
6/// This can be useful if, for example, a score function is computationally
7/// expensive and it is sufficient to compute the score on a filtered set of
8/// documents.
9///
10/// To use `function_score`, the user has to define a query and one or more
11/// functions, that compute a new score for each document returned by the query.
12///
13/// To create function_score query:
14/// ```
15/// # use opensearch_dsl::queries::*;
16/// # use opensearch_dsl::queries::params::*;
17/// # let query =
18/// Query::function_score()
19///   .query(Query::term("test", 1))
20///   .function(
21///     RandomScore::new()
22///       .filter(Query::term("test", 1))
23///       .weight(2.0),
24///   )
25///   .function(Weight::new(2.0))
26///   .max_boost(2.2)
27///   .min_score(2.3)
28///   .score_mode(FunctionScoreMode::Avg)
29///   .boost_mode(FunctionBoostMode::Max)
30///   .boost(1.1)
31///   .name("test");
32/// ```
33/// <https://www.elastic.co/guide/en/opensearch/reference/current/query-dsl-function-score-query.html>
34#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
35#[serde(remote = "Self")]
36pub struct FunctionScoreQuery {
37    #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
38    query: Option<Box<Query>>,
39
40    #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
41    functions: Vec<Function>,
42
43    #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
44    max_boost: Option<f32>,
45
46    #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
47    min_score: Option<f32>,
48
49    #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
50    score_mode: Option<FunctionScoreMode>,
51
52    #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
53    boost_mode: Option<FunctionBoostMode>,
54
55    #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
56    boost: Option<f32>,
57
58    #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
59    _name: Option<String>,
60}
61
62impl Query {
63    /// Creates an instance of [`FunctionScoreQuery`]
64    pub fn function_score() -> FunctionScoreQuery {
65        FunctionScoreQuery {
66            query: None,
67            functions: Default::default(),
68            max_boost: None,
69            min_score: None,
70            score_mode: None,
71            boost_mode: None,
72            boost: None,
73            _name: None,
74        }
75    }
76}
77
78impl FunctionScoreQuery {
79    add_boost_and_name!();
80
81    /// Base function score query
82    pub fn query<T>(mut self, query: T) -> Self
83    where
84        T: Into<Option<Query>>,
85    {
86        self.query = query.into().map(Box::new);
87        self
88    }
89
90    /// Push function to the list
91    pub fn function<T>(mut self, function: T) -> Self
92    where
93        T: Into<Option<Function>>,
94    {
95        let function = function.into();
96
97        if let Some(function) = function {
98            self.functions.push(function);
99        }
100
101        self
102    }
103
104    /// Maximum score value after applying all the functions
105    pub fn max_boost<T>(mut self, max_boost: T) -> Self
106    where
107        T: num_traits::AsPrimitive<f32>,
108    {
109        self.max_boost = Some(max_boost.as_());
110        self
111    }
112
113    /// By default, modifying the score does not change which documents match. To
114    /// exclude documents
115
116    /// that do not meet a certain score threshold the `min_score` parameter can
117    /// be set to the desired score threshold.
118    pub fn min_score<T>(mut self, min_score: T) -> Self
119    where
120        T: Into<f32>,
121    {
122        self.min_score = Some(min_score.into());
123        self
124    }
125
126    /// Each document is scored by the defined functions. The parameter
127    /// `score_mode` specifies how the computed scores are combined
128    pub fn score_mode(mut self, score_mode: FunctionScoreMode) -> Self {
129        self.score_mode = Some(score_mode);
130        self
131    }
132
133    /// The newly computed score is combined with the score of the query. The
134    /// parameter `boost_mode` defines how.
135    pub fn boost_mode(mut self, boost_mode: FunctionBoostMode) -> Self {
136        self.boost_mode = Some(boost_mode);
137        self
138    }
139}
140
141impl ShouldSkip for FunctionScoreQuery {
142    fn should_skip(&self) -> bool {
143        self.query.should_skip() || self.functions.should_skip()
144    }
145}
146
147serialize_with_root!("function_score": FunctionScoreQuery);
148deserialize_with_root!("function_score": FunctionScoreQuery);
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn serialization() {
156        assert_serialize_query(
157            Query::function_score().function(RandomScore::new()),
158            json!({
159                "function_score": {
160                    "functions": [
161                        {
162                            "random_score": {}
163                        }
164                    ]
165                }
166            }),
167        );
168
169        assert_serialize_query(
170            Query::function_score()
171                .query(Query::term("test", 1))
172                .function(RandomScore::new())
173                .function(Weight::new(2.0))
174                .max_boost(2.2)
175                .min_score(2.3)
176                .score_mode(FunctionScoreMode::Avg)
177                .boost_mode(FunctionBoostMode::Max)
178                .boost(1.1)
179                .name("test"),
180            json!({
181                "function_score": {
182                    "query": {
183                        "term": {
184                            "test": {
185                                "value": 1
186                            }
187                        }
188                    },
189                    "functions": [
190                        {
191                            "random_score": {}
192                        },
193                        {
194                            "weight": 2.0
195                        }
196                    ],
197                    "max_boost": 2.2,
198                    "min_score": 2.3,
199                    "score_mode": "avg",
200                    "boost_mode": "max",
201                    "boost": 1.1,
202                    "_name": "test"
203                }
204            }),
205        );
206    }
207
208    #[test]
209    fn issue_24() {
210        let _ = json!({
211            "function_score": {
212                "boost_mode": "replace",
213                "functions": [
214                    {
215                        "filter": { "term": { "type": "stop" } },
216                        "field_value_factor": {
217                            "field": "weight",
218                            "factor": 1.0,
219                            "missing": 1.0
220                        },
221                        "weight": 1.0
222                    },
223                    {
224                        "filter": { "term": { "type": "address" } },
225                        "filter": { "term": { "type": "addr" } },
226                        "field_value_factor": {
227                            "field": "weight",
228                            "factor": 1.0,
229                            "missing": 1.0
230                        },
231                        "weight": 1.0
232                    },
233                    {
234                        "filter": { "term": { "type": "admin" } },
235                        "field_value_factor": {
236                            "field": "weight",
237                            "factor": 1.0,
238                            "missing": 1.0
239                        },
240                        "weight": 1.0
241                    },
242                    {
243                        "filter": { "term": { "type": "poi" } },
244                        "field_value_factor": {
245                            "field": "weight",
246                            "factor": 1.0,
247                            "missing": 1.0
248                        },
249                        "weight": 1.0
250                    },
251                    {
252                        "filter": { "term": { "type": "street" } },
253                        "field_value_factor": {
254                            "field": "weight",
255                            "factor": 1.0,
256                            "missing": 1.0
257                        },
258                        "weight": 1.0
259                    }
260                ]
261            }
262        });
263
264        let _ = Query::function_score()
265            .boost_mode(FunctionBoostMode::Replace)
266            .function(
267                FieldValueFactor::new("weight")
268                    .factor(1.0)
269                    .missing(1.0)
270                    .weight(1.0)
271                    .filter(Query::term("type", "stop")),
272            )
273            .function(
274                FieldValueFactor::new("weight")
275                    .factor(1.0)
276                    .missing(1.0)
277                    .weight(1.0)
278                    .filter(Query::terms("type", ["address", "addr"])),
279            )
280            .function(
281                FieldValueFactor::new("weight")
282                    .factor(1.0)
283                    .missing(1.0)
284                    .weight(1.0)
285                    .filter(Query::term("type", "admin")),
286            )
287            .function(
288                FieldValueFactor::new("weight")
289                    .factor(1.0)
290                    .missing(1.0)
291                    .weight(1.0)
292                    .filter(Query::term("type", "poi")),
293            )
294            .function(
295                FieldValueFactor::new("weight")
296                    .factor(1.0)
297                    .missing(1.0)
298                    .weight(1.0)
299                    .filter(Query::term("type", "street")),
300            );
301    }
302    #[test]
303    fn should_not_skip_serializing_function_score_with_empty_query_gh_257() {
304        assert_serialize(
305            Query::bool().should(
306                Query::function_score()
307                    .function(
308                        Function::field_value_factor("weight")
309                            .factor(10.0)
310                            .missing(0.0)
311                            .modifier(FieldValueFactorModifier::Log1P)
312                            .weight(0.3),
313                    )
314                    .score_mode(FunctionScoreMode::Max)
315                    .boost_mode(FunctionBoostMode::Replace),
316            ),
317            json!( {
318                     "bool": {
319                       "should": [
320                         {
321                           "function_score": {
322                             "boost_mode": "replace",
323                             "functions": [
324                               {
325                                 "field_value_factor": {
326                                   "factor": 10.0,
327                                   "field": "weight",
328                                   "missing": 0.0,
329                                   "modifier": "log1p"
330                                 },
331                                 "weight": 0.3
332                               }
333                             ],
334                             "score_mode": "max"
335                           }
336                         }
337                       ]
338                    }
339            }),
340        )
341    }
342}