1use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use cognee_graph::{EdgeKey, GraphDBTrait};
24use cognee_session::{
25 SessionManager, SessionQAEntry, SessionQAUpdate, SessionStore, UsedGraphElementIds,
26};
27use thiserror::Error;
28use tracing::{info, warn};
29use uuid::Uuid;
30
31#[derive(Debug, Error)]
33pub enum FeedbackError {
34 #[error("Invalid feedback_score {0}: must be in [1, 5]")]
35 InvalidScore(i32),
36
37 #[error("Invalid alpha {0}: must be in (0, 1]")]
38 InvalidAlpha(f64),
39
40 #[error("Session error: {0}")]
41 Session(#[from] cognee_session::SessionError),
42
43 #[error("Graph error: {0}")]
44 Graph(#[from] cognee_graph::GraphDBError),
45}
46
47pub const FEEDBACK_WEIGHTS_APPLIED_KEY: &str = "feedback_weights_applied";
49
50const FEEDBACK_WEIGHT_DECIMALS: i32 = 4;
53
54pub const EDGE_ID_DELIMITER: &str = "|||";
57
58#[derive(Debug, Clone, Default)]
60pub struct FeedbackApplyResult {
61 pub processed: usize,
63 pub applied: usize,
65 pub skipped: usize,
67}
68
69pub fn normalize_feedback_score(score: i32) -> Result<f64, FeedbackError> {
73 if !(1..=5).contains(&score) {
74 return Err(FeedbackError::InvalidScore(score));
75 }
76 Ok((score as f64 - 1.0) / 4.0)
77}
78
79pub fn stream_update_weight(
84 previous_weight: f64,
85 normalized_rating: f64,
86 alpha: f64,
87) -> Result<f64, FeedbackError> {
88 if !(alpha > 0.0 && alpha <= 1.0) {
89 return Err(FeedbackError::InvalidAlpha(alpha));
90 }
91 let updated = previous_weight + alpha * (normalized_rating - previous_weight);
92 let clipped = updated.clamp(0.0, 1.0);
93 let factor = 10f64.powi(FEEDBACK_WEIGHT_DECIMALS);
94 Ok((clipped * factor).round() / factor)
95}
96
97fn is_eligible(entry: &SessionQAEntry) -> bool {
100 let score = match entry.feedback_score {
101 Some(s) if (1..=5).contains(&s) => s,
102 _ => return false,
103 };
104 let _ = score;
105
106 if let Some(meta) = &entry.memify_metadata
107 && meta.get(FEEDBACK_WEIGHTS_APPLIED_KEY).copied() == Some(true)
108 {
109 return false;
110 }
111
112 match &entry.used_graph_element_ids {
113 Some(ids) => {
114 let has_nodes = ids.node_ids.iter().any(|s| !s.is_empty());
115 let has_edges = ids.edge_ids.iter().any(|s| !s.is_empty());
116 has_nodes || has_edges
117 }
118 None => false,
119 }
120}
121
122fn dedup_sorted<'a>(ids: impl IntoIterator<Item = &'a String>) -> Vec<String> {
125 let set: HashSet<&str> = ids
126 .into_iter()
127 .map(|s| s.as_str())
128 .filter(|s| !s.is_empty())
129 .collect();
130 let mut v: Vec<String> = set.into_iter().map(|s| s.to_string()).collect();
131 v.sort();
132 v
133}
134
135fn parse_edge_id(id: &str) -> Option<EdgeKey> {
139 let parts: Vec<&str> = id.splitn(3, EDGE_ID_DELIMITER).collect();
140 if parts.len() != 3 {
141 return None;
142 }
143 Some((
144 parts[0].to_string(),
145 parts[1].to_string(),
146 parts[2].to_string(),
147 ))
148}
149
150async fn update_node_weights(
154 graph_db: &dyn GraphDBTrait,
155 ids: &[String],
156 normalized_rating: f64,
157 alpha: f64,
158) -> Result<bool, FeedbackError> {
159 if ids.is_empty() {
160 return Ok(true);
161 }
162 let existing = graph_db.get_node_feedback_weights(ids).await?;
163 let mut updates: HashMap<String, f64> = HashMap::new();
164 let mut all_found = true;
165 for id in ids {
166 match existing.get(id).copied() {
167 Some(prev) => {
168 updates.insert(
169 id.clone(),
170 stream_update_weight(prev, normalized_rating, alpha)?,
171 );
172 }
173 None => {
174 all_found = false;
175 }
176 }
177 }
178 if updates.is_empty() {
179 return Ok(false);
180 }
181 let results = graph_db.set_node_feedback_weights(&updates).await?;
182 let all_written = updates
183 .keys()
184 .all(|k| results.get(k).copied().unwrap_or(false));
185 Ok(all_found && all_written)
186}
187
188async fn update_edge_weights(
189 graph_db: &dyn GraphDBTrait,
190 edge_ids: &[String],
191 normalized_rating: f64,
192 alpha: f64,
193) -> Result<bool, FeedbackError> {
194 if edge_ids.is_empty() {
195 return Ok(true);
196 }
197 let mut keys: Vec<EdgeKey> = Vec::with_capacity(edge_ids.len());
200 let mut all_parsed = true;
201 for id in edge_ids {
202 match parse_edge_id(id) {
203 Some(k) => keys.push(k),
204 None => {
205 warn!("feedback_weights: malformed edge id {id:?}, skipping");
206 all_parsed = false;
207 }
208 }
209 }
210 if keys.is_empty() {
211 return Ok(false);
212 }
213 let existing = graph_db.get_edge_feedback_weights(&keys).await?;
214 let mut updates: HashMap<EdgeKey, f64> = HashMap::new();
215 let mut all_found = true;
216 for k in &keys {
217 match existing.get(k).copied() {
218 Some(prev) => {
219 updates.insert(
220 k.clone(),
221 stream_update_weight(prev, normalized_rating, alpha)?,
222 );
223 }
224 None => {
225 all_found = false;
226 }
227 }
228 }
229 if updates.is_empty() {
230 return Ok(false);
231 }
232 let results = graph_db.set_edge_feedback_weights(&updates).await?;
233 let all_written = updates
234 .keys()
235 .all(|k| results.get(k).copied().unwrap_or(false));
236 Ok(all_parsed && all_found && all_written)
237}
238
239async fn mark_feedback_processed(
242 session_manager: &SessionManager,
243 session_id: &str,
244 user_id: &str,
245 qa_id: &str,
246 current_metadata: Option<&HashMap<String, bool>>,
247 success: bool,
248) -> Result<(), FeedbackError> {
249 let mut meta: HashMap<String, bool> = current_metadata.cloned().unwrap_or_default();
250 meta.insert(FEEDBACK_WEIGHTS_APPLIED_KEY.to_string(), success);
251
252 session_manager
253 .update_qa(
254 Some(session_id),
255 Some(user_id),
256 qa_id,
257 SessionQAUpdate {
258 memify_metadata: Some(Some(meta)),
259 ..Default::default()
260 },
261 )
262 .await?;
263 Ok(())
264}
265
266#[allow(clippy::too_many_arguments)]
268pub async fn apply_feedback_weights_pipeline(
269 session_ids: &[String],
270 owner_id: Uuid,
271 alpha: f64,
272 graph_db: &dyn GraphDBTrait,
273 session_store: Arc<dyn SessionStore>,
274 session_manager: Arc<SessionManager>,
275) -> Result<FeedbackApplyResult, FeedbackError> {
276 if !(alpha > 0.0 && alpha <= 1.0) {
277 return Err(FeedbackError::InvalidAlpha(alpha));
278 }
279
280 let user_id_str = owner_id.to_string();
281 let mut result = FeedbackApplyResult::default();
282
283 for session_id in session_ids {
284 let entries = session_store
285 .get_all_qa_entries(session_id, Some(&user_id_str))
286 .await?;
287
288 for entry in entries {
289 if !is_eligible(&entry) {
290 result.skipped += 1;
291 continue;
292 }
293
294 let score = match entry.feedback_score {
295 Some(s) => s,
296 None => {
297 result.skipped += 1;
299 continue;
300 }
301 };
302 let normalized = match normalize_feedback_score(score) {
303 Ok(v) => v,
304 Err(_) => {
305 result.skipped += 1;
306 continue;
307 }
308 };
309
310 let used = entry
311 .used_graph_element_ids
312 .as_ref()
313 .cloned()
314 .unwrap_or(UsedGraphElementIds::default());
315 let node_ids = dedup_sorted(used.node_ids.iter());
316 let edge_ids = dedup_sorted(used.edge_ids.iter());
317
318 if node_ids.is_empty() && edge_ids.is_empty() {
319 mark_feedback_processed(
322 &session_manager,
323 session_id,
324 &user_id_str,
325 &entry.id.to_string(),
326 entry.memify_metadata.as_ref(),
327 false,
328 )
329 .await?;
330 result.skipped += 1;
331 continue;
332 }
333
334 let node_success = update_node_weights(graph_db, &node_ids, normalized, alpha).await?;
335 let edge_success = update_edge_weights(graph_db, &edge_ids, normalized, alpha).await?;
336 let success = node_success && edge_success;
337
338 mark_feedback_processed(
339 &session_manager,
340 session_id,
341 &user_id_str,
342 &entry.id.to_string(),
343 entry.memify_metadata.as_ref(),
344 success,
345 )
346 .await?;
347
348 info!(
349 qa_id = %entry.id,
350 session_id = session_id,
351 nodes = node_ids.len(),
352 edges = edge_ids.len(),
353 applied = success,
354 "feedback_weights: processed QA entry"
355 );
356
357 result.processed += 1;
358 if success {
359 result.applied += 1;
360 }
361 }
362 }
363
364 info!(
365 processed = result.processed,
366 applied = result.applied,
367 skipped = result.skipped,
368 "feedback_weights: stage complete"
369 );
370 Ok(result)
371}
372
373#[cfg(test)]
374#[allow(
375 clippy::unwrap_used,
376 clippy::expect_used,
377 reason = "test code — panics are acceptable failures"
378)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn normalize_scores_endpoints() {
384 assert!((normalize_feedback_score(1).unwrap() - 0.0).abs() < 1e-9);
385 assert!((normalize_feedback_score(3).unwrap() - 0.5).abs() < 1e-9);
386 assert!((normalize_feedback_score(5).unwrap() - 1.0).abs() < 1e-9);
387 }
388
389 #[test]
390 fn normalize_scores_rejects_out_of_range() {
391 assert!(normalize_feedback_score(0).is_err());
392 assert!(normalize_feedback_score(6).is_err());
393 assert!(normalize_feedback_score(-1).is_err());
394 }
395
396 #[test]
397 fn stream_update_midpoint() {
398 let w = stream_update_weight(0.5, 1.0, 0.1).unwrap();
399 assert!((w - 0.55).abs() < 1e-9, "got {w}");
400 }
401
402 #[test]
403 fn stream_update_zero_stays_zero() {
404 let w = stream_update_weight(0.0, 0.0, 1.0).unwrap();
405 assert!((w - 0.0).abs() < 1e-9);
406 }
407
408 #[test]
409 fn stream_update_clips_high() {
410 let w = stream_update_weight(0.9, 2.0, 0.5).unwrap();
412 assert!((w - 1.0).abs() < 1e-9);
413 }
414
415 #[test]
416 fn stream_update_clips_low() {
417 let w = stream_update_weight(0.1, -1.0, 0.5).unwrap();
418 assert!((w - 0.0).abs() < 1e-9);
419 }
420
421 #[test]
422 fn stream_update_rejects_bad_alpha() {
423 assert!(stream_update_weight(0.5, 0.5, 0.0).is_err());
424 assert!(stream_update_weight(0.5, 0.5, 1.1).is_err());
425 assert!(stream_update_weight(0.5, 0.5, -0.1).is_err());
426 }
427
428 #[test]
429 fn stream_update_rounds_to_4dp() {
430 let w = stream_update_weight(0.1, 0.5, 0.1).unwrap();
432 assert!((w - 0.14).abs() < 1e-12);
433 }
434
435 #[test]
436 fn parse_edge_id_ok() {
437 let k = parse_edge_id("a|||b|||rel").unwrap();
438 assert_eq!(k.0, "a");
439 assert_eq!(k.1, "b");
440 assert_eq!(k.2, "rel");
441 }
442
443 #[test]
444 fn parse_edge_id_with_extra_delim_in_rel() {
445 let k = parse_edge_id("a|||b|||rel|||extra").unwrap();
446 assert_eq!(k.2, "rel|||extra");
447 }
448
449 #[test]
450 fn parse_edge_id_malformed() {
451 assert!(parse_edge_id("no_delim").is_none());
452 assert!(parse_edge_id("only|||one").is_none());
453 }
454
455 #[test]
456 fn dedup_sorted_works() {
457 let v = [
458 "b".to_string(),
459 "a".to_string(),
460 "".to_string(),
461 "a".to_string(),
462 ];
463 let result = dedup_sorted(v.iter());
464 assert_eq!(result, vec!["a".to_string(), "b".to_string()]);
465 }
466
467 #[test]
468 fn is_eligible_valid_node_ids() {
469 let entry = SessionQAEntry {
470 id: Uuid::new_v4(),
471 session_id: "s".into(),
472 user_id: None,
473 question: "q".into(),
474 answer: "a".into(),
475 context: None,
476 created_at: chrono::Utc::now(),
477 feedback_text: None,
478 feedback_score: Some(4),
479 used_graph_element_ids: Some(UsedGraphElementIds {
480 node_ids: vec!["n1".into()],
481 edge_ids: vec![],
482 }),
483 memify_metadata: None,
484 };
485 assert!(is_eligible(&entry));
486 }
487
488 #[test]
489 fn is_eligible_rejects_already_applied() {
490 let mut meta = HashMap::new();
491 meta.insert("feedback_weights_applied".to_string(), true);
492 let entry = SessionQAEntry {
493 id: Uuid::new_v4(),
494 session_id: "s".into(),
495 user_id: None,
496 question: "q".into(),
497 answer: "a".into(),
498 context: None,
499 created_at: chrono::Utc::now(),
500 feedback_text: None,
501 feedback_score: Some(4),
502 used_graph_element_ids: Some(UsedGraphElementIds {
503 node_ids: vec!["n1".into()],
504 edge_ids: vec![],
505 }),
506 memify_metadata: Some(meta),
507 };
508 assert!(!is_eligible(&entry));
509 }
510
511 #[test]
512 fn is_eligible_rejects_missing_ids() {
513 let entry = SessionQAEntry {
514 id: Uuid::new_v4(),
515 session_id: "s".into(),
516 user_id: None,
517 question: "q".into(),
518 answer: "a".into(),
519 context: None,
520 created_at: chrono::Utc::now(),
521 feedback_text: None,
522 feedback_score: Some(4),
523 used_graph_element_ids: None,
524 memify_metadata: None,
525 };
526 assert!(!is_eligible(&entry));
527 }
528
529 #[test]
530 fn is_eligible_rejects_invalid_score() {
531 let entry = SessionQAEntry {
532 id: Uuid::new_v4(),
533 session_id: "s".into(),
534 user_id: None,
535 question: "q".into(),
536 answer: "a".into(),
537 context: None,
538 created_at: chrono::Utc::now(),
539 feedback_text: None,
540 feedback_score: Some(0),
541 used_graph_element_ids: Some(UsedGraphElementIds {
542 node_ids: vec!["n1".into()],
543 edge_ids: vec![],
544 }),
545 memify_metadata: None,
546 };
547 assert!(!is_eligible(&entry));
548 }
549}