use std::fmt;
use std::str::FromStr;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MemoryRoute {
Keyword,
Semantic,
#[default]
Hybrid,
Graph,
Episodic,
}
#[derive(Debug, Clone)]
pub struct RoutingDecision {
pub route: MemoryRoute,
pub confidence: f32,
pub reasoning: Option<String>,
}
pub trait MemoryRouter: Send + Sync {
fn route(&self, query: &str) -> MemoryRoute;
fn route_with_confidence(&self, query: &str) -> RoutingDecision {
RoutingDecision {
route: self.route(query),
confidence: 1.0,
reasoning: None,
}
}
}
pub trait AsyncMemoryRouter: MemoryRouter {
fn route_async<'a>(
&'a self,
query: &'a str,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = RoutingDecision> + Send + 'a>>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum RecallView {
#[default]
Head,
ZoomIn,
ZoomOut,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum CompressionLevel {
Episodic,
Procedural,
Declarative,
}
impl CompressionLevel {
#[must_use]
pub const fn cost_factor(self) -> f32 {
match self {
Self::Episodic => 1.0,
Self::Procedural => 0.6,
Self::Declarative => 0.3,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
pub struct AnchoredSummary {
pub session_intent: String,
pub files_modified: Vec<String>,
pub decisions_made: Vec<String>,
pub open_questions: Vec<String>,
pub next_steps: Vec<String>,
}
impl AnchoredSummary {
#[must_use]
pub fn is_complete(&self) -> bool {
!self.session_intent.trim().is_empty() && !self.next_steps.is_empty()
}
#[must_use]
pub fn to_markdown(&self) -> String {
let mut out = String::with_capacity(512);
out.push_str("[anchored summary]\n");
out.push_str("## Session Intent\n");
out.push_str(&self.session_intent);
out.push('\n');
if !self.files_modified.is_empty() {
out.push_str("\n## Files Modified\n");
for entry in &self.files_modified {
let clean = entry.trim_start_matches("- ");
out.push_str("- ");
out.push_str(clean);
out.push('\n');
}
}
if !self.decisions_made.is_empty() {
out.push_str("\n## Decisions Made\n");
for entry in &self.decisions_made {
let clean = entry.trim_start_matches("- ");
out.push_str("- ");
out.push_str(clean);
out.push('\n');
}
}
if !self.open_questions.is_empty() {
out.push_str("\n## Open Questions\n");
for entry in &self.open_questions {
let clean = entry.trim_start_matches("- ");
out.push_str("- ");
out.push_str(clean);
out.push('\n');
}
}
if !self.next_steps.is_empty() {
out.push_str("\n## Next Steps\n");
for entry in &self.next_steps {
let clean = entry.trim_start_matches("- ");
out.push_str("- ");
out.push_str(clean);
out.push('\n');
}
}
out
}
pub fn validate(&self) -> Result<(), String> {
const MAX_INTENT: usize = 2_000;
const MAX_ENTRY: usize = 500;
const MAX_VEC_LEN: usize = 50;
if self.session_intent.len() > MAX_INTENT {
return Err(format!(
"session_intent exceeds {MAX_INTENT} chars (got {})",
self.session_intent.len()
));
}
for (field, entries) in [
("files_modified", &self.files_modified),
("decisions_made", &self.decisions_made),
("open_questions", &self.open_questions),
("next_steps", &self.next_steps),
] {
if entries.len() > MAX_VEC_LEN {
return Err(format!(
"{field} has {} entries (max {MAX_VEC_LEN})",
entries.len()
));
}
for entry in entries {
if entry.len() > MAX_ENTRY {
return Err(format!(
"{field} entry exceeds {MAX_ENTRY} chars (got {})",
entry.len()
));
}
}
}
Ok(())
}
#[must_use]
pub fn to_json(&self) -> String {
serde_json::to_string(self).expect("AnchoredSummary serialization is infallible")
}
}
#[derive(Debug, Clone)]
pub struct SpreadingActivationParams {
pub decay_lambda: f32,
pub max_hops: u32,
pub activation_threshold: f32,
pub inhibition_threshold: f32,
pub max_activated_nodes: usize,
pub temporal_decay_rate: f64,
pub seed_structural_weight: f32,
pub seed_community_cap: usize,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EdgeType {
#[default]
Semantic,
Temporal,
Causal,
Entity,
}
impl EdgeType {
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::Semantic => "semantic",
Self::Temporal => "temporal",
Self::Causal => "causal",
Self::Entity => "entity",
}
}
}
impl fmt::Display for EdgeType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl FromStr for EdgeType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"semantic" => Ok(Self::Semantic),
"temporal" => Ok(Self::Temporal),
"causal" => Ok(Self::Causal),
"entity" => Ok(Self::Entity),
other => Err(format!("unknown edge type: {other}")),
}
}
}
pub const CAUSAL_MARKERS: &[&str] = &[
"why",
"because",
"caused",
"cause",
"reason",
"result",
"led to",
"consequence",
"trigger",
"effect",
"blame",
"fault",
];
pub const TEMPORAL_MARKERS: &[&str] = &[
"before", "after", "first", "then", "timeline", "sequence", "preceded", "followed", "started",
"ended", "during", "prior",
];
pub const ENTITY_MARKERS: &[&str] = &[
"is a",
"type of",
"kind of",
"part of",
"instance",
"same as",
"alias",
"subtype",
"subclass",
"belongs to",
];
pub const WORD_BOUNDARY_TEMPORAL: &[&str] = &["ago"];
#[must_use]
pub fn classify_graph_subgraph(query: &str) -> Vec<EdgeType> {
let lower = query.to_ascii_lowercase();
let mut types: Vec<EdgeType> = Vec::new();
if CAUSAL_MARKERS.iter().any(|m| lower.contains(m)) {
types.push(EdgeType::Causal);
}
if TEMPORAL_MARKERS.iter().any(|m| lower.contains(m)) {
types.push(EdgeType::Temporal);
}
if ENTITY_MARKERS.iter().any(|m| lower.contains(m)) {
types.push(EdgeType::Entity);
}
if !types.contains(&EdgeType::Semantic) {
types.push(EdgeType::Semantic);
}
types
}
#[must_use]
pub fn parse_route_str(s: &str, fallback: MemoryRoute) -> MemoryRoute {
match s {
"keyword" => MemoryRoute::Keyword,
"semantic" => MemoryRoute::Semantic,
"hybrid" => MemoryRoute::Hybrid,
"graph" => MemoryRoute::Graph,
"episodic" => MemoryRoute::Episodic,
_ => fallback,
}
}
pub trait TokenCounting: Send + Sync {
fn count_tokens(&self, text: &str) -> usize;
fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize;
}
#[derive(Debug, Clone)]
pub struct MemPersonaFact {
pub category: String,
pub content: String,
}
#[derive(Debug, Clone)]
pub struct MemTreeNode {
pub content: String,
}
#[derive(Debug, Clone)]
pub struct MemSummary {
pub first_message_id: Option<i64>,
pub last_message_id: Option<i64>,
pub content: String,
}
#[derive(Debug, Clone)]
pub struct MemReasoningStrategy {
pub id: String,
pub outcome: String,
pub summary: String,
}
#[derive(Debug, Clone)]
pub struct MemCorrection {
pub correction_text: String,
}
#[derive(Debug, Clone)]
pub struct MemRecalledMessage {
pub role: String,
pub content: String,
pub score: f32,
}
#[derive(Debug, Clone)]
pub struct MemGraphNeighbor {
pub fact: String,
pub confidence: f32,
}
#[derive(Debug, Clone)]
pub struct MemGraphFact {
pub fact: String,
pub confidence: f32,
pub activation_score: Option<f32>,
pub neighbors: Vec<MemGraphNeighbor>,
pub provenance_snippet: Option<String>,
}
#[derive(Debug, Clone)]
pub struct MemSessionSummary {
pub summary_text: String,
pub score: f32,
}
#[derive(Debug, Clone)]
pub struct MemDocumentChunk {
pub text: String,
}
#[derive(Debug, Clone)]
pub struct MemTrajectoryEntry {
pub intent: String,
pub outcome: String,
pub confidence: f64,
}
#[derive(Debug)]
pub struct GraphRecallParams<'a> {
pub limit: usize,
pub view: RecallView,
pub zoom_out_neighbor_cap: usize,
pub max_hops: u32,
pub temporal_decay_rate: f64,
pub edge_types: &'a [EdgeType],
pub spreading_activation: Option<SpreadingActivationParams>,
}
#[allow(clippy::type_complexity)]
pub trait ContextMemoryBackend: Send + Sync {
fn load_persona_facts<'a>(
&'a self,
min_confidence: f64,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<Vec<MemPersonaFact>, Box<dyn std::error::Error + Send + Sync>>,
> + Send
+ 'a,
>,
>;
fn load_trajectory_entries<'a>(
&'a self,
tier: Option<&'a str>,
top_k: usize,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<
Vec<MemTrajectoryEntry>,
Box<dyn std::error::Error + Send + Sync>,
>,
> + Send
+ 'a,
>,
>;
fn load_tree_nodes<'a>(
&'a self,
level: u32,
top_k: usize,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<Vec<MemTreeNode>, Box<dyn std::error::Error + Send + Sync>>,
> + Send
+ 'a,
>,
>;
fn load_summaries<'a>(
&'a self,
conversation_id: i64,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<Vec<MemSummary>, Box<dyn std::error::Error + Send + Sync>>,
> + Send
+ 'a,
>,
>;
fn retrieve_reasoning_strategies<'a>(
&'a self,
query: &'a str,
top_k: usize,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<
Vec<MemReasoningStrategy>,
Box<dyn std::error::Error + Send + Sync>,
>,
> + Send
+ 'a,
>,
>;
fn mark_reasoning_used<'a>(
&'a self,
ids: &'a [String],
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>>
+ Send
+ 'a,
>,
>;
fn retrieve_corrections<'a>(
&'a self,
query: &'a str,
limit: usize,
min_score: f32,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<Vec<MemCorrection>, Box<dyn std::error::Error + Send + Sync>>,
> + Send
+ 'a,
>,
>;
fn recall<'a>(
&'a self,
query: &'a str,
limit: usize,
router: Option<&'a dyn AsyncMemoryRouter>,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<
Vec<MemRecalledMessage>,
Box<dyn std::error::Error + Send + Sync>,
>,
> + Send
+ 'a,
>,
>;
fn recall_graph_facts<'a>(
&'a self,
query: &'a str,
params: GraphRecallParams<'a>,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<Vec<MemGraphFact>, Box<dyn std::error::Error + Send + Sync>>,
> + Send
+ 'a,
>,
>;
fn search_session_summaries<'a>(
&'a self,
query: &'a str,
limit: usize,
current_conversation_id: Option<i64>,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<
Vec<MemSessionSummary>,
Box<dyn std::error::Error + Send + Sync>,
>,
> + Send
+ 'a,
>,
>;
fn search_document_collection<'a>(
&'a self,
collection: &'a str,
query: &'a str,
top_k: usize,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<
Vec<MemDocumentChunk>,
Box<dyn std::error::Error + Send + Sync>,
>,
> + Send
+ 'a,
>,
>;
}
#[cfg(test)]
mod tests {
use super::MemoryRoute;
#[test]
fn memory_route_serde_roundtrip() {
let cases = [
("\"keyword\"", MemoryRoute::Keyword),
("\"semantic\"", MemoryRoute::Semantic),
("\"hybrid\"", MemoryRoute::Hybrid),
("\"graph\"", MemoryRoute::Graph),
("\"episodic\"", MemoryRoute::Episodic),
];
for (json_str, expected) in cases {
let got: MemoryRoute = serde_json::from_str(json_str).unwrap();
assert_eq!(got, expected);
let serialized = serde_json::to_string(&got).unwrap();
let roundtrip: MemoryRoute = serde_json::from_str(&serialized).unwrap();
assert_eq!(roundtrip, expected);
}
}
#[test]
fn memory_route_default_is_hybrid() {
assert_eq!(MemoryRoute::default(), MemoryRoute::Hybrid);
}
}