Skip to main content

cognee_cognify/memify/
feedback_weights.rs

1//! Stage 1 of `improve()` — apply feedback weights from session Q&A entries
2//! to graph nodes/edges.
3//!
4//! Ported from:
5//! - `/tmp/cognee-python/cognee/tasks/memify/apply_feedback_weights.py`
6//! - `/tmp/cognee-python/cognee/tasks/memify/extract_feedback_qas.py`
7//!
8//! The pipeline:
9//! 1. For each session, fetches all Q&A entries via `SessionStore::get_all_qa_entries`.
10//! 2. Filters to *eligible* entries: `1 <= feedback_score <= 5`, not already
11//!    marked `memify_metadata["feedback_weights_applied"] = true`, and with at
12//!    least one node id or edge id in `used_graph_element_ids`.
13//! 3. Normalizes the score to `[0, 1]` and applies a streaming update
14//!    (`w' = w + alpha * (r - w)`, clipped, rounded to 4 dp).
15//! 4. Reads/writes the `feedback_weight` property via the batch methods
16//!    on `GraphDBTrait`.
17//! 5. Marks the QA entry as processed regardless of whether the graph updates
18//!    actually succeeded (`success` flag records the outcome).
19
20use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use cognee_graph::{EdgeKey, GraphDBTrait};
24use cognee_session::{
25    SessionManager, SessionQAEntry, SessionQAUpdate, SessionStore, UsedGraphElementIds,
26};
27use thiserror::Error;
28use tracing::{info, warn};
29use uuid::Uuid;
30
31/// Error type for Stage 1 (`apply_feedback_weights`).
32#[derive(Debug, Error)]
33pub enum FeedbackError {
34    #[error("Invalid feedback_score {0}: must be in [1, 5]")]
35    InvalidScore(i32),
36
37    #[error("Invalid alpha {0}: must be in (0, 1]")]
38    InvalidAlpha(f64),
39
40    #[error("Session error: {0}")]
41    Session(#[from] cognee_session::SessionError),
42
43    #[error("Graph error: {0}")]
44    Graph(#[from] cognee_graph::GraphDBError),
45}
46
47/// Key used in `memify_metadata` to mark an entry as processed.
48pub const FEEDBACK_WEIGHTS_APPLIED_KEY: &str = "feedback_weights_applied";
49
50/// Number of decimal places to round the streaming update to. Matches
51/// Python's `FEEDBACK_WEIGHT_DECIMALS = 4`.
52const FEEDBACK_WEIGHT_DECIMALS: i32 = 4;
53
54/// Delimiter used to encode edge-IDs as `"source|||target|||relation"`
55/// strings in `used_graph_element_ids.edge_ids`.
56pub const EDGE_ID_DELIMITER: &str = "|||";
57
58/// Summary of a Stage 1 run.
59#[derive(Debug, Clone, Default)]
60pub struct FeedbackApplyResult {
61    /// Number of eligible entries that were processed (not skipped).
62    pub processed: usize,
63    /// Number of entries for which the graph updates were fully applied.
64    pub applied: usize,
65    /// Number of entries that were skipped (ineligible or already applied).
66    pub skipped: usize,
67}
68
69/// Normalize a 1..5 feedback score to [0.0, 1.0].
70///
71/// Matches Python `normalize_feedback_score()` (`apply_feedback_weights.py:43`).
72pub fn normalize_feedback_score(score: i32) -> Result<f64, FeedbackError> {
73    if !(1..=5).contains(&score) {
74        return Err(FeedbackError::InvalidScore(score));
75    }
76    Ok((score as f64 - 1.0) / 4.0)
77}
78
79/// Streaming weight update `w' = w + alpha * (r - w)` clipped to `[0, 1]`
80/// and rounded to 4 decimal places.
81///
82/// Matches Python `stream_update_weight()` (`apply_feedback_weights.py:53`).
83pub fn stream_update_weight(
84    previous_weight: f64,
85    normalized_rating: f64,
86    alpha: f64,
87) -> Result<f64, FeedbackError> {
88    if !(alpha > 0.0 && alpha <= 1.0) {
89        return Err(FeedbackError::InvalidAlpha(alpha));
90    }
91    let updated = previous_weight + alpha * (normalized_rating - previous_weight);
92    let clipped = updated.clamp(0.0, 1.0);
93    let factor = 10f64.powi(FEEDBACK_WEIGHT_DECIMALS);
94    Ok((clipped * factor).round() / factor)
95}
96
97/// Eligibility check matching Python `_is_eligible()`
98/// (`extract_feedback_qas.py:15-41`).
99fn is_eligible(entry: &SessionQAEntry) -> bool {
100    let score = match entry.feedback_score {
101        Some(s) if (1..=5).contains(&s) => s,
102        _ => return false,
103    };
104    let _ = score;
105
106    if let Some(meta) = &entry.memify_metadata
107        && meta.get(FEEDBACK_WEIGHTS_APPLIED_KEY).copied() == Some(true)
108    {
109        return false;
110    }
111
112    match &entry.used_graph_element_ids {
113        Some(ids) => {
114            let has_nodes = ids.node_ids.iter().any(|s| !s.is_empty());
115            let has_edges = ids.edge_ids.iter().any(|s| !s.is_empty());
116            has_nodes || has_edges
117        }
118        None => false,
119    }
120}
121
122/// De-duplicate and lexically sort string ids, preserving only non-empty
123/// entries. Mirrors Python `_extract_ids()`.
124fn dedup_sorted<'a>(ids: impl IntoIterator<Item = &'a String>) -> Vec<String> {
125    let set: HashSet<&str> = ids
126        .into_iter()
127        .map(|s| s.as_str())
128        .filter(|s| !s.is_empty())
129        .collect();
130    let mut v: Vec<String> = set.into_iter().map(|s| s.to_string()).collect();
131    v.sort();
132    v
133}
134
135/// Parse an edge id string `"source|||target|||rel"` into an
136/// [`EdgeKey`]. Returns `None` if the id does not have the expected three
137/// segments.
138fn parse_edge_id(id: &str) -> Option<EdgeKey> {
139    let parts: Vec<&str> = id.splitn(3, EDGE_ID_DELIMITER).collect();
140    if parts.len() != 3 {
141        return None;
142    }
143    Some((
144        parts[0].to_string(),
145        parts[1].to_string(),
146        parts[2].to_string(),
147    ))
148}
149
150/// Inner per-element update: fetch existing weights, compute new weights
151/// via `stream_update_weight`, write back, and report whether every id
152/// was found and written.
153async fn update_node_weights(
154    graph_db: &dyn GraphDBTrait,
155    ids: &[String],
156    normalized_rating: f64,
157    alpha: f64,
158) -> Result<bool, FeedbackError> {
159    if ids.is_empty() {
160        return Ok(true);
161    }
162    let existing = graph_db.get_node_feedback_weights(ids).await?;
163    let mut updates: HashMap<String, f64> = HashMap::new();
164    let mut all_found = true;
165    for id in ids {
166        match existing.get(id).copied() {
167            Some(prev) => {
168                updates.insert(
169                    id.clone(),
170                    stream_update_weight(prev, normalized_rating, alpha)?,
171                );
172            }
173            None => {
174                all_found = false;
175            }
176        }
177    }
178    if updates.is_empty() {
179        return Ok(false);
180    }
181    let results = graph_db.set_node_feedback_weights(&updates).await?;
182    let all_written = updates
183        .keys()
184        .all(|k| results.get(k).copied().unwrap_or(false));
185    Ok(all_found && all_written)
186}
187
188async fn update_edge_weights(
189    graph_db: &dyn GraphDBTrait,
190    edge_ids: &[String],
191    normalized_rating: f64,
192    alpha: f64,
193) -> Result<bool, FeedbackError> {
194    if edge_ids.is_empty() {
195        return Ok(true);
196    }
197    // Parse "source|||target|||rel" strings; silently skip malformed ones
198    // and treat them as "not found" for the all-applied flag.
199    let mut keys: Vec<EdgeKey> = Vec::with_capacity(edge_ids.len());
200    let mut all_parsed = true;
201    for id in edge_ids {
202        match parse_edge_id(id) {
203            Some(k) => keys.push(k),
204            None => {
205                warn!("feedback_weights: malformed edge id {id:?}, skipping");
206                all_parsed = false;
207            }
208        }
209    }
210    if keys.is_empty() {
211        return Ok(false);
212    }
213    let existing = graph_db.get_edge_feedback_weights(&keys).await?;
214    let mut updates: HashMap<EdgeKey, f64> = HashMap::new();
215    let mut all_found = true;
216    for k in &keys {
217        match existing.get(k).copied() {
218            Some(prev) => {
219                updates.insert(
220                    k.clone(),
221                    stream_update_weight(prev, normalized_rating, alpha)?,
222                );
223            }
224            None => {
225                all_found = false;
226            }
227        }
228    }
229    if updates.is_empty() {
230        return Ok(false);
231    }
232    let results = graph_db.set_edge_feedback_weights(&updates).await?;
233    let all_written = updates
234        .keys()
235        .all(|k| results.get(k).copied().unwrap_or(false));
236    Ok(all_parsed && all_found && all_written)
237}
238
239/// Mark the QA entry's `memify_metadata["feedback_weights_applied"] = success`
240/// via the session manager.
241async fn mark_feedback_processed(
242    session_manager: &SessionManager,
243    session_id: &str,
244    user_id: &str,
245    qa_id: &str,
246    current_metadata: Option<&HashMap<String, bool>>,
247    success: bool,
248) -> Result<(), FeedbackError> {
249    let mut meta: HashMap<String, bool> = current_metadata.cloned().unwrap_or_default();
250    meta.insert(FEEDBACK_WEIGHTS_APPLIED_KEY.to_string(), success);
251
252    session_manager
253        .update_qa(
254            Some(session_id),
255            Some(user_id),
256            qa_id,
257            SessionQAUpdate {
258                memify_metadata: Some(Some(meta)),
259                ..Default::default()
260            },
261        )
262        .await?;
263    Ok(())
264}
265
266/// Apply feedback-weight updates for the given sessions.
267#[allow(clippy::too_many_arguments)]
268pub async fn apply_feedback_weights_pipeline(
269    session_ids: &[String],
270    owner_id: Uuid,
271    alpha: f64,
272    graph_db: &dyn GraphDBTrait,
273    session_store: Arc<dyn SessionStore>,
274    session_manager: Arc<SessionManager>,
275) -> Result<FeedbackApplyResult, FeedbackError> {
276    if !(alpha > 0.0 && alpha <= 1.0) {
277        return Err(FeedbackError::InvalidAlpha(alpha));
278    }
279
280    let user_id_str = owner_id.to_string();
281    let mut result = FeedbackApplyResult::default();
282
283    for session_id in session_ids {
284        let entries = session_store
285            .get_all_qa_entries(session_id, Some(&user_id_str))
286            .await?;
287
288        for entry in entries {
289            if !is_eligible(&entry) {
290                result.skipped += 1;
291                continue;
292            }
293
294            let score = match entry.feedback_score {
295                Some(s) => s,
296                None => {
297                    // Unreachable: is_eligible requires Some(valid).
298                    result.skipped += 1;
299                    continue;
300                }
301            };
302            let normalized = match normalize_feedback_score(score) {
303                Ok(v) => v,
304                Err(_) => {
305                    result.skipped += 1;
306                    continue;
307                }
308            };
309
310            let used = entry
311                .used_graph_element_ids
312                .as_ref()
313                .cloned()
314                .unwrap_or(UsedGraphElementIds::default());
315            let node_ids = dedup_sorted(used.node_ids.iter());
316            let edge_ids = dedup_sorted(used.edge_ids.iter());
317
318            if node_ids.is_empty() && edge_ids.is_empty() {
319                // Eligible entry with no usable ids — mark as processed
320                // (success=false) so we don't revisit.
321                mark_feedback_processed(
322                    &session_manager,
323                    session_id,
324                    &user_id_str,
325                    &entry.id.to_string(),
326                    entry.memify_metadata.as_ref(),
327                    false,
328                )
329                .await?;
330                result.skipped += 1;
331                continue;
332            }
333
334            let node_success = update_node_weights(graph_db, &node_ids, normalized, alpha).await?;
335            let edge_success = update_edge_weights(graph_db, &edge_ids, normalized, alpha).await?;
336            let success = node_success && edge_success;
337
338            mark_feedback_processed(
339                &session_manager,
340                session_id,
341                &user_id_str,
342                &entry.id.to_string(),
343                entry.memify_metadata.as_ref(),
344                success,
345            )
346            .await?;
347
348            info!(
349                qa_id = %entry.id,
350                session_id = session_id,
351                nodes = node_ids.len(),
352                edges = edge_ids.len(),
353                applied = success,
354                "feedback_weights: processed QA entry"
355            );
356
357            result.processed += 1;
358            if success {
359                result.applied += 1;
360            }
361        }
362    }
363
364    info!(
365        processed = result.processed,
366        applied = result.applied,
367        skipped = result.skipped,
368        "feedback_weights: stage complete"
369    );
370    Ok(result)
371}
372
373#[cfg(test)]
374#[allow(
375    clippy::unwrap_used,
376    clippy::expect_used,
377    reason = "test code — panics are acceptable failures"
378)]
379mod tests {
380    use super::*;
381
382    #[test]
383    fn normalize_scores_endpoints() {
384        assert!((normalize_feedback_score(1).unwrap() - 0.0).abs() < 1e-9);
385        assert!((normalize_feedback_score(3).unwrap() - 0.5).abs() < 1e-9);
386        assert!((normalize_feedback_score(5).unwrap() - 1.0).abs() < 1e-9);
387    }
388
389    #[test]
390    fn normalize_scores_rejects_out_of_range() {
391        assert!(normalize_feedback_score(0).is_err());
392        assert!(normalize_feedback_score(6).is_err());
393        assert!(normalize_feedback_score(-1).is_err());
394    }
395
396    #[test]
397    fn stream_update_midpoint() {
398        let w = stream_update_weight(0.5, 1.0, 0.1).unwrap();
399        assert!((w - 0.55).abs() < 1e-9, "got {w}");
400    }
401
402    #[test]
403    fn stream_update_zero_stays_zero() {
404        let w = stream_update_weight(0.0, 0.0, 1.0).unwrap();
405        assert!((w - 0.0).abs() < 1e-9);
406    }
407
408    #[test]
409    fn stream_update_clips_high() {
410        // normalized_rating clamped but math says 0.9 + 0.5*(2.0-0.9) = 1.45 -> 1.0
411        let w = stream_update_weight(0.9, 2.0, 0.5).unwrap();
412        assert!((w - 1.0).abs() < 1e-9);
413    }
414
415    #[test]
416    fn stream_update_clips_low() {
417        let w = stream_update_weight(0.1, -1.0, 0.5).unwrap();
418        assert!((w - 0.0).abs() < 1e-9);
419    }
420
421    #[test]
422    fn stream_update_rejects_bad_alpha() {
423        assert!(stream_update_weight(0.5, 0.5, 0.0).is_err());
424        assert!(stream_update_weight(0.5, 0.5, 1.1).is_err());
425        assert!(stream_update_weight(0.5, 0.5, -0.1).is_err());
426    }
427
428    #[test]
429    fn stream_update_rounds_to_4dp() {
430        // 0.1 + 0.1 * (0.5 - 0.1) = 0.14 exactly, but float leftovers
431        let w = stream_update_weight(0.1, 0.5, 0.1).unwrap();
432        assert!((w - 0.14).abs() < 1e-12);
433    }
434
435    #[test]
436    fn parse_edge_id_ok() {
437        let k = parse_edge_id("a|||b|||rel").unwrap();
438        assert_eq!(k.0, "a");
439        assert_eq!(k.1, "b");
440        assert_eq!(k.2, "rel");
441    }
442
443    #[test]
444    fn parse_edge_id_with_extra_delim_in_rel() {
445        let k = parse_edge_id("a|||b|||rel|||extra").unwrap();
446        assert_eq!(k.2, "rel|||extra");
447    }
448
449    #[test]
450    fn parse_edge_id_malformed() {
451        assert!(parse_edge_id("no_delim").is_none());
452        assert!(parse_edge_id("only|||one").is_none());
453    }
454
455    #[test]
456    fn dedup_sorted_works() {
457        let v = [
458            "b".to_string(),
459            "a".to_string(),
460            "".to_string(),
461            "a".to_string(),
462        ];
463        let result = dedup_sorted(v.iter());
464        assert_eq!(result, vec!["a".to_string(), "b".to_string()]);
465    }
466
467    #[test]
468    fn is_eligible_valid_node_ids() {
469        let entry = SessionQAEntry {
470            id: Uuid::new_v4(),
471            session_id: "s".into(),
472            user_id: None,
473            question: "q".into(),
474            answer: "a".into(),
475            context: None,
476            created_at: chrono::Utc::now(),
477            feedback_text: None,
478            feedback_score: Some(4),
479            used_graph_element_ids: Some(UsedGraphElementIds {
480                node_ids: vec!["n1".into()],
481                edge_ids: vec![],
482            }),
483            memify_metadata: None,
484        };
485        assert!(is_eligible(&entry));
486    }
487
488    #[test]
489    fn is_eligible_rejects_already_applied() {
490        let mut meta = HashMap::new();
491        meta.insert("feedback_weights_applied".to_string(), true);
492        let entry = SessionQAEntry {
493            id: Uuid::new_v4(),
494            session_id: "s".into(),
495            user_id: None,
496            question: "q".into(),
497            answer: "a".into(),
498            context: None,
499            created_at: chrono::Utc::now(),
500            feedback_text: None,
501            feedback_score: Some(4),
502            used_graph_element_ids: Some(UsedGraphElementIds {
503                node_ids: vec!["n1".into()],
504                edge_ids: vec![],
505            }),
506            memify_metadata: Some(meta),
507        };
508        assert!(!is_eligible(&entry));
509    }
510
511    #[test]
512    fn is_eligible_rejects_missing_ids() {
513        let entry = SessionQAEntry {
514            id: Uuid::new_v4(),
515            session_id: "s".into(),
516            user_id: None,
517            question: "q".into(),
518            answer: "a".into(),
519            context: None,
520            created_at: chrono::Utc::now(),
521            feedback_text: None,
522            feedback_score: Some(4),
523            used_graph_element_ids: None,
524            memify_metadata: None,
525        };
526        assert!(!is_eligible(&entry));
527    }
528
529    #[test]
530    fn is_eligible_rejects_invalid_score() {
531        let entry = SessionQAEntry {
532            id: Uuid::new_v4(),
533            session_id: "s".into(),
534            user_id: None,
535            question: "q".into(),
536            answer: "a".into(),
537            context: None,
538            created_at: chrono::Utc::now(),
539            feedback_text: None,
540            feedback_score: Some(0),
541            used_graph_element_ids: Some(UsedGraphElementIds {
542                node_ids: vec!["n1".into()],
543                edge_ids: vec![],
544            }),
545            memify_metadata: None,
546        };
547        assert!(!is_eligible(&entry));
548    }
549}