use chrono::{DateTime, Utc};
use shape_ast::error::{Result, ShapeError};
use std::sync::RwLock;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum DataAccessMode {
Unrestricted,
Restricted,
ForwardOnly,
}
#[derive(Debug)]
pub struct LookAheadGuard {
mode: DataAccessMode,
current_time: RwLock<Option<DateTime<Utc>>>,
strict_mode: bool,
access_log: RwLock<Vec<DataAccess>>,
}
#[derive(Debug, Clone)]
pub struct DataAccess {
pub timestamp: DateTime<Utc>,
pub accessed_time: DateTime<Utc>,
pub access_type: String,
pub allowed: bool,
}
impl LookAheadGuard {
pub fn new(mode: DataAccessMode, strict_mode: bool) -> Self {
Self {
mode,
current_time: RwLock::new(None),
strict_mode,
access_log: RwLock::new(Vec::new()),
}
}
pub fn set_current_time(&self, time: DateTime<Utc>) {
*self.current_time.write().unwrap() = Some(time);
}
pub fn check_access(&self, access_time: DateTime<Utc>, access_type: &str) -> Result<()> {
let current =
self.current_time
.read()
.unwrap()
.ok_or_else(|| ShapeError::RuntimeError {
message: "Current time not set in LookAheadGuard".to_string(),
location: None,
})?;
let allowed = match self.mode {
DataAccessMode::Unrestricted => true, DataAccessMode::Restricted | DataAccessMode::ForwardOnly => access_time <= current,
};
self.access_log.write().unwrap().push(DataAccess {
timestamp: current,
accessed_time: access_time,
access_type: access_type.to_string(),
allowed,
});
if !allowed {
if self.strict_mode {
return Err(ShapeError::RuntimeError {
message: format!(
"Future data access violation: Attempted to access data at {} while current time is {}",
access_time, current
),
location: None,
});
} else {
eprintln!(
"WARNING: Future data access - accessing {} at current time {}",
access_time, current
);
}
}
Ok(())
}
pub fn check_row_index(&self, index: i32, _access_type: &str) -> Result<()> {
match self.mode {
DataAccessMode::Unrestricted => Ok(()), DataAccessMode::Restricted | DataAccessMode::ForwardOnly => {
if index > 0 {
let msg = format!(
"Future data access violation: Attempted to access data[{}] in restricted mode",
index
);
if self.strict_mode {
return Err(ShapeError::RuntimeError {
message: msg,
location: None,
});
} else {
eprintln!("WARNING: {}", msg);
}
}
Ok(())
}
}
}
pub fn get_access_log(&self) -> Vec<DataAccess> {
self.access_log.read().unwrap().clone()
}
pub fn clear_log(&self) {
self.access_log.write().unwrap().clear();
}
pub fn get_violation_summary(&self) -> LookAheadSummary {
let log = self.access_log.read().unwrap();
let violations: Vec<_> = log
.iter()
.filter(|access| !access.allowed)
.cloned()
.collect();
LookAheadSummary {
total_accesses: log.len(),
violations: violations.len(),
violation_details: violations,
}
}
}
impl Clone for LookAheadGuard {
fn clone(&self) -> Self {
Self {
mode: self.mode,
current_time: RwLock::new(*self.current_time.read().unwrap()),
strict_mode: self.strict_mode,
access_log: RwLock::new(self.access_log.read().unwrap().clone()),
}
}
}
#[derive(Debug, Clone)]
pub struct LookAheadSummary {
pub total_accesses: usize,
pub violations: usize,
pub violation_details: Vec<DataAccess>,
}
impl LookAheadSummary {
pub fn print_report(&self) {
println!("=== Data Access Validation Report ===");
println!("Total data accesses: {}", self.total_accesses);
println!("Violations found: {}", self.violations);
if self.violations > 0 {
println!("\nViolation Details:");
for (i, violation) in self.violation_details.iter().enumerate() {
println!(
" {}. At {}: Tried to access {} (type: {})",
i + 1,
violation.timestamp.format("%Y-%m-%d %H:%M:%S"),
violation.accessed_time.format("%Y-%m-%d %H:%M:%S"),
violation.access_type
);
}
}
}
}