use crate::graph::investment::solve_investment_order_for_model;
use crate::graph::validate::validate_commodity_graphs_for_model;
use crate::graph::{CommoditiesGraph, build_commodity_graphs_for_model};
use crate::id::{HasID, IDLike};
use crate::model::{Model, ModelParameters};
use crate::region::RegionID;
use crate::units::UnitType;
use anyhow::{Context, Result, bail, ensure};
use float_cmp::approx_eq;
use indexmap::IndexMap;
use itertools::Itertools;
use serde::de::{Deserialize, DeserializeOwned, Deserializer};
use std::collections::HashMap;
use std::fmt::{self, Write};
use std::fs;
use std::hash::Hash;
use std::path::Path;
mod agent;
use agent::read_agents;
mod asset;
use asset::read_user_assets;
mod commodity;
use commodity::read_commodities;
mod process;
use process::read_processes;
mod region;
use region::read_regions;
mod time_slice;
use time_slice::read_time_slice_info;
pub trait Insert<K, V> {
fn insert(&mut self, key: K, value: V) -> Option<V>;
}
impl<K: Eq + Hash, V> Insert<K, V> for HashMap<K, V> {
fn insert(&mut self, key: K, value: V) -> Option<V> {
HashMap::insert(self, key, value)
}
}
impl<K: Eq + Hash, V> Insert<K, V> for IndexMap<K, V> {
fn insert(&mut self, key: K, value: V) -> Option<V> {
IndexMap::insert(self, key, value)
}
}
pub fn read_csv<'a, T: DeserializeOwned + 'a>(
file_path: &'a Path,
) -> Result<impl Iterator<Item = T> + 'a> {
let vec = read_csv_internal(file_path)?;
if vec.is_empty() {
bail!("CSV file {} cannot be empty", file_path.display());
}
Ok(vec.into_iter())
}
pub fn read_csv_optional<'a, T: DeserializeOwned + 'a>(
file_path: &'a Path,
) -> Result<impl Iterator<Item = T> + 'a> {
if !file_path.exists() {
return Ok(Vec::new().into_iter());
}
let vec = read_csv_internal(file_path)?;
Ok(vec.into_iter())
}
fn read_csv_internal<'a, T: DeserializeOwned + 'a>(file_path: &'a Path) -> Result<Vec<T>> {
let vec = csv::ReaderBuilder::new()
.trim(csv::Trim::All)
.from_path(file_path)
.with_context(|| input_err_msg(file_path))?
.into_deserialize()
.process_results(|iter| iter.collect_vec())
.with_context(|| input_err_msg(file_path))?;
Ok(vec)
}
pub fn read_toml<T: DeserializeOwned>(file_path: &Path) -> Result<T> {
let toml_str = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?;
let toml_data = toml::from_str(&toml_str).with_context(|| input_err_msg(file_path))?;
Ok(toml_data)
}
pub fn deserialise_proportion_nonzero<'de, D, T>(deserialiser: D) -> Result<T, D::Error>
where
T: UnitType,
D: Deserializer<'de>,
{
let value = f64::deserialize(deserialiser)?;
if !(value > 0.0 && value <= 1.0) {
Err(serde::de::Error::custom("Value must be > 0 and <= 1"))?;
}
Ok(T::new(value))
}
pub fn input_err_msg<P: AsRef<Path>>(file_path: P) -> String {
format!("Error reading {}", file_path.as_ref().display())
}
fn read_csv_id_file<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
where
T: HasID<ID> + DeserializeOwned,
{
fn fill_and_validate_map<T, ID: IDLike>(file_path: &Path) -> Result<IndexMap<ID, T>>
where
T: HasID<ID> + DeserializeOwned,
{
let mut map = IndexMap::new();
for record in read_csv::<T>(file_path)? {
let id = record.get_id().clone();
let existing = map.insert(id.clone(), record).is_some();
ensure!(!existing, "Duplicate ID found: {id}");
}
ensure!(!map.is_empty(), "CSV file is empty");
Ok(map)
}
fill_and_validate_map(file_path).with_context(|| input_err_msg(file_path))
}
fn check_values_sum_to_one_approx<I, T>(fractions: I) -> Result<()>
where
T: UnitType,
I: Iterator<Item = T>,
{
let sum = fractions.sum();
ensure!(
approx_eq!(T, sum, T::new(1.0), epsilon = 1e-5),
"Sum of fractions does not equal one (actual: {sum})"
);
Ok(())
}
pub fn is_sorted_and_unique<T, I>(iter: I) -> bool
where
T: PartialOrd + Clone,
I: IntoIterator<Item = T>,
{
iter.into_iter().tuple_windows().all(|(a, b)| a < b)
}
pub fn try_insert<M, K, V>(map: &mut M, key: &K, value: V) -> Result<()>
where
M: Insert<K, V>,
K: Eq + Hash + Clone + std::fmt::Debug,
{
let existing = map.insert(key.clone(), value).is_some();
ensure!(!existing, "Key {key:?} already exists in the map");
Ok(())
}
pub fn format_items_with_cap<I, J, T>(items: I) -> String
where
I: IntoIterator<Item = T, IntoIter = J>,
J: ExactSizeIterator<Item = T>,
T: fmt::Debug,
{
const MAX_DISPLAY: usize = 10;
let items = items.into_iter();
let total_count = items.len();
let formatted_items = items
.take(MAX_DISPLAY)
.format_with(", ", |items, f| f(&format_args!("{items:?}")));
let mut out = format!("[{formatted_items}]");
if total_count > MAX_DISPLAY {
write!(&mut out, " and {} more", total_count - MAX_DISPLAY).unwrap();
}
out
}
pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<Model> {
let model_params = ModelParameters::from_path(&model_dir)?;
let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
let regions = read_regions(model_dir.as_ref())?;
let region_ids = regions.keys().cloned().collect();
let years = &model_params.milestone_years;
let commodities = read_commodities(model_dir.as_ref(), ®ion_ids, &time_slice_info, years)?;
let processes = read_processes(
model_dir.as_ref(),
&commodities,
®ion_ids,
&time_slice_info,
years,
)?;
let agents = read_agents(
model_dir.as_ref(),
&commodities,
&processes,
®ion_ids,
years,
)?;
let agent_ids = agents.keys().cloned().collect();
let user_assets = read_user_assets(model_dir.as_ref(), &agent_ids, &processes, ®ion_ids)?;
let commodity_graphs = build_commodity_graphs_for_model(&processes, ®ion_ids, years);
validate_commodity_graphs_for_model(
&commodity_graphs,
&processes,
&commodities,
&time_slice_info,
)?;
let investment_order =
solve_investment_order_for_model(&commodity_graphs, &commodities, years)?;
let model_path = model_dir
.as_ref()
.canonicalize()
.context("Could not parse path to model")?;
let model = Model {
model_path,
parameters: model_params,
agents,
commodities,
processes,
time_slice_info,
regions,
user_assets,
investment_order,
};
Ok(model)
}
pub fn load_commodity_graphs<P: AsRef<Path>>(
model_dir: P,
) -> Result<IndexMap<(RegionID, u32), CommoditiesGraph>> {
let model_params = ModelParameters::from_path(&model_dir)?;
let time_slice_info = read_time_slice_info(model_dir.as_ref())?;
let regions = read_regions(model_dir.as_ref())?;
let region_ids = regions.keys().cloned().collect();
let years = &model_params.milestone_years;
let commodities = read_commodities(model_dir.as_ref(), ®ion_ids, &time_slice_info, years)?;
let processes = read_processes(
model_dir.as_ref(),
&commodities,
®ion_ids,
&time_slice_info,
years,
)?;
let commodity_graphs = build_commodity_graphs_for_model(&processes, ®ion_ids, years);
Ok(commodity_graphs)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::id::GenericID;
use crate::units::Dimensionless;
use rstest::rstest;
use serde::Deserialize;
use serde::de::IntoDeserializer;
use serde::de::value::{Error as ValueError, F64Deserializer};
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
use tempfile::tempdir;
#[derive(Debug, PartialEq, Deserialize)]
struct Record {
id: GenericID,
value: u32,
}
impl HasID<GenericID> for Record {
fn get_id(&self) -> &GenericID {
&self.id
}
}
fn create_csv_file(dir_path: &Path, contents: &str) -> PathBuf {
let file_path = dir_path.join("test.csv");
let mut file = File::create(&file_path).unwrap();
writeln!(file, "{contents}").unwrap();
file_path
}
#[test]
fn read_csv_works() {
let dir = tempdir().unwrap();
let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
assert_eq!(
records,
&[
Record {
id: "hello".into(),
value: 1,
},
Record {
id: "world".into(),
value: 2,
}
]
);
let dir = tempdir().unwrap();
let file_path = create_csv_file(dir.path(), "id , value\t\n hello\t ,1\n world ,2\n");
let records: Vec<Record> = read_csv(&file_path).unwrap().collect();
assert_eq!(
records,
&[
Record {
id: "hello".into(),
value: 1,
},
Record {
id: "world".into(),
value: 2,
}
]
);
let file_path = create_csv_file(dir.path(), "id,value\n");
assert!(read_csv::<Record>(&file_path).is_err());
assert!(
read_csv_optional::<Record>(&file_path)
.unwrap()
.next()
.is_none()
);
let dir = tempdir().unwrap();
let file_path = dir.path().join("a_missing_file.csv");
assert!(!file_path.exists());
assert!(read_csv::<Record>(&file_path).is_err());
assert!(
read_csv_optional::<Record>(&file_path)
.unwrap()
.next()
.is_none()
);
}
#[test]
fn read_toml_works() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.toml");
{
let mut file = File::create(&file_path).unwrap();
writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
}
assert_eq!(
read_toml::<Record>(&file_path).unwrap(),
Record {
id: "hello".into(),
value: 1,
}
);
{
let mut file = File::create(&file_path).unwrap();
writeln!(file, "bad toml syntax").unwrap();
}
read_toml::<Record>(&file_path).unwrap_err();
}
fn deserialise_f64(value: f64) -> Result<Dimensionless, ValueError> {
let deserialiser: F64Deserializer<ValueError> = value.into_deserializer();
deserialise_proportion_nonzero(deserialiser)
}
#[test]
fn deserialise_proportion_nonzero_works() {
assert_eq!(deserialise_f64(0.01), Ok(Dimensionless(0.01)));
assert_eq!(deserialise_f64(0.5), Ok(Dimensionless(0.5)));
assert_eq!(deserialise_f64(1.0), Ok(Dimensionless(1.0)));
deserialise_f64(0.0).unwrap_err();
deserialise_f64(-1.0).unwrap_err();
deserialise_f64(2.0).unwrap_err();
deserialise_f64(f64::NAN).unwrap_err();
deserialise_f64(f64::INFINITY).unwrap_err();
}
#[test]
fn check_values_sum_to_one_approx_works() {
check_values_sum_to_one_approx([Dimensionless(1.0)].into_iter()).unwrap();
check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.6)].into_iter())
.unwrap();
assert!(check_values_sum_to_one_approx([Dimensionless(0.5)].into_iter()).is_err());
assert!(
check_values_sum_to_one_approx([Dimensionless(0.4), Dimensionless(0.3)].into_iter())
.is_err()
);
assert!(
check_values_sum_to_one_approx([Dimensionless(f64::INFINITY)].into_iter()).is_err()
);
assert!(check_values_sum_to_one_approx([Dimensionless(f64::NAN)].into_iter()).is_err());
}
#[rstest]
#[case(&[], true)]
#[case(&[1], true)]
#[case(&[1,2], true)]
#[case(&[1,2,3,4], true)]
#[case(&[2,1],false)]
#[case(&[1,1],false)]
#[case(&[1,3,2,4], false)]
fn is_sorted_and_unique_works(#[case] values: &[u32], #[case] expected: bool) {
assert_eq!(is_sorted_and_unique(values), expected);
}
#[test]
fn format_items_with_cap_works() {
let items = vec!["a", "b", "c"];
assert_eq!(format_items_with_cap(&items), r#"["a", "b", "c"]"#);
let many_items = vec!["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"];
assert_eq!(
format_items_with_cap(&many_items),
r#"["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] and 2 more"#
);
}
}