1use std::collections::{BTreeMap, HashMap, HashSet};
2
3use arrow_array::Array;
4use futures::TryStreamExt;
5use hirn_core::HirnError;
6use hirn_core::revision::{LogicalMemoryId, RevisionId, RevisionOperation};
7use hirn_core::semantic::SemanticRecord;
8use hirn_storage::PhysicalStore;
9use hirn_storage::store::ScanOptions;
10
11use crate::db::HirnDB;
12
13#[derive(Debug, Clone)]
15pub struct IntegrityReport {
16 pub is_clean: bool,
18 pub issues: Vec<IntegrityIssue>,
20}
21
22#[derive(Debug, Clone)]
24pub struct IntegrityIssue {
25 pub kind: IssueKind,
27 pub description: String,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum IssueKind {
34 CorruptedRecord,
36 AgentMissingNamespace,
38 OrphanedGraphNode,
40}
41
42impl std::fmt::Display for IntegrityIssue {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 write!(f, "[{:?}] {}", self.kind, self.description)
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct RepairReport {
51 pub repaired: Vec<String>,
53 pub failed: Vec<String>,
55}
56
57#[derive(Debug, Clone)]
59pub struct SemanticRevisionIntegrityReport {
60 pub is_clean: bool,
62 pub logical_memory_count: usize,
64 pub revision_count: usize,
66 pub cached_head_entries: usize,
68 pub missing_cached_heads: usize,
70 pub issues: Vec<SemanticRevisionIntegrityIssue>,
72}
73
74#[derive(Debug, Clone)]
76pub struct SemanticRevisionIntegrityIssue {
77 pub kind: SemanticRevisionIssueKind,
79 pub logical_memory_id: Option<LogicalMemoryId>,
81 pub revision_id: Option<RevisionId>,
83 pub description: String,
85}
86
87#[derive(Debug, Clone, PartialEq, Eq)]
89pub enum SemanticRevisionIssueKind {
90 InvalidRevisionIdMapping,
92 DuplicateRevisionId,
94 InvalidRootRevision,
96 DuplicateVersion,
98 NonContiguousVersionChain,
100 ConflictingTerminalState,
102 SelfMergedLogicalHead,
104 StaleHeadCacheEntry,
106}
107
108impl std::fmt::Display for SemanticRevisionIntegrityIssue {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 write!(f, "[{:?}] {}", self.kind, self.description)
111 }
112}
113
114#[derive(Debug, Clone)]
116pub struct SemanticRevisionRepairReport {
117 pub refreshed_head_count: usize,
119 pub evicted_head_count: usize,
121 pub repaired: Vec<String>,
123 pub failed: Vec<String>,
125}
126
127pub async fn check_integrity(storage: &dyn PhysicalStore) -> Result<IntegrityReport, HirnError> {
134 let mut issues = Vec::new();
135
136 let episodic_ids = collect_ids(storage, "episodic", &mut issues).await?;
138 let semantic_ids = collect_ids(storage, "semantic", &mut issues).await?;
139 let procedural_ids = collect_ids(storage, "procedural", &mut issues).await?;
140
141 let all_record_ids: HashSet<String> = episodic_ids
142 .iter()
143 .chain(semantic_ids.iter())
144 .chain(procedural_ids.iter())
145 .cloned()
146 .collect();
147
148 let agent_batches = storage
150 .scan(
151 "_agents",
152 ScanOptions {
153 columns: Some(vec!["id".into()]),
154 filter: None,
155 exact_filter: None,
156 order_by: None,
157 limit: None,
158 offset: None,
159 },
160 )
161 .await
162 .unwrap_or_default();
163
164 let ns_batches = storage
165 .scan(
166 "_namespaces",
167 ScanOptions {
168 columns: Some(vec!["name".into()]),
169 filter: None,
170 exact_filter: None,
171 order_by: None,
172 limit: None,
173 offset: None,
174 },
175 )
176 .await
177 .unwrap_or_default();
178
179 let mut namespace_names: HashSet<String> = HashSet::new();
180 for batch in &ns_batches {
181 if let Some(col) = batch
182 .column_by_name("name")
183 .and_then(|c| c.as_any().downcast_ref::<arrow_array::StringArray>())
184 {
185 for i in 0..col.len() {
186 if !col.is_null(i) {
187 namespace_names.insert(col.value(i).to_string());
188 }
189 }
190 }
191 }
192
193 for batch in &agent_batches {
194 if let Some(col) = batch
195 .column_by_name("id")
196 .and_then(|c| c.as_any().downcast_ref::<arrow_array::StringArray>())
197 {
198 for i in 0..col.len() {
199 if !col.is_null(i) {
200 let agent_id = col.value(i);
201 let private_ns = format!("private:{agent_id}");
202 if !namespace_names.contains(&private_ns) {
203 issues.push(IntegrityIssue {
204 kind: IssueKind::AgentMissingNamespace,
205 description: format!(
206 "agent '{agent_id}' has no private namespace '{private_ns}'"
207 ),
208 });
209 }
210 }
211 }
212 }
213 }
214
215 let graph_batches = storage
217 .scan(
218 "_graph_nodes",
219 ScanOptions {
220 columns: Some(vec!["id".into()]),
221 filter: None,
222 exact_filter: None,
223 order_by: None,
224 limit: None,
225 offset: None,
226 },
227 )
228 .await
229 .unwrap_or_default();
230
231 for batch in &graph_batches {
232 if let Some(col) = batch
233 .column_by_name("id")
234 .and_then(|c| c.as_any().downcast_ref::<arrow_array::StringArray>())
235 {
236 for i in 0..col.len() {
237 if !col.is_null(i) {
238 let node_id = col.value(i);
239 if !all_record_ids.contains(node_id) {
240 issues.push(IntegrityIssue {
241 kind: IssueKind::OrphanedGraphNode,
242 description: format!("graph node {node_id} does not map to any record"),
243 });
244 }
245 }
246 }
247 }
248 }
249
250 let is_clean = issues.is_empty();
251 Ok(IntegrityReport { is_clean, issues })
252}
253
254pub async fn repair(storage: &dyn PhysicalStore) -> Result<RepairReport, HirnError> {
259 let mut repaired = Vec::new();
260 let failed = Vec::new();
261
262 let agent_batches = storage
264 .scan(
265 "_agents",
266 ScanOptions {
267 columns: Some(vec!["id".into()]),
268 filter: None,
269 exact_filter: None,
270 order_by: None,
271 limit: None,
272 offset: None,
273 },
274 )
275 .await
276 .unwrap_or_default();
277
278 let ns_batches = storage
279 .scan(
280 "_namespaces",
281 ScanOptions {
282 columns: Some(vec!["name".into()]),
283 filter: None,
284 exact_filter: None,
285 order_by: None,
286 limit: None,
287 offset: None,
288 },
289 )
290 .await
291 .unwrap_or_default();
292
293 let mut namespace_names: HashSet<String> = HashSet::new();
294 for batch in &ns_batches {
295 if let Some(col) = batch
296 .column_by_name("name")
297 .and_then(|c| c.as_any().downcast_ref::<arrow_array::StringArray>())
298 {
299 for i in 0..col.len() {
300 if !col.is_null(i) {
301 namespace_names.insert(col.value(i).to_string());
302 }
303 }
304 }
305 }
306
307 let mut missing_agents: Vec<String> = Vec::new();
308 for batch in &agent_batches {
309 if let Some(col) = batch
310 .column_by_name("id")
311 .and_then(|c| c.as_any().downcast_ref::<arrow_array::StringArray>())
312 {
313 for i in 0..col.len() {
314 if !col.is_null(i) {
315 let agent_id = col.value(i).to_string();
316 let private_ns = format!("private:{agent_id}");
317 if !namespace_names.contains(&private_ns) {
318 missing_agents.push(agent_id);
319 }
320 }
321 }
322 }
323 }
324
325 if !missing_agents.is_empty() {
326 for agent_id in &missing_agents {
327 if let Ok(aid) = hirn_core::types::AgentId::new(agent_id) {
328 let ns_rec = hirn_core::namespace::NamespaceRecord::private_for(&aid);
329 let batch = hirn_storage::datasets::namespace::to_batch(&[ns_rec])
330 .map_err(|e| HirnError::storage(e))?;
331 storage
332 .append("_namespaces", batch)
333 .await
334 .map_err(|e| HirnError::storage(e))?;
335 }
336 }
337 repaired.push(format!(
338 "created {} missing private namespace(s) for agents: {}",
339 missing_agents.len(),
340 missing_agents.join(", ")
341 ));
342 }
343
344 Ok(RepairReport { repaired, failed })
345}
346
347pub async fn check_semantic_revision_integrity(
349 db: &HirnDB,
350) -> Result<SemanticRevisionIntegrityReport, HirnError> {
351 Ok(collect_semantic_revision_state(db).await?.report)
352}
353
354pub async fn repair_semantic_revision_integrity(
358 db: &HirnDB,
359) -> Result<SemanticRevisionRepairReport, HirnError> {
360 let state = collect_semantic_revision_state(db).await?;
361
362 let safe_heads: HashMap<LogicalMemoryId, SemanticRecord> = state
363 .authoritative_heads
364 .iter()
365 .filter(|(logical_memory_id, _)| !state.structurally_corrupted.contains(logical_memory_id))
366 .map(|(logical_memory_id, record)| (*logical_memory_id, record.clone()))
367 .collect();
368
369 let stale_replacements = state
370 .cached_heads
371 .iter()
372 .filter(|(logical_memory_id, cached)| {
373 safe_heads
374 .get(logical_memory_id)
375 .is_some_and(|expected| expected.revision_id != cached.revision_id)
376 })
377 .count();
378 let warmed_missing = safe_heads
379 .keys()
380 .filter(|logical_memory_id| !state.cached_heads.contains_key(logical_memory_id))
381 .count();
382 let evicted_head_count = state
383 .cached_heads
384 .keys()
385 .filter(|logical_memory_id| !safe_heads.contains_key(logical_memory_id))
386 .count();
387
388 db.replace_semantic_heads(safe_heads.into_values());
389
390 let mut repaired = Vec::new();
391 if !state.authoritative_heads.is_empty() || !state.cached_heads.is_empty() {
392 repaired.push(format!(
393 "rebuilt semantic head cache with {} authoritative head(s); replaced {} stale entry(s), warmed {} missing entry(s), evicted {} unsafe entry(s)",
394 state
395 .authoritative_heads
396 .len()
397 .saturating_sub(state.structurally_corrupted.len()),
398 stale_replacements,
399 warmed_missing,
400 evicted_head_count,
401 ));
402 }
403
404 let mut failed = Vec::new();
405 let mut seen_failures = HashSet::new();
406 for issue in state
407 .report
408 .issues
409 .iter()
410 .filter(|issue| issue.kind != SemanticRevisionIssueKind::StaleHeadCacheEntry)
411 {
412 if seen_failures.insert(issue.description.clone()) {
413 failed.push(issue.description.clone());
414 }
415 }
416
417 Ok(SemanticRevisionRepairReport {
418 refreshed_head_count: state
419 .authoritative_heads
420 .len()
421 .saturating_sub(state.structurally_corrupted.len()),
422 evicted_head_count,
423 repaired,
424 failed,
425 })
426}
427
428async fn collect_ids(
430 storage: &dyn PhysicalStore,
431 dataset: &str,
432 issues: &mut Vec<IntegrityIssue>,
433) -> Result<HashSet<String>, HirnError> {
434 let mut ids = HashSet::new();
435 let batches = storage
436 .scan(
437 dataset,
438 ScanOptions {
439 columns: Some(vec!["id".into()]),
440 filter: None,
441 exact_filter: None,
442 order_by: None,
443 limit: None,
444 offset: None,
445 },
446 )
447 .await
448 .unwrap_or_default();
449
450 for batch in &batches {
451 if let Some(col) = batch
452 .column_by_name("id")
453 .and_then(|c| c.as_any().downcast_ref::<arrow_array::StringArray>())
454 {
455 for i in 0..col.len() {
456 if !col.is_null(i) {
457 ids.insert(col.value(i).to_string());
458 }
459 }
460 } else if batch.num_rows() > 0 {
461 issues.push(IntegrityIssue {
462 kind: IssueKind::CorruptedRecord,
463 description: format!(
464 "{dataset} dataset has {n} rows but missing or invalid 'id' column",
465 n = batch.num_rows(),
466 ),
467 });
468 }
469 }
470
471 Ok(ids)
472}
473
474struct SemanticRevisionValidationState {
475 report: SemanticRevisionIntegrityReport,
476 authoritative_heads: HashMap<LogicalMemoryId, SemanticRecord>,
477 cached_heads: HashMap<LogicalMemoryId, SemanticRecord>,
478 structurally_corrupted: HashSet<LogicalMemoryId>,
479}
480
481async fn collect_semantic_revision_state(
482 db: &HirnDB,
483) -> Result<SemanticRevisionValidationState, HirnError> {
484 let mut issues = Vec::new();
485 let mut structurally_corrupted = HashSet::new();
486 let mut revision_owners = HashMap::new();
487 let mut chains: HashMap<LogicalMemoryId, Vec<SemanticRecord>> = HashMap::new();
488
489 let mut batches = db
490 .storage_backend()
491 .scan_stream(
492 hirn_storage::datasets::semantic::DATASET_NAME,
493 ScanOptions::default(),
494 )
495 .await
496 .map_err(HirnError::storage)?;
497
498 while let Some(batch) = batches.try_next().await.map_err(HirnError::storage)? {
499 let records =
500 hirn_storage::datasets::semantic::from_batch(&batch).map_err(HirnError::storage)?;
501 for record in records {
502 if record.revision_id.as_memory_id() != record.id {
503 structurally_corrupted.insert(record.logical_memory_id);
504 issues.push(SemanticRevisionIntegrityIssue {
505 kind: SemanticRevisionIssueKind::InvalidRevisionIdMapping,
506 logical_memory_id: Some(record.logical_memory_id),
507 revision_id: Some(record.revision_id),
508 description: format!(
509 "logical memory {} has revision {} stored on mismatched record {}",
510 record.logical_memory_id, record.revision_id, record.id
511 ),
512 });
513 }
514
515 if let Some((other_logical_memory_id, other_record_id)) =
516 revision_owners.insert(record.revision_id, (record.logical_memory_id, record.id))
517 {
518 structurally_corrupted.insert(record.logical_memory_id);
519 structurally_corrupted.insert(other_logical_memory_id);
520 issues.push(SemanticRevisionIntegrityIssue {
521 kind: SemanticRevisionIssueKind::DuplicateRevisionId,
522 logical_memory_id: Some(record.logical_memory_id),
523 revision_id: Some(record.revision_id),
524 description: format!(
525 "revision {} is claimed by records {} ({}) and {} ({})",
526 record.revision_id,
527 other_record_id,
528 other_logical_memory_id,
529 record.id,
530 record.logical_memory_id,
531 ),
532 });
533 }
534
535 chains
536 .entry(record.logical_memory_id)
537 .or_default()
538 .push(record);
539 }
540 }
541
542 let revision_count = chains.values().map(Vec::len).sum();
543 let logical_memory_count = chains.len();
544
545 let mut authoritative_heads = HashMap::with_capacity(chains.len());
546 for (logical_memory_id, records) in &chains {
547 if let Some(head) = validate_semantic_chain(
548 *logical_memory_id,
549 records,
550 &mut issues,
551 &mut structurally_corrupted,
552 ) {
553 authoritative_heads.insert(*logical_memory_id, head);
554 }
555 }
556
557 let cached_heads = db.cached_semantic_heads_snapshot();
558 let missing_cached_heads = authoritative_heads
559 .keys()
560 .filter(|logical_memory_id| !cached_heads.contains_key(logical_memory_id))
561 .count();
562
563 for (logical_memory_id, cached_head) in &cached_heads {
564 match authoritative_heads.get(logical_memory_id) {
565 Some(authoritative_head)
566 if authoritative_head.revision_id == cached_head.revision_id => {}
567 Some(authoritative_head) => issues.push(SemanticRevisionIntegrityIssue {
568 kind: SemanticRevisionIssueKind::StaleHeadCacheEntry,
569 logical_memory_id: Some(*logical_memory_id),
570 revision_id: Some(cached_head.revision_id),
571 description: format!(
572 "logical memory {} cached head {} diverges from authoritative head {}",
573 logical_memory_id, cached_head.revision_id, authoritative_head.revision_id,
574 ),
575 }),
576 None => issues.push(SemanticRevisionIntegrityIssue {
577 kind: SemanticRevisionIssueKind::StaleHeadCacheEntry,
578 logical_memory_id: Some(*logical_memory_id),
579 revision_id: Some(cached_head.revision_id),
580 description: format!(
581 "logical memory {} has cached head {} but no authoritative semantic chain",
582 logical_memory_id, cached_head.revision_id,
583 ),
584 }),
585 }
586 }
587
588 let report = SemanticRevisionIntegrityReport {
589 is_clean: issues.is_empty(),
590 logical_memory_count,
591 revision_count,
592 cached_head_entries: cached_heads.len(),
593 missing_cached_heads,
594 issues,
595 };
596
597 Ok(SemanticRevisionValidationState {
598 report,
599 authoritative_heads,
600 cached_heads,
601 structurally_corrupted,
602 })
603}
604
605fn validate_semantic_chain(
606 logical_memory_id: LogicalMemoryId,
607 records: &[SemanticRecord],
608 issues: &mut Vec<SemanticRevisionIntegrityIssue>,
609 structurally_corrupted: &mut HashSet<LogicalMemoryId>,
610) -> Option<SemanticRecord> {
611 let mut head = None;
612 let mut versions: BTreeMap<u32, Vec<&SemanticRecord>> = BTreeMap::new();
613 let mut has_root_create = false;
614
615 for record in records {
616 if head
617 .as_ref()
618 .is_none_or(|current| semantic_revision_is_newer(record, current))
619 {
620 head = Some(record.clone());
621 }
622
623 versions.entry(record.version).or_default().push(record);
624 if record.version == 1 && record.revision_operation == RevisionOperation::Create {
625 has_root_create = true;
626 }
627 }
628
629 if !has_root_create {
630 structurally_corrupted.insert(logical_memory_id);
631 issues.push(SemanticRevisionIntegrityIssue {
632 kind: SemanticRevisionIssueKind::InvalidRootRevision,
633 logical_memory_id: Some(logical_memory_id),
634 revision_id: None,
635 description: format!(
636 "logical memory {} is missing a version-1 create revision",
637 logical_memory_id
638 ),
639 });
640 }
641
642 for (version, bucket) in &versions {
643 if bucket.len() > 1 {
644 structurally_corrupted.insert(logical_memory_id);
645 issues.push(SemanticRevisionIntegrityIssue {
646 kind: SemanticRevisionIssueKind::DuplicateVersion,
647 logical_memory_id: Some(logical_memory_id),
648 revision_id: None,
649 description: format!(
650 "logical memory {} has {} revisions claiming version {}",
651 logical_memory_id,
652 bucket.len(),
653 version,
654 ),
655 });
656 }
657 }
658
659 let expected_versions: Vec<u32> = (1..=records.len() as u32).collect();
660 let actual_versions: Vec<u32> = versions.keys().copied().collect();
661 if actual_versions != expected_versions {
662 structurally_corrupted.insert(logical_memory_id);
663 issues.push(SemanticRevisionIntegrityIssue {
664 kind: SemanticRevisionIssueKind::NonContiguousVersionChain,
665 logical_memory_id: Some(logical_memory_id),
666 revision_id: None,
667 description: format!(
668 "logical memory {} has non-contiguous versions {:?} (expected {:?})",
669 logical_memory_id, actual_versions, expected_versions,
670 ),
671 });
672 }
673
674 if let Some(head) = &head {
675 if head.is_retracted() && head.is_merged() {
676 structurally_corrupted.insert(logical_memory_id);
677 issues.push(SemanticRevisionIntegrityIssue {
678 kind: SemanticRevisionIssueKind::ConflictingTerminalState,
679 logical_memory_id: Some(logical_memory_id),
680 revision_id: Some(head.revision_id),
681 description: format!(
682 "logical memory {} head {} is both retracted and merged",
683 logical_memory_id, head.revision_id,
684 ),
685 });
686 }
687
688 if head.merged_into == Some(logical_memory_id) {
689 structurally_corrupted.insert(logical_memory_id);
690 issues.push(SemanticRevisionIntegrityIssue {
691 kind: SemanticRevisionIssueKind::SelfMergedLogicalHead,
692 logical_memory_id: Some(logical_memory_id),
693 revision_id: Some(head.revision_id),
694 description: format!(
695 "logical memory {} head {} claims a self-merge",
696 logical_memory_id, head.revision_id,
697 ),
698 });
699 }
700 }
701
702 head
703}
704
705fn semantic_revision_is_newer(candidate: &SemanticRecord, current: &SemanticRecord) -> bool {
706 candidate.version > current.version
707 || (candidate.version == current.version
708 && (candidate.created_at > current.created_at
709 || (candidate.created_at == current.created_at
710 && candidate.revision_id > current.revision_id)))
711}
712
713#[cfg(test)]
714mod tests {
715 use super::*;
716 use hirn_storage::memory_store::MemoryStore;
717 use std::sync::Arc;
718
719 fn null_storage() -> Arc<dyn hirn_storage::PhysicalStore> {
720 Arc::new(MemoryStore::new())
721 }
722
723 #[tokio::test]
724 async fn check_empty_database_is_clean() {
725 let storage = null_storage();
726 let report = check_integrity(storage.as_ref()).await.unwrap();
727 assert!(
728 report.is_clean,
729 "empty DB should be clean: {:?}",
730 report.issues
731 );
732 }
733
734 #[tokio::test]
735 async fn repair_on_empty_database_is_noop() {
736 let storage = null_storage();
737 let report = repair(storage.as_ref()).await.unwrap();
738 assert!(report.repaired.is_empty(), "nothing to repair on empty DB");
739 assert!(report.failed.is_empty());
740 }
741}