#![allow(dead_code)]
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use parking_lot::Mutex;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use zeph_common::ToolName;
use zeph_tools::{ToolError, ToolOutput};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct HandleKey {
pub tool_id: ToolName,
pub args_hash: blake3::Hash,
}
pub struct SpeculativeHandle {
pub key: HandleKey,
pub join: JoinHandle<Result<Option<ToolOutput>, ToolError>>,
pub cancel: CancellationToken,
pub ttl_deadline: tokio::time::Instant,
pub started_at: Instant,
}
impl std::fmt::Debug for SpeculativeHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SpeculativeHandle")
.field("key", &self.key)
.field("ttl_deadline", &self.ttl_deadline)
.field("started_at", &self.started_at)
.finish_non_exhaustive()
}
}
impl SpeculativeHandle {
pub fn cancel(self) {
self.cancel.cancel();
self.join.abort();
}
pub async fn commit(self) -> Result<Option<ToolOutput>, ToolError> {
match self.join.await {
Ok(r) => r,
Err(e) if e.is_cancelled() => Err(ToolError::Execution(std::io::Error::other(
"speculative task cancelled",
))),
Err(e) => Err(ToolError::Execution(std::io::Error::other(e.to_string()))),
}
}
}
pub struct CacheInner {
pub handles: HashMap<HandleKey, SpeculativeHandle>,
}
pub struct SpeculativeCache {
pub(crate) inner: Arc<Mutex<CacheInner>>,
max: usize,
}
impl SpeculativeCache {
#[must_use]
pub fn new(max_in_flight: usize) -> Self {
Self {
inner: Arc::new(Mutex::new(CacheInner {
handles: HashMap::new(),
})),
max: max_in_flight.clamp(1, 16),
}
}
#[must_use]
pub fn shared_inner(&self) -> Arc<Mutex<CacheInner>> {
Arc::clone(&self.inner)
}
pub fn sweep_expired_inner(inner: &Mutex<CacheInner>) {
let now = tokio::time::Instant::now();
let mut g = inner.lock();
let expired: Vec<HandleKey> = g
.handles
.iter()
.filter(|(_, h)| h.ttl_deadline <= now)
.map(|(k, _)| k.clone())
.collect();
for key in expired {
if let Some(h) = g.handles.remove(&key) {
h.cancel();
}
}
}
pub fn insert(&self, handle: SpeculativeHandle) {
let mut g = self.inner.lock();
if g.handles.len() >= self.max {
let oldest_key = g
.handles
.values()
.min_by_key(|h| h.started_at)
.map(|h| h.key.clone());
if let Some(key) = oldest_key
&& let Some(evicted) = g.handles.remove(&key)
{
evicted.cancel();
}
}
if let Some(displaced) = g.handles.insert(handle.key.clone(), handle) {
displaced.cancel();
}
}
#[must_use]
pub fn take_match(
&self,
tool_id: &ToolName,
args_hash: &blake3::Hash,
) -> Option<SpeculativeHandle> {
let key = HandleKey {
tool_id: tool_id.clone(),
args_hash: *args_hash,
};
self.inner.lock().handles.remove(&key)
}
pub fn cancel_by_tool_id(&self, tool_id: &ToolName) {
let mut g = self.inner.lock();
let key = g.handles.keys().find(|k| &k.tool_id == tool_id).cloned();
if let Some(key) = key
&& let Some(h) = g.handles.remove(&key)
{
h.cancel();
}
}
pub fn sweep_expired(&self) {
Self::sweep_expired_inner(&self.inner);
}
pub fn cancel_all(&self) {
let mut g = self.inner.lock();
for (_, h) in g.handles.drain() {
h.cancel();
}
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.lock().handles.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[must_use]
pub fn hash_args(args: &serde_json::Map<String, serde_json::Value>) -> blake3::Hash {
let mut keys: Vec<&str> = args.keys().map(String::as_str).collect();
keys.sort_unstable();
let mut hasher = blake3::Hasher::new();
for k in keys {
hasher.update(k.as_bytes());
hasher.update(b"\x00");
let v = args[k].to_string();
hasher.update(v.as_bytes());
hasher.update(b"\x00");
}
hasher.finalize()
}
#[must_use]
pub fn args_template(args: &serde_json::Map<String, serde_json::Value>) -> String {
let template: serde_json::Map<String, serde_json::Value> = args
.iter()
.map(|(k, v)| {
let placeholder = match v {
serde_json::Value::String(_) => serde_json::json!("<string>"),
serde_json::Value::Number(_) => serde_json::json!("<number>"),
serde_json::Value::Bool(_) => serde_json::json!("<bool>"),
serde_json::Value::Array(_) => serde_json::json!("<array>"),
serde_json::Value::Object(_) => serde_json::json!("<object>"),
serde_json::Value::Null => serde_json::json!(null),
};
(k.clone(), placeholder)
})
.collect();
serde_json::Value::Object(template).to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hash_args_order_independent() {
let mut a = serde_json::Map::new();
a.insert("z".into(), serde_json::json!(1));
a.insert("a".into(), serde_json::json!(2));
let mut b = serde_json::Map::new();
b.insert("a".into(), serde_json::json!(2));
b.insert("z".into(), serde_json::json!(1));
assert_eq!(hash_args(&a), hash_args(&b));
}
#[test]
fn hash_args_different_values() {
let mut a = serde_json::Map::new();
a.insert("x".into(), serde_json::json!(1));
let mut b = serde_json::Map::new();
b.insert("x".into(), serde_json::json!(2));
assert_ne!(hash_args(&a), hash_args(&b));
}
#[test]
fn args_template_replaces_values_with_type_placeholders() {
let mut m = serde_json::Map::new();
m.insert("cmd".into(), serde_json::json!("ls -la"));
m.insert("timeout".into(), serde_json::json!(30));
m.insert("flag".into(), serde_json::json!(true));
let t = args_template(&m);
assert!(t.contains("<string>"));
assert!(t.contains("<number>"));
assert!(t.contains("<bool>"));
}
}