1use std::collections::HashSet;
13
14use cognee_session::{SessionManager, SessionStore};
15use serde::{Deserialize, Serialize};
16use tracing::debug;
17
18use crate::observability::COGNEE_SEARCH_TYPE;
19use crate::types::SearchError;
20use crate::{
21 SearchOrchestrator, SearchRequest, SearchResponse, SearchType, record_override, route_query,
22};
23
24#[derive(Debug, Clone, Default)]
30pub struct RecallOptions {
31 pub system_prompt: Option<String>,
32 pub system_prompt_path: Option<String>,
33 pub node_name: Option<Vec<String>>,
34 pub node_name_filter_operator: Option<String>,
35 pub only_context: Option<bool>,
36 pub wide_search_top_k: Option<usize>,
37 pub triplet_distance_penalty: Option<f32>,
38 pub feedback_influence: Option<f32>,
39 pub neighborhood_depth: Option<usize>,
40 pub neighborhood_seed_top_k: Option<usize>,
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
47#[serde(rename_all = "snake_case")]
48pub enum RecallSource {
49 Session,
50 Graph,
51 Trace,
52 GraphContext,
53}
54
55impl RecallSource {
56 pub fn as_str(&self) -> &'static str {
58 match self {
59 RecallSource::Session => "session",
60 RecallSource::Graph => "graph",
61 RecallSource::Trace => "trace",
62 RecallSource::GraphContext => "graph_context",
63 }
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
74#[serde(rename_all = "snake_case")]
75pub enum RecallScope {
76 Auto,
77 Graph,
78 Session,
79 Trace,
80 GraphContext,
81}
82
83impl RecallScope {
84 pub const ALL: &'static [Self] = &[Self::Graph, Self::Session, Self::Trace, Self::GraphContext];
87
88 #[cfg_attr(not(test), allow(dead_code))]
89 fn as_wire(&self) -> &'static str {
90 match self {
91 Self::Auto => "auto",
92 Self::Graph => "graph",
93 Self::Session => "session",
94 Self::Trace => "trace",
95 Self::GraphContext => "graph_context",
96 }
97 }
98
99 fn from_wire(s: &str) -> Option<Self> {
100 match s {
101 "auto" => Some(Self::Auto),
102 "graph" => Some(Self::Graph),
103 "session" => Some(Self::Session),
104 "trace" => Some(Self::Trace),
105 "graph_context" => Some(Self::GraphContext),
106 _ => None,
107 }
108 }
109
110 pub fn as_source(&self) -> Option<RecallSource> {
112 match self {
113 RecallScope::Auto => None,
114 RecallScope::Graph => Some(RecallSource::Graph),
115 RecallScope::Session => Some(RecallSource::Session),
116 RecallScope::Trace => Some(RecallSource::Trace),
117 RecallScope::GraphContext => Some(RecallSource::GraphContext),
118 }
119 }
120}
121
122#[derive(Debug, Clone)]
128pub enum ScopeInput {
129 Single(String),
130 Many(Vec<String>),
131}
132
133impl From<&str> for ScopeInput {
134 fn from(s: &str) -> Self {
135 ScopeInput::Single(s.to_string())
136 }
137}
138
139impl From<String> for ScopeInput {
140 fn from(s: String) -> Self {
141 ScopeInput::Single(s)
142 }
143}
144
145impl From<Vec<String>> for ScopeInput {
146 fn from(v: Vec<String>) -> Self {
147 ScopeInput::Many(v)
148 }
149}
150
151pub fn normalize_scope(input: Option<ScopeInput>) -> Result<Vec<RecallScope>, SearchError> {
161 let raw: Vec<String> = match input {
162 None => return Ok(vec![RecallScope::Auto]),
163 Some(ScopeInput::Single(s)) => vec![s],
164 Some(ScopeInput::Many(v)) => v,
165 };
166
167 if raw.is_empty() {
168 return Ok(vec![]);
172 }
173
174 fn is_valid_wire(s: &str) -> bool {
179 s == "all" || RecallScope::from_wire(s).is_some()
180 }
181 let unknown: Vec<&str> = raw
182 .iter()
183 .filter(|s| !is_valid_wire(s))
184 .map(String::as_str)
185 .collect();
186 if !unknown.is_empty() {
187 let valid_sorted = ["all", "auto", "graph", "graph_context", "session", "trace"];
190 return Err(SearchError::InvalidInput(format!(
194 "Unknown recall scope(s): {unknown:?}. Valid values: {valid_sorted:?}"
195 )));
196 }
197
198 if raw.iter().any(|s| s == "all") {
201 return Ok(RecallScope::ALL.to_vec());
202 }
203
204 let mut seen: HashSet<&str> = HashSet::new();
206 let mut out: Vec<RecallScope> = Vec::with_capacity(raw.len());
207 for s in &raw {
208 if seen.insert(s.as_str())
209 && let Some(scope) = RecallScope::from_wire(s)
210 {
211 out.push(scope);
212 }
213 }
214 Ok(out)
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct RecallItem {
220 pub source: RecallSource,
222 pub content: serde_json::Value,
224 pub score: f64,
227}
228
229pub async fn search_session(
234 query_text: &str,
235 session_id: Option<&str>,
236 user_id: Option<&str>,
237 top_k: usize,
238 store: Option<&dyn SessionStore>,
239) -> Result<Vec<RecallItem>, SearchError> {
240 let (Some(sid), Some(store)) = (session_id, store) else {
241 return Ok(vec![]);
245 };
246
247 let query_tokens = tokenize(query_text);
248 if query_tokens.is_empty() {
249 return Ok(vec![]);
250 }
251
252 let entries = store.get_all_qa_entries(sid, user_id).await?;
253 if entries.is_empty() {
254 return Ok(vec![]);
255 }
256
257 let mut scored: Vec<(usize, usize)> = entries
258 .iter()
259 .enumerate()
260 .map(|(idx, entry)| {
261 let entry_text = format!(
263 "{} {} {}",
264 entry.question,
265 entry.context.as_deref().unwrap_or(""),
266 entry.answer,
267 );
268 let entry_tokens = tokenize(&entry_text);
269 let overlap = query_tokens.intersection(&entry_tokens).count();
270 (idx, overlap)
271 })
272 .filter(|(_, overlap)| *overlap > 0)
273 .collect();
274
275 scored.sort_by_key(|s| std::cmp::Reverse(s.1));
276 scored.truncate(top_k);
277
278 Ok(scored
279 .into_iter()
280 .map(|(idx, overlap)| {
281 let entry = &entries[idx];
282 RecallItem {
283 source: RecallSource::Session,
284 content: serde_json::json!({
285 "question": entry.question,
286 "answer": entry.answer,
287 "context": entry.context,
288 "session_id": entry.session_id,
289 "created_at": entry.created_at.to_rfc3339(),
290 }),
291 score: overlap as f64,
292 }
293 })
294 .collect())
295}
296
297pub async fn search_trace(
303 query_text: &str,
304 session_id: Option<&str>,
305 user_id: Option<&str>,
306 top_k: usize,
307 sm: Option<&SessionManager>,
308) -> Result<Vec<RecallItem>, SearchError> {
309 let (Some(sid), Some(sm)) = (session_id, sm) else {
310 return Ok(vec![]);
311 };
312 let Some(uid) = user_id else {
315 return Ok(vec![]);
316 };
317 if uid.is_empty() {
318 return Ok(vec![]);
319 }
320
321 let query_tokens = tokenize(query_text);
322 if query_tokens.is_empty() {
323 return Ok(vec![]);
324 }
325
326 let entries = sm.get_agent_trace_session(uid, Some(sid), None).await?;
327 if entries.is_empty() {
328 return Ok(vec![]);
329 }
330
331 let mut scored: Vec<(usize, usize)> = entries
332 .iter()
333 .enumerate()
334 .map(|(idx, e)| {
335 let mut parts: Vec<String> = vec![
338 e.origin_function.clone(),
339 e.status.clone(),
340 e.memory_query.clone(),
341 e.memory_context.clone(),
342 e.session_feedback.clone(),
343 e.error_message.clone(),
344 ];
345 match serde_json::to_string(&e.method_params) {
349 Ok(s) => parts.push(s),
350 Err(_) => parts.push(format!("{:?}", e.method_params)),
351 }
352 if let Some(ref mrv) = e.method_return_value {
353 match serde_json::to_string(mrv) {
354 Ok(s) => parts.push(s),
355 Err(_) => parts.push(format!("{mrv:?}")),
356 }
357 }
358
359 let joined = parts.join(" ");
360 let entry_tokens = tokenize(&joined);
361 let overlap = query_tokens.intersection(&entry_tokens).count();
362 (idx, overlap)
363 })
364 .filter(|(_, overlap)| *overlap > 0)
365 .collect();
366
367 scored.sort_by_key(|s| std::cmp::Reverse(s.1));
368 scored.truncate(top_k);
369
370 Ok(scored
371 .into_iter()
372 .map(|(idx, overlap)| {
373 let e = &entries[idx];
374 RecallItem {
375 source: RecallSource::Trace,
376 content: serde_json::json!({
377 "trace_id": e.trace_id,
378 "origin_function": e.origin_function,
379 "status": e.status,
380 "memory_query": e.memory_query,
381 "memory_context": e.memory_context,
382 "method_params": e.method_params,
383 "method_return_value": e.method_return_value,
384 "error_message": e.error_message,
385 "session_feedback": e.session_feedback,
386 }),
387 score: overlap as f64,
388 }
389 })
390 .collect())
391}
392
393pub async fn fetch_graph_context(
397 session_id: Option<&str>,
398 user_id: Option<&str>,
399 sm: Option<&SessionManager>,
400) -> Result<Vec<RecallItem>, SearchError> {
401 let (Some(_sid), Some(sm)) = (session_id, sm) else {
402 return Ok(vec![]);
403 };
404 let snapshot_opt = sm.get_graph_context(session_id, user_id).await?;
405 match snapshot_opt {
406 Some(snapshot) if !snapshot.is_empty() => Ok(vec![RecallItem {
407 source: RecallSource::GraphContext,
408 content: serde_json::Value::String(snapshot),
409 score: 1.0,
410 }]),
411 _ => Ok(vec![]),
412 }
413}
414
415#[allow(clippy::too_many_arguments)]
420pub async fn run_graph(
421 query_text: &str,
422 query_type: Option<SearchType>,
423 datasets: Option<Vec<String>>,
424 top_k: usize,
425 auto_route: bool,
426 session_id: Option<&str>,
427 search_orchestrator: &SearchOrchestrator,
428 span: &tracing::Span,
429 options: Option<&RecallOptions>,
430) -> Result<(Vec<RecallItem>, SearchType, bool, SearchResponse), SearchError> {
431 let (search_type, auto_routed) = match (query_type, auto_route) {
434 (Some(qt), true) => {
435 let routed = route_query(query_text);
436 record_override(routed.search_type, qt);
437 (qt, false)
438 }
439 (Some(qt), false) => (qt, false),
440 (None, true) => {
441 let routed = route_query(query_text);
442 debug!(
443 search_type = ?routed.search_type,
444 confidence = routed.confidence,
445 "recall: auto-routed query"
446 );
447 (routed.search_type, true)
448 }
449 (None, false) => (SearchType::GraphCompletion, false),
450 };
451
452 span.record(COGNEE_SEARCH_TYPE, format!("{search_type:?}").as_str());
453
454 let request = SearchRequest {
455 query_text: query_text.to_string(),
456 search_type,
457 top_k: Some(top_k),
458 datasets,
459 dataset_ids: None,
460 system_prompt: options.and_then(|o| o.system_prompt.clone()),
461 system_prompt_path: options.and_then(|o| o.system_prompt_path.clone()),
462 only_context: options.and_then(|o| o.only_context),
463 use_combined_context: None,
464 session_id: session_id.map(|s| s.to_string()),
465 node_type: None,
466 node_name: options.and_then(|o| o.node_name.clone()),
467 wide_search_top_k: options.and_then(|o| o.wide_search_top_k),
468 triplet_distance_penalty: options.and_then(|o| o.triplet_distance_penalty),
469 save_interaction: None,
470 user_id: None,
471 verbose: None,
472 feedback_influence: options.and_then(|o| o.feedback_influence),
473 retriever_specific_config: None,
474 response_schema: None,
475 custom_search_type: None,
476 auto_feedback_detection: None,
477 node_name_filter_operator: options.and_then(|o| o.node_name_filter_operator.clone()),
478 neighborhood_depth: options.and_then(|o| o.neighborhood_depth),
479 neighborhood_seed_top_k: options.and_then(|o| o.neighborhood_seed_top_k),
480 summarize_context: None,
481 };
482
483 let response = search_orchestrator.search(&request).await?;
484
485 let items: Vec<RecallItem> = match &response.result {
486 crate::SearchOutput::Items(search_items) => search_items
487 .iter()
488 .enumerate()
489 .map(|(i, item)| RecallItem {
490 source: RecallSource::Graph,
491 content: serde_json::to_value(item)
492 .unwrap_or_else(|_| serde_json::Value::String(format!("{item:?}"))),
493 score: 1.0 - (i as f64 * 0.01),
494 })
495 .collect(),
496 crate::SearchOutput::Text(text) => vec![RecallItem {
497 source: RecallSource::Graph,
498 content: serde_json::Value::String(text.clone()),
499 score: 1.0,
500 }],
501 crate::SearchOutput::Texts(texts) => texts
502 .iter()
503 .enumerate()
504 .map(|(i, t)| RecallItem {
505 source: RecallSource::Graph,
506 content: serde_json::Value::String(t.clone()),
507 score: 1.0 - (i as f64 * 0.01),
508 })
509 .collect(),
510 other => vec![RecallItem {
511 source: RecallSource::Graph,
512 content: serde_json::to_value(other)
513 .unwrap_or_else(|_| serde_json::Value::String(format!("{other:?}"))),
514 score: 1.0,
515 }],
516 };
517
518 Ok((items, search_type, auto_routed, response))
519}
520
521fn tokenize(text: &str) -> HashSet<String> {
523 text.split(|c: char| !c.is_alphanumeric())
524 .filter(|w| w.len() >= 2)
525 .map(|w| w.to_lowercase())
526 .collect()
527}
528
529#[cfg(test)]
530#[allow(
531 clippy::unwrap_used,
532 clippy::expect_used,
533 reason = "test code — panics are acceptable failures"
534)]
535mod tests {
536 use super::*;
537
538 #[test]
539 fn tokenize_splits_and_lowercases() {
540 let tokens = tokenize("Hello, World! How are you?");
541 assert!(tokens.contains("hello"));
542 assert!(tokens.contains("world"));
543 assert!(tokens.contains("how"));
544 assert!(tokens.contains("are"));
545 assert!(tokens.contains("you"));
546 assert!(!tokens.contains("a"));
548 }
549
550 #[test]
551 fn tokenize_empty_string() {
552 let tokens = tokenize("");
553 assert!(tokens.is_empty());
554 }
555
556 #[test]
557 fn recall_source_serializes_correctly() {
558 let s = serde_json::to_string(&RecallSource::Session).expect("serialize");
559 assert_eq!(s, "\"session\"");
560 let g = serde_json::to_string(&RecallSource::Graph).expect("serialize");
561 assert_eq!(g, "\"graph\"");
562 }
563
564 #[test]
565 fn recall_source_trace_serializes_correctly() {
566 let t = serde_json::to_string(&RecallSource::Trace).expect("serialize");
567 assert_eq!(t, "\"trace\"");
568 }
569
570 #[test]
571 fn recall_source_graph_context_serializes_correctly() {
572 let g = serde_json::to_string(&RecallSource::GraphContext).expect("serialize");
573 assert_eq!(g, "\"graph_context\"");
574 }
575
576 #[test]
577 fn test_normalize_scope_none_returns_auto() {
578 let out = normalize_scope(None).expect("normalize");
579 assert_eq!(out, vec![RecallScope::Auto]);
580 }
581
582 #[test]
583 fn test_normalize_scope_string_passes_through() {
584 for (s, expected) in [
585 ("graph", RecallScope::Graph),
586 ("session", RecallScope::Session),
587 ("trace", RecallScope::Trace),
588 ("graph_context", RecallScope::GraphContext),
589 ("auto", RecallScope::Auto),
590 ] {
591 let out = normalize_scope(Some(ScopeInput::from(s))).expect("normalize");
592 assert_eq!(out, vec![expected], "scope={s}");
593 }
594 }
595
596 #[test]
597 fn test_normalize_scope_list_dedupes() {
598 let out = normalize_scope(Some(ScopeInput::Many(vec![
599 "session".to_string(),
600 "graph".to_string(),
601 "session".to_string(),
602 "trace".to_string(),
603 "graph".to_string(),
604 ])))
605 .expect("normalize");
606 assert_eq!(
608 out,
609 vec![RecallScope::Session, RecallScope::Graph, RecallScope::Trace,]
610 );
611 }
612
613 #[test]
614 fn test_normalize_scope_all_expands() {
615 let out = normalize_scope(Some(ScopeInput::from("all"))).expect("normalize");
616 assert_eq!(
617 out,
618 vec![
619 RecallScope::Graph,
620 RecallScope::Session,
621 RecallScope::Trace,
622 RecallScope::GraphContext,
623 ]
624 );
625 let out2 = normalize_scope(Some(ScopeInput::Many(vec![
627 "session".to_string(),
628 "all".to_string(),
629 ])))
630 .expect("normalize");
631 assert_eq!(
632 out2,
633 vec![
634 RecallScope::Graph,
635 RecallScope::Session,
636 RecallScope::Trace,
637 RecallScope::GraphContext,
638 ]
639 );
640 }
641
642 #[test]
643 fn test_normalize_scope_unknown_returns_error() {
644 let err = normalize_scope(Some(ScopeInput::from("nonsense"))).expect_err("should error");
645 match err {
646 SearchError::InvalidInput(_) => {}
647 other => panic!("expected InvalidInput, got {other:?}"),
648 }
649 }
650
651 #[test]
652 fn test_normalize_scope_error_message_matches_python() {
653 let err = normalize_scope(Some(ScopeInput::from("foo"))).expect_err("should error");
654 let msg = match err {
655 SearchError::InvalidInput(m) => m,
656 other => panic!("expected InvalidInput, got {other:?}"),
657 };
658 let expected = "Unknown recall scope(s): [\"foo\"]. Valid values: [\"all\", \"auto\", \"graph\", \"graph_context\", \"session\", \"trace\"]";
662 assert_eq!(msg, expected);
663 }
664
665 #[test]
666 fn recall_scope_all_constant_matches_canonical_order() {
667 assert_eq!(
668 RecallScope::ALL,
669 &[
670 RecallScope::Graph,
671 RecallScope::Session,
672 RecallScope::Trace,
673 RecallScope::GraphContext,
674 ]
675 );
676 }
677
678 #[test]
679 fn recall_scope_serde_round_trip() {
680 for (s, expected) in [
681 ("\"auto\"", RecallScope::Auto),
682 ("\"graph\"", RecallScope::Graph),
683 ("\"session\"", RecallScope::Session),
684 ("\"trace\"", RecallScope::Trace),
685 ("\"graph_context\"", RecallScope::GraphContext),
686 ] {
687 let parsed: RecallScope = serde_json::from_str(s).expect("deserialize");
688 assert_eq!(parsed, expected);
689 assert_eq!(serde_json::to_string(&expected).expect("serialize"), s);
690 }
691 }
692
693 #[test]
694 fn recall_scope_as_wire_matches_serde() {
695 assert_eq!(RecallScope::Auto.as_wire(), "auto");
696 assert_eq!(RecallScope::Graph.as_wire(), "graph");
697 assert_eq!(RecallScope::Session.as_wire(), "session");
698 assert_eq!(RecallScope::Trace.as_wire(), "trace");
699 assert_eq!(RecallScope::GraphContext.as_wire(), "graph_context");
700 }
701}