use std::path::Path;
use burn::data::dataset::Dataset;
use crate::constants::{NUM_CHANNELS, TIME_STEPS};
use crate::data::download::{generate_synthetic_dataset, SyntheticDataConfig};
use crate::error::{Result, SensorLMError};
#[derive(Debug, Clone)]
pub struct SensorTextItem {
pub sensor: Vec<f32>,
pub token_ids: Vec<i32>,
pub attention_mask: Vec<i32>,
pub caption_text: String,
}
pub struct SyntheticSensorDataset {
items: Vec<SensorTextItem>,
}
impl SyntheticSensorDataset {
pub fn new(num_samples: usize, seed: u64, max_seq_len: usize) -> Self {
let cfg = SyntheticDataConfig {
num_samples,
seed,
add_circadian: true,
add_missingness: true,
missingness_rate: 0.05,
};
let raw = generate_synthetic_dataset(&cfg);
let items = raw
.into_iter()
.map(|s| {
let sensor: Vec<f32> = s.sensor.iter().copied().collect();
let raw_ids: Vec<i32> = s
.caption
.chars()
.take(max_seq_len)
.map(|c| c as i32 % 32_000)
.collect();
let len = raw_ids.len();
let mut token_ids = raw_ids;
token_ids.resize(max_seq_len, 1);
let mut attention_mask = vec![1i32; len];
attention_mask.resize(max_seq_len, 0);
SensorTextItem {
sensor,
token_ids,
attention_mask,
caption_text: s.caption,
}
})
.collect();
Self { items }
}
}
impl Dataset<SensorTextItem> for SyntheticSensorDataset {
fn get(&self, index: usize) -> Option<SensorTextItem> {
self.items.get(index).cloned()
}
fn len(&self) -> usize {
self.items.len()
}
}
pub struct CsvSensorDataset {
items: Vec<SensorTextItem>,
}
impl CsvSensorDataset {
pub fn from_csv<F>(path: &Path, max_seq_len: usize, tokenize: F) -> Result<Self>
where
F: Fn(&str) -> Vec<i32>,
{
let expected_sensor_len = TIME_STEPS * NUM_CHANNELS;
let mut items = Vec::new();
let mut rdr = csv::Reader::from_path(path)
.map_err(|e| SensorLMError::DatasetError(e.to_string()))?;
for result in rdr.records() {
let record = result.map_err(|e| SensorLMError::DatasetError(e.to_string()))?;
if record.len() < expected_sensor_len + 1 {
return Err(SensorLMError::DatasetError(format!(
"Expected at least {} columns, got {}",
expected_sensor_len + 1,
record.len()
)));
}
let sensor: Vec<f32> = (0..expected_sensor_len)
.map(|i| {
record[i]
.trim()
.parse::<f32>()
.unwrap_or(0.0)
})
.collect();
let caption = record[expected_sensor_len].trim().to_string();
let mut token_ids = tokenize(&caption);
let real_len = token_ids.len().min(max_seq_len);
token_ids.truncate(real_len);
let mut attn = vec![1i32; real_len];
token_ids.resize(max_seq_len, 1);
attn.resize(max_seq_len, 0);
items.push(SensorTextItem {
sensor,
token_ids,
attention_mask: attn,
caption_text: caption,
});
}
Ok(Self { items })
}
}
impl Dataset<SensorTextItem> for CsvSensorDataset {
fn get(&self, index: usize) -> Option<SensorTextItem> {
self.items.get(index).cloned()
}
fn len(&self) -> usize {
self.items.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_synthetic_dataset() {
let ds = SyntheticSensorDataset::new(16, 99, 256);
assert_eq!(ds.len(), 16);
let item = ds.get(0).expect("first item");
assert_eq!(item.sensor.len(), TIME_STEPS * NUM_CHANNELS);
assert_eq!(item.token_ids.len(), 256);
assert_eq!(item.attention_mask.len(), 256);
}
}