Skip to main content

kernex_memory/store/
context.rs

1//! Context building and user profile formatting.
2//!
3//! Helper functions for onboarding, system prompt composition, language
4//! detection, and relative time formatting live in `context_helpers`.
5
6use super::Store;
7use kernex_core::{
8    config::SYSTEM_FACT_KEYS,
9    context::{CompactionStrategy, Context, ContextEntry, ContextNeeds},
10    error::KernexError,
11    message::Request,
12    traits::Summarizer,
13};
14
15// Re-export helpers so existing `super::context::*` paths in tests keep working.
16pub use super::context_helpers::detect_language;
17#[cfg(test)]
18pub(super) use super::context_helpers::onboarding_hint_text;
19pub(super) use super::context_helpers::{
20    build_system_prompt, compute_onboarding_stage, SystemPromptContext,
21};
22
23/// Identity fact keys — shown first in the user profile.
24const IDENTITY_KEYS: &[&str] = &["name", "preferred_name", "pronouns"];
25
26/// Context fact keys — shown second in the user profile.
27const CONTEXT_KEYS: &[&str] = &["timezone", "location", "occupation"];
28
29impl Store {
30    /// Build a conversation context from memory for the provider.
31    ///
32    /// The `channel` parameter identifies the communication channel since
33    /// `Request` is channel-agnostic.
34    ///
35    /// When `needs.compact` is [`CompactionStrategy::Summarize`] and a
36    /// `summarizer` is provided, overflow messages (those beyond
37    /// `max_context_messages`) are summarized and prepended to the system
38    /// prompt instead of being silently dropped.
39    pub async fn build_context(
40        &self,
41        channel: &str,
42        incoming: &Request,
43        base_system_prompt: &str,
44        needs: &ContextNeeds,
45        active_project: Option<&str>,
46        summarizer: Option<&dyn Summarizer>,
47    ) -> Result<Context, KernexError> {
48        let project_key = active_project.unwrap_or("");
49        let conv_id = self
50            .get_or_create_conversation(channel, &incoming.sender_id, project_key)
51            .await?;
52
53        let history_fut = async {
54            let rows: Vec<(String, String)> = sqlx::query_as(
55                "SELECT role, content FROM (\
56                     SELECT role, content, timestamp FROM messages \
57                     WHERE conversation_id = ? ORDER BY timestamp DESC LIMIT ?\
58                 ) ORDER BY timestamp ASC",
59            )
60            .bind(&conv_id)
61            .bind(self.max_context_messages as i64)
62            .fetch_all(&self.pool)
63            .await
64            .map_err(|e| KernexError::Store(format!("query failed: {e}")))?;
65
66            Ok::<Vec<ContextEntry>, KernexError>(
67                rows.into_iter()
68                    .map(|(role, content)| ContextEntry { role, content })
69                    .collect(),
70            )
71        };
72
73        let facts_fut = async {
74            self.get_facts(&incoming.sender_id)
75                .await
76                .unwrap_or_default()
77        };
78
79        let summaries_fut = async {
80            if needs.summaries {
81                self.get_recent_summaries(channel, &incoming.sender_id, 3)
82                    .await
83                    .unwrap_or_default()
84            } else {
85                vec![]
86            }
87        };
88
89        let recall_fut = async {
90            if needs.recall {
91                self.search_messages(&incoming.text, &conv_id, &incoming.sender_id, 5)
92                    .await
93                    .unwrap_or_default()
94            } else {
95                vec![]
96            }
97        };
98
99        let tasks_fut = async {
100            if needs.pending_tasks {
101                self.get_tasks_for_sender(&incoming.sender_id)
102                    .await
103                    .unwrap_or_default()
104            } else {
105                vec![]
106            }
107        };
108
109        let outcomes_fut = async {
110            if needs.outcomes {
111                self.get_recent_outcomes(&incoming.sender_id, 15, active_project)
112                    .await
113                    .unwrap_or_default()
114            } else {
115                vec![]
116            }
117        };
118
119        let lessons_fut = async {
120            self.get_lessons(&incoming.sender_id, active_project)
121                .await
122                .unwrap_or_default()
123        };
124
125        let (history_res, facts, summaries, recall, pending_tasks, outcomes, lessons) = tokio::join!(
126            history_fut,
127            facts_fut,
128            summaries_fut,
129            recall_fut,
130            tasks_fut,
131            outcomes_fut,
132            lessons_fut,
133        );
134
135        let history = history_res?;
136
137        // Auto-compact: summarize overflow messages instead of silently dropping.
138        let compact_summary =
139            if let (CompactionStrategy::Summarize, Some(s)) = (&needs.compact, summarizer) {
140                let total: (i64,) =
141                    sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ?")
142                        .bind(&conv_id)
143                        .fetch_one(&self.pool)
144                        .await
145                        .map_err(|e| KernexError::Store(format!("count failed: {e}")))?;
146
147                let overflow_count = (total.0 as usize).saturating_sub(self.max_context_messages);
148
149                if overflow_count > 0 {
150                    let overflow_rows: Vec<(String, String)> = sqlx::query_as(
151                        "SELECT role, content FROM messages \
152                     WHERE conversation_id = ? ORDER BY timestamp ASC LIMIT ?",
153                    )
154                    .bind(&conv_id)
155                    .bind(overflow_count as i64)
156                    .fetch_all(&self.pool)
157                    .await
158                    .map_err(|e| KernexError::Store(format!("query failed: {e}")))?;
159
160                    if !overflow_rows.is_empty() {
161                        let text = overflow_rows
162                            .iter()
163                            .map(|(role, content)| format!("{role}: {content}"))
164                            .collect::<Vec<_>>()
165                            .join("\n");
166
167                        match s.summarize(&text).await {
168                            Ok(summary) if !summary.is_empty() => Some(summary),
169                            _ => None,
170                        }
171                    } else {
172                        None
173                    }
174                } else {
175                    None
176                }
177            } else {
178                None
179            };
180
181        // Resolve language: stored preference > auto-detect > English.
182        let language =
183            if let Some((_, lang)) = facts.iter().find(|(k, _)| k == "preferred_language") {
184                lang.clone()
185            } else {
186                let detected = detect_language(&incoming.text).to_string();
187                let _ = self
188                    .store_fact(&incoming.sender_id, "preferred_language", &detected)
189                    .await;
190                detected
191            };
192
193        // Progressive onboarding: compute stage and inject hint on transitions.
194        let real_fact_count = facts
195            .iter()
196            .filter(|(k, _)| !SYSTEM_FACT_KEYS.contains(&k.as_str()))
197            .count();
198        let has_tasks = !pending_tasks.is_empty();
199
200        let current_stage: u8 = facts
201            .iter()
202            .find(|(k, _)| k == "onboarding_stage")
203            .and_then(|(_, v)| v.parse().ok())
204            .unwrap_or(0);
205
206        let new_stage = compute_onboarding_stage(current_stage, real_fact_count, has_tasks);
207
208        let onboarding_hint = if new_stage != current_stage {
209            let _ = self
210                .store_fact(
211                    &incoming.sender_id,
212                    "onboarding_stage",
213                    &new_stage.to_string(),
214                )
215                .await;
216            Some(new_stage)
217        } else if current_stage == 0 && real_fact_count == 0 {
218            Some(0u8)
219        } else {
220            if facts.iter().all(|(k, _)| k != "onboarding_stage") && current_stage == 0 {
221                let bootstrapped = compute_onboarding_stage(0, real_fact_count, has_tasks);
222                let final_stage = (0..=4).fold(0u8, |s, _| {
223                    compute_onboarding_stage(s, real_fact_count, has_tasks)
224                });
225                if final_stage > 0 {
226                    let _ = self
227                        .store_fact(
228                            &incoming.sender_id,
229                            "onboarding_stage",
230                            &final_stage.to_string(),
231                        )
232                        .await;
233                }
234                let _ = bootstrapped;
235                None
236            } else {
237                None
238            }
239        };
240
241        let facts_for_prompt: &[(String, String)] = if needs.profile { &facts } else { &[] };
242        let built_prompt = build_system_prompt(&SystemPromptContext {
243            base_rules: base_system_prompt,
244            facts: facts_for_prompt,
245            summaries: &summaries,
246            recall: &recall,
247            pending_tasks: &pending_tasks,
248            outcomes: &outcomes,
249            lessons: &lessons,
250            language: &language,
251            onboarding_hint,
252        });
253
254        let system_prompt = if let Some(summary) = compact_summary {
255            format!("[Earlier conversation summary]\n{summary}\n\n{built_prompt}")
256        } else {
257            built_prompt
258        };
259
260        Ok(Context {
261            system_prompt,
262            history,
263            current_message: incoming.text.clone(),
264            mcp_servers: Vec::new(),
265            toolboxes: Vec::new(),
266            max_turns: None,
267            allowed_tools: None,
268            model: None,
269            session_id: None,
270            agent_name: None,
271            hook_runner: None,
272            permission_rules: None,
273            extended_thinking: false,
274        })
275    }
276}
277
278/// Format user facts into a structured profile, filtering system keys
279/// and grouping identity facts first, then context, then the rest.
280pub fn format_user_profile(facts: &[(String, String)]) -> String {
281    let user_facts: Vec<&(String, String)> = facts
282        .iter()
283        .filter(|(k, _)| !SYSTEM_FACT_KEYS.contains(&k.as_str()))
284        .collect();
285
286    if user_facts.is_empty() {
287        return String::new();
288    }
289
290    let mut lines = vec!["User profile:".to_string()];
291
292    for key in IDENTITY_KEYS {
293        if let Some((_, v)) = user_facts.iter().find(|(k, _)| k == key) {
294            lines.push(format!("- {key}: {v}"));
295        }
296    }
297
298    for key in CONTEXT_KEYS {
299        if let Some((_, v)) = user_facts.iter().find(|(k, _)| k == key) {
300            lines.push(format!("- {key}: {v}"));
301        }
302    }
303
304    let known_keys: Vec<&str> = IDENTITY_KEYS
305        .iter()
306        .chain(CONTEXT_KEYS.iter())
307        .copied()
308        .collect();
309    for (k, v) in &user_facts {
310        if !known_keys.contains(&k.as_str()) {
311            lines.push(format!("- {k}: {v}"));
312        }
313    }
314
315    lines.join("\n")
316}