use olympian::checks::series::{
SPIKE_LEADING_PER_RUN, SPIKE_TRAILING_PER_RUN, STEP_LEADING_PER_RUN,
};
use serde::Deserialize;
use std::{collections::HashMap, path::Path};
use thiserror::Error;
#[derive(Debug, Deserialize, PartialEq, Clone)]
pub struct Pipeline {
#[serde(rename = "step")]
pub steps: Vec<PipelineStep>,
#[serde(skip)]
pub num_leading_required: u8,
#[serde(skip)]
pub num_trailing_required: u8,
}
#[derive(Debug, Deserialize, PartialEq, Clone)]
pub struct PipelineStep {
pub name: String,
#[serde(flatten)]
pub check: CheckConf,
}
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[serde(rename_all = "snake_case")]
#[allow(missing_docs)]
pub enum CheckConf {
SpecialValueCheck(SpecialValueCheckConf),
RangeCheck(RangeCheckConf),
RangeCheckDynamic(RangeCheckDynamicConf),
StepCheck(StepCheckConf),
SpikeCheck(SpikeCheckConf),
FlatlineCheck(FlatlineCheckConf),
BuddyCheck(BuddyCheckConf),
Sct(SctConf),
ModelConsistencyCheck(ModelConsistencyCheckConf),
#[serde(skip)]
Dummy,
}
impl CheckConf {
fn get_num_leading_trailing(&self) -> (u8, u8) {
match self {
CheckConf::SpecialValueCheck(_)
| CheckConf::RangeCheck(_)
| CheckConf::RangeCheckDynamic(_)
| CheckConf::BuddyCheck(_)
| CheckConf::Sct(_)
| CheckConf::ModelConsistencyCheck(_)
| CheckConf::Dummy => (0, 0),
CheckConf::StepCheck(_) => (STEP_LEADING_PER_RUN, 0),
CheckConf::SpikeCheck(_) => (SPIKE_LEADING_PER_RUN, SPIKE_TRAILING_PER_RUN),
CheckConf::FlatlineCheck(conf) => (conf.max, 0),
}
}
}
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[allow(missing_docs)]
pub struct SpecialValueCheckConf {
pub special_values: Vec<f64>,
}
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[allow(missing_docs)]
pub struct RangeCheckConf {
pub max: f64,
pub min: f64,
}
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[allow(missing_docs)]
pub struct RangeCheckDynamicConf {
pub source: String,
}
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[allow(missing_docs)]
pub struct StepCheckConf {
pub max: f64,
}
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[allow(missing_docs)]
pub struct SpikeCheckConf {
pub max: f64,
}
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[allow(missing_docs)]
pub struct FlatlineCheckConf {
pub max: u8,
}
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[allow(missing_docs)]
pub struct BuddyCheckConf {
pub radii: f64,
pub min_buddies: u32,
pub threshold: f64,
pub max_elev_diff: f64,
pub elev_gradient: f64,
pub min_std: f64,
pub num_iterations: u32,
}
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[allow(missing_docs)]
pub struct SctConf {
pub num_min: usize,
pub num_max: usize,
pub inner_radius: f64,
pub outer_radius: f64,
pub num_iterations: u32,
pub num_min_prof: usize,
pub min_elev_diff: f64,
pub min_horizontal_scale: f64,
pub vertical_scale: f64,
pub pos: f64,
pub neg: f64,
pub eps2: f64,
pub obs_to_check: Option<Vec<bool>>,
}
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[allow(missing_docs)]
pub struct ModelConsistencyCheckConf {
pub model_source: String,
pub model_args: String,
pub threshold: f64,
}
#[derive(Error, Debug)]
pub enum Error {
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("failed to deserialize toml: {0}")]
TomlDeserialize(#[from] toml::de::Error),
#[error("the directory contained something that wasn't a file")]
DirectoryStructure,
#[error("pipeline filename could not be parsed as a unicode string")]
InvalidFilename,
}
pub fn derive_num_leading_trailing(pipeline: &Pipeline) -> (u8, u8) {
pipeline
.steps
.iter()
.map(|step| step.check.get_num_leading_trailing())
.fold((0, 0), |acc, x| (acc.0.max(x.0), acc.1.max(x.1)))
}
pub fn load_pipelines(path: impl AsRef<Path>) -> Result<HashMap<String, Pipeline>, Error> {
std::fs::read_dir(path)?
.map(|entry| {
let entry = entry?;
if !entry.file_type()?.is_file() {
return Err(Error::DirectoryStructure);
}
let name = entry
.file_name()
.to_str()
.ok_or(Error::InvalidFilename)?
.trim_end_matches(".toml")
.to_string();
let mut pipeline = toml::from_str(&std::fs::read_to_string(entry.path())?)?;
(
pipeline.num_leading_required,
pipeline.num_trailing_required,
) = derive_num_leading_trailing(&pipeline);
Ok(Some((name, pipeline)))
})
.filter_map(Result::transpose)
.collect()
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_deserialize_fresh() {
load_pipelines("sample_pipelines/fresh")
.unwrap()
.get("TA_PT1H")
.unwrap();
}
}