Skip to main content

deepstrike_core/memory/
extractor.rs

1use crate::memory::semantic::MemoryEntry;
2use crate::types::message::{Content, ContentPart, Message, Role};
3
4pub struct ExtractionPolicy {
5    pub min_length: usize,
6    pub include_tool_results: bool,
7    pub include_questions: bool,
8}
9
10impl Default for ExtractionPolicy {
11    fn default() -> Self {
12        Self {
13            min_length: 100,
14            include_tool_results: true,
15            include_questions: true,
16        }
17    }
18}
19
20pub struct MemoryExtractor {
21    pub policy: ExtractionPolicy,
22}
23
24impl MemoryExtractor {
25    pub fn new(policy: ExtractionPolicy) -> Self {
26        Self { policy }
27    }
28
29    pub fn extract(&self, messages: &[Message]) -> Vec<MemoryEntry> {
30        let mut entries = Vec::new();
31        for msg in messages {
32            match msg.role {
33                Role::Assistant => {
34                    if let Some(text) = msg.content.as_text() {
35                        if text.len() >= self.policy.min_length {
36                            entries.push(entry(text));
37                        }
38                    }
39                }
40                Role::User if self.policy.include_questions => {
41                    if let Some(text) = msg.content.as_text() {
42                        if text.ends_with('?') {
43                            entries.push(entry(text));
44                        }
45                    }
46                }
47                Role::Tool if self.policy.include_tool_results => {
48                    if let Content::Parts(parts) = &msg.content {
49                        for part in parts {
50                            if let ContentPart::ToolResult {
51                                output, is_error, ..
52                            } = part
53                            {
54                                if !is_error && output.len() >= self.policy.min_length {
55                                    entries.push(entry(output));
56                                }
57                            }
58                        }
59                    }
60                }
61                _ => {}
62            }
63        }
64        entries
65    }
66}
67
68fn entry(text: &str) -> MemoryEntry {
69    MemoryEntry {
70        text: text.to_string(),
71        score: 0.0,
72        metadata: serde_json::Value::Null,
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[test]
81    fn extracts_long_assistant_messages() {
82        let extractor = MemoryExtractor::new(ExtractionPolicy::default());
83        let msg = Message::assistant("a".repeat(101));
84        let entries = extractor.extract(&[msg]);
85        assert_eq!(entries.len(), 1);
86    }
87
88    #[test]
89    fn extracts_user_questions() {
90        let extractor = MemoryExtractor::new(ExtractionPolicy::default());
91        let msg = Message::user("What is the answer?");
92        let entries = extractor.extract(&[msg]);
93        assert_eq!(entries.len(), 1);
94    }
95
96    #[test]
97    fn skips_short_assistant_messages() {
98        let extractor = MemoryExtractor::new(ExtractionPolicy::default());
99        let msg = Message::assistant("short");
100        let entries = extractor.extract(&[msg]);
101        assert!(entries.is_empty());
102    }
103}