Skip to main content

cognee_lib/api/
improve.rs

1//! Bidirectional session-graph bridge — `improve()`.
2//!
3//! Four-stage pipeline matching Python `cognee.api.v1.improve.improve()`:
4//! 1. Apply feedback weights from session Q&A entries to graph nodes/edges.
5//! 2. Persist session Q&A text into the permanent knowledge graph.
6//! 3. Default enrichment: reuse `memify()` for triplet embeddings.
7//! 4. Sync recent graph edges into the session's `graph_context`.
8//!
9//! Each stage is wrapped in a warning-only handler so that a failure in one
10//! stage does not abort subsequent stages (matches Python semantics).
11
12use std::sync::Arc;
13
14use cognee_cognify::memify::sync_graph_session::DEFAULT_MAX_LINES;
15use cognee_cognify::{
16    CognifyConfig, MemifyConfig, MemifyResult, apply_feedback_weights_pipeline,
17    persist_sessions_in_knowledge_graph, run_memify, sync_graph_to_session,
18};
19use cognee_database::{
20    CheckpointStore, DatabaseConnection, PipelineRunRepository, SeaOrmPipelineRunRepository,
21};
22use cognee_embedding::EmbeddingEngine;
23use cognee_graph::GraphDBTrait;
24use cognee_ingestion::{AddParams, AddPipeline};
25use cognee_llm::Llm;
26use cognee_models::DataInput;
27use cognee_ontology::OntologyResolver;
28use cognee_session::{ImproveLockGuard, SessionManager, SessionStore};
29use cognee_storage::StorageTrait;
30use cognee_vector::VectorDB;
31use tracing::{info, warn};
32use uuid::Uuid;
33
34use super::error::ApiError;
35
36/// Result of an `improve()` operation.
37#[derive(Debug, Clone, Default)]
38pub struct ImproveResult {
39    /// Names of stages that were executed.
40    pub stages_run: Vec<String>,
41    /// Result of the memify (triplet embedding) stage, if it ran.
42    pub memify_result: Option<MemifyResult>,
43    /// Number of feedback QA entries that were processed (Stage 1).
44    pub feedback_entries_processed: usize,
45    /// Number of feedback QA entries whose graph updates all applied cleanly.
46    pub feedback_entries_applied: usize,
47    /// Number of sessions whose Q&A text was persisted to the graph (Stage 2).
48    pub sessions_persisted: usize,
49    /// Total number of edges newly synced into session contexts (Stage 4).
50    pub edges_synced: usize,
51}
52
53/// Parameters for [`improve`].
54///
55/// All fields are required at construction time — `Default` is intentionally
56/// not derived because several fields (`owner_id`, the engine handles, and
57/// `cognify_config`) have no sensible default value. This forces every caller
58/// to think about each dependency. Callers that omit optional behavior should
59/// pass `None` explicitly for the `Option<...>` fields.
60///
61/// LIB-04 (Decision 8) introduced this struct to replace the previous 18
62/// positional parameters. E-05 (this commit) extended it with three v2
63/// power-user fields — `extraction_tasks`, `enrichment_tasks`, `data` —
64/// matching Python's `ImprovePayloadDTO` field-for-field. They are pure-data
65/// fields and currently informational: the orchestrator does not yet branch
66/// on them, but they are accepted by the constructor so callers (especially
67/// the HTTP layer) can plumb the raw payload through without dropping fields.
68pub struct ImproveParams<'a> {
69    /// Dataset name to operate on (Stage 2 persistence + Stage 4 lookup).
70    pub dataset_name: String,
71    /// Session ids that drive Stages 1, 2, and 4. `None` or empty skips them.
72    pub session_ids: Option<Vec<String>>,
73    /// Optional graph node-name filter applied to the memify (Stage 3) pass.
74    pub node_name: Option<Vec<String>>,
75    /// Owner UUID under which graph/session reads and writes are scoped.
76    pub owner_id: Uuid,
77    /// Optional tenant UUID for multi-tenant deployments.
78    pub tenant_id: Option<Uuid>,
79    /// Mixing factor for feedback weight propagation (Stage 1).
80    pub feedback_alpha: f64,
81
82    /// Optional list of extraction-task identifiers (Python parity:
83    /// `extraction_tasks: Optional[List[str]]`). Currently informational —
84    /// reserved for future power-user overrides matching Python's
85    /// `ImproveKwargs.extraction_tasks`.
86    pub extraction_tasks: Option<Vec<String>>,
87    /// Optional list of enrichment-task identifiers (Python parity:
88    /// `enrichment_tasks: Optional[List[str]]`). Currently informational.
89    pub enrichment_tasks: Option<Vec<String>>,
90    /// Optional inline text payload (Python parity: `data: Optional[str]`).
91    /// Currently informational; reserved for future power-user overrides.
92    pub data: Option<String>,
93
94    /// When `true` and not running in background, build the global context
95    /// index (graph summary) after Stage 3.
96    ///
97    /// Mirrors Python's `build_global_context_index` parameter.
98    /// Default `false` (opt-in) — matches Python parity.
99    pub build_global_context_index: bool,
100
101    /// When `true`, treat this as a background run: skips stages that
102    /// require the prior stage to have completed synchronously (e.g. the
103    /// global context index and the sync-graph stage).
104    ///
105    /// Background dispatch is handled by the host (HTTP server or CLI);
106    /// this flag is used only for stage-skipping logic parity with Python.
107    pub run_in_background: bool,
108
109    /// LLM handle (used by Stage 2 cognify-of-session-text).
110    pub llm: Arc<dyn Llm>,
111    /// File storage handle.
112    pub storage: Arc<dyn StorageTrait>,
113    /// Graph database handle.
114    pub graph_db: Arc<dyn GraphDBTrait>,
115    /// Vector database handle.
116    pub vector_db: Arc<dyn VectorDB>,
117    /// Embedding engine handle.
118    pub embedding_engine: Arc<dyn EmbeddingEngine>,
119    /// Ontology resolver handle.
120    pub ontology_resolver: Arc<dyn OntologyResolver>,
121
122    /// Metadata DB connection. Required for Stage 4 (dataset lookup).
123    pub db: Option<Arc<DatabaseConnection>>,
124    /// Session backing store. Required for Stages 1 and 2.
125    pub session_store: Option<Arc<dyn SessionStore>>,
126    /// Session manager. Required for Stages 1 and 4.
127    pub session_manager: Option<Arc<SessionManager>>,
128    /// Add pipeline (borrowed). Required for Stage 2.
129    pub add_pipeline: Option<&'a AddPipeline>,
130    /// Checkpoint store. Required for Stage 4.
131    pub checkpoint_store: Option<Arc<dyn CheckpointStore>>,
132
133    /// Borrowed cognify configuration used by Stage 2 persistence.
134    pub cognify_config: &'a CognifyConfig,
135}
136
137/// Bidirectional session-graph bridge.
138///
139/// Background dispatch is a host-side concern — this function is strictly
140/// synchronous. Stage 4 always runs when sessions are present.
141///
142/// All inputs are passed via [`ImproveParams`] (see Decision 8 / LIB-04).
143pub async fn improve(params: ImproveParams<'_>) -> Result<ImproveResult, ApiError> {
144    let ImproveParams {
145        dataset_name,
146        session_ids,
147        node_name,
148        owner_id,
149        tenant_id,
150        feedback_alpha,
151        llm,
152        storage,
153        graph_db,
154        vector_db,
155        embedding_engine,
156        ontology_resolver,
157        db,
158        session_store,
159        session_manager,
160        add_pipeline,
161        checkpoint_store,
162        cognify_config,
163        build_global_context_index,
164        run_in_background,
165        // E-05 v2 power-user fields — currently informational; the orchestrator
166        // does not yet branch on them. Accepting them here keeps the struct
167        // shape Python-parity-aligned for HTTP plumbing.
168        extraction_tasks: _extraction_tasks,
169        enrichment_tasks: _enrichment_tasks,
170        data: _data,
171    } = params;
172
173    // ---- Improve lock (parity with Python session_lock.py:136-150) ----
174    //
175    // When exactly one session is targeted, acquire a per-session lock so
176    // that concurrent `improve()` calls on the same session don't duplicate
177    // work (e.g. auto-improve + idle-watcher + SessionEnd firing at once).
178    // Multi-session improves skip the lock — the pattern is rare and locking
179    // N sessions atomically is messy (matches Python comment verbatim).
180    //
181    // The guard holds a `String`, not a `MutexGuard`, so it is Send-safe
182    // across `.await` points.
183    let _improve_guard = if let Some(ref sids) = session_ids {
184        if sids.len() == 1 {
185            match ImproveLockGuard::acquire(&sids[0]) {
186                Some(g) => Some(g),
187                None => {
188                    info!(
189                        session_id = %sids[0],
190                        "improve: session already being improved, skipping"
191                    );
192                    // Parity with Python `return {}` — return empty result.
193                    return Ok(ImproveResult::default());
194                }
195            }
196        } else {
197            None
198        }
199    } else {
200        None
201    };
202
203    let mut result = ImproveResult::default();
204    let has_sessions = session_ids.as_ref().is_some_and(|ids| !ids.is_empty());
205
206    // Wrap the body in a `cognee.api.improve` OTEL span for parity with
207    // Python's `cognee.api.v1.improve.improve()` (gap 03 / task 03-07).
208    // Attribute names mirror the analytics payload below and the Python
209    // span's verbose names (`dataset`, `session_count`, `run_in_background`).
210    let session_count = session_ids.as_ref().map(|v| v.len()).unwrap_or(0);
211    let span = tracing::info_span!(
212        "cognee.api.improve",
213        dataset = %dataset_name,
214        session_count = session_count,
215        run_in_background = false,
216    );
217    let _enter = span.enter();
218
219    // Mirrors Python `send_telemetry("cognee.improve", ...)` from
220    // cognee/api/v1/improve/improve.py:91.
221    #[cfg(feature = "telemetry")]
222    {
223        cognee_telemetry::send_telemetry(
224            "cognee.improve",
225            owner_id,
226            Some(serde_json::json!({
227                "dataset": dataset_name.clone(),
228                "session_count": session_count,
229                "session_ids": session_ids.clone(),
230                "run_in_background": false,
231                "cognee_version": env!("CARGO_PKG_VERSION"),
232            })),
233        );
234    }
235
236    // ---- Stage 1: Apply Feedback Weights ----
237    if has_sessions {
238        #[allow(clippy::expect_used, reason = "invariant is upheld by construction")]
239        let sids = session_ids
240            .as_ref()
241            .expect("has_sessions guarantees session_ids is Some with non-empty vec");
242        match (session_store.as_ref(), session_manager.as_ref()) {
243            (Some(store), Some(mgr)) => {
244                match apply_feedback_weights_pipeline(
245                    sids,
246                    owner_id,
247                    feedback_alpha,
248                    &*graph_db,
249                    Arc::clone(store),
250                    Arc::clone(mgr),
251                )
252                .await
253                {
254                    Ok(r) => {
255                        info!(
256                            processed = r.processed,
257                            applied = r.applied,
258                            skipped = r.skipped,
259                            "improve stage 1 (feedback_weights) complete"
260                        );
261                        result.feedback_entries_processed = r.processed;
262                        result.feedback_entries_applied = r.applied;
263                        result.stages_run.push("apply_feedback_weights".to_string());
264                    }
265                    Err(e) => {
266                        warn!("improve stage 1 (feedback_weights) failed (non-fatal): {e}");
267                    }
268                }
269            }
270            _ => {
271                warn!(
272                    "improve stage 1: session_store and session_manager are required; skipping feedback_weights"
273                );
274            }
275        }
276    }
277
278    // ---- Stage 2: Persist Session Q&A to Graph ----
279    if has_sessions {
280        #[allow(clippy::expect_used, reason = "invariant is upheld by construction")]
281        let sids = session_ids
282            .as_ref()
283            .expect("has_sessions guarantees session_ids is Some with non-empty vec");
284        // LIB-06-03: `persist_sessions_in_knowledge_graph` now requires
285        // `Arc<DatabaseConnection>` and `Arc<dyn CpuPool>`.
286        let stage2_db = db.clone();
287        match (session_store.as_ref(), add_pipeline, stage2_db) {
288            (Some(store), Some(pipeline), Some(database)) => {
289                let thread_pool: Arc<dyn cognee_core::CpuPool> =
290                    match cognee_core::RayonThreadPool::with_default_threads() {
291                        Ok(pool) => Arc::new(pool),
292                        Err(e) => {
293                            warn!(
294                                "improve stage 2: failed to construct thread pool: {e}; skipping persist_sessions"
295                            );
296                            return Ok(result);
297                        }
298                    };
299                let pipeline_run_repo: Arc<dyn PipelineRunRepository> =
300                    Arc::new(SeaOrmPipelineRunRepository::new(Arc::clone(&database)));
301                match persist_sessions_in_knowledge_graph(
302                    sids,
303                    &dataset_name,
304                    owner_id,
305                    tenant_id,
306                    Arc::clone(store),
307                    pipeline,
308                    Arc::clone(&llm),
309                    Arc::clone(&storage),
310                    Arc::clone(&graph_db),
311                    Arc::clone(&vector_db),
312                    Arc::clone(&embedding_engine),
313                    database,
314                    pipeline_run_repo,
315                    thread_pool,
316                    Arc::clone(&ontology_resolver),
317                    cognify_config,
318                )
319                .await
320                {
321                    Ok(r) => {
322                        info!(
323                            persisted = r.sessions_persisted,
324                            skipped = r.sessions_skipped,
325                            failed = r.sessions_failed,
326                            "improve stage 2 (persist_sessions) complete"
327                        );
328                        result.sessions_persisted = r.sessions_persisted;
329                        result.stages_run.push("persist_sessions".to_string());
330                    }
331                    Err(e) => {
332                        warn!("improve stage 2 (persist_sessions) failed (non-fatal): {e}");
333                    }
334                }
335            }
336            _ => {
337                warn!(
338                    "improve stage 2: session_store, add_pipeline, and DatabaseConnection are required; skipping persist_sessions"
339                );
340            }
341        }
342    }
343
344    // ---- Stage 2b: Persist Agent Trace Steps ----
345    //
346    // Mirrors Python's `_persist_session_traces` (improve.py:166-176).
347    // Reads `session_feedback` from each trace step and cognifies it into the
348    // permanent graph so that the plugin's tool-call activity reaches permanent
349    // memory — not just QA entries.
350    //
351    // Scoped-down 0.1.0 implementation: collects trace `session_feedback` text
352    // (the per-step LLM-generated feedback string) and runs it through the
353    // add→cognify path with node_set `"agent_trace_feedbacks"`.
354    //
355    // TODO(parity): Python's `persist_agent_trace_feedbacks_in_knowledge_graph_pipeline`
356    // uses per-step metadata (origin_function, status, method_params). The full
357    // parity pass should introduce a dedicated `persist_trace_feedbacks_in_knowledge_graph`
358    // function in cognee-cognify that preserves per-step provenance.
359    if has_sessions {
360        #[allow(clippy::expect_used, reason = "invariant is upheld by construction")]
361        let sids = session_ids
362            .as_ref()
363            .expect("has_sessions guarantees session_ids is Some with non-empty vec");
364
365        // Collect all trace feedback texts across the sessions.
366        let mut trace_texts: Vec<String> = Vec::new();
367        if let Some(mgr) = session_manager.as_ref() {
368            let user_id_str = owner_id.to_string();
369            for sid in sids {
370                match mgr
371                    .get_agent_trace_session(&user_id_str, Some(sid), None)
372                    .await
373                {
374                    Ok(steps) => {
375                        for step in &steps {
376                            if !step.session_feedback.is_empty() {
377                                trace_texts.push(format!(
378                                    "Session: {sid}\nFunction: {}\nStatus: {}\nFeedback: {}",
379                                    step.origin_function, step.status, step.session_feedback,
380                                ));
381                            }
382                        }
383                    }
384                    Err(e) => {
385                        warn!(
386                            session_id = sid,
387                            "improve stage 2b: could not read trace steps (non-fatal): {e}"
388                        );
389                    }
390                }
391            }
392        }
393
394        if !trace_texts.is_empty() {
395            let stage2b_db = db.clone();
396            let combined_text = trace_texts.join("\n\n");
397            match (add_pipeline, stage2b_db) {
398                (Some(pipeline), Some(database)) => {
399                    match cognee_core::RayonThreadPool::with_default_threads() {
400                        Ok(pool) => {
401                            let thread_pool: Arc<dyn cognee_core::CpuPool> = Arc::new(pool);
402                            let pipeline_run_repo: Arc<dyn PipelineRunRepository> =
403                                Arc::new(SeaOrmPipelineRunRepository::new(Arc::clone(&database)));
404                            let add_params = AddParams {
405                                node_set: Some(vec!["agent_trace_feedbacks".to_string()]),
406                                ..Default::default()
407                            };
408                            match pipeline
409                                .add_with_params(
410                                    vec![DataInput::Text(combined_text)],
411                                    &dataset_name,
412                                    owner_id,
413                                    tenant_id,
414                                    &add_params,
415                                )
416                                .await
417                                .map_err(|e| e.to_string())
418                            {
419                                Ok(data_rows) if !data_rows.is_empty() => {
420                                    // Resolve dataset_id the same way persist_sessions does.
421                                    let dataset_id_opt =
422                                        cognee_database::ops::datasets::get_dataset_by_name(
423                                            database.as_ref(),
424                                            &dataset_name,
425                                            owner_id,
426                                            tenant_id,
427                                        )
428                                        .await
429                                        .ok()
430                                        .flatten()
431                                        .map(|ds| ds.id);
432                                    if let Some(dataset_id) = dataset_id_opt {
433                                        match cognee_cognify::tasks::cognify(
434                                            data_rows,
435                                            dataset_id,
436                                            Some(owner_id),
437                                            None,
438                                            tenant_id,
439                                            Arc::clone(&llm),
440                                            Arc::clone(&storage),
441                                            Arc::clone(&graph_db),
442                                            Arc::clone(&vector_db),
443                                            Arc::clone(&embedding_engine),
444                                            Arc::clone(&database),
445                                            pipeline_run_repo,
446                                            thread_pool,
447                                            Arc::clone(&ontology_resolver),
448                                            cognify_config,
449                                        )
450                                        .await
451                                        {
452                                            Ok(_) => {
453                                                info!(
454                                                    trace_items = trace_texts.len(),
455                                                    "improve stage 2b (persist_trace_steps) complete"
456                                                );
457                                            }
458                                            Err(e) => {
459                                                warn!(
460                                                    "improve stage 2b: cognify of trace steps failed (non-fatal): {e}"
461                                                );
462                                            }
463                                        }
464                                    } else {
465                                        warn!(
466                                            "improve stage 2b: dataset lookup returned None; trace steps not cognified"
467                                        );
468                                    }
469                                }
470                                Ok(_) => {
471                                    warn!(
472                                        "improve stage 2b: add returned no rows; trace steps not cognified"
473                                    );
474                                }
475                                Err(e) => {
476                                    warn!(
477                                        "improve stage 2b: add of trace text failed (non-fatal): {e}"
478                                    );
479                                }
480                            }
481                        }
482                        Err(e) => {
483                            warn!("improve stage 2b: rayon pool init failed (non-fatal): {e}");
484                        }
485                    }
486                }
487                _ => {
488                    warn!(
489                        "improve stage 2b: add_pipeline and DatabaseConnection are required; trace steps not cognified"
490                    );
491                }
492            }
493        }
494        // Always push the stage name so stages_run stays consistent with Python,
495        // even when no traces were present or cognification was skipped/failed.
496        result.stages_run.push("persist_trace_steps".to_string());
497    }
498
499    // ---- Stage 3: Default Enrichment (always) ----
500    let memify_config = if let Some(names) = node_name {
501        MemifyConfig::default().with_node_name_filter(names)
502    } else {
503        MemifyConfig::default()
504    };
505    match db.as_ref() {
506        Some(database) => match cognee_core::RayonThreadPool::with_default_threads() {
507            Ok(pool) => {
508                let thread_pool: Arc<dyn cognee_core::CpuPool> = Arc::new(pool);
509                let pipeline_run_repo: Arc<dyn PipelineRunRepository> =
510                    Arc::new(SeaOrmPipelineRunRepository::new(Arc::clone(database)));
511                match run_memify(
512                    Arc::clone(&graph_db),
513                    Arc::clone(&vector_db),
514                    Arc::clone(&embedding_engine),
515                    thread_pool,
516                    Arc::clone(database),
517                    pipeline_run_repo,
518                    None,
519                    Some(owner_id),
520                    tenant_id,
521                    &memify_config,
522                )
523                .await
524                {
525                    Ok(mr) => {
526                        info!(
527                            triplets = mr.triplet_count,
528                            "improve stage 3 (memify) complete"
529                        );
530                        result.memify_result = Some(mr);
531                        result.stages_run.push("memify".to_string());
532                    }
533                    Err(e) => {
534                        warn!("improve stage 3 (memify) failed (non-fatal): {e}");
535                    }
536                }
537            }
538            Err(e) => {
539                warn!("improve stage 3 (memify) failed (non-fatal): rayon pool init: {e}");
540            }
541        },
542        None => {
543            warn!(
544                "improve stage 3: a relational database connection is required by the LIB-06 \
545                 executor-routed memify; skipping memify"
546            );
547        }
548    }
549
550    // ---- Stage 3b: Global Context Index (opt-in) ----
551    //
552    // Mirrors Python's `_build_global_context_index` (improve.py:201-213).
553    // When `build_global_context_index` is `true` and not running in background:
554    // build a graph summary and store it in the session graph-context so the
555    // search side can prepend it as background knowledge.
556    //
557    // Partial 0.1.0 implementation: retrieves graph summaries already stored as
558    // TextSummary nodes and concatenates them as the global context. Python's full
559    // implementation (`global_context_index_pipeline`) also builds bucket and root
560    // summaries via an LLM pass.
561    // TODO(parity): implement bucket/root summary indexing via a dedicated
562    // `global_context_index_pipeline` function in cognee-cognify that mirrors
563    // Python's `bucketing_strategy="graph"` / `max_bucket_size=4` pass.
564    if build_global_context_index {
565        if run_in_background {
566            warn!(
567                "improve stage 3b: global context index skipped in background mode \
568                 because ordered background pipeline chaining is not supported"
569            );
570        } else if let Some(sm) = session_manager.as_ref() {
571            // Partial 0.1.0: read all graph edges via `get_graph_data()` and format
572            // as "source_id → relationship → target_id" lines, then store as the
573            // global context so any session can prepend it as background knowledge.
574            // TODO(parity): replace with a full `global_context_index_pipeline` that
575            // uses LLM bucket/root summarisation (`bucketing_strategy="graph"`,
576            // `max_bucket_size=4`) matching Python's `_build_global_context_index`.
577            match graph_db.get_graph_data().await {
578                Ok((_nodes, edges)) if !edges.is_empty() => {
579                    let global_context = edges
580                        .iter()
581                        .map(|(src, tgt, rel, _props)| format!("{src} → {rel} → {tgt}"))
582                        .collect::<Vec<_>>()
583                        .join("\n");
584                    let user_id_str = owner_id.to_string();
585                    // Store under a synthetic global-context key so any session can read it.
586                    let global_session_key = "_global_context_index";
587                    match sm
588                        .set_graph_context(
589                            Some(global_session_key),
590                            Some(&user_id_str),
591                            &global_context,
592                        )
593                        .await
594                    {
595                        Ok(()) => {
596                            info!(
597                                edges = edges.len(),
598                                "improve stage 3b (global_context_index) complete"
599                            );
600                            result.stages_run.push("global_context_index".to_string());
601                        }
602                        Err(e) => {
603                            warn!(
604                                "improve stage 3b: failed to store global context (non-fatal): {e}"
605                            );
606                        }
607                    }
608                }
609                Ok(_) => {
610                    info!("improve stage 3b: graph has no edges; skipping global_context_index");
611                }
612                Err(e) => {
613                    warn!("improve stage 3b: failed to load graph data (non-fatal): {e}");
614                }
615            }
616        } else {
617            warn!("improve stage 3b: session_manager is required; skipping global_context_index");
618        }
619    }
620
621    // ---- Stage 4: Sync Graph to Session Cache ----
622    //
623    // Stage 4 always runs when sessions are present (background dispatch is host-side).
624    if has_sessions {
625        #[allow(clippy::expect_used, reason = "invariant is upheld by construction")]
626        let sids = session_ids
627            .as_ref()
628            .expect("has_sessions guarantees session_ids is Some with non-empty vec");
629        match (
630            db.as_ref(),
631            session_manager.as_ref(),
632            checkpoint_store.as_ref(),
633        ) {
634            (Some(dbc), Some(mgr), Some(ckstore)) => {
635                // Stage 4 requires a dataset UUID. Resolve from the name.
636                let dataset_id_opt = cognee_database::ops::datasets::get_dataset_by_name(
637                    dbc.as_ref(),
638                    &dataset_name,
639                    owner_id,
640                    tenant_id,
641                )
642                .await
643                .ok()
644                .flatten()
645                .map(|ds| ds.id);
646                let Some(dataset_id) = dataset_id_opt else {
647                    warn!(
648                        dataset_name = %dataset_name,
649                        "improve stage 4: dataset not found; skipping sync_graph_to_session"
650                    );
651                    return Ok(result);
652                };
653
654                let user_id_str = owner_id.to_string();
655                let mut total_synced = 0usize;
656                let mut any_ran = false;
657                for sid in sids {
658                    match sync_graph_to_session(
659                        &user_id_str,
660                        sid,
661                        dataset_id,
662                        dbc.as_ref(),
663                        mgr.as_ref(),
664                        ckstore.as_ref(),
665                        DEFAULT_MAX_LINES,
666                    )
667                    .await
668                    {
669                        Ok(r) => {
670                            info!(
671                                session_id = sid,
672                                synced = r.synced,
673                                total = r.total,
674                                "improve stage 4: session synced"
675                            );
676                            total_synced += r.synced;
677                            any_ran = true;
678                        }
679                        Err(e) => {
680                            warn!(
681                                session_id = sid,
682                                "improve stage 4 failed for session (non-fatal): {e}"
683                            );
684                        }
685                    }
686                }
687                result.edges_synced = total_synced;
688                if any_ran {
689                    result.stages_run.push("sync_graph_to_session".to_string());
690                }
691            }
692            _ => {
693                warn!(
694                    "improve stage 4: db, session_manager, and checkpoint_store are required; skipping sync_graph_to_session"
695                );
696            }
697        }
698    }
699
700    Ok(result)
701}
702
703#[cfg(test)]
704mod tests {
705    use super::*;
706
707    #[test]
708    fn improve_result_default_fields() {
709        let result = ImproveResult::default();
710        assert!(result.stages_run.is_empty());
711        assert!(result.memify_result.is_none());
712        assert_eq!(result.feedback_entries_processed, 0);
713        assert_eq!(result.feedback_entries_applied, 0);
714        assert_eq!(result.sessions_persisted, 0);
715        assert_eq!(result.edges_synced, 0);
716    }
717}