use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::debug;
#[derive(Debug, Clone)]
pub struct ToolPrediction {
pub tool_name: String,
pub predicted_params: serde_json::Value,
pub confidence: f64,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct SpeculationKey {
pub tool_name: String,
pub params_json: String,
}
impl SpeculationKey {
#[must_use]
pub fn new(tool_name: &str, params: &serde_json::Value) -> Self {
Self {
tool_name: tool_name.to_string(),
params_json: params.to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct SpeculativeResult {
pub output: String,
pub metadata: Option<serde_json::Value>,
pub created_at: std::time::Instant,
}
#[derive(Debug)]
pub struct SpeculationCache {
cache: Arc<Mutex<HashMap<SpeculationKey, SpeculativeResult>>>,
max_concurrent: usize,
active_count: Arc<std::sync::atomic::AtomicUsize>,
}
pub struct SpeculationSlotGuard {
active_count: Arc<std::sync::atomic::AtomicUsize>,
}
impl Drop for SpeculationSlotGuard {
fn drop(&mut self) {
let prev = self
.active_count
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
debug_assert!(
prev > 0,
"SpeculationSlotGuard dropped with count already 0"
);
}
}
impl SpeculationCache {
#[must_use]
pub fn new(max_concurrent: usize) -> Self {
Self {
cache: Arc::new(Mutex::new(HashMap::new())),
max_concurrent,
active_count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
}
}
pub async fn get(
&self,
tool_name: &str,
params: &serde_json::Value,
) -> Option<SpeculativeResult> {
let key = SpeculationKey::new(tool_name, params);
let cache = self.cache.lock().await;
cache.get(&key).cloned()
}
pub async fn insert(
&self,
tool_name: &str,
params: &serde_json::Value,
result: SpeculativeResult,
) {
let key = SpeculationKey::new(tool_name, params);
let mut cache = self.cache.lock().await;
cache.insert(key, result);
}
pub async fn clear(&self) {
let mut cache = self.cache.lock().await;
let count = cache.len();
cache.clear();
if count > 0 {
debug!(cleared = count, "speculation cache cleared");
}
}
pub async fn size(&self) -> usize {
self.cache.lock().await.len()
}
pub fn can_speculate(&self) -> bool {
self.active_count.load(std::sync::atomic::Ordering::Acquire) < self.max_concurrent
}
pub fn start_speculation(&self) -> bool {
let prev = self
.active_count
.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
if prev >= self.max_concurrent {
self.active_count
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
false
} else {
true
}
}
pub fn reserve_slot(&self) -> Option<SpeculationSlotGuard> {
let prev = self
.active_count
.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
if prev >= self.max_concurrent {
self.active_count
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
None
} else {
Some(SpeculationSlotGuard {
active_count: Arc::clone(&self.active_count),
})
}
}
pub fn end_speculation(&self) {
let mut current = self.active_count.load(std::sync::atomic::Ordering::Acquire);
loop {
if current == 0 {
tracing::debug!("end_speculation called with no active slots — no-op");
return;
}
match self.active_count.compare_exchange_weak(
current,
current - 1,
std::sync::atomic::Ordering::AcqRel,
std::sync::atomic::Ordering::Acquire,
) {
Ok(_) => return,
Err(updated) => current = updated,
}
}
}
pub fn active_count(&self) -> usize {
self.active_count.load(std::sync::atomic::Ordering::Acquire)
}
}
pub struct ToolPredictor {
min_confidence: f64,
}
impl ToolPredictor {
#[must_use]
pub fn new(min_confidence: f64) -> Self {
Self { min_confidence }
}
pub fn predict(
&self,
recent_tools: &[String],
available_tools: &[String],
) -> Vec<ToolPrediction> {
let mut predictions = Vec::new();
if recent_tools.is_empty() || available_tools.is_empty() {
return predictions;
}
let last_tool = &recent_tools[recent_tools.len() - 1];
let follow_ups = common_follow_ups(last_tool);
for (follow_tool, confidence) in follow_ups {
if confidence >= self.min_confidence && available_tools.contains(&follow_tool) {
predictions.push(ToolPrediction {
tool_name: follow_tool,
predicted_params: serde_json::Value::Object(serde_json::Map::new()),
confidence,
});
}
}
let repeat_count = recent_tools
.iter()
.rev()
.take_while(|t| *t == last_tool)
.count();
if repeat_count >= 2 {
let confidence = 0.6 + (repeat_count as f64 * 0.05).min(0.2);
if confidence >= self.min_confidence
&& !predictions.iter().any(|p| p.tool_name == *last_tool)
{
predictions.push(ToolPrediction {
tool_name: last_tool.clone(),
predicted_params: serde_json::Value::Object(serde_json::Map::new()),
confidence,
});
}
}
predictions.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
predictions
}
}
fn common_follow_ups(tool_name: &str) -> Vec<(String, f64)> {
match tool_name {
"file_read" => vec![
("file_read".to_string(), 0.7),
("memory_search".to_string(), 0.4),
],
"memory_search" => vec![
("memory_search".to_string(), 0.5),
("file_read".to_string(), 0.3),
],
"http_get" => vec![("http_get".to_string(), 0.6)],
"list_directory" => vec![
("file_read".to_string(), 0.7),
("list_directory".to_string(), 0.4),
],
_ => Vec::new(),
}
}
pub fn is_safe_for_speculation(risk: &roboticus_core::RiskLevel) -> bool {
matches!(risk, roboticus_core::RiskLevel::Safe)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn speculation_key_hashing() {
let key1 = SpeculationKey::new("file_read", &serde_json::json!({"path": "/tmp/a.txt"}));
let key2 = SpeculationKey::new("file_read", &serde_json::json!({"path": "/tmp/a.txt"}));
let key3 = SpeculationKey::new("file_read", &serde_json::json!({"path": "/tmp/b.txt"}));
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
#[tokio::test]
async fn cache_insert_and_get() {
let cache = SpeculationCache::new(4);
let params = serde_json::json!({"path": "/tmp/test.txt"});
cache
.insert(
"file_read",
¶ms,
SpeculativeResult {
output: "file contents".to_string(),
metadata: None,
created_at: std::time::Instant::now(),
},
)
.await;
let result = cache.get("file_read", ¶ms).await;
assert!(result.is_some());
assert_eq!(result.unwrap().output, "file contents");
}
#[tokio::test]
async fn cache_miss() {
let cache = SpeculationCache::new(4);
let params = serde_json::json!({"path": "/tmp/missing.txt"});
let result = cache.get("file_read", ¶ms).await;
assert!(result.is_none());
}
#[tokio::test]
async fn cache_clear() {
let cache = SpeculationCache::new(4);
let params = serde_json::json!({"key": "value"});
cache
.insert(
"tool1",
¶ms,
SpeculativeResult {
output: "result".to_string(),
metadata: None,
created_at: std::time::Instant::now(),
},
)
.await;
assert_eq!(cache.size().await, 1);
cache.clear().await;
assert_eq!(cache.size().await, 0);
}
#[test]
fn concurrency_limit() {
let cache = SpeculationCache::new(2);
assert!(cache.can_speculate());
assert!(cache.start_speculation());
assert!(cache.start_speculation());
assert!(!cache.start_speculation());
assert_eq!(cache.active_count(), 2);
cache.end_speculation();
assert!(cache.can_speculate());
assert_eq!(cache.active_count(), 1);
}
#[test]
fn predictor_no_history() {
let predictor = ToolPredictor::new(0.3);
let predictions = predictor.predict(&[], &["file_read".to_string()]);
assert!(predictions.is_empty());
}
#[test]
fn predictor_known_sequence() {
let predictor = ToolPredictor::new(0.3);
let recent = vec!["list_directory".to_string()];
let available = vec!["file_read".to_string(), "list_directory".to_string()];
let predictions = predictor.predict(&recent, &available);
assert!(!predictions.is_empty());
assert_eq!(predictions[0].tool_name, "file_read");
assert!(predictions[0].confidence >= 0.7);
}
#[test]
fn predictor_repeated_tool() {
let predictor = ToolPredictor::new(0.3);
let recent = vec![
"file_read".to_string(),
"file_read".to_string(),
"file_read".to_string(),
];
let available = vec!["file_read".to_string(), "memory_search".to_string()];
let predictions = predictor.predict(&recent, &available);
assert!(predictions.iter().any(|p| p.tool_name == "file_read"));
}
#[test]
fn predictor_confidence_filter() {
let predictor = ToolPredictor::new(0.9);
let recent = vec!["memory_search".to_string()];
let available = vec!["memory_search".to_string(), "file_read".to_string()];
let predictions = predictor.predict(&recent, &available);
assert!(predictions.is_empty() || predictions.iter().all(|p| p.confidence >= 0.9));
}
#[test]
fn predictor_unavailable_tool_filtered() {
let predictor = ToolPredictor::new(0.3);
let recent = vec!["list_directory".to_string()];
let available = vec!["memory_search".to_string()];
let predictions = predictor.predict(&recent, &available);
assert!(!predictions.iter().any(|p| p.tool_name == "file_read"));
}
#[test]
fn safe_for_speculation() {
assert!(is_safe_for_speculation(&roboticus_core::RiskLevel::Safe));
assert!(!is_safe_for_speculation(
&roboticus_core::RiskLevel::Caution
));
assert!(!is_safe_for_speculation(
&roboticus_core::RiskLevel::Dangerous
));
assert!(!is_safe_for_speculation(
&roboticus_core::RiskLevel::Forbidden
));
}
#[test]
fn speculation_policy_gate_never_allows_approval_or_forbidden_risks() {
let risky = [
roboticus_core::RiskLevel::Caution,
roboticus_core::RiskLevel::Dangerous,
roboticus_core::RiskLevel::Forbidden,
];
for risk in risky {
assert!(
!is_safe_for_speculation(&risk),
"speculative execution must remain Safe-only; got {risk:?}"
);
}
}
#[test]
fn predictions_sorted_by_confidence() {
let predictor = ToolPredictor::new(0.3);
let recent = vec!["list_directory".to_string()];
let available = vec!["file_read".to_string(), "list_directory".to_string()];
let predictions = predictor.predict(&recent, &available);
for i in 1..predictions.len() {
assert!(predictions[i - 1].confidence >= predictions[i].confidence);
}
}
#[test]
fn common_follow_ups_http_get() {
let predictor = ToolPredictor::new(0.3);
let recent = vec!["http_get".to_string()];
let available = vec!["http_get".to_string()];
let predictions = predictor.predict(&recent, &available);
assert!(
predictions.iter().any(|p| p.tool_name == "http_get"),
"http_get should predict a follow-up http_get"
);
}
#[test]
fn common_follow_ups_unknown_tool_returns_empty() {
let predictor = ToolPredictor::new(0.3);
let recent = vec!["unknown_exotic_tool".to_string()];
let available = vec!["unknown_exotic_tool".to_string(), "file_read".to_string()];
let predictions = predictor.predict(&recent, &available);
assert!(
predictions.is_empty(),
"unknown tool with single call should produce no predictions"
);
}
#[test]
fn predict_empty_available_tools() {
let predictor = ToolPredictor::new(0.3);
let recent = vec!["file_read".to_string()];
let predictions = predictor.predict(&recent, &[]);
assert!(
predictions.is_empty(),
"no available tools means no predictions"
);
}
#[test]
fn predict_empty_recent_tools() {
let predictor = ToolPredictor::new(0.3);
let available = vec!["file_read".to_string()];
let predictions = predictor.predict(&[], &available);
assert!(
predictions.is_empty(),
"no recent tools means no predictions"
);
}
#[test]
fn start_speculation_exhaustion_and_recovery() {
let cache = SpeculationCache::new(1);
assert!(cache.start_speculation(), "first slot should succeed");
assert!(!cache.start_speculation(), "second slot should fail");
assert_eq!(
cache.active_count(),
1,
"count should remain 1 after failed attempt"
);
cache.end_speculation();
assert_eq!(cache.active_count(), 0);
assert!(cache.start_speculation(), "slot should be available again");
}
#[test]
fn reserve_slot_guard_releases_on_drop() {
let cache = SpeculationCache::new(1);
let guard = cache.reserve_slot().expect("first reserve should succeed");
assert_eq!(cache.active_count(), 1);
drop(guard);
assert_eq!(
cache.active_count(),
0,
"dropping guard must release speculation slot"
);
}
#[tokio::test]
async fn reserve_slot_guard_releases_on_task_abort() {
let cache = Arc::new(SpeculationCache::new(1));
let cache_for_task = Arc::clone(&cache);
let task = tokio::spawn(async move {
let _guard = cache_for_task
.reserve_slot()
.expect("slot should be available");
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
});
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
assert_eq!(cache.active_count(), 1);
task.abort();
let _ = task.await;
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
assert_eq!(
cache.active_count(),
0,
"aborted task must not leak active speculation slots"
);
}
#[test]
fn memory_search_follow_ups() {
let predictor = ToolPredictor::new(0.3);
let recent = vec!["memory_search".to_string()];
let available = vec!["memory_search".to_string(), "file_read".to_string()];
let predictions = predictor.predict(&recent, &available);
assert!(
predictions.iter().any(|p| p.tool_name == "memory_search"),
"memory_search should predict memory_search follow-up"
);
}
#[test]
fn repeated_tool_no_duplicate_with_follow_up() {
let predictor = ToolPredictor::new(0.3);
let recent = vec![
"file_read".to_string(),
"file_read".to_string(),
"file_read".to_string(),
];
let available = vec!["file_read".to_string(), "memory_search".to_string()];
let predictions = predictor.predict(&recent, &available);
let file_read_count = predictions
.iter()
.filter(|p| p.tool_name == "file_read")
.count();
assert_eq!(
file_read_count, 1,
"file_read should appear exactly once (no duplicate from repeat heuristic)"
);
}
#[tokio::test]
async fn cache_different_tools_same_params() {
let cache = SpeculationCache::new(4);
let params = serde_json::json!({"path": "/tmp/test.txt"});
cache
.insert(
"file_read",
¶ms,
SpeculativeResult {
output: "read result".to_string(),
metadata: None,
created_at: std::time::Instant::now(),
},
)
.await;
cache
.insert(
"file_write",
¶ms,
SpeculativeResult {
output: "write result".to_string(),
metadata: None,
created_at: std::time::Instant::now(),
},
)
.await;
assert_eq!(cache.size().await, 2);
let read_result = cache.get("file_read", ¶ms).await.unwrap();
assert_eq!(read_result.output, "read result");
let write_result = cache.get("file_write", ¶ms).await.unwrap();
assert_eq!(write_result.output, "write result");
}
#[test]
fn speculation_key_different_tool_names() {
let params = serde_json::json!({"key": "value"});
let key1 = SpeculationKey::new("tool_a", ¶ms);
let key2 = SpeculationKey::new("tool_b", ¶ms);
assert_ne!(
key1, key2,
"different tool names should produce different keys"
);
}
#[test]
fn speculative_result_metadata() {
let result = SpeculativeResult {
output: "data".to_string(),
metadata: Some(serde_json::json!({"source": "cache"})),
created_at: std::time::Instant::now(),
};
assert_eq!(result.metadata.unwrap()["source"], "cache");
}
}