use cobre_core::{EntityId, System};
use cobre_stochastic::StochasticContext;
use crate::error::SddpError;
const DIAG_ENV_VAR: &str = "COBRE_W1_DIAG";
#[derive(Debug, Clone)]
pub struct NoiseKeyDiag {
keys: Vec<Vec<f64>>,
}
impl NoiseKeyDiag {
pub(crate) fn from_keys_if_enabled(keys: &[Vec<f64>]) -> Option<Self> {
std::env::var_os(DIAG_ENV_VAR)?;
Some(Self {
keys: keys.to_vec(),
})
}
#[must_use]
pub(crate) fn key(&self, stage: usize, omega: usize) -> Option<f64> {
self.keys.get(stage).and_then(|s| s.get(omega).copied())
}
}
pub(crate) fn build_noise_key_table(
system: &System,
stochastic: &StochasticContext,
) -> Result<Vec<Vec<f64>>, SddpError> {
let n_hydros = stochastic.n_hydros();
let hydro_ids: Vec<EntityId> = system.hydros().iter().map(|h| h.id).collect();
let study_stage_ids: Vec<i32> = system
.stages()
.iter()
.filter(|s| s.id >= 0)
.map(|s| s.id)
.collect();
let n_stages = study_stage_ids.len();
let sigma = build_sigma_table(system, &hydro_ids, &study_stage_ids, n_hydros);
let tree = stochastic.tree_view();
let mut keys = Vec::with_capacity(n_stages);
for stage in 0..n_stages {
let n_openings = tree.n_openings(stage);
let sigma_stage = &sigma[stage * n_hydros..stage * n_hydros + n_hydros];
let mut stage_keys = Vec::with_capacity(n_openings);
for omega in 0..n_openings {
let raw_noise = tree.opening(stage, omega);
stage_keys.push(noise_key(sigma_stage, raw_noise)?);
}
keys.push(stage_keys);
}
Ok(keys)
}
pub(crate) fn noise_key(sigma: &[f64], raw_noise: &[f64]) -> Result<f64, SddpError> {
if raw_noise.len() < sigma.len() {
return Err(SddpError::Validation(format!(
"noise_key σ-layout mismatch: σ length {} exceeds opening noise dimension {}; \
refusing to truncate or zero-pad",
sigma.len(),
raw_noise.len(),
)));
}
Ok(sigma.iter().zip(raw_noise.iter()).map(|(s, n)| s * n).sum())
}
fn build_sigma_table(
system: &System,
hydro_ids: &[EntityId],
study_stage_ids: &[i32],
n_hydros: usize,
) -> Vec<f64> {
use std::collections::HashMap;
let model_std: HashMap<(i32, i32), f64> = system
.inflow_models()
.iter()
.map(|m| ((m.hydro_id.0, m.stage_id), m.std_m3s))
.collect();
let mut sigma = vec![0.0_f64; study_stage_ids.len() * n_hydros];
for (s_idx, &stage_id) in study_stage_ids.iter().enumerate() {
for (h_idx, hydro_id) in hydro_ids.iter().enumerate() {
if let Some(&std) = model_std.get(&(hydro_id.0, stage_id)) {
sigma[s_idx * n_hydros + h_idx] = std;
}
}
}
sigma
}
#[cfg(test)]
mod tests {
use super::noise_key;
#[test]
fn test_noise_key_sums_sigma_weighted_components() {
let sigma = [30.0, 20.0, 10.0];
let raw_noise = [1.5, -2.0, 0.5];
let key = noise_key(&sigma, &raw_noise).expect("dims aligned");
assert!((key - 10.0).abs() < 1e-12, "expected 10.0, got {key}");
}
#[test]
fn test_noise_key_ignores_trailing_noise_components() {
let sigma = [2.0, 4.0];
let raw_noise = [1.0, 1.0, 100.0, -50.0];
let key = noise_key(&sigma, &raw_noise).expect("dims aligned");
assert!((key - 6.0).abs() < 1e-12, "expected 6.0, got {key}");
}
#[test]
fn test_noise_key_hard_errors_on_sigma_longer_than_noise() {
let sigma = [1.0, 2.0, 3.0];
let raw_noise = [1.0, 1.0];
let err = noise_key(&sigma, &raw_noise).expect_err("must reject mismatch");
let msg = err.to_string();
assert!(msg.contains('3'), "message must name σ length 3: {msg}");
assert!(msg.contains('2'), "message must name noise dim 2: {msg}");
}
}