1use std::borrow::Cow;
2use std::collections::{HashMap, HashSet};
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, TimeZone, Utc};
7use cognee_embedding::EmbeddingEngine;
8use cognee_graph::{GraphDBTrait, NodeData};
9use cognee_llm::{GenerationOptions, Llm, LlmExt, Message};
10use cognee_vector::VectorDB;
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14
15use cognee_session::SessionContext;
16
17use crate::graph_retrieval::{
18 DEFAULT_TRIPLET_DISTANCE_PENALTY, GraphRetrievalConfig, RankedGraphEdge,
19 brute_force_triplet_search,
20};
21use crate::retrievers::SearchRetriever;
22use crate::types::{
23 SearchContext, SearchError, SearchItem, SearchOutput, SearchParams, SearchType,
24};
25use crate::utils::{build_messages_with_history, render_graph_user_prompt, resolve_system_prompt};
26
27const DEFAULT_TOP_K: usize = 10;
28const DEFAULT_WIDE_SEARCH_TOP_K: usize = 100;
29const TEMPORAL_DATA_TYPE: &str = "Event";
30const TEMPORAL_FIELD_NAME: &str = "name";
31const DEFAULT_TEMPORAL_INTERVAL_PROMPT: &str = "You are tasked with identifying relevant time periods where the answer to a given query should be searched.\nCurrent date is: `{time_now}`. Determine relevant period(s) and return structured intervals.\n\nExtraction rules:\n\n1. Query without specific timestamp: use the time period with starts_at set to None and ends_at set to now.\n2. Explicit time intervals: If the query specifies a range (e.g., from 2010 to 2020, between January and March 2023), extract both start and end dates. Always assign the earlier date to starts_at and the later date to ends_at.\n3. Single timestamp: If the query refers to one specific moment (e.g., in 2015, on March 5, 2022), set starts_at and ends_at to that same timestamp.\n4. Open-ended time references: For phrases such as \"before X\" or \"after X\", represent the unspecified side as None. For example: before 2009 → starts_at: None, ends_at: 2009; after 2009 → starts_at: 2009, ends_at: None.\n5. Current-time references (\"now\", \"current\", \"today\"): If the query explicitly refers to the present, set both starts_at and ends_at to now (the ingestion timestamp).\n6. \"Who is\" and \"Who was\" questions: These imply a general identity or biographical inquiry without a specific temporal scope. Set both starts_at and ends_at to None.\n7. Ordering rule: Always ensure the earlier date is assigned to starts_at and the later date to ends_at.\n8. No temporal information: If no valid or inferable time reference is found, set both starts_at and ends_at to None.";
32
33#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
34struct QueryInterval {
35 starts_at: Option<String>,
36 ends_at: Option<String>,
37}
38
39#[derive(Debug, Clone)]
40struct ParsedInterval {
41 start: Option<DateTime<Utc>>,
42 end: Option<DateTime<Utc>>,
43}
44
45impl QueryInterval {
46 fn parse(self) -> ParsedInterval {
47 ParsedInterval {
48 start: self
49 .starts_at
50 .as_deref()
51 .and_then(|value| parse_bound(value, false)),
52 end: self
53 .ends_at
54 .as_deref()
55 .and_then(|value| parse_bound(value, true)),
56 }
57 }
58}
59
60pub struct TemporalRetriever {
61 vector_db: Arc<dyn VectorDB>,
62 embedding_engine: Arc<dyn EmbeddingEngine>,
63 graph_db: Arc<dyn GraphDBTrait>,
64 llm: Arc<dyn Llm>,
65 top_k: usize,
66 wide_search_top_k: usize,
67 triplet_distance_penalty: f32,
68 feedback_influence: f32,
69 temporal_interval_prompt: Option<String>,
70 system_prompt: Option<String>,
71 system_prompt_path: Option<String>,
72 user_prompt_template: Option<String>,
73 generation_options: Option<GenerationOptions>,
74}
75
76impl TemporalRetriever {
77 #[allow(clippy::too_many_arguments)]
78 pub fn new(
79 vector_db: Arc<dyn VectorDB>,
80 embedding_engine: Arc<dyn EmbeddingEngine>,
81 graph_db: Arc<dyn GraphDBTrait>,
82 llm: Arc<dyn Llm>,
83 top_k: Option<usize>,
84 wide_search_top_k: Option<usize>,
85 triplet_distance_penalty: Option<f32>,
86 temporal_interval_prompt: Option<String>,
87 system_prompt: Option<String>,
88 system_prompt_path: Option<String>,
89 user_prompt_template: Option<String>,
90 generation_options: Option<GenerationOptions>,
91 ) -> Self {
92 Self {
93 vector_db,
94 embedding_engine,
95 graph_db,
96 llm,
97 top_k: top_k.unwrap_or(DEFAULT_TOP_K),
98 wide_search_top_k: wide_search_top_k.unwrap_or(DEFAULT_WIDE_SEARCH_TOP_K),
99 triplet_distance_penalty: triplet_distance_penalty
100 .unwrap_or(DEFAULT_TRIPLET_DISTANCE_PENALTY),
101 feedback_influence: 0.0,
102 temporal_interval_prompt,
103 system_prompt,
104 system_prompt_path,
105 user_prompt_template,
106 generation_options,
107 }
108 }
109
110 async fn extract_interval(&self, query: &str) -> Result<Option<ParsedInterval>, SearchError> {
111 let now = chrono::Local::now().format("%d-%m-%Y").to_string();
112 let prompt_template = self
113 .temporal_interval_prompt
114 .as_deref()
115 .unwrap_or(DEFAULT_TEMPORAL_INTERVAL_PROMPT);
116 let system_prompt = prompt_template.replace("{time_now}", &now);
117
118 let interval = match self
119 .llm
120 .create_structured_output_with_messages::<QueryInterval>(
121 vec![
122 Message::system(system_prompt),
123 Message::user(query.to_string()),
124 ],
125 self.generation_options.clone(),
126 )
127 .await
128 {
129 Ok(interval) => interval,
130 Err(_) => return Ok(None),
131 };
132
133 let parsed = interval.parse();
134 if parsed.start.is_none() && parsed.end.is_none() {
135 return Ok(None);
136 }
137
138 Ok(Some(parsed))
139 }
140
141 fn get_graph_retrieval_config(&self, params: &SearchParams) -> GraphRetrievalConfig {
142 GraphRetrievalConfig {
143 top_k: params.top_k_or(self.top_k),
144 wide_search_top_k: params.wide_search_top_k_or(self.wide_search_top_k),
145 triplet_distance_penalty: params
146 .triplet_distance_penalty_or(self.triplet_distance_penalty),
147 feedback_influence: params.feedback_influence_or(self.feedback_influence),
148 node_type: params.node_type.clone(),
149 node_name: params.node_name.clone(),
150 node_name_filter_operator: params
151 .node_name_filter_operator
152 .as_deref()
153 .unwrap_or("OR")
154 .to_string(),
155 }
156 }
157
158 async fn get_ranked_graph_edges(
159 &self,
160 query: &str,
161 params: &SearchParams,
162 ) -> Result<Vec<RankedGraphEdge>, SearchError> {
163 brute_force_triplet_search(
164 query,
165 self.vector_db.as_ref(),
166 self.embedding_engine.as_ref(),
167 self.graph_db.as_ref(),
168 &self.get_graph_retrieval_config(params),
169 )
170 .await
171 }
172
173 fn ranked_edges_to_context(ranked_edges: Vec<RankedGraphEdge>) -> SearchContext {
174 ranked_edges
175 .into_iter()
176 .map(|edge| SearchItem {
177 id: None,
178 score: Some(edge.score),
179 payload: json!({
180 "source_id": edge.source_id,
181 "target_id": edge.target_id,
182 "relationship": edge.relationship_name,
183 "source_name": edge.source_name,
184 "target_name": edge.target_name,
185 "source_text": edge.source_text,
186 "target_text": edge.target_text,
187 "source_description": edge.source_description,
188 "target_description": edge.target_description,
189 }),
190 })
191 .collect()
192 }
193
194 async fn get_fallback_context(
195 &self,
196 query: &str,
197 params: &SearchParams,
198 ) -> Result<SearchContext, SearchError> {
199 let ranked_edges = self.get_ranked_graph_edges(query, params).await?;
200 Ok(Self::ranked_edges_to_context(ranked_edges))
201 }
202
203 async fn rank_temporal_events(
204 &self,
205 query: &str,
206 event_ids: &HashSet<String>,
207 ranked_edges: &[RankedGraphEdge],
208 ) -> Result<Vec<(String, f32)>, SearchError> {
209 let mut scores = HashMap::<String, f32>::new();
210
211 for edge in ranked_edges {
212 if event_ids.contains(&edge.source_id) {
213 let score = scores.entry(edge.source_id.clone()).or_insert(edge.score);
214 *score = score.max(edge.score);
215 }
216 if event_ids.contains(&edge.target_id) {
217 let score = scores.entry(edge.target_id.clone()).or_insert(edge.score);
218 *score = score.max(edge.score);
219 }
220 }
221
222 if self
223 .vector_db
224 .has_collection(TEMPORAL_DATA_TYPE, TEMPORAL_FIELD_NAME)
225 .await?
226 {
227 let query_embeddings = self.embedding_engine.embed(&[query]).await?;
228 let query_vector = query_embeddings.into_iter().next().ok_or_else(|| {
229 SearchError::InvalidInput("embedding engine returned no vectors".to_string())
230 })?;
231
232 let semantic_results = self
233 .vector_db
234 .search_similar(
235 TEMPORAL_DATA_TYPE,
236 TEMPORAL_FIELD_NAME,
237 &query_vector,
238 self.wide_search_top_k.max(self.top_k),
239 )
240 .await?;
241
242 for result in semantic_results {
243 let event_id = result.id.to_string();
244 if !event_ids.contains(&event_id) {
245 continue;
246 }
247
248 let score = scores.entry(event_id).or_insert(result.score);
249 *score = score.max(result.score);
250 }
251 }
252
253 let mut ranked = event_ids
254 .iter()
255 .map(|event_id| {
256 (
257 event_id.clone(),
258 scores.get(event_id).copied().unwrap_or(0.0),
259 )
260 })
261 .collect::<Vec<_>>();
262
263 ranked.sort_by(|left, right| {
264 right
265 .1
266 .partial_cmp(&left.1)
267 .unwrap_or(std::cmp::Ordering::Equal)
268 .then_with(|| left.0.cmp(&right.0))
269 });
270
271 Ok(ranked)
272 }
273
274 fn temporal_context_to_text(context: &SearchContext) -> String {
275 context
276 .iter()
277 .map(|item| {
278 if item.payload.get("event_id").is_some() {
279 let name = item
280 .payload
281 .get("event_name")
282 .and_then(Value::as_str)
283 .unwrap_or("Unnamed event");
284 let description = item
285 .payload
286 .get("event_description")
287 .and_then(Value::as_str)
288 .unwrap_or("No description");
289 let time = item
290 .payload
291 .get("event_time")
292 .and_then(Value::as_str)
293 .unwrap_or("unknown time");
294
295 return format!("{name} ({time}): {description}");
296 }
297
298 let source = item
299 .payload
300 .get("source_name")
301 .and_then(Value::as_str)
302 .or_else(|| item.payload.get("source_id").and_then(Value::as_str))
303 .unwrap_or("unknown_source");
304 let target = item
305 .payload
306 .get("target_name")
307 .and_then(Value::as_str)
308 .or_else(|| item.payload.get("target_id").and_then(Value::as_str))
309 .unwrap_or("unknown_target");
310 let relationship = item
311 .payload
312 .get("relationship")
313 .and_then(Value::as_str)
314 .or_else(|| {
315 item.payload
316 .get("relationship_name")
317 .and_then(Value::as_str)
318 })
319 .unwrap_or("related_to");
320
321 format!("{source} -[{relationship}]-> {target}")
322 })
323 .collect::<Vec<_>>()
324 .join("\n")
325 }
326}
327
328#[async_trait]
329impl SearchRetriever for TemporalRetriever {
330 fn search_type(&self) -> SearchType {
331 SearchType::Temporal
332 }
333
334 async fn get_context(
335 &self,
336 query: &str,
337 params: &SearchParams,
338 ) -> Result<SearchContext, SearchError> {
339 if self.graph_db.is_empty().await? {
340 return Ok(vec![]);
341 }
342
343 let Some(interval) = self.extract_interval(query).await? else {
344 return self.get_fallback_context(query, params).await;
345 };
346
347 let (candidate_timestamps, _) = self
349 .graph_db
350 .get_filtered_graph_data(&HashMap::from([(
351 Cow::Borrowed("type"),
352 vec![json!("Timestamp")],
353 )]))
354 .await?;
355
356 let interval_from_ms = interval.start.map(|dt| dt.timestamp_millis());
357 let interval_to_ms = interval.end.map(|dt| dt.timestamp_millis());
358
359 let matching_ts_ids: Vec<String> = candidate_timestamps
360 .into_iter()
361 .filter_map(|(id, props)| {
362 let time_at = props.get("time_at")?.as_i64()?;
363 is_within_interval_ms(time_at, interval_from_ms, interval_to_ms).then_some(id)
364 })
365 .collect();
366
367 let mut event_node_ids = HashSet::new();
369 for ts_id in &matching_ts_ids {
370 for node_props in self.graph_db.get_neighbors(ts_id).await? {
371 let node_type = node_props.get("type").and_then(|v| v.as_str());
372 match node_type {
373 Some("Event") => {
374 if let Some(id) = node_props.get("id").and_then(|v| v.as_str()) {
375 event_node_ids.insert(id.to_string());
376 }
377 }
378 Some("Interval") => {
379 if let Some(interval_id) = node_props.get("id").and_then(|v| v.as_str()) {
381 for inner_props in self.graph_db.get_neighbors(interval_id).await? {
382 if inner_props.get("type").and_then(|v| v.as_str()) == Some("Event")
383 && let Some(id) = inner_props.get("id").and_then(|v| v.as_str())
384 {
385 event_node_ids.insert(id.to_string());
386 }
387 }
388 }
389 }
390 _ => {}
391 }
392 }
393 }
394
395 if event_node_ids.is_empty() {
396 return self.get_fallback_context(query, params).await;
397 }
398
399 let ranked_edges = self.get_ranked_graph_edges(query, params).await?;
400 let ranked_events = self
401 .rank_temporal_events(query, &event_node_ids, &ranked_edges)
402 .await?;
403
404 let event_id_list: Vec<String> = ranked_events
406 .iter()
407 .take(params.top_k_or(self.top_k))
408 .map(|(id, _)| id.clone())
409 .collect();
410 let event_nodes = self.graph_db.get_nodes(&event_id_list).await?;
411 let nodes_by_id: HashMap<String, NodeData> =
412 event_id_list.into_iter().zip(event_nodes).collect();
413
414 let mut temporal_context = Vec::new();
415
416 for (event_id, score) in ranked_events.into_iter().take(params.top_k_or(self.top_k)) {
417 let Some(event_node) = nodes_by_id.get(&event_id) else {
418 continue;
419 };
420
421 temporal_context.push(SearchItem {
422 id: None,
423 score: Some(score),
424 payload: json!({
425 "event_id": event_id,
426 "event_name": extract_node_name(event_node),
427 "event_description": extract_node_description(event_node),
428 }),
429 });
430 }
431
432 if temporal_context.is_empty() {
433 return Ok(Self::ranked_edges_to_context(ranked_edges));
434 }
435
436 Ok(temporal_context)
437 }
438
439 async fn get_completion(
440 &self,
441 query: &str,
442 context: Option<SearchContext>,
443 session: &SessionContext,
444 params: &SearchParams,
445 ) -> Result<SearchOutput, SearchError> {
446 let completion_context = match context {
447 Some(existing_context) => existing_context,
448 None => self.get_context(query, params).await?,
449 };
450
451 let system_prompt = resolve_system_prompt(
452 params
453 .system_prompt
454 .as_deref()
455 .or(self.system_prompt.as_deref()),
456 params
457 .system_prompt_path
458 .as_deref()
459 .or(self.system_prompt_path.as_deref()),
460 )?;
461
462 let user_prompt = render_graph_user_prompt(
463 self.user_prompt_template.as_deref(),
464 query,
465 &Self::temporal_context_to_text(&completion_context),
466 );
467
468 let messages = build_messages_with_history(system_prompt, user_prompt, session);
469
470 if let Some(schema) = ¶ms.response_schema {
471 let structured_value = self
472 .llm
473 .create_structured_output_with_messages_raw(
474 messages,
475 schema,
476 self.generation_options.clone(),
477 )
478 .await
479 .map_err(|e| SearchError::LlmError(e.to_string()))?;
480 Ok(SearchOutput::Structured(structured_value))
481 } else {
482 let completion = self
483 .llm
484 .generate(messages, self.generation_options.clone())
485 .await?;
486 Ok(SearchOutput::Text(completion.content))
487 }
488 }
489}
490
491fn extract_node_name(node_data: &NodeData) -> String {
492 node_data
493 .get("name")
494 .and_then(Value::as_str)
495 .or_else(|| node_data.get("title").and_then(Value::as_str))
496 .unwrap_or("Unnamed event")
497 .to_string()
498}
499
500fn extract_node_description(node_data: &NodeData) -> String {
501 node_data
502 .get("description")
503 .and_then(Value::as_str)
504 .or_else(|| node_data.get("text").and_then(Value::as_str))
505 .unwrap_or("")
506 .to_string()
507}
508
509fn is_within_interval_ms(time_at_ms: i64, from_ms: Option<i64>, to_ms: Option<i64>) -> bool {
511 from_ms.is_none_or(|from| time_at_ms >= from) && to_ms.is_none_or(|to| time_at_ms <= to)
512}
513
514fn parse_bound(input: &str, is_end: bool) -> Option<DateTime<Utc>> {
515 let trimmed = input.trim();
516 if trimmed.is_empty() {
517 return None;
518 }
519
520 if let Ok(timestamp) = DateTime::parse_from_rfc3339(trimmed) {
521 return Some(timestamp.with_timezone(&Utc));
522 }
523
524 if let Ok(naive_dt) = NaiveDateTime::parse_from_str(trimmed, "%Y-%m-%d %H:%M:%S") {
525 return Some(Utc.from_utc_datetime(&naive_dt));
526 }
527
528 if let Ok(date) = NaiveDate::parse_from_str(trimmed, "%Y-%m-%d") {
529 return to_datetime(date, is_end);
530 }
531
532 if trimmed.len() == 7 {
533 let month_candidate = format!("{trimmed}-01");
534 if let Ok(date) = NaiveDate::parse_from_str(&month_candidate, "%Y-%m-%d") {
535 return if is_end {
536 let (next_year, next_month) = if date.month() == 12 {
537 (date.year() + 1, 1)
538 } else {
539 (date.year(), date.month() + 1)
540 };
541
542 let next_month_start = NaiveDate::from_ymd_opt(next_year, next_month, 1)?;
543 let month_end = next_month_start.pred_opt()?;
544 to_datetime(month_end, true)
545 } else {
546 to_datetime(date, false)
547 };
548 }
549 }
550
551 if trimmed.len() == 4
552 && trimmed.chars().all(|character| character.is_ascii_digit())
553 && let Ok(year) = trimmed.parse::<i32>()
554 {
555 let date = if is_end {
556 NaiveDate::from_ymd_opt(year, 12, 31)?
557 } else {
558 NaiveDate::from_ymd_opt(year, 1, 1)?
559 };
560
561 return to_datetime(date, is_end);
562 }
563
564 None
565}
566
567fn to_datetime(date: NaiveDate, is_end: bool) -> Option<DateTime<Utc>> {
568 let naive_dt = if is_end {
569 date.and_hms_opt(23, 59, 59)?
570 } else {
571 date.and_hms_opt(0, 0, 0)?
572 };
573
574 Some(Utc.from_utc_datetime(&naive_dt))
575}
576
577#[cfg(test)]
578#[allow(
579 clippy::unwrap_used,
580 clippy::expect_used,
581 reason = "test code — panics are acceptable failures"
582)]
583mod tests {
584 use std::borrow::Cow;
585 use std::collections::{HashMap, HashSet};
586 use std::sync::{Arc, Mutex};
587
588 use async_trait::async_trait;
589 use cognee_embedding::EmbeddingResult;
590 use cognee_embedding::engine::EmbeddingEngine;
591 use cognee_graph::{EdgeData, GraphDBResult, GraphDBTrait, GraphNode, NodeData};
592 use cognee_llm::{
593 GenerationOptions, GenerationResponse, Llm, LlmError, LlmResult, Message, TokenUsage,
594 };
595 use cognee_vector::{SearchResult, VectorDB, VectorDBResult, VectorPoint};
596
597 use chrono::{TimeZone, Utc};
598 use serde_json::{Value, json};
599 use uuid::Uuid;
600
601 use cognee_session::SessionContext;
602
603 use super::{QueryInterval, TemporalRetriever};
604 use crate::graph_retrieval::RankedGraphEdge;
605 use crate::retrievers::SearchRetriever;
606 use crate::types::{SearchItem, SearchOutput, SearchParams};
607
608 struct TestEmbeddingEngine;
609
610 #[async_trait]
611 impl EmbeddingEngine for TestEmbeddingEngine {
612 async fn embed(&self, _texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
613 Ok(vec![vec![0.3, 0.7]])
614 }
615
616 fn dimension(&self) -> usize {
617 2
618 }
619
620 fn batch_size(&self) -> usize {
621 16
622 }
623
624 fn max_sequence_length(&self) -> usize {
625 512
626 }
627 }
628
629 struct TestVectorDb {
630 collections: HashMap<String, Vec<SearchResult>>,
631 }
632
633 impl TestVectorDb {
634 fn key(data_type: &str, field_name: &str) -> String {
635 format!("{data_type}_{field_name}")
636 }
637 }
638
639 #[async_trait]
640 impl VectorDB for TestVectorDb {
641 async fn create_collection(
642 &self,
643 _data_type: &str,
644 _field_name: &str,
645 _dimension: usize,
646 ) -> VectorDBResult<()> {
647 Ok(())
648 }
649
650 async fn has_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<bool> {
651 Ok(self
652 .collections
653 .contains_key(&Self::key(data_type, field_name)))
654 }
655
656 async fn index_points(
657 &self,
658 _data_type: &str,
659 _field_name: &str,
660 _points: &[VectorPoint],
661 ) -> VectorDBResult<()> {
662 Ok(())
663 }
664
665 async fn search_similar(
666 &self,
667 data_type: &str,
668 field_name: &str,
669 _query_vector: &[f32],
670 top_k: usize,
671 ) -> VectorDBResult<Vec<SearchResult>> {
672 Ok(self
673 .collections
674 .get(&Self::key(data_type, field_name))
675 .cloned()
676 .unwrap_or_default()
677 .into_iter()
678 .take(top_k)
679 .collect())
680 }
681
682 async fn delete_collection(
683 &self,
684 _data_type: &str,
685 _field_name: &str,
686 ) -> VectorDBResult<()> {
687 Ok(())
688 }
689
690 async fn delete_points(
691 &self,
692 _data_type: &str,
693 _field_name: &str,
694 _point_ids: &[Uuid],
695 ) -> VectorDBResult<()> {
696 Ok(())
697 }
698
699 async fn collection_size(
700 &self,
701 data_type: &str,
702 field_name: &str,
703 ) -> VectorDBResult<usize> {
704 Ok(self
705 .collections
706 .get(&Self::key(data_type, field_name))
707 .map(|items| items.len())
708 .unwrap_or_default())
709 }
710 }
711
712 struct TestGraphDb {
713 nodes: Vec<GraphNode>,
714 edges: Vec<EdgeData>,
715 neighbors: HashMap<String, Vec<NodeData>>,
717 }
718
719 #[async_trait]
720 impl GraphDBTrait for TestGraphDb {
721 async fn initialize(&self) -> GraphDBResult<()> {
722 Ok(())
723 }
724
725 async fn is_empty(&self) -> GraphDBResult<bool> {
726 Ok(self.nodes.is_empty())
727 }
728
729 async fn query(
730 &self,
731 _query: &str,
732 _params: Option<HashMap<Cow<'static, str>, Value>>,
733 ) -> GraphDBResult<Vec<Vec<Value>>> {
734 Ok(vec![])
735 }
736
737 async fn delete_graph(&self) -> GraphDBResult<()> {
738 Ok(())
739 }
740
741 async fn has_node(&self, _node_id: &str) -> GraphDBResult<bool> {
742 Ok(false)
743 }
744
745 async fn add_node_raw(&self, _node: serde_json::Value) -> GraphDBResult<()> {
746 Ok(())
747 }
748
749 async fn add_nodes_raw(&self, _nodes: Vec<serde_json::Value>) -> GraphDBResult<()> {
750 Ok(())
751 }
752
753 async fn delete_node(&self, _node_id: &str) -> GraphDBResult<()> {
754 Ok(())
755 }
756
757 async fn delete_nodes(&self, _node_ids: &[String]) -> GraphDBResult<()> {
758 Ok(())
759 }
760
761 async fn get_node(&self, _node_id: &str) -> GraphDBResult<Option<NodeData>> {
762 Ok(None)
763 }
764
765 async fn get_nodes(&self, node_ids: &[String]) -> GraphDBResult<Vec<NodeData>> {
766 let nodes_map: HashMap<&str, &NodeData> = self
767 .nodes
768 .iter()
769 .map(|(id, data)| (id.as_str(), data))
770 .collect();
771 Ok(node_ids
772 .iter()
773 .filter_map(|id| nodes_map.get(id.as_str()).map(|d| (*d).clone()))
774 .collect())
775 }
776
777 async fn has_edge(
778 &self,
779 _source_id: &str,
780 _target_id: &str,
781 _relationship_name: &str,
782 ) -> GraphDBResult<bool> {
783 Ok(false)
784 }
785
786 async fn has_edges(&self, _edges: &[EdgeData]) -> GraphDBResult<Vec<EdgeData>> {
787 Ok(vec![])
788 }
789
790 async fn add_edge(
791 &self,
792 _source_id: &str,
793 _target_id: &str,
794 _relationship_name: &str,
795 _properties: Option<HashMap<Cow<'static, str>, Value>>,
796 ) -> GraphDBResult<()> {
797 Ok(())
798 }
799
800 async fn add_edges(&self, _edges: &[EdgeData]) -> GraphDBResult<()> {
801 Ok(())
802 }
803
804 async fn get_edges(&self, _node_id: &str) -> GraphDBResult<Vec<EdgeData>> {
805 Ok(vec![])
806 }
807
808 async fn get_neighbors(&self, node_id: &str) -> GraphDBResult<Vec<NodeData>> {
809 Ok(self.neighbors.get(node_id).cloned().unwrap_or_default())
810 }
811
812 async fn get_connections(
813 &self,
814 _node_id: &str,
815 ) -> GraphDBResult<Vec<(NodeData, HashMap<Cow<'static, str>, Value>, NodeData)>> {
816 Ok(vec![])
817 }
818
819 async fn get_graph_data(&self) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
820 Ok((self.nodes.clone(), self.edges.clone()))
821 }
822
823 async fn get_graph_metrics(
824 &self,
825 _include_optional: bool,
826 ) -> GraphDBResult<HashMap<Cow<'static, str>, Value>> {
827 Ok(HashMap::new())
828 }
829
830 async fn get_filtered_graph_data(
831 &self,
832 _attribute_filters: &HashMap<Cow<'static, str>, Vec<Value>>,
833 ) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
834 Ok((self.nodes.clone(), self.edges.clone()))
835 }
836
837 async fn get_nodeset_subgraph(
838 &self,
839 _node_type: &str,
840 _node_names: &[String],
841 _node_name_filter_operator: &str,
842 ) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
843 Ok((self.nodes.clone(), self.edges.clone()))
844 }
845 }
846
847 #[derive(Default)]
848 struct TestLlm {
849 completion_response: String,
850 interval_response: Option<QueryInterval>,
851 fail_structured_output: bool,
852 last_messages: Mutex<Vec<Message>>,
853 structured_completion_response: Mutex<Option<Value>>,
857 last_structured_messages: Mutex<Vec<Message>>,
859 }
860
861 #[async_trait]
862 impl Llm for TestLlm {
863 async fn generate(
864 &self,
865 messages: Vec<Message>,
866 _options: Option<GenerationOptions>,
867 ) -> LlmResult<GenerationResponse> {
868 self.last_messages.lock().unwrap().clone_from(&messages);
869
870 Ok(GenerationResponse {
871 content: self.completion_response.clone(),
872 model: "test-model".to_string(),
873 usage: Some(TokenUsage {
874 prompt_tokens: 1,
875 completion_tokens: 1,
876 total_tokens: 2,
877 }),
878 finish_reason: Some("stop".to_string()),
879 })
880 }
881
882 async fn create_structured_output_with_messages_raw(
883 &self,
884 messages: Vec<Message>,
885 _json_schema: &serde_json::Value,
886 _options: Option<GenerationOptions>,
887 ) -> LlmResult<serde_json::Value> {
888 self.last_structured_messages
889 .lock()
890 .unwrap()
891 .clone_from(&messages);
892
893 if self.fail_structured_output {
894 return Err(LlmError::ConfigError("forced failure".to_string()));
895 }
896
897 if let Some(value) = self.structured_completion_response.lock().unwrap().clone() {
899 return Ok(value);
900 }
901
902 let response = self
903 .interval_response
904 .clone()
905 .ok_or_else(|| LlmError::ConfigError("missing interval response".to_string()))?;
906
907 serde_json::to_value(response).map_err(|error| LlmError::ConfigError(error.to_string()))
908 }
909
910 fn model(&self) -> &str {
911 "test-model"
912 }
913 }
914
915 fn event_node_data(id: &str, name: &str) -> NodeData {
916 HashMap::from([
917 (Cow::Borrowed("id"), json!(id)),
918 (Cow::Borrowed("name"), json!(name)),
919 (Cow::Borrowed("type"), json!("Event")),
920 (
921 Cow::Borrowed("description"),
922 json!(format!("Description for {name}")),
923 ),
924 ])
925 }
926
927 fn timestamp_node(id: &str, time_at_ms: i64) -> GraphNode {
928 (
929 id.to_string(),
930 HashMap::from([
931 (Cow::Borrowed("id"), json!(id)),
932 (Cow::Borrowed("type"), json!("Timestamp")),
933 (Cow::Borrowed("time_at"), json!(time_at_ms)),
934 ]),
935 )
936 }
937
938 fn event_graph_node(id: &str, name: &str) -> GraphNode {
939 (id.to_string(), event_node_data(id, name))
940 }
941
942 #[tokio::test]
943 async fn returns_temporal_event_context_when_interval_matches() {
944 let launch_event_ms: i64 = 1710460800000;
946 let old_event_ms: i64 = 1578614400000;
948
949 let ts_in_2024 = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa";
950 let ts_in_2020 = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb";
951 let event_launch = "11111111-1111-1111-1111-111111111111";
952 let event_old = "22222222-2222-2222-2222-222222222222";
953
954 let vector_db = Arc::new(TestVectorDb {
955 collections: HashMap::from([
956 (
957 TestVectorDb::key("Entity", "name"),
958 vec![SearchResult {
959 id: uuid::Uuid::new_v4(),
960 score: 0.8,
961 metadata: HashMap::from([(String::from("type"), json!("entity"))]),
962 }],
963 ),
964 (
965 TestVectorDb::key("Event", "name"),
966 vec![SearchResult {
967 id: uuid::Uuid::parse_str(event_launch).unwrap(),
968 score: 0.95,
969 metadata: HashMap::new(),
970 }],
971 ),
972 ]),
973 });
974
975 let embedding_engine = Arc::new(TestEmbeddingEngine);
976 let graph_db = Arc::new(TestGraphDb {
977 nodes: vec![
978 timestamp_node(ts_in_2024, launch_event_ms),
979 timestamp_node(ts_in_2020, old_event_ms),
980 event_graph_node(event_launch, "Launch event"),
981 event_graph_node(event_old, "Old event"),
982 ],
983 edges: vec![
984 (
985 event_launch.to_string(),
986 ts_in_2024.to_string(),
987 "at".to_string(),
988 HashMap::new(),
989 ),
990 (
991 event_old.to_string(),
992 ts_in_2020.to_string(),
993 "at".to_string(),
994 HashMap::new(),
995 ),
996 ],
997 neighbors: HashMap::from([
998 (
1000 ts_in_2024.to_string(),
1001 vec![event_node_data(event_launch, "Launch event")],
1002 ),
1003 (
1005 ts_in_2020.to_string(),
1006 vec![event_node_data(event_old, "Old event")],
1007 ),
1008 ]),
1009 });
1010 let llm = Arc::new(TestLlm {
1011 completion_response: "temporal answer".to_string(),
1012 interval_response: Some(QueryInterval {
1013 starts_at: Some("2024-01-01".to_string()),
1014 ends_at: Some("2024-12-31".to_string()),
1015 }),
1016 fail_structured_output: false,
1017 last_messages: Mutex::new(vec![]),
1018 structured_completion_response: Mutex::new(None),
1019 last_structured_messages: Mutex::new(vec![]),
1020 });
1021
1022 let retriever = TemporalRetriever::new(
1023 vector_db,
1024 embedding_engine,
1025 graph_db,
1026 llm,
1027 Some(5),
1028 Some(10),
1029 Some(0.0),
1030 None,
1031 None,
1032 None,
1033 None,
1034 None,
1035 );
1036
1037 let context = retriever
1038 .get_context("What happened in 2024?", &SearchParams::default())
1039 .await
1040 .unwrap();
1041
1042 assert_eq!(context.len(), 1);
1043 assert_eq!(
1044 context[0].payload.get("event_name").and_then(Value::as_str),
1045 Some("Launch event")
1046 );
1047 }
1048
1049 #[test]
1052 fn parse_bound_datetime_space_format() {
1053 use chrono::{TimeZone, Utc};
1054 let result = super::parse_bound("2024-01-15 10:30:00", false);
1055 assert_eq!(
1056 result,
1057 Some(Utc.with_ymd_and_hms(2024, 1, 15, 10, 30, 0).unwrap())
1058 );
1059 }
1060
1061 #[test]
1062 fn parse_bound_rfc3339_format() {
1063 use chrono::{TimeZone, Utc};
1064 let result = super::parse_bound("2024-01-15T10:30:00Z", false);
1065 assert_eq!(
1066 result,
1067 Some(Utc.with_ymd_and_hms(2024, 1, 15, 10, 30, 0).unwrap())
1068 );
1069 }
1070
1071 #[test]
1072 fn parse_bound_date_only_start() {
1073 use chrono::{TimeZone, Utc};
1074 let result = super::parse_bound("2024-03-15", false);
1075 assert_eq!(
1076 result,
1077 Some(Utc.with_ymd_and_hms(2024, 3, 15, 0, 0, 0).unwrap())
1078 );
1079 }
1080
1081 #[test]
1082 fn parse_bound_date_only_end() {
1083 use chrono::{TimeZone, Utc};
1084 let result = super::parse_bound("2024-03-15", true);
1085 assert_eq!(
1086 result,
1087 Some(Utc.with_ymd_and_hms(2024, 3, 15, 23, 59, 59).unwrap())
1088 );
1089 }
1090
1091 #[test]
1092 fn parse_bound_month_start() {
1093 use chrono::{TimeZone, Utc};
1094 let result = super::parse_bound("2024-03", false);
1095 assert_eq!(
1096 result,
1097 Some(Utc.with_ymd_and_hms(2024, 3, 1, 0, 0, 0).unwrap())
1098 );
1099 }
1100
1101 #[test]
1102 fn parse_bound_month_end() {
1103 use chrono::{TimeZone, Utc};
1104 let result = super::parse_bound("2024-03", true);
1105 assert_eq!(
1106 result,
1107 Some(Utc.with_ymd_and_hms(2024, 3, 31, 23, 59, 59).unwrap())
1108 );
1109 }
1110
1111 #[test]
1112 fn parse_bound_month_end_leap_year() {
1113 use chrono::{TimeZone, Utc};
1114 let result = super::parse_bound("2024-02", true);
1115 assert_eq!(
1116 result,
1117 Some(Utc.with_ymd_and_hms(2024, 2, 29, 23, 59, 59).unwrap())
1118 );
1119 }
1120
1121 #[test]
1122 fn parse_bound_month_end_non_leap_year() {
1123 use chrono::{TimeZone, Utc};
1124 let result = super::parse_bound("2023-02", true);
1125 assert_eq!(
1126 result,
1127 Some(Utc.with_ymd_and_hms(2023, 2, 28, 23, 59, 59).unwrap())
1128 );
1129 }
1130
1131 #[test]
1132 fn parse_bound_month_end_december_wrap() {
1133 use chrono::{TimeZone, Utc};
1134 let result = super::parse_bound("2024-12", true);
1135 assert_eq!(
1136 result,
1137 Some(Utc.with_ymd_and_hms(2024, 12, 31, 23, 59, 59).unwrap())
1138 );
1139 }
1140
1141 #[test]
1142 fn parse_bound_year_start() {
1143 use chrono::{TimeZone, Utc};
1144 let result = super::parse_bound("2024", false);
1145 assert_eq!(
1146 result,
1147 Some(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap())
1148 );
1149 }
1150
1151 #[test]
1152 fn parse_bound_year_end() {
1153 use chrono::{TimeZone, Utc};
1154 let result = super::parse_bound("2024", true);
1155 assert_eq!(
1156 result,
1157 Some(Utc.with_ymd_and_hms(2024, 12, 31, 23, 59, 59).unwrap())
1158 );
1159 }
1160
1161 #[test]
1162 fn parse_bound_empty_and_whitespace_returns_none() {
1163 assert_eq!(super::parse_bound("", false), None);
1164 assert_eq!(super::parse_bound(" ", false), None);
1165 }
1166
1167 #[test]
1168 fn parse_bound_invalid_input_returns_none() {
1169 assert_eq!(super::parse_bound("not-a-date", false), None);
1170 assert_eq!(super::parse_bound("abc", false), None);
1171 }
1172
1173 #[test]
1176 fn is_within_interval_ms_basic_cases() {
1177 use super::is_within_interval_ms;
1178
1179 assert!(is_within_interval_ms(500, Some(100), Some(1000)));
1181 assert!(is_within_interval_ms(100, Some(100), Some(1000)));
1183 assert!(is_within_interval_ms(1000, Some(100), Some(1000)));
1185 assert!(!is_within_interval_ms(50, Some(100), Some(1000)));
1187 assert!(!is_within_interval_ms(1500, Some(100), Some(1000)));
1189 }
1190
1191 #[test]
1192 fn is_within_interval_ms_open_ended_bounds() {
1193 use super::is_within_interval_ms;
1194
1195 assert!(is_within_interval_ms(50, None, Some(1000)));
1197 assert!(!is_within_interval_ms(1500, None, Some(1000)));
1198
1199 assert!(is_within_interval_ms(1500, Some(100), None));
1201 assert!(!is_within_interval_ms(50, Some(100), None));
1202
1203 assert!(is_within_interval_ms(0, None, None));
1205 assert!(is_within_interval_ms(i64::MAX, None, None));
1206 assert!(is_within_interval_ms(i64::MIN, None, None));
1207 }
1208
1209 #[test]
1212 fn query_interval_parse_both_bounds() {
1213 use chrono::{TimeZone, Utc};
1214
1215 let qi = QueryInterval {
1216 starts_at: Some("2024-01-01".to_string()),
1217 ends_at: Some("2024-12-31".to_string()),
1218 };
1219 let parsed = qi.parse();
1220 assert_eq!(
1221 parsed.start,
1222 Some(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap())
1223 );
1224 assert_eq!(
1225 parsed.end,
1226 Some(Utc.with_ymd_and_hms(2024, 12, 31, 23, 59, 59).unwrap())
1227 );
1228 }
1229
1230 #[test]
1231 fn query_interval_parse_none_bounds() {
1232 let qi = QueryInterval {
1233 starts_at: None,
1234 ends_at: None,
1235 };
1236 let parsed = qi.parse();
1237 assert!(parsed.start.is_none());
1238 assert!(parsed.end.is_none());
1239 }
1240
1241 #[test]
1242 fn query_interval_parse_partial_bounds() {
1243 use chrono::{TimeZone, Utc};
1244
1245 let qi = QueryInterval {
1247 starts_at: Some("2024-06".to_string()),
1248 ends_at: None,
1249 };
1250 let parsed = qi.parse();
1251 assert_eq!(
1252 parsed.start,
1253 Some(Utc.with_ymd_and_hms(2024, 6, 1, 0, 0, 0).unwrap())
1254 );
1255 assert!(parsed.end.is_none());
1256
1257 let qi = QueryInterval {
1259 starts_at: None,
1260 ends_at: Some("2024".to_string()),
1261 };
1262 let parsed = qi.parse();
1263 assert!(parsed.start.is_none());
1264 assert_eq!(
1265 parsed.end,
1266 Some(Utc.with_ymd_and_hms(2024, 12, 31, 23, 59, 59).unwrap())
1267 );
1268 }
1269
1270 #[tokio::test]
1271 async fn falls_back_to_graph_context_when_interval_extraction_fails() {
1272 let vector_db = Arc::new(TestVectorDb {
1273 collections: HashMap::from([(
1274 TestVectorDb::key("Entity", "name"),
1275 vec![SearchResult {
1276 id: uuid::Uuid::parse_str("33333333-3333-3333-3333-333333333333").unwrap(),
1277 score: 0.9,
1278 metadata: HashMap::new(),
1279 }],
1280 )]),
1281 });
1282
1283 let embedding_engine = Arc::new(TestEmbeddingEngine);
1284 let graph_db = Arc::new(TestGraphDb {
1285 nodes: vec![
1286 (
1287 "33333333-3333-3333-3333-333333333333".to_string(),
1288 HashMap::from([
1289 (
1290 Cow::Borrowed("id"),
1291 json!("33333333-3333-3333-3333-333333333333"),
1292 ),
1293 (Cow::Borrowed("name"), json!("Entity A")),
1294 ]),
1295 ),
1296 (
1297 "44444444-4444-4444-4444-444444444444".to_string(),
1298 HashMap::from([
1299 (
1300 Cow::Borrowed("id"),
1301 json!("44444444-4444-4444-4444-444444444444"),
1302 ),
1303 (Cow::Borrowed("name"), json!("Entity B")),
1304 ]),
1305 ),
1306 ],
1307 edges: vec![(
1308 "33333333-3333-3333-3333-333333333333".to_string(),
1309 "44444444-4444-4444-4444-444444444444".to_string(),
1310 "connected_to".to_string(),
1311 HashMap::new(),
1312 )],
1313 neighbors: HashMap::new(),
1314 });
1315 let llm = Arc::new(TestLlm {
1316 completion_response: "fallback answer".to_string(),
1317 interval_response: None,
1318 fail_structured_output: true,
1319 last_messages: Mutex::new(vec![]),
1320 structured_completion_response: Mutex::new(None),
1321 last_structured_messages: Mutex::new(vec![]),
1322 });
1323
1324 let retriever = TemporalRetriever::new(
1325 vector_db,
1326 embedding_engine,
1327 graph_db,
1328 llm,
1329 Some(3),
1330 Some(10),
1331 Some(0.0),
1332 None,
1333 None,
1334 None,
1335 None,
1336 None,
1337 );
1338
1339 let context = retriever
1340 .get_context("What happened?", &SearchParams::default())
1341 .await
1342 .unwrap();
1343 assert_eq!(context.len(), 1);
1344 assert_eq!(
1345 context[0]
1346 .payload
1347 .get("relationship")
1348 .and_then(Value::as_str),
1349 Some("connected_to")
1350 );
1351 }
1352
1353 fn build_retriever_with_llm(llm: TestLlm) -> TemporalRetriever {
1354 TemporalRetriever::new(
1355 Arc::new(TestVectorDb {
1356 collections: HashMap::new(),
1357 }),
1358 Arc::new(TestEmbeddingEngine),
1359 Arc::new(TestGraphDb {
1360 nodes: vec![],
1361 edges: vec![],
1362 neighbors: HashMap::new(),
1363 }),
1364 Arc::new(llm),
1365 Some(5),
1366 Some(10),
1367 Some(0.0),
1368 None,
1369 None,
1370 None,
1371 None,
1372 None,
1373 )
1374 }
1375
1376 #[tokio::test]
1377 async fn extract_interval_returns_parsed_interval_from_llm() {
1378 let llm = TestLlm {
1379 completion_response: String::new(),
1380 interval_response: Some(QueryInterval {
1381 starts_at: Some("2024-01-01".into()),
1382 ends_at: Some("2024-12-31".into()),
1383 }),
1384 fail_structured_output: false,
1385 last_messages: Mutex::new(vec![]),
1386 ..TestLlm::default()
1387 };
1388 let retriever = build_retriever_with_llm(llm);
1389
1390 let result = retriever
1391 .extract_interval("What happened in 2024?")
1392 .await
1393 .unwrap();
1394
1395 let parsed = result.expect("should return Some(ParsedInterval)");
1396 assert_eq!(
1397 parsed.start,
1398 Some(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap())
1399 );
1400 assert_eq!(
1401 parsed.end,
1402 Some(Utc.with_ymd_and_hms(2024, 12, 31, 23, 59, 59).unwrap())
1403 );
1404 }
1405
1406 #[tokio::test]
1407 async fn extract_interval_returns_none_when_llm_returns_none_none() {
1408 let llm = TestLlm {
1409 completion_response: String::new(),
1410 interval_response: Some(QueryInterval {
1411 starts_at: None,
1412 ends_at: None,
1413 }),
1414 fail_structured_output: false,
1415 last_messages: Mutex::new(vec![]),
1416 ..TestLlm::default()
1417 };
1418 let retriever = build_retriever_with_llm(llm);
1419
1420 let result = retriever
1421 .extract_interval("Who is Einstein?")
1422 .await
1423 .unwrap();
1424
1425 assert!(
1426 result.is_none(),
1427 "both fields None means no interval detected"
1428 );
1429 }
1430
1431 #[tokio::test]
1432 async fn extract_interval_returns_none_when_llm_fails() {
1433 let llm = TestLlm {
1434 completion_response: String::new(),
1435 interval_response: None,
1436 fail_structured_output: true,
1437 last_messages: Mutex::new(vec![]),
1438 ..TestLlm::default()
1439 };
1440 let retriever = build_retriever_with_llm(llm);
1441
1442 let result = retriever.extract_interval("What happened?").await.unwrap();
1443
1444 assert!(result.is_none(), "error should be swallowed gracefully");
1445 }
1446
1447 #[tokio::test]
1448 async fn extract_interval_with_only_starts_at() {
1449 let llm = TestLlm {
1450 completion_response: String::new(),
1451 interval_response: Some(QueryInterval {
1452 starts_at: Some("2024-01-01".into()),
1453 ends_at: None,
1454 }),
1455 fail_structured_output: false,
1456 last_messages: Mutex::new(vec![]),
1457 ..TestLlm::default()
1458 };
1459 let retriever = build_retriever_with_llm(llm);
1460
1461 let result = retriever
1462 .extract_interval("What happened after 2024?")
1463 .await
1464 .unwrap();
1465
1466 let parsed = result.expect("should return Some(ParsedInterval)");
1467 assert_eq!(
1468 parsed.start,
1469 Some(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap())
1470 );
1471 assert_eq!(parsed.end, None);
1472 }
1473
1474 #[tokio::test]
1475 async fn extract_interval_with_only_ends_at() {
1476 let llm = TestLlm {
1477 completion_response: String::new(),
1478 interval_response: Some(QueryInterval {
1479 starts_at: None,
1480 ends_at: Some("2024-12-31".into()),
1481 }),
1482 fail_structured_output: false,
1483 last_messages: Mutex::new(vec![]),
1484 ..TestLlm::default()
1485 };
1486 let retriever = build_retriever_with_llm(llm);
1487
1488 let result = retriever
1489 .extract_interval("What happened before 2025?")
1490 .await
1491 .unwrap();
1492
1493 let parsed = result.expect("should return Some(ParsedInterval)");
1494 assert_eq!(parsed.start, None);
1495 assert_eq!(
1496 parsed.end,
1497 Some(Utc.with_ymd_and_hms(2024, 12, 31, 23, 59, 59).unwrap())
1498 );
1499 }
1500
1501 fn build_retriever(
1504 vector_db: Arc<dyn VectorDB>,
1505 graph_db: Arc<dyn GraphDBTrait>,
1506 llm: Arc<dyn Llm>,
1507 ) -> TemporalRetriever {
1508 TemporalRetriever::new(
1509 vector_db,
1510 Arc::new(TestEmbeddingEngine),
1511 graph_db,
1512 llm,
1513 Some(10),
1514 Some(100),
1515 Some(0.0),
1516 None,
1517 None,
1518 None,
1519 None,
1520 None,
1521 )
1522 }
1523
1524 #[tokio::test]
1525 async fn get_context_with_time_from_only() {
1526 let ts_2020 = "aa000000-0000-0000-0000-000000000001";
1528 let ts_2024_jan = "aa000000-0000-0000-0000-000000000002";
1529 let ts_2024_jul = "aa000000-0000-0000-0000-000000000003";
1530 let ev_old = "bb000000-0000-0000-0000-000000000001";
1531 let ev_jan = "bb000000-0000-0000-0000-000000000002";
1532 let ev_jul = "bb000000-0000-0000-0000-000000000003";
1533
1534 let graph_db = Arc::new(TestGraphDb {
1535 nodes: vec![
1536 timestamp_node(ts_2020, 1592179200000),
1537 timestamp_node(ts_2024_jan, 1705276800000),
1538 timestamp_node(ts_2024_jul, 1721001600000),
1539 event_graph_node(ev_old, "Old event"),
1540 event_graph_node(ev_jan, "Jan event"),
1541 event_graph_node(ev_jul, "Jul event"),
1542 ],
1543 edges: vec![
1544 (
1545 ev_old.to_string(),
1546 ts_2020.to_string(),
1547 "at".to_string(),
1548 HashMap::new(),
1549 ),
1550 (
1551 ev_jan.to_string(),
1552 ts_2024_jan.to_string(),
1553 "at".to_string(),
1554 HashMap::new(),
1555 ),
1556 (
1557 ev_jul.to_string(),
1558 ts_2024_jul.to_string(),
1559 "at".to_string(),
1560 HashMap::new(),
1561 ),
1562 ],
1563 neighbors: HashMap::from([
1564 (
1565 ts_2020.to_string(),
1566 vec![event_node_data(ev_old, "Old event")],
1567 ),
1568 (
1569 ts_2024_jan.to_string(),
1570 vec![event_node_data(ev_jan, "Jan event")],
1571 ),
1572 (
1573 ts_2024_jul.to_string(),
1574 vec![event_node_data(ev_jul, "Jul event")],
1575 ),
1576 ]),
1577 });
1578
1579 let vector_db = Arc::new(TestVectorDb {
1580 collections: HashMap::from([(
1581 TestVectorDb::key("Event", "name"),
1582 vec![
1583 SearchResult {
1584 id: uuid::Uuid::parse_str(ev_jan).unwrap(),
1585 score: 0.9,
1586 metadata: HashMap::new(),
1587 },
1588 SearchResult {
1589 id: uuid::Uuid::parse_str(ev_jul).unwrap(),
1590 score: 0.85,
1591 metadata: HashMap::new(),
1592 },
1593 ],
1594 )]),
1595 });
1596
1597 let llm = Arc::new(TestLlm {
1598 completion_response: String::new(),
1599 interval_response: Some(QueryInterval {
1600 starts_at: Some("2024-01-01".to_string()),
1601 ends_at: None,
1602 }),
1603 fail_structured_output: false,
1604 last_messages: Mutex::new(vec![]),
1605 ..TestLlm::default()
1606 });
1607
1608 let retriever = build_retriever(vector_db, graph_db, llm);
1609
1610 let context = retriever
1611 .get_context("What happened since 2024?", &SearchParams::default())
1612 .await
1613 .unwrap();
1614
1615 assert_eq!(context.len(), 2, "should have 2 items (both 2024 events)");
1616
1617 let event_names: Vec<&str> = context
1618 .iter()
1619 .filter_map(|item| item.payload.get("event_name").and_then(Value::as_str))
1620 .collect();
1621 assert!(
1622 event_names.contains(&"Jan event"),
1623 "Jan event should be present"
1624 );
1625 assert!(
1626 event_names.contains(&"Jul event"),
1627 "Jul event should be present"
1628 );
1629 assert!(
1630 !event_names.contains(&"Old event"),
1631 "Old event should NOT be present"
1632 );
1633 }
1634
1635 #[tokio::test]
1636 async fn get_context_with_time_to_only() {
1637 let ts_2020 = "aa000000-0000-0000-0000-000000000001";
1638 let ts_2024_jan = "aa000000-0000-0000-0000-000000000002";
1639 let ts_2024_jul = "aa000000-0000-0000-0000-000000000003";
1640 let ev_old = "bb000000-0000-0000-0000-000000000001";
1641 let ev_jan = "bb000000-0000-0000-0000-000000000002";
1642 let ev_jul = "bb000000-0000-0000-0000-000000000003";
1643
1644 let graph_db = Arc::new(TestGraphDb {
1645 nodes: vec![
1646 timestamp_node(ts_2020, 1592179200000),
1647 timestamp_node(ts_2024_jan, 1705276800000),
1648 timestamp_node(ts_2024_jul, 1721001600000),
1649 event_graph_node(ev_old, "Old event"),
1650 event_graph_node(ev_jan, "Jan event"),
1651 event_graph_node(ev_jul, "Jul event"),
1652 ],
1653 edges: vec![
1654 (
1655 ev_old.to_string(),
1656 ts_2020.to_string(),
1657 "at".to_string(),
1658 HashMap::new(),
1659 ),
1660 (
1661 ev_jan.to_string(),
1662 ts_2024_jan.to_string(),
1663 "at".to_string(),
1664 HashMap::new(),
1665 ),
1666 (
1667 ev_jul.to_string(),
1668 ts_2024_jul.to_string(),
1669 "at".to_string(),
1670 HashMap::new(),
1671 ),
1672 ],
1673 neighbors: HashMap::from([
1674 (
1675 ts_2020.to_string(),
1676 vec![event_node_data(ev_old, "Old event")],
1677 ),
1678 (
1679 ts_2024_jan.to_string(),
1680 vec![event_node_data(ev_jan, "Jan event")],
1681 ),
1682 (
1683 ts_2024_jul.to_string(),
1684 vec![event_node_data(ev_jul, "Jul event")],
1685 ),
1686 ]),
1687 });
1688
1689 let vector_db = Arc::new(TestVectorDb {
1690 collections: HashMap::from([(
1691 TestVectorDb::key("Event", "name"),
1692 vec![SearchResult {
1693 id: uuid::Uuid::parse_str(ev_old).unwrap(),
1694 score: 0.88,
1695 metadata: HashMap::new(),
1696 }],
1697 )]),
1698 });
1699
1700 let llm = Arc::new(TestLlm {
1701 completion_response: String::new(),
1702 interval_response: Some(QueryInterval {
1703 starts_at: None,
1704 ends_at: Some("2021-12-31".to_string()),
1705 }),
1706 fail_structured_output: false,
1707 last_messages: Mutex::new(vec![]),
1708 ..TestLlm::default()
1709 });
1710
1711 let retriever = build_retriever(vector_db, graph_db, llm);
1712
1713 let context = retriever
1714 .get_context("What happened before 2022?", &SearchParams::default())
1715 .await
1716 .unwrap();
1717
1718 assert_eq!(context.len(), 1, "should have 1 item (only the 2020 event)");
1719 assert_eq!(
1720 context[0].payload.get("event_name").and_then(Value::as_str),
1721 Some("Old event")
1722 );
1723 }
1724
1725 #[tokio::test]
1726 async fn get_context_falls_back_when_no_events_in_range() {
1727 let ts_2020 = "aa000000-0000-0000-0000-000000000010";
1728 let ts_2021 = "aa000000-0000-0000-0000-000000000011";
1729 let ev_2020 = "bb000000-0000-0000-0000-000000000010";
1730 let ev_2021 = "bb000000-0000-0000-0000-000000000011";
1731 let entity_a = "cc000000-0000-0000-0000-000000000001";
1732 let entity_b = "cc000000-0000-0000-0000-000000000002";
1733
1734 let graph_db = Arc::new(TestGraphDb {
1735 nodes: vec![
1736 timestamp_node(ts_2020, 1577836800000), timestamp_node(ts_2021, 1609459200000), event_graph_node(ev_2020, "Event 2020"),
1739 event_graph_node(ev_2021, "Event 2021"),
1740 (
1741 entity_a.to_string(),
1742 HashMap::from([
1743 (Cow::Borrowed("id"), json!(entity_a)),
1744 (Cow::Borrowed("name"), json!("Entity A")),
1745 (Cow::Borrowed("type"), json!("Entity")),
1746 ]),
1747 ),
1748 (
1749 entity_b.to_string(),
1750 HashMap::from([
1751 (Cow::Borrowed("id"), json!(entity_b)),
1752 (Cow::Borrowed("name"), json!("Entity B")),
1753 (Cow::Borrowed("type"), json!("Entity")),
1754 ]),
1755 ),
1756 ],
1757 edges: vec![
1758 (
1759 ev_2020.to_string(),
1760 ts_2020.to_string(),
1761 "at".to_string(),
1762 HashMap::new(),
1763 ),
1764 (
1765 ev_2021.to_string(),
1766 ts_2021.to_string(),
1767 "at".to_string(),
1768 HashMap::new(),
1769 ),
1770 (
1771 entity_a.to_string(),
1772 entity_b.to_string(),
1773 "connected_to".to_string(),
1774 HashMap::new(),
1775 ),
1776 ],
1777 neighbors: HashMap::from([
1778 (
1779 ts_2020.to_string(),
1780 vec![event_node_data(ev_2020, "Event 2020")],
1781 ),
1782 (
1783 ts_2021.to_string(),
1784 vec![event_node_data(ev_2021, "Event 2021")],
1785 ),
1786 ]),
1787 });
1788
1789 let vector_db = Arc::new(TestVectorDb {
1790 collections: HashMap::from([(
1791 TestVectorDb::key("Entity", "name"),
1792 vec![SearchResult {
1793 id: uuid::Uuid::parse_str(entity_a).unwrap(),
1794 score: 0.9,
1795 metadata: HashMap::new(),
1796 }],
1797 )]),
1798 });
1799
1800 let llm = Arc::new(TestLlm {
1801 completion_response: String::new(),
1802 interval_response: Some(QueryInterval {
1803 starts_at: Some("2030".to_string()),
1804 ends_at: Some("2031".to_string()),
1805 }),
1806 fail_structured_output: false,
1807 last_messages: Mutex::new(vec![]),
1808 ..TestLlm::default()
1809 });
1810
1811 let retriever = build_retriever(vector_db, graph_db, llm);
1812
1813 let context = retriever
1814 .get_context("What happened in 2030?", &SearchParams::default())
1815 .await
1816 .unwrap();
1817
1818 assert!(
1820 !context.is_empty(),
1821 "fallback should produce at least one result"
1822 );
1823 assert!(
1824 context
1825 .iter()
1826 .any(|item| item.payload.get("relationship").is_some()),
1827 "fallback context items should have 'relationship' in payload"
1828 );
1829 assert!(
1830 context
1831 .iter()
1832 .all(|item| item.payload.get("event_id").is_none()),
1833 "fallback context should NOT have 'event_id' (those are temporal items)"
1834 );
1835 }
1836
1837 #[tokio::test]
1838 async fn get_context_on_empty_graph() {
1839 let graph_db = Arc::new(TestGraphDb {
1840 nodes: vec![],
1841 edges: vec![],
1842 neighbors: HashMap::new(),
1843 });
1844
1845 let vector_db = Arc::new(TestVectorDb {
1846 collections: HashMap::new(),
1847 });
1848
1849 let llm = Arc::new(TestLlm {
1851 completion_response: String::new(),
1852 interval_response: None,
1853 fail_structured_output: false,
1854 last_messages: Mutex::new(vec![]),
1855 ..TestLlm::default()
1856 });
1857
1858 let retriever = build_retriever(vector_db, graph_db, llm);
1859
1860 let context = retriever
1861 .get_context("Anything?", &SearchParams::default())
1862 .await
1863 .unwrap();
1864
1865 assert!(
1866 context.is_empty(),
1867 "empty graph should return empty context"
1868 );
1869 }
1870
1871 #[tokio::test]
1872 async fn get_context_respects_top_k() {
1873 let mut nodes = Vec::new();
1874 let mut edges = Vec::new();
1875 let mut neighbors = HashMap::new();
1876 let mut vector_results = Vec::new();
1877
1878 for i in 1..=5 {
1879 let ts_id = format!("aa000000-0000-0000-0000-0000000000{i:02}");
1880 let ev_id = format!("bb000000-0000-0000-0000-0000000000{i:02}");
1881 let ev_name = format!("Event {i}");
1882 let time_ms = 1704067200000_i64 + (i as i64 - 1) * 30 * 86400 * 1000;
1884
1885 nodes.push(timestamp_node(&ts_id, time_ms));
1886 nodes.push(event_graph_node(&ev_id, &ev_name));
1887 edges.push((
1888 ev_id.clone(),
1889 ts_id.clone(),
1890 "at".to_string(),
1891 HashMap::new(),
1892 ));
1893 neighbors.insert(ts_id, vec![event_node_data(&ev_id, &ev_name)]);
1894 vector_results.push(SearchResult {
1895 id: uuid::Uuid::parse_str(&ev_id).unwrap(),
1896 score: 0.9 - (i as f32 * 0.01),
1897 metadata: HashMap::new(),
1898 });
1899 }
1900
1901 let graph_db = Arc::new(TestGraphDb {
1902 nodes,
1903 edges,
1904 neighbors,
1905 });
1906
1907 let vector_db = Arc::new(TestVectorDb {
1908 collections: HashMap::from([(TestVectorDb::key("Event", "name"), vector_results)]),
1909 });
1910
1911 let llm = Arc::new(TestLlm {
1912 completion_response: String::new(),
1913 interval_response: Some(QueryInterval {
1914 starts_at: Some("2024-01-01".to_string()),
1915 ends_at: Some("2024-12-31".to_string()),
1916 }),
1917 fail_structured_output: false,
1918 last_messages: Mutex::new(vec![]),
1919 ..TestLlm::default()
1920 });
1921
1922 let retriever = build_retriever(vector_db, graph_db, llm);
1923
1924 let params = SearchParams {
1925 top_k: Some(2),
1926 ..Default::default()
1927 };
1928
1929 let context = retriever
1930 .get_context("What happened in 2024?", ¶ms)
1931 .await
1932 .unwrap();
1933
1934 assert_eq!(context.len(), 2, "top_k=2 should limit results to 2 items");
1935 }
1936
1937 #[tokio::test]
1938 async fn get_context_2hop_interval_traversal() {
1939 let ts1 = "aa000000-0000-0000-0000-000000000020";
1940 let ts2 = "aa000000-0000-0000-0000-000000000021";
1941 let interval_id = "ii000000-0000-0000-0000-000000000001";
1942 let event_id = "bb000000-0000-0000-0000-000000000020";
1943
1944 let interval_node_data: NodeData = HashMap::from([
1945 (Cow::Borrowed("id"), json!(interval_id)),
1946 (Cow::Borrowed("type"), json!("Interval")),
1947 (Cow::Borrowed("name"), json!("Feb-Mar 2024")),
1948 ]);
1949
1950 let graph_db = Arc::new(TestGraphDb {
1951 nodes: vec![
1952 timestamp_node(ts1, 1706745600000), timestamp_node(ts2, 1709251200000), (interval_id.to_string(), interval_node_data.clone()),
1955 event_graph_node(event_id, "Team Meeting"),
1956 ],
1957 edges: vec![(
1958 event_id.to_string(),
1959 interval_id.to_string(),
1960 "during".to_string(),
1961 HashMap::new(),
1962 )],
1963 neighbors: HashMap::from([
1964 (ts1.to_string(), vec![interval_node_data.clone()]),
1966 (ts2.to_string(), vec![interval_node_data]),
1968 (
1970 interval_id.to_string(),
1971 vec![event_node_data(event_id, "Team Meeting")],
1972 ),
1973 ]),
1974 });
1975
1976 let vector_db = Arc::new(TestVectorDb {
1977 collections: HashMap::from([(
1978 TestVectorDb::key("Event", "name"),
1979 vec![SearchResult {
1980 id: uuid::Uuid::parse_str(event_id).unwrap(),
1981 score: 0.92,
1982 metadata: HashMap::new(),
1983 }],
1984 )]),
1985 });
1986
1987 let llm = Arc::new(TestLlm {
1988 completion_response: String::new(),
1989 interval_response: Some(QueryInterval {
1990 starts_at: Some("2024-02".to_string()),
1991 ends_at: Some("2024-03".to_string()),
1992 }),
1993 fail_structured_output: false,
1994 last_messages: Mutex::new(vec![]),
1995 ..TestLlm::default()
1996 });
1997
1998 let retriever = build_retriever(vector_db, graph_db, llm);
1999
2000 let context = retriever
2001 .get_context(
2002 "What meetings happened in Feb-Mar 2024?",
2003 &SearchParams::default(),
2004 )
2005 .await
2006 .unwrap();
2007
2008 assert_eq!(
2009 context.len(),
2010 1,
2011 "should find 1 event via 2-hop traversal (Timestamp -> Interval -> Event)"
2012 );
2013 assert_eq!(
2014 context[0].payload.get("event_name").and_then(Value::as_str),
2015 Some("Team Meeting")
2016 );
2017 }
2018
2019 fn default_session() -> SessionContext {
2024 SessionContext {
2025 session_id: None,
2026 history: vec![],
2027 formatted_history: String::new(),
2028 graph_context: None,
2029 }
2030 }
2031
2032 fn make_event_context() -> Vec<SearchItem> {
2033 vec![
2034 SearchItem {
2035 id: None,
2036 score: Some(0.9),
2037 payload: json!({
2038 "event_id": "evt-1",
2039 "event_name": "Product Launch",
2040 "event_description": "Launched the new product",
2041 "event_time": "2024-03-15",
2042 }),
2043 },
2044 SearchItem {
2045 id: None,
2046 score: Some(0.7),
2047 payload: json!({
2048 "event_id": "evt-2",
2049 "event_name": "Quarterly Review",
2050 "event_description": "Reviewed Q1 results",
2051 "event_time": "2024-04-01",
2052 }),
2053 },
2054 ]
2055 }
2056
2057 fn simple_retriever(llm: Arc<TestLlm>) -> TemporalRetriever {
2058 let vector_db = Arc::new(TestVectorDb {
2059 collections: HashMap::new(),
2060 });
2061 let embedding_engine = Arc::new(TestEmbeddingEngine);
2062 let graph_db = Arc::new(TestGraphDb {
2063 nodes: vec![],
2064 edges: vec![],
2065 neighbors: HashMap::new(),
2066 });
2067
2068 TemporalRetriever::new(
2069 vector_db,
2070 embedding_engine,
2071 graph_db,
2072 llm,
2073 Some(5),
2074 Some(10),
2075 Some(0.0),
2076 None,
2077 None,
2078 None,
2079 None,
2080 None,
2081 )
2082 }
2083
2084 #[tokio::test]
2085 async fn get_completion_generates_text_from_context() {
2086 let llm = Arc::new(TestLlm {
2087 completion_response: "The product was launched in March 2024.".to_string(),
2088 last_messages: Mutex::new(vec![]),
2089 last_structured_messages: Mutex::new(vec![]),
2090 ..TestLlm::default()
2091 });
2092
2093 let retriever = simple_retriever(llm);
2094 let context = make_event_context();
2095 let session = default_session();
2096
2097 let output = retriever
2098 .get_completion(
2099 "What happened in 2024?",
2100 Some(context),
2101 &session,
2102 &SearchParams::default(),
2103 )
2104 .await
2105 .unwrap();
2106
2107 match output {
2108 SearchOutput::Text(text) => {
2109 assert_eq!(text, "The product was launched in March 2024.");
2110 }
2111 other => panic!("Expected SearchOutput::Text, got {other:?}"),
2112 }
2113 }
2114
2115 #[tokio::test]
2116 async fn get_completion_with_provided_context_passes_to_llm() {
2117 let llm = Arc::new(TestLlm {
2118 completion_response: "completion result".to_string(),
2119 last_messages: Mutex::new(vec![]),
2120 last_structured_messages: Mutex::new(vec![]),
2121 ..TestLlm::default()
2122 });
2123
2124 let retriever = simple_retriever(Arc::clone(&llm));
2125 let context = make_event_context();
2126 let session = default_session();
2127
2128 retriever
2129 .get_completion(
2130 "What happened in 2024?",
2131 Some(context),
2132 &session,
2133 &SearchParams::default(),
2134 )
2135 .await
2136 .unwrap();
2137
2138 let messages = llm.last_messages.lock().unwrap();
2139 assert_eq!(messages.len(), 2, "Expected system + user messages");
2140
2141 let user_msg = &messages[1].content;
2143 assert!(
2144 user_msg.contains("Product Launch"),
2145 "User prompt should contain event name from context"
2146 );
2147 assert!(
2148 user_msg.contains("Quarterly Review"),
2149 "User prompt should contain second event name from context"
2150 );
2151 }
2152
2153 #[tokio::test]
2154 async fn get_completion_without_context_calls_get_context() {
2155 let launch_event_ms: i64 = 1710460800000; let ts_id = "ts-aaa";
2158 let event_id = "ev-111";
2159
2160 let vector_db = Arc::new(TestVectorDb {
2161 collections: HashMap::from([(
2162 TestVectorDb::key("Entity", "name"),
2163 vec![SearchResult {
2164 id: Uuid::new_v4(),
2165 score: 0.8,
2166 metadata: HashMap::new(),
2167 }],
2168 )]),
2169 });
2170
2171 let embedding_engine = Arc::new(TestEmbeddingEngine);
2172 let graph_db = Arc::new(TestGraphDb {
2173 nodes: vec![
2174 timestamp_node(ts_id, launch_event_ms),
2175 event_graph_node(event_id, "Launch"),
2176 ],
2177 edges: vec![],
2178 neighbors: HashMap::from([(
2179 ts_id.to_string(),
2180 vec![event_node_data(event_id, "Launch")],
2181 )]),
2182 });
2183
2184 let llm = Arc::new(TestLlm {
2185 completion_response: "answer from internal context".to_string(),
2186 interval_response: Some(QueryInterval {
2187 starts_at: Some("2024-01-01".to_string()),
2188 ends_at: Some("2024-12-31".to_string()),
2189 }),
2190 last_messages: Mutex::new(vec![]),
2191 last_structured_messages: Mutex::new(vec![]),
2192 ..TestLlm::default()
2193 });
2194
2195 let retriever = TemporalRetriever::new(
2196 vector_db,
2197 embedding_engine,
2198 graph_db,
2199 llm.clone(),
2200 Some(5),
2201 Some(10),
2202 Some(0.0),
2203 None,
2204 None,
2205 None,
2206 None,
2207 None,
2208 );
2209
2210 let session = default_session();
2211
2212 let output = retriever
2213 .get_completion(
2214 "What happened in 2024?",
2215 None,
2216 &session,
2217 &SearchParams::default(),
2218 )
2219 .await
2220 .unwrap();
2221
2222 match output {
2223 SearchOutput::Text(text) => {
2224 assert_eq!(text, "answer from internal context");
2225 }
2226 other => panic!("Expected SearchOutput::Text, got {other:?}"),
2227 }
2228
2229 let messages = llm.last_messages.lock().unwrap();
2231 assert!(!messages.is_empty(), "LLM generate should have been called");
2232 let user_msg = &messages[1].content;
2233 assert!(
2234 user_msg.contains("Launch"),
2235 "User prompt should reference the event from internal context"
2236 );
2237 }
2238
2239 #[tokio::test]
2240 async fn get_completion_with_response_schema() {
2241 let structured_value = json!({
2242 "answer": "The product launched in 2024",
2243 "confidence": 0.95
2244 });
2245
2246 let llm = Arc::new(TestLlm {
2247 completion_response: "should not be used".to_string(),
2248 structured_completion_response: Mutex::new(Some(structured_value.clone())),
2249 last_messages: Mutex::new(vec![]),
2250 last_structured_messages: Mutex::new(vec![]),
2251 ..TestLlm::default()
2252 });
2253
2254 let retriever = simple_retriever(llm);
2255 let context = make_event_context();
2256 let session = default_session();
2257
2258 let params = SearchParams {
2259 response_schema: Some(json!({
2260 "type": "object",
2261 "properties": {
2262 "answer": { "type": "string" },
2263 "confidence": { "type": "number" }
2264 }
2265 })),
2266 ..SearchParams::default()
2267 };
2268
2269 let output = retriever
2270 .get_completion("What happened in 2024?", Some(context), &session, ¶ms)
2271 .await
2272 .unwrap();
2273
2274 match output {
2275 SearchOutput::Structured(value) => {
2276 assert_eq!(value, structured_value);
2277 }
2278 other => panic!("Expected SearchOutput::Structured, got {other:?}"),
2279 }
2280 }
2281
2282 #[tokio::test]
2283 async fn get_completion_includes_session_history() {
2284 let llm = Arc::new(TestLlm {
2285 completion_response: "history-aware answer".to_string(),
2286 last_messages: Mutex::new(vec![]),
2287 last_structured_messages: Mutex::new(vec![]),
2288 ..TestLlm::default()
2289 });
2290
2291 let retriever = simple_retriever(Arc::clone(&llm));
2292 let context = make_event_context();
2293
2294 let session = SessionContext {
2295 session_id: Some("sess-1".to_string()),
2296 history: vec![
2297 Message::user("Previous question?".to_string()),
2298 Message::assistant("Previous answer.".to_string()),
2299 ],
2300 formatted_history: "Q: Previous question?\nA: Previous answer.".to_string(),
2301 graph_context: None,
2302 };
2303
2304 retriever
2305 .get_completion(
2306 "Follow-up question?",
2307 Some(context),
2308 &session,
2309 &SearchParams::default(),
2310 )
2311 .await
2312 .unwrap();
2313
2314 let messages = llm.last_messages.lock().unwrap();
2315 assert_eq!(messages.len(), 2, "Expected system + user messages");
2316
2317 let system_msg = &messages[0].content;
2319 assert!(
2320 system_msg.contains("Previous question?"),
2321 "System prompt should include session history"
2322 );
2323 assert!(
2324 system_msg.contains("Previous answer."),
2325 "System prompt should include session history answer"
2326 );
2327 }
2328
2329 fn make_ranked_edge(source_id: &str, target_id: &str, score: f32) -> RankedGraphEdge {
2334 RankedGraphEdge {
2335 source_id: source_id.to_string(),
2336 target_id: target_id.to_string(),
2337 relationship_name: "related_to".to_string(),
2338 score,
2339 source_name: format!("Source-{source_id}"),
2340 target_name: format!("Target-{target_id}"),
2341 dataset_id: None,
2342 source_text: None,
2343 target_text: None,
2344 source_description: None,
2345 target_description: None,
2346 }
2347 }
2348
2349 fn ranking_retriever() -> TemporalRetriever {
2350 let vector_db = Arc::new(TestVectorDb {
2351 collections: HashMap::from([(
2352 TestVectorDb::key("Event", "name"),
2353 vec![
2354 SearchResult {
2355 id: Uuid::parse_str("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa").unwrap(),
2356 score: 0.9,
2357 metadata: HashMap::new(),
2358 },
2359 SearchResult {
2360 id: Uuid::parse_str("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb").unwrap(),
2361 score: 0.5,
2362 metadata: HashMap::new(),
2363 },
2364 SearchResult {
2365 id: Uuid::parse_str("cccccccc-cccc-cccc-cccc-cccccccccccc").unwrap(),
2366 score: 0.3,
2367 metadata: HashMap::new(),
2368 },
2369 ],
2370 )]),
2371 });
2372
2373 let embedding_engine = Arc::new(TestEmbeddingEngine);
2374 let graph_db = Arc::new(TestGraphDb {
2375 nodes: vec![],
2376 edges: vec![],
2377 neighbors: HashMap::new(),
2378 });
2379
2380 let llm = Arc::new(TestLlm::default());
2381
2382 TemporalRetriever::new(
2383 vector_db,
2384 embedding_engine,
2385 graph_db,
2386 llm,
2387 Some(5),
2388 Some(10),
2389 Some(0.0),
2390 None,
2391 None,
2392 None,
2393 None,
2394 None,
2395 )
2396 }
2397
2398 #[tokio::test]
2399 async fn rank_sorts_by_combined_score() {
2400 let retriever = ranking_retriever();
2401
2402 let event_ids: HashSet<String> = [
2403 "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
2404 "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb",
2405 "cccccccc-cccc-cccc-cccc-cccccccccccc",
2406 ]
2407 .iter()
2408 .map(|s| s.to_string())
2409 .collect();
2410
2411 let ranked_edges = vec![
2412 make_ranked_edge("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", "other-node", 0.8),
2413 make_ranked_edge("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", "other-node", 0.4),
2414 make_ranked_edge("cccccccc-cccc-cccc-cccc-cccccccccccc", "other-node", 0.2),
2415 ];
2416
2417 let ranked = retriever
2418 .rank_temporal_events("test query", &event_ids, &ranked_edges)
2419 .await
2420 .unwrap();
2421
2422 assert_eq!(ranked.len(), 3);
2423 assert_eq!(ranked[0].0, "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa");
2424 assert_eq!(ranked[1].0, "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb");
2425 assert_eq!(ranked[2].0, "cccccccc-cccc-cccc-cccc-cccccccccccc");
2426
2427 assert!(ranked[0].1 >= ranked[1].1);
2429 assert!(ranked[1].1 >= ranked[2].1);
2430 }
2431
2432 #[tokio::test]
2433 async fn rank_events_not_in_vector_get_default_score() {
2434 let retriever = ranking_retriever();
2435
2436 let event_ids: HashSet<String> = [
2437 "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
2438 "dddddddd-dddd-dddd-dddd-dddddddddddd", ]
2440 .iter()
2441 .map(|s| s.to_string())
2442 .collect();
2443
2444 let ranked_edges = vec![make_ranked_edge(
2445 "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
2446 "other-node",
2447 0.8,
2448 )];
2449
2450 let ranked = retriever
2451 .rank_temporal_events("test query", &event_ids, &ranked_edges)
2452 .await
2453 .unwrap();
2454
2455 assert_eq!(ranked.len(), 2);
2456
2457 let unknown_event = ranked
2458 .iter()
2459 .find(|(id, _)| id == "dddddddd-dddd-dddd-dddd-dddddddddddd")
2460 .unwrap();
2461 assert!(
2462 unknown_event.1.abs() < f32::EPSILON,
2463 "Unknown event should have score 0.0, got {}",
2464 unknown_event.1
2465 );
2466
2467 let known_event = ranked
2468 .iter()
2469 .find(|(id, _)| id == "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
2470 .unwrap();
2471 assert!(
2472 known_event.1 > unknown_event.1,
2473 "Known event should have higher score"
2474 );
2475 }
2476
2477 #[tokio::test]
2478 async fn rank_empty_vector_results() {
2479 let vector_db = Arc::new(TestVectorDb {
2480 collections: HashMap::new(),
2481 });
2482 let embedding_engine = Arc::new(TestEmbeddingEngine);
2483 let graph_db = Arc::new(TestGraphDb {
2484 nodes: vec![],
2485 edges: vec![],
2486 neighbors: HashMap::new(),
2487 });
2488 let llm = Arc::new(TestLlm::default());
2489
2490 let retriever = TemporalRetriever::new(
2491 vector_db,
2492 embedding_engine,
2493 graph_db,
2494 llm,
2495 Some(5),
2496 Some(10),
2497 Some(0.0),
2498 None,
2499 None,
2500 None,
2501 None,
2502 None,
2503 );
2504
2505 let event_ids: HashSet<String> = ["ev-1", "ev-2"].iter().map(|s| s.to_string()).collect();
2506
2507 let ranked_edges = vec![
2508 make_ranked_edge("ev-1", "other", 0.6),
2509 make_ranked_edge("other", "ev-2", 0.3),
2510 ];
2511
2512 let ranked = retriever
2513 .rank_temporal_events("query", &event_ids, &ranked_edges)
2514 .await
2515 .unwrap();
2516
2517 assert_eq!(ranked.len(), 2);
2518 assert_eq!(ranked[0].0, "ev-1");
2519 assert!((ranked[0].1 - 0.6).abs() < f32::EPSILON);
2520 assert_eq!(ranked[1].0, "ev-2");
2521 assert!((ranked[1].1 - 0.3).abs() < f32::EPSILON);
2522 }
2523
2524 #[tokio::test]
2525 async fn rank_empty_event_ids() {
2526 let retriever = ranking_retriever();
2527
2528 let event_ids: HashSet<String> = HashSet::new();
2529 let ranked_edges = vec![make_ranked_edge("some-node", "other-node", 0.5)];
2530
2531 let ranked = retriever
2532 .rank_temporal_events("query", &event_ids, &ranked_edges)
2533 .await
2534 .unwrap();
2535
2536 assert!(
2537 ranked.is_empty(),
2538 "Empty event_ids should yield empty result"
2539 );
2540 }
2541
2542 #[tokio::test]
2543 async fn rank_mismatched_vector_ids() {
2544 let vector_db = Arc::new(TestVectorDb {
2545 collections: HashMap::from([(
2546 TestVectorDb::key("Event", "name"),
2547 vec![SearchResult {
2548 id: Uuid::parse_str("ffffffff-ffff-ffff-ffff-ffffffffffff").unwrap(),
2549 score: 0.99,
2550 metadata: HashMap::new(),
2551 }],
2552 )]),
2553 });
2554 let embedding_engine = Arc::new(TestEmbeddingEngine);
2555 let graph_db = Arc::new(TestGraphDb {
2556 nodes: vec![],
2557 edges: vec![],
2558 neighbors: HashMap::new(),
2559 });
2560 let llm = Arc::new(TestLlm::default());
2561
2562 let retriever = TemporalRetriever::new(
2563 vector_db,
2564 embedding_engine,
2565 graph_db,
2566 llm,
2567 Some(5),
2568 Some(10),
2569 Some(0.0),
2570 None,
2571 None,
2572 None,
2573 None,
2574 None,
2575 );
2576
2577 let event_ids: HashSet<String> =
2578 ["ev-abc", "ev-def"].iter().map(|s| s.to_string()).collect();
2579
2580 let ranked_edges = vec![make_ranked_edge("ev-abc", "something", 0.4)];
2581
2582 let ranked = retriever
2583 .rank_temporal_events("query", &event_ids, &ranked_edges)
2584 .await
2585 .unwrap();
2586
2587 assert_eq!(ranked.len(), 2);
2588 let ev_abc = ranked.iter().find(|(id, _)| id == "ev-abc").unwrap();
2589 assert!((ev_abc.1 - 0.4).abs() < f32::EPSILON);
2590
2591 let ev_def = ranked.iter().find(|(id, _)| id == "ev-def").unwrap();
2592 assert!(ev_def.1.abs() < f32::EPSILON);
2593 }
2594
2595 #[test]
2600 fn context_to_text_formats_event_items() {
2601 let context = vec![
2602 SearchItem {
2603 id: None,
2604 score: Some(0.9),
2605 payload: json!({
2606 "event_id": "evt-1",
2607 "event_name": "Product Launch",
2608 "event_description": "Launched the new product",
2609 "event_time": "2024-03-15",
2610 }),
2611 },
2612 SearchItem {
2613 id: None,
2614 score: Some(0.7),
2615 payload: json!({
2616 "event_id": "evt-2",
2617 "event_name": "Quarterly Review",
2618 "event_description": "Reviewed Q1 results",
2619 "event_time": "2024-04-01",
2620 }),
2621 },
2622 ];
2623
2624 let text = TemporalRetriever::temporal_context_to_text(&context);
2625 let lines: Vec<&str> = text.lines().collect();
2626
2627 assert_eq!(lines.len(), 2);
2628 assert_eq!(
2629 lines[0],
2630 "Product Launch (2024-03-15): Launched the new product"
2631 );
2632 assert_eq!(
2633 lines[1],
2634 "Quarterly Review (2024-04-01): Reviewed Q1 results"
2635 );
2636 }
2637
2638 #[test]
2639 fn context_to_text_formats_triplet_items() {
2640 let context = vec![
2641 SearchItem {
2642 id: None,
2643 score: Some(0.8),
2644 payload: json!({
2645 "source_name": "Alice",
2646 "target_name": "Bob",
2647 "relationship": "knows",
2648 }),
2649 },
2650 SearchItem {
2651 id: None,
2652 score: Some(0.6),
2653 payload: json!({
2654 "source_name": "Company X",
2655 "target_name": "Product Y",
2656 "relationship": "produces",
2657 }),
2658 },
2659 ];
2660
2661 let text = TemporalRetriever::temporal_context_to_text(&context);
2662 let lines: Vec<&str> = text.lines().collect();
2663
2664 assert_eq!(lines.len(), 2);
2665 assert_eq!(lines[0], "Alice -[knows]-> Bob");
2666 assert_eq!(lines[1], "Company X -[produces]-> Product Y");
2667 }
2668
2669 #[test]
2670 fn context_to_text_empty_context() {
2671 let context: Vec<SearchItem> = vec![];
2672 let text = TemporalRetriever::temporal_context_to_text(&context);
2673 assert_eq!(text, "");
2674 }
2675
2676 #[test]
2677 fn context_to_text_missing_fields_use_defaults() {
2678 let context = vec![SearchItem {
2679 id: None,
2680 score: Some(0.5),
2681 payload: json!({
2682 "event_id": "evt-bare",
2683 }),
2684 }];
2685
2686 let text = TemporalRetriever::temporal_context_to_text(&context);
2687 assert_eq!(text, "Unnamed event (unknown time): No description");
2688 }
2689
2690 #[test]
2691 fn context_to_text_mixed_items() {
2692 let context = vec![
2693 SearchItem {
2694 id: None,
2695 score: Some(0.9),
2696 payload: json!({
2697 "event_id": "evt-1",
2698 "event_name": "Conference",
2699 "event_description": "Annual tech conference",
2700 "event_time": "2024-06-15",
2701 }),
2702 },
2703 SearchItem {
2704 id: None,
2705 score: Some(0.7),
2706 payload: json!({
2707 "source_name": "Speaker",
2708 "target_name": "Conference",
2709 "relationship": "presents_at",
2710 }),
2711 },
2712 ];
2713
2714 let text = TemporalRetriever::temporal_context_to_text(&context);
2715 let lines: Vec<&str> = text.lines().collect();
2716
2717 assert_eq!(lines.len(), 2);
2718 assert_eq!(lines[0], "Conference (2024-06-15): Annual tech conference");
2719 assert_eq!(lines[1], "Speaker -[presents_at]-> Conference");
2720 }
2721}