1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::PathBuf;
10
11use super::episodic_memory::{Episode, Outcome};
12
13use crate::core::memory_policy::ProceduralPolicy;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ProceduralStore {
17 pub project_hash: String,
18 pub procedures: Vec<Procedure>,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct Procedure {
23 pub id: String,
24 pub name: String,
25 pub description: String,
26 pub steps: Vec<ProcedureStep>,
27 pub activation_keywords: Vec<String>,
28 pub confidence: f32,
29 pub times_used: u32,
30 pub times_succeeded: u32,
31 pub last_used: DateTime<Utc>,
32 pub project_specific: bool,
33 pub created_at: DateTime<Utc>,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
37pub struct ProcedureStep {
38 pub tool: String,
39 pub description: String,
40 pub optional: bool,
41}
42
43impl Procedure {
44 pub fn success_rate(&self) -> f32 {
45 if self.times_used == 0 {
46 return 0.0;
47 }
48 self.times_succeeded as f32 / self.times_used as f32
49 }
50
51 pub fn matches_context(&self, task: &str) -> bool {
52 let task_lower = task.to_lowercase();
53 self.activation_keywords
54 .iter()
55 .any(|kw| task_lower.contains(&kw.to_lowercase()))
56 }
57}
58
59impl ProceduralStore {
60 pub fn new(project_hash: &str) -> Self {
61 Self {
62 project_hash: project_hash.to_string(),
63 procedures: Vec::new(),
64 }
65 }
66
67 pub fn suggest(&self, task: &str) -> Vec<&Procedure> {
68 let mut matches: Vec<(&Procedure, f32)> = self
69 .procedures
70 .iter()
71 .filter(|p| p.matches_context(task) && p.confidence >= 0.3)
72 .map(|p| {
73 let score = p.confidence * 0.5 + p.success_rate() * 0.3 + usage_recency(p) * 0.2;
74 (p, score)
75 })
76 .collect();
77
78 matches.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
79 matches.into_iter().map(|(p, _)| p).collect()
80 }
81
82 pub fn record_usage(&mut self, procedure_id: &str, success: bool) {
83 if let Some(proc) = self.procedures.iter_mut().find(|p| p.id == procedure_id) {
84 proc.times_used += 1;
85 if success {
86 proc.times_succeeded += 1;
87 }
88 proc.last_used = Utc::now();
89 proc.confidence =
90 (proc.confidence * 0.8 + if success { 0.2 } else { -0.1 }).clamp(0.0, 1.0);
91 }
92 }
93
94 pub fn add_procedure(&mut self, procedure: Procedure, policy: &ProceduralPolicy) {
95 if let Some(existing) = self
96 .procedures
97 .iter_mut()
98 .find(|p| p.name == procedure.name)
99 {
100 existing.confidence = existing.confidence.midpoint(procedure.confidence);
101 existing.steps = procedure.steps;
102 existing.activation_keywords = procedure.activation_keywords;
103 } else {
104 self.procedures.push(procedure);
105 }
106
107 if self.procedures.len() > policy.max_procedures {
108 self.procedures.sort_by(|a, b| {
109 b.confidence
110 .partial_cmp(&a.confidence)
111 .unwrap_or(std::cmp::Ordering::Equal)
112 });
113 self.procedures.truncate(policy.max_procedures);
114 }
115 }
116
117 pub fn detect_patterns(&mut self, episodes: &[Episode], policy: &ProceduralPolicy) {
118 let sequences = extract_tool_sequences(episodes);
119 let patterns = find_repeated_sequences(&sequences, policy);
120
121 for (steps, count, keywords) in patterns {
122 if count < policy.min_repetitions || steps.len() < policy.min_sequence_len {
123 continue;
124 }
125
126 let name = generate_procedure_name(&steps);
127 let already_exists = self.procedures.iter().any(|p| p.name == name);
128 if already_exists {
129 continue;
130 }
131
132 let success_count = episodes
133 .iter()
134 .filter(|ep| matches!(ep.outcome, Outcome::Success { .. }))
135 .count();
136 let confidence = success_count as f32 / episodes.len().max(1) as f32;
137
138 self.add_procedure(
139 Procedure {
140 id: format!("proc-{}", md5_short(&name)),
141 name,
142 description: format!("Detected workflow ({count} repetitions)"),
143 steps,
144 activation_keywords: keywords,
145 confidence,
146 times_used: count as u32,
147 times_succeeded: success_count as u32,
148 last_used: Utc::now(),
149 project_specific: true,
150 created_at: Utc::now(),
151 },
152 policy,
153 );
154 }
155 }
156
157 fn store_path(project_hash: &str) -> Option<PathBuf> {
158 let dir = crate::core::data_dir::lean_ctx_data_dir()
159 .ok()?
160 .join("memory")
161 .join("procedures");
162 Some(dir.join(format!("{project_hash}.json")))
163 }
164
165 pub fn load(project_hash: &str) -> Option<Self> {
166 let path = Self::store_path(project_hash)?;
167 let data = std::fs::read_to_string(path).ok()?;
168 serde_json::from_str(&data).ok()
169 }
170
171 pub fn load_or_create(project_hash: &str) -> Self {
172 Self::load(project_hash).unwrap_or_else(|| Self::new(project_hash))
173 }
174
175 pub fn save(&self) -> Result<(), String> {
176 let path = Self::store_path(&self.project_hash)
177 .ok_or_else(|| "Cannot determine data directory".to_string())?;
178 if let Some(dir) = path.parent() {
179 std::fs::create_dir_all(dir).map_err(|e| format!("{e}"))?;
180 }
181 let json = serde_json::to_string_pretty(self).map_err(|e| format!("{e}"))?;
182 std::fs::write(path, json).map_err(|e| format!("{e}"))
183 }
184}
185
186fn extract_tool_sequences(episodes: &[Episode]) -> Vec<Vec<String>> {
187 episodes
188 .iter()
189 .map(|ep| ep.actions.iter().map(|a| a.tool.clone()).collect())
190 .collect()
191}
192
193fn find_repeated_sequences(
194 sequences: &[Vec<String>],
195 policy: &ProceduralPolicy,
196) -> Vec<(Vec<ProcedureStep>, usize, Vec<String>)> {
197 let mut ngram_counts: HashMap<Vec<String>, usize> = HashMap::new();
198
199 for seq in sequences {
200 if seq.len() < policy.min_sequence_len {
201 continue;
202 }
203 let max_win = seq.len().min(policy.max_window_size);
204 for window_size in policy.min_sequence_len..=max_win {
205 for window in seq.windows(window_size) {
206 let key: Vec<String> = window.to_vec();
207 *ngram_counts.entry(key).or_insert(0) += 1;
208 }
209 }
210 }
211
212 let mut results: Vec<(Vec<ProcedureStep>, usize, Vec<String>)> = Vec::new();
213
214 let mut sorted: Vec<_> = ngram_counts.into_iter().collect();
215 sorted.sort_by(|a, b| {
216 let score_a = a.1 * a.0.len();
217 let score_b = b.1 * b.0.len();
218 score_b.cmp(&score_a)
219 });
220
221 let mut seen_prefixes: std::collections::HashSet<String> = std::collections::HashSet::new();
222
223 for (tools, count) in sorted {
224 if count < policy.min_repetitions {
225 continue;
226 }
227
228 let prefix = tools.join("->");
229 let is_substring = seen_prefixes.iter().any(|s| s.contains(&prefix));
230 if is_substring {
231 continue;
232 }
233
234 seen_prefixes.insert(prefix);
235
236 let steps: Vec<ProcedureStep> = tools
237 .iter()
238 .map(|t| ProcedureStep {
239 tool: t.clone(),
240 description: String::new(),
241 optional: false,
242 })
243 .collect();
244
245 let keywords: Vec<String> = tools
246 .iter()
247 .filter(|t| !t.starts_with("ctx_"))
248 .cloned()
249 .collect();
250
251 results.push((steps, count, keywords));
252 }
253
254 results
255}
256
257fn generate_procedure_name(steps: &[ProcedureStep]) -> String {
258 let tools: Vec<&str> = steps.iter().map(|s| s.tool.as_str()).collect();
259 let short: Vec<&str> = tools
260 .iter()
261 .map(|t| t.strip_prefix("ctx_").unwrap_or(t))
262 .collect();
263 format!("workflow-{}", short.join("-"))
264}
265
266fn md5_short(input: &str) -> String {
267 use md5::{Digest, Md5};
268 let result = Md5::digest(input.as_bytes());
269 format!("{result:x}")[..8].to_string()
270}
271
272fn usage_recency(proc: &Procedure) -> f32 {
273 let days_old = Utc::now().signed_duration_since(proc.last_used).num_days() as f32;
274 (1.0 - days_old / 30.0).max(0.0)
275}
276
277pub fn format_suggestion(proc: &Procedure) -> String {
278 let mut output = format!(
279 "Suggested workflow: {} (confidence: {:.0}%, used {}x, success rate: {:.0}%)\n",
280 proc.name,
281 proc.confidence * 100.0,
282 proc.times_used,
283 proc.success_rate() * 100.0
284 );
285 for (i, step) in proc.steps.iter().enumerate() {
286 let opt = if step.optional { " (optional)" } else { "" };
287 output.push_str(&format!(" {}. {}{opt}\n", i + 1, step.tool));
288 }
289 output
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use crate::core::episodic_memory::{Action, Episode, Outcome};
296
297 fn make_episode_with_tools(tools: &[&str]) -> Episode {
298 Episode {
299 id: "ep-1".to_string(),
300 session_id: "s-1".to_string(),
301 timestamp: Utc::now(),
302 task_description: "test task".to_string(),
303 actions: tools
304 .iter()
305 .map(|t| Action {
306 tool: t.to_string(),
307 description: String::new(),
308 timestamp: Utc::now(),
309 duration_ms: 100,
310 success: true,
311 })
312 .collect(),
313 outcome: Outcome::Success { tests_passed: true },
314 affected_files: vec![],
315 summary: String::new(),
316 duration_secs: 60,
317 tokens_used: 1000,
318 }
319 }
320
321 #[test]
322 fn detect_patterns_from_episodes() {
323 let policy = ProceduralPolicy::default();
324 let episodes: Vec<Episode> = (0..5)
325 .map(|_| make_episode_with_tools(&["ctx_read", "ctx_shell", "ctx_read"]))
326 .collect();
327
328 let mut store = ProceduralStore::new("test");
329 store.detect_patterns(&episodes, &policy);
330
331 assert!(
332 !store.procedures.is_empty(),
333 "Should detect at least one pattern"
334 );
335 }
336
337 #[test]
338 fn suggest_matching_procedure() {
339 let policy = ProceduralPolicy::default();
340 let mut store = ProceduralStore::new("test");
341 store.add_procedure(
342 Procedure {
343 id: "proc-1".to_string(),
344 name: "deploy-workflow".to_string(),
345 description: "Deploy".to_string(),
346 steps: vec![ProcedureStep {
347 tool: "ctx_shell".to_string(),
348 description: "cargo build".to_string(),
349 optional: false,
350 }],
351 activation_keywords: vec!["deploy".to_string(), "release".to_string()],
352 confidence: 0.8,
353 times_used: 5,
354 times_succeeded: 4,
355 last_used: Utc::now(),
356 project_specific: true,
357 created_at: Utc::now(),
358 },
359 &policy,
360 );
361
362 let suggestions = store.suggest("deploy the new version");
363 assert_eq!(suggestions.len(), 1);
364 assert_eq!(suggestions[0].name, "deploy-workflow");
365
366 let none = store.suggest("refactor the database layer");
367 assert!(none.is_empty());
368 }
369
370 #[test]
371 fn record_usage_updates_confidence() {
372 let policy = ProceduralPolicy::default();
373 let mut store = ProceduralStore::new("test");
374 store.add_procedure(
375 Procedure {
376 id: "proc-1".to_string(),
377 name: "test-workflow".to_string(),
378 description: "Test".to_string(),
379 steps: vec![],
380 activation_keywords: vec![],
381 confidence: 0.5,
382 times_used: 0,
383 times_succeeded: 0,
384 last_used: Utc::now(),
385 project_specific: false,
386 created_at: Utc::now(),
387 },
388 &policy,
389 );
390
391 store.record_usage("proc-1", true);
392 let proc = &store.procedures[0];
393 assert_eq!(proc.times_used, 1);
394 assert_eq!(proc.times_succeeded, 1);
395 assert!(proc.confidence > 0.5);
396 }
397
398 #[test]
399 fn success_rate_calculation() {
400 let proc = Procedure {
401 id: "p".to_string(),
402 name: "n".to_string(),
403 description: String::new(),
404 steps: vec![],
405 activation_keywords: vec![],
406 confidence: 0.5,
407 times_used: 10,
408 times_succeeded: 7,
409 last_used: Utc::now(),
410 project_specific: false,
411 created_at: Utc::now(),
412 };
413 assert!((proc.success_rate() - 0.7).abs() < 0.01);
414 }
415
416 #[test]
417 fn max_procedures_enforced() {
418 let policy = ProceduralPolicy::default();
419 let mut store = ProceduralStore::new("test");
420 for i in 0..110 {
421 store.add_procedure(
422 Procedure {
423 id: format!("p-{i}"),
424 name: format!("workflow-{i}"),
425 description: String::new(),
426 steps: vec![],
427 activation_keywords: vec![],
428 confidence: i as f32 / 110.0,
429 times_used: 0,
430 times_succeeded: 0,
431 last_used: Utc::now(),
432 project_specific: false,
433 created_at: Utc::now(),
434 },
435 &policy,
436 );
437 }
438 assert!(store.procedures.len() <= policy.max_procedures);
439 }
440
441 #[test]
442 fn format_suggestion_output() {
443 let proc = Procedure {
444 id: "p".to_string(),
445 name: "deploy-workflow".to_string(),
446 description: String::new(),
447 steps: vec![
448 ProcedureStep {
449 tool: "ctx_shell".to_string(),
450 description: "test".to_string(),
451 optional: false,
452 },
453 ProcedureStep {
454 tool: "ctx_shell".to_string(),
455 description: "build".to_string(),
456 optional: true,
457 },
458 ],
459 activation_keywords: vec![],
460 confidence: 0.85,
461 times_used: 10,
462 times_succeeded: 8,
463 last_used: Utc::now(),
464 project_specific: false,
465 created_at: Utc::now(),
466 };
467 let output = format_suggestion(&proc);
468 assert!(output.contains("deploy-workflow"));
469 assert!(output.contains("85%"));
470 assert!(output.contains("(optional)"));
471 }
472}