elasticsearch_dsl/search/queries/params/
function_score_query.rs

1use crate::search::*;
2use crate::util::*;
3use chrono::{DateTime, Utc};
4use serde::ser::{Serialize, SerializeMap, Serializer};
5use std::fmt::Debug;
6
7/// Each document is scored by the defined functions. The parameter `score_mode` specifies how
8/// the computed scores are combined
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Default)]
10#[serde(rename_all = "snake_case")]
11pub enum FunctionScoreMode {
12    /// Scores are multiplied (default)
13    #[default]
14    Multiply,
15
16    /// Scores are summed
17    Sum,
18
19    /// Scores are averaged
20    Avg,
21
22    /// The first function that has a matching filter is applied
23    First,
24
25    /// Maximum score is used
26    Max,
27
28    /// Minimum score is used
29    Min,
30}
31
32/// The newly computed score is combined with the score of the query. The parameter
33/// `boost_mode` defines how.
34#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Default)]
35#[serde(rename_all = "snake_case")]
36pub enum FunctionBoostMode {
37    /// Query score and function score is multiplied (default)
38    #[default]
39    Multiply,
40
41    /// Only function score is used, the query score is ignored
42    Replace,
43
44    /// Query score and function score are added
45    Sum,
46
47    /// Average
48    Avg,
49
50    /// Max of query score and function score
51    Max,
52
53    /// Min of query score and function score
54    Min,
55}
56
57macro_rules! function {
58    ($name:ident { $($variant:ident($query:ty)),+ $(,)? }) => {
59        /// Functions available for use in [FunctionScoreQuery](crate::FunctionScoreQuery)
60        #[derive(Debug, Clone, PartialEq, Serialize)]
61        #[allow(missing_docs)]
62        #[serde(untagged)]
63        pub enum $name {
64            $(
65                $variant($query),
66            )*
67        }
68
69        $(
70            impl From<$query> for $name {
71                fn from(q: $query) -> Self {
72                    $name::$variant(q)
73                }
74            }
75        )+
76
77        $(
78            impl From<$query> for Option<$name> {
79                fn from(q: $query) -> Self {
80                    Some($name::$variant(q))
81                }
82            }
83        )+
84    };
85}
86
87function!(Function {
88    Weight(Weight),
89    RandomScore(RandomScore),
90    FieldValueFactor(FieldValueFactor),
91    DecayDateTime(Decay<DateTime<Utc>>),
92    DecayLocation(Decay<GeoLocation>),
93    DecayI8(Decay<i8>),
94    DecayI16(Decay<i16>),
95    DecayI32(Decay<i32>),
96    DecayI64(Decay<i64>),
97    DecayU8(Decay<u8>),
98    DecayU16(Decay<u16>),
99    DecayU32(Decay<u32>),
100    DecayU64(Decay<u64>),
101    DecayF32(Decay<f32>),
102    DecayF64(Decay<f64>),
103    ScriptScore(ScriptScore),
104});
105
106impl Function {
107    /// Creates an instance of [Weight](Weight)
108    pub fn weight(weight: f32) -> Weight {
109        Weight::new(weight)
110    }
111
112    /// Creates an instance of [RandomScore](RandomScore)
113    pub fn random_score() -> RandomScore {
114        RandomScore::new()
115    }
116
117    /// Creates an instance of [FieldValueFactor](FieldValueFactor)
118    ///
119    /// - `field` - Field to be extracted from the document.
120    pub fn field_value_factor<T>(field: T) -> FieldValueFactor
121    where
122        T: ToString,
123    {
124        FieldValueFactor::new(field)
125    }
126
127    /// Creates an instance of [Decay](Decay)
128    ///
129    /// - `function` - Decay function variant
130    /// - `field` - Field to apply function to
131    /// - `origin` - The point of origin used for calculating distance. Must be given as a number
132    ///   for numeric field, date for date fields and geo point for geo fields. Required for geo and
133    ///   numeric field. For date fields the default is `now`. Date math (for example now-1h) is
134    ///   supported for origin.
135    /// - `scale` - Required for all types. Defines the distance from origin + offset at which the
136    ///   computed score will equal `decay` parameter. For geo fields: Can be defined as number+unit
137    ///   (1km, 12m,…​). Default unit is meters. For date fields: Can to be defined as a number+unit
138    ///   ("1h", "10d",…​). Default unit is milliseconds. For numeric field: Any number.
139    pub fn decay<T, O>(
140        function: DecayFunction,
141        field: T,
142        origin: O,
143        scale: <O as Origin>::Scale,
144    ) -> Decay<O>
145    where
146        T: ToString,
147        O: Origin,
148    {
149        Decay::new(function, field, origin, scale)
150    }
151
152    /// Creates an instance of script
153    ///
154    /// - `source` - script source
155    pub fn script(source: Script) -> ScriptScore {
156        ScriptScore::new(source)
157    }
158}
159
160/// The `weight` score allows you to multiply the score by the provided weight.
161///
162/// This can sometimes be desired since boost value set on specific queries gets normalized, while
163/// for this score function it does not
164#[derive(Debug, Clone, PartialEq, Serialize)]
165pub struct Weight {
166    weight: f32,
167    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
168    filter: Option<Query>,
169}
170
171impl Weight {
172    /// Creates an instance of [Weight](Weight)
173    pub fn new(weight: f32) -> Self {
174        Self {
175            weight,
176            filter: None,
177        }
178    }
179
180    /// Add function filter
181    pub fn filter<T>(mut self, filter: T) -> Self
182    where
183        T: Into<Option<Query>>,
184    {
185        self.filter = filter.into();
186        self
187    }
188}
189
190/// The `random_score` generates scores that are uniformly distributed from `0` up to but not
191/// including `1`.
192///
193/// By default, it uses the internal Lucene doc ids as a source of randomness, which is very
194/// efficient but unfortunately not reproducible since documents might be renumbered by merges.
195///
196/// In case you want scores to be reproducible, it is possible to provide a `seed` and `field`. The
197/// final score will then be computed based on this seed, the minimum value of `field` for the
198/// considered document and a salt that is computed based on the index name and shard id so that
199/// documents that have the same value but are stored in different indexes get different scores.
200/// Note that documents that are within the same shard and have the same value for `field` will
201/// however get the same score, so it is usually desirable to use a field that has unique values
202/// for all documents. A good default choice might be to use the `_seq_no` field, whose only
203/// drawback is that scores will change if the document is updated since update operations also
204/// update the value of the `_seq_no` field.
205#[derive(Debug, Default, Clone, PartialEq, Serialize)]
206pub struct RandomScore {
207    random_score: RandomScoreInner,
208
209    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
210    filter: Option<Query>,
211
212    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
213    weight: Option<f32>,
214}
215
216#[derive(Debug, Default, Clone, PartialEq, Serialize)]
217struct RandomScoreInner {
218    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
219    seed: Option<Term>,
220
221    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
222    field: Option<String>,
223}
224
225impl RandomScore {
226    /// Creates an instance of [RandomScore](RandomScore)
227    pub fn new() -> Self {
228        Default::default()
229    }
230
231    /// Add function filter
232    pub fn filter<T>(mut self, filter: T) -> Self
233    where
234        T: Into<Option<Query>>,
235    {
236        self.filter = filter.into();
237        self
238    }
239
240    /// The `weight` score allows you to multiply the score by the provided `weight`. This can sometimes be desired
241    /// since boost value set on specific queries gets normalized, while for this score function it does not.
242    /// The number value is of type float.
243    pub fn weight<T>(mut self, weight: T) -> Self
244    where
245        T: num_traits::AsPrimitive<f32>,
246    {
247        self.weight = Some(weight.as_());
248        self
249    }
250
251    /// Sets seed value
252    pub fn seed<T>(mut self, seed: T) -> Self
253    where
254        T: Serialize,
255    {
256        self.random_score.seed = Term::new(seed);
257        self
258    }
259
260    /// Sets field value
261    pub fn field<T>(mut self, field: T) -> Self
262    where
263        T: ToString,
264    {
265        self.random_score.field = Some(field.to_string());
266        self
267    }
268}
269
270/// The `field_value_factor` function allows you to use a field from a document to influence the
271/// score.
272/// It’s similar to using the `script_score` function, however, it avoids the overhead of scripting.
273/// If used on a multi-valued field, only the first value of the field is used in calculations.
274///
275/// As an example, imagine you have a document indexed with a numeric `my-int` field and wish to
276/// influence the score of a document with this field, an example doing so would look like:
277/// ```
278/// # use elasticsearch_dsl::{FieldValueFactor, FieldValueFactorModifier};
279/// # fn main() {
280/// # let _ =
281/// FieldValueFactor::new("my-int")
282///     .factor(1.2)
283///     .modifier(FieldValueFactorModifier::Sqrt)
284///     .missing(1.0)
285/// # ;}
286/// ```
287/// Which will translate into the following formula for scoring:
288/// ```text
289/// sqrt(1.2 * doc['my-int'].value)
290/// ```
291#[derive(Debug, Clone, PartialEq, Serialize)]
292pub struct FieldValueFactor {
293    field_value_factor: FieldValueFactorInner,
294
295    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
296    filter: Option<Query>,
297
298    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
299    weight: Option<f32>,
300}
301
302#[derive(Debug, Clone, PartialEq, Serialize)]
303struct FieldValueFactorInner {
304    field: String,
305
306    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
307    factor: Option<f32>,
308
309    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
310    modifier: Option<FieldValueFactorModifier>,
311
312    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
313    missing: Option<f32>,
314}
315
316impl FieldValueFactor {
317    /// Creates an instance of [FieldValueFactor](FieldValueFactor)
318    ///
319    /// - `field` - Field to be extracted from the document.
320    pub fn new<T>(field: T) -> Self
321    where
322        T: ToString,
323    {
324        Self {
325            field_value_factor: FieldValueFactorInner {
326                field: field.to_string(),
327                factor: None,
328                modifier: None,
329                missing: None,
330            },
331            filter: None,
332            weight: None,
333        }
334    }
335
336    /// Add function filter
337    pub fn filter<T>(mut self, filter: T) -> Self
338    where
339        T: Into<Option<Query>>,
340    {
341        self.filter = filter.into();
342        self
343    }
344
345    /// The `weight` score allows you to multiply the score by the provided `weight`. This can sometimes be desired
346    /// since boost value set on specific queries gets normalized, while for this score function it does not.
347    /// The number value is of type float.
348    pub fn weight<T>(mut self, weight: T) -> Self
349    where
350        T: num_traits::AsPrimitive<f32>,
351    {
352        self.weight = Some(weight.as_());
353        self
354    }
355
356    /// Factor to multiply the field value with
357    pub fn factor(mut self, factor: f32) -> Self {
358        self.field_value_factor.factor = Some(factor);
359        self
360    }
361
362    /// Modifier to apply to the field value
363    pub fn modifier(mut self, modifier: FieldValueFactorModifier) -> Self {
364        self.field_value_factor.modifier = Some(modifier);
365        self
366    }
367
368    /// Value used if the document doesn’t have that field. The modifier and factor are still
369    /// applied to it as though it were read from the document
370    pub fn missing(mut self, missing: f32) -> Self {
371        self.field_value_factor.missing = Some(missing);
372        self
373    }
374}
375
376/// Modifier to apply to the field value
377///
378/// Defaults to [none](FieldValueFactorModifier::None)
379#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize)]
380#[serde(rename_all = "lowercase")]
381pub enum FieldValueFactorModifier {
382    /// Do not apply any multiplier to the field value
383    None,
384
385    /// Take the [common logarithm](https://en.wikipedia.org/wiki/Common_logarithm) of the field
386    /// value
387    ///
388    /// Because this function will return a negative value and cause an error if used on values
389    /// between `0` and `1`, it is recommended to use [log1p](FieldValueFactorModifier::Log1P)
390    /// instead.
391    Log,
392
393    /// Add 1 to the field value and take the common logarithm
394    Log1P,
395
396    /// Add 2 to the field value and take the common logarithm
397    Log2P,
398
399    /// Take the [natural logarithm](https://en.wikipedia.org/wiki/Natural_logarithm) of the field
400    /// value.
401    ///
402    /// Because this function will return a negative value and cause an error if used on values
403    /// between `0` and `1`, it is recommended to use [ln1p](FieldValueFactorModifier::Ln1P)
404    /// instead.
405    Ln,
406
407    /// Add 1 to the field value and take the natural logarithm
408    Ln1P,
409
410    /// Add 2 to the field value and take the natural logarithm
411    Ln2P,
412
413    /// Square the field value (multiply it by itself)
414    Square,
415
416    /// Take the [square root](https://en.wikipedia.org/wiki/Square_root) of the field value
417    Sqrt,
418
419    /// [Reciprocate](https://en.wikipedia.org/wiki/Multiplicative_inverse) the field value, same
420    /// as `1/x` where `x` is the field’s value
421    Reciprocal,
422}
423
424#[doc(hidden)]
425pub trait Origin: Debug + PartialEq + Serialize + Clone {
426    type Scale: Debug + PartialEq + Serialize + Clone;
427    type Offset: Debug + PartialEq + Serialize + Clone;
428}
429
430impl Origin for DateTime<Utc> {
431    type Scale = Time;
432    type Offset = Time;
433}
434
435impl Origin for GeoLocation {
436    type Scale = Distance;
437    type Offset = Distance;
438}
439
440macro_rules! impl_origin_for_numbers {
441    ($($name:ident ),+) => {
442        $(
443            impl Origin for $name {
444                type Scale = Self;
445                type Offset = Self;
446            }
447        )+
448    }
449}
450
451impl_origin_for_numbers![i8, i16, i32, i64, u8, u16, u32, u64, f32, f64];
452
453/// Decay functions score a document with a function that decays depending on the distance of a
454/// numeric field value of the document from a user given origin. This is similar to a range query,
455/// but with smooth edges instead of boxes.
456///
457/// To use distance scoring on a query that has numerical fields, the user has to define an
458/// `origin` and a `scale` for each field. The `origin` is needed to define the “central point”
459/// from which the distance is calculated, and the `scale` to define the rate of decay.
460#[derive(Debug, Clone, PartialEq)]
461pub struct Decay<T: Origin> {
462    function: DecayFunction,
463
464    inner: DecayFieldInner<T>,
465
466    filter: Option<Query>,
467
468    weight: Option<f32>,
469}
470
471#[derive(Debug, Clone, PartialEq)]
472struct DecayFieldInner<T: Origin> {
473    field: String,
474    inner: DecayInner<T>,
475}
476
477#[derive(Debug, Clone, PartialEq, Serialize)]
478struct DecayInner<O>
479where
480    O: Origin,
481{
482    origin: O,
483
484    scale: <O as Origin>::Scale,
485
486    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
487    offset: Option<<O as Origin>::Offset>,
488
489    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
490    decay: Option<f32>,
491}
492
493impl<O> Decay<O>
494where
495    O: Origin,
496{
497    /// Creates an instance of [Decay](Decay)
498    ///
499    /// - `function` - Decay function variant
500    /// - `field` - Field to apply function to
501    /// - `origin` - The point of origin used for calculating distance. Must be given as a number
502    ///   for numeric field, date for date fields and geo point for geo fields. Required for geo and
503    ///   numeric field. For date fields the default is `now`. Date math (for example now-1h) is
504    ///   supported for origin.
505    /// - `scale` - Required for all types. Defines the distance from origin + offset at which the
506    ///   computed score will equal `decay` parameter. For geo fields: Can be defined as number+unit
507    ///   (1km, 12m,…​). Default unit is meters. For date fields: Can to be defined as a number+unit
508    ///   ("1h", "10d",…​). Default unit is milliseconds. For numeric field: Any number.
509    pub fn new<T>(function: DecayFunction, field: T, origin: O, scale: <O as Origin>::Scale) -> Self
510    where
511        T: ToString,
512    {
513        Self {
514            function,
515            inner: DecayFieldInner {
516                field: field.to_string(),
517                inner: DecayInner {
518                    origin,
519                    scale,
520                    offset: None,
521                    decay: None,
522                },
523            },
524            filter: None,
525            weight: None,
526        }
527    }
528
529    /// Add function filter
530    pub fn filter<T>(mut self, filter: T) -> Self
531    where
532        T: Into<Option<Query>>,
533    {
534        self.filter = filter.into();
535        self
536    }
537
538    /// The `weight` score allows you to multiply the score by the provided `weight`. This can sometimes be desired
539    /// since boost value set on specific queries gets normalized, while for this score function it does not.
540    /// The number value is of type float.
541    pub fn weight<T>(mut self, weight: T) -> Self
542    where
543        T: num_traits::AsPrimitive<f32>,
544    {
545        self.weight = Some(weight.as_());
546        self
547    }
548
549    /// If an `offset` is defined, the decay function will only compute the decay function for
550    /// documents with a distance greater than the defined `offset`.
551    ///
552    /// The default is `0`.
553    pub fn offset(mut self, offset: <O as Origin>::Offset) -> Self {
554        self.inner.inner.offset = Some(offset);
555        self
556    }
557
558    /// The `decay` parameter defines how documents are scored at the distance given at `scale`. If
559    /// no `decay` is defined, documents at the distance `scale` will be scored `0.5`.
560    pub fn decay(mut self, decay: f32) -> Self {
561        self.inner.inner.decay = Some(decay);
562        self
563    }
564}
565
566impl<T: Origin> Serialize for Decay<T> {
567    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
568    where
569        S: Serializer,
570    {
571        let mut map = serializer.serialize_map(Some(3))?;
572
573        map.serialize_entry(&self.function, &self.inner)?;
574
575        if let Some(filter) = &self.filter {
576            map.serialize_entry("filter", filter)?;
577        }
578
579        if let Some(weight) = &self.weight {
580            map.serialize_entry("weight", weight)?;
581        }
582
583        map.end()
584    }
585}
586
587impl<T: Origin> Serialize for DecayFieldInner<T> {
588    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
589    where
590        S: Serializer,
591    {
592        let mut map = serializer.serialize_map(Some(1))?;
593
594        map.serialize_entry(&self.field, &self.inner)?;
595
596        map.end()
597    }
598}
599
600/// Decay function variants
601///
602/// <https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-function-score-query.html#_supported_decay_functions>
603#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize)]
604#[serde(rename_all = "snake_case")]
605pub enum DecayFunction {
606    /// Linear decay
607    Linear,
608
609    /// Exponential decay
610    Exp,
611
612    /// Gauss decay
613    Gauss,
614}
615
616/// The script_score function allows you to wrap another query and customize the scoring of it
617/// optionally with a computation derived from other numeric field values in the doc using a script
618/// expression
619#[derive(Debug, Clone, PartialEq, Serialize)]
620pub struct ScriptScore {
621    script_score: ScriptWrapper,
622}
623
624#[derive(Debug, Clone, PartialEq, Serialize)]
625struct ScriptWrapper {
626    script: Script,
627}
628
629impl ScriptScore {
630    /// Creates an instance of [Script]
631    ///
632    /// - `script` - script source
633    pub fn new(script: Script) -> Self {
634        Self {
635            script_score: ScriptWrapper { script },
636        }
637    }
638}
639
640#[cfg(test)]
641mod tests {
642    use super::*;
643    use chrono::prelude::*;
644
645    #[test]
646    fn serialization() {
647        assert_serialize(
648            Decay::new(
649                DecayFunction::Gauss,
650                "test",
651                Utc.with_ymd_and_hms(2014, 7, 8, 9, 1, 0).single().unwrap(),
652                Time::Days(7),
653            ),
654            json!({
655                "gauss": {
656                    "test": {
657                        "origin": "2014-07-08T09:01:00Z",
658                        "scale": "7d",
659                    }
660                }
661            }),
662        );
663
664        assert_serialize(
665            Decay::new(DecayFunction::Linear, "test", 1, 2),
666            json!({
667                "linear": {
668                    "test": {
669                        "origin": 1,
670                        "scale": 2,
671                    }
672                }
673            }),
674        );
675
676        assert_serialize(
677            ScriptScore::new(Script::source("Math.log(2 + doc['my-int'].value)")),
678            json!({
679                "script_score": {
680                    "script": {
681                        "source": "Math.log(2 + doc['my-int'].value)"
682                    }
683                }
684            }),
685        );
686    }
687
688    #[test]
689    fn float_decay() {
690        assert_serialize(
691            Decay::new(DecayFunction::Linear, "test", 0.1, 0.5),
692            json!({
693                "linear": {
694                    "test": {
695                        "origin": 0.1,
696                        "scale": 0.5
697                    }
698                }
699            }),
700        );
701    }
702}