use serde::{Deserialize, Serialize};
use std::io::{Read, Write};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CanonicalConcept {
pub id: String,
pub version: u32,
pub labels: Vec<String>,
pub related: Vec<String>,
}
impl CanonicalConcept {
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
version: 1,
labels: Vec::new(),
related: Vec::new(),
}
}
pub fn with_label(mut self, label: impl Into<String>) -> Self {
self.labels.push(label.into());
self
}
pub fn with_related(mut self, related_id: impl Into<String>) -> Self {
self.related.push(related_id.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BridgeConfig {
pub max_expansion_depth: u8,
pub max_packet_facts: usize,
pub token_budget: usize,
pub deterministic_weight: f32,
pub concept_weight: f32,
pub semantic_weight: f32,
}
impl Default for BridgeConfig {
fn default() -> Self {
Self {
max_expansion_depth: 2,
max_packet_facts: 20,
token_budget: 1000,
deterministic_weight: 0.6,
concept_weight: 0.3,
semantic_weight: 0.1,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoreBreakdown {
pub deterministic: f32,
pub concept: f32,
pub semantic: f32,
pub final_score: f32,
pub evidence: Vec<String>,
}
impl ScoreBreakdown {
pub fn deterministic_only(score: f32) -> Self {
Self {
deterministic: score,
concept: 0.0,
semantic: 0.0,
final_score: score,
evidence: vec!["deterministic_recall".to_string()],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BridgeHit {
pub id: String,
pub text_preview: Option<String>,
pub scores: ScoreBreakdown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryPacket {
pub query_intent: String,
pub facts: Vec<String>,
pub sources: Vec<String>,
pub confidence: f32,
}
impl MemoryPacket {
pub fn estimated_tokens(&self) -> usize {
let word_count: usize = self
.facts
.iter()
.map(|f| f.split_whitespace().count())
.sum();
(word_count as f32 / 0.75).ceil() as usize
}
pub fn to_json(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
pub fn to_json_pretty(&self) -> serde_json::Result<String> {
serde_json::to_string_pretty(self)
}
}
pub trait SemanticReranker: Send + Sync {
fn version(&self) -> &str;
fn rerank(&self, query: &str, hits: &mut [BridgeHit]);
}
#[derive(Debug, Clone, Default)]
pub struct ConceptGraph {
concepts: std::collections::HashMap<String, CanonicalConcept>,
label_index: std::collections::HashMap<String, Vec<String>>,
}
impl ConceptGraph {
pub fn new() -> Self {
Self::default()
}
pub fn add_concept(&mut self, concept: CanonicalConcept) {
let id = concept.id.clone();
let labels: Vec<String> = concept.labels.iter().map(|l| l.to_lowercase()).collect();
for label in &labels {
self.label_index
.entry(label.clone())
.or_default()
.push(id.clone());
}
self.concepts.insert(id, concept);
}
pub fn remove_concept(&mut self, id: &str) -> Option<CanonicalConcept> {
let concept = self.concepts.remove(id)?;
for label in &concept.labels {
if let Some(ids) = self.label_index.get_mut(&label.to_lowercase()) {
ids.retain(|i| i != id);
if ids.is_empty() {
self.label_index.remove(&label.to_lowercase());
}
}
}
Some(concept)
}
pub fn get_concept(&self, id: &str) -> Option<&CanonicalConcept> {
self.concepts.get(id)
}
pub fn match_tokens(&self, tokens: &[String]) -> Vec<String> {
let mut matched = std::collections::HashSet::new();
for token in tokens {
if let Some(ids) = self.label_index.get(&token.to_lowercase()) {
matched.extend(ids.clone());
}
}
matched.into_iter().collect()
}
pub fn expand(&self, concept_ids: &[String], max_depth: u8) -> Vec<String> {
let mut expanded = std::collections::HashSet::new();
let mut to_visit: Vec<(String, u8)> =
concept_ids.iter().map(|id| (id.clone(), 0)).collect();
let mut visited = std::collections::HashSet::new();
while let Some((id, depth)) = to_visit.pop() {
if visited.contains(&id) || depth > max_depth {
continue;
}
visited.insert(id.clone());
if let Some(concept) = self.concepts.get(&id) {
for label in &concept.labels {
expanded.insert(label.clone());
}
if depth < max_depth {
for related_id in &concept.related {
if !visited.contains(related_id) {
to_visit.push((related_id.clone(), depth + 1));
}
}
}
}
}
expanded.into_iter().collect()
}
pub fn load_from_json(reader: impl Read) -> crate::Result<Self> {
let concepts: Vec<CanonicalConcept> = serde_json::from_reader(reader)?;
let mut graph = Self::new();
for concept in concepts {
graph.add_concept(concept);
}
Ok(graph)
}
pub fn save_to_json(&self, writer: impl Write) -> crate::Result<()> {
let concepts: Vec<&CanonicalConcept> = self.concepts.values().collect();
serde_json::to_writer_pretty(writer, &concepts)?;
Ok(())
}
pub fn concept_count(&self) -> usize {
self.concepts.len()
}
pub fn label_count(&self) -> usize {
self.label_index.len()
}
pub fn all_concepts(&self) -> impl Iterator<Item = &CanonicalConcept> {
self.concepts.values()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_canonical_concept_builder() {
let concept = CanonicalConcept::new("concept.test")
.with_label("test label")
.with_related("concept.other");
assert_eq!(concept.id, "concept.test");
assert_eq!(concept.labels, vec!["test label"]);
assert_eq!(concept.related, vec!["concept.other"]);
}
#[test]
fn test_concept_graph_add_and_get() {
let mut graph = ConceptGraph::new();
let concept = CanonicalConcept::new("c1")
.with_label("label1")
.with_label("Label2"); graph.add_concept(concept);
assert_eq!(graph.concept_count(), 1);
assert_eq!(graph.label_count(), 2); assert!(graph.get_concept("c1").is_some());
}
#[test]
fn test_concept_graph_match_tokens_case_insensitive() {
let mut graph = ConceptGraph::new();
graph.add_concept(
CanonicalConcept::new("c1").with_label("agent-memory"), );
graph.add_concept(CanonicalConcept::new("c2").with_label("session"));
let matched = graph.match_tokens(&["Agent-Memory".to_string(), "SESSION".to_string()]);
assert_eq!(matched.len(), 2);
}
#[test]
fn test_concept_graph_expand() {
let mut graph = ConceptGraph::new();
graph.add_concept(
CanonicalConcept::new("c1")
.with_label("agent memory")
.with_related("c2"),
);
graph.add_concept(CanonicalConcept::new("c2").with_label("session context"));
let expanded = graph.expand(&["c1".to_string()], 1);
assert!(expanded.contains(&"agent memory".to_string()));
assert!(expanded.contains(&"session context".to_string()));
}
#[test]
fn test_concept_graph_expand_no_cycle() {
let mut graph = ConceptGraph::new();
graph.add_concept(
CanonicalConcept::new("c1")
.with_label("label1")
.with_related("c2"),
);
graph.add_concept(
CanonicalConcept::new("c2")
.with_label("label2")
.with_related("c1"), );
let expanded = graph.expand(&["c1".to_string()], 10);
assert!(expanded.contains(&"label1".to_string()));
assert!(expanded.contains(&"label2".to_string()));
}
#[test]
fn test_concept_graph_remove() {
let mut graph = ConceptGraph::new();
graph.add_concept(CanonicalConcept::new("c1").with_label("label1"));
assert_eq!(graph.label_count(), 1);
graph.remove_concept("c1");
assert_eq!(graph.concept_count(), 0);
assert_eq!(graph.label_count(), 0); }
#[test]
fn test_memory_packet_estimated_tokens() {
let packet = MemoryPacket {
query_intent: "test".to_string(),
facts: vec!["hello world".to_string(), "foo bar baz".to_string()],
sources: vec!["c1".to_string()],
confidence: 0.9,
};
assert!(packet.estimated_tokens() >= 5);
}
#[test]
fn test_bridge_config_defaults() {
let config = BridgeConfig::default();
assert_eq!(config.max_expansion_depth, 2);
assert_eq!(config.max_packet_facts, 20);
assert!((config.deterministic_weight - 0.6).abs() < 0.01);
}
}