use std::path::Path;
use anyhow::{Context, Result};
use serde::Deserialize;
use crate::{config::MAX_SERIES_LEN, data::batch::{normalize, Sample}};
const SLEEP_STAGES: &[&str] = &[
"Wake",
"Non-REM stage 1",
"Non-REM stage 2",
"Non-REM stage 3",
"REM sleep",
"Movement",
];
#[derive(Debug, Deserialize)]
struct SleepRow {
time_series: Vec<f32>,
label: String,
rationale: String,
}
pub struct SleepSplits {
pub train: Vec<Sample>,
pub val: Vec<Sample>,
pub test: Vec<Sample>,
}
pub fn load_sleep_splits(data_dir: &Path) -> Result<SleepSplits> {
Ok(SleepSplits {
train: load_split(data_dir, "train.jsonl")?,
val: load_split(data_dir, "val.jsonl")?,
test: load_split(data_dir, "test.jsonl")?,
})
}
fn load_split(base: &Path, filename: &str) -> Result<Vec<Sample>> {
let file = base.join("sleep_cot").join(filename);
let text = std::fs::read_to_string(&file)
.with_context(|| format!("Cannot read SleepEDF file {file:?}"))?;
text.lines()
.filter(|l| !l.trim().is_empty())
.enumerate()
.map(|(i, line)| {
let row: SleepRow = serde_json::from_str(line)
.with_context(|| format!("{filename} line {i}: parse error"))?;
row_to_sample(&row)
})
.collect()
}
fn row_to_sample(row: &SleepRow) -> Result<Sample> {
let ts_slice = &row.time_series[..row.time_series.len().min(MAX_SERIES_LEN)];
let (normed, mean, std) = normalize(ts_slice);
let pre_prompt = "\
You are given a 30-second EEG time series segment. Your task is to classify the sleep stage \
based on analysis of the data.
Instructions:
- Analyze the data objectively without presuming a particular label.
- Reason carefully and methodically about what the signal patterns suggest regarding sleep stage.
- Write your reasoning as a single, coherent paragraph. Do not use bullet points, lists, or section headers.
- Only reveal the correct class at the very end.
- Never state that you are uncertain or unable to classify the data. You must always provide a \
rationale and a final answer."
.to_string();
let stage_list = SLEEP_STAGES.join(", ");
let post_prompt = format!(
"Possible sleep stages are:\n{stage_list}\n\n\
Please now write your rationale. Make sure that your last word is the answer. \
You MUST end your response with \"Answer: \""
);
let ts_text = format!(
"The following is the EEG time series, it has mean {mean:.4} and std {std:.4}:"
);
Ok(Sample {
pre_prompt,
time_series_text: vec![ts_text],
time_series: vec![normed],
post_prompt,
answer: row.rationale.trim().to_string(),
label: Some(row.label.trim().to_string()),
})
}