use crate::common::traits::AIProvider;
use dashmap::DashMap;
use std::sync::Arc;
pub struct Prefetcher {
model: Arc<dyn AIProvider>,
pattern_cache: Arc<DashMap<String, Vec<String>>>,
pending_prefetches: Arc<DashMap<String, tokio::sync::oneshot::Receiver<String>>>,
}
impl Prefetcher {
pub fn new(model: Arc<dyn AIProvider>) -> Self {
Self {
model,
pattern_cache: Arc::new(DashMap::new()),
pending_prefetches: Arc::new(DashMap::new()),
}
}
pub async fn prefetch_for_task(&self, task: &str) -> Result<(), Box<dyn std::error::Error>> {
let patterns = self.extract_patterns(task).await;
for pattern in patterns {
if !self.pending_prefetches.contains_key(&pattern) {
let model_clone = Arc::clone(&self.model);
let pattern_clone = pattern.clone();
let (tx, rx) = tokio::sync::oneshot::channel();
self.pending_prefetches.insert(pattern.clone(), rx);
tokio::spawn(async move {
let result = model_clone.complete(&pattern_clone).await;
let _ = tx.send(result.unwrap_or_default());
});
}
}
Ok(())
}
pub async fn get_prefetch(&self, pattern: &str) -> Option<String> {
None
}
async fn extract_patterns(&self, task: &str) -> Vec<String> {
let mut patterns = Vec::new();
let task_lower = task.to_lowercase();
if task_lower.contains("function") || task_lower.contains("implement") {
patterns.push(format!("Explain the function: {}", task));
patterns.push(format!("Improve the function: {}", task));
patterns.push(format!("Add tests for: {}", task));
}
if task_lower.contains("bug") || task_lower.contains("error") || task_lower.contains("fix")
{
patterns.push(format!("Debug this further: {}", task));
patterns.push("Provide alternative solutions".to_string());
patterns.push("Review best practices".to_string());
}
if task_lower.contains("refactor") || task_lower.contains("optimiz") {
patterns.push(format!("Performance analysis of: {}", task));
patterns.push("Code review suggestions".to_string());
patterns.push("Security considerations".to_string());
}
if task_lower.contains("rust") {
patterns.push("Rust best practices".to_string());
patterns.push("Memory safety considerations".to_string());
} else if task_lower.contains("python") {
patterns.push("Python best practices".to_string());
patterns.push("Performance optimization".to_string());
} else if task_lower.contains("javascript") || task_lower.contains("js") {
patterns.push("JavaScript best practices".to_string());
patterns.push("Async/await patterns".to_string());
}
if task_lower.contains("api") || task_lower.contains("endpoint") {
patterns.push("API documentation".to_string());
patterns.push("Error handling for API".to_string());
}
if task_lower.contains("test") || task_lower.contains("unit") {
patterns.push("Test edge cases".to_string());
patterns.push("Integration test suggestions".to_string());
}
patterns.truncate(5);
patterns
}
pub async fn prefetch_from_code_context(
&self,
code: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let mut patterns = Vec::new();
if code.contains("fn ") {
patterns.push("Suggest improvements to these functions".to_string());
}
if code.contains("TODO") || code.contains("FIXME") || code.contains("FIXME") {
patterns.push("Address these TODO items".to_string());
}
if code.contains("unwrap()") || code.contains("expect(") {
patterns.push("Improve error handling in this code".to_string());
}
for pattern in patterns {
let model_clone = Arc::clone(&self.model);
let pattern_clone = pattern.clone();
tokio::spawn(async move {
let _ = model_clone.complete(&pattern_clone).await;
});
}
Ok(())
}
pub async fn cancel_all_prefetches(&self) {
self.pending_prefetches.clear();
}
pub async fn stats(&self) -> PrefetchStats {
PrefetchStats {
pattern_cache_size: self.pattern_cache.len(),
pending_prefetches: self.pending_prefetches.len(),
}
}
}
#[derive(Debug)]
pub struct PrefetchStats {
pub pattern_cache_size: usize,
pub pending_prefetches: usize,
}