use super::*;
#[derive(Debug, Clone)]
pub struct NarrativeThread {
pub title: String,
pub segment_indices: Vec<usize>,
pub record_ids: Vec<MemoryId>,
pub contents: Vec<String>,
pub summaries: Vec<String>,
pub start_time: Timestamp,
pub end_time: Timestamp,
pub entities: Vec<String>,
pub sub_threads: Vec<Self>,
pub embedding: Option<Vec<f32>>,
}
struct CondensedMatrix {
data: Vec<f32>,
n: usize,
}
impl CondensedMatrix {
fn new(segments: &[EpisodeSegment]) -> Self {
let n = segments.len();
let size = n * (n - 1) / 2;
let mut data = Vec::with_capacity(size);
let entity_sets: Vec<HashSet<&str>> = segments
.iter()
.map(|s| s.dominant_entities.iter().map(String::as_str).collect())
.collect();
for i in 0..n {
for j in (i + 1)..n {
let embedding_sim =
match (&segments[i].topic_embedding, &segments[j].topic_embedding) {
(Some(ea), Some(eb)) => {
1.0 - lance_linalg::distance::cosine_distance(ea, eb)
}
_ => 0.0,
};
let intersection = entity_sets[i].intersection(&entity_sets[j]).count();
let union = entity_sets[i].union(&entity_sets[j]).count();
let entity_sim = if union > 0 {
intersection as f32 / union as f32
} else {
0.0
};
data.push(embedding_sim * 0.6 + entity_sim * 0.4);
}
}
Self { data, n }
}
#[inline]
fn get(&self, i: usize, j: usize) -> f32 {
let (a, b) = if i < j { (i, j) } else { (j, i) };
self.data[a * self.n - a * (a + 1) / 2 + b - a - 1]
}
}
struct UnionFind {
parent: Vec<usize>,
rank: Vec<usize>,
}
impl UnionFind {
fn new(n: usize) -> Self {
Self {
parent: (0..n).collect(),
rank: vec![0; n],
}
}
fn find(&mut self, x: usize) -> usize {
if self.parent[x] != x {
self.parent[x] = self.find(self.parent[x]);
}
self.parent[x]
}
fn union(&mut self, x: usize, y: usize) -> bool {
let rx = self.find(x);
let ry = self.find(y);
if rx == ry {
return false;
}
match self.rank[rx].cmp(&self.rank[ry]) {
std::cmp::Ordering::Less => self.parent[rx] = ry,
std::cmp::Ordering::Greater => self.parent[ry] = rx,
std::cmp::Ordering::Equal => {
self.parent[ry] = rx;
self.rank[rx] += 1;
}
}
true
}
}
pub fn form_narrative_threads(
segments: &[EpisodeSegment],
_patterns: &DetectedPatterns,
config: &ConsolidationConfig,
) -> Vec<NarrativeThread> {
if segments.is_empty() {
return Vec::new();
}
if segments.len() == 1 {
return vec![thread_from_segment_group(segments, &[0])];
}
let n = segments.len();
let matrix = CondensedMatrix::new(segments);
let mut edges: Vec<(f32, usize, usize)> = Vec::with_capacity(n * (n - 1) / 2);
for i in 0..n {
for j in (i + 1)..n {
let sim = matrix.get(i, j);
if sim >= config.thread_similarity_threshold {
edges.push((sim, i, j));
}
}
}
edges.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let mut uf = UnionFind::new(n);
for &(_sim, i, j) in &edges {
uf.union(i, j);
}
let mut cluster_map: HashMap<usize, Vec<usize>> = HashMap::new();
for i in 0..n {
cluster_map.entry(uf.find(i)).or_default().push(i);
}
let mut threads: Vec<NarrativeThread> = Vec::new();
for (_root, cluster) in &cluster_map {
let mut thread = thread_from_segment_group(segments, cluster);
if cluster.len() >= 4 {
let sub_threads = detect_sub_threads(segments, cluster, config, &matrix);
if sub_threads.len() > 1 {
thread.sub_threads = sub_threads;
}
}
threads.push(thread);
}
threads.sort_by_key(|thread| thread.start_time);
threads
}
fn thread_from_segment_group(segments: &[EpisodeSegment], indices: &[usize]) -> NarrativeThread {
let mut all_records: Vec<&EpisodicRecord> = Vec::new();
let mut all_entities: HashMap<String, usize> = HashMap::new();
let mut embeddings: Vec<&Vec<f32>> = Vec::new();
for &idx in indices {
let seg = &segments[idx];
for rec in &seg.records {
all_records.push(rec);
for ent in &rec.entities {
*all_entities.entry(ent.name.clone()).or_default() += 1;
}
}
if let Some(ref emb) = seg.topic_embedding {
embeddings.push(emb);
}
}
all_records.sort_by_key(|r| r.timestamp);
let mut entity_list: Vec<(String, usize)> = all_entities.into_iter().collect();
entity_list.sort_by_key(|item| std::cmp::Reverse(item.1));
let top_entities: Vec<String> = entity_list
.iter()
.take(3)
.map(|(name, _)| name.clone())
.collect();
let title = if top_entities.is_empty() {
"Unnamed Thread".to_string()
} else {
top_entities.join(", ")
};
let entities: Vec<String> = entity_list.into_iter().map(|(name, _)| name).collect();
let embedding = if embeddings.is_empty() {
None
} else {
let dims = embeddings[0].len();
let mut mean = vec![0.0f32; dims];
for emb in &embeddings {
for (i, v) in emb.iter().enumerate() {
mean[i] += v;
}
}
let n = embeddings.len() as f32;
for v in &mut mean {
*v /= n;
}
let norm: f32 = mean.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for v in &mut mean {
*v /= norm;
}
}
Some(mean)
};
let (start_time, end_time) = match (all_records.first(), all_records.last()) {
(Some(first), Some(last)) => (first.timestamp, last.timestamp),
_ => (Timestamp::default(), Timestamp::default()),
};
NarrativeThread {
title,
segment_indices: indices.to_vec(),
record_ids: all_records.iter().map(|r| r.id).collect(),
contents: all_records.iter().map(|r| r.content.clone()).collect(),
summaries: all_records.iter().map(|r| r.summary.clone()).collect(),
start_time,
end_time,
entities,
sub_threads: Vec::new(),
embedding,
}
}
fn detect_sub_threads(
segments: &[EpisodeSegment],
cluster: &[usize],
config: &ConsolidationConfig,
matrix: &CondensedMatrix,
) -> Vec<NarrativeThread> {
let tighter_threshold = config.thread_similarity_threshold + 0.15;
let m = cluster.len();
let mut edges: Vec<(f32, usize, usize)> = Vec::new();
for ci in 0..m {
for cj in (ci + 1)..m {
let sim = matrix.get(cluster[ci], cluster[cj]);
if sim >= tighter_threshold {
edges.push((sim, ci, cj));
}
}
}
edges.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let mut uf = UnionFind::new(m);
for &(_sim, i, j) in &edges {
uf.union(i, j);
}
let mut sub_map: HashMap<usize, Vec<usize>> = HashMap::new();
for ci in 0..m {
sub_map.entry(uf.find(ci)).or_default().push(cluster[ci]);
}
sub_map
.into_values()
.map(|c| thread_from_segment_group(segments, &c))
.collect()
}