Skip to main content

cognee_cognify/memify/
persist_sessions.rs

1//! Stage 2 of `improve()` — persist session Q&A text into the permanent
2//! knowledge graph.
3//!
4//! Ported from:
5//! - `/tmp/cognee-python/cognee/tasks/memify/extract_user_sessions.py`
6//! - `/tmp/cognee-python/cognee/tasks/memify/cognify_session.py`
7//!
8//! For each session ID:
9//! 1. Load all Q&A entries.
10//! 2. Concatenate into a single string matching the Python format
11//!    (`"Session ID: <sid>\n\nQuestion: <q>\n\nAnswer: <a>\n\n"`).
12//! 3. Run `AddPipeline::add_with_params(...)` with
13//!    `node_set = ["user_sessions_from_cache"]`.
14//! 4. Run `cognify(...)` on the resulting data rows to extract entities and
15//!    relationships.
16//!
17//! Empty sessions are silently skipped; individual session failures do not
18//! abort the loop (matches Python's try/except per-session).
19
20use std::sync::Arc;
21
22use cognee_core::CpuPool;
23use cognee_database::{DatabaseConnection, PipelineRunRepository};
24use cognee_embedding::EmbeddingEngine;
25use cognee_graph::GraphDBTrait;
26use cognee_ingestion::{AddParams, AddPipeline};
27use cognee_llm::Llm;
28use cognee_models::DataInput;
29use cognee_ontology::OntologyResolver;
30use cognee_session::SessionStore;
31use cognee_storage::StorageTrait;
32use cognee_vector::VectorDB;
33use thiserror::Error;
34use tracing::{info, warn};
35use uuid::Uuid;
36
37use crate::config::CognifyConfig;
38use crate::error::CognifyError;
39use crate::tasks::cognify;
40
41/// Node-set tag attached to session-derived data; matches Python
42/// `cognify_session.py:32`.
43pub const USER_SESSIONS_NODE_SET: &str = "user_sessions_from_cache";
44
45/// Error type for Stage 2 (`persist_sessions_in_knowledge_graph`).
46#[derive(Debug, Error)]
47pub enum PersistSessionsError {
48    #[error("Session error: {0}")]
49    Session(#[from] cognee_session::SessionError),
50
51    #[error("Ingestion error: {0}")]
52    Ingestion(String),
53
54    #[error("Cognify error: {0}")]
55    Cognify(#[from] CognifyError),
56}
57
58/// Summary of a Stage 2 run.
59#[derive(Debug, Clone, Default)]
60pub struct PersistSessionsResult {
61    /// Number of sessions whose text was successfully persisted to the graph.
62    pub sessions_persisted: usize,
63    /// Number of sessions that were skipped (empty).
64    pub sessions_skipped: usize,
65    /// Number of sessions that failed to persist (non-fatal; logged).
66    pub sessions_failed: usize,
67}
68
69/// Concatenate all Q&A entries of a session into a single string.
70///
71/// Matches Python `extract_user_sessions.py:62-67`:
72/// ```text
73/// Session ID: {sid}
74///
75/// Question: {q}
76///
77/// Answer: {a}
78///
79/// Question: {q2}
80///
81/// Answer: {a2}
82///
83/// ```
84fn concat_session_entries(session_id: &str, entries: &[cognee_session::SessionQAEntry]) -> String {
85    let mut buf = format!("Session ID: {session_id}\n\n");
86    for e in entries {
87        buf.push_str(&format!(
88            "Question: {}\n\nAnswer: {}\n\n",
89            e.question, e.answer
90        ));
91    }
92    buf
93}
94
95/// Persist session Q&A text into the permanent graph.
96#[allow(clippy::too_many_arguments)]
97pub async fn persist_sessions_in_knowledge_graph(
98    session_ids: &[String],
99    dataset_name: &str,
100    owner_id: Uuid,
101    tenant_id: Option<Uuid>,
102    session_store: Arc<dyn SessionStore>,
103    add_pipeline: &AddPipeline,
104    llm: Arc<dyn Llm>,
105    storage: Arc<dyn StorageTrait>,
106    graph_db: Arc<dyn GraphDBTrait>,
107    vector_db: Arc<dyn VectorDB>,
108    embedding_engine: Arc<dyn EmbeddingEngine>,
109    database: Arc<DatabaseConnection>,
110    pipeline_run_repo: Arc<dyn PipelineRunRepository>,
111    thread_pool: Arc<dyn CpuPool>,
112    ontology_resolver: Arc<dyn OntologyResolver>,
113    cognify_config: &CognifyConfig,
114) -> Result<PersistSessionsResult, PersistSessionsError> {
115    let user_id_str = owner_id.to_string();
116    let mut result = PersistSessionsResult::default();
117
118    for sid in session_ids {
119        let entries = session_store
120            .get_all_qa_entries(sid, Some(&user_id_str))
121            .await?;
122        if entries.is_empty() {
123            info!(
124                session_id = sid,
125                "persist_sessions: empty session, skipping"
126            );
127            result.sessions_skipped += 1;
128            continue;
129        }
130
131        let buf = concat_session_entries(sid, &entries);
132        if buf.trim().is_empty() {
133            result.sessions_skipped += 1;
134            continue;
135        }
136
137        let params = AddParams {
138            node_set: Some(vec![USER_SESSIONS_NODE_SET.to_string()]),
139            ..Default::default()
140        };
141
142        let add_result = match add_pipeline
143            .add_with_params(
144                vec![DataInput::Text(buf)],
145                dataset_name,
146                owner_id,
147                tenant_id,
148                &params,
149            )
150            .await
151        {
152            Ok(v) => v,
153            Err(e) => {
154                warn!(session_id = sid, "persist_sessions: add failed: {e}");
155                result.sessions_failed += 1;
156                continue;
157            }
158        };
159
160        if add_result.is_empty() {
161            warn!(session_id = sid, "persist_sessions: add returned no rows");
162            result.sessions_failed += 1;
163            continue;
164        }
165
166        // Each Data row in add_result belongs to exactly one dataset; use the
167        // first row's dataset lookup by querying via name. AddPipeline does
168        // not currently return the dataset_id directly, so we derive it from
169        // the node_set-tagged Data entries through the same helper that
170        // cognify_dataset_refs uses: look up by name/owner.
171        let dataset_id = match cognee_database::ops::datasets::get_dataset_by_name(
172            database.as_ref(),
173            dataset_name,
174            owner_id,
175            tenant_id,
176        )
177        .await
178        {
179            Ok(Some(ds)) => ds.id,
180            Ok(None) => {
181                warn!(
182                    session_id = sid,
183                    dataset_name = dataset_name,
184                    "persist_sessions: dataset lookup returned None"
185                );
186                result.sessions_failed += 1;
187                continue;
188            }
189            Err(e) => {
190                warn!(
191                    session_id = sid,
192                    "persist_sessions: dataset lookup failed: {e}"
193                );
194                result.sessions_failed += 1;
195                continue;
196            }
197        };
198
199        match cognify(
200            add_result,
201            dataset_id,
202            Some(owner_id),
203            None,
204            tenant_id,
205            Arc::clone(&llm),
206            Arc::clone(&storage),
207            Arc::clone(&graph_db),
208            Arc::clone(&vector_db),
209            Arc::clone(&embedding_engine),
210            Arc::clone(&database),
211            Arc::clone(&pipeline_run_repo),
212            Arc::clone(&thread_pool),
213            Arc::clone(&ontology_resolver),
214            cognify_config,
215        )
216        .await
217        {
218            Ok(_) => {
219                info!(session_id = sid, "persist_sessions: session persisted");
220                result.sessions_persisted += 1;
221            }
222            Err(e) => {
223                warn!(
224                    session_id = sid,
225                    "persist_sessions: cognify failed (non-fatal): {e}"
226                );
227                result.sessions_failed += 1;
228            }
229        }
230    }
231
232    Ok(result)
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use cognee_session::SessionQAEntry;
239    use uuid::Uuid;
240
241    fn mk_entry(q: &str, a: &str) -> SessionQAEntry {
242        SessionQAEntry {
243            id: Uuid::new_v4(),
244            session_id: "s1".into(),
245            user_id: None,
246            question: q.into(),
247            answer: a.into(),
248            context: None,
249            created_at: chrono::Utc::now(),
250            feedback_text: None,
251            feedback_score: None,
252            used_graph_element_ids: None,
253            memify_metadata: None,
254        }
255    }
256
257    #[test]
258    fn concat_format_matches_python() {
259        let entries = vec![mk_entry("q1", "a1"), mk_entry("q2", "a2")];
260        let out = concat_session_entries("sid-1", &entries);
261        let expected =
262            "Session ID: sid-1\n\nQuestion: q1\n\nAnswer: a1\n\nQuestion: q2\n\nAnswer: a2\n\n";
263        assert_eq!(out, expected);
264    }
265
266    #[test]
267    fn concat_empty_entries() {
268        let out = concat_session_entries("sid-empty", &[]);
269        assert_eq!(out, "Session ID: sid-empty\n\n");
270    }
271}