1use std::collections::HashMap;
20
21use serde::{Serialize, Serializer};
22
23use crate::{
24 json::{FieldBased, NoOuter, ShouldSkip},
25 units::{Distance, Duration, JsonVal, Location},
26};
27
28#[derive(Debug, Serialize)]
30pub enum Function {
31 #[serde(rename = "script_score")]
32 ScriptScore(ScriptScore),
33 #[serde(rename = "weight")]
34 Weight(Weight),
35 #[serde(rename = "random_score")]
36 RandomScore(RandomScore),
37 #[serde(rename = "field_value_factor")]
38 FieldValueFactor(FieldValueFactor),
39 #[serde(rename = "linear")]
40 Linear(Decay),
41 #[serde(rename = "exp")]
42 Exp(Decay),
43 #[serde(rename = "gauss")]
44 Gauss(Decay),
45}
46
47#[derive(Debug, Default, Serialize)]
49pub struct ScriptScore {
50 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
51 lang: Option<String>,
52 params: HashMap<String, JsonVal>,
53 inline: String,
54}
55
56impl Function {
57 pub fn build_script_score<A>(script: A) -> ScriptScore
58 where
59 A: Into<String>,
60 {
61 ScriptScore {
62 inline: script.into(),
63 ..Default::default()
64 }
65 }
66}
67
68impl ScriptScore {
69 add_field!(with_lang, lang, String);
70
71 pub fn with_params<A>(mut self, params: A) -> Self
72 where
73 A: IntoIterator<Item = (String, JsonVal)>,
74 {
75 self.params.extend(params);
76 self
77 }
78
79 pub fn add_param<A, B>(mut self, key: A, value: B) -> Self
80 where
81 A: Into<String>,
82 B: Into<JsonVal>,
83 {
84 self.params.insert(key.into(), value.into());
85 self
86 }
87
88 pub fn build(self) -> Function {
89 Function::ScriptScore(self)
90 }
91}
92
93#[derive(Debug, Default, Serialize)]
95pub struct Weight(f64);
96
97impl Function {
98 pub fn build_weight<A>(weight: A) -> Weight
99 where
100 A: Into<f64>,
101 {
102 Weight(weight.into())
103 }
104}
105
106impl Weight {
107 pub fn build(self) -> Function {
108 Function::Weight(self)
109 }
110}
111
112#[derive(Debug, Default, Serialize)]
114pub struct RandomScore(i64);
115
116impl Function {
117 pub fn build_random_score<A>(seed: A) -> RandomScore
118 where
119 A: Into<i64>,
120 {
121 RandomScore(seed.into())
122 }
123}
124
125impl RandomScore {
126 pub fn build(self) -> Function {
127 Function::RandomScore(self)
128 }
129}
130
131#[derive(Debug, Default, Serialize)]
133pub struct FieldValueFactor {
134 field: String,
135 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
136 factor: Option<f64>,
137 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
138 modifier: Option<Modifier>,
139 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
140 missing: Option<JsonVal>,
141}
142
143impl Function {
144 pub fn build_field_value_factor<A>(field: A) -> FieldValueFactor
145 where
146 A: Into<String>,
147 {
148 FieldValueFactor {
149 field: field.into(),
150 ..Default::default()
151 }
152 }
153}
154
155impl FieldValueFactor {
156 add_field!(with_factor, factor, f64);
157 add_field!(with_modifier, modifier, Modifier);
158 add_field!(with_missing, missing, JsonVal);
159
160 pub fn build(self) -> Function {
161 Function::FieldValueFactor(self)
162 }
163}
164
165#[derive(Debug)]
167pub enum Modifier {
168 None,
169 Log,
170 Log1p,
171 Log2p,
172 Ln,
173 Ln1p,
174 Ln2p,
175 Square,
176 Sqrt,
177 Reciprocal,
178}
179
180impl Serialize for Modifier {
181 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
182 where
183 S: Serializer,
184 {
185 match self {
186 Modifier::None => "none".serialize(serializer),
187 Modifier::Log => "log".serialize(serializer),
188 Modifier::Log1p => "log1p".serialize(serializer),
189 Modifier::Log2p => "log2p".serialize(serializer),
190 Modifier::Ln => "ln".serialize(serializer),
191 Modifier::Ln1p => "ln1p".serialize(serializer),
192 Modifier::Ln2p => "ln2p".serialize(serializer),
193 Modifier::Square => "square".serialize(serializer),
194 Modifier::Sqrt => "sqrt".serialize(serializer),
195 Modifier::Reciprocal => "reciprocal".serialize(serializer),
196 }
197 }
198}
199
200#[derive(Debug, Default, Serialize)]
201pub struct DecayOptions {
202 origin: Origin,
203 scale: Scale,
204 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
205 offset: Option<Scale>,
206 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
207 decay: Option<f64>,
208 #[serde(skip_serializing_if = "ShouldSkip::should_skip")]
209 multi_value_mode: Option<MultiValueMode>,
210}
211
212impl DecayOptions {
213 pub fn new<A, B>(origin: A, scale: B) -> DecayOptions
214 where
215 A: Into<Origin>,
216 B: Into<Scale>,
217 {
218 DecayOptions {
219 origin: origin.into(),
220 scale: scale.into(),
221 offset: None,
222 decay: None,
223 multi_value_mode: None,
224 }
225 }
226
227 add_field!(with_offset, offset, Scale);
228 add_field!(with_decay, decay, f64);
229 add_field!(with_multi_value_mode, multi_value_mode, MultiValueMode);
230
231 pub fn with_scale(mut self, val: Scale) -> Self {
232 self.scale = val;
233 self
234 }
235
236 pub fn with_origin(mut self, val: Origin) -> Self {
237 self.origin = val;
238 self
239 }
240
241 pub fn build<A: Into<String>>(self, field: A) -> Decay {
242 Decay(FieldBased::new(
243 field.into(),
244 self,
245 NoOuter,
246 ))
247 }
248}
249
250#[derive(Debug, Serialize)]
252pub struct Decay(FieldBased<String, DecayOptions, NoOuter>);
253
254impl Function {
255 pub fn build_decay<A, B, C>(field: A, origin: B, scale: C) -> Decay
256 where
257 A: Into<String>,
258 B: Into<Origin>,
259 C: Into<Scale>,
260 {
261 Decay(FieldBased::new(
262 field.into(),
263 DecayOptions {
264 origin: origin.into(),
265 scale: scale.into(),
266 ..Default::default()
267 },
268 NoOuter,
269 ))
270 }
271
272 pub fn build_decay_from_options<A: Into<String>>(field: A, options: DecayOptions) -> Decay {
273 options.build(field)
274 }
275}
276
277impl Decay {
278 pub fn build_linear(self) -> Function {
279 Function::Linear(self)
280 }
281
282 pub fn build_exp(self) -> Function {
283 Function::Exp(self)
284 }
285
286 pub fn build_gauss(self) -> Function {
287 Function::Gauss(self)
288 }
289}
290
291#[derive(Debug)]
295pub enum Origin {
296 I64(i64),
297 U64(u64),
298 F64(f64),
299 Location(Location),
300 Date(String),
301}
302
303impl Default for Origin {
304 fn default() -> Origin {
305 Origin::I64(0)
306 }
307}
308
309from!(i64, Origin, I64);
310from!(u64, Origin, U64);
311from!(f64, Origin, F64);
312from!(Location, Origin, Location);
313from!(String, Origin, Date);
314
315impl Serialize for Origin {
316 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
317 where
318 S: Serializer,
319 {
320 match self {
321 Origin::I64(orig) => orig.serialize(serializer),
322 Origin::U64(orig) => orig.serialize(serializer),
323 Origin::F64(orig) => orig.serialize(serializer),
324 Origin::Location(ref orig) => orig.serialize(serializer),
325 Origin::Date(ref orig) => orig.serialize(serializer),
326 }
327 }
328}
329
330#[derive(Debug)]
332pub enum Scale {
333 I64(i64),
334 U64(u64),
335 F64(f64),
336 Distance(Distance),
337 Duration(Duration),
338}
339
340impl Default for Scale {
341 fn default() -> Self {
342 Scale::I64(0)
343 }
344}
345
346from!(i64, Scale, I64);
347from!(u64, Scale, U64);
348from!(f64, Scale, F64);
349from!(Distance, Scale, Distance);
350from!(Duration, Scale, Duration);
351
352impl Serialize for Scale {
353 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
354 where
355 S: Serializer,
356 {
357 match self {
358 Scale::I64(s) => s.serialize(serializer),
359 Scale::U64(s) => s.serialize(serializer),
360 Scale::F64(s) => s.serialize(serializer),
361 Scale::Distance(ref s) => s.serialize(serializer),
362 Scale::Duration(ref s) => s.serialize(serializer),
363 }
364 }
365}
366
367#[derive(Debug)]
369pub enum MultiValueMode {
370 Min,
371 Max,
372 Avg,
373 Sum,
374}
375
376impl Serialize for MultiValueMode {
377 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
378 where
379 S: Serializer,
380 {
381 use self::MultiValueMode::*;
382 match self {
383 Min => "min",
384 Max => "max",
385 Avg => "avg",
386 Sum => "sum",
387 }
388 .serialize(serializer)
389 }
390}
391
392#[cfg(test)]
393pub mod tests {
394 use serde_json;
395
396 #[test]
397 fn test_decay_query() {
398 use crate::units::*;
399 let gauss_decay_query = super::Function::build_decay(
400 "my_field",
401 Location::LatLon(42., 24.),
402 Distance::new(3., DistanceUnit::Kilometer),
403 )
404 .build_gauss();
405
406 assert_eq!(
407 r#"{"gauss":{"my_field":{"origin":{"lat":42.0,"lon":24.0},"scale":"3km"}}}"#,
408 serde_json::to_string(&gauss_decay_query).unwrap()
409 );
410 }
411}