praxis_context/
default.rs

1use std::sync::Arc;
2use anyhow::Result;
3use async_trait::async_trait;
4use tiktoken_rs::cl100k_base;
5use chrono::Utc;
6
7use praxis_llm::{ChatClient, Message, Content};
8use praxis_persist::{PersistenceClient, DBMessage};
9use crate::strategy::{ContextStrategy, ContextWindow};
10use crate::templates::{DEFAULT_SYSTEM_PROMPT_TEMPLATE, DEFAULT_SUMMARIZATION_PROMPT};
11
12pub struct DefaultContextStrategy {
13    max_tokens: usize,
14    llm_client: Arc<dyn ChatClient>,
15    system_prompt_template: String,
16    summarization_template: String,
17}
18
19impl DefaultContextStrategy {
20    pub fn new(
21        max_tokens: usize,
22        llm_client: Arc<dyn ChatClient>,
23    ) -> Self {
24        Self {
25            max_tokens,
26            llm_client,
27            system_prompt_template: DEFAULT_SYSTEM_PROMPT_TEMPLATE.to_string(),
28            summarization_template: DEFAULT_SUMMARIZATION_PROMPT.to_string(),
29        }
30    }
31    
32    pub fn with_templates(
33        max_tokens: usize,
34        llm_client: Arc<dyn ChatClient>,
35        system_prompt_template: String,
36        summarization_template: String,
37    ) -> Self {
38        Self {
39            max_tokens,
40            llm_client,
41            system_prompt_template,
42            summarization_template,
43        }
44    }
45    
46    /// Count tokens in messages using tiktoken
47    fn count_tokens(&self, messages: &[DBMessage]) -> Result<usize> {
48        let bpe = cl100k_base().map_err(|e| anyhow::anyhow!("Tokenizer error: {}", e))?;
49        
50        let mut total_tokens = 0;
51        for msg in messages {
52            let tokens = bpe.encode_with_special_tokens(&msg.content);
53            total_tokens += tokens.len();
54        }
55        
56        Ok(total_tokens)
57    }
58    
59    /// Build conversation text from messages
60    fn build_conversation_text(messages: &[DBMessage]) -> String {
61        messages.iter()
62            .map(|m| {
63                let role = match m.role {
64                    praxis_persist::MessageRole::User => "User",
65                    praxis_persist::MessageRole::Assistant => "Assistant",
66                };
67                format!("{}: {}", role, m.content)
68            })
69            .collect::<Vec<_>>()
70            .join("\n")
71    }
72    
73    /// Generate summary of old messages
74    async fn generate_summary(&self, messages: &[DBMessage], previous_summary: Option<&str>) -> Result<String> {
75        let conversation = Self::build_conversation_text(messages);
76        
77        let previous_summary_text = previous_summary.unwrap_or("Não temos resumo ainda.");
78        
79        let summary_prompt = self.summarization_template
80            .replace("<previous_summary>", previous_summary_text)
81            .replace("<conversation>", &conversation);
82        
83        let request = praxis_llm::ChatRequest::new(
84            "gpt-4o-mini".to_string(),
85            vec![Message::Human {
86                content: Content::text(summary_prompt),
87                name: None,
88            }],
89        );
90        
91        let response = self.llm_client.chat(request).await?;
92        
93        let summary = response.content.unwrap_or_else(|| "Summary generation failed".to_string());
94        
95        Ok(summary)
96    }
97    
98    /// Build system prompt.
99    fn build_system_prompt(&self, summary: Option<&str>) -> String {
100        let summary_text = summary.unwrap_or("Não temos resumo ainda.");
101        self.system_prompt_template.replace("<summary>", summary_text)
102    }
103}
104
105#[async_trait]
106impl ContextStrategy for DefaultContextStrategy {
107    async fn get_context_window(
108        &self,
109        thread_id: &str,
110        persist_client: Arc<dyn PersistenceClient>,
111    ) -> Result<ContextWindow> {
112        // 1. Get thread
113        let thread = persist_client.get_thread(thread_id).await?
114            .ok_or_else(|| anyhow::anyhow!("Thread {} not found - should be created before sending messages", thread_id))?;
115        
116        // 2. Fetch messages after last_summary_update
117        let messages_to_evaluate = persist_client
118            .get_messages_after(thread_id, thread.last_summary_update)
119            .await?;
120        
121        let existing_summary = thread.summary.as_ref().map(|s| s.text.as_str());
122        if messages_to_evaluate.is_empty() {
123            return Ok(ContextWindow {
124                system_prompt: self.build_system_prompt(existing_summary),
125                messages: vec![],
126            });
127        }
128        
129        // 3. Count tokens of CURRENT WINDOW
130        let current_window_tokens = self.count_tokens(&messages_to_evaluate)?;
131        
132        // 4. If current window exceeds max_tokens, spawn async summary generation
133        if current_window_tokens > self.max_tokens {
134            // Clone everything needed for fire-and-forget task
135            let messages_clone = messages_to_evaluate.clone();
136            let previous_summary = existing_summary.map(|s| s.to_string());
137            let persist_client_clone = Arc::clone(&persist_client);
138            let thread_id_owned = thread_id.to_string();
139            
140            // Clone strategy fields to recreate context in async task
141            let strategy = Self {
142                max_tokens: self.max_tokens,
143                llm_client: self.llm_client.clone(),
144                system_prompt_template: self.system_prompt_template.clone(),
145                summarization_template: self.summarization_template.clone(),
146            };
147            
148            tokio::spawn(async move {
149                if let Ok(summary_text) = strategy
150                    .generate_summary(&messages_clone, previous_summary.as_deref())
151                    .await {
152                        let summary_time = Utc::now();
153                        let _ = persist_client_clone.save_thread_summary(
154                            &thread_id_owned,
155                            summary_text,
156                            summary_time
157                        ).await;
158                }
159            });
160        }
161        
162        // 6. Build system prompt with existing summary (if any)
163        let system_prompt = self.build_system_prompt(existing_summary);
164        
165        // 7. Convert DBMessage → praxis_llm::Message
166        let llm_messages = messages_to_evaluate
167            .into_iter()
168            .filter_map(|msg| msg.try_into().ok())
169            .collect();
170        
171        Ok(ContextWindow {
172            system_prompt,
173            messages: llm_messages,
174        })
175    }
176}
177