elasticsearch_dsl/search/queries/compound/
function_score_query.rs

1use crate::search::*;
2use crate::util::*;
3
4/// The `function_score` allows you to modify the score of documents that are retrieved by a query.
5///
6/// This can be useful if, for example, a score function is computationally expensive and it is
7/// sufficient to compute the score on a filtered set of documents.
8///
9/// To use `function_score`, the user has to define a query and one or more functions, that compute
10/// a new score for each document returned by the query.
11///
12/// To create function_score query:
13/// ```
14/// # use elasticsearch_dsl::queries::*;
15/// # use elasticsearch_dsl::queries::params::*;
16/// # let query =
17/// Query::function_score()
18///     .query(Query::term("test", 1))
19///     .function(RandomScore::new().filter(Query::term("test", 1)).weight(2.0))
20///     .function(Weight::new(2.0))
21///     .max_boost(2.2)
22///     .min_score(2.3)
23///     .score_mode(FunctionScoreMode::Avg)
24///     .boost_mode(FunctionBoostMode::Max)
25///     .boost(1.1)
26///     .name("test");
27/// ```
28/// <https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-function-score-query.html>
29#[derive(Debug, Clone, PartialEq, Serialize)]
30#[serde(remote = "Self")]
31pub struct FunctionScoreQuery {
32    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
33    query: Option<Box<Query>>,
34
35    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
36    functions: Vec<Function>,
37
38    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
39    max_boost: Option<f32>,
40
41    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
42    min_score: Option<f32>,
43
44    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
45    score_mode: Option<FunctionScoreMode>,
46
47    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
48    boost_mode: Option<FunctionBoostMode>,
49
50    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
51    boost: Option<f32>,
52
53    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
54    _name: Option<String>,
55}
56
57impl Query {
58    /// Creates an instance of [`FunctionScoreQuery`]
59    pub fn function_score() -> FunctionScoreQuery {
60        FunctionScoreQuery {
61            query: None,
62            functions: Default::default(),
63            max_boost: None,
64            min_score: None,
65            score_mode: None,
66            boost_mode: None,
67            boost: None,
68            _name: None,
69        }
70    }
71}
72
73impl FunctionScoreQuery {
74    /// Base function score query
75    pub fn query<T>(mut self, query: T) -> Self
76    where
77        T: Into<Option<Query>>,
78    {
79        self.query = query.into().map(Box::new);
80        self
81    }
82
83    /// Push function to the list
84    pub fn function<T>(mut self, function: T) -> Self
85    where
86        T: Into<Option<Function>>,
87    {
88        let function = function.into();
89
90        if let Some(function) = function {
91            self.functions.push(function);
92        }
93
94        self
95    }
96
97    /// Maximum score value after applying all the functions
98    pub fn max_boost<T>(mut self, max_boost: T) -> Self
99    where
100        T: num_traits::AsPrimitive<f32>,
101    {
102        self.max_boost = Some(max_boost.as_());
103        self
104    }
105
106    /// By default, modifying the score does not change which documents match. To exclude documents
107    /// that do not meet a certain score threshold the `min_score` parameter can be set to the
108    /// desired score threshold.
109    pub fn min_score<T>(mut self, min_score: T) -> Self
110    where
111        T: Into<f32>,
112    {
113        self.min_score = Some(min_score.into());
114        self
115    }
116
117    /// Each document is scored by the defined functions. The parameter `score_mode` specifies how
118    /// the computed scores are combined
119    pub fn score_mode(mut self, score_mode: FunctionScoreMode) -> Self {
120        self.score_mode = Some(score_mode);
121        self
122    }
123
124    /// The newly computed score is combined with the score of the query. The parameter
125    /// `boost_mode` defines how.
126    pub fn boost_mode(mut self, boost_mode: FunctionBoostMode) -> Self {
127        self.boost_mode = Some(boost_mode);
128        self
129    }
130
131    add_boost_and_name!();
132}
133
134impl ShouldSkip for FunctionScoreQuery {
135    fn should_skip(&self) -> bool {
136        self.functions.should_skip()
137    }
138}
139
140serialize_with_root!("function_score": FunctionScoreQuery);
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn serialization() {
148        assert_serialize_query(
149            Query::function_score().function(RandomScore::new()),
150            json!({
151                "function_score": {
152                    "functions": [
153                        {
154                            "random_score": {}
155                        }
156                    ]
157                }
158            }),
159        );
160
161        assert_serialize_query(
162            Query::function_score()
163                .query(Query::term("test", 1))
164                .function(RandomScore::new())
165                .function(Weight::new(2.0))
166                .max_boost(2.2)
167                .min_score(2.3)
168                .score_mode(FunctionScoreMode::Avg)
169                .boost_mode(FunctionBoostMode::Max)
170                .boost(1.1)
171                .name("test"),
172            json!({
173                "function_score": {
174                    "query": {
175                        "term": {
176                            "test": {
177                                "value": 1
178                            }
179                        }
180                    },
181                    "functions": [
182                        {
183                            "random_score": {}
184                        },
185                        {
186                            "weight": 2.0
187                        }
188                    ],
189                    "max_boost": 2.2,
190                    "min_score": 2.3,
191                    "score_mode": "avg",
192                    "boost_mode": "max",
193                    "boost": 1.1,
194                    "_name": "test"
195                }
196            }),
197        );
198    }
199
200    #[test]
201    fn issue_24() {
202        let _ = json!({
203            "function_score": {
204                "boost_mode": "replace",
205                "functions": [
206                    {
207                        "filter": { "term": { "type": "stop" } },
208                        "field_value_factor": {
209                            "field": "weight",
210                            "factor": 1.0,
211                            "missing": 1.0
212                        },
213                        "weight": 1.0
214                    },
215                    {
216                        "filter": { "term": { "type": "address" } },
217                        "filter": { "term": { "type": "addr" } },
218                        "field_value_factor": {
219                            "field": "weight",
220                            "factor": 1.0,
221                            "missing": 1.0
222                        },
223                        "weight": 1.0
224                    },
225                    {
226                        "filter": { "term": { "type": "admin" } },
227                        "field_value_factor": {
228                            "field": "weight",
229                            "factor": 1.0,
230                            "missing": 1.0
231                        },
232                        "weight": 1.0
233                    },
234                    {
235                        "filter": { "term": { "type": "poi" } },
236                        "field_value_factor": {
237                            "field": "weight",
238                            "factor": 1.0,
239                            "missing": 1.0
240                        },
241                        "weight": 1.0
242                    },
243                    {
244                        "filter": { "term": { "type": "street" } },
245                        "field_value_factor": {
246                            "field": "weight",
247                            "factor": 1.0,
248                            "missing": 1.0
249                        },
250                        "weight": 1.0
251                    }
252                ]
253            }
254        });
255
256        let _ = Query::function_score()
257            .boost_mode(FunctionBoostMode::Replace)
258            .function(
259                FieldValueFactor::new("weight")
260                    .factor(1.0)
261                    .missing(1.0)
262                    .weight(1.0)
263                    .filter(Query::term("type", "stop")),
264            )
265            .function(
266                FieldValueFactor::new("weight")
267                    .factor(1.0)
268                    .missing(1.0)
269                    .weight(1.0)
270                    .filter(Query::terms("type", ["address", "addr"])),
271            )
272            .function(
273                FieldValueFactor::new("weight")
274                    .factor(1.0)
275                    .missing(1.0)
276                    .weight(1.0)
277                    .filter(Query::term("type", "admin")),
278            )
279            .function(
280                FieldValueFactor::new("weight")
281                    .factor(1.0)
282                    .missing(1.0)
283                    .weight(1.0)
284                    .filter(Query::term("type", "poi")),
285            )
286            .function(
287                FieldValueFactor::new("weight")
288                    .factor(1.0)
289                    .missing(1.0)
290                    .weight(1.0)
291                    .filter(Query::term("type", "street")),
292            );
293    }
294
295    #[test]
296    fn should_not_skip_serializing_function_score_with_empty_query_gh_257() {
297        assert_serialize(
298            Query::bool().should(
299                Query::function_score()
300                    .function(
301                        Function::field_value_factor("weight")
302                            .factor(10.0)
303                            .missing(0.0)
304                            .modifier(FieldValueFactorModifier::Log1P)
305                            .weight(0.3),
306                    )
307                    .score_mode(FunctionScoreMode::Max)
308                    .boost_mode(FunctionBoostMode::Replace),
309            ),
310            json!( {
311                     "bool": {
312                       "should": [
313                         {
314                           "function_score": {
315                             "boost_mode": "replace",
316                             "functions": [
317                               {
318                                 "field_value_factor": {
319                                   "factor": 10.0,
320                                   "field": "weight",
321                                   "missing": 0.0,
322                                   "modifier": "log1p"
323                                 },
324                                 "weight": 0.3
325                               }
326                             ],
327                             "score_mode": "max"
328                           }
329                         }
330                       ]
331                    }
332            }),
333        )
334    }
335}