use std::sync::Arc;
use tracing::{Instrument, Level, event, info_span};
use crate::jobs::Job;
use crate::store::{MemoryStore, StoreError};
use super::ClientInner;
pub const CATEGORY_LABELS: &[&str] = &["preference", "identity", "workflow", "factual", "transient"];
const FALLBACK_CATEGORY: &str = "transient";
const MIN_CATEGORY_SCORE: f32 = 0.5;
const HYPOTHESIS_TEMPLATE: &str = "This memory is about a {}.";
#[derive(Debug, thiserror::Error)]
pub(super) enum CategorizeError {
#[error("source lookup failed: {0}")]
SourceLookup(#[from] StoreError),
#[error("persist failed: {0}")]
Persist(String),
#[error("classification failed: {0}")]
Classify(String),
}
impl ClientInner {
pub(super) async fn run_categorize(self: &Arc<Self>, job: Job) -> Result<(), CategorizeError> {
let span = info_span!("memoir.categorize", source_pid = %job.source_pid);
async move { self.run_categorize_inner(job).await }
.instrument(span)
.await
}
async fn run_categorize_inner(self: &Arc<Self>, job: Job) -> Result<(), CategorizeError> {
let pid = job.source_pid.clone();
let Some(classifier) = self.nli.clone() else {
event!(
name: "memoir.categorize.skipped",
Level::WARN,
source_pid = %pid,
"no NLI classifier configured; treating job as no-op",
);
return Ok(());
};
let memory = match self.store.recall(&pid).await {
Ok(memory) => memory,
Err(StoreError::NotFound(_)) => {
event!(
name: "memoir.categorize.source_missing",
Level::INFO,
source_pid = %pid,
"target memory absent for {{source_pid}} (cascade delete race); skipping",
);
return Ok(());
}
Err(err) => return Err(CategorizeError::SourceLookup(err)),
};
let content = memory.content.clone();
let scored =
tokio::task::spawn_blocking(move || classifier.classify(&content, CATEGORY_LABELS, HYPOTHESIS_TEMPLATE))
.await
.map_err(|join_err| CategorizeError::Classify(format!("classify task panicked: {join_err}")))?
.map_err(|nli_err| CategorizeError::Classify(nli_err.to_string()))?;
let category = match scored.first() {
Some(top) if top.score >= MIN_CATEGORY_SCORE => top.label.clone(),
_ => FALLBACK_CATEGORY.to_string(),
};
self.store
.set_category(&pid, &category)
.await
.map_err(|err| CategorizeError::Persist(err.to_string()))?;
event!(
name: "memoir.categorize.done",
Level::INFO,
source_pid = %pid,
category = %category,
"categorized {{source_pid}} as {{category}}",
);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn should_have_five_distinct_category_labels() {
assert_eq!(CATEGORY_LABELS.len(), 5);
let unique: std::collections::HashSet<_> = CATEGORY_LABELS.iter().collect();
assert_eq!(unique.len(), 5, "category labels must be distinct");
}
#[test]
fn should_include_fallback_in_taxonomy() {
assert!(
CATEGORY_LABELS.contains(&FALLBACK_CATEGORY),
"the fallback category must be a valid taxonomy label"
);
}
#[test]
fn should_fill_hypothesis_template_per_label() {
assert_eq!(
HYPOTHESIS_TEMPLATE.replace("{}", "preference"),
"This memory is about a preference."
);
}
}