use std::collections::BTreeMap;
use crate::tensor::Tensor;
use crate::model::tribe::TribeV2;
#[derive(Debug, Clone)]
pub struct Segment {
pub start: f64,
pub duration: f64,
pub has_events: bool,
}
#[derive(Debug, Clone)]
pub struct SegmentConfig {
pub duration_trs: usize,
pub overlap_trs: usize,
pub tr: f64,
pub remove_empty_segments: bool,
pub feature_frequency: f64,
pub stride_drop_incomplete: bool,
}
impl Default for SegmentConfig {
fn default() -> Self {
Self {
duration_trs: 100,
overlap_trs: 0,
tr: 0.5,
remove_empty_segments: true,
feature_frequency: 2.0,
stride_drop_incomplete: false,
}
}
}
#[derive(Debug)]
pub struct SegmentedPrediction {
pub predictions: Vec<Vec<f32>>,
pub segments: Vec<Segment>,
pub total_segments: usize,
pub kept_segments: usize,
}
pub fn compute_segment_boundaries(
total_timesteps: usize,
config: &SegmentConfig,
) -> Vec<(usize, usize)> {
let stride = config.duration_trs.saturating_sub(config.overlap_trs).max(1);
let mut segments = Vec::new();
let mut start = 0;
while start < total_timesteps {
let end = (start + config.duration_trs).min(total_timesteps);
if config.stride_drop_incomplete && (end - start) < config.duration_trs {
break;
}
segments.push((start, end));
start += stride;
if end >= total_timesteps {
break;
}
}
segments
}
fn has_nonzero_features(features: &BTreeMap<String, Tensor>, start: usize, end: usize) -> bool {
for tensor in features.values() {
let t_dim = *tensor.shape.last().unwrap();
let s = start.min(t_dim);
let e = end.min(t_dim);
if s >= e {
continue;
}
let batch_size: usize = tensor.shape[..tensor.shape.len() - 1].iter().product();
for bi in 0..batch_size {
for ti in s..e {
let idx = bi * t_dim + ti;
if idx < tensor.data.len() && tensor.data[idx] != 0.0 {
return true;
}
}
}
}
false
}
fn slice_features(
features: &BTreeMap<String, Tensor>,
start: usize,
end: usize,
) -> BTreeMap<String, Tensor> {
let segment_len = end - start;
let mut sliced = BTreeMap::new();
for (name, tensor) in features {
let t_dim = *tensor.shape.last().unwrap();
let ndim = tensor.ndim();
let batch_shape = &tensor.shape[..ndim - 1];
let batch_size: usize = batch_shape.iter().product();
let mut new_data = vec![0.0f32; batch_size * segment_len];
let copy_start = start.min(t_dim);
let copy_end = end.min(t_dim);
let copy_len = copy_end.saturating_sub(copy_start);
if copy_len > 0 {
for bi in 0..batch_size {
for ti in 0..copy_len {
new_data[bi * segment_len + ti] =
tensor.data[bi * t_dim + copy_start + ti];
}
}
}
let mut new_shape = batch_shape.to_vec();
new_shape.push(segment_len);
sliced.insert(name.clone(), Tensor::from_vec(new_data, new_shape));
}
sliced
}
pub fn predict_segmented(
model: &TribeV2,
features: &BTreeMap<String, Tensor>,
config: &SegmentConfig,
) -> SegmentedPrediction {
let total_timesteps = features
.values()
.next()
.map(|t| *t.shape.last().unwrap())
.unwrap_or(0);
let boundaries = compute_segment_boundaries(total_timesteps, config);
let mut all_predictions: Vec<Vec<f32>> = Vec::new();
let mut kept_segments: Vec<Segment> = Vec::new();
let mut total_trs = 0usize;
let mut kept_trs = 0usize;
for (start, end) in &boundaries {
let segment_len = end - start;
let _has_events = has_nonzero_features(features, *start, *end);
let seg_features = slice_features(features, *start, *end);
let output = model.forward(&seg_features, None, true);
let n_outputs = output.shape[1];
let n_out_ts = output.shape[2];
for ti in 0..n_out_ts {
total_trs += 1;
let keep = if config.remove_empty_segments {
let input_start = *start + (ti * segment_len) / n_out_ts;
let input_end = *start + ((ti + 1) * segment_len) / n_out_ts;
has_nonzero_features(features, input_start, input_end.max(input_start + 1))
} else {
true
};
if keep || !config.remove_empty_segments {
let mut row = Vec::with_capacity(n_outputs);
for di in 0..n_outputs {
row.push(output.data[di * n_out_ts + ti]);
}
all_predictions.push(row);
kept_segments.push(Segment {
start: (*start + ti) as f64 * config.tr,
duration: config.tr,
has_events: keep,
});
kept_trs += 1;
}
}
}
SegmentedPrediction {
predictions: all_predictions,
segments: kept_segments,
total_segments: total_trs,
kept_segments: kept_trs,
}
}
pub fn predict_segments_batched(
model: &TribeV2,
features: &BTreeMap<String, Tensor>,
config: &SegmentConfig,
) -> Vec<(Tensor, Segment)> {
let total_timesteps = features
.values()
.next()
.map(|t| *t.shape.last().unwrap())
.unwrap_or(0);
let boundaries = compute_segment_boundaries(total_timesteps, config);
let mut results = Vec::new();
for (start, end) in &boundaries {
let segment_len = end - start;
let has_events = has_nonzero_features(features, *start, *end);
if config.remove_empty_segments && !has_events {
continue;
}
let segment = Segment {
start: *start as f64 * config.tr,
duration: segment_len as f64 * config.tr,
has_events,
};
let seg_features = slice_features(features, *start, *end);
let output = model.forward(&seg_features, None, true);
results.push((output, segment));
}
results
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_segment_boundaries_no_overlap() {
let config = SegmentConfig {
duration_trs: 10,
overlap_trs: 0,
..Default::default()
};
let segs = compute_segment_boundaries(25, &config);
assert_eq!(segs, vec![(0, 10), (10, 20), (20, 25)]);
}
#[test]
fn test_compute_segment_boundaries_with_overlap() {
let config = SegmentConfig {
duration_trs: 10,
overlap_trs: 5,
..Default::default()
};
let segs = compute_segment_boundaries(20, &config);
assert_eq!(segs, vec![(0, 10), (5, 15), (10, 20)]);
}
#[test]
fn test_compute_segment_boundaries_drop_incomplete() {
let config = SegmentConfig {
duration_trs: 10,
overlap_trs: 0,
stride_drop_incomplete: true,
..Default::default()
};
let segs = compute_segment_boundaries(25, &config);
assert_eq!(segs, vec![(0, 10), (10, 20)]);
}
#[test]
fn test_compute_segment_boundaries_exact() {
let config = SegmentConfig {
duration_trs: 10,
overlap_trs: 0,
..Default::default()
};
let segs = compute_segment_boundaries(20, &config);
assert_eq!(segs, vec![(0, 10), (10, 20)]);
}
#[test]
fn test_slice_features() {
let mut features = BTreeMap::new();
let data: Vec<f32> = (0..20).map(|i| i as f32).collect();
features.insert("text".to_string(), Tensor::from_vec(data, vec![2, 10]));
let sliced = slice_features(&features, 3, 7);
let t = sliced.get("text").unwrap();
assert_eq!(t.shape, vec![2, 4]);
assert_eq!(t.data[0], 3.0);
assert_eq!(t.data[1], 4.0);
assert_eq!(t.data[2], 5.0);
assert_eq!(t.data[3], 6.0);
assert_eq!(t.data[4], 13.0);
assert_eq!(t.data[5], 14.0);
}
#[test]
fn test_slice_features_with_padding() {
let mut features = BTreeMap::new();
let data: Vec<f32> = (0..6).map(|i| i as f32).collect();
features.insert("a".to_string(), Tensor::from_vec(data, vec![2, 3]));
let sliced = slice_features(&features, 2, 5);
let t = sliced.get("a").unwrap();
assert_eq!(t.shape, vec![2, 3]);
assert_eq!(t.data[0], 2.0);
assert_eq!(t.data[1], 0.0);
assert_eq!(t.data[2], 0.0);
}
}