1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::Mutex;
4use tracing::debug;
5
6#[derive(Debug, Clone)]
8pub struct ToolPrediction {
9 pub tool_name: String,
10 pub predicted_params: serde_json::Value,
11 pub confidence: f64,
12}
13
14#[derive(Debug, Clone, Hash, PartialEq, Eq)]
18pub struct SpeculationKey {
19 pub tool_name: String,
20 pub params_json: String,
21}
22
23impl SpeculationKey {
24 #[must_use]
25 pub fn new(tool_name: &str, params: &serde_json::Value) -> Self {
26 Self {
27 tool_name: tool_name.to_string(),
28 params_json: params.to_string(),
29 }
30 }
31}
32
33#[derive(Debug, Clone)]
35pub struct SpeculativeResult {
36 pub output: String,
37 pub metadata: Option<serde_json::Value>,
38 pub created_at: std::time::Instant,
39}
40
41#[derive(Debug)]
44pub struct SpeculationCache {
45 cache: Arc<Mutex<HashMap<SpeculationKey, SpeculativeResult>>>,
46 max_concurrent: usize,
47 active_count: Arc<std::sync::atomic::AtomicUsize>,
48}
49
50pub struct SpeculationSlotGuard {
51 active_count: Arc<std::sync::atomic::AtomicUsize>,
52}
53
54impl Drop for SpeculationSlotGuard {
55 fn drop(&mut self) {
56 let prev = self
59 .active_count
60 .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
61 debug_assert!(
62 prev > 0,
63 "SpeculationSlotGuard dropped with count already 0"
64 );
65 }
66}
67
68impl SpeculationCache {
69 #[must_use]
70 pub fn new(max_concurrent: usize) -> Self {
71 Self {
72 cache: Arc::new(Mutex::new(HashMap::new())),
73 max_concurrent,
74 active_count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
75 }
76 }
77
78 pub async fn get(
80 &self,
81 tool_name: &str,
82 params: &serde_json::Value,
83 ) -> Option<SpeculativeResult> {
84 let key = SpeculationKey::new(tool_name, params);
85 let cache = self.cache.lock().await;
86 cache.get(&key).cloned()
87 }
88
89 pub async fn insert(
91 &self,
92 tool_name: &str,
93 params: &serde_json::Value,
94 result: SpeculativeResult,
95 ) {
96 let key = SpeculationKey::new(tool_name, params);
97 let mut cache = self.cache.lock().await;
98 cache.insert(key, result);
99 }
100
101 pub async fn clear(&self) {
103 let mut cache = self.cache.lock().await;
104 let count = cache.len();
105 cache.clear();
106 if count > 0 {
107 debug!(cleared = count, "speculation cache cleared");
108 }
109 }
110
111 pub async fn size(&self) -> usize {
113 self.cache.lock().await.len()
114 }
115
116 pub fn can_speculate(&self) -> bool {
118 self.active_count.load(std::sync::atomic::Ordering::Acquire) < self.max_concurrent
119 }
120
121 pub fn start_speculation(&self) -> bool {
123 let prev = self
124 .active_count
125 .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
126 if prev >= self.max_concurrent {
127 self.active_count
128 .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
129 false
130 } else {
131 true
132 }
133 }
134
135 pub fn reserve_slot(&self) -> Option<SpeculationSlotGuard> {
139 let prev = self
140 .active_count
141 .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
142 if prev >= self.max_concurrent {
143 self.active_count
144 .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
145 None
146 } else {
147 Some(SpeculationSlotGuard {
148 active_count: Arc::clone(&self.active_count),
149 })
150 }
151 }
152
153 pub fn end_speculation(&self) {
158 let mut current = self.active_count.load(std::sync::atomic::Ordering::Acquire);
160 loop {
161 if current == 0 {
162 tracing::debug!("end_speculation called with no active slots — no-op");
163 return;
164 }
165 match self.active_count.compare_exchange_weak(
166 current,
167 current - 1,
168 std::sync::atomic::Ordering::AcqRel,
169 std::sync::atomic::Ordering::Acquire,
170 ) {
171 Ok(_) => return,
172 Err(updated) => current = updated,
173 }
174 }
175 }
176
177 pub fn active_count(&self) -> usize {
178 self.active_count.load(std::sync::atomic::Ordering::Acquire)
179 }
180}
181
182pub struct ToolPredictor {
184 min_confidence: f64,
185}
186
187impl ToolPredictor {
188 #[must_use]
189 pub fn new(min_confidence: f64) -> Self {
190 Self { min_confidence }
191 }
192
193 pub fn predict(
196 &self,
197 recent_tools: &[String],
198 available_tools: &[String],
199 ) -> Vec<ToolPrediction> {
200 let mut predictions = Vec::new();
201
202 if recent_tools.is_empty() || available_tools.is_empty() {
203 return predictions;
204 }
205
206 let last_tool = &recent_tools[recent_tools.len() - 1];
207
208 let follow_ups = common_follow_ups(last_tool);
209 for (follow_tool, confidence) in follow_ups {
210 if confidence >= self.min_confidence && available_tools.contains(&follow_tool) {
211 predictions.push(ToolPrediction {
212 tool_name: follow_tool,
213 predicted_params: serde_json::Value::Object(serde_json::Map::new()),
214 confidence,
215 });
216 }
217 }
218
219 let repeat_count = recent_tools
221 .iter()
222 .rev()
223 .take_while(|t| *t == last_tool)
224 .count();
225 if repeat_count >= 2 {
226 let confidence = 0.6 + (repeat_count as f64 * 0.05).min(0.2);
227 if confidence >= self.min_confidence
228 && !predictions.iter().any(|p| p.tool_name == *last_tool)
229 {
230 predictions.push(ToolPrediction {
231 tool_name: last_tool.clone(),
232 predicted_params: serde_json::Value::Object(serde_json::Map::new()),
233 confidence,
234 });
235 }
236 }
237
238 predictions.sort_by(|a, b| {
239 b.confidence
240 .partial_cmp(&a.confidence)
241 .unwrap_or(std::cmp::Ordering::Equal)
242 });
243 predictions
244 }
245}
246
247fn common_follow_ups(tool_name: &str) -> Vec<(String, f64)> {
249 match tool_name {
250 "file_read" => vec![
251 ("file_read".to_string(), 0.7),
252 ("memory_search".to_string(), 0.4),
253 ],
254 "memory_search" => vec![
255 ("memory_search".to_string(), 0.5),
256 ("file_read".to_string(), 0.3),
257 ],
258 "http_get" => vec![("http_get".to_string(), 0.6)],
259 "list_directory" => vec![
260 ("file_read".to_string(), 0.7),
261 ("list_directory".to_string(), 0.4),
262 ],
263 _ => Vec::new(),
264 }
265}
266
267pub fn is_safe_for_speculation(risk: &roboticus_core::RiskLevel) -> bool {
270 matches!(risk, roboticus_core::RiskLevel::Safe)
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn speculation_key_hashing() {
279 let key1 = SpeculationKey::new("file_read", &serde_json::json!({"path": "/tmp/a.txt"}));
280 let key2 = SpeculationKey::new("file_read", &serde_json::json!({"path": "/tmp/a.txt"}));
281 let key3 = SpeculationKey::new("file_read", &serde_json::json!({"path": "/tmp/b.txt"}));
282
283 assert_eq!(key1, key2);
284 assert_ne!(key1, key3);
285 }
286
287 #[tokio::test]
288 async fn cache_insert_and_get() {
289 let cache = SpeculationCache::new(4);
290 let params = serde_json::json!({"path": "/tmp/test.txt"});
291
292 cache
293 .insert(
294 "file_read",
295 ¶ms,
296 SpeculativeResult {
297 output: "file contents".to_string(),
298 metadata: None,
299 created_at: std::time::Instant::now(),
300 },
301 )
302 .await;
303
304 let result = cache.get("file_read", ¶ms).await;
305 assert!(result.is_some());
306 assert_eq!(result.unwrap().output, "file contents");
307 }
308
309 #[tokio::test]
310 async fn cache_miss() {
311 let cache = SpeculationCache::new(4);
312 let params = serde_json::json!({"path": "/tmp/missing.txt"});
313 let result = cache.get("file_read", ¶ms).await;
314 assert!(result.is_none());
315 }
316
317 #[tokio::test]
318 async fn cache_clear() {
319 let cache = SpeculationCache::new(4);
320 let params = serde_json::json!({"key": "value"});
321 cache
322 .insert(
323 "tool1",
324 ¶ms,
325 SpeculativeResult {
326 output: "result".to_string(),
327 metadata: None,
328 created_at: std::time::Instant::now(),
329 },
330 )
331 .await;
332
333 assert_eq!(cache.size().await, 1);
334 cache.clear().await;
335 assert_eq!(cache.size().await, 0);
336 }
337
338 #[test]
339 fn concurrency_limit() {
340 let cache = SpeculationCache::new(2);
341 assert!(cache.can_speculate());
342 assert!(cache.start_speculation());
343 assert!(cache.start_speculation());
344 assert!(!cache.start_speculation());
345 assert_eq!(cache.active_count(), 2);
346
347 cache.end_speculation();
348 assert!(cache.can_speculate());
349 assert_eq!(cache.active_count(), 1);
350 }
351
352 #[test]
353 fn predictor_no_history() {
354 let predictor = ToolPredictor::new(0.3);
355 let predictions = predictor.predict(&[], &["file_read".to_string()]);
356 assert!(predictions.is_empty());
357 }
358
359 #[test]
360 fn predictor_known_sequence() {
361 let predictor = ToolPredictor::new(0.3);
362 let recent = vec!["list_directory".to_string()];
363 let available = vec!["file_read".to_string(), "list_directory".to_string()];
364 let predictions = predictor.predict(&recent, &available);
365 assert!(!predictions.is_empty());
366 assert_eq!(predictions[0].tool_name, "file_read");
367 assert!(predictions[0].confidence >= 0.7);
368 }
369
370 #[test]
371 fn predictor_repeated_tool() {
372 let predictor = ToolPredictor::new(0.3);
373 let recent = vec![
374 "file_read".to_string(),
375 "file_read".to_string(),
376 "file_read".to_string(),
377 ];
378 let available = vec!["file_read".to_string(), "memory_search".to_string()];
379 let predictions = predictor.predict(&recent, &available);
380 assert!(predictions.iter().any(|p| p.tool_name == "file_read"));
381 }
382
383 #[test]
384 fn predictor_confidence_filter() {
385 let predictor = ToolPredictor::new(0.9);
386 let recent = vec!["memory_search".to_string()];
387 let available = vec!["memory_search".to_string(), "file_read".to_string()];
388 let predictions = predictor.predict(&recent, &available);
389 assert!(predictions.is_empty() || predictions.iter().all(|p| p.confidence >= 0.9));
390 }
391
392 #[test]
393 fn predictor_unavailable_tool_filtered() {
394 let predictor = ToolPredictor::new(0.3);
395 let recent = vec!["list_directory".to_string()];
396 let available = vec!["memory_search".to_string()];
397 let predictions = predictor.predict(&recent, &available);
398 assert!(!predictions.iter().any(|p| p.tool_name == "file_read"));
399 }
400
401 #[test]
402 fn safe_for_speculation() {
403 assert!(is_safe_for_speculation(&roboticus_core::RiskLevel::Safe));
404 assert!(!is_safe_for_speculation(
405 &roboticus_core::RiskLevel::Caution
406 ));
407 assert!(!is_safe_for_speculation(
408 &roboticus_core::RiskLevel::Dangerous
409 ));
410 assert!(!is_safe_for_speculation(
411 &roboticus_core::RiskLevel::Forbidden
412 ));
413 }
414
415 #[test]
416 fn speculation_policy_gate_never_allows_approval_or_forbidden_risks() {
417 let risky = [
418 roboticus_core::RiskLevel::Caution,
419 roboticus_core::RiskLevel::Dangerous,
420 roboticus_core::RiskLevel::Forbidden,
421 ];
422 for risk in risky {
423 assert!(
424 !is_safe_for_speculation(&risk),
425 "speculative execution must remain Safe-only; got {risk:?}"
426 );
427 }
428 }
429
430 #[test]
431 fn predictions_sorted_by_confidence() {
432 let predictor = ToolPredictor::new(0.3);
433 let recent = vec!["list_directory".to_string()];
434 let available = vec!["file_read".to_string(), "list_directory".to_string()];
435 let predictions = predictor.predict(&recent, &available);
436 for i in 1..predictions.len() {
437 assert!(predictions[i - 1].confidence >= predictions[i].confidence);
438 }
439 }
440
441 #[test]
442 fn common_follow_ups_http_get() {
443 let predictor = ToolPredictor::new(0.3);
445 let recent = vec!["http_get".to_string()];
446 let available = vec!["http_get".to_string()];
447 let predictions = predictor.predict(&recent, &available);
448 assert!(
449 predictions.iter().any(|p| p.tool_name == "http_get"),
450 "http_get should predict a follow-up http_get"
451 );
452 }
453
454 #[test]
455 fn common_follow_ups_unknown_tool_returns_empty() {
456 let predictor = ToolPredictor::new(0.3);
458 let recent = vec!["unknown_exotic_tool".to_string()];
459 let available = vec!["unknown_exotic_tool".to_string(), "file_read".to_string()];
460 let predictions = predictor.predict(&recent, &available);
461 assert!(
463 predictions.is_empty(),
464 "unknown tool with single call should produce no predictions"
465 );
466 }
467
468 #[test]
469 fn predict_empty_available_tools() {
470 let predictor = ToolPredictor::new(0.3);
471 let recent = vec!["file_read".to_string()];
472 let predictions = predictor.predict(&recent, &[]);
473 assert!(
474 predictions.is_empty(),
475 "no available tools means no predictions"
476 );
477 }
478
479 #[test]
480 fn predict_empty_recent_tools() {
481 let predictor = ToolPredictor::new(0.3);
482 let available = vec!["file_read".to_string()];
483 let predictions = predictor.predict(&[], &available);
484 assert!(
485 predictions.is_empty(),
486 "no recent tools means no predictions"
487 );
488 }
489
490 #[test]
491 fn start_speculation_exhaustion_and_recovery() {
492 let cache = SpeculationCache::new(1);
493 assert!(cache.start_speculation(), "first slot should succeed");
494 assert!(!cache.start_speculation(), "second slot should fail");
495 assert_eq!(
496 cache.active_count(),
497 1,
498 "count should remain 1 after failed attempt"
499 );
500 cache.end_speculation();
501 assert_eq!(cache.active_count(), 0);
502 assert!(cache.start_speculation(), "slot should be available again");
503 }
504
505 #[test]
506 fn reserve_slot_guard_releases_on_drop() {
507 let cache = SpeculationCache::new(1);
508 let guard = cache.reserve_slot().expect("first reserve should succeed");
509 assert_eq!(cache.active_count(), 1);
510 drop(guard);
511 assert_eq!(
512 cache.active_count(),
513 0,
514 "dropping guard must release speculation slot"
515 );
516 }
517
518 #[tokio::test]
519 async fn reserve_slot_guard_releases_on_task_abort() {
520 let cache = Arc::new(SpeculationCache::new(1));
521 let cache_for_task = Arc::clone(&cache);
522 let task = tokio::spawn(async move {
523 let _guard = cache_for_task
524 .reserve_slot()
525 .expect("slot should be available");
526 tokio::time::sleep(std::time::Duration::from_secs(30)).await;
527 });
528 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
529 assert_eq!(cache.active_count(), 1);
530 task.abort();
531 let _ = task.await;
533 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
534 assert_eq!(
535 cache.active_count(),
536 0,
537 "aborted task must not leak active speculation slots"
538 );
539 }
540
541 #[test]
542 fn memory_search_follow_ups() {
543 let predictor = ToolPredictor::new(0.3);
544 let recent = vec!["memory_search".to_string()];
545 let available = vec!["memory_search".to_string(), "file_read".to_string()];
546 let predictions = predictor.predict(&recent, &available);
547 assert!(
548 predictions.iter().any(|p| p.tool_name == "memory_search"),
549 "memory_search should predict memory_search follow-up"
550 );
551 }
552
553 #[test]
554 fn repeated_tool_no_duplicate_with_follow_up() {
555 let predictor = ToolPredictor::new(0.3);
558 let recent = vec![
559 "file_read".to_string(),
560 "file_read".to_string(),
561 "file_read".to_string(),
562 ];
563 let available = vec!["file_read".to_string(), "memory_search".to_string()];
564 let predictions = predictor.predict(&recent, &available);
565 let file_read_count = predictions
566 .iter()
567 .filter(|p| p.tool_name == "file_read")
568 .count();
569 assert_eq!(
570 file_read_count, 1,
571 "file_read should appear exactly once (no duplicate from repeat heuristic)"
572 );
573 }
574
575 #[tokio::test]
576 async fn cache_different_tools_same_params() {
577 let cache = SpeculationCache::new(4);
578 let params = serde_json::json!({"path": "/tmp/test.txt"});
579 cache
580 .insert(
581 "file_read",
582 ¶ms,
583 SpeculativeResult {
584 output: "read result".to_string(),
585 metadata: None,
586 created_at: std::time::Instant::now(),
587 },
588 )
589 .await;
590 cache
591 .insert(
592 "file_write",
593 ¶ms,
594 SpeculativeResult {
595 output: "write result".to_string(),
596 metadata: None,
597 created_at: std::time::Instant::now(),
598 },
599 )
600 .await;
601 assert_eq!(cache.size().await, 2);
602 let read_result = cache.get("file_read", ¶ms).await.unwrap();
603 assert_eq!(read_result.output, "read result");
604 let write_result = cache.get("file_write", ¶ms).await.unwrap();
605 assert_eq!(write_result.output, "write result");
606 }
607
608 #[test]
609 fn speculation_key_different_tool_names() {
610 let params = serde_json::json!({"key": "value"});
611 let key1 = SpeculationKey::new("tool_a", ¶ms);
612 let key2 = SpeculationKey::new("tool_b", ¶ms);
613 assert_ne!(
614 key1, key2,
615 "different tool names should produce different keys"
616 );
617 }
618
619 #[test]
620 fn speculative_result_metadata() {
621 let result = SpeculativeResult {
622 output: "data".to_string(),
623 metadata: Some(serde_json::json!({"source": "cache"})),
624 created_at: std::time::Instant::now(),
625 };
626 assert_eq!(result.metadata.unwrap()["source"], "cache");
627 }
628}