1use anyhow::Result;
7use do_memory_core::SelfLearningMemory;
8use std::sync::Arc;
9use tracing::{debug, info};
10
11#[derive(Debug, Clone)]
13pub struct CacheWarmingConfig {
14 pub recent_episodes_limit: usize,
16 pub patterns_per_domain: usize,
18 pub sample_queries: Vec<SampleQuery>,
20}
21
22#[derive(Debug, Clone)]
24pub struct SampleQuery {
25 pub description: String,
26 pub domain: String,
27 pub language: Option<String>,
28 pub framework: Option<String>,
29 pub tags: Vec<String>,
30}
31
32impl CacheWarmingConfig {
33 pub fn from_env() -> Self {
35 Self {
36 recent_episodes_limit: std::env::var("MCP_CACHE_WARMING_EPISODES")
37 .unwrap_or_else(|_| "50".to_string())
38 .parse()
39 .unwrap_or(50),
40 patterns_per_domain: std::env::var("MCP_CACHE_WARMING_PATTERNS")
41 .unwrap_or_else(|_| "20".to_string())
42 .parse()
43 .unwrap_or(20),
44 sample_queries: Self::default_sample_queries(),
45 }
46 }
47
48 pub fn default_sample_queries() -> Vec<SampleQuery> {
50 vec![
51 SampleQuery {
52 description: "implement api endpoint".to_string(),
53 domain: "web-api".to_string(),
54 language: Some("rust".to_string()),
55 framework: Some("axum".to_string()),
56 tags: vec!["rest".to_string(), "api".to_string()],
57 },
58 SampleQuery {
59 description: "parse json data".to_string(),
60 domain: "data-processing".to_string(),
61 language: Some("rust".to_string()),
62 framework: None,
63 tags: vec!["json".to_string(), "parsing".to_string()],
64 },
65 SampleQuery {
66 description: "write unit tests".to_string(),
67 domain: "testing".to_string(),
68 language: Some("rust".to_string()),
69 framework: None,
70 tags: vec!["unit-tests".to_string(), "testing".to_string()],
71 },
72 SampleQuery {
73 description: "debug performance issue".to_string(),
74 domain: "debugging".to_string(),
75 language: Some("rust".to_string()),
76 framework: None,
77 tags: vec!["performance".to_string(), "debugging".to_string()],
78 },
79 SampleQuery {
80 description: "refactor code for maintainability".to_string(),
81 domain: "refactoring".to_string(),
82 language: Some("rust".to_string()),
83 framework: None,
84 tags: vec!["refactoring".to_string(), "maintainability".to_string()],
85 },
86 ]
87 }
88}
89
90pub async fn warm_cache(
97 memory: &Arc<SelfLearningMemory>,
98 config: &CacheWarmingConfig,
99) -> Result<()> {
100 info!("Starting cache warming process");
101
102 let start_time = std::time::Instant::now();
103
104 warm_episodes_cache(memory, config).await?;
106
107 warm_patterns_cache(memory, config).await?;
109
110 warm_query_patterns(memory, config).await?;
112
113 let duration = start_time.elapsed();
114 info!("Cache warming completed in {:.2}s", duration.as_secs_f64());
115
116 Ok(())
117}
118
119async fn warm_episodes_cache(
121 memory: &Arc<SelfLearningMemory>,
122 config: &CacheWarmingConfig,
123) -> Result<()> {
124 info!(
125 "Warming episodes cache with {} recent episodes",
126 config.recent_episodes_limit
127 );
128
129 let context = do_memory_core::TaskContext {
131 domain: "general".to_string(),
132 language: None,
133 framework: None,
134 complexity: do_memory_core::ComplexityLevel::Moderate,
135 tags: vec![],
136 };
137
138 let episodes = memory
140 .retrieve_relevant_context(
141 "recent tasks".to_string(),
142 context,
143 config.recent_episodes_limit,
144 )
145 .await;
146
147 info!("Pre-loaded {} episodes into cache", episodes.len());
148
149 Ok(())
150}
151
152async fn warm_patterns_cache(
154 memory: &Arc<SelfLearningMemory>,
155 config: &CacheWarmingConfig,
156) -> Result<()> {
157 info!(
158 "Warming patterns cache with {} patterns per domain",
159 config.patterns_per_domain
160 );
161
162 let common_domains = vec![
164 "web-api",
165 "data-processing",
166 "code-generation",
167 "debugging",
168 "refactoring",
169 "testing",
170 "analysis",
171 "documentation",
172 ];
173
174 for domain in &common_domains {
175 let context = do_memory_core::TaskContext {
176 domain: domain.to_string(),
177 language: None,
178 framework: None,
179 complexity: do_memory_core::ComplexityLevel::Moderate,
180 tags: vec![domain.to_string()],
181 };
182
183 let patterns = memory
184 .retrieve_relevant_patterns(&context, config.patterns_per_domain)
185 .await;
186
187 debug!(
188 "Pre-loaded {} patterns for domain '{}'",
189 patterns.len(),
190 domain
191 );
192 }
193
194 info!("Patterns cache warming completed");
195 Ok(())
196}
197
198async fn warm_query_patterns(
200 memory: &Arc<SelfLearningMemory>,
201 config: &CacheWarmingConfig,
202) -> Result<()> {
203 info!(
204 "Warming query patterns with {} sample queries",
205 config.sample_queries.len()
206 );
207
208 for query in &config.sample_queries {
210 let context = do_memory_core::TaskContext {
211 domain: query.domain.clone(),
212 language: query.language.clone(),
213 framework: query.framework.clone(),
214 complexity: do_memory_core::ComplexityLevel::Moderate,
215 tags: query.tags.clone(),
216 };
217
218 let _episodes = memory
220 .retrieve_relevant_context(query.description.clone(), context.clone(), 5)
221 .await;
222
223 let _patterns = memory.retrieve_relevant_patterns(&context, 5).await;
225
226 debug!("Warmed query pattern: '{}'", query.description);
227 }
228
229 info!("Query patterns warming completed");
230 Ok(())
231}
232
233pub fn is_cache_warming_enabled() -> bool {
235 std::env::var("MCP_CACHE_WARMING_ENABLED")
236 .unwrap_or_else(|_| "true".to_string())
237 .to_lowercase()
238 == "true"
239}