1use super::Store;
7use kernex_core::{
8 config::SYSTEM_FACT_KEYS,
9 context::{CompactionStrategy, Context, ContextEntry, ContextNeeds},
10 error::KernexError,
11 message::Request,
12 traits::Summarizer,
13};
14
15pub use super::context_helpers::detect_language;
17#[cfg(test)]
18pub(super) use super::context_helpers::onboarding_hint_text;
19pub(super) use super::context_helpers::{
20 build_system_prompt, compute_onboarding_stage, SystemPromptContext,
21};
22
23const IDENTITY_KEYS: &[&str] = &["name", "preferred_name", "pronouns"];
25
26const CONTEXT_KEYS: &[&str] = &["timezone", "location", "occupation"];
28
29impl Store {
30 pub async fn build_context(
40 &self,
41 channel: &str,
42 incoming: &Request,
43 base_system_prompt: &str,
44 needs: &ContextNeeds,
45 active_project: Option<&str>,
46 summarizer: Option<&dyn Summarizer>,
47 ) -> Result<Context, KernexError> {
48 let project_key = active_project.unwrap_or("");
49 let conv_id = self
50 .get_or_create_conversation(channel, &incoming.sender_id, project_key)
51 .await?;
52
53 let history_fut = async {
54 let rows: Vec<(String, String)> = sqlx::query_as(
55 "SELECT role, content FROM (\
56 SELECT role, content, timestamp FROM messages \
57 WHERE conversation_id = ? ORDER BY timestamp DESC LIMIT ?\
58 ) ORDER BY timestamp ASC",
59 )
60 .bind(&conv_id)
61 .bind(self.max_context_messages as i64)
62 .fetch_all(&self.pool)
63 .await
64 .map_err(|e| KernexError::Store(format!("query failed: {e}")))?;
65
66 Ok::<Vec<ContextEntry>, KernexError>(
67 rows.into_iter()
68 .map(|(role, content)| ContextEntry { role, content })
69 .collect(),
70 )
71 };
72
73 let facts_fut = async {
74 self.get_facts(&incoming.sender_id)
75 .await
76 .unwrap_or_default()
77 };
78
79 let summaries_fut = async {
80 if needs.summaries {
81 self.get_recent_summaries(channel, &incoming.sender_id, 3)
82 .await
83 .unwrap_or_default()
84 } else {
85 vec![]
86 }
87 };
88
89 let recall_fut = async {
90 if needs.recall {
91 self.search_messages(&incoming.text, &conv_id, &incoming.sender_id, 5)
92 .await
93 .unwrap_or_default()
94 } else {
95 vec![]
96 }
97 };
98
99 let tasks_fut = async {
100 if needs.pending_tasks {
101 self.get_tasks_for_sender(&incoming.sender_id)
102 .await
103 .unwrap_or_default()
104 } else {
105 vec![]
106 }
107 };
108
109 let outcomes_fut = async {
110 if needs.outcomes {
111 self.get_recent_outcomes(&incoming.sender_id, 15, active_project)
112 .await
113 .unwrap_or_default()
114 } else {
115 vec![]
116 }
117 };
118
119 let lessons_fut = async {
120 self.get_lessons(&incoming.sender_id, active_project)
121 .await
122 .unwrap_or_default()
123 };
124
125 let (history_res, facts, summaries, recall, pending_tasks, outcomes, lessons) = tokio::join!(
126 history_fut,
127 facts_fut,
128 summaries_fut,
129 recall_fut,
130 tasks_fut,
131 outcomes_fut,
132 lessons_fut,
133 );
134
135 let history = history_res?;
136
137 let compact_summary =
139 if let (CompactionStrategy::Summarize, Some(s)) = (&needs.compact, summarizer) {
140 let total: (i64,) =
141 sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ?")
142 .bind(&conv_id)
143 .fetch_one(&self.pool)
144 .await
145 .map_err(|e| KernexError::Store(format!("count failed: {e}")))?;
146
147 let overflow_count = (total.0 as usize).saturating_sub(self.max_context_messages);
148
149 if overflow_count > 0 {
150 let overflow_rows: Vec<(String, String)> = sqlx::query_as(
151 "SELECT role, content FROM messages \
152 WHERE conversation_id = ? ORDER BY timestamp ASC LIMIT ?",
153 )
154 .bind(&conv_id)
155 .bind(overflow_count as i64)
156 .fetch_all(&self.pool)
157 .await
158 .map_err(|e| KernexError::Store(format!("query failed: {e}")))?;
159
160 if !overflow_rows.is_empty() {
161 let text = overflow_rows
162 .iter()
163 .map(|(role, content)| format!("{role}: {content}"))
164 .collect::<Vec<_>>()
165 .join("\n");
166
167 match s.summarize(&text).await {
168 Ok(summary) if !summary.is_empty() => Some(summary),
169 _ => None,
170 }
171 } else {
172 None
173 }
174 } else {
175 None
176 }
177 } else {
178 None
179 };
180
181 let language =
183 if let Some((_, lang)) = facts.iter().find(|(k, _)| k == "preferred_language") {
184 lang.clone()
185 } else {
186 let detected = detect_language(&incoming.text).to_string();
187 let _ = self
188 .store_fact(&incoming.sender_id, "preferred_language", &detected)
189 .await;
190 detected
191 };
192
193 let real_fact_count = facts
195 .iter()
196 .filter(|(k, _)| !SYSTEM_FACT_KEYS.contains(&k.as_str()))
197 .count();
198 let has_tasks = !pending_tasks.is_empty();
199
200 let current_stage: u8 = facts
201 .iter()
202 .find(|(k, _)| k == "onboarding_stage")
203 .and_then(|(_, v)| v.parse().ok())
204 .unwrap_or(0);
205
206 let new_stage = compute_onboarding_stage(current_stage, real_fact_count, has_tasks);
207
208 let onboarding_hint = if new_stage != current_stage {
209 let _ = self
210 .store_fact(
211 &incoming.sender_id,
212 "onboarding_stage",
213 &new_stage.to_string(),
214 )
215 .await;
216 Some(new_stage)
217 } else if current_stage == 0 && real_fact_count == 0 {
218 Some(0u8)
219 } else {
220 if facts.iter().all(|(k, _)| k != "onboarding_stage") && current_stage == 0 {
221 let bootstrapped = compute_onboarding_stage(0, real_fact_count, has_tasks);
222 let final_stage = (0..=4).fold(0u8, |s, _| {
223 compute_onboarding_stage(s, real_fact_count, has_tasks)
224 });
225 if final_stage > 0 {
226 let _ = self
227 .store_fact(
228 &incoming.sender_id,
229 "onboarding_stage",
230 &final_stage.to_string(),
231 )
232 .await;
233 }
234 let _ = bootstrapped;
235 None
236 } else {
237 None
238 }
239 };
240
241 let facts_for_prompt: &[(String, String)] = if needs.profile { &facts } else { &[] };
242 let built_prompt = build_system_prompt(&SystemPromptContext {
243 base_rules: base_system_prompt,
244 facts: facts_for_prompt,
245 summaries: &summaries,
246 recall: &recall,
247 pending_tasks: &pending_tasks,
248 outcomes: &outcomes,
249 lessons: &lessons,
250 language: &language,
251 onboarding_hint,
252 });
253
254 let system_prompt = if let Some(summary) = compact_summary {
255 format!("[Earlier conversation summary]\n{summary}\n\n{built_prompt}")
256 } else {
257 built_prompt
258 };
259
260 Ok(Context {
261 system_prompt,
262 history,
263 current_message: incoming.text.clone(),
264 mcp_servers: Vec::new(),
265 toolboxes: Vec::new(),
266 max_turns: None,
267 allowed_tools: None,
268 model: None,
269 session_id: None,
270 agent_name: None,
271 hook_runner: None,
272 permission_rules: None,
273 extended_thinking: false,
274 })
275 }
276}
277
278pub fn format_user_profile(facts: &[(String, String)]) -> String {
281 let user_facts: Vec<&(String, String)> = facts
282 .iter()
283 .filter(|(k, _)| !SYSTEM_FACT_KEYS.contains(&k.as_str()))
284 .collect();
285
286 if user_facts.is_empty() {
287 return String::new();
288 }
289
290 let mut lines = vec!["User profile:".to_string()];
291
292 for key in IDENTITY_KEYS {
293 if let Some((_, v)) = user_facts.iter().find(|(k, _)| k == key) {
294 lines.push(format!("- {key}: {v}"));
295 }
296 }
297
298 for key in CONTEXT_KEYS {
299 if let Some((_, v)) = user_facts.iter().find(|(k, _)| k == key) {
300 lines.push(format!("- {key}: {v}"));
301 }
302 }
303
304 let known_keys: Vec<&str> = IDENTITY_KEYS
305 .iter()
306 .chain(CONTEXT_KEYS.iter())
307 .copied()
308 .collect();
309 for (k, v) in &user_facts {
310 if !known_keys.contains(&k.as_str()) {
311 lines.push(format!("- {k}: {v}"));
312 }
313 }
314
315 lines.join("\n")
316}