#![allow(dead_code)]
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use thiserror::Error;
use tokio::sync::Mutex as AsyncMutex;
use tracing::{debug, warn};
use zeph_db::DbPool;
use super::prediction::{Prediction, PredictionSource};
use crate::agent::speculative::cache::{args_template, hash_args};
const Z: f64 = 1.645;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolOutcome {
Success,
Failure,
}
#[derive(Debug, Error)]
pub enum PatternError {
#[error("database error: {0}")]
Db(#[from] zeph_db::sqlx::Error),
#[error("json error: {0}")]
Json(#[from] serde_json::Error),
}
struct RefreshState {
last_refresh: Option<std::time::Instant>,
}
pub struct PatternStore {
pool: DbPool,
half_life_days: f64,
refresh_debounce: Arc<AsyncMutex<std::collections::HashMap<String, RefreshState>>>,
min_observations: u32,
}
impl PatternStore {
#[must_use]
pub fn new(pool: DbPool, half_life_days: f64) -> Self {
Self {
pool,
half_life_days,
refresh_debounce: Arc::new(AsyncMutex::new(std::collections::HashMap::new())),
min_observations: 5,
}
}
#[must_use]
pub fn with_min_observations(mut self, n: u32) -> Self {
self.min_observations = n;
self
}
#[allow(clippy::too_many_arguments)]
pub async fn observe(
&self,
skill_name: &str,
skill_hash: &str,
prev_tool: Option<&str>,
next_tool: &str,
args_json: &str,
outcome: ToolOutcome,
latency_ms: u64,
) -> Result<(), PatternError> {
let now = unix_now();
let half_life_secs = self.half_life_days * 86_400.0;
let success_delta = i64::from(outcome == ToolOutcome::Success);
let args: serde_json::Value = serde_json::from_str(args_json)?;
let args_obj = args.as_object().cloned().unwrap_or_default();
let args_fingerprint = {
let h = hash_args(&args_obj);
h.to_hex().to_string()
};
let tmpl = args_template(&args_obj);
#[allow(clippy::cast_possible_wrap)]
let latency_i64 = latency_ms as i64;
let existing = zeph_db::query_as::<_, (f64, i64, i64, i64)>(
r"
SELECT count_decayed, last_seen_at, count_raw, avg_latency_ms
FROM tool_pattern_transitions
WHERE skill_name = ? AND skill_hash = ?
AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
AND next_tool = ? AND args_fingerprint = ?
",
)
.bind(skill_name)
.bind(skill_hash)
.bind(prev_tool)
.bind(prev_tool)
.bind(next_tool)
.bind(&args_fingerprint)
.fetch_optional(&self.pool)
.await?;
if let Some((old_decayed, last_seen_at, old_count_raw, old_avg_latency)) = existing {
#[allow(clippy::cast_precision_loss)]
let elapsed = (now - last_seen_at).max(0) as f64;
let new_decayed = old_decayed * 0.5f64.powf(elapsed / half_life_secs) + 1.0;
let new_count_raw = old_count_raw + 1;
#[allow(clippy::cast_precision_loss)]
let new_avg_latency = (old_avg_latency * old_count_raw + latency_i64) / new_count_raw;
zeph_db::query(
r"
UPDATE tool_pattern_transitions SET
count_decayed = ?,
count_raw = ?,
success_raw = success_raw + ?,
last_seen_at = ?,
avg_latency_ms = ?
WHERE skill_name = ? AND skill_hash = ?
AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
AND next_tool = ? AND args_fingerprint = ?
",
)
.bind(new_decayed)
.bind(new_count_raw)
.bind(success_delta)
.bind(now)
.bind(new_avg_latency)
.bind(skill_name)
.bind(skill_hash)
.bind(prev_tool)
.bind(prev_tool)
.bind(next_tool)
.bind(&args_fingerprint)
.execute(&self.pool)
.await?;
} else {
zeph_db::query(
r"
INSERT INTO tool_pattern_transitions
(skill_name, skill_hash, prev_tool, next_tool, args_fingerprint,
args_template, count_raw, success_raw, count_decayed, last_seen_at, avg_latency_ms)
VALUES (?, ?, ?, ?, ?, ?, 1, ?, 1.0, ?, ?)
",
)
.bind(skill_name)
.bind(skill_hash)
.bind(prev_tool)
.bind(next_tool)
.bind(&args_fingerprint)
.bind(&tmpl)
.bind(success_delta)
.bind(now)
.bind(latency_i64)
.execute(&self.pool)
.await?;
}
self.debounced_refresh(skill_name, skill_hash, prev_tool)
.await;
Ok(())
}
pub async fn predict(
&self,
skill_name: &str,
skill_hash: &str,
prev_tool: Option<&str>,
k: u8,
) -> Result<Vec<Prediction>, PatternError> {
let rows = zeph_db::query_as::<_, (String, String, f64, f64, i64)>(
r"
SELECT next_tool, args_template, score, wilson_lower_bound, rank
FROM tool_pattern_predictions
WHERE skill_name = ? AND skill_hash = ?
AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
AND wilson_lower_bound >= 0.5
ORDER BY rank ASC
LIMIT ?
",
)
.bind(skill_name)
.bind(skill_hash)
.bind(prev_tool)
.bind(prev_tool)
.bind(i64::from(k))
.fetch_all(&self.pool)
.await?;
let predictions = rows
.into_iter()
.enumerate()
.filter_map(|(i, (next_tool, args_template, score, _wilson, _rank))| {
let args: serde_json::Map<String, serde_json::Value> =
serde_json::from_str(&args_template).ok()?;
Some(Prediction {
tool_id: zeph_common::ToolName::new(next_tool),
args,
#[allow(clippy::cast_possible_truncation)]
confidence: score as f32,
source: PredictionSource::HistoryPattern {
skill: skill_name.to_owned(),
#[allow(clippy::cast_possible_truncation)]
rank: i as u8,
},
})
})
.collect();
Ok(predictions)
}
pub async fn refresh(
&self,
skill_name: &str,
skill_hash: &str,
prev_tool: Option<&str>,
) -> Result<(), PatternError> {
let min_obs = self.min_observations;
let half_life_secs = self.half_life_days * 86_400.0;
let now = unix_now();
let rows = zeph_db::query_as::<_, (String, String, String, i64, i64, f64, i64)>(
r"
SELECT next_tool, args_fingerprint, args_template,
count_raw, success_raw, count_decayed, last_seen_at
FROM tool_pattern_transitions
WHERE skill_name = ? AND skill_hash = ?
AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))
AND count_raw >= ?
",
)
.bind(skill_name)
.bind(skill_hash)
.bind(prev_tool)
.bind(prev_tool)
.bind(i64::from(min_obs))
.fetch_all(&self.pool)
.await?;
if rows.is_empty() {
return Ok(());
}
let scored = score_rows(rows, now, half_life_secs);
if scored.is_empty() {
return Ok(());
}
let mut tx = zeph_db::begin(&self.pool).await?;
zeph_db::query(
"DELETE FROM tool_pattern_predictions \
WHERE skill_name = ? AND skill_hash = ? \
AND (prev_tool = ? OR (prev_tool IS NULL AND ? IS NULL))",
)
.bind(skill_name)
.bind(skill_hash)
.bind(prev_tool)
.bind(prev_tool)
.execute(&mut *tx)
.await?;
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
for (rank, (next_tool, args_fp, tmpl, score, wilson)) in scored.iter().enumerate().take(10)
{
let rank_i64 = rank as i64;
zeph_db::query(
r"
INSERT OR REPLACE INTO tool_pattern_predictions
(skill_name, skill_hash, prev_tool, next_tool, args_fingerprint,
args_template, score, wilson_lower_bound, rank)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
",
)
.bind(skill_name)
.bind(skill_hash)
.bind(prev_tool)
.bind(next_tool)
.bind(args_fp)
.bind(tmpl)
.bind(score)
.bind(wilson)
.bind(rank_i64)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
debug!(
skill = skill_name,
prev_tool = prev_tool.unwrap_or("<activation>"),
"PASTE: refreshed {} predictions",
scored.len().min(10)
);
Ok(())
}
pub async fn vacuum(&self) -> Result<u64, PatternError> {
let cutoff = unix_now() - 30 * 86_400;
let result = zeph_db::query("DELETE FROM tool_pattern_transitions WHERE last_seen_at < ?")
.bind(cutoff)
.execute(&self.pool)
.await?;
let rows = result.rows_affected();
if rows > 0 {
debug!("PASTE vacuum: removed {} stale rows", rows);
}
Ok(rows)
}
async fn debounced_refresh(&self, skill_name: &str, skill_hash: &str, prev_tool: Option<&str>) {
let key = format!("{skill_hash}:{}", prev_tool.unwrap_or(""));
let should_refresh = {
let mut map = self.refresh_debounce.lock().await;
let state = map
.entry(key.clone())
.or_insert(RefreshState { last_refresh: None });
match state.last_refresh {
None => true,
Some(t) => t.elapsed() >= Duration::from_mins(1),
}
};
if should_refresh {
if let Err(e) = self.refresh(skill_name, skill_hash, prev_tool).await {
warn!("PASTE refresh failed: {e}");
}
let mut map = self.refresh_debounce.lock().await;
if let Some(state) = map.get_mut(&key) {
state.last_refresh = Some(std::time::Instant::now());
}
}
}
}
fn score_rows(
rows: Vec<(String, String, String, i64, i64, f64, i64)>,
now: i64,
half_life_secs: f64,
) -> Vec<(String, String, String, f64, f64)> {
let decayed: Vec<_> = rows
.into_iter()
.map(
|(tool, fp, tmpl, count_raw, success_raw, count_decayed, last_seen_at)| {
#[allow(clippy::cast_precision_loss)]
let elapsed = now.saturating_sub(last_seen_at) as f64;
let current_decay = count_decayed * 0.5f64.powf(elapsed / half_life_secs);
#[allow(clippy::cast_sign_loss)]
let wilson = wilson_lower_bound(success_raw as u64, count_raw as u64);
(tool, fp, tmpl, current_decay, wilson)
},
)
.collect();
let total: f64 = decayed.iter().map(|(_, _, _, d, _)| d).sum();
if total <= 0.0 {
return vec![];
}
let mut scored: Vec<_> = decayed
.into_iter()
.map(|(tool, fp, tmpl, d, wilson)| ((d / total) * wilson, tool, fp, tmpl, wilson))
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
scored
.into_iter()
.map(|(score, tool, fp, tmpl, wilson)| (tool, fp, tmpl, score, wilson))
.collect()
}
#[allow(clippy::cast_possible_wrap)]
fn unix_now() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64
}
#[allow(clippy::cast_precision_loss)]
fn wilson_lower_bound(successes: u64, n: u64) -> f64 {
if n == 0 {
return 0.0;
}
let n = n as f64;
let p_hat = successes as f64 / n;
let z2 = Z * Z;
let numerator =
p_hat + z2 / (2.0 * n) - Z * (p_hat * (1.0 - p_hat) / n + z2 / (4.0 * n * n)).sqrt();
let denominator = 1.0 + z2 / n;
(numerator / denominator).max(0.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn wilson_zero_observations() {
assert!((wilson_lower_bound(0, 0) - 0.0_f64).abs() < f64::EPSILON);
}
#[test]
fn wilson_all_success_small_n() {
let lb = wilson_lower_bound(3, 3);
assert!(lb > 0.0 && lb < 1.0, "got {lb}");
}
#[test]
fn wilson_zero_success() {
let lb = wilson_lower_bound(0, 10);
assert!(lb < 0.1, "got {lb}");
}
#[test]
fn fingerprint_deterministic_different_order() {
fn fp(json: &str) -> String {
let v: serde_json::Value = serde_json::from_str(json).unwrap();
let obj = v.as_object().cloned().unwrap_or_default();
hash_args(&obj).to_hex().to_string()
}
let a = r#"{"z": 1, "a": 2}"#;
let b = r#"{"a": 2, "z": 1}"#;
assert_eq!(fp(a), fp(b));
}
}