Skip to main content

do_memory_mcp/server/
cache_warming.rs

1// Cache warming configuration and methods
2//!
3//! This module contains cache warming configuration and methods for pre-loading
4//! episodes, patterns, and query patterns to improve initial query performance.
5
6use anyhow::Result;
7use do_memory_core::SelfLearningMemory;
8use std::sync::Arc;
9use tracing::{debug, info};
10
11/// Configuration for cache warming process
12#[derive(Debug, Clone)]
13pub struct CacheWarmingConfig {
14    /// Number of recent episodes to pre-load
15    pub recent_episodes_limit: usize,
16    /// Number of patterns to pre-load per domain
17    pub patterns_per_domain: usize,
18    /// Sample queries to execute for warming
19    pub sample_queries: Vec<SampleQuery>,
20}
21
22/// Sample query for cache warming
23#[derive(Debug, Clone)]
24pub struct SampleQuery {
25    pub description: String,
26    pub domain: String,
27    pub language: Option<String>,
28    pub framework: Option<String>,
29    pub tags: Vec<String>,
30}
31
32impl CacheWarmingConfig {
33    /// Create cache warming config from environment variables
34    pub fn from_env() -> Self {
35        Self {
36            recent_episodes_limit: std::env::var("MCP_CACHE_WARMING_EPISODES")
37                .unwrap_or_else(|_| "50".to_string())
38                .parse()
39                .unwrap_or(50),
40            patterns_per_domain: std::env::var("MCP_CACHE_WARMING_PATTERNS")
41                .unwrap_or_else(|_| "20".to_string())
42                .parse()
43                .unwrap_or(20),
44            sample_queries: Self::default_sample_queries(),
45        }
46    }
47
48    /// Default sample queries for cache warming
49    pub fn default_sample_queries() -> Vec<SampleQuery> {
50        vec![
51            SampleQuery {
52                description: "implement api endpoint".to_string(),
53                domain: "web-api".to_string(),
54                language: Some("rust".to_string()),
55                framework: Some("axum".to_string()),
56                tags: vec!["rest".to_string(), "api".to_string()],
57            },
58            SampleQuery {
59                description: "parse json data".to_string(),
60                domain: "data-processing".to_string(),
61                language: Some("rust".to_string()),
62                framework: None,
63                tags: vec!["json".to_string(), "parsing".to_string()],
64            },
65            SampleQuery {
66                description: "write unit tests".to_string(),
67                domain: "testing".to_string(),
68                language: Some("rust".to_string()),
69                framework: None,
70                tags: vec!["unit-tests".to_string(), "testing".to_string()],
71            },
72            SampleQuery {
73                description: "debug performance issue".to_string(),
74                domain: "debugging".to_string(),
75                language: Some("rust".to_string()),
76                framework: None,
77                tags: vec!["performance".to_string(), "debugging".to_string()],
78            },
79            SampleQuery {
80                description: "refactor code for maintainability".to_string(),
81                domain: "refactoring".to_string(),
82                language: Some("rust".to_string()),
83                framework: None,
84                tags: vec!["refactoring".to_string(), "maintainability".to_string()],
85            },
86        ]
87    }
88}
89
90/// Warm the cache by pre-loading recent episodes and common query patterns
91///
92/// This method performs cache warming to improve initial query performance by:
93/// 1. Pre-loading recent episodes into cache
94/// 2. Pre-computing common query patterns
95/// 3. Warming up pattern extraction and retrieval systems
96pub async fn warm_cache(
97    memory: &Arc<SelfLearningMemory>,
98    config: &CacheWarmingConfig,
99) -> Result<()> {
100    info!("Starting cache warming process");
101
102    let start_time = std::time::Instant::now();
103
104    // Warm episodes cache
105    warm_episodes_cache(memory, config).await?;
106
107    // Warm patterns cache
108    warm_patterns_cache(memory, config).await?;
109
110    // Warm common query patterns
111    warm_query_patterns(memory, config).await?;
112
113    let duration = start_time.elapsed();
114    info!("Cache warming completed in {:.2}s", duration.as_secs_f64());
115
116    Ok(())
117}
118
119/// Warm the episodes cache by loading recent episodes
120async fn warm_episodes_cache(
121    memory: &Arc<SelfLearningMemory>,
122    config: &CacheWarmingConfig,
123) -> Result<()> {
124    info!(
125        "Warming episodes cache with {} recent episodes",
126        config.recent_episodes_limit
127    );
128
129    // Create a generic context to retrieve diverse episodes
130    let context = do_memory_core::TaskContext {
131        domain: "general".to_string(),
132        language: None,
133        framework: None,
134        complexity: do_memory_core::ComplexityLevel::Moderate,
135        tags: vec![],
136    };
137
138    // Retrieve recent episodes using a broad query
139    let episodes = memory
140        .retrieve_relevant_context(
141            "recent tasks".to_string(),
142            context,
143            config.recent_episodes_limit,
144        )
145        .await;
146
147    info!("Pre-loaded {} episodes into cache", episodes.len());
148
149    Ok(())
150}
151
152/// Warm the patterns cache by loading relevant patterns
153async fn warm_patterns_cache(
154    memory: &Arc<SelfLearningMemory>,
155    config: &CacheWarmingConfig,
156) -> Result<()> {
157    info!(
158        "Warming patterns cache with {} patterns per domain",
159        config.patterns_per_domain
160    );
161
162    // Warm patterns for common domains
163    let common_domains = vec![
164        "web-api",
165        "data-processing",
166        "code-generation",
167        "debugging",
168        "refactoring",
169        "testing",
170        "analysis",
171        "documentation",
172    ];
173
174    for domain in &common_domains {
175        let context = do_memory_core::TaskContext {
176            domain: domain.to_string(),
177            language: None,
178            framework: None,
179            complexity: do_memory_core::ComplexityLevel::Moderate,
180            tags: vec![domain.to_string()],
181        };
182
183        let patterns = memory
184            .retrieve_relevant_patterns(&context, config.patterns_per_domain)
185            .await;
186
187        debug!(
188            "Pre-loaded {} patterns for domain '{}'",
189            patterns.len(),
190            domain
191        );
192    }
193
194    info!("Patterns cache warming completed");
195    Ok(())
196}
197
198/// Warm common query patterns by executing typical queries
199async fn warm_query_patterns(
200    memory: &Arc<SelfLearningMemory>,
201    config: &CacheWarmingConfig,
202) -> Result<()> {
203    info!(
204        "Warming query patterns with {} sample queries",
205        config.sample_queries.len()
206    );
207
208    // Execute sample queries to warm up retrieval systems
209    for query in &config.sample_queries {
210        let context = do_memory_core::TaskContext {
211            domain: query.domain.clone(),
212            language: query.language.clone(),
213            framework: query.framework.clone(),
214            complexity: do_memory_core::ComplexityLevel::Moderate,
215            tags: query.tags.clone(),
216        };
217
218        // Query memory (this will populate caches)
219        let _episodes = memory
220            .retrieve_relevant_context(query.description.clone(), context.clone(), 5)
221            .await;
222
223        // Query patterns
224        let _patterns = memory.retrieve_relevant_patterns(&context, 5).await;
225
226        debug!("Warmed query pattern: '{}'", query.description);
227    }
228
229    info!("Query patterns warming completed");
230    Ok(())
231}
232
233/// Check if cache warming is enabled via environment variable
234pub fn is_cache_warming_enabled() -> bool {
235    std::env::var("MCP_CACHE_WARMING_ENABLED")
236        .unwrap_or_else(|_| "true".to_string())
237        .to_lowercase()
238        == "true"
239}