1use crate::model::Bar;
15use chrono::{DateTime, Utc};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum MismatchDirection {
20 ExpectedIncrement,
22 ExpectedDecrement,
24}
25
26#[derive(Debug, Clone, PartialEq)]
28pub enum ValidationResult {
29 Valid,
31 TsMismatch {
33 expected: MismatchDirection,
35 prev_time: DateTime<Utc>,
37 curr_time: DateTime<Utc>,
39 index: usize,
41 },
42 DuplicateTs {
44 ts: DateTime<Utc>,
46 index: usize,
48 },
49 InvalidOHLC {
51 index: usize,
53 reason: String,
55 },
56}
57
58pub struct DataValidator {
69 expected_direction: MismatchDirection,
71 log_warnings: bool,
73 validate_ohlc: bool,
75}
76
77impl Default for DataValidator {
78 fn default() -> Self {
79 Self::new()
80 }
81}
82
83impl DataValidator {
84 pub fn new() -> Self {
86 Self {
87 expected_direction: MismatchDirection::ExpectedIncrement,
88 log_warnings: true,
89 validate_ohlc: true,
90 }
91 }
92
93 pub fn with_direction(mut self, direction: MismatchDirection) -> Self {
95 self.expected_direction = direction;
96 self
97 }
98
99 pub fn with_logging(mut self, log_warnings: bool) -> Self {
101 self.log_warnings = log_warnings;
102 self
103 }
104
105 pub fn with_ohlc_validation(mut self, validate: bool) -> Self {
107 self.validate_ohlc = validate;
108 self
109 }
110
111 pub fn validate_sequence(&self, bars: &[Bar]) -> Vec<ValidationResult> {
113 let mut results = Vec::new();
114
115 for (i, bar) in bars.iter().enumerate() {
116 if self.validate_ohlc
118 && let Some(error) = self.validate_ohlc_bar(bar, i)
119 {
120 if self.log_warnings {
121 eprintln!("OHLC validation error at index {}: {}", i, error.reason());
122 }
123 results.push(error);
124 }
125
126 if i > 0 {
128 let prev_bar = &bars[i - 1];
129 if let Some(error) = self.validate_ts_sequence(prev_bar, bar, i) {
130 if self.log_warnings {
131 eprintln!("Ts mismatch at index {i}: {error:?}");
132 }
133 results.push(error);
134 }
135 }
136 }
137
138 if results.is_empty() {
139 results.push(ValidationResult::Valid);
140 }
141
142 results
143 }
144
145 pub fn validate_new_bar(&self, last_bar: Option<&Bar>, new_bar: &Bar) -> ValidationResult {
147 if self.validate_ohlc
149 && let Some(error) = self.validate_ohlc_bar(new_bar, 0)
150 {
151 if self.log_warnings {
152 eprintln!("OHLC validation error: {}", error.reason());
153 }
154 return error;
155 }
156
157 if let Some(prev) = last_bar
159 && let Some(error) = self.validate_ts_sequence(prev, new_bar, 0)
160 {
161 if self.log_warnings {
162 eprintln!("Ts mismatch: {error:?}");
163 }
164 return error;
165 }
166
167 ValidationResult::Valid
168 }
169
170 fn validate_ohlc_bar(&self, bar: &Bar, index: usize) -> Option<ValidationResult> {
172 if bar.high < bar.open || bar.high < bar.close || bar.high < bar.low {
174 return Some(ValidationResult::InvalidOHLC {
175 index,
176 reason: format!(
177 "High ({}) is less than Open ({}), Close ({}), or Low ({})",
178 bar.high, bar.open, bar.close, bar.low
179 ),
180 });
181 }
182
183 if bar.low > bar.open || bar.low > bar.close || bar.low > bar.high {
185 return Some(ValidationResult::InvalidOHLC {
186 index,
187 reason: format!(
188 "Low ({}) is greater than Open ({}), Close ({}), or High ({})",
189 bar.low, bar.open, bar.close, bar.high
190 ),
191 });
192 }
193
194 if bar.volume < 0.0 {
196 return Some(ValidationResult::InvalidOHLC {
197 index,
198 reason: format!("Volume ({}) is negative", bar.volume),
199 });
200 }
201
202 None
203 }
204
205 fn validate_ts_sequence(
207 &self,
208 prev: &Bar,
209 current: &Bar,
210 index: usize,
211 ) -> Option<ValidationResult> {
212 if prev.time == current.time {
214 return Some(ValidationResult::DuplicateTs {
215 ts: current.time,
216 index,
217 });
218 }
219
220 match self.expected_direction {
222 MismatchDirection::ExpectedIncrement => {
223 if current.time < prev.time {
224 return Some(ValidationResult::TsMismatch {
225 expected: MismatchDirection::ExpectedIncrement,
226 prev_time: prev.time,
227 curr_time: current.time,
228 index,
229 });
230 }
231 }
232 MismatchDirection::ExpectedDecrement => {
233 if current.time > prev.time {
234 return Some(ValidationResult::TsMismatch {
235 expected: MismatchDirection::ExpectedDecrement,
236 prev_time: prev.time,
237 curr_time: current.time,
238 index,
239 });
240 }
241 }
242 }
243
244 None
245 }
246}
247
248impl ValidationResult {
249 pub fn is_valid(&self) -> bool {
251 matches!(self, ValidationResult::Valid)
252 }
253
254 pub fn is_error(&self) -> bool {
256 !self.is_valid()
257 }
258
259 pub fn reason(&self) -> String {
261 match self {
262 ValidationResult::Valid => "Valid".to_string(),
263 ValidationResult::TsMismatch {
264 expected,
265 prev_time,
266 curr_time,
267 index,
268 } => {
269 let direction = match expected {
270 MismatchDirection::ExpectedIncrement => "increasing",
271 MismatchDirection::ExpectedDecrement => "decreasing",
272 };
273 format!(
274 "Ts mismatch at index {index}: expected {direction} ts, but {prev_time} came before {curr_time}"
275 )
276 }
277 ValidationResult::DuplicateTs { ts, index } => {
278 format!("Duplicate ts {ts} at index {index}")
279 }
280 ValidationResult::InvalidOHLC { index, reason } => {
281 format!("Invalid OHLC at index {index}: {reason}")
282 }
283 }
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use chrono::TimeZone;
291
292 fn create_bar(ts: DateTime<Utc>, open: f64, high: f64, low: f64, close: f64) -> Bar {
293 Bar {
294 time: ts,
295 open,
296 high,
297 low,
298 close,
299 volume: 1000.0,
300 }
301 }
302
303 #[test]
304 fn test_valid_sequence() {
305 let validator = DataValidator::new();
306 let bars = vec![
307 create_bar(
308 Utc.with_ymd_and_hms(2024, 1, 1, 10, 0, 0).unwrap(),
309 100.0,
310 105.0,
311 95.0,
312 102.0,
313 ),
314 create_bar(
315 Utc.with_ymd_and_hms(2024, 1, 1, 11, 0, 0).unwrap(),
316 102.0,
317 108.0,
318 100.0,
319 105.0,
320 ),
321 create_bar(
322 Utc.with_ymd_and_hms(2024, 1, 1, 12, 0, 0).unwrap(),
323 105.0,
324 110.0,
325 103.0,
326 107.0,
327 ),
328 ];
329
330 let results = validator.validate_sequence(&bars);
331 assert_eq!(results.len(), 1);
332 assert!(results[0].is_valid());
333 }
334
335 #[test]
336 fn test_ts_mismatch() {
337 let validator = DataValidator::new().with_logging(false);
338 let bars = vec![
339 create_bar(
340 Utc.with_ymd_and_hms(2024, 1, 1, 10, 0, 0).unwrap(),
341 100.0,
342 105.0,
343 95.0,
344 102.0,
345 ),
346 create_bar(
347 Utc.with_ymd_and_hms(2024, 1, 1, 9, 0, 0).unwrap(),
348 102.0,
349 108.0,
350 100.0,
351 105.0,
352 ), ];
354
355 let results = validator.validate_sequence(&bars);
356 assert!(
357 results
358 .iter()
359 .any(|r| matches!(r, ValidationResult::TsMismatch { .. }))
360 );
361 }
362
363 #[test]
364 fn test_duplicate_ts() {
365 let validator = DataValidator::new().with_logging(false);
366 let bars = vec![
367 create_bar(
368 Utc.with_ymd_and_hms(2024, 1, 1, 10, 0, 0).unwrap(),
369 100.0,
370 105.0,
371 95.0,
372 102.0,
373 ),
374 create_bar(
375 Utc.with_ymd_and_hms(2024, 1, 1, 10, 0, 0).unwrap(),
376 102.0,
377 108.0,
378 100.0,
379 105.0,
380 ), ];
382
383 let results = validator.validate_sequence(&bars);
384 assert!(
385 results
386 .iter()
387 .any(|r| matches!(r, ValidationResult::DuplicateTs { .. }))
388 );
389 }
390
391 #[test]
392 fn test_invalid_ohlc_high_too_low() {
393 let validator = DataValidator::new().with_logging(false);
394 let bar = create_bar(
395 Utc.with_ymd_and_hms(2024, 1, 1, 10, 0, 0).unwrap(),
396 100.0,
397 95.0, 90.0,
399 98.0,
400 );
401
402 let result = validator.validate_new_bar(None, &bar);
403 assert!(matches!(result, ValidationResult::InvalidOHLC { .. }));
404 }
405
406 #[test]
407 fn test_invalid_ohlc_low_too_high() {
408 let validator = DataValidator::new().with_logging(false);
409 let bar = create_bar(
410 Utc.with_ymd_and_hms(2024, 1, 1, 10, 0, 0).unwrap(),
411 100.0,
412 110.0,
413 105.0, 102.0,
415 );
416
417 let result = validator.validate_new_bar(None, &bar);
418 assert!(matches!(result, ValidationResult::InvalidOHLC { .. }));
419 }
420
421 #[test]
422 fn test_descending_sequence() {
423 let validator = DataValidator::new()
424 .with_direction(MismatchDirection::ExpectedDecrement)
425 .with_logging(false);
426
427 let bars = vec![
428 create_bar(
429 Utc.with_ymd_and_hms(2024, 1, 1, 12, 0, 0).unwrap(),
430 100.0,
431 105.0,
432 95.0,
433 102.0,
434 ),
435 create_bar(
436 Utc.with_ymd_and_hms(2024, 1, 1, 11, 0, 0).unwrap(),
437 102.0,
438 108.0,
439 100.0,
440 105.0,
441 ),
442 create_bar(
443 Utc.with_ymd_and_hms(2024, 1, 1, 10, 0, 0).unwrap(),
444 105.0,
445 110.0,
446 103.0,
447 107.0,
448 ),
449 ];
450
451 let results = validator.validate_sequence(&bars);
452 assert_eq!(results.len(), 1);
453 assert!(results[0].is_valid());
454 }
455
456 #[test]
457 fn test_validation_result_reason() {
458 let result = ValidationResult::TsMismatch {
459 expected: MismatchDirection::ExpectedIncrement,
460 prev_time: Utc.with_ymd_and_hms(2024, 1, 1, 10, 0, 0).unwrap(),
461 curr_time: Utc.with_ymd_and_hms(2024, 1, 1, 9, 0, 0).unwrap(),
462 index: 1,
463 };
464
465 let reason = result.reason();
466 assert!(reason.contains("mismatch"));
467 assert!(reason.contains("increasing"));
468 }
469}