elasticsearch_dsl/search/queries/params/
function_score_query.rs

1use crate::search::*;
2use crate::util::*;
3use chrono::{DateTime, Utc};
4use serde::ser::{Serialize, SerializeMap, Serializer};
5use std::fmt::Debug;
6
7/// Each document is scored by the defined functions. The parameter `score_mode` specifies how
8/// the computed scores are combined
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize)]
10#[serde(rename_all = "snake_case")]
11pub enum FunctionScoreMode {
12    /// Scores are multiplied (default)
13    Multiply,
14
15    /// Scores are summed
16    Sum,
17
18    /// Scores are averaged
19    Avg,
20
21    /// The first function that has a matching filter is applied
22    First,
23
24    /// Maximum score is used
25    Max,
26
27    /// Minimum score is used
28    Min,
29}
30
31impl Default for FunctionScoreMode {
32    fn default() -> Self {
33        Self::Multiply
34    }
35}
36
37/// The newly computed score is combined with the score of the query. The parameter
38/// `boost_mode` defines how.
39#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize)]
40#[serde(rename_all = "snake_case")]
41pub enum FunctionBoostMode {
42    /// Query score and function score is multiplied (default)
43    Multiply,
44
45    /// Only function score is used, the query score is ignored
46    Replace,
47
48    /// Query score and function score are added
49    Sum,
50
51    /// Average
52    Avg,
53
54    /// Max of query score and function score
55    Max,
56
57    /// Min of query score and function score
58    Min,
59}
60
61impl Default for FunctionBoostMode {
62    fn default() -> Self {
63        Self::Multiply
64    }
65}
66
67macro_rules! function {
68    ($name:ident { $($variant:ident($query:ty)),+ $(,)? }) => {
69        /// Functions available for use in [FunctionScoreQuery](crate::FunctionScoreQuery)
70        #[derive(Debug, Clone, PartialEq, Serialize)]
71        #[allow(missing_docs)]
72        #[serde(untagged)]
73        pub enum $name {
74            $(
75                $variant($query),
76            )*
77        }
78
79        $(
80            impl From<$query> for $name {
81                fn from(q: $query) -> Self {
82                    $name::$variant(q)
83                }
84            }
85        )+
86
87        $(
88            impl From<$query> for Option<$name> {
89                fn from(q: $query) -> Self {
90                    Some($name::$variant(q))
91                }
92            }
93        )+
94    };
95}
96
97function!(Function {
98    Weight(Weight),
99    RandomScore(RandomScore),
100    FieldValueFactor(FieldValueFactor),
101    DecayDateTime(Decay<DateTime<Utc>>),
102    DecayLocation(Decay<GeoLocation>),
103    DecayI8(Decay<i8>),
104    DecayI16(Decay<i16>),
105    DecayI32(Decay<i32>),
106    DecayI64(Decay<i64>),
107    DecayU8(Decay<u8>),
108    DecayU16(Decay<u16>),
109    DecayU32(Decay<u32>),
110    DecayU64(Decay<u64>),
111    DecayF32(Decay<f32>),
112    DecayF64(Decay<f64>),
113    ScriptScore(ScriptScore),
114});
115
116impl Function {
117    /// Creates an instance of [Weight](Weight)
118    pub fn weight(weight: f32) -> Weight {
119        Weight::new(weight)
120    }
121
122    /// Creates an instance of [RandomScore](RandomScore)
123    pub fn random_score() -> RandomScore {
124        RandomScore::new()
125    }
126
127    /// Creates an instance of [FieldValueFactor](FieldValueFactor)
128    ///
129    /// - `field` - Field to be extracted from the document.
130    pub fn field_value_factor<T>(field: T) -> FieldValueFactor
131    where
132        T: ToString,
133    {
134        FieldValueFactor::new(field)
135    }
136
137    /// Creates an instance of [Decay](Decay)
138    ///
139    /// - `function` - Decay function variant
140    /// - `field` - Field to apply function to
141    /// - `origin` - The point of origin used for calculating distance. Must be given as a number
142    ///   for numeric field, date for date fields and geo point for geo fields. Required for geo and
143    ///   numeric field. For date fields the default is `now`. Date math (for example now-1h) is
144    ///   supported for origin.
145    /// - `scale` - Required for all types. Defines the distance from origin + offset at which the
146    ///   computed score will equal `decay` parameter. For geo fields: Can be defined as number+unit
147    ///   (1km, 12m,…​). Default unit is meters. For date fields: Can to be defined as a number+unit
148    ///   ("1h", "10d",…​). Default unit is milliseconds. For numeric field: Any number.
149    pub fn decay<T, O>(
150        function: DecayFunction,
151        field: T,
152        origin: O,
153        scale: <O as Origin>::Scale,
154    ) -> Decay<O>
155    where
156        T: ToString,
157        O: Origin,
158    {
159        Decay::new(function, field, origin, scale)
160    }
161
162    /// Creates an instance of script
163    ///
164    /// - `source` - script source
165    pub fn script(source: Script) -> ScriptScore {
166        ScriptScore::new(source)
167    }
168}
169
170/// The `weight` score allows you to multiply the score by the provided weight.
171///
172/// This can sometimes be desired since boost value set on specific queries gets normalized, while
173/// for this score function it does not
174#[derive(Debug, Clone, PartialEq, Serialize)]
175pub struct Weight {
176    weight: f32,
177    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
178    filter: Option<Query>,
179}
180
181impl Weight {
182    /// Creates an instance of [Weight](Weight)
183    pub fn new(weight: f32) -> Self {
184        Self {
185            weight,
186            filter: None,
187        }
188    }
189
190    /// Add function filter
191    pub fn filter<T>(mut self, filter: T) -> Self
192    where
193        T: Into<Option<Query>>,
194    {
195        self.filter = filter.into();
196        self
197    }
198}
199
200/// The `random_score` generates scores that are uniformly distributed from `0` up to but not
201/// including `1`.
202///
203/// By default, it uses the internal Lucene doc ids as a source of randomness, which is very
204/// efficient but unfortunately not reproducible since documents might be renumbered by merges.
205///
206/// In case you want scores to be reproducible, it is possible to provide a `seed` and `field`. The
207/// final score will then be computed based on this seed, the minimum value of `field` for the
208/// considered document and a salt that is computed based on the index name and shard id so that
209/// documents that have the same value but are stored in different indexes get different scores.
210/// Note that documents that are within the same shard and have the same value for `field` will
211/// however get the same score, so it is usually desirable to use a field that has unique values
212/// for all documents. A good default choice might be to use the `_seq_no` field, whose only
213/// drawback is that scores will change if the document is updated since update operations also
214/// update the value of the `_seq_no` field.
215#[derive(Debug, Default, Clone, PartialEq, Serialize)]
216pub struct RandomScore {
217    random_score: RandomScoreInner,
218
219    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
220    filter: Option<Query>,
221
222    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
223    weight: Option<f32>,
224}
225
226#[derive(Debug, Default, Clone, PartialEq, Serialize)]
227struct RandomScoreInner {
228    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
229    seed: Option<Term>,
230
231    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
232    field: Option<String>,
233}
234
235impl RandomScore {
236    /// Creates an instance of [RandomScore](RandomScore)
237    pub fn new() -> Self {
238        Default::default()
239    }
240
241    /// Add function filter
242    pub fn filter<T>(mut self, filter: T) -> Self
243    where
244        T: Into<Option<Query>>,
245    {
246        self.filter = filter.into();
247        self
248    }
249
250    /// The `weight` score allows you to multiply the score by the provided `weight`. This can sometimes be desired
251    /// since boost value set on specific queries gets normalized, while for this score function it does not.
252    /// The number value is of type float.
253    pub fn weight<T>(mut self, weight: T) -> Self
254    where
255        T: num_traits::AsPrimitive<f32>,
256    {
257        self.weight = Some(weight.as_());
258        self
259    }
260
261    /// Sets seed value
262    pub fn seed<T>(mut self, seed: T) -> Self
263    where
264        T: Serialize,
265    {
266        self.random_score.seed = Term::new(seed);
267        self
268    }
269
270    /// Sets field value
271    pub fn field<T>(mut self, field: T) -> Self
272    where
273        T: ToString,
274    {
275        self.random_score.field = Some(field.to_string());
276        self
277    }
278}
279
280/// The `field_value_factor` function allows you to use a field from a document to influence the
281/// score.
282/// It’s similar to using the `script_score` function, however, it avoids the overhead of scripting.
283/// If used on a multi-valued field, only the first value of the field is used in calculations.
284///
285/// As an example, imagine you have a document indexed with a numeric `my-int` field and wish to
286/// influence the score of a document with this field, an example doing so would look like:
287/// ```
288/// # use elasticsearch_dsl::{FieldValueFactor, FieldValueFactorModifier};
289/// # fn main() {
290/// # let _ =
291/// FieldValueFactor::new("my-int")
292///     .factor(1.2)
293///     .modifier(FieldValueFactorModifier::Sqrt)
294///     .missing(1.0)
295/// # ;}
296/// ```
297/// Which will translate into the following formula for scoring:
298/// ```text
299/// sqrt(1.2 * doc['my-int'].value)
300/// ```
301#[derive(Debug, Clone, PartialEq, Serialize)]
302pub struct FieldValueFactor {
303    field_value_factor: FieldValueFactorInner,
304
305    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
306    filter: Option<Query>,
307
308    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
309    weight: Option<f32>,
310}
311
312#[derive(Debug, Clone, PartialEq, Serialize)]
313struct FieldValueFactorInner {
314    field: String,
315
316    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
317    factor: Option<f32>,
318
319    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
320    modifier: Option<FieldValueFactorModifier>,
321
322    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
323    missing: Option<f32>,
324}
325
326impl FieldValueFactor {
327    /// Creates an instance of [FieldValueFactor](FieldValueFactor)
328    ///
329    /// - `field` - Field to be extracted from the document.
330    pub fn new<T>(field: T) -> Self
331    where
332        T: ToString,
333    {
334        Self {
335            field_value_factor: FieldValueFactorInner {
336                field: field.to_string(),
337                factor: None,
338                modifier: None,
339                missing: None,
340            },
341            filter: None,
342            weight: None,
343        }
344    }
345
346    /// Add function filter
347    pub fn filter<T>(mut self, filter: T) -> Self
348    where
349        T: Into<Option<Query>>,
350    {
351        self.filter = filter.into();
352        self
353    }
354
355    /// The `weight` score allows you to multiply the score by the provided `weight`. This can sometimes be desired
356    /// since boost value set on specific queries gets normalized, while for this score function it does not.
357    /// The number value is of type float.
358    pub fn weight<T>(mut self, weight: T) -> Self
359    where
360        T: num_traits::AsPrimitive<f32>,
361    {
362        self.weight = Some(weight.as_());
363        self
364    }
365
366    /// Factor to multiply the field value with
367    pub fn factor(mut self, factor: f32) -> Self {
368        self.field_value_factor.factor = Some(factor);
369        self
370    }
371
372    /// Modifier to apply to the field value
373    pub fn modifier(mut self, modifier: FieldValueFactorModifier) -> Self {
374        self.field_value_factor.modifier = Some(modifier);
375        self
376    }
377
378    /// Value used if the document doesn’t have that field. The modifier and factor are still
379    /// applied to it as though it were read from the document
380    pub fn missing(mut self, missing: f32) -> Self {
381        self.field_value_factor.missing = Some(missing);
382        self
383    }
384}
385
386/// Modifier to apply to the field value
387///
388/// Defaults to [none](FieldValueFactorModifier::None)
389#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize)]
390#[serde(rename_all = "lowercase")]
391pub enum FieldValueFactorModifier {
392    /// Do not apply any multiplier to the field value
393    None,
394
395    /// Take the [common logarithm](https://en.wikipedia.org/wiki/Common_logarithm) of the field
396    /// value
397    ///
398    /// Because this function will return a negative value and cause an error if used on values
399    /// between `0` and `1`, it is recommended to use [log1p](FieldValueFactorModifier::Log1P)
400    /// instead.
401    Log,
402
403    /// Add 1 to the field value and take the common logarithm
404    Log1P,
405
406    /// Add 2 to the field value and take the common logarithm
407    Log2P,
408
409    /// Take the [natural logarithm](https://en.wikipedia.org/wiki/Natural_logarithm) of the field
410    /// value.
411    ///
412    /// Because this function will return a negative value and cause an error if used on values
413    /// between `0` and `1`, it is recommended to use [ln1p](FieldValueFactorModifier::Ln1P)
414    /// instead.
415    Ln,
416
417    /// Add 1 to the field value and take the natural logarithm
418    Ln1P,
419
420    /// Add 2 to the field value and take the natural logarithm
421    Ln2P,
422
423    /// Square the field value (multiply it by itself)
424    Square,
425
426    /// Take the [square root](https://en.wikipedia.org/wiki/Square_root) of the field value
427    Sqrt,
428
429    /// [Reciprocate](https://en.wikipedia.org/wiki/Multiplicative_inverse) the field value, same
430    /// as `1/x` where `x` is the field’s value
431    Reciprocal,
432}
433
434#[doc(hidden)]
435pub trait Origin: Debug + PartialEq + Serialize + Clone {
436    type Scale: Debug + PartialEq + Serialize + Clone;
437    type Offset: Debug + PartialEq + Serialize + Clone;
438}
439
440impl Origin for DateTime<Utc> {
441    type Scale = Time;
442    type Offset = Time;
443}
444
445impl Origin for GeoLocation {
446    type Scale = Distance;
447    type Offset = Distance;
448}
449
450macro_rules! impl_origin_for_numbers {
451    ($($name:ident ),+) => {
452        $(
453            impl Origin for $name {
454                type Scale = Self;
455                type Offset = Self;
456            }
457        )+
458    }
459}
460
461impl_origin_for_numbers![i8, i16, i32, i64, u8, u16, u32, u64, f32, f64];
462
463/// Decay functions score a document with a function that decays depending on the distance of a
464/// numeric field value of the document from a user given origin. This is similar to a range query,
465/// but with smooth edges instead of boxes.
466///
467/// To use distance scoring on a query that has numerical fields, the user has to define an
468/// `origin` and a `scale` for each field. The `origin` is needed to define the “central point”
469/// from which the distance is calculated, and the `scale` to define the rate of decay.
470#[derive(Debug, Clone, PartialEq)]
471pub struct Decay<T: Origin> {
472    function: DecayFunction,
473
474    inner: DecayFieldInner<T>,
475
476    filter: Option<Query>,
477
478    weight: Option<f32>,
479}
480
481#[derive(Debug, Clone, PartialEq)]
482struct DecayFieldInner<T: Origin> {
483    field: String,
484    inner: DecayInner<T>,
485}
486
487#[derive(Debug, Clone, PartialEq, Serialize)]
488struct DecayInner<O>
489where
490    O: Origin,
491{
492    origin: O,
493
494    scale: <O as Origin>::Scale,
495
496    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
497    offset: Option<<O as Origin>::Offset>,
498
499    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
500    decay: Option<f32>,
501}
502
503impl<O> Decay<O>
504where
505    O: Origin,
506{
507    /// Creates an instance of [Decay](Decay)
508    ///
509    /// - `function` - Decay function variant
510    /// - `field` - Field to apply function to
511    /// - `origin` - The point of origin used for calculating distance. Must be given as a number
512    ///   for numeric field, date for date fields and geo point for geo fields. Required for geo and
513    ///   numeric field. For date fields the default is `now`. Date math (for example now-1h) is
514    ///   supported for origin.
515    /// - `scale` - Required for all types. Defines the distance from origin + offset at which the
516    ///   computed score will equal `decay` parameter. For geo fields: Can be defined as number+unit
517    ///   (1km, 12m,…​). Default unit is meters. For date fields: Can to be defined as a number+unit
518    ///   ("1h", "10d",…​). Default unit is milliseconds. For numeric field: Any number.
519    pub fn new<T>(function: DecayFunction, field: T, origin: O, scale: <O as Origin>::Scale) -> Self
520    where
521        T: ToString,
522    {
523        Self {
524            function,
525            inner: DecayFieldInner {
526                field: field.to_string(),
527                inner: DecayInner {
528                    origin,
529                    scale,
530                    offset: None,
531                    decay: None,
532                },
533            },
534            filter: None,
535            weight: None,
536        }
537    }
538
539    /// Add function filter
540    pub fn filter<T>(mut self, filter: T) -> Self
541    where
542        T: Into<Option<Query>>,
543    {
544        self.filter = filter.into();
545        self
546    }
547
548    /// The `weight` score allows you to multiply the score by the provided `weight`. This can sometimes be desired
549    /// since boost value set on specific queries gets normalized, while for this score function it does not.
550    /// The number value is of type float.
551    pub fn weight<T>(mut self, weight: T) -> Self
552    where
553        T: num_traits::AsPrimitive<f32>,
554    {
555        self.weight = Some(weight.as_());
556        self
557    }
558
559    /// If an `offset` is defined, the decay function will only compute the decay function for
560    /// documents with a distance greater than the defined `offset`.
561    ///
562    /// The default is `0`.
563    pub fn offset(mut self, offset: <O as Origin>::Offset) -> Self {
564        self.inner.inner.offset = Some(offset);
565        self
566    }
567
568    /// The `decay` parameter defines how documents are scored at the distance given at `scale`. If
569    /// no `decay` is defined, documents at the distance `scale` will be scored `0.5`.
570    pub fn decay(mut self, decay: f32) -> Self {
571        self.inner.inner.decay = Some(decay);
572        self
573    }
574}
575
576impl<T: Origin> Serialize for Decay<T> {
577    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
578    where
579        S: Serializer,
580    {
581        let mut map = serializer.serialize_map(Some(3))?;
582
583        map.serialize_entry(&self.function, &self.inner)?;
584
585        if let Some(filter) = &self.filter {
586            map.serialize_entry("filter", filter)?;
587        }
588
589        if let Some(weight) = &self.weight {
590            map.serialize_entry("weight", weight)?;
591        }
592
593        map.end()
594    }
595}
596
597impl<T: Origin> Serialize for DecayFieldInner<T> {
598    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
599    where
600        S: Serializer,
601    {
602        let mut map = serializer.serialize_map(Some(1))?;
603
604        map.serialize_entry(&self.field, &self.inner)?;
605
606        map.end()
607    }
608}
609
610/// Decay function variants
611///
612/// <https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-function-score-query.html#_supported_decay_functions>
613#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize)]
614#[serde(rename_all = "snake_case")]
615pub enum DecayFunction {
616    /// Linear decay
617    Linear,
618
619    /// Exponential decay
620    Exp,
621
622    /// Gauss decay
623    Gauss,
624}
625
626/// The script_score function allows you to wrap another query and customize the scoring of it
627/// optionally with a computation derived from other numeric field values in the doc using a script
628/// expression
629#[derive(Debug, Clone, PartialEq, Serialize)]
630pub struct ScriptScore {
631    script_score: ScriptWrapper,
632}
633
634#[derive(Debug, Clone, PartialEq, Serialize)]
635struct ScriptWrapper {
636    script: Script,
637}
638
639impl ScriptScore {
640    /// Creates an instance of [Script]
641    ///
642    /// - `script` - script source
643    pub fn new(script: Script) -> Self {
644        Self {
645            script_score: ScriptWrapper { script },
646        }
647    }
648}
649
650#[cfg(test)]
651mod tests {
652    use super::*;
653    use chrono::prelude::*;
654
655    #[test]
656    fn serialization() {
657        assert_serialize(
658            Decay::new(
659                DecayFunction::Gauss,
660                "test",
661                Utc.with_ymd_and_hms(2014, 7, 8, 9, 1, 0).single().unwrap(),
662                Time::Days(7),
663            ),
664            json!({
665                "gauss": {
666                    "test": {
667                        "origin": "2014-07-08T09:01:00Z",
668                        "scale": "7d",
669                    }
670                }
671            }),
672        );
673
674        assert_serialize(
675            Decay::new(DecayFunction::Linear, "test", 1, 2),
676            json!({
677                "linear": {
678                    "test": {
679                        "origin": 1,
680                        "scale": 2,
681                    }
682                }
683            }),
684        );
685
686        assert_serialize(
687            ScriptScore::new(Script::source("Math.log(2 + doc['my-int'].value)")),
688            json!({
689                "script_score": {
690                    "script": {
691                        "source": "Math.log(2 + doc['my-int'].value)"
692                    }
693                }
694            }),
695        );
696    }
697
698    #[test]
699    fn float_decay() {
700        assert_serialize(
701            Decay::new(DecayFunction::Linear, "test", 0.1, 0.5),
702            json!({
703                "linear": {
704                    "test": {
705                        "origin": 0.1,
706                        "scale": 0.5
707                    }
708                }
709            }),
710        );
711    }
712}