Skip to main content

lean_ctx/core/
procedural_memory.rs

1//! Procedural Memory — recurring workflow detection and template storage.
2//!
3//! Detects repeated tool-call sequences in Episodic Memory and stores them
4//! as reusable Procedures with activation/termination conditions.
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::PathBuf;
10
11use super::episodic_memory::{Episode, Outcome};
12
13use crate::core::memory_policy::ProceduralPolicy;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ProceduralStore {
17    pub project_hash: String,
18    pub procedures: Vec<Procedure>,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct Procedure {
23    pub id: String,
24    pub name: String,
25    pub description: String,
26    pub steps: Vec<ProcedureStep>,
27    pub activation_keywords: Vec<String>,
28    pub confidence: f32,
29    pub times_used: u32,
30    pub times_succeeded: u32,
31    pub last_used: DateTime<Utc>,
32    pub project_specific: bool,
33    pub created_at: DateTime<Utc>,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
37pub struct ProcedureStep {
38    pub tool: String,
39    pub description: String,
40    pub optional: bool,
41}
42
43impl Procedure {
44    pub fn success_rate(&self) -> f32 {
45        if self.times_used == 0 {
46            return 0.0;
47        }
48        self.times_succeeded as f32 / self.times_used as f32
49    }
50
51    pub fn matches_context(&self, task: &str) -> bool {
52        let task_lower = task.to_lowercase();
53        self.activation_keywords
54            .iter()
55            .any(|kw| task_lower.contains(&kw.to_lowercase()))
56    }
57}
58
59impl ProceduralStore {
60    pub fn new(project_hash: &str) -> Self {
61        Self {
62            project_hash: project_hash.to_string(),
63            procedures: Vec::new(),
64        }
65    }
66
67    pub fn suggest(&self, task: &str) -> Vec<&Procedure> {
68        let mut matches: Vec<(&Procedure, f32)> = self
69            .procedures
70            .iter()
71            .filter(|p| p.matches_context(task) && p.confidence >= 0.3)
72            .map(|p| {
73                let score = p.confidence * 0.5 + p.success_rate() * 0.3 + usage_recency(p) * 0.2;
74                (p, score)
75            })
76            .collect();
77
78        matches.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
79        matches.into_iter().map(|(p, _)| p).collect()
80    }
81
82    pub fn record_usage(&mut self, procedure_id: &str, success: bool) {
83        if let Some(proc) = self.procedures.iter_mut().find(|p| p.id == procedure_id) {
84            proc.times_used += 1;
85            if success {
86                proc.times_succeeded += 1;
87            }
88            proc.last_used = Utc::now();
89            proc.confidence =
90                (proc.confidence * 0.8 + if success { 0.2 } else { -0.1 }).clamp(0.0, 1.0);
91        }
92    }
93
94    pub fn add_procedure(&mut self, procedure: Procedure, policy: &ProceduralPolicy) {
95        if let Some(existing) = self
96            .procedures
97            .iter_mut()
98            .find(|p| p.name == procedure.name)
99        {
100            existing.confidence = existing.confidence.midpoint(procedure.confidence);
101            existing.steps = procedure.steps;
102            existing.activation_keywords = procedure.activation_keywords;
103        } else {
104            self.procedures.push(procedure);
105        }
106
107        if self.procedures.len() > policy.max_procedures {
108            self.procedures.sort_by(|a, b| {
109                b.confidence
110                    .partial_cmp(&a.confidence)
111                    .unwrap_or(std::cmp::Ordering::Equal)
112            });
113            self.procedures.truncate(policy.max_procedures);
114        }
115    }
116
117    pub fn detect_patterns(&mut self, episodes: &[Episode], policy: &ProceduralPolicy) {
118        let sequences = extract_tool_sequences(episodes);
119        let patterns = find_repeated_sequences(&sequences, policy);
120
121        for (steps, count, keywords) in patterns {
122            if count < policy.min_repetitions || steps.len() < policy.min_sequence_len {
123                continue;
124            }
125
126            let name = generate_procedure_name(&steps);
127            let already_exists = self.procedures.iter().any(|p| p.name == name);
128            if already_exists {
129                continue;
130            }
131
132            let success_count = episodes
133                .iter()
134                .filter(|ep| matches!(ep.outcome, Outcome::Success { .. }))
135                .count();
136            let confidence = success_count as f32 / episodes.len().max(1) as f32;
137
138            self.add_procedure(
139                Procedure {
140                    id: format!("proc-{}", md5_short(&name)),
141                    name,
142                    description: format!("Detected workflow ({count} repetitions)"),
143                    steps,
144                    activation_keywords: keywords,
145                    confidence,
146                    times_used: count as u32,
147                    times_succeeded: success_count as u32,
148                    last_used: Utc::now(),
149                    project_specific: true,
150                    created_at: Utc::now(),
151                },
152                policy,
153            );
154        }
155    }
156
157    fn store_path(project_hash: &str) -> Option<PathBuf> {
158        let dir = crate::core::data_dir::lean_ctx_data_dir()
159            .ok()?
160            .join("memory")
161            .join("procedures");
162        Some(dir.join(format!("{project_hash}.json")))
163    }
164
165    pub fn load(project_hash: &str) -> Option<Self> {
166        let path = Self::store_path(project_hash)?;
167        let data = std::fs::read_to_string(path).ok()?;
168        serde_json::from_str(&data).ok()
169    }
170
171    pub fn load_or_create(project_hash: &str) -> Self {
172        Self::load(project_hash).unwrap_or_else(|| Self::new(project_hash))
173    }
174
175    pub fn save(&self) -> Result<(), String> {
176        let path = Self::store_path(&self.project_hash)
177            .ok_or_else(|| "Cannot determine data directory".to_string())?;
178        if let Some(dir) = path.parent() {
179            std::fs::create_dir_all(dir).map_err(|e| format!("{e}"))?;
180        }
181        let json = serde_json::to_string_pretty(self).map_err(|e| format!("{e}"))?;
182        std::fs::write(path, json).map_err(|e| format!("{e}"))
183    }
184}
185
186fn extract_tool_sequences(episodes: &[Episode]) -> Vec<Vec<String>> {
187    episodes
188        .iter()
189        .map(|ep| ep.actions.iter().map(|a| a.tool.clone()).collect())
190        .collect()
191}
192
193fn find_repeated_sequences(
194    sequences: &[Vec<String>],
195    policy: &ProceduralPolicy,
196) -> Vec<(Vec<ProcedureStep>, usize, Vec<String>)> {
197    let mut ngram_counts: HashMap<Vec<String>, usize> = HashMap::new();
198
199    for seq in sequences {
200        if seq.len() < policy.min_sequence_len {
201            continue;
202        }
203        let max_win = seq.len().min(policy.max_window_size);
204        for window_size in policy.min_sequence_len..=max_win {
205            for window in seq.windows(window_size) {
206                let key: Vec<String> = window.to_vec();
207                *ngram_counts.entry(key).or_insert(0) += 1;
208            }
209        }
210    }
211
212    let mut results: Vec<(Vec<ProcedureStep>, usize, Vec<String>)> = Vec::new();
213
214    let mut sorted: Vec<_> = ngram_counts.into_iter().collect();
215    sorted.sort_by(|a, b| {
216        let score_a = a.1 * a.0.len();
217        let score_b = b.1 * b.0.len();
218        score_b.cmp(&score_a)
219    });
220
221    let mut seen_prefixes: std::collections::HashSet<String> = std::collections::HashSet::new();
222
223    for (tools, count) in sorted {
224        if count < policy.min_repetitions {
225            continue;
226        }
227
228        let prefix = tools.join("->");
229        let is_substring = seen_prefixes.iter().any(|s| s.contains(&prefix));
230        if is_substring {
231            continue;
232        }
233
234        seen_prefixes.insert(prefix);
235
236        let steps: Vec<ProcedureStep> = tools
237            .iter()
238            .map(|t| ProcedureStep {
239                tool: t.clone(),
240                description: String::new(),
241                optional: false,
242            })
243            .collect();
244
245        let keywords: Vec<String> = tools
246            .iter()
247            .filter(|t| !t.starts_with("ctx_"))
248            .cloned()
249            .collect();
250
251        results.push((steps, count, keywords));
252    }
253
254    results
255}
256
257fn generate_procedure_name(steps: &[ProcedureStep]) -> String {
258    let tools: Vec<&str> = steps.iter().map(|s| s.tool.as_str()).collect();
259    let short: Vec<&str> = tools
260        .iter()
261        .map(|t| t.strip_prefix("ctx_").unwrap_or(t))
262        .collect();
263    format!("workflow-{}", short.join("-"))
264}
265
266fn md5_short(input: &str) -> String {
267    use md5::{Digest, Md5};
268    let result = Md5::digest(input.as_bytes());
269    format!("{result:x}")[..8].to_string()
270}
271
272fn usage_recency(proc: &Procedure) -> f32 {
273    let days_old = Utc::now().signed_duration_since(proc.last_used).num_days() as f32;
274    (1.0 - days_old / 30.0).max(0.0)
275}
276
277pub fn format_suggestion(proc: &Procedure) -> String {
278    let mut output = format!(
279        "Suggested workflow: {} (confidence: {:.0}%, used {}x, success rate: {:.0}%)\n",
280        proc.name,
281        proc.confidence * 100.0,
282        proc.times_used,
283        proc.success_rate() * 100.0
284    );
285    for (i, step) in proc.steps.iter().enumerate() {
286        let opt = if step.optional { " (optional)" } else { "" };
287        output.push_str(&format!("  {}. {}{opt}\n", i + 1, step.tool));
288    }
289    output
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use crate::core::episodic_memory::{Action, Episode, Outcome};
296
297    fn make_episode_with_tools(tools: &[&str]) -> Episode {
298        Episode {
299            id: "ep-1".to_string(),
300            session_id: "s-1".to_string(),
301            timestamp: Utc::now(),
302            task_description: "test task".to_string(),
303            actions: tools
304                .iter()
305                .map(|t| Action {
306                    tool: t.to_string(),
307                    description: String::new(),
308                    timestamp: Utc::now(),
309                    duration_ms: 100,
310                    success: true,
311                })
312                .collect(),
313            outcome: Outcome::Success { tests_passed: true },
314            affected_files: vec![],
315            summary: String::new(),
316            duration_secs: 60,
317            tokens_used: 1000,
318        }
319    }
320
321    #[test]
322    fn detect_patterns_from_episodes() {
323        let policy = ProceduralPolicy::default();
324        let episodes: Vec<Episode> = (0..5)
325            .map(|_| make_episode_with_tools(&["ctx_read", "ctx_shell", "ctx_read"]))
326            .collect();
327
328        let mut store = ProceduralStore::new("test");
329        store.detect_patterns(&episodes, &policy);
330
331        assert!(
332            !store.procedures.is_empty(),
333            "Should detect at least one pattern"
334        );
335    }
336
337    #[test]
338    fn suggest_matching_procedure() {
339        let policy = ProceduralPolicy::default();
340        let mut store = ProceduralStore::new("test");
341        store.add_procedure(
342            Procedure {
343                id: "proc-1".to_string(),
344                name: "deploy-workflow".to_string(),
345                description: "Deploy".to_string(),
346                steps: vec![ProcedureStep {
347                    tool: "ctx_shell".to_string(),
348                    description: "cargo build".to_string(),
349                    optional: false,
350                }],
351                activation_keywords: vec!["deploy".to_string(), "release".to_string()],
352                confidence: 0.8,
353                times_used: 5,
354                times_succeeded: 4,
355                last_used: Utc::now(),
356                project_specific: true,
357                created_at: Utc::now(),
358            },
359            &policy,
360        );
361
362        let suggestions = store.suggest("deploy the new version");
363        assert_eq!(suggestions.len(), 1);
364        assert_eq!(suggestions[0].name, "deploy-workflow");
365
366        let none = store.suggest("refactor the database layer");
367        assert!(none.is_empty());
368    }
369
370    #[test]
371    fn record_usage_updates_confidence() {
372        let policy = ProceduralPolicy::default();
373        let mut store = ProceduralStore::new("test");
374        store.add_procedure(
375            Procedure {
376                id: "proc-1".to_string(),
377                name: "test-workflow".to_string(),
378                description: "Test".to_string(),
379                steps: vec![],
380                activation_keywords: vec![],
381                confidence: 0.5,
382                times_used: 0,
383                times_succeeded: 0,
384                last_used: Utc::now(),
385                project_specific: false,
386                created_at: Utc::now(),
387            },
388            &policy,
389        );
390
391        store.record_usage("proc-1", true);
392        let proc = &store.procedures[0];
393        assert_eq!(proc.times_used, 1);
394        assert_eq!(proc.times_succeeded, 1);
395        assert!(proc.confidence > 0.5);
396    }
397
398    #[test]
399    fn success_rate_calculation() {
400        let proc = Procedure {
401            id: "p".to_string(),
402            name: "n".to_string(),
403            description: String::new(),
404            steps: vec![],
405            activation_keywords: vec![],
406            confidence: 0.5,
407            times_used: 10,
408            times_succeeded: 7,
409            last_used: Utc::now(),
410            project_specific: false,
411            created_at: Utc::now(),
412        };
413        assert!((proc.success_rate() - 0.7).abs() < 0.01);
414    }
415
416    #[test]
417    fn max_procedures_enforced() {
418        let policy = ProceduralPolicy::default();
419        let mut store = ProceduralStore::new("test");
420        for i in 0..110 {
421            store.add_procedure(
422                Procedure {
423                    id: format!("p-{i}"),
424                    name: format!("workflow-{i}"),
425                    description: String::new(),
426                    steps: vec![],
427                    activation_keywords: vec![],
428                    confidence: i as f32 / 110.0,
429                    times_used: 0,
430                    times_succeeded: 0,
431                    last_used: Utc::now(),
432                    project_specific: false,
433                    created_at: Utc::now(),
434                },
435                &policy,
436            );
437        }
438        assert!(store.procedures.len() <= policy.max_procedures);
439    }
440
441    #[test]
442    fn format_suggestion_output() {
443        let proc = Procedure {
444            id: "p".to_string(),
445            name: "deploy-workflow".to_string(),
446            description: String::new(),
447            steps: vec![
448                ProcedureStep {
449                    tool: "ctx_shell".to_string(),
450                    description: "test".to_string(),
451                    optional: false,
452                },
453                ProcedureStep {
454                    tool: "ctx_shell".to_string(),
455                    description: "build".to_string(),
456                    optional: true,
457                },
458            ],
459            activation_keywords: vec![],
460            confidence: 0.85,
461            times_used: 10,
462            times_succeeded: 8,
463            last_used: Utc::now(),
464            project_specific: false,
465            created_at: Utc::now(),
466        };
467        let output = format_suggestion(&proc);
468        assert!(output.contains("deploy-workflow"));
469        assert!(output.contains("85%"));
470        assert!(output.contains("(optional)"));
471    }
472}