Skip to main content

rs_es/query/
functions.rs

1/*
2 * Copyright 2016-2018 Ben Ashford
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//! Specific options for the Function option of various queries
18
19use std::collections::HashMap;
20
21use serde::{Serialize, Serializer};
22
23use crate::{
24    json::{FieldBased, NoOuter, ShouldSkip},
25    units::{Distance, Duration, JsonVal, Location},
26};
27
28/// Function
29#[derive(Debug, Serialize)]
30pub enum Function {
31    #[serde(rename = "script_score")]
32    ScriptScore(ScriptScore),
33    #[serde(rename = "weight")]
34    Weight(Weight),
35    #[serde(rename = "random_score")]
36    RandomScore(RandomScore),
37    #[serde(rename = "field_value_factor")]
38    FieldValueFactor(FieldValueFactor),
39    #[serde(rename = "linear")]
40    Linear(Decay),
41    #[serde(rename = "exp")]
42    Exp(Decay),
43    #[serde(rename = "gauss")]
44    Gauss(Decay),
45}
46
47/// ScriptScore function
48#[derive(Debug, Default, Serialize)]
49pub struct ScriptScore {
50    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
51    lang: Option<String>,
52    params: HashMap<String, JsonVal>,
53    inline: String,
54}
55
56impl Function {
57    pub fn build_script_score<A>(script: A) -> ScriptScore
58    where
59        A: Into<String>,
60    {
61        ScriptScore {
62            inline: script.into(),
63            ..Default::default()
64        }
65    }
66}
67
68impl ScriptScore {
69    add_field!(with_lang, lang, String);
70
71    pub fn with_params<A>(mut self, params: A) -> Self
72    where
73        A: IntoIterator<Item = (String, JsonVal)>,
74    {
75        self.params.extend(params);
76        self
77    }
78
79    pub fn add_param<A, B>(mut self, key: A, value: B) -> Self
80    where
81        A: Into<String>,
82        B: Into<JsonVal>,
83    {
84        self.params.insert(key.into(), value.into());
85        self
86    }
87
88    pub fn build(self) -> Function {
89        Function::ScriptScore(self)
90    }
91}
92
93/// Weight function
94#[derive(Debug, Default, Serialize)]
95pub struct Weight(f64);
96
97impl Function {
98    pub fn build_weight<A>(weight: A) -> Weight
99    where
100        A: Into<f64>,
101    {
102        Weight(weight.into())
103    }
104}
105
106impl Weight {
107    pub fn build(self) -> Function {
108        Function::Weight(self)
109    }
110}
111
112/// Random score function
113#[derive(Debug, Default, Serialize)]
114pub struct RandomScore(i64);
115
116impl Function {
117    pub fn build_random_score<A>(seed: A) -> RandomScore
118    where
119        A: Into<i64>,
120    {
121        RandomScore(seed.into())
122    }
123}
124
125impl RandomScore {
126    pub fn build(self) -> Function {
127        Function::RandomScore(self)
128    }
129}
130
131/// Field value factor function
132#[derive(Debug, Default, Serialize)]
133pub struct FieldValueFactor {
134    field: String,
135    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
136    factor: Option<f64>,
137    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
138    modifier: Option<Modifier>,
139    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
140    missing: Option<JsonVal>,
141}
142
143impl Function {
144    pub fn build_field_value_factor<A>(field: A) -> FieldValueFactor
145    where
146        A: Into<String>,
147    {
148        FieldValueFactor {
149            field: field.into(),
150            ..Default::default()
151        }
152    }
153}
154
155impl FieldValueFactor {
156    add_field!(with_factor, factor, f64);
157    add_field!(with_modifier, modifier, Modifier);
158    add_field!(with_missing, missing, JsonVal);
159
160    pub fn build(self) -> Function {
161        Function::FieldValueFactor(self)
162    }
163}
164
165/// Modifier for the FieldValueFactor function
166#[derive(Debug)]
167pub enum Modifier {
168    None,
169    Log,
170    Log1p,
171    Log2p,
172    Ln,
173    Ln1p,
174    Ln2p,
175    Square,
176    Sqrt,
177    Reciprocal,
178}
179
180impl Serialize for Modifier {
181    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
182    where
183        S: Serializer,
184    {
185        match self {
186            Modifier::None => "none".serialize(serializer),
187            Modifier::Log => "log".serialize(serializer),
188            Modifier::Log1p => "log1p".serialize(serializer),
189            Modifier::Log2p => "log2p".serialize(serializer),
190            Modifier::Ln => "ln".serialize(serializer),
191            Modifier::Ln1p => "ln1p".serialize(serializer),
192            Modifier::Ln2p => "ln2p".serialize(serializer),
193            Modifier::Square => "square".serialize(serializer),
194            Modifier::Sqrt => "sqrt".serialize(serializer),
195            Modifier::Reciprocal => "reciprocal".serialize(serializer),
196        }
197    }
198}
199
200#[derive(Debug, Default, Serialize)]
201pub struct DecayOptions {
202    origin: Origin,
203    scale: Scale,
204    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
205    offset: Option<Scale>,
206    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
207    decay: Option<f64>,
208    #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
209    multi_value_mode: Option<MultiValueMode>,
210}
211
212impl DecayOptions {
213    pub fn new<A, B>(origin: A, scale: B) -> DecayOptions
214    where
215        A: Into<Origin>,
216        B: Into<Scale>,
217    {
218        DecayOptions {
219            origin: origin.into(),
220            scale: scale.into(),
221            offset: None,
222            decay: None,
223            multi_value_mode: None,
224        }
225    }
226
227    add_field!(with_offset, offset, Scale);
228    add_field!(with_decay, decay, f64);
229    add_field!(with_multi_value_mode, multi_value_mode, MultiValueMode);
230
231    pub fn with_scale(mut self, val: Scale) -> Self {
232        self.scale = val;
233        self
234    }
235
236    pub fn with_origin(mut self, val: Origin) -> Self {
237        self.origin = val;
238        self
239    }
240
241    pub fn build<A: Into<String>>(self, field: A) -> Decay {
242        Decay(FieldBased::new(
243            field.into(),
244            self,
245            NoOuter,
246        ))
247    }
248}
249
250/// Decay functions
251#[derive(Debug, Serialize)]
252pub struct Decay(FieldBased<String, DecayOptions, NoOuter>);
253
254impl Function {
255    pub fn build_decay<A, B, C>(field: A, origin: B, scale: C) -> Decay
256    where
257        A: Into<String>,
258        B: Into<Origin>,
259        C: Into<Scale>,
260    {
261        Decay(FieldBased::new(
262            field.into(),
263            DecayOptions {
264                origin: origin.into(),
265                scale: scale.into(),
266                ..Default::default()
267            },
268            NoOuter,
269        ))
270    }
271
272    pub fn build_decay_from_options<A: Into<String>>(field: A, options: DecayOptions) -> Decay {
273        options.build(field)
274    }
275}
276
277impl Decay {
278    pub fn build_linear(self) -> Function {
279        Function::Linear(self)
280    }
281
282    pub fn build_exp(self) -> Function {
283        Function::Exp(self)
284    }
285
286    pub fn build_gauss(self) -> Function {
287        Function::Gauss(self)
288    }
289}
290
291// options used by decay functions
292
293/// Origin for decay function
294#[derive(Debug)]
295pub enum Origin {
296    I64(i64),
297    U64(u64),
298    F64(f64),
299    Location(Location),
300    Date(String),
301}
302
303impl Default for Origin {
304    fn default() -> Origin {
305        Origin::I64(0)
306    }
307}
308
309from!(i64, Origin, I64);
310from!(u64, Origin, U64);
311from!(f64, Origin, F64);
312from!(Location, Origin, Location);
313from!(String, Origin, Date);
314
315impl Serialize for Origin {
316    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
317    where
318        S: Serializer,
319    {
320        match self {
321            Origin::I64(orig) => orig.serialize(serializer),
322            Origin::U64(orig) => orig.serialize(serializer),
323            Origin::F64(orig) => orig.serialize(serializer),
324            Origin::Location(ref orig) => orig.serialize(serializer),
325            Origin::Date(ref orig) => orig.serialize(serializer),
326        }
327    }
328}
329
330/// Scale used by decay function
331#[derive(Debug)]
332pub enum Scale {
333    I64(i64),
334    U64(u64),
335    F64(f64),
336    Distance(Distance),
337    Duration(Duration),
338}
339
340impl Default for Scale {
341    fn default() -> Self {
342        Scale::I64(0)
343    }
344}
345
346from!(i64, Scale, I64);
347from!(u64, Scale, U64);
348from!(f64, Scale, F64);
349from!(Distance, Scale, Distance);
350from!(Duration, Scale, Duration);
351
352impl Serialize for Scale {
353    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
354    where
355        S: Serializer,
356    {
357        match self {
358            Scale::I64(s) => s.serialize(serializer),
359            Scale::U64(s) => s.serialize(serializer),
360            Scale::F64(s) => s.serialize(serializer),
361            Scale::Distance(ref s) => s.serialize(serializer),
362            Scale::Duration(ref s) => s.serialize(serializer),
363        }
364    }
365}
366
367/// Values for multi_value_mode
368#[derive(Debug)]
369pub enum MultiValueMode {
370    Min,
371    Max,
372    Avg,
373    Sum,
374}
375
376impl Serialize for MultiValueMode {
377    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
378    where
379        S: Serializer,
380    {
381        use self::MultiValueMode::*;
382        match self {
383            Min => "min",
384            Max => "max",
385            Avg => "avg",
386            Sum => "sum",
387        }
388        .serialize(serializer)
389    }
390}
391
392#[cfg(test)]
393pub mod tests {
394    use serde_json;
395
396    #[test]
397    fn test_decay_query() {
398        use crate::units::*;
399        let gauss_decay_query = super::Function::build_decay(
400            "my_field",
401            Location::LatLon(42., 24.),
402            Distance::new(3., DistanceUnit::Kilometer),
403        )
404        .build_gauss();
405
406        assert_eq!(
407            r#"{"gauss":{"my_field":{"origin":{"lat":42.0,"lon":24.0},"scale":"3km"}}}"#,
408            serde_json::to_string(&gauss_decay_query).unwrap()
409        );
410    }
411}