use std::fmt;
use serde::{
de::{self, Deserialize, Deserializer, Visitor},
Serialize, Serializer,
};
use crate::{util::*, Map};
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct Script {
#[serde(flatten)]
source: ScriptSource,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
lang: Option<ScriptLang>,
#[serde(default, skip_serializing_if = "ShouldSkip::should_skip")]
params: Map<String, serde_json::Value>,
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ScriptSource {
Source(String),
Id(String),
}
impl Script {
pub fn source<S>(source: S) -> Self
where
S: ToString,
{
Self {
source: ScriptSource::Source(source.to_string()),
lang: None,
params: Map::new(),
}
}
pub fn id<S>(id: S) -> Self
where
S: ToString,
{
Self {
source: ScriptSource::Id(id.to_string()),
lang: None,
params: Map::new(),
}
}
pub fn lang<S>(mut self, lang: S) -> Self
where
S: Into<ScriptLang>,
{
self.lang = Some(lang.into());
self
}
pub fn param<T, S>(mut self, name: S, param: T) -> Self
where
S: ToString,
T: Serialize,
{
if let Ok(param) = serde_json::to_value(param) {
let _ = self.params.entry(name.to_string()).or_insert(param);
}
self
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ScriptLang {
Painless,
Expression,
Mustache,
Custom(String),
}
impl Serialize for ScriptLang {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
Self::Painless => serializer.serialize_str("painless"),
Self::Expression => serializer.serialize_str("expression"),
Self::Mustache => serializer.serialize_str("mustache"),
Self::Custom(lang) => lang.serialize(serializer),
}
}
}
impl<'de> Deserialize<'de> for ScriptLang {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ScriptLangVisitor;
impl<'de> Visitor<'de> for ScriptLangVisitor {
type Value = ScriptLang;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string representing a script language")
}
fn visit_str<E>(self, value: &str) -> Result<ScriptLang, E>
where
E: de::Error,
{
match value {
"painless" => Ok(ScriptLang::Painless),
"expression" => Ok(ScriptLang::Expression),
"mustache" => Ok(ScriptLang::Mustache),
_ => Ok(ScriptLang::Custom(value.to_string())),
}
}
}
deserializer.deserialize_str(ScriptLangVisitor)
}
}
impl<T> From<T> for ScriptLang
where
T: ToString,
{
fn from(value: T) -> Self {
let value = value.to_string();
match value.as_str() {
"painless" => Self::Painless,
"expression" => Self::Expression,
"mustache" => Self::Mustache,
_ => Self::Custom(value),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn serialization() {
assert_serialize(
Script::source("Math.log(_score * 2) * params['multiplier'].len()")
.param("multiplier", [1, 2, 3])
.lang(ScriptLang::Painless),
json!({
"source": "Math.log(_score * 2) * params['multiplier'].len()",
"lang": "painless",
"params": {
"multiplier": [1, 2, 3]
}
}),
);
assert_serialize(
Script::source("doc['my_field'].value * params['multiplier']")
.param("multiplier", 1)
.lang("my_lang"),
json!({
"source": "doc['my_field'].value * params['multiplier']",
"lang": "my_lang",
"params": {
"multiplier": 1
}
}),
);
assert_serialize(
Script::id(123).param("multiplier", [1, 2, 3]),
json!({
"id": "123",
"params": {
"multiplier": [1, 2, 3]
}
}),
);
}
}