1use std::{fmt, fmt::Debug};
2
3use chrono::{DateTime, Utc};
4use serde::{
5 de::{self, Deserialize, DeserializeOwned, Deserializer, MapAccess, Visitor},
6 ser::{Serialize, SerializeMap, Serializer},
7};
8
9use crate::{search::*, util::*};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize)]
14#[serde(rename_all = "snake_case")]
15pub enum FunctionScoreMode {
16 Multiply,
18
19 Sum,
21
22 Avg,
24
25 First,
27
28 Max,
30
31 Min,
33}
34
35impl Default for FunctionScoreMode {
36 fn default() -> Self {
37 Self::Multiply
38 }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize)]
44#[serde(rename_all = "snake_case")]
45pub enum FunctionBoostMode {
46 Multiply,
48
49 Replace,
51
52 Sum,
54
55 Avg,
57
58 Max,
60
61 Min,
63}
64
65impl Default for FunctionBoostMode {
66 fn default() -> Self {
67 Self::Multiply
68 }
69}
70
71macro_rules! function {
72 ($name:ident { $($variant:ident($query:ty)),+ $(,)? }) => {
73 #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
75 #[allow(missing_docs)]
76 #[serde(untagged)]
77 pub enum $name {
78 $(
79 $variant($query),
80 )*
81 }
82
83 $(
84 impl From<$query> for $name {
85 fn from(q: $query) -> Self {
86 $name::$variant(q)
87 }
88 }
89 )+
90
91 $(
92 impl From<$query> for Option<$name> {
93 fn from(q: $query) -> Self {
94 Some($name::$variant(q))
95 }
96 }
97 )+
98 };
99}
100
101function!(Function {
102 Weight(Weight),
103 RandomScore(RandomScore),
104 FieldValueFactor(FieldValueFactor),
105 DecayDateTime(Decay<DateTime<Utc>>),
106 DecayLocation(Decay<GeoLocation>),
107 DecayI8(Decay<i8>),
108 DecayI16(Decay<i16>),
109 DecayI32(Decay<i32>),
110 DecayI64(Decay<i64>),
111 DecayU8(Decay<u8>),
112 DecayU16(Decay<u16>),
113 DecayU32(Decay<u32>),
114 DecayU64(Decay<u64>),
115 Script(Script),
116});
117
118impl Function {
119 pub fn weight(weight: f32) -> Weight {
121 Weight::new(weight)
122 }
123
124 pub fn random_score() -> RandomScore {
126 RandomScore::new()
127 }
128
129 pub fn field_value_factor<T>(field: T) -> FieldValueFactor
133 where
134 T: ToString,
135 {
136 FieldValueFactor::new(field)
137 }
138
139 pub fn decay<T, O>(
155 function: DecayFunction,
156 field: T,
157 origin: O,
158 scale: <O as Origin>::Scale,
159 ) -> Decay<O>
160 where
161 T: ToString,
162 O: Origin,
163 {
164 Decay::new(function, field, origin, scale)
165 }
166
167 pub fn script<T>(source: T) -> FunctionScoreScript
171 where
172 T: ToString,
173 {
174 FunctionScoreScript::new(source)
175 }
176}
177
178#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
183pub struct Weight {
184 weight: f32,
185 #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
186 filter: Option<Query>,
187}
188
189impl Weight {
190 pub fn new(weight: f32) -> Self {
192 Self {
193 weight,
194 filter: None,
195 }
196 }
197
198 pub fn filter<T>(mut self, filter: T) -> Self
200 where
201 T: Into<Option<Query>>,
202 {
203 self.filter = filter.into();
204 self
205 }
206}
207
208#[derive(Debug, Default, Clone, PartialEq, Deserialize, Serialize)]
227pub struct RandomScore {
228 random_score: RandomScoreInner,
229
230 #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
231 filter: Option<Query>,
232
233 #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
234 weight: Option<f32>,
235}
236
237#[derive(Debug, Default, Clone, PartialEq, Deserialize, Serialize)]
238struct RandomScoreInner {
239 #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
240 seed: Option<Term>,
241
242 #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
243 field: Option<String>,
244}
245
246impl RandomScore {
247 pub fn new() -> Self {
249 Default::default()
250 }
251
252 pub fn filter<T>(mut self, filter: T) -> Self
254 where
255 T: Into<Option<Query>>,
256 {
257 self.filter = filter.into();
258 self
259 }
260
261 pub fn weight<T>(mut self, weight: T) -> Self
266 where
267 T: num_traits::AsPrimitive<f32>,
268 {
269 self.weight = Some(weight.as_());
270 self
271 }
272
273 pub fn seed<T>(mut self, seed: T) -> Self
275 where
276 T: Serialize,
277 {
278 self.random_score.seed = Term::new(seed);
279 self
280 }
281
282 pub fn field<T>(mut self, field: T) -> Self
284 where
285 T: ToString,
286 {
287 self.random_score.field = Some(field.to_string());
288 self
289 }
290}
291
292#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
315pub struct FieldValueFactor {
316 field_value_factor: FieldValueFactorInner,
317
318 #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
319 filter: Option<Query>,
320
321 #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
322 weight: Option<f32>,
323}
324
325#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
326struct FieldValueFactorInner {
327 field: String,
328
329 #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
330 factor: Option<f32>,
331
332 #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
333 modifier: Option<FieldValueFactorModifier>,
334
335 #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
336 missing: Option<f32>,
337}
338
339impl FieldValueFactor {
340 pub fn new<T>(field: T) -> Self
344 where
345 T: ToString,
346 {
347 Self {
348 field_value_factor: FieldValueFactorInner {
349 field: field.to_string(),
350 factor: None,
351 modifier: None,
352 missing: None,
353 },
354 filter: None,
355 weight: None,
356 }
357 }
358
359 pub fn filter<T>(mut self, filter: T) -> Self
361 where
362 T: Into<Option<Query>>,
363 {
364 self.filter = filter.into();
365 self
366 }
367
368 pub fn weight<T>(mut self, weight: T) -> Self
373 where
374 T: num_traits::AsPrimitive<f32>,
375 {
376 self.weight = Some(weight.as_());
377 self
378 }
379
380 pub fn factor(mut self, factor: f32) -> Self {
382 self.field_value_factor.factor = Some(factor);
383 self
384 }
385
386 pub fn modifier(mut self, modifier: FieldValueFactorModifier) -> Self {
388 self.field_value_factor.modifier = Some(modifier);
389 self
390 }
391
392 pub fn missing(mut self, missing: f32) -> Self {
395 self.field_value_factor.missing = Some(missing);
396 self
397 }
398}
399
400#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize)]
404#[serde(rename_all = "snake_case")]
405pub enum FieldValueFactorModifier {
406 None,
408
409 Log,
416
417 Log1P,
419
420 Log2P,
422
423 Ln,
430
431 Ln1P,
433
434 Ln2P,
436
437 Square,
439
440 Sqrt,
442
443 Reciprocal,
446}
447
448#[doc(hidden)]
449pub trait Origin: Debug + PartialEq + DeserializeOwned + Serialize + Clone {
450 type Scale: Debug + PartialEq + DeserializeOwned + Serialize + Clone;
451 type Offset: Debug + PartialEq + DeserializeOwned + Serialize + Clone;
452}
453
454impl Origin for DateTime<Utc> {
455 type Offset = Time;
456 type Scale = Time;
457}
458
459impl Origin for GeoLocation {
460 type Offset = Distance;
461 type Scale = Distance;
462}
463
464macro_rules! impl_origin_for_numbers {
465 ($($name:ident ),+) => {
466 $(
467 impl Origin for $name {
468 type Scale = Self;
469 type Offset = Self;
470 }
471 )+
472 }
473}
474
475impl_origin_for_numbers![i8, i16, i32, i64, u8, u16, u32, u64, f32, f64];
476
477#[derive(Debug, Clone, PartialEq)]
487pub struct Decay<T: Origin + DeserializeOwned> {
488 function: DecayFunction,
489
490 inner: DecayFieldInner<T>,
491
492 filter: Option<Query>,
493
494 weight: Option<f32>,
495}
496
497impl<T: Origin> Serialize for Decay<T> {
498 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
499 where
500 S: Serializer,
501 {
502 let mut map = serializer.serialize_map(Some(3))?;
503
504 map.serialize_entry(&self.function, &self.inner)?;
505
506 if let Some(filter) = &self.filter {
507 map.serialize_entry("filter", filter)?;
508 }
509
510 if let Some(weight) = &self.weight {
511 map.serialize_entry("weight", weight)?;
512 }
513
514 map.end()
515 }
516}
517
518impl<'de, T: Origin + DeserializeOwned> Deserialize<'de> for Decay<T> {
519 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
520 where
521 D: Deserializer<'de>,
522 {
523 struct DecayVisitor<T: Origin + DeserializeOwned> {
524 _marker: std::marker::PhantomData<T>,
525 }
526
527 impl<'de, T: Origin + DeserializeOwned> Visitor<'de> for DecayVisitor<T> {
528 type Value = Decay<T>;
529
530 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
531 formatter.write_str("struct Decay")
532 }
533
534 fn visit_map<A>(self, mut map: A) -> Result<Decay<T>, A::Error>
535 where
536 A: MapAccess<'de>,
537 {
538 let mut function = None;
539 let inner = None;
540 let mut filter = None;
541 let mut weight = None;
542
543 while let Some(key) = map.next_key()? {
544 match key {
545 "function" => function = Some(map.next_value()?),
546 "filter" => filter = Some(map.next_value()?),
549 "weight" => weight = Some(map.next_value()?),
550 _ => {
551 return Err(de::Error::unknown_field(
552 key,
553 &["function", "inner", "filter", "weight"],
554 ))
555 }
556 }
557 }
558
559 let function = function.ok_or_else(|| de::Error::missing_field("function"))?;
560 let inner = inner.ok_or_else(|| de::Error::missing_field("inner"))?;
561
562 Ok(Decay {
563 function,
564 inner,
565 filter,
566 weight,
567 })
568 }
569 }
570
571 deserializer.deserialize_struct(
572 "Decay",
573 &["function", "inner", "filter", "weight"],
574 DecayVisitor {
575 _marker: std::marker::PhantomData,
576 },
577 )
578 }
579}
580
581#[derive(Debug, Clone, PartialEq)]
582struct DecayFieldInner<T: Origin + DeserializeOwned> {
583 field: String,
584 inner: DecayInner<T>,
585}
586
587#[derive(Debug, Clone, PartialEq, Serialize)]
588struct DecayInner<O>
589where
590 O: Origin + DeserializeOwned,
591{
592 origin: O,
593
594 scale: <O as Origin>::Scale,
595
596 #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
597 offset: Option<<O as Origin>::Offset>,
598
599 #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
600 decay: Option<f32>,
601}
602
603impl<T: Origin> Serialize for DecayFieldInner<T> {
604 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
605 where
606 S: Serializer,
607 {
608 let mut map = serializer.serialize_map(Some(1))?;
609
610 map.serialize_entry(&self.field, &self.inner)?;
611
612 map.end()
613 }
614}
615
616impl<'de, O: Origin + DeserializeOwned> Deserialize<'de> for DecayInner<O> {
617 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
618 where
619 D: Deserializer<'de>,
620 {
621 struct DecayInnerVisitor<O: Origin + DeserializeOwned> {
622 _marker: std::marker::PhantomData<O>,
623 }
624
625 impl<'de, O: Origin + DeserializeOwned> Visitor<'de> for DecayInnerVisitor<O> {
626 type Value = DecayInner<O>;
627
628 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
629 formatter.write_str("struct DecayInner")
630 }
631
632 fn visit_map<A>(self, mut map: A) -> Result<DecayInner<O>, A::Error>
633 where
634 A: MapAccess<'de>,
635 {
636 let mut origin = None;
637 let mut scale = None;
638 let mut offset = None;
639 let mut decay = None;
640
641 while let Some(key) = map.next_key()? {
642 match key {
643 "origin" => origin = Some(map.next_value()?),
644 "scale" => scale = Some(map.next_value()?),
645 "offset" => offset = map.next_value()?,
646 "decay" => decay = map.next_value()?,
647 _ => {
648 return Err(de::Error::unknown_field(
649 key,
650 &["origin", "scale", "offset", "decay"],
651 ))
652 }
653 }
654 }
655
656 let origin = origin.ok_or_else(|| de::Error::missing_field("origin"))?;
657 let scale = scale.ok_or_else(|| de::Error::missing_field("scale"))?;
658
659 Ok(DecayInner {
660 origin,
661 scale,
662 offset,
663 decay,
664 })
665 }
666 }
667
668 deserializer.deserialize_struct(
669 "DecayInner",
670 &["origin", "scale", "offset", "decay"],
671 DecayInnerVisitor {
672 _marker: std::marker::PhantomData,
673 },
674 )
675 }
676}
677
678impl<O> Decay<O>
679where
680 O: Origin,
681{
682 pub fn new<T>(function: DecayFunction, field: T, origin: O, scale: <O as Origin>::Scale) -> Self
698 where
699 T: ToString,
700 {
701 Self {
702 function,
703 inner: DecayFieldInner {
704 field: field.to_string(),
705 inner: DecayInner {
706 origin,
707 scale,
708 offset: None,
709 decay: None,
710 },
711 },
712 filter: None,
713 weight: None,
714 }
715 }
716
717 pub fn filter<T>(mut self, filter: T) -> Self
719 where
720 T: Into<Option<Query>>,
721 {
722 self.filter = filter.into();
723 self
724 }
725
726 pub fn weight<T>(mut self, weight: T) -> Self
731 where
732 T: num_traits::AsPrimitive<f32>,
733 {
734 self.weight = Some(weight.as_());
735 self
736 }
737
738 pub fn offset(mut self, offset: <O as Origin>::Offset) -> Self {
744 self.inner.inner.offset = Some(offset);
745 self
746 }
747
748 pub fn decay(mut self, decay: f32) -> Self {
752 self.inner.inner.decay = Some(decay);
753 self
754 }
755}
756
757#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize)]
761#[serde(rename_all = "snake_case")]
762pub enum DecayFunction {
763 Linear,
765
766 Exp,
768
769 Gauss,
771}
772
773#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
777pub struct FunctionScoreScript {
778 script_score: ScriptInnerWrapper,
779}
780
781#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
782struct ScriptInnerWrapper {
783 script: ScriptInner,
784}
785
786#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
787struct ScriptInner {
788 source: String,
789
790 #[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
791 params: Option<serde_json::Value>,
792}
793
794impl FunctionScoreScript {
795 pub fn new<T>(source: T) -> Self
799 where
800 T: ToString,
801 {
802 Self {
803 script_score: ScriptInnerWrapper {
804 script: ScriptInner {
805 source: source.to_string(),
806 params: None,
807 },
808 },
809 }
810 }
811
812 pub fn params(mut self, params: serde_json::Value) -> Self {
814 self.script_score.script.params = Some(params);
815 self
816 }
817}
818
819#[cfg(test)]
820mod tests {
821 use chrono::prelude::*;
822
823 use super::*;
824
825 #[test]
826 fn serialization() {
827 assert_serialize(
828 Decay::new(
829 DecayFunction::Gauss,
830 "test",
831 Utc.with_ymd_and_hms(2014, 7, 8, 9, 1, 0).single().unwrap(),
832 Time::Days(7),
833 ),
834 json!({
835 "gauss": {
836 "test": {
837 "origin": "2014-07-08T09:01:00Z",
838 "scale": "7d",
839 }
840 }
841 }),
842 );
843
844 assert_serialize(
845 Decay::new(DecayFunction::Linear, "test", 1, 2),
846 json!({
847 "linear": {
848 "test": {
849 "origin": 1,
850 "scale": 2,
851 }
852 }
853 }),
854 );
855 }
856}