use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::Arc;
use crate::error::MemoryError;
use crate::types::{ConversationId, MessageId};
#[derive(Debug, Clone)]
pub struct PromotionInput {
pub message_id: MessageId,
pub conversation_id: ConversationId,
pub content: String,
pub embedding: Option<Vec<f32>>,
}
#[derive(Debug, Clone)]
pub struct PromotionCandidate {
pub signature: String,
pub member_ids: Vec<MessageId>,
pub session_ids: Vec<ConversationId>,
pub centroid: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct PromotionConfig {
pub min_occurrences: u32,
pub min_sessions: u32,
pub cluster_threshold: f32,
}
impl Default for PromotionConfig {
fn default() -> Self {
Self {
min_occurrences: 3,
min_sessions: 2,
cluster_threshold: 0.85,
}
}
}
pub trait SkillWriter: Send + Sync {
fn write_skill(
&self,
description: String,
signature: String,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), String>> + Send + '_>>;
}
pub struct PromotionEngine {
writer: Arc<dyn SkillWriter>,
config: PromotionConfig,
output_dir: PathBuf,
}
impl PromotionEngine {
#[must_use]
pub fn new(writer: Arc<dyn SkillWriter>, config: PromotionConfig, output_dir: PathBuf) -> Self {
Self {
writer,
config,
output_dir,
}
}
#[tracing::instrument(name = "memory.compression.promote.scan", skip_all,
fields(window_len = window.len()))]
pub async fn scan(
&self,
window: &[PromotionInput],
) -> Result<Vec<PromotionCandidate>, MemoryError> {
let embeds: Vec<&PromotionInput> =
window.iter().filter(|p| p.embedding.is_some()).collect();
if embeds.is_empty() {
return Ok(vec![]);
}
let dim = embeds[0].embedding.as_ref().unwrap().len();
struct Cluster {
centroid: Vec<f32>,
member_ids: Vec<MessageId>,
session_ids: HashSet<ConversationId>,
}
let mut clusters: Vec<Cluster> = Vec::new();
for input in &embeds {
let emb = input.embedding.as_ref().unwrap();
if emb.len() != dim {
return Err(MemoryError::Promotion(format!(
"embedding dimension mismatch: expected {dim}, got {}",
emb.len()
)));
}
let mut assigned = false;
for cluster in &mut clusters {
let sim = cosine_similarity(emb, &cluster.centroid);
if sim >= self.config.cluster_threshold {
#[allow(clippy::cast_precision_loss)]
let n = cluster.member_ids.len() as f32;
for (c, v) in cluster.centroid.iter_mut().zip(emb.iter()) {
*c = (*c * n + v) / (n + 1.0);
}
cluster.member_ids.push(input.message_id);
cluster.session_ids.insert(input.conversation_id);
assigned = true;
break;
}
}
if !assigned {
clusters.push(Cluster {
centroid: emb.clone(),
member_ids: vec![input.message_id],
session_ids: std::iter::once(input.conversation_id).collect(),
});
}
}
let candidates = clusters
.into_iter()
.filter(|c| {
u32::try_from(c.member_ids.len()).unwrap_or(u32::MAX) >= self.config.min_occurrences
&& u32::try_from(c.session_ids.len()).unwrap_or(u32::MAX)
>= self.config.min_sessions
})
.map(|c| {
let signature = cluster_signature(&c.centroid);
PromotionCandidate {
signature,
member_ids: c.member_ids,
session_ids: c.session_ids.into_iter().collect(),
centroid: c.centroid,
}
})
.collect();
Ok(candidates)
}
#[tracing::instrument(name = "memory.compression.promote.persist", skip_all,
fields(signature = %candidate.signature))]
pub async fn promote(&self, candidate: &PromotionCandidate) -> Result<(), MemoryError> {
let skill_name = format!("promoted-pattern-{}", &candidate.signature[..12]);
let skill_dir = self.output_dir.join(&skill_name);
if skill_dir.exists() {
tracing::debug!(signature = %candidate.signature, "promotion candidate already exists, skipping");
return Ok(());
}
let member_count = candidate.member_ids.len();
let session_count = candidate.session_ids.len();
let description = format!(
"Recurring procedural pattern detected across {member_count} messages in \
{session_count} sessions. Generate a concise SKILL.md capturing the common \
tool-use pattern or workflow. Signature: {}.",
candidate.signature
);
self.writer
.write_skill(description, candidate.signature.clone())
.await
.map_err(MemoryError::Promotion)
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a < f32::EPSILON || norm_b < f32::EPSILON {
return 0.0;
}
(dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
}
fn cluster_signature(centroid: &[f32]) -> String {
use std::hash::Hash;
let mut hasher = std::collections::hash_map::DefaultHasher::new();
for v in centroid {
let bits = v.to_bits();
bits.hash(&mut hasher);
}
let h = std::hash::Hasher::finish(&hasher);
format!("{h:016x}")
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
struct RecordingWriter {
written: Mutex<Vec<String>>,
}
impl SkillWriter for RecordingWriter {
fn write_skill(
&self,
description: String,
_signature: String,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), String>> + Send + '_>>
{
self.written.lock().unwrap().push(description);
Box::pin(async { Ok(()) })
}
}
fn make_input(id: i64, cid: i64, content: &str, emb: Vec<f32>) -> PromotionInput {
PromotionInput {
message_id: MessageId(id),
conversation_id: ConversationId(cid),
content: content.to_string(),
embedding: Some(emb),
}
}
fn unit_vec(n: usize, val: f32) -> Vec<f32> {
let mut v = vec![0.0_f32; n];
v[0] = val;
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
v.iter_mut().for_each(|x| *x /= norm);
v
}
#[test]
fn cosine_similarity_identical() {
let v = vec![1.0_f32, 0.0, 0.0];
assert!((cosine_similarity(&v, &v) - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_similarity_orthogonal() {
let a = vec![1.0_f32, 0.0];
let b = vec![0.0_f32, 1.0];
assert!((cosine_similarity(&a, &b) - 0.0).abs() < 1e-6);
}
#[tokio::test]
async fn scan_returns_empty_for_no_embeddings() {
let writer = Arc::new(RecordingWriter {
written: Mutex::new(vec![]),
});
let engine =
PromotionEngine::new(writer, PromotionConfig::default(), PathBuf::from("/tmp"));
let window = vec![PromotionInput {
message_id: MessageId(1),
conversation_id: ConversationId(1),
content: "hello".into(),
embedding: None,
}];
let candidates = engine.scan(&window).await.unwrap();
assert!(candidates.is_empty());
}
#[tokio::test]
async fn scan_qualifies_cluster_meeting_thresholds() {
let writer = Arc::new(RecordingWriter {
written: Mutex::new(vec![]),
});
let config = PromotionConfig {
min_occurrences: 3,
min_sessions: 2,
cluster_threshold: 0.90,
};
let engine = PromotionEngine::new(writer, config, PathBuf::from("/tmp"));
let base = unit_vec(4, 1.0);
let window = vec![
make_input(1, 1, "a", base.clone()),
make_input(2, 1, "b", base.clone()),
make_input(3, 2, "c", base.clone()),
make_input(4, 3, "d", base.clone()),
];
let candidates = engine.scan(&window).await.unwrap();
assert_eq!(candidates.len(), 1, "expected 1 qualifying cluster");
let c = &candidates[0];
assert_eq!(c.member_ids.len(), 4);
assert_eq!(c.session_ids.len(), 3);
}
#[tokio::test]
async fn scan_rejects_cluster_below_min_sessions() {
let writer = Arc::new(RecordingWriter {
written: Mutex::new(vec![]),
});
let config = PromotionConfig {
min_occurrences: 3,
min_sessions: 2,
cluster_threshold: 0.90,
};
let engine = PromotionEngine::new(writer, config, PathBuf::from("/tmp"));
let base = unit_vec(4, 1.0);
let window = (1..=4)
.map(|i| make_input(i, 1, "x", base.clone()))
.collect::<Vec<_>>();
let candidates = engine.scan(&window).await.unwrap();
assert!(
candidates.is_empty(),
"should reject cluster with only 1 session"
);
}
#[tokio::test]
async fn scan_errors_on_dimension_mismatch() {
let writer = Arc::new(RecordingWriter {
written: Mutex::new(vec![]),
});
let engine =
PromotionEngine::new(writer, PromotionConfig::default(), PathBuf::from("/tmp"));
let window = vec![
make_input(1, 1, "a", vec![1.0, 0.0, 0.0]),
make_input(2, 2, "b", vec![0.0, 1.0]), ];
let result = engine.scan(&window).await;
assert!(result.is_err(), "expected error on dimension mismatch");
}
}