elasticsearch_dsl/search/queries/params/function_score_query.rs
1use crate::search::*;
2use crate::util::*;
3use chrono::{DateTime, Utc};
4use serde::ser::{Serialize, SerializeMap, Serializer};
5use std::fmt::Debug;
6
7/// Each document is scored by the defined functions. The parameter `score_mode` specifies how
8/// the computed scores are combined
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize)]
10#[serde(rename_all = "snake_case")]
11pub enum FunctionScoreMode {
12 /// Scores are multiplied (default)
13 Multiply,
14
15 /// Scores are summed
16 Sum,
17
18 /// Scores are averaged
19 Avg,
20
21 /// The first function that has a matching filter is applied
22 First,
23
24 /// Maximum score is used
25 Max,
26
27 /// Minimum score is used
28 Min,
29}
30
31impl Default for FunctionScoreMode {
32 fn default() -> Self {
33 Self::Multiply
34 }
35}
36
37/// The newly computed score is combined with the score of the query. The parameter
38/// `boost_mode` defines how.
39#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize)]
40#[serde(rename_all = "snake_case")]
41pub enum FunctionBoostMode {
42 /// Query score and function score is multiplied (default)
43 Multiply,
44
45 /// Only function score is used, the query score is ignored
46 Replace,
47
48 /// Query score and function score are added
49 Sum,
50
51 /// Average
52 Avg,
53
54 /// Max of query score and function score
55 Max,
56
57 /// Min of query score and function score
58 Min,
59}
60
61impl Default for FunctionBoostMode {
62 fn default() -> Self {
63 Self::Multiply
64 }
65}
66
67macro_rules! function {
68 ($name:ident { $($variant:ident($query:ty)),+ $(,)? }) => {
69 /// Functions available for use in [FunctionScoreQuery](crate::FunctionScoreQuery)
70 #[derive(Debug, Clone, PartialEq, Serialize)]
71 #[allow(missing_docs)]
72 #[serde(untagged)]
73 pub enum $name {
74 $(
75 $variant($query),
76 )*
77 }
78
79 $(
80 impl From<$query> for $name {
81 fn from(q: $query) -> Self {
82 $name::$variant(q)
83 }
84 }
85 )+
86
87 $(
88 impl From<$query> for Option<$name> {
89 fn from(q: $query) -> Self {
90 Some($name::$variant(q))
91 }
92 }
93 )+
94 };
95}
96
97function!(Function {
98 Weight(Weight),
99 RandomScore(RandomScore),
100 FieldValueFactor(FieldValueFactor),
101 DecayDateTime(Decay<DateTime<Utc>>),
102 DecayLocation(Decay<GeoLocation>),
103 DecayI8(Decay<i8>),
104 DecayI16(Decay<i16>),
105 DecayI32(Decay<i32>),
106 DecayI64(Decay<i64>),
107 DecayU8(Decay<u8>),
108 DecayU16(Decay<u16>),
109 DecayU32(Decay<u32>),
110 DecayU64(Decay<u64>),
111 DecayF32(Decay<f32>),
112 DecayF64(Decay<f64>),
113 ScriptScore(ScriptScore),
114});
115
116impl Function {
117 /// Creates an instance of [Weight](Weight)
118 pub fn weight(weight: f32) -> Weight {
119 Weight::new(weight)
120 }
121
122 /// Creates an instance of [RandomScore](RandomScore)
123 pub fn random_score() -> RandomScore {
124 RandomScore::new()
125 }
126
127 /// Creates an instance of [FieldValueFactor](FieldValueFactor)
128 ///
129 /// - `field` - Field to be extracted from the document.
130 pub fn field_value_factor<T>(field: T) -> FieldValueFactor
131 where
132 T: ToString,
133 {
134 FieldValueFactor::new(field)
135 }
136
137 /// Creates an instance of [Decay](Decay)
138 ///
139 /// - `function` - Decay function variant
140 /// - `field` - Field to apply function to
141 /// - `origin` - The point of origin used for calculating distance. Must be given as a number
142 /// for numeric field, date for date fields and geo point for geo fields. Required for geo and
143 /// numeric field. For date fields the default is `now`. Date math (for example now-1h) is
144 /// supported for origin.
145 /// - `scale` - Required for all types. Defines the distance from origin + offset at which the
146 /// computed score will equal `decay` parameter. For geo fields: Can be defined as number+unit
147 /// (1km, 12m,…). Default unit is meters. For date fields: Can to be defined as a number+unit
148 /// ("1h", "10d",…). Default unit is milliseconds. For numeric field: Any number.
149 pub fn decay<T, O>(
150 function: DecayFunction,
151 field: T,
152 origin: O,
153 scale: <O as Origin>::Scale,
154 ) -> Decay<O>
155 where
156 T: ToString,
157 O: Origin,
158 {
159 Decay::new(function, field, origin, scale)
160 }
161
162 /// Creates an instance of script
163 ///
164 /// - `source` - script source
165 pub fn script(source: Script) -> ScriptScore {
166 ScriptScore::new(source)
167 }
168}
169
170/// The `weight` score allows you to multiply the score by the provided weight.
171///
172/// This can sometimes be desired since boost value set on specific queries gets normalized, while
173/// for this score function it does not
174#[derive(Debug, Clone, PartialEq, Serialize)]
175pub struct Weight {
176 weight: f32,
177 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
178 filter: Option<Query>,
179}
180
181impl Weight {
182 /// Creates an instance of [Weight](Weight)
183 pub fn new(weight: f32) -> Self {
184 Self {
185 weight,
186 filter: None,
187 }
188 }
189
190 /// Add function filter
191 pub fn filter<T>(mut self, filter: T) -> Self
192 where
193 T: Into<Option<Query>>,
194 {
195 self.filter = filter.into();
196 self
197 }
198}
199
200/// The `random_score` generates scores that are uniformly distributed from `0` up to but not
201/// including `1`.
202///
203/// By default, it uses the internal Lucene doc ids as a source of randomness, which is very
204/// efficient but unfortunately not reproducible since documents might be renumbered by merges.
205///
206/// In case you want scores to be reproducible, it is possible to provide a `seed` and `field`. The
207/// final score will then be computed based on this seed, the minimum value of `field` for the
208/// considered document and a salt that is computed based on the index name and shard id so that
209/// documents that have the same value but are stored in different indexes get different scores.
210/// Note that documents that are within the same shard and have the same value for `field` will
211/// however get the same score, so it is usually desirable to use a field that has unique values
212/// for all documents. A good default choice might be to use the `_seq_no` field, whose only
213/// drawback is that scores will change if the document is updated since update operations also
214/// update the value of the `_seq_no` field.
215#[derive(Debug, Default, Clone, PartialEq, Serialize)]
216pub struct RandomScore {
217 random_score: RandomScoreInner,
218
219 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
220 filter: Option<Query>,
221
222 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
223 weight: Option<f32>,
224}
225
226#[derive(Debug, Default, Clone, PartialEq, Serialize)]
227struct RandomScoreInner {
228 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
229 seed: Option<Term>,
230
231 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
232 field: Option<String>,
233}
234
235impl RandomScore {
236 /// Creates an instance of [RandomScore](RandomScore)
237 pub fn new() -> Self {
238 Default::default()
239 }
240
241 /// Add function filter
242 pub fn filter<T>(mut self, filter: T) -> Self
243 where
244 T: Into<Option<Query>>,
245 {
246 self.filter = filter.into();
247 self
248 }
249
250 /// The `weight` score allows you to multiply the score by the provided `weight`. This can sometimes be desired
251 /// since boost value set on specific queries gets normalized, while for this score function it does not.
252 /// The number value is of type float.
253 pub fn weight<T>(mut self, weight: T) -> Self
254 where
255 T: num_traits::AsPrimitive<f32>,
256 {
257 self.weight = Some(weight.as_());
258 self
259 }
260
261 /// Sets seed value
262 pub fn seed<T>(mut self, seed: T) -> Self
263 where
264 T: Serialize,
265 {
266 self.random_score.seed = Term::new(seed);
267 self
268 }
269
270 /// Sets field value
271 pub fn field<T>(mut self, field: T) -> Self
272 where
273 T: ToString,
274 {
275 self.random_score.field = Some(field.to_string());
276 self
277 }
278}
279
280/// The `field_value_factor` function allows you to use a field from a document to influence the
281/// score.
282/// It’s similar to using the `script_score` function, however, it avoids the overhead of scripting.
283/// If used on a multi-valued field, only the first value of the field is used in calculations.
284///
285/// As an example, imagine you have a document indexed with a numeric `my-int` field and wish to
286/// influence the score of a document with this field, an example doing so would look like:
287/// ```
288/// # use elasticsearch_dsl::{FieldValueFactor, FieldValueFactorModifier};
289/// # fn main() {
290/// # let _ =
291/// FieldValueFactor::new("my-int")
292/// .factor(1.2)
293/// .modifier(FieldValueFactorModifier::Sqrt)
294/// .missing(1.0)
295/// # ;}
296/// ```
297/// Which will translate into the following formula for scoring:
298/// ```text
299/// sqrt(1.2 * doc['my-int'].value)
300/// ```
301#[derive(Debug, Clone, PartialEq, Serialize)]
302pub struct FieldValueFactor {
303 field_value_factor: FieldValueFactorInner,
304
305 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
306 filter: Option<Query>,
307
308 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
309 weight: Option<f32>,
310}
311
312#[derive(Debug, Clone, PartialEq, Serialize)]
313struct FieldValueFactorInner {
314 field: String,
315
316 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
317 factor: Option<f32>,
318
319 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
320 modifier: Option<FieldValueFactorModifier>,
321
322 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
323 missing: Option<f32>,
324}
325
326impl FieldValueFactor {
327 /// Creates an instance of [FieldValueFactor](FieldValueFactor)
328 ///
329 /// - `field` - Field to be extracted from the document.
330 pub fn new<T>(field: T) -> Self
331 where
332 T: ToString,
333 {
334 Self {
335 field_value_factor: FieldValueFactorInner {
336 field: field.to_string(),
337 factor: None,
338 modifier: None,
339 missing: None,
340 },
341 filter: None,
342 weight: None,
343 }
344 }
345
346 /// Add function filter
347 pub fn filter<T>(mut self, filter: T) -> Self
348 where
349 T: Into<Option<Query>>,
350 {
351 self.filter = filter.into();
352 self
353 }
354
355 /// The `weight` score allows you to multiply the score by the provided `weight`. This can sometimes be desired
356 /// since boost value set on specific queries gets normalized, while for this score function it does not.
357 /// The number value is of type float.
358 pub fn weight<T>(mut self, weight: T) -> Self
359 where
360 T: num_traits::AsPrimitive<f32>,
361 {
362 self.weight = Some(weight.as_());
363 self
364 }
365
366 /// Factor to multiply the field value with
367 pub fn factor(mut self, factor: f32) -> Self {
368 self.field_value_factor.factor = Some(factor);
369 self
370 }
371
372 /// Modifier to apply to the field value
373 pub fn modifier(mut self, modifier: FieldValueFactorModifier) -> Self {
374 self.field_value_factor.modifier = Some(modifier);
375 self
376 }
377
378 /// Value used if the document doesn’t have that field. The modifier and factor are still
379 /// applied to it as though it were read from the document
380 pub fn missing(mut self, missing: f32) -> Self {
381 self.field_value_factor.missing = Some(missing);
382 self
383 }
384}
385
386/// Modifier to apply to the field value
387///
388/// Defaults to [none](FieldValueFactorModifier::None)
389#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize)]
390#[serde(rename_all = "lowercase")]
391pub enum FieldValueFactorModifier {
392 /// Do not apply any multiplier to the field value
393 None,
394
395 /// Take the [common logarithm](https://en.wikipedia.org/wiki/Common_logarithm) of the field
396 /// value
397 ///
398 /// Because this function will return a negative value and cause an error if used on values
399 /// between `0` and `1`, it is recommended to use [log1p](FieldValueFactorModifier::Log1P)
400 /// instead.
401 Log,
402
403 /// Add 1 to the field value and take the common logarithm
404 Log1P,
405
406 /// Add 2 to the field value and take the common logarithm
407 Log2P,
408
409 /// Take the [natural logarithm](https://en.wikipedia.org/wiki/Natural_logarithm) of the field
410 /// value.
411 ///
412 /// Because this function will return a negative value and cause an error if used on values
413 /// between `0` and `1`, it is recommended to use [ln1p](FieldValueFactorModifier::Ln1P)
414 /// instead.
415 Ln,
416
417 /// Add 1 to the field value and take the natural logarithm
418 Ln1P,
419
420 /// Add 2 to the field value and take the natural logarithm
421 Ln2P,
422
423 /// Square the field value (multiply it by itself)
424 Square,
425
426 /// Take the [square root](https://en.wikipedia.org/wiki/Square_root) of the field value
427 Sqrt,
428
429 /// [Reciprocate](https://en.wikipedia.org/wiki/Multiplicative_inverse) the field value, same
430 /// as `1/x` where `x` is the field’s value
431 Reciprocal,
432}
433
434#[doc(hidden)]
435pub trait Origin: Debug + PartialEq + Serialize + Clone {
436 type Scale: Debug + PartialEq + Serialize + Clone;
437 type Offset: Debug + PartialEq + Serialize + Clone;
438}
439
440impl Origin for DateTime<Utc> {
441 type Scale = Time;
442 type Offset = Time;
443}
444
445impl Origin for GeoLocation {
446 type Scale = Distance;
447 type Offset = Distance;
448}
449
450macro_rules! impl_origin_for_numbers {
451 ($($name:ident ),+) => {
452 $(
453 impl Origin for $name {
454 type Scale = Self;
455 type Offset = Self;
456 }
457 )+
458 }
459}
460
461impl_origin_for_numbers![i8, i16, i32, i64, u8, u16, u32, u64, f32, f64];
462
463/// Decay functions score a document with a function that decays depending on the distance of a
464/// numeric field value of the document from a user given origin. This is similar to a range query,
465/// but with smooth edges instead of boxes.
466///
467/// To use distance scoring on a query that has numerical fields, the user has to define an
468/// `origin` and a `scale` for each field. The `origin` is needed to define the “central point”
469/// from which the distance is calculated, and the `scale` to define the rate of decay.
470#[derive(Debug, Clone, PartialEq)]
471pub struct Decay<T: Origin> {
472 function: DecayFunction,
473
474 inner: DecayFieldInner<T>,
475
476 filter: Option<Query>,
477
478 weight: Option<f32>,
479}
480
481#[derive(Debug, Clone, PartialEq)]
482struct DecayFieldInner<T: Origin> {
483 field: String,
484 inner: DecayInner<T>,
485}
486
487#[derive(Debug, Clone, PartialEq, Serialize)]
488struct DecayInner<O>
489where
490 O: Origin,
491{
492 origin: O,
493
494 scale: <O as Origin>::Scale,
495
496 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
497 offset: Option<<O as Origin>::Offset>,
498
499 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
500 decay: Option<f32>,
501}
502
503impl<O> Decay<O>
504where
505 O: Origin,
506{
507 /// Creates an instance of [Decay](Decay)
508 ///
509 /// - `function` - Decay function variant
510 /// - `field` - Field to apply function to
511 /// - `origin` - The point of origin used for calculating distance. Must be given as a number
512 /// for numeric field, date for date fields and geo point for geo fields. Required for geo and
513 /// numeric field. For date fields the default is `now`. Date math (for example now-1h) is
514 /// supported for origin.
515 /// - `scale` - Required for all types. Defines the distance from origin + offset at which the
516 /// computed score will equal `decay` parameter. For geo fields: Can be defined as number+unit
517 /// (1km, 12m,…). Default unit is meters. For date fields: Can to be defined as a number+unit
518 /// ("1h", "10d",…). Default unit is milliseconds. For numeric field: Any number.
519 pub fn new<T>(function: DecayFunction, field: T, origin: O, scale: <O as Origin>::Scale) -> Self
520 where
521 T: ToString,
522 {
523 Self {
524 function,
525 inner: DecayFieldInner {
526 field: field.to_string(),
527 inner: DecayInner {
528 origin,
529 scale,
530 offset: None,
531 decay: None,
532 },
533 },
534 filter: None,
535 weight: None,
536 }
537 }
538
539 /// Add function filter
540 pub fn filter<T>(mut self, filter: T) -> Self
541 where
542 T: Into<Option<Query>>,
543 {
544 self.filter = filter.into();
545 self
546 }
547
548 /// The `weight` score allows you to multiply the score by the provided `weight`. This can sometimes be desired
549 /// since boost value set on specific queries gets normalized, while for this score function it does not.
550 /// The number value is of type float.
551 pub fn weight<T>(mut self, weight: T) -> Self
552 where
553 T: num_traits::AsPrimitive<f32>,
554 {
555 self.weight = Some(weight.as_());
556 self
557 }
558
559 /// If an `offset` is defined, the decay function will only compute the decay function for
560 /// documents with a distance greater than the defined `offset`.
561 ///
562 /// The default is `0`.
563 pub fn offset(mut self, offset: <O as Origin>::Offset) -> Self {
564 self.inner.inner.offset = Some(offset);
565 self
566 }
567
568 /// The `decay` parameter defines how documents are scored at the distance given at `scale`. If
569 /// no `decay` is defined, documents at the distance `scale` will be scored `0.5`.
570 pub fn decay(mut self, decay: f32) -> Self {
571 self.inner.inner.decay = Some(decay);
572 self
573 }
574}
575
576impl<T: Origin> Serialize for Decay<T> {
577 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
578 where
579 S: Serializer,
580 {
581 let mut map = serializer.serialize_map(Some(3))?;
582
583 map.serialize_entry(&self.function, &self.inner)?;
584
585 if let Some(filter) = &self.filter {
586 map.serialize_entry("filter", filter)?;
587 }
588
589 if let Some(weight) = &self.weight {
590 map.serialize_entry("weight", weight)?;
591 }
592
593 map.end()
594 }
595}
596
597impl<T: Origin> Serialize for DecayFieldInner<T> {
598 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
599 where
600 S: Serializer,
601 {
602 let mut map = serializer.serialize_map(Some(1))?;
603
604 map.serialize_entry(&self.field, &self.inner)?;
605
606 map.end()
607 }
608}
609
610/// Decay function variants
611///
612/// <https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-function-score-query.html#_supported_decay_functions>
613#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize)]
614#[serde(rename_all = "snake_case")]
615pub enum DecayFunction {
616 /// Linear decay
617 Linear,
618
619 /// Exponential decay
620 Exp,
621
622 /// Gauss decay
623 Gauss,
624}
625
626/// The script_score function allows you to wrap another query and customize the scoring of it
627/// optionally with a computation derived from other numeric field values in the doc using a script
628/// expression
629#[derive(Debug, Clone, PartialEq, Serialize)]
630pub struct ScriptScore {
631 script_score: ScriptWrapper,
632}
633
634#[derive(Debug, Clone, PartialEq, Serialize)]
635struct ScriptWrapper {
636 script: Script,
637}
638
639impl ScriptScore {
640 /// Creates an instance of [Script]
641 ///
642 /// - `script` - script source
643 pub fn new(script: Script) -> Self {
644 Self {
645 script_score: ScriptWrapper { script },
646 }
647 }
648}
649
650#[cfg(test)]
651mod tests {
652 use super::*;
653 use chrono::prelude::*;
654
655 #[test]
656 fn serialization() {
657 assert_serialize(
658 Decay::new(
659 DecayFunction::Gauss,
660 "test",
661 Utc.with_ymd_and_hms(2014, 7, 8, 9, 1, 0).single().unwrap(),
662 Time::Days(7),
663 ),
664 json!({
665 "gauss": {
666 "test": {
667 "origin": "2014-07-08T09:01:00Z",
668 "scale": "7d",
669 }
670 }
671 }),
672 );
673
674 assert_serialize(
675 Decay::new(DecayFunction::Linear, "test", 1, 2),
676 json!({
677 "linear": {
678 "test": {
679 "origin": 1,
680 "scale": 2,
681 }
682 }
683 }),
684 );
685
686 assert_serialize(
687 ScriptScore::new(Script::source("Math.log(2 + doc['my-int'].value)")),
688 json!({
689 "script_score": {
690 "script": {
691 "source": "Math.log(2 + doc['my-int'].value)"
692 }
693 }
694 }),
695 );
696 }
697
698 #[test]
699 fn float_decay() {
700 assert_serialize(
701 Decay::new(DecayFunction::Linear, "test", 0.1, 0.5),
702 json!({
703 "linear": {
704 "test": {
705 "origin": 0.1,
706 "scale": 0.5
707 }
708 }
709 }),
710 );
711 }
712}