use anyhow::{Context, Result, ensure};
use csv::{ReaderBuilder, Trim, Writer};
use indexmap::IndexSet;
use std::fs;
use std::path::{Path, PathBuf};
pub struct ModelPatch {
base_model_dir: PathBuf,
file_patches: Vec<FilePatch>,
toml_patch: Option<toml::value::Table>,
}
impl ModelPatch {
pub fn new<P: Into<PathBuf>>(base_model_dir: P) -> Self {
ModelPatch {
base_model_dir: base_model_dir.into(),
file_patches: Vec::new(),
toml_patch: None,
}
}
pub fn from_example(name: &str) -> Self {
let base_model_dir = PathBuf::from("examples").join(name);
ModelPatch::new(base_model_dir)
}
pub fn with_file_patch(mut self, patch: FilePatch) -> Self {
self.file_patches.push(patch);
self
}
pub fn with_file_patches<I>(mut self, patches: I) -> Self
where
I: IntoIterator<Item = FilePatch>,
{
self.file_patches.extend(patches);
self
}
pub fn with_toml_patch(mut self, patch_str: impl AsRef<str>) -> Self {
assert!(
self.toml_patch.is_none(),
"TOML patch already set for this ModelPatch"
);
let s = patch_str.as_ref();
let patch: toml::value::Table =
toml::from_str(s).expect("Failed to parse string passed to with_toml_patch");
self.toml_patch = Some(patch);
self
}
pub fn build<O: AsRef<Path>>(&self, out_dir: O) -> Result<()> {
let base_dir = self.base_model_dir.as_path();
let out_path = out_dir.as_ref();
let base_toml_path = base_dir.join("model.toml");
let out_toml_path = out_path.join("model.toml");
if let Some(toml_patch) = &self.toml_patch {
let toml_content = fs::read_to_string(&base_toml_path)?;
let merged_toml = merge_model_toml(&toml_content, toml_patch)?;
fs::write(&out_toml_path, merged_toml)?;
} else {
fs::copy(&base_toml_path, &out_toml_path)?;
}
for entry in fs::read_dir(base_dir)? {
let entry = entry?;
let src_path = entry.path();
if src_path.is_file()
&& src_path
.extension()
.and_then(|e| e.to_str())
.is_some_and(|ext| ext.eq_ignore_ascii_case("csv"))
{
let dst_path = out_path.join(entry.file_name());
fs::copy(&src_path, &dst_path)?;
}
}
for patch in &self.file_patches {
patch.apply_and_save(base_dir, out_path)?;
}
Ok(())
}
pub fn build_to_tempdir(&self) -> Result<tempfile::TempDir> {
let temp_dir = tempfile::tempdir()?;
self.build(temp_dir.path())?;
Ok(temp_dir)
}
}
type CSVTable = IndexSet<Vec<String>>;
#[derive(Clone)]
pub struct FilePatch {
filename: String,
header_row: Option<Vec<String>>,
replacement_content: Option<String>,
to_delete: CSVTable,
to_add: CSVTable,
}
impl FilePatch {
pub fn new(filename: impl Into<String>) -> Self {
FilePatch {
filename: filename.into(),
header_row: None,
replacement_content: None,
to_delete: IndexSet::new(),
to_add: IndexSet::new(),
}
}
pub fn with_header(mut self, header: impl Into<String>) -> Self {
assert!(
self.replacement_content.is_none(),
"Cannot set header when replacement content is set for this FilePatch",
);
assert!(
self.header_row.is_none(),
"Header already set for this FilePatch",
);
let s = header.into();
let v = s.split(',').map(|s| s.trim().to_string()).collect();
self.header_row = Some(v);
self
}
pub fn with_replacement(mut self, lines: &[&str]) -> Self {
assert!(
self.header_row.is_none(),
"Cannot set replacement content when header is set for this FilePatch",
);
assert!(
self.to_delete.is_empty() && self.to_add.is_empty(),
"Cannot set replacement content when additions/deletions are set for this FilePatch",
);
assert!(
self.replacement_content.is_none(),
"Replacement content already set for this FilePatch",
);
if !lines.is_empty() {
let first_col_count = lines[0].matches(',').count() + 1;
for (idx, line) in lines.iter().enumerate() {
let col_count = line.matches(',').count() + 1;
assert_eq!(
col_count, first_col_count,
"Line {idx} has {col_count} columns but line 0 has {first_col_count}: {line:?}"
);
}
}
let content = lines.join("\n") + "\n";
self.replacement_content = Some(content);
self
}
pub fn with_addition(mut self, row: impl Into<String>) -> Self {
assert!(
self.replacement_content.is_none(),
"Cannot add rows when replacement content is set for this FilePatch",
);
let s = row.into();
let v = s.split(',').map(|s| s.trim().to_string()).collect();
self.to_add.insert(v);
self
}
pub fn with_deletion(mut self, row: impl Into<String>) -> Self {
assert!(
self.replacement_content.is_none(),
"Cannot delete rows when replacement content is set for this FilePatch",
);
let s = row.into();
let v = s.split(',').map(|s| s.trim().to_string()).collect();
self.to_delete.insert(v);
self
}
fn apply(&self, base_model_dir: &Path) -> Result<String> {
let base_path = base_model_dir.join(&self.filename);
ensure!(
base_path.exists() && base_path.is_file(),
"Base file for patching does not exist: {}",
base_path.display()
);
if let Some(content) = &self.replacement_content {
return Ok(content.clone());
}
let base = fs::read_to_string(&base_path)?;
let modified = modify_base_with_patch(&base, self)
.with_context(|| format!("Error applying patch to file: {}", self.filename))?;
Ok(modified)
}
pub fn apply_and_save(&self, base_model_dir: &Path, out_model_dir: &Path) -> Result<()> {
let modified = self.apply(base_model_dir)?;
let new_path = out_model_dir.join(&self.filename);
fs::write(&new_path, modified)?;
Ok(())
}
}
fn merge_model_toml(base_toml: &str, patch: &toml::value::Table) -> Result<String> {
let mut base_val: toml::Value = toml::from_str(base_toml)?;
let base_tbl = base_val
.as_table_mut()
.context("Base model TOML must be a table")?;
for (k, v) in patch {
base_tbl.insert(k.clone(), v.clone());
}
let out = toml::to_string_pretty(&base_val)?;
Ok(out)
}
fn modify_base_with_patch(base: &str, patch: &FilePatch) -> Result<String> {
let mut reader = ReaderBuilder::new()
.trim(Trim::All)
.from_reader(base.as_bytes());
let base_header = reader
.headers()
.context("Failed to read base file header")?;
let base_header_vec: Vec<String> = base_header.iter().map(ToString::to_string).collect();
if let Some(ref header_row_vec) = patch.header_row {
ensure!(
base_header_vec == *header_row_vec,
"Header mismatch: base file has [{}], patch has [{}]",
base_header_vec.join(", "),
header_row_vec.join(", ")
);
}
let mut base_rows: CSVTable = CSVTable::new();
for result in reader.records() {
let record = result?;
let row_vec = record
.iter()
.map(|s| s.trim().to_string())
.collect::<Vec<_>>();
ensure!(
base_rows.insert(row_vec.clone()),
"Duplicate row in base file: {row_vec:?}",
);
}
for del_row in &patch.to_delete {
ensure!(
!patch.to_add.contains(del_row),
"Row appears in both deletions and additions: {del_row:?}",
);
}
for del_row in &patch.to_delete {
ensure!(
base_rows.contains(del_row),
"Row to delete not present in base file: {del_row:?}"
);
}
base_rows.retain(|row| !patch.to_delete.contains(row));
for add_row in &patch.to_add {
ensure!(
base_rows.insert(add_row.clone()),
"Addition already present in base file: {add_row:?}"
);
}
let expected_len = base_header_vec.len();
for row in &base_rows {
ensure!(
row.len() == expected_len,
"Row has {} columns but header has {expected_len}: {row:?}",
row.len(),
);
}
let mut wtr = Writer::from_writer(vec![]);
wtr.write_record(base_header_vec.iter())?;
for row in &base_rows {
let row_iter = row.iter().map(String::as_str);
wtr.write_record(row_iter)?;
}
wtr.flush()?;
let inner = wtr.into_inner()?;
let output = String::from_utf8(inner)?;
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fixture::assert_error;
use crate::input::read_toml;
use crate::model::ModelParameters;
use crate::patch::{FilePatch, ModelPatch};
#[test]
fn modify_base_with_patch_works() {
let base = "col1,col2\nvalue1,value2\nvalue3,value4\nvalue5,value6\n";
let patch = FilePatch::new("test.csv")
.with_header("col1,col2")
.with_deletion("value3,value4")
.with_addition("value7,value8");
let modified = modify_base_with_patch(base, &patch).unwrap();
let lines: Vec<&str> = modified.lines().collect();
assert_eq!(lines[0], "col1,col2"); assert_eq!(lines[1], "value1,value2"); assert_eq!(lines[2], "value5,value6"); assert_eq!(lines[3], "value7,value8"); assert!(!modified.contains("value3,value4")); }
#[test]
fn modify_base_with_patch_mismatched_header() {
let base = "col1,col2\nvalue1,value2\n";
let patch = FilePatch::new("test.csv").with_header("col1,col3");
assert_error!(
modify_base_with_patch(base, &patch),
"Header mismatch: base file has [col1, col2], patch has [col1, col3]"
);
}
#[test]
fn merge_model_toml_basic() {
let base = r#"
field = "data"
[section]
a = 1
"#;
let mut patch = toml::value::Table::new();
patch.insert(
"field".to_string(),
toml::Value::String("patched".to_string()),
);
patch.insert(
"new_field".to_string(),
toml::Value::String("added".to_string()),
);
let merged = merge_model_toml(base, &patch).unwrap();
assert!(merged.contains("field = \"patched\""));
assert!(merged.contains("[section]"));
assert!(merged.contains("new_field = \"added\""));
}
#[test]
fn file_patch() {
let assets_patch = FilePatch::new("assets.csv")
.with_deletion("GASDRV,GBR,A0_GEX,4002.26,2020")
.with_addition("GASDRV,GBR,A0_GEX,4003.26,2020");
let model_dir = ModelPatch::from_example("simple")
.with_file_patch(assets_patch)
.build_to_tempdir()
.unwrap();
let assets_path = model_dir.path().join("assets.csv");
let assets_content = std::fs::read_to_string(assets_path).unwrap();
assert!(!assets_content.contains("GASDRV,GBR,A0_GEX,4002.26,2020"));
assert!(assets_content.contains("GASDRV,GBR,A0_GEX,4003.26,2020"));
}
#[test]
fn file_patch_with_replacement() {
let expected = "col1,col2\nnew1,new2\n";
let model_dir = ModelPatch::from_example("simple")
.with_file_patch(
FilePatch::new("assets.csv").with_replacement(&["col1,col2", "new1,new2"]),
)
.build_to_tempdir()
.unwrap();
let assets_path = model_dir.path().join("assets.csv");
let assets_content = std::fs::read_to_string(assets_path).unwrap();
assert_eq!(assets_content, expected);
}
#[test]
#[should_panic(
expected = "Cannot set replacement content when header is set for this FilePatch"
)]
fn file_patch_replacement_after_header_panics() {
let _ = FilePatch::new("assets.csv")
.with_header("col1,col2")
.with_replacement(&["col1,col2", "a,b"]);
}
#[test]
#[should_panic(
expected = "Cannot set replacement content when additions/deletions are set for this FilePatch"
)]
fn file_patch_replacement_after_addition_panics() {
let _ = FilePatch::new("assets.csv")
.with_addition("a,b")
.with_replacement(&["col1,col2", "a,b"]);
}
#[test]
#[should_panic(expected = "Cannot add rows when replacement content is set for this FilePatch")]
fn file_patch_addition_after_replacement_panics() {
let _ = FilePatch::new("assets.csv")
.with_replacement(&["col1,col2", "a,b"])
.with_addition("c,d");
}
#[test]
fn file_patch_with_replacement_missing_base_file_fails() {
let model_patch = ModelPatch::from_example("simple").with_file_patch(
FilePatch::new("not_a_real_file.csv").with_replacement(&["x,y", "1,2"]),
);
let expected = format!(
"Base file for patching does not exist: {}",
std::path::PathBuf::from("examples")
.join("simple")
.join("not_a_real_file.csv")
.display()
);
assert_error!(model_patch.build_to_tempdir(), expected);
}
#[test]
#[should_panic(expected = "Line 1 has 2 columns but line 0 has 3")]
fn file_patch_replacement_column_count_mismatch_panics() {
let _ = FilePatch::new("test.csv").with_replacement(&["col1,col2,col3", "a,b"]);
}
#[test]
fn toml_patch() {
let toml_patch = "milestone_years = [2020, 2030, 2040, 2050]\n";
let model_dir = ModelPatch::from_example("simple")
.with_toml_patch(toml_patch)
.build_to_tempdir()
.unwrap();
let toml: ModelParameters = read_toml(&model_dir.path().join("model.toml")).unwrap();
assert_eq!(toml.milestone_years, vec![2020, 2030, 2040, 2050]);
}
}