use anyhow::{Context, Result};
use chrono::{Duration, NaiveDateTime};
use indicatif::{ProgressBar, ProgressStyle};
use std::collections::HashMap;
use std::sync::Arc;
mod cli;
mod config;
mod io;
mod network;
mod routing;
mod state;
pub mod kernel {
pub mod muskingum;
}
use cli::get_args;
use config::{ChannelParams, ColumnConfig, OutputFormat};
use io::netcdf::init_netcdf_output;
use network::build_network_topology;
use routing::process_routing_parallel;
static OUTPUT_TYPE: &str = "NetCDF";
fn main() -> Result<()> {
let config: cli::Config = get_args()?;
run_routing(config, false)
}
fn run_routing(config: cli::Config, quiet: bool) -> Result<()> {
let dt: f32 = config.internal_timestep_seconds as f32;
let db_path: std::path::PathBuf = config.gpkg_file;
let csv_dir: std::path::PathBuf = config.csv_dir;
let _config_dir: std::path::PathBuf = config.config_dir;
let output_format: OutputFormat = match OUTPUT_TYPE {
"CSV" => OutputFormat::Csv,
"NetCDF" => OutputFormat::NetCdf,
"Both" => OutputFormat::Both,
_ => return Err(anyhow::anyhow!("Invalid output type: {}", OUTPUT_TYPE)),
};
let conn = rusqlite::Connection::open(&db_path)
.with_context(|| format!("Failed to open database: {:?}", db_path))?;
let column_config = ColumnConfig::new();
println!("Building network topology...");
let topology = build_network_topology(&conn, &column_config, &csv_dir)?;
println!("Loading channel parameters...");
let channel_params_map = network::load_channel_parameters(&conn, &topology, &column_config)?;
let csv_writer = if matches!(output_format, OutputFormat::Csv | OutputFormat::Both) {
Some(io::csv::create_csv_writer("network_routing_results.csv")?)
} else {
None
};
let (max_external_steps, reference_time) =
get_simulation_params(&csv_dir, &channel_params_map)?;
let start_time = reference_time;
let end_time = start_time + Duration::seconds((3600 * max_external_steps) as i64);
let external_timestep_seconds = 3600;
let downsampling = external_timestep_seconds / config.internal_timestep_seconds;
let total_timesteps = max_external_steps * downsampling;
println!("\nSimulation Configuration:");
println!(" Period: {} to {}", start_time, end_time);
println!(
" Internal timestep: {} seconds",
config.internal_timestep_seconds
);
println!(" Network nodes: {}", topology.routing_order.len());
println!(" Total timesteps: {}", total_timesteps);
let timesteps: Vec<f64> = (1..=max_external_steps)
.map(|step| (step * external_timestep_seconds) as f64)
.collect();
let nc_filename = format!("troute_output_{}.nc", reference_time.format("%Y%m%d%H%M"));
let netcdf_writer = init_netcdf_output(
config.output_dir,
&nc_filename,
topology.routing_order.len(),
timesteps,
&reference_time,
)?;
let pb = ProgressBar::new(topology.routing_order.len() as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} nodes ({eta})")?
.progress_chars("#>-")
);
if quiet {
pb.set_draw_target(indicatif::ProgressDrawTarget::hidden());
}
println!("\nStarting parallel wave-front routing...");
process_routing_parallel(
config.kernel,
Arc::new(topology),
Arc::new(channel_params_map),
total_timesteps,
dt,
downsampling,
netcdf_writer,
Arc::new(pb),
config.num_threads,
)?;
if let Some(mut wtr) = csv_writer {
wtr.flush().context("Failed to flush CSV writer")?;
println!("CSV results saved to network_routing_results.csv");
}
println!(
"\nNetwork routing complete. Output saved to {}",
nc_filename
);
Ok(())
}
fn get_simulation_params(
csv_dir: &std::path::PathBuf,
features: &HashMap<u32, ChannelParams>,
) -> Result<(usize, NaiveDateTime)> {
let first_id = features
.keys()
.next()
.ok_or_else(|| anyhow::anyhow!("No features found"))?;
let file_name = csv_dir.join(format!("cat-{}.csv", first_id));
let content = std::fs::read_to_string(&file_name)
.with_context(|| format!("Failed to read file: {:?}", file_name))?;
if content.lines().count() == 0 {
return Err(anyhow::anyhow!("CSV file is empty: {:?}", file_name))
.with_context(|| format!("Failed to read CSV file: {:?}", file_name));
} else if content.lines().count() == 1 {
return Err(anyhow::anyhow!(
"CSV file only contains header: {:?}",
file_name
))
.with_context(|| format!("Failed to read CSV file: {:?}", file_name));
}
let max_external_steps = content.lines().count().saturating_sub(2);
let line = content
.lines()
.nth(1)
.with_context(|| format!("Failed to read second line of CSV file: {:?}", file_name))?;
let time = line
.split(',')
.nth(1)
.with_context(|| format!("Failed to parse time from CSV line: {:?}", line))?;
let reference_time = NaiveDateTime::parse_from_str(time, "%Y-%m-%d %H:%M:%S")
.context("Failed to parse reference time")?;
Ok((max_external_steps, reference_time))
}
#[cfg(test)]
mod tests {
use crate::kernel::muskingum;
use crate::kernel::muskingum::MuskingumCungeKernel;
use super::*;
fn setup_test_config() -> cli::Config {
cli::Config {
config_dir: std::path::PathBuf::from("./tests/one_cat/config"),
csv_dir: std::path::PathBuf::from("./tests/one_cat/outputs/ngen"),
gpkg_file: std::path::PathBuf::from("./tests/one_cat/config/cat-486888_subset.gpkg"),
internal_timestep_seconds: 300,
output_dir: std::path::PathBuf::from("./tests/one_cat/outputs/troute"),
kernel: muskingum::MuskingumCungeKernel::TRouteModernized,
num_threads: 1,
}
}
#[test]
fn test_get_simulation_params() {
let config: cli::Config = setup_test_config();
let conn: rusqlite::Connection = rusqlite::Connection::open(&config.gpkg_file).unwrap();
let column_config: ColumnConfig = ColumnConfig::new();
let topology: network::NetworkTopology =
build_network_topology(&conn, &column_config, &config.csv_dir).unwrap();
let channel_params_map: HashMap<u32, ChannelParams> =
network::load_channel_parameters(&conn, &topology, &column_config).unwrap();
let (max_external_steps, reference_time) =
get_simulation_params(&config.csv_dir, &channel_params_map).unwrap();
assert_eq!(max_external_steps, 24);
assert_eq!(
reference_time,
NaiveDateTime::parse_from_str("2010-01-01 00:00:00", "%Y-%m-%d %H:%M:%S").unwrap()
);
}
#[test]
fn test_invalid_csv_file() {
let config: cli::Config = setup_test_config();
let conn: rusqlite::Connection = rusqlite::Connection::open(&config.gpkg_file).unwrap();
let column_config: ColumnConfig = ColumnConfig::new();
let topology: network::NetworkTopology =
build_network_topology(&conn, &column_config, &config.csv_dir).unwrap();
let channel_params_map: HashMap<u32, ChannelParams> =
network::load_channel_parameters(&conn, &topology, &column_config).unwrap();
let invalid_csv_dir = std::path::PathBuf::from("./tests/one_cat/outputs/invalid_csv");
let result: std::result::Result<(usize, NaiveDateTime), anyhow::Error> =
get_simulation_params(&invalid_csv_dir, &channel_params_map);
assert!(result.is_err());
assert!(
result
.err()
.unwrap()
.to_string()
.contains("Failed to read file:")
);
}
#[test]
fn test_empty_features() {
let empty_features: HashMap<u32, ChannelParams> = HashMap::new();
let result = get_simulation_params(
&std::path::PathBuf::from("./tests/one_cat/outputs/ngen"),
&empty_features,
);
assert!(result.is_err());
assert!(
result
.err()
.unwrap()
.to_string()
.contains("No features found")
);
}
#[test]
fn test_invalid_feature() {
let config: cli::Config = setup_test_config();
let mut invalid_features: HashMap<u32, ChannelParams> = HashMap::new();
invalid_features.insert(
999999,
ChannelParams {
dx: 0.0,
n: 0.0,
ncc: 0.0,
s0: 0.0,
bw: 0.0,
tw: 0.0,
twcc: 0.0,
cs: 0.0,
},
);
let result = get_simulation_params(&config.csv_dir, &invalid_features);
assert!(result.is_err());
assert!(
result
.err()
.unwrap()
.to_string()
.contains("Failed to read file:")
);
}
#[test]
fn test_run_routing() {
let tmp_dir = std::env::temp_dir().join("rs_route_test_run_routing");
std::fs::create_dir_all(&tmp_dir).unwrap();
let config = cli::Config {
output_dir: tmp_dir.clone(),
..setup_test_config()
};
let result = run_routing(config, true);
let _ = std::fs::remove_dir_all(&tmp_dir);
assert!(result.is_ok());
}
fn make_test_input() -> muskingum::MuskingumCungeInput {
muskingum::MuskingumCungeInput {
dt: 300.0,
qup: 5.0,
quc: 6.0,
qdp: 4.5,
ql: 0.5,
dx: 5000.0,
bw: 10.0,
tw: 100.0,
tw_cc: 120.0,
n: 0.06,
n_cc: 0.12,
cs: 1.0,
s0: 0.001,
velp: 0.0,
depthp: 0.5,
}
}
fn assert_results_close(
a: &muskingum::MuskingumCungeResult,
b: &muskingum::MuskingumCungeResult,
tolerance: f32,
label: &str,
) {
let flow_diff = if a.qdc == 0.0 && b.qdc == 0.0 {
0.0
} else {
((a.qdc - b.qdc) / a.qdc).abs()
};
let vel_diff = if a.velc == 0.0 && b.velc == 0.0 {
0.0
} else {
((a.velc - b.velc) / a.velc).abs()
};
let depth_diff = if a.depthc == 0.0 && b.depthc == 0.0 {
0.0
} else {
((a.depthc - b.depthc) / a.depthc).abs()
};
assert!(
flow_diff < tolerance,
"{}: flow differs by {:.1}% (got {}, expected {})",
label,
flow_diff * 100.0,
b.qdc,
a.qdc,
);
assert!(
vel_diff < tolerance,
"{}: velocity differs by {:.1}% (got {}, expected {})",
label,
vel_diff * 100.0,
b.velc,
a.velc,
);
assert!(
depth_diff < tolerance,
"{}: depth differs by {:.1}% (got {}, expected {})",
label,
depth_diff * 100.0,
b.depthc,
a.depthc,
);
}
#[test]
fn test_kernel_equivalence_fortran_kernels() {
let input = make_test_input();
let f_mod = MuskingumCungeKernel::TRouteModernized.exec(&input, false);
let f_leg = MuskingumCungeKernel::TRouteLegacy.exec(&input, false);
let tolerance = 0.0001;
assert_results_close(&f_mod, &f_leg, tolerance, "Fortran Modernized vs Legacy");
}
#[test]
fn test_kernel_equivalence_rust_vs_fortran() {
let input = make_test_input();
let rs = MuskingumCungeKernel::RouteRs.exec(&input, false);
let f_mod = MuskingumCungeKernel::TRouteModernized.exec(&input, false);
let tolerance = 0.0001; assert_results_close(&rs, &f_mod, tolerance, "RouteRs vs TRouteModernized");
}
#[test]
fn test_c_kernel_produces_output() {
let input = make_test_input();
let result = MuskingumCungeKernel::CMuskingumCunge.exec(&input, false);
assert!(
result.qdc > 0.0,
"C kernel flow should be positive, got {}",
result.qdc
);
assert!(
result.velc > 0.0,
"C kernel velocity should be positive, got {}",
result.velc
);
assert!(
result.depthc > 0.0,
"C kernel depth should be positive, got {}",
result.depthc
);
}
#[test]
fn test_kernel_equivalence_full_routing() {
let kernels = [
MuskingumCungeKernel::RouteRs,
MuskingumCungeKernel::TRouteModernized,
MuskingumCungeKernel::TRouteLegacy,
];
let mut flow_sums: Vec<(MuskingumCungeKernel, f64)> = Vec::new();
for kernel in kernels {
let tmp_dir = std::env::temp_dir().join(format!("rs_route_test_{:?}", kernel));
std::fs::create_dir_all(&tmp_dir).unwrap();
let config = cli::Config {
kernel,
output_dir: tmp_dir.clone(),
..setup_test_config()
};
let result = run_routing(config, true);
assert!(
result.is_ok(),
"Kernel {:?} failed: {:?}",
kernel,
result.err()
);
let nc_path = tmp_dir.join("troute_output_201001010000.nc");
println!("{}", nc_path.display());
let file = netcdf::open(&nc_path).unwrap();
let flow_var = file.variable("flow").unwrap();
let flow_data = flow_var.get::<f32, _>(..).unwrap();
let total_flow: f64 = flow_data.iter().map(|&v| v as f64).sum();
flow_sums.push((kernel, total_flow));
let _ = std::fs::remove_dir_all(&tmp_dir);
}
let tolerance = 0.0001; let (ref_name, ref_flow) = flow_sums[0];
for &(name, flow) in &flow_sums[1..] {
let diff = ((flow - ref_flow) / ref_flow).abs();
assert!(
diff < tolerance,
"{:?} vs {:?}: total flow differs by {:.1}% ({} vs {})",
name,
ref_name,
diff * 100.0,
flow,
ref_flow,
);
}
}
#[test]
fn test_run_invalid_config() {
let invalid_config = cli::Config {
config_dir: std::path::PathBuf::from("./tests/invalid_test/config"),
csv_dir: std::path::PathBuf::from("./tests/invalid_test/outputs/ngen"),
gpkg_file: std::path::PathBuf::from(
"./tests/invalid_test/config/cat-486888_subset.gpkg",
),
internal_timestep_seconds: 300,
output_dir: std::path::PathBuf::from("./tests/invalid_test/outputs/troute"),
kernel: muskingum::MuskingumCungeKernel::TRouteModernized,
num_threads: 1,
};
let result = run_routing(invalid_config, true);
assert!(result.is_err());
assert!(
result
.err()
.unwrap()
.to_string()
.contains("Failed to open database:")
);
}
}