1use std::collections::HashMap;
14use std::sync::Arc;
15
16use hirn_core::embed::{ChatMessage, LlmOptions, LlmProvider};
17use hirn_core::error::HirnResult;
18use hirn_core::id::MemoryId;
19use hirn_core::metadata::Metadata;
20use hirn_core::semantic::SemanticRecord;
21use hirn_core::types::{AgentId, EdgeRelation, KnowledgeType, Layer, Origin};
22
23use crate::db::HirnDB;
24use crate::graph_store::GraphStore;
25
26#[derive(Debug, Clone)]
32pub struct CommunityConfig {
33 pub resolution: f64,
36 pub auto_resolution: bool,
41 pub max_iterations: usize,
43 pub max_levels: usize,
45 pub min_community_size: usize,
48}
49
50impl Default for CommunityConfig {
51 fn default() -> Self {
52 Self {
53 resolution: 1.0,
54 auto_resolution: true,
55 max_iterations: 10,
56 max_levels: 5,
57 min_community_size: 2,
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
68pub struct Community {
69 pub level: usize,
71 pub index: usize,
72 pub members: Vec<MemoryId>,
74 pub parent: Option<usize>,
76 pub children: Vec<usize>,
78}
79
80#[derive(Debug, Clone)]
82pub struct CommunityResult {
83 pub levels: Vec<Vec<Community>>,
86 pub node_to_community: HashMap<MemoryId, usize>,
88 pub total_communities: usize,
90}
91
92struct AdjacencyGraph {
99 n: usize,
101 adj: Vec<Vec<(usize, f64)>>,
103 total_weight: f64,
105 degree: Vec<f64>,
107 index_to_id: Vec<MemoryId>,
109}
110
111impl AdjacencyGraph {
112 async fn from_graph_store(store: &dyn GraphStore) -> HirnResult<Self> {
114 let node_ids = store.node_ids().await?;
115 let n = node_ids.len();
116
117 let id_to_index: HashMap<MemoryId, usize> = node_ids
118 .iter()
119 .enumerate()
120 .map(|(i, id)| (*id, i))
121 .collect();
122
123 let mut adj: Vec<Vec<(usize, f64)>> = vec![vec![]; n];
124
125 for edge in store.all_edges().await? {
126 let Some(&src) = id_to_index.get(&edge.source) else {
127 continue;
128 };
129 let Some(&tgt) = id_to_index.get(&edge.target) else {
130 continue;
131 };
132 if src == tgt {
133 continue;
134 }
135 let w = edge.weight as f64;
136 adj[src].push((tgt, w));
137 adj[tgt].push((src, w));
138 }
139
140 for neighbors in &mut adj {
141 neighbors.sort_by_key(|&(idx, _)| idx);
142 neighbors.dedup_by(|a, b| {
143 if a.0 == b.0 {
144 b.1 += a.1;
145 true
146 } else {
147 false
148 }
149 });
150 }
151
152 let mut total_weight = 0.0;
153 for neighbors in &adj {
154 for &(_, w) in neighbors {
155 total_weight += w;
156 }
157 }
158 total_weight /= 2.0;
159
160 let degree: Vec<f64> = adj
161 .iter()
162 .map(|ns| ns.iter().map(|&(_, w)| w).sum())
163 .collect();
164
165 drop(id_to_index);
166
167 Ok(Self {
168 n,
169 adj,
170 total_weight,
171 degree,
172 index_to_id: node_ids,
173 })
174 }
175
176 fn coarsen(&self, assignments: &[usize], num_communities: usize) -> AdjacencyGraph {
178 let mut adj: Vec<Vec<(usize, f64)>> = vec![vec![]; num_communities];
179
180 for (node, neighbors) in self.adj.iter().enumerate() {
181 let c1 = assignments[node];
182 for &(neighbor, w) in neighbors {
183 let c2 = assignments[neighbor];
184 if c1 != c2 {
185 adj[c1].push((c2, w));
186 }
187 }
188 }
189
190 for neighbors in &mut adj {
192 neighbors.sort_by_key(|&(idx, _)| idx);
193 neighbors.dedup_by(|a, b| {
194 if a.0 == b.0 {
195 b.1 += a.1;
196 true
197 } else {
198 false
199 }
200 });
201 }
202
203 let mut total_weight = 0.0;
204 for neighbors in &adj {
205 for &(_, w) in neighbors {
206 total_weight += w;
207 }
208 }
209 total_weight /= 2.0;
210
211 let degree: Vec<f64> = adj
212 .iter()
213 .map(|ns| ns.iter().map(|&(_, w)| w).sum())
214 .collect();
215
216 let index_to_id = (0..num_communities).map(|_| MemoryId::new()).collect();
218 AdjacencyGraph {
219 n: num_communities,
220 adj,
221 total_weight,
222 degree,
223 index_to_id,
224 }
225 }
226}
227
228pub async fn detect_communities(
234 store: &dyn GraphStore,
235 config: &CommunityConfig,
236) -> HirnResult<CommunityResult> {
237 let adj = AdjacencyGraph::from_graph_store(store).await?;
238
239 if adj.n == 0 {
240 return Ok(CommunityResult {
241 levels: vec![],
242 node_to_community: HashMap::new(),
243 total_communities: 0,
244 });
245 }
246
247 let effective_resolution = if config.auto_resolution && adj.n > 0 {
249 let avg_degree = 2.0 * adj.total_weight / adj.n as f64;
252 avg_degree.sqrt().max(0.1_f64).min(10.0_f64)
253 } else {
254 config.resolution
255 };
256 let effective_config = CommunityConfig {
257 resolution: effective_resolution,
258 auto_resolution: false, ..*config
260 };
261
262 let base_index_to_id = adj.index_to_id.clone();
263 let mut all_levels: Vec<Vec<usize>> = Vec::new();
264 let mut current_graph = adj;
265
266 for _level in 0..effective_config.max_levels {
267 if current_graph.n <= 1 {
268 break;
269 }
270 let assignments = leiden_one_level(¤t_graph, &effective_config);
271 let num_communities = *assignments.iter().max().unwrap_or(&0) + 1;
272 if num_communities >= current_graph.n {
273 break;
274 }
275 all_levels.push(assignments.clone());
276 current_graph = current_graph.coarsen(&assignments, num_communities);
277 }
278
279 Ok(build_community_result(
280 &base_index_to_id,
281 &all_levels,
282 &effective_config,
283 ))
284}
285
286fn leiden_one_level(graph: &AdjacencyGraph, config: &CommunityConfig) -> Vec<usize> {
288 let n = graph.n;
289 let mut assignment: Vec<usize> = (0..n).collect();
291 let mut num_communities = n;
292
293 for _iteration in 0..config.max_iterations {
294 let mut improved = false;
295
296 let mut comm_degree: HashMap<usize, f64> = HashMap::new();
298 for (i, &c) in assignment.iter().enumerate() {
299 *comm_degree.entry(c).or_default() += graph.degree[i];
300 }
301
302 let m2 = 2.0 * graph.total_weight;
303
304 for node in 0..n {
307 let current_comm = assignment[node];
308
309 let mut comm_weights: HashMap<usize, f64> = HashMap::new();
311 for &(neighbor, w) in &graph.adj[node] {
312 let nc = assignment[neighbor];
313 *comm_weights.entry(nc).or_default() += w;
314 }
315
316 let ki = graph.degree[node];
318 if m2 == 0.0 {
319 continue;
320 }
321
322 let mut best_comm = current_comm;
323 let mut best_delta = 0.0;
324
325 let w_in_current = comm_weights.get(¤t_comm).copied().unwrap_or(0.0);
327 let sigma_current = comm_degree.get(¤t_comm).copied().unwrap_or(0.0);
328
329 for (&candidate_comm, &w_to_candidate) in &comm_weights {
330 if candidate_comm == current_comm {
331 continue;
332 }
333 let sigma_candidate = comm_degree.get(&candidate_comm).copied().unwrap_or(0.0);
334
335 let delta = (w_to_candidate - w_in_current)
337 + config.resolution * ki * (sigma_current - ki - sigma_candidate) / m2;
338
339 if delta > best_delta {
340 best_delta = delta;
341 best_comm = candidate_comm;
342 }
343 }
344
345 if best_comm != current_comm {
346 *comm_degree.entry(current_comm).or_default() -= ki;
348 *comm_degree.entry(best_comm).or_default() += ki;
349 assignment[node] = best_comm;
350 improved = true;
351 }
352 }
353
354 if !improved {
355 break;
356 }
357
358 let (compacted, new_count) = compact_assignments(&assignment);
360 assignment = compacted;
361 num_communities = new_count;
362 }
363
364 let refined = refine_communities(&assignment, num_communities, graph, config);
367 refined
368}
369
370fn compact_assignments(assignments: &[usize]) -> (Vec<usize>, usize) {
372 let mut mapping: HashMap<usize, usize> = HashMap::new();
373 let mut next_id = 0;
374 let compacted: Vec<usize> = assignments
375 .iter()
376 .map(|&c| {
377 *mapping.entry(c).or_insert_with(|| {
378 let id = next_id;
379 next_id += 1;
380 id
381 })
382 })
383 .collect();
384 (compacted, next_id)
385}
386
387fn refine_communities(
391 assignments: &[usize],
392 _num_communities: usize,
393 graph: &AdjacencyGraph,
394 _config: &CommunityConfig,
395) -> Vec<usize> {
396 let mut refined = assignments.to_vec();
397
398 for _pass in 0..3 {
399 let mut changed = false;
400
401 for node in 0..graph.n {
402 let my_comm = refined[node];
403
404 let mut w_internal = 0.0;
406 let mut best_external_comm = my_comm;
408 let mut best_external_weight = 0.0;
409 let mut ext_weights: HashMap<usize, f64> = HashMap::new();
410
411 for &(neighbor, w) in &graph.adj[node] {
412 if refined[neighbor] == my_comm {
413 w_internal += w;
414 } else {
415 *ext_weights.entry(refined[neighbor]).or_default() += w;
416 }
417 }
418
419 for (&c, &w) in &ext_weights {
420 if w > best_external_weight {
421 best_external_weight = w;
422 best_external_comm = c;
423 }
424 }
425
426 if best_external_weight > w_internal && best_external_comm != my_comm {
428 refined[node] = best_external_comm;
429 changed = true;
430 }
431 }
432
433 if !changed {
434 break;
435 }
436 }
437
438 let (compacted, _) = compact_assignments(&refined);
440 compacted
441}
442
443fn build_community_result(
448 base_index_to_id: &[MemoryId],
449 levels: &[Vec<usize>],
450 config: &CommunityConfig,
451) -> CommunityResult {
452 if levels.is_empty() {
453 let mut node_to_community = HashMap::new();
455 for (i, id) in base_index_to_id.iter().enumerate() {
456 node_to_community.insert(*id, i);
457 }
458 return CommunityResult {
459 levels: vec![],
460 node_to_community,
461 total_communities: 0,
462 };
463 }
464
465 let base_assignments = &levels[0];
466 let num_base = *base_assignments.iter().max().unwrap_or(&0) + 1;
467
468 let mut leaf_communities: Vec<Community> = (0..num_base)
470 .map(|idx| Community {
471 level: 0,
472 index: idx,
473 members: vec![],
474 parent: None,
475 children: vec![],
476 })
477 .collect();
478
479 let mut node_to_community = HashMap::new();
481 for (node_idx, &comm) in base_assignments.iter().enumerate() {
482 if comm < leaf_communities.len() {
483 let id = base_index_to_id[node_idx];
484 leaf_communities[comm].members.push(id);
485 node_to_community.insert(id, comm);
486 }
487 }
488
489 let mut valid_leaf: Vec<Community> = leaf_communities
491 .into_iter()
492 .filter(|c| c.members.len() >= config.min_community_size)
493 .collect();
494
495 for (i, c) in valid_leaf.iter_mut().enumerate() {
497 c.index = i;
498 }
499
500 node_to_community.clear();
502 for c in &valid_leaf {
503 for &member in &c.members {
504 node_to_community.insert(member, c.index);
505 }
506 }
507
508 let mut all_levels: Vec<Vec<Community>> = vec![valid_leaf];
509
510 for (level_idx, level_assignments) in levels.iter().skip(1).enumerate() {
512 let prev_level = &all_levels[level_idx];
513 let num_comms = *level_assignments.iter().max().unwrap_or(&0) + 1;
514
515 let mut higher: Vec<Community> = (0..num_comms)
516 .map(|idx| Community {
517 level: level_idx + 1,
518 index: idx,
519 members: vec![],
520 parent: None,
521 children: vec![],
522 })
523 .collect();
524
525 for (prev_idx, &parent_comm) in level_assignments.iter().enumerate() {
527 if prev_idx < prev_level.len() && parent_comm < higher.len() {
528 higher[parent_comm]
529 .children
530 .push(prev_level[prev_idx].index);
531 higher[parent_comm]
533 .members
534 .extend_from_slice(&prev_level[prev_idx].members);
535 }
536 }
537
538 for (prev_idx, &parent_comm) in level_assignments.iter().enumerate() {
541 if prev_idx < all_levels[level_idx].len() {
542 all_levels[level_idx][prev_idx].parent = Some(parent_comm);
543 }
544 }
545
546 let valid: Vec<Community> = higher
547 .into_iter()
548 .filter(|c| !c.members.is_empty())
549 .collect();
550 all_levels.push(valid);
551 }
552
553 let total = all_levels.iter().map(|l| l.len()).sum();
554
555 CommunityResult {
556 levels: all_levels,
557 node_to_community,
558 total_communities: total,
559 }
560}
561
562#[derive(Debug, Clone)]
568pub struct CommunityDelta {
569 pub added: Vec<usize>,
571 pub modified: Vec<usize>,
573 pub unchanged: Vec<usize>,
575 pub removed: Vec<usize>,
577}
578
579pub fn compute_community_delta(prev: &CommunityResult, new: &CommunityResult) -> CommunityDelta {
584 use std::collections::HashSet;
585
586 fn member_key(community: &Community) -> Vec<MemoryId> {
587 let mut ids = community.members.clone();
588 ids.sort();
589 ids
590 }
591
592 let prev_leaves = prev.levels.first().map(|l| l.as_slice()).unwrap_or(&[]);
593 let new_leaves = new.levels.first().map(|l| l.as_slice()).unwrap_or(&[]);
594
595 let prev_keys: HashMap<Vec<MemoryId>, usize> = prev_leaves
597 .iter()
598 .map(|c| (member_key(c), c.index))
599 .collect();
600
601 let new_keys: HashMap<Vec<MemoryId>, usize> = new_leaves
602 .iter()
603 .map(|c| (member_key(c), c.index))
604 .collect();
605
606 let mut added = Vec::new();
607 let mut modified = Vec::new();
608 let mut unchanged = Vec::new();
609
610 for community in new_leaves {
611 let key = member_key(community);
612 if prev_keys.contains_key(&key) {
613 unchanged.push(community.index);
614 } else {
615 let new_members: HashSet<_> = community.members.iter().collect();
618 let is_modified = prev_leaves
619 .iter()
620 .any(|pc| pc.members.iter().any(|m| new_members.contains(m)));
621 if is_modified {
622 modified.push(community.index);
623 } else {
624 added.push(community.index);
625 }
626 }
627 }
628
629 let removed: Vec<usize> = prev_leaves
631 .iter()
632 .filter(|pc| {
633 let key = member_key(pc);
634 !new_keys.contains_key(&key)
635 })
636 .map(|pc| pc.index)
637 .collect();
638
639 CommunityDelta {
640 added,
641 modified,
642 unchanged,
643 removed,
644 }
645}
646
647#[derive(Debug, Clone)]
653pub struct CommunitySummaryResult {
654 pub summaries_stored: usize,
656 pub edges_created: usize,
658}
659
660async fn community_edge_exists(
661 db: &HirnDB,
662 source: MemoryId,
663 target: MemoryId,
664 relation: EdgeRelation,
665) -> bool {
666 match db.cached_graph().get_edges_between(source, target).await {
667 Ok(edges) => edges.iter().any(|edge| {
668 edge.relation == relation && edge.source == source && edge.target == target
669 }),
670 Err(error) => {
671 tracing::warn!(
672 source = %source,
673 target = %target,
674 relation = ?relation,
675 error = %error,
676 "failed to inspect community summary edge"
677 );
678 false
679 }
680 }
681}
682
683async fn ensure_community_edge(
684 db: &HirnDB,
685 source: MemoryId,
686 target: MemoryId,
687 relation: EdgeRelation,
688) -> bool {
689 if community_edge_exists(db, source, target, relation).await {
690 return false;
691 }
692
693 match db
694 .connect_with(source, target, relation, 1.0, Metadata::default())
695 .await
696 {
697 Ok(_) => true,
698 Err(hirn_core::HirnError::AlreadyExists(error)) => {
699 if community_edge_exists(db, source, target, relation).await {
700 true
701 } else {
702 tracing::warn!(
703 source = %source,
704 target = %target,
705 relation = ?relation,
706 error = %error,
707 "community edge write reported duplicate without leaving a visible edge"
708 );
709 false
710 }
711 }
712 Err(error) => {
713 tracing::warn!(
714 source = %source,
715 target = %target,
716 relation = ?relation,
717 error = %error,
718 "failed to create community summary edge"
719 );
720 false
721 }
722 }
723}
724
725async fn repair_community_membership_edges(
726 db: &HirnDB,
727 summary_id: MemoryId,
728 members: &[MemoryId],
729) -> usize {
730 let mut edges_created = 0;
731
732 for &member_id in members {
733 if ensure_community_edge(db, summary_id, member_id, EdgeRelation::DerivedFrom).await {
734 edges_created += 1;
735 }
736 if ensure_community_edge(db, member_id, summary_id, EdgeRelation::PartOf).await {
737 edges_created += 1;
738 }
739 }
740
741 edges_created
742}
743
744pub async fn generate_community_summaries(
750 db: &HirnDB,
751 llm: &Arc<dyn LlmProvider>,
752 communities: &CommunityResult,
753 max_members_per_prompt: usize,
754 llm_timeout: std::time::Duration,
755) -> HirnResult<CommunitySummaryResult> {
756 if communities.levels.is_empty() {
757 return Ok(CommunitySummaryResult {
758 summaries_stored: 0,
759 edges_created: 0,
760 });
761 }
762
763 let agent = AgentId::well_known("community");
764 let leaf_communities = &communities.levels[0];
765 let mut summaries_stored = 0;
766 let mut edges_created = 0;
767
768 for community in leaf_communities {
769 if community.members.is_empty() {
770 continue;
771 }
772
773 let descriptions =
775 collect_member_descriptions(db, &community.members, max_members_per_prompt).await;
776 if descriptions.is_empty() {
777 continue;
778 }
779
780 let concept_name = format!("community-{}-{}", community.level, community.index);
781
782 if let Ok(existing) = db.get_semantic_by_concept(&concept_name).await {
784 edges_created +=
785 repair_community_membership_edges(db, existing.id, &community.members).await;
786 continue;
787 }
788
789 let member_text = descriptions
791 .iter()
792 .enumerate()
793 .map(|(i, d)| format!("{}. {}", i + 1, d))
794 .collect::<Vec<_>>()
795 .join("\n");
796
797 let system = ChatMessage {
798 role: "system".to_string(),
799 content: "You are an analyst that produces concise community summaries. \
800 Given a list of related memory descriptions, produce a structured summary \
801 with the following format:\n\
802 THEME: <one-line theme>\n\
803 KEY_ENTITIES: <comma-separated key entities>\n\
804 SUMMARY: <2-4 sentence summary including representative examples>"
805 .to_string(),
806 };
807 let sanitized_member_text = hirn_core::sanitize::sanitize_for_llm(&member_text);
808 let user = ChatMessage {
809 role: "user".to_string(),
810 content: format!(
811 "Summarize the following {} related memories (community level {}, index {}) \
812 into a structured community summary:\n\n{}",
813 descriptions.len(),
814 community.level,
815 community.index,
816 sanitized_member_text
817 ),
818 };
819
820 let options = LlmOptions {
821 temperature: 0.3,
822 max_tokens: 256,
823 ..Default::default()
824 };
825
826 let summary =
827 super::generate_text_with_timeout(llm.as_ref(), &[system, user], &options, llm_timeout)
828 .await?;
829
830 let mut builder = SemanticRecord::builder()
832 .concept(&concept_name)
833 .knowledge_type(KnowledgeType::Community)
834 .description(&summary)
835 .confidence(0.7)
836 .agent_id(agent.clone())
837 .origin(Origin::Consolidation);
838
839 if let Ok(emb) = db.embed_text(&summary).await {
841 builder = builder.embedding(emb);
842 }
843
844 for &member_id in &community.members {
846 builder = builder.source_episode(member_id);
847 }
848
849 let record = builder.build()?;
850 let semantic_id = db.store_semantic(record).await?;
851 summaries_stored += 1;
852
853 edges_created +=
854 repair_community_membership_edges(db, semantic_id, &community.members).await;
855 }
856
857 Ok(CommunitySummaryResult {
858 summaries_stored,
859 edges_created,
860 })
861}
862
863pub async fn generate_community_summaries_incremental(
870 db: &HirnDB,
871 llm: &Arc<dyn LlmProvider>,
872 prev: &CommunityResult,
873 new: &CommunityResult,
874 max_members_per_prompt: usize,
875 llm_timeout: std::time::Duration,
876) -> HirnResult<CommunitySummaryResult> {
877 let delta = compute_community_delta(prev, new);
878
879 for &removed_idx in &delta.removed {
881 let concept_name = format!("community-0-{removed_idx}");
882 if let Ok(record) = db.get_semantic_by_concept(&concept_name).await {
883 db.purge_semantic(record.id).await?;
884 }
885 }
886
887 for &modified_idx in &delta.modified {
889 let concept_name = format!("community-0-{modified_idx}");
890 if let Ok(record) = db.get_semantic_by_concept(&concept_name).await {
891 db.purge_semantic(record.id).await?;
892 }
893 }
894
895 let needs_summary: std::collections::HashSet<usize> = delta
897 .added
898 .iter()
899 .chain(delta.modified.iter())
900 .copied()
901 .collect();
902
903 let filtered_leaves: Vec<Community> = new
904 .levels
905 .first()
906 .map(|l| {
907 l.iter()
908 .filter(|c| needs_summary.contains(&c.index))
909 .cloned()
910 .collect()
911 })
912 .unwrap_or_default();
913
914 let unchanged_leaves: Vec<Community> = new
915 .levels
916 .first()
917 .map(|l| {
918 l.iter()
919 .filter(|c| !needs_summary.contains(&c.index))
920 .cloned()
921 .collect()
922 })
923 .unwrap_or_default();
924
925 let mut result = if filtered_leaves.is_empty() {
926 CommunitySummaryResult {
927 summaries_stored: 0,
928 edges_created: 0,
929 }
930 } else {
931 let filtered = CommunityResult {
932 levels: vec![filtered_leaves],
933 node_to_community: new.node_to_community.clone(),
934 total_communities: needs_summary.len(),
935 };
936
937 generate_community_summaries(db, llm, &filtered, max_members_per_prompt, llm_timeout)
938 .await?
939 };
940
941 for community in unchanged_leaves {
942 let concept_name = format!("community-{}-{}", community.level, community.index);
943 if let Ok(existing) = db.get_semantic_by_concept(&concept_name).await {
944 result.edges_created +=
945 repair_community_membership_edges(db, existing.id, &community.members).await;
946 }
947 }
948
949 Ok(result)
950}
951
952async fn collect_member_descriptions(db: &HirnDB, members: &[MemoryId], max: usize) -> Vec<String> {
954 let graph = db.graph_store();
957 let mut member_layers: Vec<(MemoryId, Option<Layer>)> = Vec::new();
958 for &id in members.iter().take(max) {
959 let layer = graph.node_layer(id).await.ok().flatten();
960 member_layers.push((id, layer));
961 }
962
963 let mut descriptions = Vec::new();
964 for (id, layer) in member_layers {
965 let desc = match layer {
966 Some(Layer::Semantic) => db
967 .get_semantic(id)
968 .await
969 .ok()
970 .map(|r| format!("{}: {}", r.concept, r.description)),
971 Some(Layer::Episodic) => db.get_episode(id).await.ok().map(|r| r.content.clone()),
972 _ => None,
973 };
974
975 if let Some(d) = desc {
976 descriptions.push(d);
977 }
978 }
979
980 descriptions
981}
982
983#[cfg(test)]
988mod tests {
989 use super::*;
990
991 use std::sync::atomic::{AtomicUsize, Ordering};
994
995 struct MockCommunityLlm {
996 response: String,
997 calls: AtomicUsize,
998 }
999
1000 impl MockCommunityLlm {
1001 fn new(response: &str) -> Self {
1002 Self {
1003 response: response.to_string(),
1004 calls: AtomicUsize::new(0),
1005 }
1006 }
1007 }
1008
1009 #[async_trait::async_trait]
1010 impl LlmProvider for MockCommunityLlm {
1011 async fn generate_text(
1012 &self,
1013 _messages: &[ChatMessage],
1014 _options: &LlmOptions,
1015 ) -> hirn_core::HirnResult<String> {
1016 self.calls.fetch_add(1, Ordering::Relaxed);
1017 Ok(self.response.clone())
1018 }
1019
1020 fn model_id(&self) -> &str {
1021 "mock-community"
1022 }
1023 }
1024
1025 async fn test_db() -> HirnDB {
1026 let dir = tempfile::tempdir().unwrap();
1027 let db_path = dir.path().join("test");
1028 let lance_path = dir.path().join("lance");
1029 let mut config = hirn_core::HirnConfig::default();
1030 config.db_path = db_path;
1031 config.embedding_dimensions = hirn_core::EmbeddingDimension::new_const(3);
1032 let storage: Arc<dyn hirn_storage::PhysicalStore> = hirn_storage::HirnDb::open(
1033 hirn_storage::HirnDbConfig::local(lance_path.to_str().unwrap()),
1034 )
1035 .await
1036 .unwrap()
1037 .store_arc();
1038 let db = HirnDB::open_with_config(config, storage).await.unwrap();
1039 std::mem::forget(dir);
1040 db
1041 }
1042
1043 #[tokio::test(flavor = "multi_thread")]
1044 async fn summary_empty_communities() {
1045 let db = test_db().await;
1046 let llm: Arc<dyn LlmProvider> = Arc::new(MockCommunityLlm::new("summary text"));
1047 let empty = CommunityResult {
1048 levels: vec![],
1049 node_to_community: HashMap::new(),
1050 total_communities: 0,
1051 };
1052 let result =
1053 generate_community_summaries(&db, &llm, &empty, 50, std::time::Duration::from_secs(30))
1054 .await
1055 .unwrap();
1056 assert_eq!(result.summaries_stored, 0);
1057 assert_eq!(result.edges_created, 0);
1058 }
1059
1060 #[tokio::test(flavor = "multi_thread")]
1061 async fn summary_generated_and_stored() {
1062 let db = test_db().await;
1063 let llm: Arc<dyn LlmProvider> = Arc::new(MockCommunityLlm::new(
1064 "THEME: Testing patterns\n\
1065 KEY_ENTITIES: test-concept-0, test-concept-1, test-concept-2\n\
1066 SUMMARY: This community is about testing. It covers 3 related concepts.",
1067 ));
1068
1069 let agent = AgentId::new("test").unwrap();
1071 let mut member_ids = Vec::new();
1072 for i in 0..3 {
1073 let record = SemanticRecord::builder()
1074 .concept(&format!("test-concept-{i}"))
1075 .description(&format!("Description for concept {i}"))
1076 .agent_id(agent.clone())
1077 .origin(Origin::Consolidation)
1078 .build()
1079 .unwrap();
1080 let id = db.store_semantic(record).await.unwrap();
1081 member_ids.push(id);
1082 }
1083
1084 let mut node_to_community = HashMap::new();
1086 for &id in &member_ids {
1087 node_to_community.insert(id, 0);
1088 }
1089 let communities = CommunityResult {
1090 levels: vec![vec![Community {
1091 level: 0,
1092 index: 0,
1093 members: member_ids.clone(),
1094 parent: None,
1095 children: vec![],
1096 }]],
1097 node_to_community,
1098 total_communities: 1,
1099 };
1100
1101 let result = generate_community_summaries(
1102 &db,
1103 &llm,
1104 &communities,
1105 50,
1106 std::time::Duration::from_secs(30),
1107 )
1108 .await
1109 .unwrap();
1110 assert_eq!(result.summaries_stored, 1);
1111
1112 let stored = db.get_semantic_by_concept("community-0-0").await.unwrap();
1114 assert_eq!(stored.knowledge_type, KnowledgeType::Community);
1115 assert!(stored.description.contains("THEME:"));
1116 assert!(stored.description.contains("KEY_ENTITIES:"));
1117 assert_eq!(stored.source_episodes.len(), 3);
1119 }
1120
1121 #[tokio::test(flavor = "multi_thread")]
1122 async fn summary_idempotent() {
1123 let db = test_db().await;
1124 let mock = Arc::new(MockCommunityLlm::new("Summary."));
1125 let llm: Arc<dyn LlmProvider> = mock.clone();
1126
1127 let agent = AgentId::new("test").unwrap();
1128 let record = SemanticRecord::builder()
1129 .concept("member-x")
1130 .description("x desc")
1131 .agent_id(agent.clone())
1132 .origin(Origin::Consolidation)
1133 .build()
1134 .unwrap();
1135 let id = db.store_semantic(record).await.unwrap();
1136
1137 let mut ntc = HashMap::new();
1138 ntc.insert(id, 0);
1139 let communities = CommunityResult {
1140 levels: vec![vec![Community {
1141 level: 0,
1142 index: 0,
1143 members: vec![id],
1144 parent: None,
1145 children: vec![],
1146 }]],
1147 node_to_community: ntc,
1148 total_communities: 1,
1149 };
1150
1151 let r1 = generate_community_summaries(
1153 &db,
1154 &llm,
1155 &communities,
1156 50,
1157 std::time::Duration::from_secs(30),
1158 )
1159 .await
1160 .unwrap();
1161
1162 let r2 = generate_community_summaries(
1164 &db,
1165 &llm,
1166 &communities,
1167 50,
1168 std::time::Duration::from_secs(30),
1169 )
1170 .await
1171 .unwrap();
1172
1173 assert_eq!(r1.summaries_stored, 1);
1175 assert_eq!(r2.summaries_stored, 0);
1176 assert_eq!(mock.calls.load(Ordering::Relaxed), 1);
1177 }
1178
1179 #[tokio::test(flavor = "multi_thread")]
1180 async fn summary_rerun_repairs_missing_membership_edges() {
1181 let db = test_db().await;
1182 let mock = Arc::new(MockCommunityLlm::new("Summary."));
1183 let llm: Arc<dyn LlmProvider> = mock.clone();
1184
1185 let agent = AgentId::new("test").unwrap();
1186 let mut member_ids = Vec::new();
1187 for i in 0..3 {
1188 let record = SemanticRecord::builder()
1189 .concept(&format!("member-{i}"))
1190 .description("member")
1191 .agent_id(agent.clone())
1192 .origin(Origin::Consolidation)
1193 .build()
1194 .unwrap();
1195 member_ids.push(db.store_semantic(record).await.unwrap());
1196 }
1197
1198 let mut node_to_community = HashMap::new();
1199 for &id in &member_ids {
1200 node_to_community.insert(id, 0);
1201 }
1202 let communities = CommunityResult {
1203 levels: vec![vec![Community {
1204 level: 0,
1205 index: 0,
1206 members: member_ids.clone(),
1207 parent: None,
1208 children: vec![],
1209 }]],
1210 node_to_community,
1211 total_communities: 1,
1212 };
1213
1214 let first = generate_community_summaries(
1215 &db,
1216 &llm,
1217 &communities,
1218 50,
1219 std::time::Duration::from_secs(30),
1220 )
1221 .await
1222 .unwrap();
1223 assert_eq!(first.summaries_stored, 1);
1224
1225 let summary = db.get_semantic_by_concept("community-0-0").await.unwrap();
1226 for &member_id in &member_ids {
1227 let edges = db
1228 .cached_graph()
1229 .get_edges_between(summary.id, member_id)
1230 .await
1231 .unwrap();
1232 for edge in edges {
1233 if (edge.relation == EdgeRelation::DerivedFrom
1234 && edge.source == summary.id
1235 && edge.target == member_id)
1236 || (edge.relation == EdgeRelation::PartOf
1237 && edge.source == member_id
1238 && edge.target == summary.id)
1239 {
1240 db.cached_graph().remove_edge(edge.id).await.unwrap();
1241 }
1242 }
1243 }
1244
1245 let second = generate_community_summaries(
1246 &db,
1247 &llm,
1248 &communities,
1249 50,
1250 std::time::Duration::from_secs(30),
1251 )
1252 .await
1253 .unwrap();
1254
1255 assert_eq!(second.summaries_stored, 0);
1256 assert_eq!(second.edges_created, member_ids.len() * 2);
1257 assert_eq!(mock.calls.load(Ordering::Relaxed), 1);
1258
1259 for &member_id in &member_ids {
1260 assert!(
1261 community_edge_exists(&db, summary.id, member_id, EdgeRelation::DerivedFrom).await,
1262 "summary should regain DerivedFrom edge to member"
1263 );
1264 assert!(
1265 community_edge_exists(&db, member_id, summary.id, EdgeRelation::PartOf).await,
1266 "member should regain PartOf edge to summary"
1267 );
1268 }
1269 }
1270
1271 #[tokio::test(flavor = "multi_thread")]
1272 async fn incremental_only_affected_communities_regenerated() {
1273 let db = test_db().await;
1274 let mock = Arc::new(MockCommunityLlm::new(
1275 "THEME: Test\nKEY_ENTITIES: a\nSUMMARY: Test.",
1276 ));
1277 let llm: Arc<dyn LlmProvider> = mock.clone();
1278
1279 let agent = AgentId::new("test").unwrap();
1280
1281 let mut cluster_a = Vec::new();
1283 let mut cluster_b = Vec::new();
1284 for i in 0..3 {
1285 let record = SemanticRecord::builder()
1286 .concept(&format!("auth-concept-{i}"))
1287 .description(&format!("Auth pattern {i}"))
1288 .agent_id(agent.clone())
1289 .origin(Origin::Consolidation)
1290 .build()
1291 .unwrap();
1292 cluster_a.push(db.store_semantic(record).await.unwrap());
1293 }
1294 for i in 0..3 {
1295 let record = SemanticRecord::builder()
1296 .concept(&format!("cache-concept-{i}"))
1297 .description(&format!("Cache pattern {i}"))
1298 .agent_id(agent.clone())
1299 .origin(Origin::Consolidation)
1300 .build()
1301 .unwrap();
1302 cluster_b.push(db.store_semantic(record).await.unwrap());
1303 }
1304
1305 let mut ntc = HashMap::new();
1307 for &id in &cluster_a {
1308 ntc.insert(id, 0);
1309 }
1310 for &id in &cluster_b {
1311 ntc.insert(id, 1);
1312 }
1313 let prev = CommunityResult {
1314 levels: vec![vec![
1315 Community {
1316 level: 0,
1317 index: 0,
1318 members: cluster_a.clone(),
1319 parent: None,
1320 children: vec![],
1321 },
1322 Community {
1323 level: 0,
1324 index: 1,
1325 members: cluster_b.clone(),
1326 parent: None,
1327 children: vec![],
1328 },
1329 ]],
1330 node_to_community: ntc,
1331 total_communities: 2,
1332 };
1333
1334 let r1 =
1336 generate_community_summaries(&db, &llm, &prev, 50, std::time::Duration::from_secs(30))
1337 .await
1338 .unwrap();
1339 assert_eq!(r1.summaries_stored, 2);
1340 assert_eq!(mock.calls.load(Ordering::Relaxed), 2);
1341
1342 let mut cluster_c = Vec::new();
1344 for i in 0..5 {
1345 let record = SemanticRecord::builder()
1346 .concept(&format!("new-topic-{i}"))
1347 .description(&format!("New topic episode {i}"))
1348 .agent_id(agent.clone())
1349 .origin(Origin::Consolidation)
1350 .build()
1351 .unwrap();
1352 cluster_c.push(db.store_semantic(record).await.unwrap());
1353 }
1354
1355 let mut ntc_new = HashMap::new();
1357 for &id in &cluster_a {
1358 ntc_new.insert(id, 0);
1359 }
1360 for &id in &cluster_b {
1361 ntc_new.insert(id, 1);
1362 }
1363 for &id in &cluster_c {
1364 ntc_new.insert(id, 2);
1365 }
1366 let new = CommunityResult {
1367 levels: vec![vec![
1368 Community {
1369 level: 0,
1370 index: 0,
1371 members: cluster_a.clone(),
1372 parent: None,
1373 children: vec![],
1374 },
1375 Community {
1376 level: 0,
1377 index: 1,
1378 members: cluster_b.clone(),
1379 parent: None,
1380 children: vec![],
1381 },
1382 Community {
1383 level: 0,
1384 index: 2,
1385 members: cluster_c.clone(),
1386 parent: None,
1387 children: vec![],
1388 },
1389 ]],
1390 node_to_community: ntc_new,
1391 total_communities: 3,
1392 };
1393
1394 mock.calls.store(0, Ordering::Relaxed);
1396 let r2 = generate_community_summaries_incremental(
1397 &db,
1398 &llm,
1399 &prev,
1400 &new,
1401 50,
1402 std::time::Duration::from_secs(30),
1403 )
1404 .await
1405 .unwrap();
1406 assert_eq!(
1407 r2.summaries_stored, 1,
1408 "only the new community should be summarized"
1409 );
1410 assert_eq!(
1411 mock.calls.load(Ordering::Relaxed),
1412 1,
1413 "LLM called only for new community"
1414 );
1415
1416 assert!(db.get_semantic_by_concept("community-0-2").await.is_ok());
1418 assert!(db.get_semantic_by_concept("community-0-0").await.is_ok());
1420 assert!(db.get_semantic_by_concept("community-0-1").await.is_ok());
1421 }
1422
1423 #[tokio::test(flavor = "multi_thread")]
1424 async fn incremental_rerun_repairs_edges_for_unchanged_communities() {
1425 let db = test_db().await;
1426 let mock = Arc::new(MockCommunityLlm::new("Summary."));
1427 let llm: Arc<dyn LlmProvider> = mock.clone();
1428
1429 let agent = AgentId::new("test").unwrap();
1430 let mut member_ids = Vec::new();
1431 for i in 0..2 {
1432 let record = SemanticRecord::builder()
1433 .concept(&format!("inc-member-{i}"))
1434 .description("member")
1435 .agent_id(agent.clone())
1436 .origin(Origin::Consolidation)
1437 .build()
1438 .unwrap();
1439 member_ids.push(db.store_semantic(record).await.unwrap());
1440 }
1441
1442 let mut node_to_community = HashMap::new();
1443 for &id in &member_ids {
1444 node_to_community.insert(id, 0);
1445 }
1446 let communities = CommunityResult {
1447 levels: vec![vec![Community {
1448 level: 0,
1449 index: 0,
1450 members: member_ids.clone(),
1451 parent: None,
1452 children: vec![],
1453 }]],
1454 node_to_community,
1455 total_communities: 1,
1456 };
1457
1458 generate_community_summaries(
1459 &db,
1460 &llm,
1461 &communities,
1462 50,
1463 std::time::Duration::from_secs(30),
1464 )
1465 .await
1466 .unwrap();
1467
1468 let summary = db.get_semantic_by_concept("community-0-0").await.unwrap();
1469 for &member_id in &member_ids {
1470 let edges = db
1471 .cached_graph()
1472 .get_edges_between(summary.id, member_id)
1473 .await
1474 .unwrap();
1475 for edge in edges {
1476 if edge.relation == EdgeRelation::DerivedFrom
1477 && edge.source == summary.id
1478 && edge.target == member_id
1479 {
1480 db.cached_graph().remove_edge(edge.id).await.unwrap();
1481 }
1482 }
1483 }
1484
1485 let rerun = generate_community_summaries_incremental(
1486 &db,
1487 &llm,
1488 &communities,
1489 &communities,
1490 50,
1491 std::time::Duration::from_secs(30),
1492 )
1493 .await
1494 .unwrap();
1495
1496 assert_eq!(rerun.summaries_stored, 0);
1497 assert_eq!(rerun.edges_created, member_ids.len());
1498 assert_eq!(mock.calls.load(Ordering::Relaxed), 1);
1499
1500 for &member_id in &member_ids {
1501 assert!(
1502 community_edge_exists(&db, summary.id, member_id, EdgeRelation::DerivedFrom).await,
1503 "incremental rerun should repair unchanged community summary edge"
1504 );
1505 }
1506 }
1507
1508 #[test]
1509 fn compute_delta_identifies_changes() {
1510 let ids_a: Vec<MemoryId> = (0..3).map(|_| MemoryId::new()).collect();
1511 let ids_b: Vec<MemoryId> = (0..3).map(|_| MemoryId::new()).collect();
1512 let ids_c: Vec<MemoryId> = (0..3).map(|_| MemoryId::new()).collect();
1513
1514 let prev = CommunityResult {
1516 levels: vec![vec![
1517 Community {
1518 level: 0,
1519 index: 0,
1520 members: ids_a.clone(),
1521 parent: None,
1522 children: vec![],
1523 },
1524 Community {
1525 level: 0,
1526 index: 1,
1527 members: ids_b.clone(),
1528 parent: None,
1529 children: vec![],
1530 },
1531 ]],
1532 node_to_community: HashMap::new(),
1533 total_communities: 2,
1534 };
1535
1536 let new = CommunityResult {
1538 levels: vec![vec![
1539 Community {
1540 level: 0,
1541 index: 0,
1542 members: ids_a.clone(),
1543 parent: None,
1544 children: vec![],
1545 },
1546 Community {
1547 level: 0,
1548 index: 2,
1549 members: ids_c.clone(),
1550 parent: None,
1551 children: vec![],
1552 },
1553 ]],
1554 node_to_community: HashMap::new(),
1555 total_communities: 2,
1556 };
1557
1558 let delta = compute_community_delta(&prev, &new);
1559
1560 assert!(
1561 delta.unchanged.contains(&0),
1562 "community 0 should be unchanged"
1563 );
1564 assert!(delta.added.contains(&2), "community 2 should be added");
1565 assert!(delta.removed.contains(&1), "community 1 should be removed");
1566 assert!(
1567 delta.modified.is_empty(),
1568 "no communities should be modified"
1569 );
1570 }
1571}