use ndarray::{Array2, ArrayView2};
use rand::{seq::SliceRandom, Rng};
use crate::constants::CHANNEL_GROUPS;
use crate::data::preprocessing::{channel_stats, denormalized};
use crate::data::captioning::templates::LOW_LEVEL_TEMPLATES;
use crate::error::Result;
pub fn generate_statistical_caption<R: Rng>(
x_norm: &ArrayView2<f64>,
mask: Option<&Array2<u8>>,
rng: &mut R,
) -> Result<String> {
use crate::data::preprocessing::apply_mask;
let mut x_phys = denormalized(x_norm)?;
if let Some(m) = mask {
apply_mask(&mut x_phys, m)?;
}
let stats = channel_stats(&x_phys);
let mut parts = Vec::new();
for group in CHANNEL_GROUPS {
let mut group_parts = Vec::new();
for &(display_name, ch_idx) in group.primary {
let (mean, max, min, std) = stats[ch_idx];
if [mean, max, min, std].iter().any(|v| v.is_nan()) {
continue;
}
group_parts.push(describe_low_level(display_name, mean, max, min, std, rng));
}
if group.random_k > 0 && !group.random.is_empty() {
let sample: Vec<_> = group
.random
.choose_multiple(rng, group.random_k)
.collect();
for &&(display_name, ch_idx) in &sample {
let (mean, max, min, std) = stats[ch_idx];
if [mean, max, min, std].iter().any(|v| v.is_nan()) {
continue;
}
group_parts.push(describe_low_level(display_name, mean, max, min, std, rng));
}
}
if !group_parts.is_empty() {
parts.push(format!("For {}, {}\n", group.category, group_parts.join(" ")));
}
}
Ok(parts.concat())
}
fn describe_low_level<R: Rng>(
name: &str,
mean_val: f64,
max_val: f64,
min_val: f64,
std_val: f64,
rng: &mut R,
) -> String {
let tmpl = LOW_LEVEL_TEMPLATES.choose(rng).copied().unwrap_or(LOW_LEVEL_TEMPLATES[0]);
tmpl.replace("{name}", name)
.replace("{mean_val:.1}", &format!("{mean_val:.1}"))
.replace("{max_val:.1}", &format!("{max_val:.1}"))
.replace("{min_val:.1}", &format!("{min_val:.1}"))
.replace("{std_val:.1}", &format!("{std_val:.1}"))
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
use rand::SeedableRng;
use rand::rngs::StdRng;
use crate::constants::NUM_CHANNELS;
#[test]
fn test_statistical_caption_runs() {
let x = Array2::<f64>::zeros((1440, NUM_CHANNELS));
let mut rng = StdRng::seed_from_u64(42);
let cap = generate_statistical_caption(&x.view(), None, &mut rng).unwrap();
assert!(!cap.is_empty(), "Caption must be non-empty");
assert!(cap.contains("Heart"), "Caption must mention Heart group");
}
}