use super::*;
#[derive(Debug, Clone)]
pub struct EpisodeSegment {
pub records: Vec<EpisodicRecord>,
pub start_time: Timestamp,
pub end_time: Timestamp,
pub dominant_entities: Vec<String>,
pub topic_embedding: Option<Vec<f32>>,
}
impl EpisodeSegment {
pub(super) fn from_records(records: Vec<EpisodicRecord>) -> Option<Self> {
let start_time = records.first()?.timestamp;
let end_time = records.last()?.timestamp;
let mut entity_counts: HashMap<String, usize> = HashMap::new();
for r in &records {
for e in &r.entities {
*entity_counts.entry(e.name.clone()).or_default() += 1;
}
}
let mut dominant: Vec<(String, usize)> = entity_counts.into_iter().collect();
dominant.sort_by_key(|item| std::cmp::Reverse(item.1));
let dominant_entities: Vec<String> =
dominant.into_iter().take(5).map(|(name, _)| name).collect();
let embeddings: Vec<&Vec<f32>> = records
.iter()
.filter_map(|r| r.embedding.as_ref())
.collect();
let topic_embedding = if embeddings.is_empty() {
None
} else {
let dims = embeddings[0].len();
let mut mean = vec![0.0f32; dims];
for emb in &embeddings {
for (i, v) in emb.iter().enumerate() {
mean[i] += v;
}
}
let n = embeddings.len() as f32;
for v in &mut mean {
*v /= n;
}
let norm: f32 = mean.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for v in &mut mean {
*v /= norm;
}
}
Some(mean)
};
Some(Self {
records,
start_time,
end_time,
dominant_entities,
topic_embedding,
})
}
}
pub fn segment_episodes(
records: &[EpisodicRecord],
config: &ConsolidationConfig,
) -> Vec<EpisodeSegment> {
if records.is_empty() {
return Vec::new();
}
if records.len() == 1 {
return EpisodeSegment::from_records(vec![records[0].clone()])
.into_iter()
.collect();
}
let lookback = config.segmentation_lookback;
let gamma = config.segmentation_gamma;
let mut dissimilarities: Vec<f64> = Vec::with_capacity(records.len());
let mut surprises: Vec<f64> = Vec::with_capacity(records.len());
dissimilarities.push(0.0);
surprises.push(records[0].surprise as f64);
for i in 1..records.len() {
let prev = &records[i - 1];
let curr = &records[i];
let dissim = if let (Some(emb_prev), Some(emb_curr)) = (&prev.embedding, &curr.embedding) {
let sim = 1.0 - lance_linalg::distance::cosine_distance(emb_prev, emb_curr);
(1.0 - sim) as f64
} else {
0.0
};
dissimilarities.push(dissim);
surprises.push(curr.surprise as f64);
}
let mut boundaries: HashSet<usize> = HashSet::new();
for i in 1..records.len() {
let prev = &records[i - 1];
let curr = &records[i];
let topic_threshold = adaptive_threshold(
&dissimilarities,
i,
lookback,
gamma,
config.topic_similarity_threshold as f64, );
if dissimilarities[i] > topic_threshold {
boundaries.insert(i);
}
let surprise_threshold = adaptive_threshold(
&surprises,
i,
lookback,
gamma,
config.surprise_threshold as f64, );
if surprises[i] > surprise_threshold {
boundaries.insert(i);
}
let gap_secs = curr
.timestamp
.as_datetime()
.signed_duration_since(prev.timestamp.as_datetime())
.num_seconds();
if gap_secs > config.temporal_gap_seconds {
boundaries.insert(i);
}
}
let mut segments = Vec::new();
let mut start = 0;
let mut sorted_boundaries: Vec<usize> = boundaries.into_iter().collect();
sorted_boundaries.sort_unstable();
for boundary in sorted_boundaries {
if boundary > start {
segments.extend(EpisodeSegment::from_records(
records[start..boundary].to_vec(),
));
}
start = boundary;
}
if start < records.len() {
segments.extend(EpisodeSegment::from_records(records[start..].to_vec()));
}
segments
}
fn adaptive_threshold(
signal: &[f64],
current_idx: usize,
lookback: usize,
gamma: f64,
floor: f64,
) -> f64 {
if current_idx < 2 {
return floor;
}
let window_start = current_idx.saturating_sub(lookback);
let window = &signal[window_start..current_idx];
if window.is_empty() {
return floor;
}
let n = window.len() as f64;
let mean = window.iter().sum::<f64>() / n;
let variance = window.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
let stddev = variance.sqrt();
(mean + gamma * stddev).max(floor)
}