Skip to main content

brainwires_reasoning/
summarizer.rs

1//! Summarizer - Context Summarization
2//!
3//! Uses a provider to generate summaries for tiered memory demotion,
4//! reducing the need for expensive API calls for context compression.
5
6use std::sync::Arc;
7use tracing::warn;
8
9use brainwires_core::message::Message;
10use brainwires_core::provider::{ChatOptions, Provider};
11
12use crate::InferenceTimer;
13
14const CONTEXT_BUFFER_CAPACITY: usize = 4000;
15
16/// Result of a summarization operation
17#[derive(Clone, Debug)]
18pub struct SummarizationResult {
19    /// The generated summary
20    pub summary: String,
21    /// Confidence score (0.0 - 1.0)
22    pub confidence: f32,
23    /// Whether LLM was used
24    pub used_local_llm: bool,
25}
26
27impl SummarizationResult {
28    /// Create a result from LLM summarization
29    pub fn from_local(summary: String, confidence: f32) -> Self {
30        Self {
31            summary,
32            confidence,
33            used_local_llm: true,
34        }
35    }
36
37    /// Create a fallback result (simple truncation)
38    pub fn from_fallback(summary: String) -> Self {
39        Self {
40            summary,
41            confidence: 0.3,
42            used_local_llm: false,
43        }
44    }
45}
46
47/// Key fact extracted from content
48#[derive(Clone, Debug)]
49pub struct ExtractedFact {
50    /// The fact content
51    pub fact: String,
52    /// Type of fact (decision, definition, requirement, etc.)
53    pub fact_type: FactCategory,
54    /// Confidence score
55    pub confidence: f32,
56}
57
58/// Category of extracted facts
59#[derive(Clone, Copy, Debug, PartialEq, Eq)]
60pub enum FactCategory {
61    /// A decision that was made.
62    Decision,
63    /// A definition or clarification of a term.
64    Definition,
65    /// A requirement or constraint.
66    Requirement,
67    /// A code change, fix, or modification.
68    CodeChange,
69    /// A configuration or settings change.
70    Configuration,
71    /// A reference to external documentation or resources.
72    Reference,
73    /// An uncategorized fact.
74    Other,
75}
76
77impl FactCategory {
78    /// Parse from string
79    #[allow(clippy::should_implement_trait)]
80    pub fn from_str(s: &str) -> Self {
81        let lower = s.to_lowercase();
82        if lower.contains("decision") {
83            FactCategory::Decision
84        } else if lower.contains("definition") || lower.contains("define") {
85            FactCategory::Definition
86        } else if lower.contains("requirement")
87            || lower.contains("must")
88            || lower.contains("should")
89        {
90            FactCategory::Requirement
91        } else if lower.contains("code") || lower.contains("change") || lower.contains("fix") {
92            FactCategory::CodeChange
93        } else if lower.contains("config") || lower.contains("setting") {
94            FactCategory::Configuration
95        } else if lower.contains("reference") || lower.contains("see") || lower.contains("link") {
96            FactCategory::Reference
97        } else {
98            FactCategory::Other
99        }
100    }
101}
102
103/// Summarizer for context compression
104pub struct LocalSummarizer {
105    provider: Arc<dyn Provider>,
106    model_id: String,
107    /// Maximum tokens for summary output
108    max_summary_tokens: u32,
109    /// Maximum facts to extract per summary
110    max_facts: usize,
111}
112
113impl LocalSummarizer {
114    /// Create a new summarizer
115    pub fn new(provider: Arc<dyn Provider>, model_id: impl Into<String>) -> Self {
116        Self {
117            provider,
118            model_id: model_id.into(),
119            max_summary_tokens: 150,
120            max_facts: 5,
121        }
122    }
123
124    /// Set maximum summary tokens
125    pub fn with_max_summary_tokens(mut self, tokens: u32) -> Self {
126        self.max_summary_tokens = tokens;
127        self
128    }
129
130    /// Set maximum facts to extract
131    pub fn with_max_facts(mut self, facts: usize) -> Self {
132        self.max_facts = facts;
133        self
134    }
135
136    /// Summarize a message for warm tier storage
137    ///
138    /// Generates a 50-100 word summary suitable for the warm memory tier.
139    pub async fn summarize_message(
140        &self,
141        content: &str,
142        role: &str,
143    ) -> Option<SummarizationResult> {
144        let timer = InferenceTimer::new("summarize_message", &self.model_id);
145
146        // Skip very short content
147        if content.trim().len() < 50 {
148            return Some(SummarizationResult::from_fallback(content.to_string()));
149        }
150
151        let prompt = format!(
152            "Summarize this {} message in 50-100 words. Preserve key information, decisions, and technical details.\n\nMessage:\n{}\n\nSummary:",
153            role,
154            // Truncate very long content for efficiency
155            if content.len() > 2000 {
156                &content[..2000]
157            } else {
158                content
159            }
160        );
161
162        let messages = vec![Message::user(&prompt)];
163        let options = ChatOptions::creative(self.max_summary_tokens);
164
165        match self.provider.chat(&messages, None, &options).await {
166            Ok(response) => {
167                let summary = response.message.text_or_summary();
168                let cleaned = self.clean_summary(&summary);
169                if cleaned.len() < 10 {
170                    timer.finish(false);
171                    return None;
172                }
173                timer.finish(true);
174                Some(SummarizationResult::from_local(cleaned, 0.8))
175            }
176            Err(e) => {
177                warn!(target: "local_llm", "Message summarization failed: {}", e);
178                timer.finish(false);
179                None
180            }
181        }
182    }
183
184    /// Extract key facts from a summary for cold tier storage
185    ///
186    /// Parses structured facts from content for ultra-compressed archival.
187    pub async fn extract_facts(&self, summary: &str) -> Option<Vec<ExtractedFact>> {
188        let timer = InferenceTimer::new("extract_facts", &self.model_id);
189
190        // Skip very short summaries
191        if summary.trim().len() < 30 {
192            return Some(vec![ExtractedFact {
193                fact: summary.to_string(),
194                fact_type: FactCategory::Other,
195                confidence: 0.5,
196            }]);
197        }
198
199        let prompt = format!(
200            "Extract {} key facts from this text. Format each as: TYPE: fact\nTypes: Decision, Definition, Requirement, CodeChange, Configuration, Reference, Other\n\nText:\n{}\n\nFacts:",
201            self.max_facts, summary
202        );
203
204        let messages = vec![Message::user(&prompt)];
205        let options = ChatOptions::factual(200);
206
207        match self.provider.chat(&messages, None, &options).await {
208            Ok(response) => {
209                let output = response.message.text_or_summary();
210                let facts = self.parse_facts(&output);
211                if facts.is_empty() {
212                    timer.finish(false);
213                    return None;
214                }
215                timer.finish(true);
216                Some(facts)
217            }
218            Err(e) => {
219                warn!(target: "local_llm", "Fact extraction failed: {}", e);
220                timer.finish(false);
221                None
222            }
223        }
224    }
225
226    /// Compact a conversation for emergency context reduction
227    ///
228    /// Used when token count exceeds threshold (e.g., 80k tokens).
229    pub async fn compact_conversation(
230        &self,
231        messages: &[(String, String)], // (role, content) pairs
232        keep_recent: usize,
233    ) -> Option<String> {
234        let timer = InferenceTimer::new("compact_conversation", &self.model_id);
235
236        if messages.len() <= keep_recent {
237            return None; // Nothing to compact
238        }
239
240        let to_compact = &messages[..messages.len() - keep_recent];
241
242        // Build a condensed representation
243        let mut context = String::with_capacity(CONTEXT_BUFFER_CAPACITY);
244        for (role, content) in to_compact.iter().take(20) {
245            let truncated = if content.len() > 200 {
246                &content[..200]
247            } else {
248                content
249            };
250            context.push_str(&format!("[{}]: {}\n", role, truncated));
251        }
252
253        if to_compact.len() > 20 {
254            context.push_str(&format!(
255                "\n... ({} more messages)\n",
256                to_compact.len() - 20
257            ));
258        }
259
260        let prompt = format!(
261            "Summarize this conversation history in 200-300 words. Focus on: decisions made, key technical details, current task state.\n\n{}\n\nSummary:",
262            context
263        );
264
265        let chat_messages = vec![Message::user(&prompt)];
266        let options = ChatOptions::creative(400);
267
268        match self.provider.chat(&chat_messages, None, &options).await {
269            Ok(response) => {
270                let summary = response.message.text_or_summary();
271                let cleaned = self.clean_summary(&summary);
272                timer.finish(true);
273                Some(cleaned)
274            }
275            Err(e) => {
276                warn!(target: "local_llm", "Conversation compaction failed: {}", e);
277                timer.finish(false);
278                None
279            }
280        }
281    }
282
283    /// Heuristic summarization (no LLM)
284    pub fn summarize_heuristic(&self, content: &str) -> SummarizationResult {
285        SummarizationResult::from_fallback(self.truncate_summary(content))
286    }
287
288    /// Extract entities from content for summary metadata
289    pub fn extract_entities(&self, content: &str) -> Vec<String> {
290        let mut entities = Vec::new();
291
292        // Extract file paths
293        let path_patterns = [r"([a-zA-Z0-9_\-/]+\.[a-z]{2,4})", r"src/[a-zA-Z0-9_\-/]+"];
294        for pattern in path_patterns {
295            if let Ok(re) = regex::Regex::new(pattern) {
296                for cap in re.captures_iter(content) {
297                    if let Some(m) = cap.get(0) {
298                        let entity = m.as_str().to_string();
299                        if !entities.contains(&entity) && entity.len() > 3 {
300                            entities.push(entity);
301                        }
302                    }
303                }
304            }
305        }
306
307        // Extract function/type names (PascalCase or snake_case)
308        if let Ok(re) = regex::Regex::new(r"\b([A-Z][a-zA-Z0-9]+|[a-z]+_[a-z_]+)\b") {
309            for cap in re.captures_iter(content) {
310                if let Some(m) = cap.get(1) {
311                    let entity = m.as_str().to_string();
312                    if !entities.contains(&entity)
313                        && entity.len() > 3
314                        && ![
315                            "This", "That", "These", "Those", "What", "When", "Where", "Which",
316                        ]
317                        .contains(&entity.as_str())
318                    {
319                        entities.push(entity);
320                    }
321                }
322            }
323        }
324
325        // Limit to top 10 entities
326        entities.truncate(10);
327        entities
328    }
329
330    /// Truncate content to create a simple summary
331    fn truncate_summary(&self, content: &str) -> String {
332        let words: Vec<&str> = content.split_whitespace().collect();
333        if words.len() <= 100 {
334            content.to_string()
335        } else {
336            format!("{}...", words[..100].join(" "))
337        }
338    }
339
340    /// Clean up LLM output for summary
341    fn clean_summary(&self, output: &str) -> String {
342        let mut cleaned = output.trim().to_string();
343
344        // Remove common prefixes
345        let prefixes = [
346            "Summary:",
347            "Here's a summary:",
348            "Here is a summary:",
349            "The summary is:",
350        ];
351        for prefix in prefixes {
352            if cleaned.to_lowercase().starts_with(&prefix.to_lowercase()) {
353                cleaned = cleaned[prefix.len()..].trim().to_string();
354            }
355        }
356
357        // Remove trailing incomplete sentences
358        if let Some(last_period) = cleaned.rfind('.')
359            && last_period < cleaned.len() - 20
360        {
361            cleaned = cleaned[..=last_period].to_string();
362        }
363
364        cleaned
365    }
366
367    /// Parse facts from LLM output
368    fn parse_facts(&self, output: &str) -> Vec<ExtractedFact> {
369        let mut facts = Vec::new();
370
371        for line in output.lines() {
372            let line = line.trim();
373            if line.is_empty() {
374                continue;
375            }
376
377            // Try to parse "TYPE: fact" format
378            if let Some(colon_pos) = line.find(':') {
379                let type_part = &line[..colon_pos].trim();
380                let fact_part = line[colon_pos + 1..].trim();
381
382                if !fact_part.is_empty() {
383                    facts.push(ExtractedFact {
384                        fact: fact_part.to_string(),
385                        fact_type: FactCategory::from_str(type_part),
386                        confidence: 0.75,
387                    });
388                }
389            } else if line.len() > 10 {
390                // Line without type prefix
391                facts.push(ExtractedFact {
392                    fact: line.to_string(),
393                    fact_type: FactCategory::Other,
394                    confidence: 0.5,
395                });
396            }
397
398            if facts.len() >= self.max_facts {
399                break;
400            }
401        }
402
403        facts
404    }
405
406    /// Heuristic fact extraction (no LLM)
407    fn _extract_facts_heuristic(&self, content: &str) -> Vec<ExtractedFact> {
408        let mut facts = Vec::new();
409
410        // Look for sentences with decision/requirement indicators
411        for sentence in content.split(['.', '!', '?']) {
412            let sentence = sentence.trim();
413            if sentence.len() < 10 {
414                continue;
415            }
416
417            let lower = sentence.to_lowercase();
418            let fact_type = if lower.contains("decided")
419                || lower.contains("will use")
420                || lower.contains("chose")
421            {
422                FactCategory::Decision
423            } else if lower.contains("must")
424                || lower.contains("should")
425                || lower.contains("need to")
426            {
427                FactCategory::Requirement
428            } else if lower.contains("is defined as") || lower.contains("means") {
429                FactCategory::Definition
430            } else if lower.contains("changed")
431                || lower.contains("fixed")
432                || lower.contains("updated")
433            {
434                FactCategory::CodeChange
435            } else if lower.contains("configured") || lower.contains("set to") {
436                FactCategory::Configuration
437            } else {
438                continue; // Skip non-fact sentences
439            };
440
441            facts.push(ExtractedFact {
442                fact: sentence.to_string(),
443                fact_type,
444                confidence: 0.5,
445            });
446
447            if facts.len() >= self.max_facts {
448                break;
449            }
450        }
451
452        // If no facts found, create one from the first sentence
453        if facts.is_empty()
454            && let Some(first_sentence) = content.split('.').next()
455            && first_sentence.len() > 10
456        {
457            facts.push(ExtractedFact {
458                fact: first_sentence.trim().to_string(),
459                fact_type: FactCategory::Other,
460                confidence: 0.3,
461            });
462        }
463
464        facts
465    }
466}
467
468/// Builder for LocalSummarizer
469pub struct LocalSummarizerBuilder {
470    provider: Option<Arc<dyn Provider>>,
471    model_id: String,
472    max_summary_tokens: u32,
473    max_facts: usize,
474}
475
476impl Default for LocalSummarizerBuilder {
477    fn default() -> Self {
478        Self {
479            provider: None,
480            model_id: "lfm2-1.2b".to_string(), // Use larger model for summarization
481            max_summary_tokens: 150,
482            max_facts: 5,
483        }
484    }
485}
486
487impl LocalSummarizerBuilder {
488    /// Create a new builder with default settings.
489    pub fn new() -> Self {
490        Self::default()
491    }
492
493    /// Set the provider to use for summarization.
494    pub fn provider(mut self, provider: Arc<dyn Provider>) -> Self {
495        self.provider = Some(provider);
496        self
497    }
498
499    /// Set the model ID to use for inference.
500    pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
501        self.model_id = model_id.into();
502        self
503    }
504
505    /// Set the maximum number of tokens for summary output.
506    pub fn max_summary_tokens(mut self, tokens: u32) -> Self {
507        self.max_summary_tokens = tokens;
508        self
509    }
510
511    /// Set the maximum number of facts to extract per summary.
512    pub fn max_facts(mut self, facts: usize) -> Self {
513        self.max_facts = facts;
514        self
515    }
516
517    /// Build the summarizer, returning `None` if no provider was set.
518    pub fn build(self) -> Option<LocalSummarizer> {
519        self.provider.map(|p| {
520            LocalSummarizer::new(p, self.model_id)
521                .with_max_summary_tokens(self.max_summary_tokens)
522                .with_max_facts(self.max_facts)
523        })
524    }
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530
531    #[test]
532    fn test_summarization_result() {
533        let result = SummarizationResult::from_local("Test summary".to_string(), 0.9);
534        assert!(result.used_local_llm);
535        assert_eq!(result.confidence, 0.9);
536
537        let fallback = SummarizationResult::from_fallback("Fallback".to_string());
538        assert!(!fallback.used_local_llm);
539        assert_eq!(fallback.confidence, 0.3);
540    }
541
542    #[test]
543    fn test_fact_category_parsing() {
544        assert_eq!(FactCategory::from_str("Decision"), FactCategory::Decision);
545        assert_eq!(
546            FactCategory::from_str("REQUIREMENT"),
547            FactCategory::Requirement
548        );
549        assert_eq!(
550            FactCategory::from_str("code change"),
551            FactCategory::CodeChange
552        );
553        assert_eq!(FactCategory::from_str("random"), FactCategory::Other);
554    }
555
556    #[test]
557    fn test_entity_extraction() {
558        let _summarizer = LocalSummarizerBuilder::default();
559
560        // Test entity extraction logic
561        let content = "Modified src/main.rs and added LocalSummarizer to handle_request function";
562        let entities = extract_entities_direct(content);
563
564        assert!(
565            entities
566                .iter()
567                .any(|e| e.contains("main.rs") || e.contains("src/"))
568        );
569        assert!(
570            entities
571                .iter()
572                .any(|e| e.contains("LocalSummarizer") || e.contains("handle_request"))
573        );
574    }
575
576    fn extract_entities_direct(content: &str) -> Vec<String> {
577        let mut entities = Vec::new();
578
579        // Extract file paths
580        if let Ok(re) = regex::Regex::new(r"([a-zA-Z0-9_\-/]+\.[a-z]{2,4})") {
581            for cap in re.captures_iter(content) {
582                if let Some(m) = cap.get(0) {
583                    entities.push(m.as_str().to_string());
584                }
585            }
586        }
587
588        // Extract PascalCase names
589        if let Ok(re) = regex::Regex::new(r"\b([A-Z][a-zA-Z0-9]+)\b") {
590            for cap in re.captures_iter(content) {
591                if let Some(m) = cap.get(1) {
592                    let name = m.as_str().to_string();
593                    if !["Modified", "This", "That"].contains(&name.as_str()) {
594                        entities.push(name);
595                    }
596                }
597            }
598        }
599
600        entities
601    }
602
603    #[test]
604    fn test_heuristic_fact_extraction() {
605        let content =
606            "We decided to use Rust. The config must be updated. The function was changed.";
607        let facts = extract_facts_heuristic_direct(content);
608
609        assert!(!facts.is_empty());
610        assert!(facts.iter().any(|f| f.fact_type == FactCategory::Decision));
611    }
612
613    fn extract_facts_heuristic_direct(content: &str) -> Vec<ExtractedFact> {
614        let mut facts = Vec::new();
615
616        for sentence in content.split('.') {
617            let sentence = sentence.trim();
618            if sentence.len() < 10 {
619                continue;
620            }
621
622            let lower = sentence.to_lowercase();
623            let fact_type = if lower.contains("decided") {
624                FactCategory::Decision
625            } else if lower.contains("must") {
626                FactCategory::Requirement
627            } else if lower.contains("changed") {
628                FactCategory::CodeChange
629            } else {
630                continue;
631            };
632
633            facts.push(ExtractedFact {
634                fact: sentence.to_string(),
635                fact_type,
636                confidence: 0.5,
637            });
638        }
639
640        facts
641    }
642
643    #[test]
644    fn test_truncate_summary() {
645        let long_content = "word ".repeat(200);
646        let truncated = truncate_summary_direct(&long_content);
647
648        let word_count = truncated.split_whitespace().count();
649        assert!(word_count <= 101); // 100 words + "..."
650    }
651
652    fn truncate_summary_direct(content: &str) -> String {
653        let words: Vec<&str> = content.split_whitespace().collect();
654        if words.len() <= 100 {
655            content.to_string()
656        } else {
657            format!("{}...", words[..100].join(" "))
658        }
659    }
660}