use std::{fmt, fmt::Debug};
use chrono::{DateTime, Utc};
use serde::{
de::{self, Deserialize, DeserializeOwned, Deserializer, MapAccess, Visitor},
ser::{Serialize, SerializeMap, Serializer},
};
use crate::{search::*, util::*};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum FunctionScoreMode {
Multiply,
Sum,
Avg,
First,
Max,
Min,
}
impl Default for FunctionScoreMode {
fn default() -> Self {
Self::Multiply
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum FunctionBoostMode {
Multiply,
Replace,
Sum,
Avg,
Max,
Min,
}
impl Default for FunctionBoostMode {
fn default() -> Self {
Self::Multiply
}
}
macro_rules! function {
($name:ident { $($variant:ident($query:ty)),+ $(,)? }) => {
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[allow(missing_docs)]
#[serde(untagged)]
pub enum $name {
$(
$variant($query),
)*
}
$(
impl From<$query> for $name {
fn from(q: $query) -> Self {
$name::$variant(q)
}
}
)+
$(
impl From<$query> for Option<$name> {
fn from(q: $query) -> Self {
Some($name::$variant(q))
}
}
)+
};
}
function!(Function {
Weight(Weight),
RandomScore(RandomScore),
FieldValueFactor(FieldValueFactor),
DecayDateTime(Decay<DateTime<Utc>>),
DecayLocation(Decay<GeoLocation>),
DecayI8(Decay<i8>),
DecayI16(Decay<i16>),
DecayI32(Decay<i32>),
DecayI64(Decay<i64>),
DecayU8(Decay<u8>),
DecayU16(Decay<u16>),
DecayU32(Decay<u32>),
DecayU64(Decay<u64>),
Script(Script),
});
impl Function {
pub fn weight(weight: f32) -> Weight {
Weight::new(weight)
}
pub fn random_score() -> RandomScore {
RandomScore::new()
}
pub fn field_value_factor<T>(field: T) -> FieldValueFactor
where
T: ToString,
{
FieldValueFactor::new(field)
}
pub fn decay<T, O>(
function: DecayFunction,
field: T,
origin: O,
scale: <O as Origin>::Scale,
) -> Decay<O>
where
T: ToString,
O: Origin,
{
Decay::new(function, field, origin, scale)
}
pub fn script<T>(source: T) -> FunctionScoreScript
where
T: ToString,
{
FunctionScoreScript::new(source)
}
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct Weight {
weight: f32,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
filter: Option<Query>,
}
impl Weight {
pub fn new(weight: f32) -> Self {
Self {
weight,
filter: None,
}
}
pub fn filter<T>(mut self, filter: T) -> Self
where
T: Into<Option<Query>>,
{
self.filter = filter.into();
self
}
}
#[derive(Debug, Default, Clone, PartialEq, Deserialize, Serialize)]
pub struct RandomScore {
random_score: RandomScoreInner,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
filter: Option<Query>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
weight: Option<f32>,
}
#[derive(Debug, Default, Clone, PartialEq, Deserialize, Serialize)]
struct RandomScoreInner {
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
seed: Option<Term>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
field: Option<String>,
}
impl RandomScore {
pub fn new() -> Self {
Default::default()
}
pub fn filter<T>(mut self, filter: T) -> Self
where
T: Into<Option<Query>>,
{
self.filter = filter.into();
self
}
pub fn weight<T>(mut self, weight: T) -> Self
where
T: num_traits::AsPrimitive<f32>,
{
self.weight = Some(weight.as_());
self
}
pub fn seed<T>(mut self, seed: T) -> Self
where
T: Serialize,
{
self.random_score.seed = Term::new(seed);
self
}
pub fn field<T>(mut self, field: T) -> Self
where
T: ToString,
{
self.random_score.field = Some(field.to_string());
self
}
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct FieldValueFactor {
field_value_factor: FieldValueFactorInner,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
filter: Option<Query>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
weight: Option<f32>,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
struct FieldValueFactorInner {
field: String,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
factor: Option<f32>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
modifier: Option<FieldValueFactorModifier>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
missing: Option<f32>,
}
impl FieldValueFactor {
pub fn new<T>(field: T) -> Self
where
T: ToString,
{
Self {
field_value_factor: FieldValueFactorInner {
field: field.to_string(),
factor: None,
modifier: None,
missing: None,
},
filter: None,
weight: None,
}
}
pub fn filter<T>(mut self, filter: T) -> Self
where
T: Into<Option<Query>>,
{
self.filter = filter.into();
self
}
pub fn weight<T>(mut self, weight: T) -> Self
where
T: num_traits::AsPrimitive<f32>,
{
self.weight = Some(weight.as_());
self
}
pub fn factor(mut self, factor: f32) -> Self {
self.field_value_factor.factor = Some(factor);
self
}
pub fn modifier(mut self, modifier: FieldValueFactorModifier) -> Self {
self.field_value_factor.modifier = Some(modifier);
self
}
pub fn missing(mut self, missing: f32) -> Self {
self.field_value_factor.missing = Some(missing);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum FieldValueFactorModifier {
None,
Log,
Log1P,
Log2P,
Ln,
Ln1P,
Ln2P,
Square,
Sqrt,
Reciprocal,
}
#[doc(hidden)]
pub trait Origin: Debug + PartialEq + DeserializeOwned + Serialize + Clone {
type Scale: Debug + PartialEq + DeserializeOwned + Serialize + Clone;
type Offset: Debug + PartialEq + DeserializeOwned + Serialize + Clone;
}
impl Origin for DateTime<Utc> {
type Offset = Time;
type Scale = Time;
}
impl Origin for GeoLocation {
type Offset = Distance;
type Scale = Distance;
}
macro_rules! impl_origin_for_numbers {
($($name:ident ),+) => {
$(
impl Origin for $name {
type Scale = Self;
type Offset = Self;
}
)+
}
}
impl_origin_for_numbers![i8, i16, i32, i64, u8, u16, u32, u64, f32, f64];
#[derive(Debug, Clone, PartialEq)]
pub struct Decay<T: Origin + DeserializeOwned> {
function: DecayFunction,
inner: DecayFieldInner<T>,
filter: Option<Query>,
weight: Option<f32>,
}
impl<T: Origin> Serialize for Decay<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut map = serializer.serialize_map(Some(3))?;
map.serialize_entry(&self.function, &self.inner)?;
if let Some(filter) = &self.filter {
map.serialize_entry("filter", filter)?;
}
if let Some(weight) = &self.weight {
map.serialize_entry("weight", weight)?;
}
map.end()
}
}
impl<'de, T: Origin + DeserializeOwned> Deserialize<'de> for Decay<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct DecayVisitor<T: Origin + DeserializeOwned> {
_marker: std::marker::PhantomData<T>,
}
impl<'de, T: Origin + DeserializeOwned> Visitor<'de> for DecayVisitor<T> {
type Value = Decay<T>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct Decay")
}
fn visit_map<A>(self, mut map: A) -> Result<Decay<T>, A::Error>
where
A: MapAccess<'de>,
{
let mut function = None;
let inner = None;
let mut filter = None;
let mut weight = None;
while let Some(key) = map.next_key()? {
match key {
"function" => function = Some(map.next_value()?),
"filter" => filter = Some(map.next_value()?),
"weight" => weight = Some(map.next_value()?),
_ => {
return Err(de::Error::unknown_field(
key,
&["function", "inner", "filter", "weight"],
))
}
}
}
let function = function.ok_or_else(|| de::Error::missing_field("function"))?;
let inner = inner.ok_or_else(|| de::Error::missing_field("inner"))?;
Ok(Decay {
function,
inner,
filter,
weight,
})
}
}
deserializer.deserialize_struct(
"Decay",
&["function", "inner", "filter", "weight"],
DecayVisitor {
_marker: std::marker::PhantomData,
},
)
}
}
#[derive(Debug, Clone, PartialEq)]
struct DecayFieldInner<T: Origin + DeserializeOwned> {
field: String,
inner: DecayInner<T>,
}
#[derive(Debug, Clone, PartialEq, Serialize)]
struct DecayInner<O>
where
O: Origin + DeserializeOwned,
{
origin: O,
scale: <O as Origin>::Scale,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
offset: Option<<O as Origin>::Offset>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
decay: Option<f32>,
}
impl<T: Origin> Serialize for DecayFieldInner<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut map = serializer.serialize_map(Some(1))?;
map.serialize_entry(&self.field, &self.inner)?;
map.end()
}
}
impl<'de, O: Origin + DeserializeOwned> Deserialize<'de> for DecayInner<O> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct DecayInnerVisitor<O: Origin + DeserializeOwned> {
_marker: std::marker::PhantomData<O>,
}
impl<'de, O: Origin + DeserializeOwned> Visitor<'de> for DecayInnerVisitor<O> {
type Value = DecayInner<O>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct DecayInner")
}
fn visit_map<A>(self, mut map: A) -> Result<DecayInner<O>, A::Error>
where
A: MapAccess<'de>,
{
let mut origin = None;
let mut scale = None;
let mut offset = None;
let mut decay = None;
while let Some(key) = map.next_key()? {
match key {
"origin" => origin = Some(map.next_value()?),
"scale" => scale = Some(map.next_value()?),
"offset" => offset = map.next_value()?,
"decay" => decay = map.next_value()?,
_ => {
return Err(de::Error::unknown_field(
key,
&["origin", "scale", "offset", "decay"],
))
}
}
}
let origin = origin.ok_or_else(|| de::Error::missing_field("origin"))?;
let scale = scale.ok_or_else(|| de::Error::missing_field("scale"))?;
Ok(DecayInner {
origin,
scale,
offset,
decay,
})
}
}
deserializer.deserialize_struct(
"DecayInner",
&["origin", "scale", "offset", "decay"],
DecayInnerVisitor {
_marker: std::marker::PhantomData,
},
)
}
}
impl<O> Decay<O>
where
O: Origin,
{
pub fn new<T>(function: DecayFunction, field: T, origin: O, scale: <O as Origin>::Scale) -> Self
where
T: ToString,
{
Self {
function,
inner: DecayFieldInner {
field: field.to_string(),
inner: DecayInner {
origin,
scale,
offset: None,
decay: None,
},
},
filter: None,
weight: None,
}
}
pub fn filter<T>(mut self, filter: T) -> Self
where
T: Into<Option<Query>>,
{
self.filter = filter.into();
self
}
pub fn weight<T>(mut self, weight: T) -> Self
where
T: num_traits::AsPrimitive<f32>,
{
self.weight = Some(weight.as_());
self
}
pub fn offset(mut self, offset: <O as Origin>::Offset) -> Self {
self.inner.inner.offset = Some(offset);
self
}
pub fn decay(mut self, decay: f32) -> Self {
self.inner.inner.decay = Some(decay);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum DecayFunction {
Linear,
Exp,
Gauss,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub struct FunctionScoreScript {
script_score: ScriptInnerWrapper,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
struct ScriptInnerWrapper {
script: ScriptInner,
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
struct ScriptInner {
source: String,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
params: Option<serde_json::Value>,
}
impl FunctionScoreScript {
pub fn new<T>(source: T) -> Self
where
T: ToString,
{
Self {
script_score: ScriptInnerWrapper {
script: ScriptInner {
source: source.to_string(),
params: None,
},
},
}
}
pub fn params(mut self, params: serde_json::Value) -> Self {
self.script_score.script.params = Some(params);
self
}
}
#[cfg(test)]
mod tests {
use chrono::prelude::*;
use super::*;
#[test]
fn serialization() {
assert_serialize(
Decay::new(
DecayFunction::Gauss,
"test",
Utc.with_ymd_and_hms(2014, 7, 8, 9, 1, 0).single().unwrap(),
Time::Days(7),
),
json!({
"gauss": {
"test": {
"origin": "2014-07-08T09:01:00Z",
"scale": "7d",
}
}
}),
);
assert_serialize(
Decay::new(DecayFunction::Linear, "test", 1, 2),
json!({
"linear": {
"test": {
"origin": 1,
"scale": 2,
}
}
}),
);
}
}