use crate::context::Context;
use crate::duration::{format_duration, parse_duration};
use crate::magic::{Arguments, Identifier, This};
use crate::objects::{Value, ValueType};
use crate::resolvers::{Argument, Resolver};
use crate::ExecutionError;
use cel_parser::Expression;
use chrono::{DateTime, Duration, FixedOffset};
use std::cmp::Ordering;
use std::convert::TryInto;
use std::sync::Arc;
type Result<T> = std::result::Result<T, ExecutionError>;
#[derive(Clone)]
pub struct FunctionContext<'context> {
pub name: Arc<String>,
pub this: Option<Value>,
pub ptx: &'context Context<'context>,
pub args: Vec<Expression>,
pub arg_idx: usize,
}
impl<'context> FunctionContext<'context> {
pub fn new(
name: Arc<String>,
this: Option<Value>,
ptx: &'context Context<'context>,
args: Vec<Expression>,
) -> Self {
Self {
name,
this,
ptx,
args,
arg_idx: 0,
}
}
pub fn resolve<R>(&self, resolver: R) -> Result<Value>
where
R: Resolver,
{
resolver.resolve(self)
}
pub fn error<M: ToString>(&self, message: M) -> ExecutionError {
ExecutionError::function_error(self.name.as_str(), message)
}
}
pub fn size(ftx: &FunctionContext, value: Value) -> Result<i64> {
let size = match value {
Value::List(l) => l.len(),
Value::Map(m) => m.map.len(),
Value::String(s) => s.len(),
Value::Bytes(b) => b.len(),
value => return Err(ftx.error(format!("cannot determine the size of {:?}", value))),
};
Ok(size as i64)
}
pub fn contains(This(this): This<Value>, arg: Value) -> Result<Value> {
Ok(match this {
Value::List(v) => v.contains(&arg),
Value::Map(v) => v
.map
.contains_key(&arg.try_into().map_err(ExecutionError::UnsupportedKeyType)?),
Value::String(s) => {
if let Value::String(arg) = arg {
s.contains(arg.as_str())
} else {
false
}
}
Value::Bytes(b) => {
if let Value::Bytes(arg) = arg {
let s = arg.as_slice();
b.windows(arg.len()).any(|w| w == s)
} else {
false
}
}
_ => false,
}
.into())
}
pub fn string(ftx: &FunctionContext, This(this): This<Value>) -> Result<Value> {
Ok(match this {
Value::String(v) => Value::String(v.clone()),
Value::Timestamp(t) => Value::String(t.to_rfc3339().into()),
Value::Duration(v) => Value::String(format_duration(&v).into()),
Value::Int(v) => Value::String(v.to_string().into()),
Value::UInt(v) => Value::String(v.to_string().into()),
Value::Float(v) => Value::String(v.to_string().into()),
Value::Bytes(v) => Value::String(Arc::new(String::from_utf8_lossy(v.as_slice()).into())),
v => return Err(ftx.error(format!("cannot convert {:?} to string", v))),
})
}
pub fn double(ftx: &FunctionContext, This(this): This<Value>) -> Result<Value> {
Ok(match this {
Value::String(v) => v
.parse::<f64>()
.map(Value::Float)
.map_err(|e| ftx.error(format!("string parse error: {e}")))?,
Value::Float(v) => Value::Float(v),
Value::Int(v) => Value::Float(v as f64),
Value::UInt(v) => Value::Float(v as f64),
v => return Err(ftx.error(format!("cannot convert {:?} to double", v))),
})
}
pub fn uint(ftx: &FunctionContext, This(this): This<Value>) -> Result<Value> {
Ok(match this {
Value::String(v) => v
.parse::<u64>()
.map(Value::UInt)
.map_err(|e| ftx.error(format!("string parse error: {e}")))?,
Value::Float(v) => {
if v > u64::MAX as f64 || v < u64::MIN as f64 {
return Err(ftx.error("unsigned integer overflow"));
}
Value::UInt(v as u64)
}
Value::Int(v) => Value::UInt(
v.try_into()
.map_err(|_| ftx.error("unsigned integer overflow"))?,
),
Value::UInt(v) => Value::UInt(v),
v => return Err(ftx.error(format!("cannot convert {:?} to uint", v))),
})
}
pub fn int(ftx: &FunctionContext, This(this): This<Value>) -> Result<Value> {
Ok(match this {
Value::String(v) => v
.parse::<i64>()
.map(Value::Int)
.map_err(|e| ftx.error(format!("string parse error: {e}")))?,
Value::Float(v) => {
if v > i64::MAX as f64 || v < i64::MIN as f64 {
return Err(ftx.error("integer overflow"));
}
Value::Int(v as i64)
}
Value::Int(v) => Value::Int(v),
Value::UInt(v) => Value::Int(v.try_into().map_err(|_| ftx.error("integer overflow"))?),
v => return Err(ftx.error(format!("cannot convert {:?} to int", v))),
})
}
pub fn starts_with(This(this): This<Arc<String>>, prefix: Arc<String>) -> bool {
this.starts_with(prefix.as_str())
}
pub fn has(ftx: &FunctionContext) -> Result<Value> {
match ftx.resolve(Argument(0)) {
Ok(_) => Value::Bool(true),
Err(err) => match err {
ExecutionError::NoSuchKey(_) => Value::Bool(false),
_ => return Err(err),
},
}
.into()
}
pub fn map(
ftx: &FunctionContext,
This(this): This<Value>,
ident: Identifier,
expr: Expression,
) -> Result<Value> {
match this {
Value::List(items) => {
let mut values = Vec::with_capacity(items.len());
let mut ptx = ftx.ptx.clone();
for item in items.iter() {
ptx.add_variable_from_value(ident.clone(), item.clone());
let value = ptx.resolve(&expr)?;
values.push(value);
}
Value::List(Arc::new(values))
}
_ => return Err(this.error_expected_type(ValueType::List)),
}
.into()
}
pub fn filter(
ftx: &FunctionContext,
This(this): This<Value>,
ident: Identifier,
expr: Expression,
) -> Result<Value> {
match this {
Value::List(items) => {
let mut values = Vec::with_capacity(items.len());
let mut ptx = ftx.ptx.clone();
for item in items.iter() {
ptx.add_variable_from_value(ident.clone(), item.clone());
if let Value::Bool(true) = ptx.resolve(&expr)? {
values.push(item.clone());
}
}
Value::List(Arc::new(values))
}
_ => return Err(this.error_expected_type(ValueType::List)),
}
.into()
}
pub fn all(
ftx: &FunctionContext,
This(this): This<Value>,
ident: Identifier,
expr: Expression,
) -> Result<bool> {
return match this {
Value::List(items) => {
let mut ptx = ftx.ptx.clone();
for item in items.iter() {
ptx.add_variable_from_value(&ident, item);
if let Value::Bool(false) = ptx.resolve(&expr)? {
return Ok(false);
}
}
Ok(true)
}
Value::Map(value) => {
let mut ptx = ftx.ptx.clone();
for key in value.map.keys() {
ptx.add_variable_from_value(&ident, key);
if let Value::Bool(false) = ptx.resolve(&expr)? {
return Ok(false);
}
}
Ok(true)
}
_ => return Err(this.error_expected_type(ValueType::List)),
};
}
pub fn exists(
ftx: &FunctionContext,
This(this): This<Value>,
ident: Identifier,
expr: Expression,
) -> Result<bool> {
match this {
Value::List(items) => {
let mut ptx = ftx.ptx.clone();
for item in items.iter() {
ptx.add_variable_from_value(&ident, item);
if let Value::Bool(true) = ptx.resolve(&expr)? {
return Ok(true);
}
}
Ok(false)
}
Value::Map(value) => {
let mut ptx = ftx.ptx.clone();
for key in value.map.keys() {
ptx.add_variable_from_value(&ident, key);
if let Value::Bool(true) = ptx.resolve(&expr)? {
return Ok(true);
}
}
Ok(false)
}
_ => Err(this.error_expected_type(ValueType::List)),
}
}
pub fn exists_one(
ftx: &FunctionContext,
This(this): This<Value>,
ident: Identifier,
expr: Expression,
) -> Result<bool> {
match this {
Value::List(items) => {
let mut ptx = ftx.ptx.clone();
let mut exists = false;
for item in items.iter() {
ptx.add_variable_from_value(&ident, item);
if let Value::Bool(true) = ptx.resolve(&expr)? {
if exists {
return Ok(false);
}
exists = true;
}
}
Ok(exists)
}
Value::Map(value) => {
let mut ptx = ftx.ptx.clone();
let mut exists = false;
for key in value.map.keys() {
ptx.add_variable_from_value(&ident, key);
if let Value::Bool(true) = ptx.resolve(&expr)? {
if exists {
return Ok(false);
}
exists = true;
}
}
Ok(exists)
}
_ => Err(this.error_expected_type(ValueType::List)),
}
}
pub fn duration(value: Arc<String>) -> Result<Value> {
Ok(Value::Duration(_duration(value.as_str())?))
}
pub fn timestamp(value: Arc<String>) -> Result<Value> {
Ok(Value::Timestamp(
DateTime::parse_from_rfc3339(value.as_str())
.map_err(|e| ExecutionError::function_error("timestamp", e.to_string().as_str()))?,
))
}
pub fn max(Arguments(args): Arguments) -> Result<Value> {
let items = if args.len() == 1 {
match &args[0] {
Value::List(values) => values,
_ => return Ok(args[0].clone()),
}
} else {
&args
};
items
.iter()
.skip(1)
.try_fold(items.first().unwrap_or(&Value::Null), |acc, x| {
match acc.partial_cmp(x) {
Some(Ordering::Greater) => Ok(acc),
Some(_) => Ok(x),
None => Err(ExecutionError::ValuesNotComparable(acc.clone(), x.clone())),
}
})
.map(|v| v.clone())
}
fn _duration(i: &str) -> Result<Duration> {
let (_, duration) = parse_duration(i)
.map_err(|e| ExecutionError::function_error("duration", &e.to_string()))?;
Ok(duration)
}
fn _timestamp(i: &str) -> Result<DateTime<FixedOffset>> {
DateTime::parse_from_rfc3339(i)
.map_err(|e| ExecutionError::function_error("timestamp", &e.to_string()))
}
#[cfg(test)]
mod tests {
use crate::context::Context;
use crate::testing::test_script;
use std::collections::HashMap;
fn assert_script(input: &(&str, &str)) {
assert_eq!(test_script(input.1, None), Ok(true.into()), "{}", input.0);
}
#[test]
fn test_size() {
[
("size of list", "size([1, 2, 3]) == 3"),
("size of map", "size({'a': 1, 'b': 2, 'c': 3}) == 3"),
("size of string", "size('foo') == 3"),
("size of bytes", "size(b'foo') == 3"),
]
.iter()
.for_each(assert_script);
}
#[test]
fn test_has() {
let tests = vec![
("map has", "has(foo.bar) == true"),
("map has", "has(foo.bar) == true"),
("map not has", "has(foo.baz) == false"),
("map deep not has", "has(foo.baz.bar) == false"),
];
for (name, script) in tests {
let mut ctx = Context::default();
ctx.add_variable_from_value("foo", HashMap::from([("bar", 1)]));
assert_eq!(test_script(script, Some(ctx)), Ok(true.into()), "{}", name);
}
}
#[test]
fn test_map() {
[
("map list", "[1, 2, 3].map(x, x * 2) == [2, 4, 6]"),
("map list 2", "[1, 2, 3].map(y, y + 1) == [2, 3, 4]"),
(
"nested map",
"[[1, 2], [2, 3]].map(x, x.map(x, x * 2)) == [[2, 4], [4, 6]]",
),
]
.iter()
.for_each(assert_script);
}
#[test]
fn test_filter() {
[("filter list", "[1, 2, 3].filter(x, x > 2) == [3]")]
.iter()
.for_each(assert_script);
}
#[test]
fn test_all() {
[
("all list #1", "[0, 1, 2].all(x, x >= 0)"),
("all list #2", "[0, 1, 2].all(x, x > 0) == false"),
("all map", "{0: 0, 1:1, 2:2}.all(x, x >= 0) == true"),
]
.iter()
.for_each(assert_script);
}
#[test]
fn test_exists() {
[
("exist list #1", "[0, 1, 2].exists(x, x > 0)"),
("exist list #2", "[0, 1, 2].exists(x, x == 3) == false"),
("exist list #3", "[0, 1, 2, 2].exists(x, x == 2)"),
("exist map", "{0: 0, 1:1, 2:2}.exists(x, x > 0)"),
]
.iter()
.for_each(assert_script);
}
#[test]
fn test_exists_one() {
[
("exist list #1", "[0, 1, 2].exists_one(x, x > 0) == false"),
("exist list #2", "[0, 1, 2].exists_one(x, x == 0)"),
("exist map", "{0: 0, 1:1, 2:2}.exists_one(x, x == 2)"),
]
.iter()
.for_each(assert_script);
}
#[test]
fn test_max() {
[
("max single", "max(1) == 1"),
("max multiple", "max(1, 2, 3) == 3"),
("max negative", "max(-1, 0) == 0"),
("max float", "max(-1.0, 0.0) == 0.0"),
("max list", "max([1, 2, 3]) == 3"),
("max empty list", "max([]) == null"),
("max no args", "max() == null"),
]
.iter()
.for_each(assert_script);
}
#[test]
fn test_duration() {
[
("duration equal 1", "duration('1s') == duration('1000ms')"),
("duration equal 2", "duration('1m') == duration('60s')"),
("duration equal 3", "duration('1h') == duration('60m')"),
("duration comparison 1", "duration('1m') > duration('1s')"),
("duration comparison 2", "duration('1m') < duration('1h')"),
(
"duration subtraction",
"duration('1h') - duration('1m') == duration('59m')",
),
(
"duration addition",
"duration('1h') + duration('1m') == duration('1h1m')",
),
]
.iter()
.for_each(assert_script);
}
#[test]
fn test_starts_with() {
[
("starts with true", "'foobar'.startsWith('foo') == true"),
("starts with false", "'foobar'.startsWith('bar') == false"),
]
.iter()
.for_each(assert_script);
}
#[test]
fn test_timestamp() {
[(
"comparison",
"timestamp('2023-05-29T00:00:00Z') > timestamp('2023-05-28T00:00:00Z')",
),
(
"comparison",
"timestamp('2023-05-29T00:00:00Z') < timestamp('2023-05-30T00:00:00Z')",
),
(
"subtracting duration",
"timestamp('2023-05-29T00:00:00Z') - duration('24h') == timestamp('2023-05-28T00:00:00Z')",
),
(
"subtracting date",
"timestamp('2023-05-29T00:00:00Z') - timestamp('2023-05-28T00:00:00Z') == duration('24h')",
),
(
"adding duration",
"timestamp('2023-05-28T00:00:00Z') + duration('24h') == timestamp('2023-05-29T00:00:00Z')",
),
(
"timestamp string",
"timestamp('2023-05-28T00:00:00Z').string() == '2023-05-28T00:00:00+00:00'",
)]
.iter()
.for_each(assert_script);
}
#[test]
fn test_string() {
[
("duration", "duration('1h30m').string() == '1h30m0s'"),
(
"timestamp",
"timestamp('2023-05-29T00:00:00Z').string() == '2023-05-29T00:00:00+00:00'",
),
("string", "'foo'.string() == 'foo'"),
("int", "10.string() == '10'"),
("float", "10.5.string() == '10.5'"),
("bytes", "b'foo'.string() == 'foo'"),
]
.iter()
.for_each(assert_script);
}
#[test]
fn test_double() {
[
("string", "'10'.double() == 10.0"),
("int", "10.double() == 10.0"),
("double", "10.0.double() == 10.0"),
]
.iter()
.for_each(assert_script);
}
#[test]
fn test_uint() {
[
("string", "'10'.uint() == 10.uint()"),
("double", "10.5.uint() == 10.uint()"),
]
.iter()
.for_each(assert_script);
}
#[test]
fn test_int() {
[
("string", "'10'.int() == 10"),
("int", "10.int() == 10"),
("uint", "10.uint().int() == 10"),
("double", "10.5.int() == 10"),
]
.iter()
.for_each(assert_script);
}
}