use serde::Serialize;
use serde::ser::SerializeStruct;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, serde::Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Operator {
#[serde(rename = "eq")]
Eq,
#[serde(rename = "gt")]
Gt,
#[serde(rename = "gte")]
Gte,
#[serde(rename = "lt")]
Lt,
#[serde(rename = "lte")]
Lte,
#[serde(rename = "btwn")]
Between,
}
impl std::str::FromStr for Operator {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"eq" | "=" | "==" => Ok(Operator::Eq),
"gt" | ">" => Ok(Operator::Gt),
"gte" | ">=" => Ok(Operator::Gte),
"lt" | "<" => Ok(Operator::Lt),
"lte" | "<=" => Ok(Operator::Lte),
"btwn" | "between" => Ok(Operator::Between),
_ => Err(()),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, serde::Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum LogicalOperator {
#[default]
And,
Or,
}
impl std::str::FromStr for LogicalOperator {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"and" | "&&" => Ok(LogicalOperator::And),
"or" | "||" => Ok(LogicalOperator::Or),
_ => Err(()),
}
}
}
pub trait ScreenerField: Clone + Serialize + 'static {
fn as_str(&self) -> &'static str;
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum ConditionValue {
Number(f64),
Between(f64, f64),
StringEq(String),
}
#[derive(Debug, Clone)]
pub struct QueryCondition<F: ScreenerField> {
pub field: F,
pub operator: Operator,
pub value: ConditionValue,
}
impl<F: ScreenerField> Serialize for QueryCondition<F> {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut s = serializer.serialize_struct("QueryCondition", 2)?;
s.serialize_field("operator", &self.operator)?;
let field_str = self.field.as_str();
let operands: serde_json::Value = match &self.value {
ConditionValue::Number(v) => serde_json::json!([field_str, v]),
ConditionValue::Between(min, max) => serde_json::json!([field_str, min, max]),
ConditionValue::StringEq(v) => serde_json::json!([field_str, v]),
};
s.serialize_field("operands", &operands)?;
s.end()
}
}
#[derive(Debug, Clone, Serialize)]
pub struct QueryGroup<F: ScreenerField> {
pub operator: LogicalOperator,
pub operands: Vec<QueryOperand<F>>,
}
impl<F: ScreenerField> QueryGroup<F> {
pub fn new(operator: LogicalOperator) -> Self {
Self {
operator,
operands: Vec::new(),
}
}
pub fn add_operand(&mut self, operand: QueryOperand<F>) {
self.operands.push(operand);
}
}
#[derive(Debug, Clone, Serialize)]
#[serde(untagged)]
pub enum QueryOperand<F: ScreenerField> {
Condition(QueryCondition<F>),
Group(QueryGroup<F>),
}
pub trait ScreenerFieldExt: ScreenerField + Sized {
fn gt(self, v: f64) -> QueryCondition<Self> {
QueryCondition {
field: self,
operator: Operator::Gt,
value: ConditionValue::Number(v),
}
}
fn lt(self, v: f64) -> QueryCondition<Self> {
QueryCondition {
field: self,
operator: Operator::Lt,
value: ConditionValue::Number(v),
}
}
fn gte(self, v: f64) -> QueryCondition<Self> {
QueryCondition {
field: self,
operator: Operator::Gte,
value: ConditionValue::Number(v),
}
}
fn lte(self, v: f64) -> QueryCondition<Self> {
QueryCondition {
field: self,
operator: Operator::Lte,
value: ConditionValue::Number(v),
}
}
fn eq_num(self, v: f64) -> QueryCondition<Self> {
QueryCondition {
field: self,
operator: Operator::Eq,
value: ConditionValue::Number(v),
}
}
fn between(self, min: f64, max: f64) -> QueryCondition<Self> {
QueryCondition {
field: self,
operator: Operator::Between,
value: ConditionValue::Between(min, max),
}
}
fn eq_str(self, v: impl Into<String>) -> QueryCondition<Self> {
QueryCondition {
field: self,
operator: Operator::Eq,
value: ConditionValue::StringEq(v.into()),
}
}
}
impl<T: ScreenerField + Sized> ScreenerFieldExt for T {}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::screeners::fields::EquityField;
#[test]
fn test_condition_gt_serializes_correctly() {
let condition = EquityField::AvgDailyVol3M.gt(200_000.0);
let json = serde_json::to_value(&condition).unwrap();
assert_eq!(json["operator"], "gt");
assert_eq!(json["operands"][0], "avgdailyvol3m");
assert_eq!(json["operands"][1], 200_000.0);
}
#[test]
fn test_condition_lt_serializes_correctly() {
let condition = EquityField::PeRatio.lt(30.0);
let json = serde_json::to_value(&condition).unwrap();
assert_eq!(json["operator"], "lt");
assert_eq!(json["operands"][0], "peratio.lasttwelvemonths");
assert_eq!(json["operands"][1], 30.0);
}
#[test]
fn test_condition_between_serializes_correctly() {
let condition = EquityField::PeRatio.between(10.0, 25.0);
let json = serde_json::to_value(&condition).unwrap();
assert_eq!(json["operator"], "btwn");
assert_eq!(json["operands"][0], "peratio.lasttwelvemonths");
assert_eq!(json["operands"][1], 10.0);
assert_eq!(json["operands"][2], 25.0);
}
#[test]
fn test_condition_eq_str_serializes_correctly() {
let condition = EquityField::Region.eq_str("us");
let json = serde_json::to_value(&condition).unwrap();
assert_eq!(json["operator"], "eq");
assert_eq!(json["operands"][0], "region");
assert_eq!(json["operands"][1], "us");
}
#[test]
fn test_query_group_serializes_correctly() {
let mut group = QueryGroup::new(LogicalOperator::And);
group.add_operand(QueryOperand::Condition(EquityField::Region.eq_str("us")));
group.add_operand(QueryOperand::Condition(
EquityField::AvgDailyVol3M.gt(200_000.0),
));
let json = serde_json::to_value(&group).unwrap();
assert_eq!(json["operator"], "and");
assert_eq!(json["operands"].as_array().unwrap().len(), 2);
}
}