use crate::ev_formats::dataframe_builder::EventDataFrameBuilder;
use crate::ev_formats::{EventFormat, LoadConfig};
use polars::prelude::*;
use std::fs::File;
use std::io::Read;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct AerConfig {
pub big_endian: bool,
pub validate_coordinates: bool,
pub max_x: u16,
pub max_y: u16,
pub skip_invalid_events: bool,
pub generate_timestamps: bool,
pub timestamp_mode: TimestampMode,
pub start_timestamp: f64,
pub time_increment: f64,
pub max_events: Option<usize>,
pub bytes_per_event: usize,
}
#[derive(Debug, Clone, PartialEq)]
pub enum TimestampMode {
Sequential,
Uniform,
Exponential,
Custom(Vec<f64>),
}
#[derive(Debug)]
pub enum AerError {
Io(std::io::Error),
InvalidFileSize(u64, usize),
InvalidCoordinate(u16, u16, u16, u16),
InvalidEventData(usize, Vec<u8>),
InsufficientData(usize, usize),
InvalidBytesPerEvent(usize),
EmptyFile,
ValidationFailed(String),
}
impl std::fmt::Display for AerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AerError::Io(e) => write!(f, "I/O error: {e}"),
AerError::InvalidFileSize(size, expected) => write!(
f,
"Invalid file size: {size} bytes, expected multiple of {expected}"
),
AerError::InvalidCoordinate(x, y, max_x, max_y) => write!(
f,
"Invalid coordinate: x={x}, y={y}, max_x={max_x}, max_y={max_y}"
),
AerError::InvalidEventData(offset, data) => {
write!(f, "Invalid event data at byte {offset}: {data:02X?}")
}
AerError::InsufficientData(expected, actual) => write!(
f,
"Insufficient data: expected {expected} bytes, got {actual}"
),
AerError::InvalidBytesPerEvent(bytes) => {
write!(f, "Invalid bytes per event: {bytes}, expected 2 or 4")
}
AerError::EmptyFile => write!(f, "File is empty"),
AerError::ValidationFailed(msg) => write!(f, "Event validation failed: {msg}"),
}
}
}
impl std::error::Error for AerError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
AerError::Io(e) => Some(e),
_ => None,
}
}
}
impl From<std::io::Error> for AerError {
fn from(error: std::io::Error) -> Self {
AerError::Io(error)
}
}
pub type AerResult<T> = Result<T, AerError>;
#[derive(Debug, Clone)]
pub struct AerMetadata {
pub file_size: u64,
pub event_count: usize,
pub bytes_per_event: usize,
pub endianness: String,
pub coordinate_bounds: Option<(u16, u16, u16, u16)>, pub timestamp_range: Option<(f64, f64)>,
pub polarity_distribution: Option<(usize, usize)>,
}
impl Default for AerConfig {
fn default() -> Self {
Self {
big_endian: false,
validate_coordinates: true,
max_x: 511, max_y: 511, skip_invalid_events: false,
generate_timestamps: true,
timestamp_mode: TimestampMode::Sequential,
start_timestamp: 0.0,
time_increment: 1e-6, max_events: None,
bytes_per_event: 4, }
}
}
impl AerConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_endianness(mut self, big_endian: bool) -> Self {
self.big_endian = big_endian;
self
}
pub fn with_coordinate_bounds(mut self, max_x: u16, max_y: u16) -> Self {
self.max_x = max_x;
self.max_y = max_y;
self
}
pub fn with_timestamp_generation(
mut self,
generate: bool,
mode: TimestampMode,
start: f64,
increment: f64,
) -> Self {
self.generate_timestamps = generate;
self.timestamp_mode = mode;
self.start_timestamp = start;
self.time_increment = increment;
self
}
pub fn with_bytes_per_event(mut self, bytes: usize) -> Self {
self.bytes_per_event = bytes;
self
}
pub fn with_validation(mut self, validate: bool, skip_invalid: bool) -> Self {
self.validate_coordinates = validate;
self.skip_invalid_events = skip_invalid;
self
}
pub fn with_max_events(mut self, max_events: Option<usize>) -> Self {
self.max_events = max_events;
self
}
}
pub struct AerReader {
config: AerConfig,
}
impl AerReader {
pub fn new() -> Self {
Self {
config: AerConfig::default(),
}
}
pub fn with_config(config: AerConfig) -> Self {
Self { config }
}
pub fn read_file<P: AsRef<Path>>(&self, path: P) -> AerResult<(DataFrame, AerMetadata)> {
let mut file = File::open(path.as_ref())?;
let file_size = file.metadata()?.len();
if file_size == 0 {
return Err(AerError::EmptyFile);
}
if file_size % self.config.bytes_per_event as u64 != 0 {
return Err(AerError::InvalidFileSize(
file_size,
self.config.bytes_per_event,
));
}
let expected_event_count = (file_size / self.config.bytes_per_event as u64) as usize;
let event_count = match self.config.max_events {
Some(max) => expected_event_count.min(max),
None => expected_event_count,
};
let mut buffer = vec![0u8; event_count * self.config.bytes_per_event];
file.read_exact(&mut buffer)?;
let (events, metadata) = self.parse_events(&buffer, file_size)?;
Ok((events, metadata))
}
fn parse_events(&self, data: &[u8], file_size: u64) -> AerResult<(DataFrame, AerMetadata)> {
if self.config.bytes_per_event != 2 && self.config.bytes_per_event != 4 {
return Err(AerError::InvalidBytesPerEvent(self.config.bytes_per_event));
}
{
let event_count = data.len() / self.config.bytes_per_event;
let mut builder = EventDataFrameBuilder::new(EventFormat::AER, event_count);
let mut min_x = u16::MAX;
let mut min_y = u16::MAX;
let mut max_x = 0u16;
let mut max_y = 0u16;
let mut positive_count = 0;
let mut negative_count = 0;
let mut valid_events = 0;
let mut parsed_events = Vec::with_capacity(event_count);
for i in 0..event_count {
let offset = i * self.config.bytes_per_event;
match self
.parse_single_event(&data[offset..offset + self.config.bytes_per_event], i)
{
Ok((x, y, t, polarity)) => {
min_x = min_x.min(x);
min_y = min_y.min(y);
max_x = max_x.max(x);
max_y = max_y.max(y);
if polarity {
positive_count += 1;
} else {
negative_count += 1;
}
parsed_events.push((x, y, t, polarity));
valid_events += 1;
}
Err(e) => {
if self.config.skip_invalid_events {
continue;
} else {
return Err(e);
}
}
}
}
if self.config.generate_timestamps {
self.generate_timestamps_for_parsed_events(&mut parsed_events)?;
}
for (x, y, timestamp, polarity) in &parsed_events {
builder.add_event(*x, *y, *timestamp, *polarity);
}
let events = builder.build().map_err(|e| {
AerError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Failed to build DataFrame: {}", e),
))
})?;
let timestamp_range = if !parsed_events.is_empty() {
Some((parsed_events[0].2, parsed_events[parsed_events.len() - 1].2))
} else {
None
};
let coordinate_bounds = if valid_events > 0 {
Some((min_x, min_y, max_x, max_y))
} else {
None
};
let metadata = AerMetadata {
file_size,
event_count: valid_events,
bytes_per_event: self.config.bytes_per_event,
endianness: if self.config.big_endian {
"big".to_string()
} else {
"little".to_string()
},
coordinate_bounds,
timestamp_range,
polarity_distribution: Some((positive_count, negative_count)),
};
Ok((events, metadata))
}
}
fn parse_single_event(
&self,
data: &[u8],
_event_index: usize,
) -> AerResult<(u16, u16, f64, bool)> {
if data.len() < self.config.bytes_per_event {
return Err(AerError::InsufficientData(
self.config.bytes_per_event,
data.len(),
));
}
let raw_event = match self.config.bytes_per_event {
2 => {
if self.config.big_endian {
u16::from_be_bytes([data[0], data[1]]) as u32
} else {
u16::from_le_bytes([data[0], data[1]]) as u32
}
}
4 => {
if self.config.big_endian {
u32::from_be_bytes([data[0], data[1], data[2], data[3]])
} else {
u32::from_le_bytes([data[0], data[1], data[2], data[3]])
}
}
_ => {
return Err(AerError::InvalidBytesPerEvent(self.config.bytes_per_event));
}
};
let polarity_bit = (raw_event & 0x1) as u8;
let x = ((raw_event >> 1) & 0x1FF) as u16; let y = ((raw_event >> 10) & 0x1FF) as u16; let polarity = polarity_bit == 1;
if self.config.validate_coordinates && (x > self.config.max_x || y > self.config.max_y) {
return Err(AerError::InvalidCoordinate(
x,
y,
self.config.max_x,
self.config.max_y,
));
}
Ok((x, y, 0.0, polarity))
}
fn generate_timestamps_for_parsed_events(
&self,
events: &mut [(u16, u16, f64, bool)],
) -> AerResult<()> {
if events.is_empty() {
return Ok(());
}
match &self.config.timestamp_mode {
TimestampMode::Sequential => {
for (i, event) in events.iter_mut().enumerate() {
event.2 = self.config.start_timestamp + (i as f64 * self.config.time_increment);
}
}
TimestampMode::Uniform => {
let total_time = events.len() as f64 * self.config.time_increment;
let event_count = events.len();
for (i, event) in events.iter_mut().enumerate() {
event.2 =
self.config.start_timestamp + (i as f64 / event_count as f64) * total_time;
}
}
TimestampMode::Exponential => {
let mut current_time = self.config.start_timestamp;
let lambda = 1.0 / self.config.time_increment; for event in events.iter_mut() {
let u: f64 = (fastrand::f64() + 1e-10).ln(); let interval = -u / lambda;
current_time += interval;
event.2 = current_time;
}
}
TimestampMode::Custom(timestamps) => {
if timestamps.len() != events.len() {
let timestamp_count = timestamps.len();
let event_count = events.len();
return Err(AerError::ValidationFailed(format!(
"Custom timestamp count ({timestamp_count}) doesn't match event count ({event_count})"
)));
}
for (event, ×tamp) in events.iter_mut().zip(timestamps.iter()) {
event.2 = timestamp;
}
}
}
Ok(())
}
pub fn read_with_config<P: AsRef<Path>>(
&self,
path: P,
load_config: &LoadConfig,
) -> AerResult<DataFrame> {
let (events, _metadata) = self.read_file(path)?;
{
let mut df = events.lazy();
if load_config.sort {
df = df.sort(["t"], Default::default());
}
df.collect().map_err(|e| {
AerError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Failed to process DataFrame: {}", e),
))
})
}
}
pub fn config(&self) -> &AerConfig {
&self.config
}
}
impl Default for AerReader {
fn default() -> Self {
Self::new()
}
}
pub fn read_aer_file<P: AsRef<Path>>(path: P) -> AerResult<(DataFrame, AerMetadata)> {
let reader = AerReader::new();
reader.read_file(path)
}
pub fn read_aer_file_with_config<P: AsRef<Path>>(
path: P,
config: AerConfig,
) -> AerResult<(DataFrame, AerMetadata)> {
let reader = AerReader::with_config(config);
reader.read_file(path)
}
pub fn is_aer_format<P: AsRef<Path>>(path: P) -> bool {
let file = match File::open(path.as_ref()) {
Ok(f) => f,
Err(_) => return false,
};
let file_size = match file.metadata() {
Ok(m) => m.len(),
Err(_) => return false,
};
if file_size % 4 != 0 && file_size % 2 != 0 {
return false;
}
let bytes_to_read = std::cmp::min(32, file_size as usize);
let mut buffer = vec![0u8; bytes_to_read];
let mut file = file;
if file.read_exact(&mut buffer).is_err() {
return false;
}
let config = AerConfig::default();
let reader = AerReader::with_config(config);
for i in 0..8 {
let offset = i * 4;
if offset + 4 > buffer.len() {
break;
}
if reader
.parse_single_event(&buffer[offset..offset + 4], i)
.is_err()
{
return false;
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_aer_config_default() {
let config = AerConfig::default();
assert!(!config.big_endian);
assert!(config.validate_coordinates);
assert_eq!(config.max_x, 511);
assert_eq!(config.max_y, 511);
assert_eq!(config.bytes_per_event, 4);
assert!(config.generate_timestamps);
assert_eq!(config.timestamp_mode, TimestampMode::Sequential);
}
#[test]
fn test_aer_config_builder() {
let config = AerConfig::new()
.with_endianness(true)
.with_coordinate_bounds(1023, 1023)
.with_bytes_per_event(2)
.with_validation(true, true);
assert!(config.big_endian);
assert_eq!(config.max_x, 1023);
assert_eq!(config.max_y, 1023);
assert_eq!(config.bytes_per_event, 2);
assert!(config.skip_invalid_events);
}
#[test]
fn test_parse_18bit_aer_event() {
let config = AerConfig::default();
let reader = AerReader::with_config(config);
let raw_event = 205001u32;
let data = raw_event.to_le_bytes();
let event = reader.parse_single_event(&data, 0).unwrap();
assert_eq!(event.x, 100);
assert_eq!(event.y, 200);
assert!(event.polarity);
}
#[test]
fn test_parse_negative_polarity() {
let config = AerConfig::default();
let reader = AerReader::with_config(config);
let raw_event = 76900u32;
let data = raw_event.to_le_bytes();
let event = reader.parse_single_event(&data, 0).unwrap();
assert_eq!(event.x, 50);
assert_eq!(event.y, 75);
assert!(!event.polarity);
}
#[test]
fn test_coordinate_validation() {
let config = AerConfig::default().with_coordinate_bounds(100, 100);
let reader = AerReader::with_config(config);
let raw_event = (200u32 << 10) | (150u32 << 1) | 1; let data = raw_event.to_le_bytes();
let result = reader.parse_single_event(&data, 0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
AerError::InvalidCoordinate(_, _, _, _)
));
}
#[test]
fn test_skip_invalid_events() {
let config = AerConfig::default()
.with_coordinate_bounds(100, 100)
.with_validation(true, true); let reader = AerReader::with_config(config);
let mut data = Vec::new();
let valid_event = (75u32 << 10) | (50u32 << 1) | 1;
data.extend_from_slice(&valid_event.to_le_bytes());
let invalid_event = (200u32 << 10) | (150u32 << 1);
data.extend_from_slice(&invalid_event.to_le_bytes());
let valid_event2 = (30u32 << 10) | (25u32 << 1);
data.extend_from_slice(&valid_event2.to_le_bytes());
let (events, metadata) = reader.parse_events(&data, data.len() as u64).unwrap();
assert_eq!(events.height(), 2); assert_eq!(metadata.event_count, 2);
assert_eq!(events[0].x, 50);
assert_eq!(events[0].y, 75);
assert_eq!(events[1].x, 25);
assert_eq!(events[1].y, 30);
}
#[test]
fn test_timestamp_generation_sequential() {
let config = AerConfig::default().with_timestamp_generation(
true,
TimestampMode::Sequential,
1.0,
0.001,
);
let reader = AerReader::with_config(config);
let events_data = vec![
((100u32 << 10) | (50u32 << 1) | 1).to_le_bytes(),
((200u32 << 10) | (150u32 << 1)).to_le_bytes(),
((300u32 << 10) | (250u32 << 1) | 1).to_le_bytes(),
];
let data: Vec<u8> = events_data.into_iter().flatten().collect();
let (events, _) = reader.parse_events(&data, data.len() as u64).unwrap();
assert_eq!(events.height(), 3);
assert_eq!(events[0].t, 1.0);
assert_eq!(events[1].t, 1.001);
assert_eq!(events[2].t, 1.002);
}
#[test]
fn test_timestamp_generation_uniform() {
let config = AerConfig::default().with_timestamp_generation(
true,
TimestampMode::Uniform,
0.0,
0.003,
);
let reader = AerReader::with_config(config);
let events_data = vec![
((100u32 << 10) | (50u32 << 1) | 1).to_le_bytes(),
((200u32 << 10) | (150u32 << 1)).to_le_bytes(),
((300u32 << 10) | (250u32 << 1) | 1).to_le_bytes(),
];
let data: Vec<u8> = events_data.into_iter().flatten().collect();
let (events, _) = reader.parse_events(&data, data.len() as u64).unwrap();
assert_eq!(events.height(), 3);
assert_eq!(events[0].t, 0.0);
assert_eq!(events[1].t, 0.003);
assert_eq!(events[2].t, 0.006);
}
#[test]
fn test_big_endian_parsing() {
let config = AerConfig::default().with_endianness(true);
let reader = AerReader::with_config(config);
let raw_event = 205001u32;
let data = raw_event.to_be_bytes(); let event = reader.parse_single_event(&data, 0).unwrap();
assert_eq!(event.x, 100);
assert_eq!(event.y, 200);
assert!(event.polarity);
}
#[test]
fn test_16bit_format() {
let config = AerConfig::default().with_bytes_per_event(2);
let reader = AerReader::with_config(config);
let raw_event = (75u16 << 8) | (50u16 << 1) | 1; let data = raw_event.to_le_bytes();
let event = reader.parse_single_event(&data, 0).unwrap();
assert!(event.polarity);
}
#[test]
fn test_read_aer_file() {
let config = AerConfig::default();
let reader = AerReader::with_config(config);
let mut temp_file = NamedTempFile::new().unwrap();
let events_data = vec![
((100u32 << 10) | (50u32 << 1) | 1).to_le_bytes(),
((200u32 << 10) | (150u32 << 1)).to_le_bytes(),
((300u32 << 10) | (250u32 << 1) | 1).to_le_bytes(),
];
for event_bytes in events_data {
temp_file.write_all(&event_bytes).unwrap();
}
let (events, metadata) = reader.read_file(temp_file.path()).unwrap();
assert_eq!(events.height(), 3);
assert_eq!(metadata.event_count, 3);
assert_eq!(metadata.bytes_per_event, 4);
assert_eq!(metadata.file_size, 12);
assert_eq!(events[0].x, 50);
assert_eq!(events[0].y, 100);
assert!(events[0].polarity);
assert!(metadata.coordinate_bounds.is_some());
assert!(metadata.polarity_distribution.is_some());
let (pos_count, neg_count) = metadata.polarity_distribution.unwrap();
assert_eq!(pos_count, 2);
assert_eq!(neg_count, 1);
}
#[test]
fn test_empty_file() {
let config = AerConfig::default();
let reader = AerReader::with_config(config);
let temp_file = NamedTempFile::new().unwrap();
let result = reader.read_file(temp_file.path());
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), AerError::EmptyFile));
}
#[test]
fn test_invalid_file_size() {
let config = AerConfig::default();
let reader = AerReader::with_config(config);
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(&[1, 2, 3]).unwrap(); let result = reader.read_file(temp_file.path());
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
AerError::InvalidFileSize(_, _)
));
}
#[test]
fn test_custom_timestamps() {
let custom_timestamps = vec![0.1, 0.5, 1.2];
let config = AerConfig::default().with_timestamp_generation(
true,
TimestampMode::Custom(custom_timestamps.clone()),
0.0,
0.001,
);
let reader = AerReader::with_config(config);
let events_data = vec![
((100u32 << 10) | (50u32 << 1) | 1).to_le_bytes(),
((200u32 << 10) | (150u32 << 1)).to_le_bytes(),
((300u32 << 10) | (250u32 << 1) | 1).to_le_bytes(),
];
let data: Vec<u8> = events_data.into_iter().flatten().collect();
let (events, _) = reader.parse_events(&data, data.len() as u64).unwrap();
assert_eq!(events.height(), 3);
assert_eq!(events[0].t, 0.1);
assert_eq!(events[1].t, 0.5);
assert_eq!(events[2].t, 1.2);
}
#[test]
fn test_is_aer_format() {
let mut temp_file = NamedTempFile::new().unwrap();
let events_data = vec![
((100u32 << 10) | (50u32 << 1) | 1).to_le_bytes(),
((200u32 << 10) | (150u32 << 1)).to_le_bytes(),
];
for event_bytes in events_data {
temp_file.write_all(&event_bytes).unwrap();
}
assert!(is_aer_format(temp_file.path()));
}
#[test]
fn test_max_events_limit() {
let config = AerConfig::default().with_max_events(Some(2));
let reader = AerReader::with_config(config);
let mut temp_file = NamedTempFile::new().unwrap();
let events_data = vec![
((100u32 << 10) | (50u32 << 1) | 1).to_le_bytes(),
((200u32 << 10) | (150u32 << 1)).to_le_bytes(),
((300u32 << 10) | (250u32 << 1) | 1).to_le_bytes(),
];
for event_bytes in events_data {
temp_file.write_all(&event_bytes).unwrap();
}
let (events, metadata) = reader.read_file(temp_file.path()).unwrap();
assert_eq!(events.height(), 2);
assert_eq!(metadata.event_count, 2);
}
}