use ndarray::{Array2, ArrayView2};
use rand::{seq::SliceRandom, Rng};
use crate::constants::CHANNEL_GROUPS;
use crate::data::captioning::templates::{ANOMALY_TEMPLATES, TREND_TEMPLATES};
use crate::data::preprocessing::{average_downsample_ct, denormalized};
use crate::error::Result;
pub fn generate_structural_caption<R: Rng>(
x_norm: &ArrayView2<f64>,
max_per_category: usize,
rng: &mut R,
) -> Result<String> {
let x_phys = denormalized(x_norm)?;
let ct: Array2<f64> = x_phys.t().to_owned();
const TARGET_T: usize = 36;
const DOWNSAMPLE_SCALE: usize = 40;
let ct_ds = average_downsample_ct(&ct, TARGET_T);
let mut caption = String::new();
for group in CHANNEL_GROUPS {
let mut insights: Vec<(usize, String)> = Vec::new();
for &(display_name, ch_idx) in group.primary {
if ch_idx >= ct_ds.nrows() {
continue;
}
let channel_data: Vec<f64> = ct_ds.row(ch_idx).iter().copied().collect();
let trends = identify_trends(&channel_data, DOWNSAMPLE_SCALE);
for (start, end, trend_type, _slope, _delta, _seg) in &trends {
insights.push((
*start,
describe_trend(display_name, trend_type, *start, *end, rng),
));
}
let peaks_valleys = detect_peaks_valleys(&channel_data, DOWNSAMPLE_SCALE);
for (minute, anomaly_type) in &peaks_valleys {
insights.push((
*minute,
describe_anomaly(display_name, anomaly_type, *minute, rng),
));
}
}
if insights.len() > max_per_category {
insights.shuffle(rng);
insights.truncate(max_per_category);
insights.sort_by_key(|(t, _)| *t);
}
let category_text: Vec<&str> = insights.iter().map(|(_, s)| s.as_str()).collect();
caption.push_str(&format!("{}: {}\n", group.category, category_text.join(" ")));
}
Ok(caption)
}
type TrendResult = (usize, usize, String, f64, f64, usize);
fn identify_trends(data: &[f64], downsample_scale: usize) -> Vec<TrendResult> {
if data.is_empty() {
return vec![];
}
let max_v = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let min_v = data.iter().cloned().fold(f64::INFINITY, f64::min);
let range = (max_v - min_v).max(1e-9);
let thresholds: &[(usize, f64)] = &[(6, 1.5), (8, 1.3), (12, 1.0)];
let stable_threshold = 0.01 * range;
let mut candidates: Vec<TrendResult> = Vec::new();
for &(seg, scale) in thresholds {
let slope_thresh = scale * range / 40.0;
let step = seg / 2; let mut start = 0;
while start + seg <= data.len() {
let slice = &data[start..start + seg];
let slope = linear_regression_slope(slice);
let delta_val = slice[seg - 1] - slice[0];
let start_min = (start + 1) * downsample_scale;
let end_min = (start + seg) * downsample_scale;
if slope > slope_thresh && delta_val > 0.2 * range {
candidates.push((start_min, end_min, "increasing".into(), slope, delta_val, seg));
} else if slope < -slope_thresh && (-delta_val) > 0.2 * range {
candidates.push((start_min, end_min, "decreasing".into(), slope, -delta_val, seg));
} else if slope.abs() < stable_threshold {
let seg_range = slice.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
- slice.iter().cloned().fold(f64::INFINITY, f64::min);
if seg_range < 0.1 * range {
candidates.push((start_min, end_min, "stable".into(), slope, slice[seg - 1], seg));
}
}
start += step;
}
}
candidates.sort_by(|a, b| b.4.partial_cmp(&a.4).unwrap_or(std::cmp::Ordering::Equal));
let mut selected: Vec<TrendResult> = Vec::new();
'outer: for cand in candidates {
if selected.len() == 3 {
break;
}
let (s1, e1, ..) = cand;
for &(s2, e2, ..) in &selected {
let overlap = overlap_fraction(s1, e1, s2, e2);
if overlap > 0.3 {
continue 'outer;
}
}
selected.push(cand);
}
selected
}
fn overlap_fraction(s1: usize, e1: usize, s2: usize, e2: usize) -> f64 {
let ov = (e1.min(e2) as isize - s1.max(s2) as isize).max(0) as f64;
let shorter = ((e1 - s1).min(e2 - s2)) as f64;
if shorter == 0.0 { 0.0 } else { ov / shorter }
}
fn linear_regression_slope(y: &[f64]) -> f64 {
let n = y.len() as f64;
let x_mean = (n - 1.0) / 2.0;
let y_mean: f64 = y.iter().sum::<f64>() / n;
let num: f64 = y
.iter()
.enumerate()
.map(|(i, &yi)| (i as f64 - x_mean) * (yi - y_mean))
.sum();
let den: f64 = (0..y.len())
.map(|i| (i as f64 - x_mean).powi(2))
.sum();
if den == 0.0 { 0.0 } else { num / den }
}
fn detect_peaks_valleys(data: &[f64], downsample_scale: usize) -> Vec<(usize, String)> {
if data.len() < 3 {
return vec![];
}
let max_v = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let min_v = data.iter().cloned().fold(f64::INFINITY, f64::min);
let mean_v: f64 = data.iter().sum::<f64>() / data.len() as f64;
let range = (max_v - min_v).max(1e-9);
const PROMINENCE_THRESHOLD: f64 = 0.5;
const HEIGHT_THRESHOLD: f64 = 0.6;
const DISTANCE: usize = 5;
let prom_thresh = PROMINENCE_THRESHOLD * range;
let height_thresh = HEIGHT_THRESHOLD * range + mean_v;
let valley_thresh = -(mean_v + (1.0 - HEIGHT_THRESHOLD) * range);
let mut results = Vec::new();
let peaks = find_peaks(data, prom_thresh, Some(height_thresh), DISTANCE);
for p in peaks {
results.push(((p + 1) * downsample_scale, "spike".to_string()));
}
let inv: Vec<f64> = data.iter().map(|x| -x).collect();
let valleys = find_peaks(&inv, prom_thresh, Some(valley_thresh), DISTANCE);
for v in valleys {
results.push(((v + 1) * downsample_scale, "drop".to_string()));
}
results
}
fn find_peaks(
data: &[f64],
prominence_threshold: f64,
height_threshold: Option<f64>,
min_distance: usize,
) -> Vec<usize> {
let n = data.len();
let mut peaks: Vec<(usize, f64)> = Vec::new();
for i in 1..n - 1 {
if data[i] > data[i - 1] && data[i] > data[i + 1] {
if let Some(ht) = height_threshold {
if data[i] < ht {
continue;
}
}
let left_min = data[..i].iter().cloned().fold(f64::INFINITY, f64::min);
let right_min = data[i + 1..].iter().cloned().fold(f64::INFINITY, f64::min);
let prominence = data[i] - left_min.max(right_min);
if prominence >= prominence_threshold {
peaks.push((i, data[i]));
}
}
}
peaks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut selected: Vec<usize> = Vec::new();
for (idx, _) in peaks {
if selected
.iter()
.all(|&s| (idx as isize - s as isize).unsigned_abs() >= min_distance)
{
selected.push(idx);
}
}
selected
}
fn describe_trend<R: Rng>(
sensor_name: &str,
trend_type: &str,
start: usize,
end: usize,
rng: &mut R,
) -> String {
let tmpl = TREND_TEMPLATES
.choose(rng)
.copied()
.unwrap_or(TREND_TEMPLATES[0]);
tmpl.replace("{sensor_name}", sensor_name)
.replace("{trend_type}", trend_type)
.replace("{start}", &start.to_string())
.replace("{end}", &end.to_string())
}
fn describe_anomaly<R: Rng>(
sensor_name: &str,
anomaly: &str,
time: usize,
rng: &mut R,
) -> String {
let tmpl = ANOMALY_TEMPLATES
.choose(rng)
.copied()
.unwrap_or(ANOMALY_TEMPLATES[0]);
tmpl.replace("{sensor_name}", sensor_name)
.replace("{anomaly}", anomaly)
.replace("{time}", &time.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
use rand::{rngs::StdRng, SeedableRng};
use crate::constants::NUM_CHANNELS;
#[test]
fn test_structural_caption_runs() {
let x = Array2::<f64>::zeros((1440, NUM_CHANNELS));
let mut rng = StdRng::seed_from_u64(7);
let cap = generate_structural_caption(&x.view(), 7, &mut rng).unwrap();
assert!(!cap.is_empty());
}
#[test]
fn test_linreg_slope() {
let y: Vec<f64> = (0..10).map(|i| i as f64).collect();
let slope = linear_regression_slope(&y);
assert!((slope - 1.0).abs() < 1e-9, "slope should be 1.0, got {slope}");
}
#[test]
fn test_find_peaks() {
let data = vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
let peaks = find_peaks(&data, 0.5, None, 2);
assert_eq!(peaks.len(), 2);
}
}