use std::path::PathBuf;
use thiserror::Error;
pub type TrainResult<T> = Result<T, TrainError>;
#[derive(Debug, Error)]
pub enum TrainError {
#[error("Configuration error: {0}")]
Config(#[from] ConfigError),
#[error("Dataset error: {0}")]
Dataset(#[from] DatasetError),
#[error("MAE pretraining error: {0}")]
Mae(#[from] MaeError),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("Dataset is empty")]
EmptyDataset,
#[error("Index {index} is out of bounds for dataset of length {len}")]
IndexOutOfBounds {
index: usize,
len: usize,
},
#[error("Shape mismatch: expected {expected:?}, got {actual:?}")]
ShapeMismatch {
expected: Vec<usize>,
actual: Vec<usize>,
},
#[error("Training step failed: {0}")]
TrainingStep(String),
#[error("Checkpoint error: {message} (path: {path:?})")]
Checkpoint {
message: String,
path: PathBuf,
},
#[error("Not implemented: {0}")]
NotImplemented(String),
}
impl TrainError {
pub fn training_step<S: Into<String>>(msg: S) -> Self {
TrainError::TrainingStep(msg.into())
}
pub fn checkpoint<S: Into<String>>(msg: S, path: impl Into<PathBuf>) -> Self {
TrainError::Checkpoint {
message: msg.into(),
path: path.into(),
}
}
pub fn not_implemented<S: Into<String>>(msg: S) -> Self {
TrainError::NotImplemented(msg.into())
}
pub fn shape_mismatch(expected: Vec<usize>, actual: Vec<usize>) -> Self {
TrainError::ShapeMismatch { expected, actual }
}
}
#[derive(Debug, Error)]
pub enum ConfigError {
#[error("Invalid value for `{field}`: {reason}")]
InvalidValue {
field: &'static str,
reason: String,
},
#[error("Cannot read config file `{path}`: {source}")]
FileRead {
path: PathBuf,
#[source]
source: std::io::Error,
},
#[error("Cannot parse config file `{path}`: {source}")]
ParseError {
path: PathBuf,
#[source]
source: serde_json::Error,
},
#[error("Path `{path}` in config does not exist")]
PathNotFound {
path: PathBuf,
},
}
impl ConfigError {
pub fn invalid_value<S: Into<String>>(field: &'static str, reason: S) -> Self {
ConfigError::InvalidValue {
field,
reason: reason.into(),
}
}
}
#[derive(Debug, Error)]
pub enum DatasetError {
#[error("Data not found at `{path}`: {message}")]
DataNotFound {
path: PathBuf,
message: String,
},
#[error("Invalid data format in `{path}`: {message}")]
InvalidFormat {
path: PathBuf,
message: String,
},
#[error("I/O error reading `{path}`: {source}")]
IoError {
path: PathBuf,
#[source]
source: std::io::Error,
},
#[error("Subcarrier count mismatch in `{path}`: file has {found}, expected {expected}")]
SubcarrierMismatch {
path: PathBuf,
found: usize,
expected: usize,
},
#[error("Index {idx} out of bounds (dataset has {len} samples)")]
IndexOutOfBounds {
idx: usize,
len: usize,
},
#[error("NumPy read error in `{path}`: {message}")]
NpyReadError {
path: PathBuf,
message: String,
},
#[error("Metadata error for subject {subject_id}: {message}")]
MetadataError {
subject_id: u32,
message: String,
},
#[error("File format error: {0}")]
Format(String),
#[error("Directory not found: {path}")]
DirectoryNotFound {
path: String,
},
#[error("No subjects found in `{data_dir}` for IDs: {requested:?}")]
NoSubjectsFound {
data_dir: PathBuf,
requested: Vec<u32>,
},
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Invalid split: {0}")]
InvalidSplit(String),
}
impl DatasetError {
pub fn not_found<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
DatasetError::DataNotFound {
path: path.into(),
message: msg.into(),
}
}
pub fn invalid_format<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
DatasetError::InvalidFormat {
path: path.into(),
message: msg.into(),
}
}
pub fn io_error(path: impl Into<PathBuf>, source: std::io::Error) -> Self {
DatasetError::IoError {
path: path.into(),
source,
}
}
pub fn subcarrier_mismatch(path: impl Into<PathBuf>, found: usize, expected: usize) -> Self {
DatasetError::SubcarrierMismatch {
path: path.into(),
found,
expected,
}
}
pub fn npy_read<S: Into<String>>(path: impl Into<PathBuf>, msg: S) -> Self {
DatasetError::NpyReadError {
path: path.into(),
message: msg.into(),
}
}
}
#[derive(Debug, Error)]
pub enum SubcarrierError {
#[error("Subcarrier count must be >= 1, got {count}")]
ZeroCount {
count: usize,
},
#[error(
"Subcarrier shape mismatch: last dim is {actual_sc} but src_n={expected_sc} \
(full shape: {shape:?})"
)]
InputShapeMismatch {
expected_sc: usize,
actual_sc: usize,
shape: Vec<usize>,
},
#[error("Interpolation method `{method}` is not implemented")]
MethodNotImplemented {
method: String,
},
#[error("src_n == dst_n == {count}; call interpolate only when counts differ")]
NopInterpolation {
count: usize,
},
#[error("Numerical error: {0}")]
NumericalError(String),
}
impl SubcarrierError {
pub fn numerical<S: Into<String>>(msg: S) -> Self {
SubcarrierError::NumericalError(msg.into())
}
}
#[derive(Debug, Error)]
pub enum MaeError {
#[error(
"Window length {actual} does not match time × subcarriers = \
{time} × {subc} = {expected}"
)]
WindowShapeMismatch {
time: usize,
subc: usize,
expected: usize,
actual: usize,
},
#[error("Patch {axis} extent {patch} exceeds window {axis} extent {window}")]
PatchExceedsWindow {
axis: &'static str,
patch: usize,
window: usize,
},
#[error(
"Window {axis} extent {window} is not divisible by patch {axis} extent \
{patch} (remainder {remainder}); crop the window to {crop} or change \
the patch size"
)]
NotDivisible {
axis: &'static str,
window: usize,
patch: usize,
remainder: usize,
crop: usize,
},
#[error("Invalid mask ratio {ratio}: must be finite and strictly inside (0, 1)")]
InvalidMaskRatio {
ratio: f64,
},
#[error("Non-finite CSI value {value} at (t={row}, sc={col})")]
NonFiniteValue {
row: usize,
col: usize,
value: f32,
},
}