pub mod embed;
pub mod extract;
pub mod resolver;
pub mod signals;
pub mod store;
use std::sync::Arc;
use anyhow::Result;
use chrono::Utc;
use rusqlite::Transaction;
use serde::{Deserialize, Serialize};
use self::extract::{EntityExtractor, ExtractedEntities};
use self::resolver::{canonicalise, CanonicalEntity};
use self::signals::{ScoreSignals, SignalWeights};
use crate::memory::chunks::{approx_token_count, Chunk, SourceKind};
use crate::memory::config::MemoryConfig;
pub const DEFAULT_DROP_THRESHOLD: f32 = 0.3;
pub const DEFAULT_DEFINITE_KEEP: f32 = 0.85;
pub const DEFAULT_DEFINITE_DROP: f32 = 0.15;
pub const PRIORITY_TAG: &str = "priority_high";
pub const PRIORITY_BOOST: f32 = 0.25;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ScoreResult {
pub chunk_id: String,
pub total: f32,
pub signals: ScoreSignals,
pub kept: bool,
pub drop_reason: Option<String>,
pub extracted: ExtractedEntities,
pub canonical_entities: Vec<CanonicalEntity>,
}
pub struct ScoringConfig {
pub extractor: Arc<dyn EntityExtractor>,
pub weights: SignalWeights,
pub drop_threshold: f32,
pub llm_extractor: Option<Arc<dyn EntityExtractor>>,
pub definite_keep_threshold: f32,
pub definite_drop_threshold: f32,
}
impl ScoringConfig {
pub fn default_regex_only() -> Self {
Self {
extractor: Arc::new(extract::CompositeExtractor::regex_only()),
weights: SignalWeights::default(),
drop_threshold: DEFAULT_DROP_THRESHOLD,
llm_extractor: None,
definite_keep_threshold: DEFAULT_DEFINITE_KEEP,
definite_drop_threshold: DEFAULT_DEFINITE_DROP,
}
}
pub fn with_llm_extractor(llm: Arc<dyn EntityExtractor>) -> Self {
Self {
extractor: Arc::new(extract::CompositeExtractor::regex_only()),
weights: SignalWeights::with_llm_enabled(),
drop_threshold: DEFAULT_DROP_THRESHOLD,
llm_extractor: Some(llm),
definite_keep_threshold: DEFAULT_DEFINITE_KEEP,
definite_drop_threshold: DEFAULT_DEFINITE_DROP,
}
}
}
pub async fn score_chunk(chunk: &Chunk, cfg: &ScoringConfig) -> Result<ScoreResult> {
let scoring_content = scoring_content_for_chunk(chunk);
let scoring_token_count = approx_token_count(&scoring_content);
let mut extracted = cfg.extractor.extract(&scoring_content).await?;
let mut signals = self::signals::compute(
&chunk.metadata,
&scoring_content,
scoring_token_count,
&extracted,
);
let cheap_total = self::signals::combine_cheap_only(&signals, &cfg.weights);
let in_band =
cheap_total > cfg.definite_drop_threshold && cheap_total < cfg.definite_keep_threshold;
let llm_consulted = if in_band {
if let Some(llm) = cfg.llm_extractor.as_ref() {
match llm.extract(&scoring_content).await {
Ok(more) => {
extracted.merge(more);
signals = self::signals::compute(
&chunk.metadata,
&scoring_content,
scoring_token_count,
&extracted,
);
true
}
Err(_e) => {
false
}
}
} else {
false
}
} else {
false
};
let mut total = if llm_consulted {
self::signals::combine(&signals, &cfg.weights)
} else {
self::signals::combine_cheap_only(&signals, &cfg.weights)
};
let priority = chunk.metadata.tags.iter().any(|t| t == PRIORITY_TAG);
if priority {
let boosted = (total + PRIORITY_BOOST).min(1.0);
if boosted > total {
total = boosted;
}
}
let tiny_entity_free = !priority
&& scoring_token_count < self::signals::token_count::TOKEN_MIN
&& extracted.is_empty();
let kept = !tiny_entity_free && total >= cfg.drop_threshold;
let drop_reason = if kept {
None
} else if tiny_entity_free {
Some(format!(
"token_count {} < minimum {} and no entities extracted",
scoring_token_count,
self::signals::token_count::TOKEN_MIN
))
} else {
Some(format!(
"total {total:.3} < threshold {:.3}",
cfg.drop_threshold
))
};
let canonical_entities = canonicalise(&extracted);
if !kept {}
Ok(ScoreResult {
chunk_id: chunk.id.clone(),
total,
signals,
kept,
drop_reason,
extracted,
canonical_entities,
})
}
fn scoring_content_for_chunk(chunk: &Chunk) -> String {
if chunk.metadata.source_kind != SourceKind::Chat {
return chunk.content.clone();
}
chunk
.content
.lines()
.filter(|line| {
let trimmed = line.trim_start();
!trimmed.starts_with("# Chat transcript") && !trimmed.starts_with("## ")
})
.collect::<Vec<_>>()
.join("\n")
}
pub async fn score_chunks(chunks: &[Chunk], cfg: &ScoringConfig) -> Result<Vec<ScoreResult>> {
let mut out = Vec::with_capacity(chunks.len());
for chunk in chunks {
out.push(score_chunk(chunk, cfg).await?);
}
Ok(out)
}
pub async fn score_chunks_fast(chunks: &[Chunk], cfg: &ScoringConfig) -> Result<Vec<ScoreResult>> {
let fast_cfg = ScoringConfig {
extractor: cfg.extractor.clone(),
weights: cfg.weights.clone(),
drop_threshold: cfg.drop_threshold,
llm_extractor: None,
definite_keep_threshold: cfg.definite_keep_threshold,
definite_drop_threshold: cfg.definite_drop_threshold,
};
score_chunks(chunks, &fast_cfg).await
}
pub fn persist_score(
config: &MemoryConfig,
result: &ScoreResult,
timestamp_ms: i64,
tree_id: Option<&str>,
) -> Result<()> {
let row = score_row(result);
store::upsert_score(config, &row)?;
if result.kept {
store::clear_entity_index_for_node(config, &result.chunk_id)?;
if !result.canonical_entities.is_empty() {
store::index_entities(
config,
&result.canonical_entities,
&result.chunk_id,
"leaf",
timestamp_ms,
tree_id,
)?;
}
}
Ok(())
}
pub fn persist_score_tx(
tx: &Transaction<'_>,
result: &ScoreResult,
timestamp_ms: i64,
tree_id: Option<&str>,
) -> Result<()> {
let row = score_row(result);
store::upsert_score_tx(tx, &row)?;
if result.kept {
store::clear_entity_index_for_node_tx(tx, &result.chunk_id)?;
if !result.canonical_entities.is_empty() {
store::index_entities_tx(
tx,
&result.canonical_entities,
&result.chunk_id,
"leaf",
timestamp_ms,
tree_id,
)?;
}
}
Ok(())
}
fn score_row(result: &ScoreResult) -> store::ScoreRow {
store::ScoreRow {
chunk_id: result.chunk_id.clone(),
total: result.total,
signals: result.signals.clone(),
dropped: !result.kept,
reason: result.drop_reason.clone(),
computed_at_ms: Utc::now().timestamp_millis(),
llm_importance_reason: result.extracted.llm_importance_reason.clone(),
}
}
#[cfg(test)]
#[path = "mod_tests.rs"]
mod tests;