use std::fmt;
use hirn_core::HirnError;
use hirn_core::record::MemoryRecord;
use hirn_core::types::Origin;
#[derive(Debug, Clone, Copy)]
pub struct ScoringWeights {
pub similarity: f32,
pub importance: f32,
pub recency: f32,
pub activation: f32,
pub causal_relevance: f32,
pub surprise: f32,
pub source_reliability: f32,
}
impl ScoringWeights {
pub fn validate(&self) -> Result<(), HirnError> {
for (name, w) in [
("similarity", self.similarity),
("importance", self.importance),
("recency", self.recency),
("activation", self.activation),
("causal_relevance", self.causal_relevance),
("surprise", self.surprise),
("source_reliability", self.source_reliability),
] {
if w < 0.0 || w > 1.0 {
return Err(HirnError::InvalidInput(format!(
"scoring weight '{name}' must be in [0.0, 1.0], got {w}"
)));
}
}
let sum = self.similarity
+ self.importance
+ self.recency
+ self.activation
+ self.causal_relevance
+ self.surprise
+ self.source_reliability;
if (sum - 1.0).abs() > 1e-4 {
return Err(HirnError::InvalidInput(format!(
"scoring weights must sum to 1.0, got {sum}"
)));
}
Ok(())
}
pub const PURE_SIMILARITY: Self = Self {
similarity: 1.0,
importance: 0.0,
recency: 0.0,
activation: 0.0,
causal_relevance: 0.0,
surprise: 0.0,
source_reliability: 0.0,
};
}
impl Default for ScoringWeights {
fn default() -> Self {
Self {
similarity: 0.30,
importance: 0.20,
recency: 0.20,
activation: 0.10,
causal_relevance: 0.05,
surprise: 0.05,
source_reliability: 0.10,
}
}
}
#[cfg(test)]
mod weight_tests {
use super::*;
#[test]
fn scoring_weights_default_sum_to_one() {
ScoringWeights::default().validate().unwrap();
}
}
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
pub struct ScoreBreakdown {
pub similarity: f32,
pub importance: f32,
pub recency: f32,
pub activation: f32,
pub causal_relevance: f32,
pub surprise: f32,
pub source_reliability: f32,
}
impl fmt::Display for ScoreBreakdown {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"sim={:.3} imp={:.3} rec={:.3} act={:.3} caus={:.3} sur={:.3} src={:.3}",
self.similarity,
self.importance,
self.recency,
self.activation,
self.causal_relevance,
self.surprise,
self.source_reliability,
)
}
}
#[must_use]
pub fn source_reliability_for_record(record: &MemoryRecord) -> f32 {
let origin = match record {
MemoryRecord::Episodic(e) => e.provenance.origin(),
MemoryRecord::Semantic(s) => s.provenance.origin(),
MemoryRecord::Working(_) => return 0.8,
MemoryRecord::Procedural(_) => return 0.8,
};
source_reliability_for_origin(*origin)
}
#[must_use]
pub fn source_reliability_for_origin(origin: Origin) -> f32 {
match origin {
Origin::DirectObservation | Origin::UserProvided => 1.0,
Origin::LlmExtraction => 0.8,
Origin::Consolidation | Origin::DreamReplay => 0.6,
Origin::CrossAgent => 0.5,
}
}
pub fn composite_score(
similarity: f32,
importance: f32,
age_hours: f64,
decay_lambda: f64,
access_freq: u64,
activation: f32,
causal_rel: f32,
surprise: f32,
source_rel: f32,
weights: &ScoringWeights,
) -> f32 {
let recency = fade_mem_recency(importance, age_hours, decay_lambda, access_freq);
let score = weights.similarity * similarity.clamp(0.0, 1.0)
+ weights.importance * importance.clamp(0.0, 1.0)
+ weights.recency * recency.clamp(0.0, 1.0)
+ weights.activation * activation.clamp(0.0, 1.0)
+ weights.causal_relevance * causal_rel.clamp(0.0, 1.0)
+ weights.surprise * surprise.clamp(0.0, 1.0)
+ weights.source_reliability * source_rel.clamp(0.0, 1.0);
score.clamp(0.0, 1.0)
}
#[must_use]
pub fn fade_mem_recency(
importance: f32,
age_hours: f64,
decay_lambda: f64,
access_freq: u64,
) -> f32 {
let imp = importance.clamp(0.0, 1.0) as f64;
let freq = access_freq as f64;
let adaptive_rate = decay_lambda * (1.0 / (1.0 + imp)) * (1.0 / (1.0 + freq));
(-adaptive_rate * age_hours).exp() as f32
}
pub use hirn_core::embed::{NoopReranker, RerankResult, Reranker};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pure_similarity() {
let score = composite_score(
0.9,
0.5,
1.0,
0.01,
0,
0.0,
0.0,
0.0,
0.0,
&ScoringWeights::PURE_SIMILARITY,
);
assert!((score - 0.9).abs() < 1e-4);
}
#[test]
fn higher_importance_ranks_higher() {
let w = ScoringWeights {
similarity: 0.5,
importance: 0.5,
recency: 0.0,
activation: 0.0,
causal_relevance: 0.0,
surprise: 0.0,
source_reliability: 0.0,
};
let low = composite_score(0.8, 0.2, 0.0, 0.01, 0, 0.0, 0.0, 0.0, 0.0, &w);
let high = composite_score(0.8, 0.9, 0.0, 0.01, 0, 0.0, 0.0, 0.0, 0.0, &w);
assert!(high > low);
}
#[test]
fn more_recent_ranks_higher() {
let w = ScoringWeights {
similarity: 0.5,
importance: 0.0,
recency: 0.5,
activation: 0.0,
causal_relevance: 0.0,
surprise: 0.0,
source_reliability: 0.0,
};
let old = composite_score(0.8, 0.5, 720.0, 0.01, 0, 0.0, 0.0, 0.0, 0.0, &w); let recent = composite_score(0.8, 0.5, 1.0, 0.01, 0, 0.0, 0.0, 0.0, 0.0, &w); assert!(recent > old);
}
#[test]
fn recency_decay() {
let w = ScoringWeights::PURE_SIMILARITY;
let s1 = composite_score(0.9, 0.5, 1.0, 0.01, 0, 0.0, 0.0, 0.0, 0.0, &w);
let s2 = composite_score(0.9, 0.5, 720.0, 0.01, 0, 0.0, 0.0, 0.0, 0.0, &w);
assert!((s1 - s2).abs() < 1e-4);
}
#[test]
fn score_in_range() {
let w = ScoringWeights::default();
for sim in [0.0, 0.1, 0.5, 0.9, 1.0] {
for imp in [0.0, 0.5, 1.0] {
for age in [0.0, 1.0, 24.0, 720.0] {
let s = composite_score(sim, imp, age, 0.01, 0, 0.0, 0.0, 0.0, 0.0, &w);
assert!(
(0.0..=1.0).contains(&s),
"score {s} out of range for sim={sim}, imp={imp}, age={age}"
);
}
}
}
}
#[test]
fn invalid_weights() {
let w = ScoringWeights {
similarity: 0.5,
importance: 0.5,
recency: 0.5,
activation: 0.0,
causal_relevance: 0.0,
surprise: 0.0,
source_reliability: 0.0,
};
assert!(w.validate().is_err());
}
#[test]
fn valid_weights() {
ScoringWeights::default().validate().unwrap();
ScoringWeights::PURE_SIMILARITY.validate().unwrap();
}
}