kernex-memory 0.8.2

Pluggable storage for Kernex: conversations, learning, and scheduled tasks.
Documentation
//! Context building and user profile formatting.
//!
//! Helper functions for onboarding, system prompt composition, language
//! detection, and relative time formatting live in `context_helpers`.

use super::Store;
use crate::error::MemoryError;
use kernex_core::{
    config::SYSTEM_FACT_KEYS,
    context::{CompactionStrategy, Context, ContextEntry, ContextNeeds},
    message::Request,
    traits::Summarizer,
};

// Re-export helpers so existing `super::context::*` paths in tests keep working.
pub use super::context_helpers::detect_language;
#[cfg(test)]
pub(super) use super::context_helpers::onboarding_hint_text;
pub(super) use super::context_helpers::{
    build_system_prompt, compute_onboarding_stage, SystemPromptContext,
};

/// Identity fact keys — shown first in the user profile.
const IDENTITY_KEYS: &[&str] = &["name", "preferred_name", "pronouns"];

/// Context fact keys — shown second in the user profile.
const CONTEXT_KEYS: &[&str] = &["timezone", "location", "occupation"];

impl Store {
    /// Build a conversation context from memory for the provider.
    ///
    /// The `channel` parameter identifies the communication channel since
    /// `Request` is channel-agnostic.
    ///
    /// When `needs.compact` is [`CompactionStrategy::Summarize`] and a
    /// `summarizer` is provided, overflow messages (those beyond
    /// `max_context_messages`) are summarized and prepended to the system
    /// prompt instead of being silently dropped.
    pub async fn build_context(
        &self,
        channel: &str,
        incoming: &Request,
        base_system_prompt: &str,
        needs: &ContextNeeds,
        active_project: Option<&str>,
        summarizer: Option<&dyn Summarizer>,
    ) -> Result<Context, MemoryError> {
        let project_key = active_project.unwrap_or("");
        let conv_id = self
            .get_or_create_conversation(channel, &incoming.sender_id, project_key)
            .await?;

        let history_fut = async {
            let rows: Vec<(String, String)> = sqlx::query_as(
                "SELECT role, content FROM (\
                     SELECT role, content, timestamp FROM messages \
                     WHERE conversation_id = ? ORDER BY timestamp DESC LIMIT ?\
                 ) ORDER BY timestamp ASC",
            )
            .bind(&conv_id)
            .bind(self.max_context_messages as i64)
            .fetch_all(&self.pool)
            .await
            .map_err(|e| MemoryError::sqlite("query failed", e))?;

            Ok::<Vec<ContextEntry>, MemoryError>(
                rows.into_iter()
                    .map(|(role, content)| ContextEntry { role, content })
                    .collect(),
            )
        };

        let facts_fut = async {
            self.get_facts(&incoming.sender_id)
                .await
                .unwrap_or_default()
        };

        let summaries_fut = async {
            if needs.summaries {
                self.get_recent_summaries(channel, &incoming.sender_id, 3)
                    .await
                    .unwrap_or_default()
            } else {
                vec![]
            }
        };

        let recall_fut = async {
            if needs.recall {
                self.search_messages(&incoming.text, &conv_id, &incoming.sender_id, 5, None)
                    .await
                    .unwrap_or_default()
            } else {
                vec![]
            }
        };

        let tasks_fut = async {
            if needs.pending_tasks {
                self.get_tasks_for_sender(&incoming.sender_id)
                    .await
                    .unwrap_or_default()
            } else {
                vec![]
            }
        };

        let outcomes_fut = async {
            if needs.outcomes {
                self.get_recent_outcomes(&incoming.sender_id, 15, active_project)
                    .await
                    .unwrap_or_default()
            } else {
                vec![]
            }
        };

        let lessons_fut = async {
            self.get_lessons(&incoming.sender_id, active_project)
                .await
                .unwrap_or_default()
        };

        let (history_res, facts, summaries, recall, pending_tasks, outcomes, lessons) = tokio::join!(
            history_fut,
            facts_fut,
            summaries_fut,
            recall_fut,
            tasks_fut,
            outcomes_fut,
            lessons_fut,
        );

        let history = history_res?;

        // Detect history overflow once. We only run the COUNT when the
        // history loader hit its LIMIT, so short conversations pay nothing.
        // Whichever branch handles overflow (summarize or warn) reuses the
        // same count.
        let overflow_count = if history.len() >= self.max_context_messages {
            let total: (i64,) =
                sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ?")
                    .bind(&conv_id)
                    .fetch_one(&self.pool)
                    .await
                    .map_err(|e| MemoryError::sqlite("count failed", e))?;
            (total.0 as usize).saturating_sub(self.max_context_messages)
        } else {
            0
        };

        // Auto-compact: summarize overflow messages instead of silently dropping.
        let compact_summary = if overflow_count > 0 {
            if let (CompactionStrategy::Summarize, Some(s)) = (&needs.compact, summarizer) {
                let overflow_rows: Vec<(String, String)> = sqlx::query_as(
                    "SELECT role, content FROM messages \
                     WHERE conversation_id = ? ORDER BY timestamp ASC LIMIT ?",
                )
                .bind(&conv_id)
                .bind(overflow_count as i64)
                .fetch_all(&self.pool)
                .await
                .map_err(|e| MemoryError::sqlite("query failed", e))?;

                if overflow_rows.is_empty() {
                    None
                } else {
                    let text = overflow_rows
                        .iter()
                        .map(|(role, content)| format!("{role}: {content}"))
                        .collect::<Vec<_>>()
                        .join("\n");

                    match s.summarize(&text).await {
                        Ok(summary) if !summary.is_empty() => Some(summary),
                        Ok(_) => {
                            tracing::warn!(
                                conversation_id = %conv_id,
                                overflow = overflow_count,
                                "summarizer returned empty output; dropping {overflow_count} oldest messages",
                            );
                            None
                        }
                        Err(e) => {
                            tracing::warn!(
                                conversation_id = %conv_id,
                                overflow = overflow_count,
                                error = %e,
                                "summarizer failed; falling back to silent drop of {overflow_count} oldest messages",
                            );
                            None
                        }
                    }
                }
            } else {
                // Drop strategy or no summarizer wired in. Surface this as a
                // warn so operators running with default tracing see that
                // history is being lost and have a path to the fix
                // (RuntimeBuilder::auto_compact). One log per overflow event.
                tracing::warn!(
                    conversation_id = %conv_id,
                    overflow = overflow_count,
                    max = self.max_context_messages,
                    "history overflow: silently dropping {overflow_count} oldest messages. \
                     Enable RuntimeBuilder::auto_compact for summarization.",
                );
                None
            }
        } else {
            None
        };

        // Resolve language: stored preference > auto-detect > English.
        let language =
            if let Some((_, lang)) = facts.iter().find(|(k, _)| k == "preferred_language") {
                lang.clone()
            } else {
                let detected = detect_language(&incoming.text).to_string();
                let _ = self
                    .store_fact(&incoming.sender_id, "preferred_language", &detected)
                    .await;
                detected
            };

        // Progressive onboarding: compute stage and inject hint on transitions.
        let real_fact_count = facts
            .iter()
            .filter(|(k, _)| !SYSTEM_FACT_KEYS.contains(&k.as_str()))
            .count();
        let has_tasks = !pending_tasks.is_empty();

        let current_stage: u8 = facts
            .iter()
            .find(|(k, _)| k == "onboarding_stage")
            .and_then(|(_, v)| v.parse().ok())
            .unwrap_or(0);

        let new_stage = compute_onboarding_stage(current_stage, real_fact_count, has_tasks);

        let onboarding_hint = if new_stage != current_stage {
            let _ = self
                .store_fact(
                    &incoming.sender_id,
                    "onboarding_stage",
                    &new_stage.to_string(),
                )
                .await;
            Some(new_stage)
        } else if current_stage == 0 && real_fact_count == 0 {
            Some(0u8)
        } else {
            if facts.iter().all(|(k, _)| k != "onboarding_stage") && current_stage == 0 {
                let bootstrapped = compute_onboarding_stage(0, real_fact_count, has_tasks);
                let final_stage = (0..=4).fold(0u8, |s, _| {
                    compute_onboarding_stage(s, real_fact_count, has_tasks)
                });
                if final_stage > 0 {
                    let _ = self
                        .store_fact(
                            &incoming.sender_id,
                            "onboarding_stage",
                            &final_stage.to_string(),
                        )
                        .await;
                }
                let _ = bootstrapped;
                None
            } else {
                None
            }
        };

        let facts_for_prompt: &[(String, String)] = if needs.profile { &facts } else { &[] };
        let built_prompt = build_system_prompt(&SystemPromptContext {
            base_rules: base_system_prompt,
            facts: facts_for_prompt,
            summaries: &summaries,
            recall: &recall,
            pending_tasks: &pending_tasks,
            outcomes: &outcomes,
            lessons: &lessons,
            language: &language,
            onboarding_hint,
        });

        let system_prompt = if let Some(summary) = compact_summary {
            format!("[Earlier conversation summary]\n{summary}\n\n{built_prompt}")
        } else {
            built_prompt
        };

        Ok(Context {
            system_prompt,
            history,
            current_message: incoming.text.clone(),
            mcp_servers: Vec::new(),
            toolboxes: Vec::new(),
            max_turns: None,
            allowed_tools: None,
            model: None,
            session_id: None,
            agent_name: None,
            hook_runner: None,
            permission_rules: None,
            extended_thinking: false,
        })
    }
}

/// Format user facts into a structured profile, filtering system keys
/// and grouping identity facts first, then context, then the rest.
pub fn format_user_profile(facts: &[(String, String)]) -> String {
    let user_facts: Vec<&(String, String)> = facts
        .iter()
        .filter(|(k, _)| !SYSTEM_FACT_KEYS.contains(&k.as_str()))
        .collect();

    if user_facts.is_empty() {
        return String::new();
    }

    let mut lines = vec!["User profile:".to_string()];

    for key in IDENTITY_KEYS {
        if let Some((_, v)) = user_facts.iter().find(|(k, _)| k == key) {
            lines.push(format!("- {key}: {v}"));
        }
    }

    for key in CONTEXT_KEYS {
        if let Some((_, v)) = user_facts.iter().find(|(k, _)| k == key) {
            lines.push(format!("- {key}: {v}"));
        }
    }

    let known_keys: Vec<&str> = IDENTITY_KEYS
        .iter()
        .chain(CONTEXT_KEYS.iter())
        .copied()
        .collect();
    for (k, v) in &user_facts {
        if !known_keys.contains(&k.as_str()) {
            lines.push(format!("- {k}: {v}"));
        }
    }

    lines.join("\n")
}