use std::fs;
use serde::{Deserialize, Serialize};
use tempfile::TempDir;
use pondrs::app::config::{apply_overrides, deserialize_config, load_yaml};
use pondrs::app::App;
use pondrs::datasets::{MemoryDataset, Param};
use pondrs::error::PondError;
use pondrs::graph::build_pipeline_graph;
use pondrs::hooks::LoggingHook;
use pondrs::runners::{SequentialRunner, ParallelRunner};
use pondrs::{Dataset, Node, Pipeline, PipelineInfo, Steps};
#[derive(Serialize, Deserialize)]
struct TestCatalog {
a: MemoryDataset<i32>,
b: MemoryDataset<i32>,
c: MemoryDataset<i32>,
}
#[derive(Serialize, Deserialize)]
struct TestParams {
scale: Param<i32>,
offset: Param<i32>,
}
#[derive(Serialize, Deserialize)]
struct NestedParams {
model: ModelParams,
threshold: Param<f64>,
}
#[derive(Serialize, Deserialize)]
struct ModelParams {
learning_rate: Param<f64>,
epochs: Param<usize>,
}
fn write_yaml(dir: &TempDir, name: &str, content: &str) -> String {
let path = dir.path().join(name);
fs::write(&path, content).unwrap();
path.to_str().unwrap().to_string()
}
#[test]
fn test_load_yaml_and_deserialize_params() {
let dir = TempDir::new().unwrap();
let path = write_yaml(&dir, "params.yml", "scale: 5\noffset: 10\n");
let value = load_yaml(&path).unwrap();
let params: TestParams = deserialize_config(value).unwrap();
assert_eq!(params.scale.0, 5);
assert_eq!(params.offset.0, 10);
}
#[test]
fn test_load_yaml_and_deserialize_catalog() {
let dir = TempDir::new().unwrap();
let path = write_yaml(&dir, "catalog.yml", "a: {}\nb: {}\nc: {}\n");
let value = load_yaml(&path).unwrap();
let _catalog: TestCatalog = deserialize_config(value).unwrap();
}
#[test]
fn test_load_yaml_missing_file() {
let result = load_yaml("/nonexistent/path/missing.yml");
assert!(result.is_err());
}
#[test]
fn test_apply_overrides_flat_key() {
let dir = TempDir::new().unwrap();
let path = write_yaml(&dir, "params.yml", "scale: 5\noffset: 10\n");
let mut value = load_yaml(&path).unwrap();
apply_overrides(&mut value, &["scale=99".to_string()]);
let params: TestParams = deserialize_config(value).unwrap();
assert_eq!(params.scale.0, 99);
assert_eq!(params.offset.0, 10);
}
#[test]
fn test_apply_overrides_nested_dot_notation() {
let dir = TempDir::new().unwrap();
let path = write_yaml(
&dir,
"params.yml",
"model:\n learning_rate: 0.001\n epochs: 10\nthreshold: 0.5\n",
);
let mut value = load_yaml(&path).unwrap();
apply_overrides(
&mut value,
&[
"model.learning_rate=0.01".to_string(),
"model.epochs=50".to_string(),
"threshold=0.9".to_string(),
],
);
let params: NestedParams = deserialize_config(value).unwrap();
assert!((params.model.learning_rate.0 - 0.01).abs() < 1e-9);
assert_eq!(params.model.epochs.0, 50);
assert!((params.threshold.0 - 0.9).abs() < 1e-9);
}
#[test]
fn test_apply_overrides_bool_and_null_parsing() {
let dir = TempDir::new().unwrap();
let path = write_yaml(&dir, "conf.yml", "flag: false\ncount: 0\n");
let mut value = load_yaml(&path).unwrap();
apply_overrides(&mut value, &["flag=true".to_string(), "count=42".to_string()]);
assert_eq!(value["flag"], serde_yaml::Value::Bool(true));
assert_eq!(
value["count"],
serde_yaml::Value::Number(serde_yaml::Number::from(42))
);
}
fn seq_pipeline<'a>(
cat: &'a TestCatalog,
params: &'a TestParams,
) -> impl Steps<PondError> + 'a {
(
Node {
name: "multiply",
func: |v: i32, scale: i32| (v * scale,),
input: (¶ms.offset, ¶ms.scale),
output: (&cat.a,),
},
Node {
name: "add",
func: |a: i32, off: i32| (a + off,),
input: (&cat.a, ¶ms.offset),
output: (&cat.b,),
},
Node {
name: "square",
func: |b: i32| (b * b,),
input: (&cat.b,),
output: (&cat.c,),
},
)
}
fn par_pipeline<'a>(
cat: &'a TestCatalog,
params: &'a TestParams,
) -> impl Steps<PondError> + 'a {
(
Node {
name: "make_a",
func: |v: i32| (v * 2,),
input: (¶ms.scale,),
output: (&cat.a,),
},
Node {
name: "make_b",
func: |v: i32| (v + 100,),
input: (¶ms.offset,),
output: (&cat.b,),
},
Node {
name: "combine",
func: |a: i32, b: i32| (a + b,),
input: (&cat.a, &cat.b),
output: (&cat.c,),
},
)
}
fn nested_pipeline<'a>(
cat: &'a TestCatalog,
params: &'a TestParams,
) -> impl Steps<PondError> + 'a {
(
Node {
name: "init",
func: |s: i32| (s,),
input: (¶ms.scale,),
output: (&cat.a,),
},
Pipeline {
name: "transform",
steps: (
Node {
name: "add_offset",
func: |a: i32, off: i32| (a + off,),
input: (&cat.a, ¶ms.offset),
output: (&cat.b,),
},
Node {
name: "double",
func: |b: i32| (b * 2,),
input: (&cat.b,),
output: (&cat.c,),
},
),
input: (&cat.a, ¶ms.offset),
output: (&cat.c,),
},
)
}
fn error_pipeline<'a>(
cat: &'a TestCatalog,
params: &'a TestParams,
) -> impl Steps<PondError> + 'a {
(
Node {
name: "init",
func: |v: i32| (v,),
input: (¶ms.scale,),
output: (&cat.a,),
},
Node {
name: "fail",
func: |_a: i32| -> Result<(i32,), PondError> {
Err(PondError::Custom("intentional failure".to_string()))
},
input: (&cat.a,),
output: (&cat.b,),
},
)
}
#[test]
fn test_app_run_sequential() {
let dir = TempDir::new().unwrap();
let catalog_path = write_yaml(&dir, "catalog.yml", "a: {}\nb: {}\nc: {}\n");
let params_path = write_yaml(&dir, "params.yml", "scale: 3\noffset: 10\n");
let catalog: TestCatalog = deserialize_config(load_yaml(&catalog_path).unwrap()).unwrap();
let params: TestParams = deserialize_config(load_yaml(¶ms_path).unwrap()).unwrap();
let app = App::new(catalog, params)
.with_hooks((LoggingHook::new(),))
.with_runners((SequentialRunner,));
app.execute(seq_pipeline).unwrap();
assert_eq!(app.catalog().c.load().unwrap(), 1600);
}
#[test]
fn test_app_run_with_param_overrides() {
let dir = TempDir::new().unwrap();
let catalog_path = write_yaml(&dir, "catalog.yml", "a: {}\nb: {}\nc: {}\n");
let params_path = write_yaml(&dir, "params.yml", "scale: 3\noffset: 10\n");
let mut params_value = load_yaml(¶ms_path).unwrap();
apply_overrides(&mut params_value, &["scale=5".to_string(), "offset=2".to_string()]);
let catalog: TestCatalog = deserialize_config(load_yaml(&catalog_path).unwrap()).unwrap();
let params: TestParams = deserialize_config(params_value).unwrap();
let app = App::new(catalog, params)
.with_hooks((LoggingHook::new(),))
.with_runners((SequentialRunner,));
app.execute(seq_pipeline).unwrap();
assert_eq!(app.catalog().c.load().unwrap(), 144);
}
#[test]
fn test_app_run_parallel() {
let dir = TempDir::new().unwrap();
let catalog_path = write_yaml(&dir, "catalog.yml", "a: {}\nb: {}\nc: {}\n");
let params_path = write_yaml(&dir, "params.yml", "scale: 7\noffset: 3\n");
let catalog: TestCatalog = deserialize_config(load_yaml(&catalog_path).unwrap()).unwrap();
let params: TestParams = deserialize_config(load_yaml(¶ms_path).unwrap()).unwrap();
let app = App::new(catalog, params)
.with_runners((ParallelRunner::default(),));
app.execute(par_pipeline).unwrap();
assert_eq!(app.catalog().c.load().unwrap(), 117);
}
#[test]
fn test_app_check_valid() {
let dir = TempDir::new().unwrap();
let catalog_path = write_yaml(&dir, "catalog.yml", "a: {}\nb: {}\nc: {}\n");
let params_path = write_yaml(&dir, "params.yml", "scale: 1\noffset: 1\n");
let catalog: TestCatalog = deserialize_config(load_yaml(&catalog_path).unwrap()).unwrap();
let params: TestParams = deserialize_config(load_yaml(¶ms_path).unwrap()).unwrap();
let pipeline = seq_pipeline(&catalog, ¶ms);
assert!(pipeline.check().is_ok());
let graph = build_pipeline_graph(&pipeline, &catalog, ¶ms);
assert_eq!(graph.node_indices.len(), 3);
}
#[test]
fn test_app_nested_pipeline_check_and_run() {
let dir = TempDir::new().unwrap();
let catalog_path = write_yaml(&dir, "catalog.yml", "a: {}\nb: {}\nc: {}\n");
let params_path = write_yaml(&dir, "params.yml", "scale: 4\noffset: 6\n");
let catalog: TestCatalog = deserialize_config(load_yaml(&catalog_path).unwrap()).unwrap();
let params: TestParams = deserialize_config(load_yaml(¶ms_path).unwrap()).unwrap();
let app = App::new(catalog, params)
.with_hooks((LoggingHook::new(),))
.with_runners((SequentialRunner,));
{
let pipeline = nested_pipeline(app.catalog(), app.params());
assert!(pipeline.check().is_ok());
}
app.execute(nested_pipeline).unwrap();
assert_eq!(app.catalog().c.load().unwrap(), 20);
}
#[test]
fn test_app_nested_pipeline_parallel() {
let dir = TempDir::new().unwrap();
let catalog_path = write_yaml(&dir, "catalog.yml", "a: {}\nb: {}\nc: {}\n");
let params_path = write_yaml(&dir, "params.yml", "scale: 4\noffset: 6\n");
let catalog: TestCatalog = deserialize_config(load_yaml(&catalog_path).unwrap()).unwrap();
let params: TestParams = deserialize_config(load_yaml(¶ms_path).unwrap()).unwrap();
let app = App::new(catalog, params)
.with_runners((ParallelRunner::default(),));
app.execute(nested_pipeline).unwrap();
assert_eq!(app.catalog().c.load().unwrap(), 20);
}
#[test]
fn test_app_error_propagation_sequential() {
let dir = TempDir::new().unwrap();
let catalog_path = write_yaml(&dir, "catalog.yml", "a: {}\nb: {}\nc: {}\n");
let params_path = write_yaml(&dir, "params.yml", "scale: 1\noffset: 1\n");
let catalog: TestCatalog = deserialize_config(load_yaml(&catalog_path).unwrap()).unwrap();
let params: TestParams = deserialize_config(load_yaml(¶ms_path).unwrap()).unwrap();
let app = App::new(catalog, params)
.with_runners((SequentialRunner,));
let result = app.execute(error_pipeline);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("intentional failure"));
}
#[test]
fn test_app_error_propagation_parallel() {
let dir = TempDir::new().unwrap();
let catalog_path = write_yaml(&dir, "catalog.yml", "a: {}\nb: {}\nc: {}\n");
let params_path = write_yaml(&dir, "params.yml", "scale: 1\noffset: 1\n");
let catalog: TestCatalog = deserialize_config(load_yaml(&catalog_path).unwrap()).unwrap();
let params: TestParams = deserialize_config(load_yaml(¶ms_path).unwrap()).unwrap();
let app = App::new(catalog, params)
.with_runners((ParallelRunner::default(),));
let result = app.execute(error_pipeline);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("intentional failure"));
}