Skip to main content

roboticus_agent/
speculative.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::Mutex;
4use tracing::debug;
5
6/// Prediction of a likely tool call based on conversation context.
7#[derive(Debug, Clone)]
8pub struct ToolPrediction {
9    pub tool_name: String,
10    pub predicted_params: serde_json::Value,
11    pub confidence: f64,
12}
13
14/// Cache key for speculative results.
15/// Uses the full JSON string for exact matching — no hash collisions and
16/// no dependency on DefaultHasher stability across Rust versions.
17#[derive(Debug, Clone, Hash, PartialEq, Eq)]
18pub struct SpeculationKey {
19    pub tool_name: String,
20    pub params_json: String,
21}
22
23impl SpeculationKey {
24    #[must_use]
25    pub fn new(tool_name: &str, params: &serde_json::Value) -> Self {
26        Self {
27            tool_name: tool_name.to_string(),
28            params_json: params.to_string(),
29        }
30    }
31}
32
33/// Cached result from a speculative tool execution.
34#[derive(Debug, Clone)]
35pub struct SpeculativeResult {
36    pub output: String,
37    pub metadata: Option<serde_json::Value>,
38    pub created_at: std::time::Instant,
39}
40
41/// Manages speculative execution of read-only tools.
42/// Spawns tokio tasks to pre-fetch results for predicted tool calls.
43#[derive(Debug)]
44pub struct SpeculationCache {
45    cache: Arc<Mutex<HashMap<SpeculationKey, SpeculativeResult>>>,
46    max_concurrent: usize,
47    active_count: Arc<std::sync::atomic::AtomicUsize>,
48}
49
50pub struct SpeculationSlotGuard {
51    active_count: Arc<std::sync::atomic::AtomicUsize>,
52}
53
54impl Drop for SpeculationSlotGuard {
55    fn drop(&mut self) {
56        // Saturating decrement: guard is created by reserve_slot which always
57        // increments first, but we defend against double-drop or misuse.
58        let prev = self
59            .active_count
60            .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
61        debug_assert!(
62            prev > 0,
63            "SpeculationSlotGuard dropped with count already 0"
64        );
65    }
66}
67
68impl SpeculationCache {
69    #[must_use]
70    pub fn new(max_concurrent: usize) -> Self {
71        Self {
72            cache: Arc::new(Mutex::new(HashMap::new())),
73            max_concurrent,
74            active_count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
75        }
76    }
77
78    /// Check if we have a cached result for the given tool + params.
79    pub async fn get(
80        &self,
81        tool_name: &str,
82        params: &serde_json::Value,
83    ) -> Option<SpeculativeResult> {
84        let key = SpeculationKey::new(tool_name, params);
85        let cache = self.cache.lock().await;
86        cache.get(&key).cloned()
87    }
88
89    /// Store a speculative result.
90    pub async fn insert(
91        &self,
92        tool_name: &str,
93        params: &serde_json::Value,
94        result: SpeculativeResult,
95    ) {
96        let key = SpeculationKey::new(tool_name, params);
97        let mut cache = self.cache.lock().await;
98        cache.insert(key, result);
99    }
100
101    /// Clear all cached speculative results (called at turn completion).
102    pub async fn clear(&self) {
103        let mut cache = self.cache.lock().await;
104        let count = cache.len();
105        cache.clear();
106        if count > 0 {
107            debug!(cleared = count, "speculation cache cleared");
108        }
109    }
110
111    /// Number of cached results.
112    pub async fn size(&self) -> usize {
113        self.cache.lock().await.len()
114    }
115
116    /// Whether we can spawn another speculative task.
117    pub fn can_speculate(&self) -> bool {
118        self.active_count.load(std::sync::atomic::Ordering::Acquire) < self.max_concurrent
119    }
120
121    /// Try to reserve a speculation slot. Returns true if the slot was acquired.
122    pub fn start_speculation(&self) -> bool {
123        let prev = self
124            .active_count
125            .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
126        if prev >= self.max_concurrent {
127            self.active_count
128                .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
129            false
130        } else {
131            true
132        }
133    }
134
135    /// Reserve a speculation slot and get an RAII guard that releases it on drop.
136    ///
137    /// Prefer this over start/end pairs to guarantee cleanup on cancellation/panic.
138    pub fn reserve_slot(&self) -> Option<SpeculationSlotGuard> {
139        let prev = self
140            .active_count
141            .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
142        if prev >= self.max_concurrent {
143            self.active_count
144                .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
145            None
146        } else {
147            Some(SpeculationSlotGuard {
148                active_count: Arc::clone(&self.active_count),
149            })
150        }
151    }
152
153    /// Release a speculation slot.
154    ///
155    /// Saturates at zero: calling this when no slots are active is a no-op
156    /// rather than wrapping `usize` to `MAX`.
157    pub fn end_speculation(&self) {
158        // CAS loop: decrement only if current value > 0.
159        let mut current = self.active_count.load(std::sync::atomic::Ordering::Acquire);
160        loop {
161            if current == 0 {
162                tracing::debug!("end_speculation called with no active slots — no-op");
163                return;
164            }
165            match self.active_count.compare_exchange_weak(
166                current,
167                current - 1,
168                std::sync::atomic::Ordering::AcqRel,
169                std::sync::atomic::Ordering::Acquire,
170            ) {
171                Ok(_) => return,
172                Err(updated) => current = updated,
173            }
174        }
175    }
176
177    pub fn active_count(&self) -> usize {
178        self.active_count.load(std::sync::atomic::Ordering::Acquire)
179    }
180}
181
182/// Predicts likely next tool calls based on conversation context.
183pub struct ToolPredictor {
184    min_confidence: f64,
185}
186
187impl ToolPredictor {
188    #[must_use]
189    pub fn new(min_confidence: f64) -> Self {
190        Self { min_confidence }
191    }
192
193    /// Analyze recent tool call history to predict likely next calls.
194    /// Uses pattern matching on sequential tool usage.
195    pub fn predict(
196        &self,
197        recent_tools: &[String],
198        available_tools: &[String],
199    ) -> Vec<ToolPrediction> {
200        let mut predictions = Vec::new();
201
202        if recent_tools.is_empty() || available_tools.is_empty() {
203            return predictions;
204        }
205
206        let last_tool = &recent_tools[recent_tools.len() - 1];
207
208        let follow_ups = common_follow_ups(last_tool);
209        for (follow_tool, confidence) in follow_ups {
210            if confidence >= self.min_confidence && available_tools.contains(&follow_tool) {
211                predictions.push(ToolPrediction {
212                    tool_name: follow_tool,
213                    predicted_params: serde_json::Value::Object(serde_json::Map::new()),
214                    confidence,
215                });
216            }
217        }
218
219        // Repeated tool calls raise confidence that the same tool will be called again
220        let repeat_count = recent_tools
221            .iter()
222            .rev()
223            .take_while(|t| *t == last_tool)
224            .count();
225        if repeat_count >= 2 {
226            let confidence = 0.6 + (repeat_count as f64 * 0.05).min(0.2);
227            if confidence >= self.min_confidence
228                && !predictions.iter().any(|p| p.tool_name == *last_tool)
229            {
230                predictions.push(ToolPrediction {
231                    tool_name: last_tool.clone(),
232                    predicted_params: serde_json::Value::Object(serde_json::Map::new()),
233                    confidence,
234                });
235            }
236        }
237
238        predictions.sort_by(|a, b| {
239            b.confidence
240                .partial_cmp(&a.confidence)
241                .unwrap_or(std::cmp::Ordering::Equal)
242        });
243        predictions
244    }
245}
246
247/// Known tool sequences that commonly follow each other.
248fn common_follow_ups(tool_name: &str) -> Vec<(String, f64)> {
249    match tool_name {
250        "file_read" => vec![
251            ("file_read".to_string(), 0.7),
252            ("memory_search".to_string(), 0.4),
253        ],
254        "memory_search" => vec![
255            ("memory_search".to_string(), 0.5),
256            ("file_read".to_string(), 0.3),
257        ],
258        "http_get" => vec![("http_get".to_string(), 0.6)],
259        "list_directory" => vec![
260            ("file_read".to_string(), 0.7),
261            ("list_directory".to_string(), 0.4),
262        ],
263        _ => Vec::new(),
264    }
265}
266
267/// Only `Safe` tools are eligible for speculative pre-fetching — they have
268/// no side effects and can be re-executed without consequence.
269pub fn is_safe_for_speculation(risk: &roboticus_core::RiskLevel) -> bool {
270    matches!(risk, roboticus_core::RiskLevel::Safe)
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn speculation_key_hashing() {
279        let key1 = SpeculationKey::new("file_read", &serde_json::json!({"path": "/tmp/a.txt"}));
280        let key2 = SpeculationKey::new("file_read", &serde_json::json!({"path": "/tmp/a.txt"}));
281        let key3 = SpeculationKey::new("file_read", &serde_json::json!({"path": "/tmp/b.txt"}));
282
283        assert_eq!(key1, key2);
284        assert_ne!(key1, key3);
285    }
286
287    #[tokio::test]
288    async fn cache_insert_and_get() {
289        let cache = SpeculationCache::new(4);
290        let params = serde_json::json!({"path": "/tmp/test.txt"});
291
292        cache
293            .insert(
294                "file_read",
295                &params,
296                SpeculativeResult {
297                    output: "file contents".to_string(),
298                    metadata: None,
299                    created_at: std::time::Instant::now(),
300                },
301            )
302            .await;
303
304        let result = cache.get("file_read", &params).await;
305        assert!(result.is_some());
306        assert_eq!(result.unwrap().output, "file contents");
307    }
308
309    #[tokio::test]
310    async fn cache_miss() {
311        let cache = SpeculationCache::new(4);
312        let params = serde_json::json!({"path": "/tmp/missing.txt"});
313        let result = cache.get("file_read", &params).await;
314        assert!(result.is_none());
315    }
316
317    #[tokio::test]
318    async fn cache_clear() {
319        let cache = SpeculationCache::new(4);
320        let params = serde_json::json!({"key": "value"});
321        cache
322            .insert(
323                "tool1",
324                &params,
325                SpeculativeResult {
326                    output: "result".to_string(),
327                    metadata: None,
328                    created_at: std::time::Instant::now(),
329                },
330            )
331            .await;
332
333        assert_eq!(cache.size().await, 1);
334        cache.clear().await;
335        assert_eq!(cache.size().await, 0);
336    }
337
338    #[test]
339    fn concurrency_limit() {
340        let cache = SpeculationCache::new(2);
341        assert!(cache.can_speculate());
342        assert!(cache.start_speculation());
343        assert!(cache.start_speculation());
344        assert!(!cache.start_speculation());
345        assert_eq!(cache.active_count(), 2);
346
347        cache.end_speculation();
348        assert!(cache.can_speculate());
349        assert_eq!(cache.active_count(), 1);
350    }
351
352    #[test]
353    fn predictor_no_history() {
354        let predictor = ToolPredictor::new(0.3);
355        let predictions = predictor.predict(&[], &["file_read".to_string()]);
356        assert!(predictions.is_empty());
357    }
358
359    #[test]
360    fn predictor_known_sequence() {
361        let predictor = ToolPredictor::new(0.3);
362        let recent = vec!["list_directory".to_string()];
363        let available = vec!["file_read".to_string(), "list_directory".to_string()];
364        let predictions = predictor.predict(&recent, &available);
365        assert!(!predictions.is_empty());
366        assert_eq!(predictions[0].tool_name, "file_read");
367        assert!(predictions[0].confidence >= 0.7);
368    }
369
370    #[test]
371    fn predictor_repeated_tool() {
372        let predictor = ToolPredictor::new(0.3);
373        let recent = vec![
374            "file_read".to_string(),
375            "file_read".to_string(),
376            "file_read".to_string(),
377        ];
378        let available = vec!["file_read".to_string(), "memory_search".to_string()];
379        let predictions = predictor.predict(&recent, &available);
380        assert!(predictions.iter().any(|p| p.tool_name == "file_read"));
381    }
382
383    #[test]
384    fn predictor_confidence_filter() {
385        let predictor = ToolPredictor::new(0.9);
386        let recent = vec!["memory_search".to_string()];
387        let available = vec!["memory_search".to_string(), "file_read".to_string()];
388        let predictions = predictor.predict(&recent, &available);
389        assert!(predictions.is_empty() || predictions.iter().all(|p| p.confidence >= 0.9));
390    }
391
392    #[test]
393    fn predictor_unavailable_tool_filtered() {
394        let predictor = ToolPredictor::new(0.3);
395        let recent = vec!["list_directory".to_string()];
396        let available = vec!["memory_search".to_string()];
397        let predictions = predictor.predict(&recent, &available);
398        assert!(!predictions.iter().any(|p| p.tool_name == "file_read"));
399    }
400
401    #[test]
402    fn safe_for_speculation() {
403        assert!(is_safe_for_speculation(&roboticus_core::RiskLevel::Safe));
404        assert!(!is_safe_for_speculation(
405            &roboticus_core::RiskLevel::Caution
406        ));
407        assert!(!is_safe_for_speculation(
408            &roboticus_core::RiskLevel::Dangerous
409        ));
410        assert!(!is_safe_for_speculation(
411            &roboticus_core::RiskLevel::Forbidden
412        ));
413    }
414
415    #[test]
416    fn speculation_policy_gate_never_allows_approval_or_forbidden_risks() {
417        let risky = [
418            roboticus_core::RiskLevel::Caution,
419            roboticus_core::RiskLevel::Dangerous,
420            roboticus_core::RiskLevel::Forbidden,
421        ];
422        for risk in risky {
423            assert!(
424                !is_safe_for_speculation(&risk),
425                "speculative execution must remain Safe-only; got {risk:?}"
426            );
427        }
428    }
429
430    #[test]
431    fn predictions_sorted_by_confidence() {
432        let predictor = ToolPredictor::new(0.3);
433        let recent = vec!["list_directory".to_string()];
434        let available = vec!["file_read".to_string(), "list_directory".to_string()];
435        let predictions = predictor.predict(&recent, &available);
436        for i in 1..predictions.len() {
437            assert!(predictions[i - 1].confidence >= predictions[i].confidence);
438        }
439    }
440
441    #[test]
442    fn common_follow_ups_http_get() {
443        // http_get follow-ups should predict another http_get
444        let predictor = ToolPredictor::new(0.3);
445        let recent = vec!["http_get".to_string()];
446        let available = vec!["http_get".to_string()];
447        let predictions = predictor.predict(&recent, &available);
448        assert!(
449            predictions.iter().any(|p| p.tool_name == "http_get"),
450            "http_get should predict a follow-up http_get"
451        );
452    }
453
454    #[test]
455    fn common_follow_ups_unknown_tool_returns_empty() {
456        // Unknown tool names produce no follow-up predictions (only repeat heuristic)
457        let predictor = ToolPredictor::new(0.3);
458        let recent = vec!["unknown_exotic_tool".to_string()];
459        let available = vec!["unknown_exotic_tool".to_string(), "file_read".to_string()];
460        let predictions = predictor.predict(&recent, &available);
461        // No follow-ups for unknown tool, and only 1 call so no repeat heuristic
462        assert!(
463            predictions.is_empty(),
464            "unknown tool with single call should produce no predictions"
465        );
466    }
467
468    #[test]
469    fn predict_empty_available_tools() {
470        let predictor = ToolPredictor::new(0.3);
471        let recent = vec!["file_read".to_string()];
472        let predictions = predictor.predict(&recent, &[]);
473        assert!(
474            predictions.is_empty(),
475            "no available tools means no predictions"
476        );
477    }
478
479    #[test]
480    fn predict_empty_recent_tools() {
481        let predictor = ToolPredictor::new(0.3);
482        let available = vec!["file_read".to_string()];
483        let predictions = predictor.predict(&[], &available);
484        assert!(
485            predictions.is_empty(),
486            "no recent tools means no predictions"
487        );
488    }
489
490    #[test]
491    fn start_speculation_exhaustion_and_recovery() {
492        let cache = SpeculationCache::new(1);
493        assert!(cache.start_speculation(), "first slot should succeed");
494        assert!(!cache.start_speculation(), "second slot should fail");
495        assert_eq!(
496            cache.active_count(),
497            1,
498            "count should remain 1 after failed attempt"
499        );
500        cache.end_speculation();
501        assert_eq!(cache.active_count(), 0);
502        assert!(cache.start_speculation(), "slot should be available again");
503    }
504
505    #[test]
506    fn reserve_slot_guard_releases_on_drop() {
507        let cache = SpeculationCache::new(1);
508        let guard = cache.reserve_slot().expect("first reserve should succeed");
509        assert_eq!(cache.active_count(), 1);
510        drop(guard);
511        assert_eq!(
512            cache.active_count(),
513            0,
514            "dropping guard must release speculation slot"
515        );
516    }
517
518    #[tokio::test]
519    async fn reserve_slot_guard_releases_on_task_abort() {
520        let cache = Arc::new(SpeculationCache::new(1));
521        let cache_for_task = Arc::clone(&cache);
522        let task = tokio::spawn(async move {
523            let _guard = cache_for_task
524                .reserve_slot()
525                .expect("slot should be available");
526            tokio::time::sleep(std::time::Duration::from_secs(30)).await;
527        });
528        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
529        assert_eq!(cache.active_count(), 1);
530        task.abort();
531        // Wait for cancellation propagation and drop handling.
532        let _ = task.await;
533        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
534        assert_eq!(
535            cache.active_count(),
536            0,
537            "aborted task must not leak active speculation slots"
538        );
539    }
540
541    #[test]
542    fn memory_search_follow_ups() {
543        let predictor = ToolPredictor::new(0.3);
544        let recent = vec!["memory_search".to_string()];
545        let available = vec!["memory_search".to_string(), "file_read".to_string()];
546        let predictions = predictor.predict(&recent, &available);
547        assert!(
548            predictions.iter().any(|p| p.tool_name == "memory_search"),
549            "memory_search should predict memory_search follow-up"
550        );
551    }
552
553    #[test]
554    fn repeated_tool_no_duplicate_with_follow_up() {
555        // file_read repeated 3 times: follow-up includes file_read (0.7),
556        // repeat heuristic should not add a duplicate
557        let predictor = ToolPredictor::new(0.3);
558        let recent = vec![
559            "file_read".to_string(),
560            "file_read".to_string(),
561            "file_read".to_string(),
562        ];
563        let available = vec!["file_read".to_string(), "memory_search".to_string()];
564        let predictions = predictor.predict(&recent, &available);
565        let file_read_count = predictions
566            .iter()
567            .filter(|p| p.tool_name == "file_read")
568            .count();
569        assert_eq!(
570            file_read_count, 1,
571            "file_read should appear exactly once (no duplicate from repeat heuristic)"
572        );
573    }
574
575    #[tokio::test]
576    async fn cache_different_tools_same_params() {
577        let cache = SpeculationCache::new(4);
578        let params = serde_json::json!({"path": "/tmp/test.txt"});
579        cache
580            .insert(
581                "file_read",
582                &params,
583                SpeculativeResult {
584                    output: "read result".to_string(),
585                    metadata: None,
586                    created_at: std::time::Instant::now(),
587                },
588            )
589            .await;
590        cache
591            .insert(
592                "file_write",
593                &params,
594                SpeculativeResult {
595                    output: "write result".to_string(),
596                    metadata: None,
597                    created_at: std::time::Instant::now(),
598                },
599            )
600            .await;
601        assert_eq!(cache.size().await, 2);
602        let read_result = cache.get("file_read", &params).await.unwrap();
603        assert_eq!(read_result.output, "read result");
604        let write_result = cache.get("file_write", &params).await.unwrap();
605        assert_eq!(write_result.output, "write result");
606    }
607
608    #[test]
609    fn speculation_key_different_tool_names() {
610        let params = serde_json::json!({"key": "value"});
611        let key1 = SpeculationKey::new("tool_a", &params);
612        let key2 = SpeculationKey::new("tool_b", &params);
613        assert_ne!(
614            key1, key2,
615            "different tool names should produce different keys"
616        );
617    }
618
619    #[test]
620    fn speculative_result_metadata() {
621        let result = SpeculativeResult {
622            output: "data".to_string(),
623            metadata: Some(serde_json::json!({"source": "cache"})),
624            created_at: std::time::Instant::now(),
625        };
626        assert_eq!(result.metadata.unwrap()["source"], "cache");
627    }
628}