use crate::error::ForgeResult;
use crate::types::{ColumnValue, ParsedModel, Variable};
use serde_yaml_ng::Value;
use std::collections::HashMap;
use std::fs;
use std::hash::BuildHasher;
use std::path::Path;
pub fn update_yaml_file<S: BuildHasher>(
path: &Path,
calculated_values: &HashMap<String, f64, S>,
) -> ForgeResult<()> {
let content = fs::read_to_string(path)?;
let mut yaml: Value = serde_yaml_ng::from_str(&content)?;
for (var_path, calculated_value) in calculated_values {
update_value_in_yaml(&mut yaml, var_path, *calculated_value);
}
let updated_content = serde_yaml_ng::to_string(&yaml)?;
fs::write(path, updated_content)?;
Ok(())
}
pub fn write_calculated_results(path: &Path, result: &ParsedModel) -> ForgeResult<bool> {
let content = fs::read_to_string(path)?;
let content_trimmed = content.trim_start();
if content_trimmed.starts_with("---") && content_trimmed[3..].contains("\n---") {
return Ok(false); }
let backup_path = path.with_extension("yaml.bak");
fs::copy(path, &backup_path)?;
let mut yaml: Value = serde_yaml_ng::from_str(&content)?;
if let Value::Mapping(ref mut root) = yaml {
for (table_name, table) in &result.tables {
if let Some(Value::Mapping(table_map)) = root.get_mut(Value::String(table_name.clone()))
{
if let Some(col) = table.columns.get("value") {
if let ColumnValue::Number(values) = &col.values {
let yaml_values: Vec<Value> = values
.iter()
.map(|v| {
#[allow(clippy::cast_possible_truncation)]
if v.fract() == 0.0 && v.abs() < 1e10 {
Value::Number(serde_yaml_ng::Number::from(*v as i64))
} else {
Value::Number(serde_yaml_ng::Number::from(*v))
}
})
.collect();
table_map.insert(
Value::String("value".to_string()),
Value::Sequence(yaml_values),
);
}
}
}
}
for (name, var) in &result.scalars {
if let Some(value) = var.value {
update_value_in_yaml(&mut yaml, name, value);
}
}
}
let updated_content = serde_yaml_ng::to_string(&yaml)?;
fs::write(path, updated_content)?;
Ok(true)
}
pub fn update_scalars<S: BuildHasher>(
path: &Path,
scalars: &HashMap<String, Variable, S>,
) -> ForgeResult<()> {
let mut calculated_values = HashMap::new();
for (name, var) in scalars {
if let Some(value) = var.value {
calculated_values.insert(name.clone(), value);
}
}
update_yaml_file(path, &calculated_values)
}
fn update_value_in_yaml(yaml: &mut Value, path: &str, new_value: f64) {
let parts: Vec<&str> = path.split('.').collect();
update_value_recursive(yaml, &parts, 0, new_value);
}
fn update_value_recursive(yaml: &mut Value, path_parts: &[&str], index: usize, new_value: f64) {
if index >= path_parts.len() {
return;
}
let current_part = path_parts[index];
if let Value::Mapping(map) = yaml {
if index == path_parts.len() - 1 {
if let Some(Value::Mapping(inner_map)) =
map.get_mut(Value::String(current_part.to_string()))
{
if inner_map.contains_key(Value::String("value".to_string())) {
inner_map.insert(
Value::String("value".to_string()),
Value::Number(serde_yaml_ng::Number::from(new_value)),
);
}
}
} else {
if let Some(entry) = map.get_mut(Value::String(current_part.to_string())) {
update_value_recursive(entry, path_parts, index + 1, new_value);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_update_simple_value() {
let yaml_content = r#"
gross_margin:
value: 0.0
formula: "=1 - platform_take_rate"
"#;
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(yaml_content.as_bytes()).unwrap();
let mut values = HashMap::new();
values.insert("gross_margin".to_string(), 0.90);
update_yaml_file(temp_file.path(), &values).unwrap();
let updated_content = fs::read_to_string(temp_file.path()).unwrap();
assert!(updated_content.contains("0.9") || updated_content.contains("0.90"));
}
#[test]
fn test_update_nested_value() {
let yaml_content = r#"
summary:
total:
value: 0.0
formula: "=SUM(data.values)"
"#;
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(yaml_content.as_bytes()).unwrap();
let mut values = HashMap::new();
values.insert("summary.total".to_string(), 150.0);
update_yaml_file(temp_file.path(), &values).unwrap();
let updated_content = fs::read_to_string(temp_file.path()).unwrap();
assert!(updated_content.contains("150"));
}
#[test]
fn test_update_scalars() {
let yaml_content = r"
revenue:
value: 0.0
formula: null
";
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(yaml_content.as_bytes()).unwrap();
let mut scalars = HashMap::new();
let var = Variable::new("revenue".to_string(), Some(1000.0), None);
scalars.insert("revenue".to_string(), var);
update_scalars(temp_file.path(), &scalars).unwrap();
let updated_content = fs::read_to_string(temp_file.path()).unwrap();
assert!(updated_content.contains("1000"));
}
#[test]
fn test_update_scalars_with_no_value() {
let yaml_content = r"
revenue:
value: 100.0
formula: null
";
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(yaml_content.as_bytes()).unwrap();
let mut scalars = HashMap::new();
let var = Variable::new("revenue".to_string(), None, None);
scalars.insert("revenue".to_string(), var);
update_scalars(temp_file.path(), &scalars).unwrap();
let updated_content = fs::read_to_string(temp_file.path()).unwrap();
assert!(updated_content.contains("100"));
}
#[test]
fn test_write_calculated_results_skips_multidoc() {
use crate::types::ParsedModel;
let yaml_content = "---\nfirst: 1\n---\nsecond: 2\n";
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(yaml_content.as_bytes()).unwrap();
let model = ParsedModel::new();
let result = write_calculated_results(temp_file.path(), &model).unwrap();
assert!(!result, "Multi-doc YAML should be skipped");
}
#[test]
fn test_write_calculated_results_creates_backup() {
use crate::types::ParsedModel;
let yaml_content = r#"
_forge_version: "5.0.0"
test_scalar:
value: 100.0
formula: null
"#;
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(yaml_content.as_bytes()).unwrap();
let path = temp_file.path();
let model = ParsedModel::new();
let result = write_calculated_results(path, &model).unwrap();
assert!(result, "Single-doc YAML should be written");
let backup_path = path.with_extension("yaml.bak");
assert!(backup_path.exists(), "Backup file should be created");
let _ = fs::remove_file(backup_path);
}
#[test]
fn test_write_calculated_results_with_scalars() {
use crate::types::ParsedModel;
let yaml_content = r#"
profit:
value: 0.0
formula: "=revenue - costs"
"#;
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(yaml_content.as_bytes()).unwrap();
let path = temp_file.path();
let mut model = ParsedModel::new();
let var = Variable::new("profit".to_string(), Some(500.0), None);
model.scalars.insert("profit".to_string(), var);
write_calculated_results(path, &model).unwrap();
let updated_content = fs::read_to_string(path).unwrap();
assert!(updated_content.contains("500"));
let _ = fs::remove_file(path.with_extension("yaml.bak"));
}
#[test]
fn test_write_calculated_results_with_tables() {
use crate::types::{Column, ColumnValue, ParsedModel, Table};
let yaml_content = r"
financials:
value: [0, 0, 0]
";
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(yaml_content.as_bytes()).unwrap();
let path = temp_file.path();
let mut model = ParsedModel::new();
let mut table = Table::new("financials".to_string());
table.add_column(Column::new(
"value".to_string(),
ColumnValue::Number(vec![100.0, 200.0, 300.0]),
));
model.tables.insert("financials".to_string(), table);
write_calculated_results(path, &model).unwrap();
let updated_content = fs::read_to_string(path).unwrap();
assert!(updated_content.contains("100"));
assert!(updated_content.contains("200"));
assert!(updated_content.contains("300"));
let _ = fs::remove_file(path.with_extension("yaml.bak"));
}
#[test]
fn test_update_value_empty_path() {
let mut yaml: Value = serde_yaml_ng::from_str("test: 1").unwrap();
update_value_in_yaml(&mut yaml, "", 0.0);
}
#[test]
fn test_update_value_nonexistent_path() {
let mut yaml: Value = serde_yaml_ng::from_str("test: 1").unwrap();
update_value_in_yaml(&mut yaml, "nonexistent.path", 99.0);
}
#[test]
fn test_update_multiple_values() {
let yaml_content = r"
revenue:
value: 0.0
costs:
value: 0.0
profit:
value: 0.0
";
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(yaml_content.as_bytes()).unwrap();
let mut values = HashMap::new();
values.insert("revenue".to_string(), 1000.0);
values.insert("costs".to_string(), 400.0);
values.insert("profit".to_string(), 600.0);
update_yaml_file(temp_file.path(), &values).unwrap();
let updated_content = fs::read_to_string(temp_file.path()).unwrap();
assert!(updated_content.contains("1000"));
assert!(updated_content.contains("400"));
assert!(updated_content.contains("600"));
}
#[test]
fn test_write_results_fractional_values() {
use crate::types::ParsedModel;
let yaml_content = r"
rate:
value: 0.0
";
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(yaml_content.as_bytes()).unwrap();
let path = temp_file.path();
let mut model = ParsedModel::new();
let var = Variable::new("rate".to_string(), Some(0.05), None);
model.scalars.insert("rate".to_string(), var);
write_calculated_results(path, &model).unwrap();
let updated_content = fs::read_to_string(path).unwrap();
assert!(updated_content.contains("0.05"));
let _ = fs::remove_file(path.with_extension("yaml.bak"));
}
#[test]
fn test_write_results_integer_values() {
use crate::types::{Column, ColumnValue, ParsedModel, Table};
let yaml_content = r"
counts:
value: [0, 0]
";
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(yaml_content.as_bytes()).unwrap();
let path = temp_file.path();
let mut model = ParsedModel::new();
let mut table = Table::new("counts".to_string());
table.add_column(Column::new(
"value".to_string(),
ColumnValue::Number(vec![10.0, 20.0]),
));
model.tables.insert("counts".to_string(), table);
write_calculated_results(path, &model).unwrap();
let updated_content = fs::read_to_string(path).unwrap();
assert!(updated_content.contains("10"));
assert!(updated_content.contains("20"));
let _ = fs::remove_file(path.with_extension("yaml.bak"));
}
}