use crate::core::config::Config;
use crate::core::db::Database;
use crate::core::error::Result;
use crate::core::types::{HsgQueryResult, MemRow, Sector, SectorClassification, Tier};
use crate::memory::embed::EmbeddingProvider;
use crate::utils::{
canonical_token_set, cosine_similarity, keyword_filter_memories, now_ms, sigmoid, token_overlap,
};
use lazy_static::lazy_static;
use regex::Regex;
use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct ScoringWeights {
pub similarity: f64,
pub overlap: f64,
pub waypoint: f64,
pub recency: f64,
pub tag_match: f64,
}
impl Default for ScoringWeights {
fn default() -> Self {
Self {
similarity: 0.40,
overlap: 0.20,
waypoint: 0.15,
recency: 0.15,
tag_match: 0.10,
}
}
}
#[derive(Debug, Clone)]
pub struct HybridParams {
pub tau: f64,
pub beta: f64,
pub eta: f64,
pub gamma: f64,
pub alpha_reinforce: f64,
pub t_days: f64,
pub t_max_days: f64,
pub tau_hours: f64,
pub epsilon: f64,
}
impl Default for HybridParams {
fn default() -> Self {
Self {
tau: 3.0,
beta: 2.0,
eta: 0.1,
gamma: 0.2,
alpha_reinforce: 0.08,
t_days: 7.0,
t_max_days: 60.0,
tau_hours: 1.0,
epsilon: 1e-8,
}
}
}
#[derive(Debug, Clone)]
pub struct SectorConfig {
pub decay_lambda: f64,
pub weight: f64,
pub patterns: Vec<Regex>,
}
lazy_static! {
static ref SECTOR_CONFIGS: HashMap<Sector, SectorConfig> = {
let mut m = HashMap::new();
m.insert(Sector::Episodic, SectorConfig {
decay_lambda: 0.015,
weight: 1.2,
patterns: vec![
Regex::new(r"(?i)\b(today|yesterday|last\s+week|remember\s+when|that\s+time)\b").unwrap(),
Regex::new(r"(?i)\b(I\s+(did|went|saw|met|felt))\b").unwrap(),
Regex::new(r"(?i)\b(at\s+\d+:\d+|on\s+\w+day|in\s+\d{4})\b").unwrap(),
Regex::new(r"(?i)\b(happened|occurred|experience|event|moment)\b").unwrap(),
],
});
m.insert(Sector::Semantic, SectorConfig {
decay_lambda: 0.005,
weight: 1.0,
patterns: vec![
Regex::new(r"(?i)\b(define|definition|meaning|concept|theory)\b").unwrap(),
Regex::new(r"(?i)\b(what\s+is|how\s+does|why\s+do|facts?\s+about)\b").unwrap(),
Regex::new(r"(?i)\b(principle|rule|law|algorithm|method)\b").unwrap(),
Regex::new(r"(?i)\b(knowledge|information|data|research|study)\b").unwrap(),
],
});
m.insert(Sector::Procedural, SectorConfig {
decay_lambda: 0.008,
weight: 1.1,
patterns: vec![
Regex::new(r"(?i)\b(how\s+to|step\s+by\s+step|procedure|process)\b").unwrap(),
Regex::new(r"(?i)\b(first|then|next|finally|afterwards)\b").unwrap(),
Regex::new(r"(?i)\b(install|configure|setup|run|execute)\b").unwrap(),
Regex::new(r"(?i)\b(tutorial|guide|instructions|manual)\b").unwrap(),
Regex::new(r"(?i)\b(click|press|type|enter|select)\b").unwrap(),
],
});
m.insert(Sector::Emotional, SectorConfig {
decay_lambda: 0.02,
weight: 1.3,
patterns: vec![
Regex::new(r"(?i)\b(feel|feeling|felt|emotion|mood)\b").unwrap(),
Regex::new(r"(?i)\b(happy|sad|angry|excited|worried|anxious|calm)\b").unwrap(),
Regex::new(r"(?i)\b(love|hate|like|dislike|enjoy|fear)\b").unwrap(),
Regex::new(r"(?i)\b(amazing|terrible|wonderful|awful|fantastic|horrible)\b").unwrap(),
Regex::new(r"[!]{2,}|[\?\!]{2,}").unwrap(),
],
});
m.insert(Sector::Reflective, SectorConfig {
decay_lambda: 0.001,
weight: 0.8,
patterns: vec![
Regex::new(r"(?i)\b(think|thinking|thought|reflect|reflection)\b").unwrap(),
Regex::new(r"(?i)\b(realize|understand|insight|conclusion|lesson)\b").unwrap(),
Regex::new(r"(?i)\b(why|purpose|meaning|significance|impact)\b").unwrap(),
Regex::new(r"(?i)\b(philosophy|wisdom|belief|value|principle)\b").unwrap(),
Regex::new(r"(?i)\b(should\s+have|could\s+have|if\s+only|what\s+if)\b").unwrap(),
],
});
m
};
static ref SECTOR_RELATIONSHIPS: HashMap<Sector, HashMap<Sector, f64>> = {
let mut m = HashMap::new();
let mut semantic = HashMap::new();
semantic.insert(Sector::Procedural, 0.8);
semantic.insert(Sector::Episodic, 0.6);
semantic.insert(Sector::Reflective, 0.7);
semantic.insert(Sector::Emotional, 0.4);
m.insert(Sector::Semantic, semantic);
let mut procedural = HashMap::new();
procedural.insert(Sector::Semantic, 0.8);
procedural.insert(Sector::Episodic, 0.6);
procedural.insert(Sector::Reflective, 0.6);
procedural.insert(Sector::Emotional, 0.3);
m.insert(Sector::Procedural, procedural);
let mut episodic = HashMap::new();
episodic.insert(Sector::Reflective, 0.8);
episodic.insert(Sector::Semantic, 0.6);
episodic.insert(Sector::Procedural, 0.6);
episodic.insert(Sector::Emotional, 0.7);
m.insert(Sector::Episodic, episodic);
let mut reflective = HashMap::new();
reflective.insert(Sector::Episodic, 0.8);
reflective.insert(Sector::Semantic, 0.7);
reflective.insert(Sector::Procedural, 0.6);
reflective.insert(Sector::Emotional, 0.6);
m.insert(Sector::Reflective, reflective);
let mut emotional = HashMap::new();
emotional.insert(Sector::Episodic, 0.7);
emotional.insert(Sector::Reflective, 0.6);
emotional.insert(Sector::Semantic, 0.4);
emotional.insert(Sector::Procedural, 0.3);
m.insert(Sector::Emotional, emotional);
m
};
}
#[derive(Debug, Clone)]
pub struct WaypointExpansion {
pub id: String,
pub weight: f64,
pub path: Vec<String>,
}
pub struct HsgEngine {
db: Arc<Database>,
embedder: Arc<dyn EmbeddingProvider>,
weights: ScoringWeights,
params: HybridParams,
tier: Tier,
keyword_boost: f64,
}
impl HsgEngine {
pub fn new(
db: Arc<Database>,
embedder: Arc<dyn EmbeddingProvider>,
) -> Self {
Self {
db,
embedder,
weights: ScoringWeights::default(),
params: HybridParams::default(),
tier: Tier::default(),
keyword_boost: 0.3,
}
}
pub fn with_config(
db: Arc<Database>,
embedder: Arc<dyn EmbeddingProvider>,
config: &Config,
) -> Self {
Self {
db,
embedder,
weights: ScoringWeights::default(),
params: HybridParams::default(),
tier: config.tier,
keyword_boost: config.keyword_boost,
}
}
pub async fn query(
&self,
query: &str,
k: usize,
sectors: Option<&[Sector]>,
min_salience: Option<f64>,
user_id: Option<&str>,
) -> Result<Vec<HsgQueryResult>> {
let query_class = classify_content(query, None);
let target_sectors = sectors.map(|s| s.to_vec()).unwrap_or_else(|| {
let mut secs = vec![query_class.primary];
secs.extend(query_class.additional.clone());
secs
});
let query_embedding = self.embedder.embed(query, &query_class.primary).await?;
let candidates = self.get_candidates(&target_sectors, k * 3, user_id)?;
let query_tokens = canonical_token_set(query);
let now = now_ms();
let mut initial_results: Vec<(String, f64)> = Vec::new();
let mut candidate_ids: HashSet<String> = HashSet::new();
for mem in &candidates {
if let Some(min_sal) = min_salience {
if mem.salience < min_sal {
continue;
}
}
let mem_vec = self.db.get_vector(&mem.id, &mem.primary_sector)?;
let mem_vec = match mem_vec {
Some(v) => v,
None => continue,
};
let similarity = cosine_similarity(&query_embedding.vector, &mem_vec) as f64;
initial_results.push((mem.id.clone(), similarity));
candidate_ids.insert(mem.id.clone());
}
initial_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_sims: Vec<f64> = initial_results.iter().take(8).map(|(_, s)| *s).collect();
let avg_top = if !top_sims.is_empty() {
top_sims.iter().sum::<f64>() / top_sims.len() as f64
} else {
0.0
};
let high_confidence = avg_top >= 0.55;
let initial_ids: Vec<String> = initial_results.iter().map(|(id, _)| id.clone()).collect();
let expansions = if !high_confidence {
self.expand_via_waypoints(&initial_ids, k * 2)?
} else {
Vec::new()
};
for exp in &expansions {
candidate_ids.insert(exp.id.clone());
}
let expansion_map: HashMap<String, &WaypointExpansion> = expansions
.iter()
.map(|e| (e.id.clone(), e))
.collect();
let keyword_scores: HashMap<String, f64> = if self.tier == Tier::Hybrid {
let memory_contents: Vec<(String, String)> = candidate_ids
.iter()
.filter_map(|id| {
if let Some(mem) = candidates.iter().find(|m| &m.id == id) {
Some((id.clone(), mem.content.clone()))
} else {
self.db.get_memory(id).ok().flatten().map(|m| (id.clone(), m.content))
}
})
.collect();
keyword_filter_memories(query, &memory_contents, Some(0.05), Some(3))
} else {
HashMap::new()
};
let mut scored: Vec<(HsgQueryResult, f64)> = Vec::new();
for id in &candidate_ids {
let mem = if let Some(m) = candidates.iter().find(|m| &m.id == id) {
m.clone()
} else if let Some(m) = self.db.get_memory(id)? {
m
} else {
continue;
};
if let Some(min_sal) = min_salience {
if mem.salience < min_sal {
continue;
}
}
let mem_vec = self.db.get_vector(&mem.id, &mem.primary_sector)?;
let mem_vec = match mem_vec {
Some(v) => v,
None => continue,
};
let similarity = cosine_similarity(&query_embedding.vector, &mem_vec) as f64;
let overlap = token_overlap(query, &mem.content);
let recency = self.compute_recency_score(mem.last_seen_at, now);
let tag_score = self.compute_tag_score(&mem, &query_tokens);
let (waypoint_weight, path) = if let Some(exp) = expansion_map.get(id) {
(exp.weight.clamp(0.0, 1.0), exp.path.clone())
} else {
(self.get_waypoint_weight(&mem.id)?, vec![mem.id.clone()])
};
let mem_sector = mem.primary_sector;
let query_sector = query_class.primary;
let sector_penalty = if mem_sector != query_sector
&& !query_class.additional.contains(&mem_sector)
{
SECTOR_RELATIONSHIPS
.get(&query_sector)
.and_then(|m| m.get(&mem_sector))
.copied()
.unwrap_or(0.3)
} else {
1.0
};
let adjusted_similarity = similarity * sector_penalty;
let keyword_boost = if self.tier == Tier::Hybrid {
keyword_scores.get(id).copied().unwrap_or(0.0) * self.keyword_boost
} else {
0.0
};
let raw_score = self.weights.similarity * adjusted_similarity
+ self.weights.overlap * overlap
+ self.weights.recency * recency
+ self.weights.tag_match * tag_score
+ self.weights.waypoint * waypoint_weight
+ keyword_boost;
let final_score = sigmoid(raw_score * self.params.tau);
let result = HsgQueryResult {
id: mem.id.clone(),
content: mem.content.clone(),
score: final_score,
sectors: vec![mem.primary_sector],
primary_sector: mem.primary_sector,
path,
salience: mem.salience,
last_seen_at: mem.last_seen_at,
};
scored.push((result, final_score));
}
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let results: Vec<HsgQueryResult> = scored.into_iter().take(k).map(|(r, _)| r).collect();
for result in &results {
let _ = self.db.update_feedback_score(&result.id, result.score);
}
Ok(results)
}
fn get_candidates(
&self,
sectors: &[Sector],
limit: usize,
user_id: Option<&str>,
) -> Result<Vec<MemRow>> {
let mut candidates = Vec::new();
if let Some(uid) = user_id {
candidates.extend(self.db.get_memories_by_user(uid, limit, 0)?);
} else {
for sector in sectors {
let mems = self.db.get_memories_by_sector(sector, limit / sectors.len(), 0)?;
candidates.extend(mems);
}
}
Ok(candidates)
}
fn compute_recency_score(&self, last_seen_at: i64, now: i64) -> f64 {
let hours_since = (now - last_seen_at) as f64 / (1000.0 * 60.0 * 60.0);
let days_since = hours_since / 24.0;
(-days_since / self.params.t_days).exp()
}
fn compute_tag_score(&self, mem: &MemRow, query_tokens: &HashSet<String>) -> f64 {
let tags = match &mem.tags {
Some(t) => t,
None => return 0.0,
};
let mut matches = 0;
for tag in tags {
let tag_lower = tag.to_lowercase();
if query_tokens.contains(&tag_lower) {
matches += 2; } else {
for token in query_tokens {
if tag_lower.contains(token) || token.contains(&tag_lower) {
matches += 1;
}
}
}
}
(matches as f64 / (tags.len() * 2).max(1) as f64).min(1.0)
}
fn get_waypoint_weight(&self, id: &str) -> Result<f64> {
let neighbors = self.db.get_neighbors(id)?;
if neighbors.is_empty() {
return Ok(0.0);
}
let total: f64 = neighbors.iter().map(|(_, w)| *w).sum();
Ok(total / neighbors.len() as f64)
}
pub fn expand_via_waypoints(
&self,
initial_ids: &[String],
max_expansions: usize,
) -> Result<Vec<WaypointExpansion>> {
let mut expansions: Vec<WaypointExpansion> = Vec::new();
let mut visited: HashSet<String> = HashSet::new();
let mut queue: VecDeque<WaypointExpansion> = VecDeque::new();
for id in initial_ids {
let exp = WaypointExpansion {
id: id.clone(),
weight: 1.0,
path: vec![id.clone()],
};
expansions.push(exp.clone());
visited.insert(id.clone());
queue.push_back(exp);
}
let mut expansion_count = 0;
while let Some(current) = queue.pop_front() {
if expansion_count >= max_expansions {
break;
}
let neighbors = self.db.get_neighbors(¤t.id)?;
for (neighbor_id, neighbor_weight) in neighbors {
if visited.contains(&neighbor_id) {
continue;
}
let clamped_weight = neighbor_weight.clamp(0.0, 1.0);
let expanded_weight = current.weight * clamped_weight * 0.8;
if expanded_weight < 0.1 {
continue;
}
let mut new_path = current.path.clone();
new_path.push(neighbor_id.clone());
let exp = WaypointExpansion {
id: neighbor_id.clone(),
weight: expanded_weight,
path: new_path,
};
expansions.push(exp.clone());
visited.insert(neighbor_id);
queue.push_back(exp);
expansion_count += 1;
if expansion_count >= max_expansions {
break;
}
}
}
Ok(expansions)
}
}
pub fn classify_content(content: &str, metadata: Option<&serde_json::Value>) -> SectorClassification {
if let Some(meta) = metadata {
if let Some(sector_str) = meta.get("sector").and_then(|s| s.as_str()) {
if let Some(sector) = Sector::from_str(sector_str) {
return SectorClassification {
primary: sector,
additional: vec![],
confidence: 1.0,
};
}
}
}
let mut scores: HashMap<Sector, f64> = HashMap::new();
for (sector, config) in SECTOR_CONFIGS.iter() {
let mut score = 0.0;
for pattern in &config.patterns {
let matches: Vec<_> = pattern.find_iter(content).collect();
score += matches.len() as f64 * config.weight;
}
scores.insert(*sector, score);
}
let mut sorted_scores: Vec<_> = scores.iter().collect();
sorted_scores.sort_by(|a, b| {
let cmp = b.1.partial_cmp(a.1).unwrap_or(std::cmp::Ordering::Equal);
if cmp == std::cmp::Ordering::Equal {
let rank = |s: &Sector| match s {
Sector::Semantic => 0,
Sector::Procedural => 1,
Sector::Episodic => 2,
Sector::Reflective => 3,
Sector::Emotional => 4,
};
rank(a.0).cmp(&rank(b.0))
} else {
cmp
}
});
let (primary, primary_score) = sorted_scores.first()
.map(|(&s, &sc)| (s, sc))
.unwrap_or((Sector::Semantic, 0.0));
let threshold = (primary_score * 0.3).max(1.0);
let additional: Vec<Sector> = sorted_scores
.iter()
.skip(1)
.filter(|(_, &score)| score > 0.0 && score >= threshold)
.map(|(§or, _)| sector)
.collect();
let secondary_score = sorted_scores.get(1).map(|(_, &s)| s).unwrap_or(0.0);
let confidence = if primary_score > 0.0 {
(primary_score / (primary_score + secondary_score + 1.0)).min(1.0)
} else {
0.2
};
SectorClassification {
primary,
additional,
confidence,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_classify_episodic() {
let class = classify_content("Yesterday I went to the store", None);
assert_eq!(class.primary, Sector::Episodic);
}
#[test]
fn test_classify_semantic() {
let class = classify_content("The definition of entropy in physics", None);
assert_eq!(class.primary, Sector::Semantic);
}
#[test]
fn test_classify_procedural() {
let class = classify_content("How to install Python: first download, then run the installer", None);
assert_eq!(class.primary, Sector::Procedural);
}
#[test]
fn test_classify_emotional() {
let class = classify_content("I feel so happy and excited about this!", None);
assert_eq!(class.primary, Sector::Emotional);
}
#[test]
fn test_classify_reflective() {
let class = classify_content("I realize now that I should have done things differently", None);
assert_eq!(class.primary, Sector::Reflective);
}
#[test]
fn test_classify_default() {
let class = classify_content("random text without patterns", None);
assert!(class.confidence < 0.5);
}
#[test]
fn test_scoring_weights_sum() {
let weights = ScoringWeights::default();
let sum = weights.similarity + weights.overlap + weights.waypoint + weights.recency + weights.tag_match;
assert!((sum - 1.0).abs() < 0.01);
}
}