use crate::model::Bar;
use chrono::{DateTime, Utc};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MismatchDirection {
ExpectedIncrement,
ExpectedDecrement,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ValidationResult {
Valid,
TsMismatch {
expected: MismatchDirection,
prev_time: DateTime<Utc>,
curr_time: DateTime<Utc>,
index: usize,
},
DuplicateTs {
ts: DateTime<Utc>,
index: usize,
},
InvalidOHLC {
index: usize,
reason: String,
},
}
pub struct DataValidator {
expected_direction: MismatchDirection,
log_warnings: bool,
validate_ohlc: bool,
}
impl Default for DataValidator {
fn default() -> Self {
Self::new()
}
}
impl DataValidator {
pub fn new() -> Self {
Self {
expected_direction: MismatchDirection::ExpectedIncrement,
log_warnings: true,
validate_ohlc: true,
}
}
pub fn with_direction(mut self, direction: MismatchDirection) -> Self {
self.expected_direction = direction;
self
}
pub fn with_logging(mut self, log_warnings: bool) -> Self {
self.log_warnings = log_warnings;
self
}
pub fn with_ohlc_validation(mut self, validate: bool) -> Self {
self.validate_ohlc = validate;
self
}
pub fn validate_sequence(&self, bars: &[Bar]) -> Vec<ValidationResult> {
let mut results = Vec::new();
for (i, bar) in bars.iter().enumerate() {
if self.validate_ohlc
&& let Some(error) = self.validate_ohlc_bar(bar, i)
{
if self.log_warnings {
eprintln!("OHLC validation error at index {}: {}", i, error.reason());
}
results.push(error);
}
if i > 0 {
let prev_bar = &bars[i - 1];
if let Some(error) = self.validate_ts_sequence(prev_bar, bar, i) {
if self.log_warnings {
eprintln!("Ts mismatch at index {i}: {error:?}");
}
results.push(error);
}
}
}
if results.is_empty() {
results.push(ValidationResult::Valid);
}
results
}
pub fn validate_new_bar(&self, last_bar: Option<&Bar>, new_bar: &Bar) -> ValidationResult {
if self.validate_ohlc
&& let Some(error) = self.validate_ohlc_bar(new_bar, 0)
{
if self.log_warnings {
eprintln!("OHLC validation error: {}", error.reason());
}
return error;
}
if let Some(prev) = last_bar
&& let Some(error) = self.validate_ts_sequence(prev, new_bar, 0)
{
if self.log_warnings {
eprintln!("Ts mismatch: {error:?}");
}
return error;
}
ValidationResult::Valid
}
fn validate_ohlc_bar(&self, bar: &Bar, index: usize) -> Option<ValidationResult> {
if bar.high < bar.open || bar.high < bar.close || bar.high < bar.low {
return Some(ValidationResult::InvalidOHLC {
index,
reason: format!(
"High ({}) is less than Open ({}), Close ({}), or Low ({})",
bar.high, bar.open, bar.close, bar.low
),
});
}
if bar.low > bar.open || bar.low > bar.close || bar.low > bar.high {
return Some(ValidationResult::InvalidOHLC {
index,
reason: format!(
"Low ({}) is greater than Open ({}), Close ({}), or High ({})",
bar.low, bar.open, bar.close, bar.high
),
});
}
if bar.volume < 0.0 {
return Some(ValidationResult::InvalidOHLC {
index,
reason: format!("Volume ({}) is negative", bar.volume),
});
}
None
}
fn validate_ts_sequence(
&self,
prev: &Bar,
current: &Bar,
index: usize,
) -> Option<ValidationResult> {
if prev.time == current.time {
return Some(ValidationResult::DuplicateTs {
ts: current.time,
index,
});
}
match self.expected_direction {
MismatchDirection::ExpectedIncrement => {
if current.time < prev.time {
return Some(ValidationResult::TsMismatch {
expected: MismatchDirection::ExpectedIncrement,
prev_time: prev.time,
curr_time: current.time,
index,
});
}
}
MismatchDirection::ExpectedDecrement => {
if current.time > prev.time {
return Some(ValidationResult::TsMismatch {
expected: MismatchDirection::ExpectedDecrement,
prev_time: prev.time,
curr_time: current.time,
index,
});
}
}
}
None
}
}
impl ValidationResult {
pub fn is_valid(&self) -> bool {
matches!(self, ValidationResult::Valid)
}
pub fn is_error(&self) -> bool {
!self.is_valid()
}
pub fn reason(&self) -> String {
match self {
ValidationResult::Valid => "Valid".to_string(),
ValidationResult::TsMismatch {
expected,
prev_time,
curr_time,
index,
} => {
let direction = match expected {
MismatchDirection::ExpectedIncrement => "increasing",
MismatchDirection::ExpectedDecrement => "decreasing",
};
format!(
"Ts mismatch at index {index}: expected {direction} ts, but {prev_time} came before {curr_time}"
)
}
ValidationResult::DuplicateTs { ts, index } => {
format!("Duplicate ts {ts} at index {index}")
}
ValidationResult::InvalidOHLC { index, reason } => {
format!("Invalid OHLC at index {index}: {reason}")
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::TimeZone;
fn create_bar(ts: DateTime<Utc>, open: f64, high: f64, low: f64, close: f64) -> Bar {
Bar {
time: ts,
open,
high,
low,
close,
volume: 1000.0,
}
}
#[test]
fn test_valid_sequence() {
let validator = DataValidator::new();
let bars = vec![
create_bar(
Utc.with_ymd_and_hms(2024, 1, 1, 10, 0, 0).unwrap(),
100.0,
105.0,
95.0,
102.0,
),
create_bar(
Utc.with_ymd_and_hms(2024, 1, 1, 11, 0, 0).unwrap(),
102.0,
108.0,
100.0,
105.0,
),
create_bar(
Utc.with_ymd_and_hms(2024, 1, 1, 12, 0, 0).unwrap(),
105.0,
110.0,
103.0,
107.0,
),
];
let results = validator.validate_sequence(&bars);
assert_eq!(results.len(), 1);
assert!(results[0].is_valid());
}
#[test]
fn test_ts_mismatch() {
let validator = DataValidator::new().with_logging(false);
let bars = vec![
create_bar(
Utc.with_ymd_and_hms(2024, 1, 1, 10, 0, 0).unwrap(),
100.0,
105.0,
95.0,
102.0,
),
create_bar(
Utc.with_ymd_and_hms(2024, 1, 1, 9, 0, 0).unwrap(),
102.0,
108.0,
100.0,
105.0,
), ];
let results = validator.validate_sequence(&bars);
assert!(
results
.iter()
.any(|r| matches!(r, ValidationResult::TsMismatch { .. }))
);
}
#[test]
fn test_duplicate_ts() {
let validator = DataValidator::new().with_logging(false);
let bars = vec![
create_bar(
Utc.with_ymd_and_hms(2024, 1, 1, 10, 0, 0).unwrap(),
100.0,
105.0,
95.0,
102.0,
),
create_bar(
Utc.with_ymd_and_hms(2024, 1, 1, 10, 0, 0).unwrap(),
102.0,
108.0,
100.0,
105.0,
), ];
let results = validator.validate_sequence(&bars);
assert!(
results
.iter()
.any(|r| matches!(r, ValidationResult::DuplicateTs { .. }))
);
}
#[test]
fn test_invalid_ohlc_high_too_low() {
let validator = DataValidator::new().with_logging(false);
let bar = create_bar(
Utc.with_ymd_and_hms(2024, 1, 1, 10, 0, 0).unwrap(),
100.0,
95.0, 90.0,
98.0,
);
let result = validator.validate_new_bar(None, &bar);
assert!(matches!(result, ValidationResult::InvalidOHLC { .. }));
}
#[test]
fn test_invalid_ohlc_low_too_high() {
let validator = DataValidator::new().with_logging(false);
let bar = create_bar(
Utc.with_ymd_and_hms(2024, 1, 1, 10, 0, 0).unwrap(),
100.0,
110.0,
105.0, 102.0,
);
let result = validator.validate_new_bar(None, &bar);
assert!(matches!(result, ValidationResult::InvalidOHLC { .. }));
}
#[test]
fn test_descending_sequence() {
let validator = DataValidator::new()
.with_direction(MismatchDirection::ExpectedDecrement)
.with_logging(false);
let bars = vec![
create_bar(
Utc.with_ymd_and_hms(2024, 1, 1, 12, 0, 0).unwrap(),
100.0,
105.0,
95.0,
102.0,
),
create_bar(
Utc.with_ymd_and_hms(2024, 1, 1, 11, 0, 0).unwrap(),
102.0,
108.0,
100.0,
105.0,
),
create_bar(
Utc.with_ymd_and_hms(2024, 1, 1, 10, 0, 0).unwrap(),
105.0,
110.0,
103.0,
107.0,
),
];
let results = validator.validate_sequence(&bars);
assert_eq!(results.len(), 1);
assert!(results[0].is_valid());
}
#[test]
fn test_validation_result_reason() {
let result = ValidationResult::TsMismatch {
expected: MismatchDirection::ExpectedIncrement,
prev_time: Utc.with_ymd_and_hms(2024, 1, 1, 10, 0, 0).unwrap(),
curr_time: Utc.with_ymd_and_hms(2024, 1, 1, 9, 0, 0).unwrap(),
index: 1,
};
let reason = result.reason();
assert!(reason.contains("mismatch"));
assert!(reason.contains("increasing"));
}
}