Skip to main content

cognee_search/retrievers/
temporal_retriever.rs

1use std::borrow::Cow;
2use std::collections::{HashMap, HashSet};
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, TimeZone, Utc};
7use cognee_embedding::EmbeddingEngine;
8use cognee_graph::{GraphDBTrait, NodeData};
9use cognee_llm::{GenerationOptions, Llm, LlmExt, Message};
10use cognee_vector::VectorDB;
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14
15use cognee_session::SessionContext;
16
17use crate::graph_retrieval::{
18    DEFAULT_TRIPLET_DISTANCE_PENALTY, GraphRetrievalConfig, RankedGraphEdge,
19    brute_force_triplet_search,
20};
21use crate::retrievers::SearchRetriever;
22use crate::types::{
23    SearchContext, SearchError, SearchItem, SearchOutput, SearchParams, SearchType,
24};
25use crate::utils::{build_messages_with_history, render_graph_user_prompt, resolve_system_prompt};
26
27const DEFAULT_TOP_K: usize = 10;
28const DEFAULT_WIDE_SEARCH_TOP_K: usize = 100;
29const TEMPORAL_DATA_TYPE: &str = "Event";
30const TEMPORAL_FIELD_NAME: &str = "name";
31const DEFAULT_TEMPORAL_INTERVAL_PROMPT: &str = "You are tasked with identifying relevant time periods where the answer to a given query should be searched.\nCurrent date is:  `{time_now}`. Determine relevant period(s) and return structured intervals.\n\nExtraction rules:\n\n1. Query without specific timestamp: use the time period with starts_at set to None and ends_at set to now.\n2. Explicit time intervals: If the query specifies a range (e.g., from 2010 to 2020, between January and March 2023), extract both start and end dates. Always assign the earlier date to starts_at and the later date to ends_at.\n3. Single timestamp: If the query refers to one specific moment (e.g., in 2015, on March 5, 2022), set starts_at and ends_at to that same timestamp.\n4. Open-ended time references: For phrases such as \"before X\" or \"after X\", represent the unspecified side as None. For example: before 2009 → starts_at: None, ends_at: 2009; after 2009 → starts_at: 2009, ends_at: None.\n5. Current-time references (\"now\", \"current\", \"today\"): If the query explicitly refers to the present, set both starts_at and ends_at to now (the ingestion timestamp).\n6. \"Who is\" and \"Who was\" questions: These imply a general identity or biographical inquiry without a specific temporal scope. Set both starts_at and ends_at to None.\n7. Ordering rule: Always ensure the earlier date is assigned to starts_at and the later date to ends_at.\n8. No temporal information: If no valid or inferable time reference is found, set both starts_at and ends_at to None.";
32
33#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
34struct QueryInterval {
35    starts_at: Option<String>,
36    ends_at: Option<String>,
37}
38
39#[derive(Debug, Clone)]
40struct ParsedInterval {
41    start: Option<DateTime<Utc>>,
42    end: Option<DateTime<Utc>>,
43}
44
45impl QueryInterval {
46    fn parse(self) -> ParsedInterval {
47        ParsedInterval {
48            start: self
49                .starts_at
50                .as_deref()
51                .and_then(|value| parse_bound(value, false)),
52            end: self
53                .ends_at
54                .as_deref()
55                .and_then(|value| parse_bound(value, true)),
56        }
57    }
58}
59
60pub struct TemporalRetriever {
61    vector_db: Arc<dyn VectorDB>,
62    embedding_engine: Arc<dyn EmbeddingEngine>,
63    graph_db: Arc<dyn GraphDBTrait>,
64    llm: Arc<dyn Llm>,
65    top_k: usize,
66    wide_search_top_k: usize,
67    triplet_distance_penalty: f32,
68    feedback_influence: f32,
69    temporal_interval_prompt: Option<String>,
70    system_prompt: Option<String>,
71    system_prompt_path: Option<String>,
72    user_prompt_template: Option<String>,
73    generation_options: Option<GenerationOptions>,
74}
75
76impl TemporalRetriever {
77    #[allow(clippy::too_many_arguments)]
78    pub fn new(
79        vector_db: Arc<dyn VectorDB>,
80        embedding_engine: Arc<dyn EmbeddingEngine>,
81        graph_db: Arc<dyn GraphDBTrait>,
82        llm: Arc<dyn Llm>,
83        top_k: Option<usize>,
84        wide_search_top_k: Option<usize>,
85        triplet_distance_penalty: Option<f32>,
86        temporal_interval_prompt: Option<String>,
87        system_prompt: Option<String>,
88        system_prompt_path: Option<String>,
89        user_prompt_template: Option<String>,
90        generation_options: Option<GenerationOptions>,
91    ) -> Self {
92        Self {
93            vector_db,
94            embedding_engine,
95            graph_db,
96            llm,
97            top_k: top_k.unwrap_or(DEFAULT_TOP_K),
98            wide_search_top_k: wide_search_top_k.unwrap_or(DEFAULT_WIDE_SEARCH_TOP_K),
99            triplet_distance_penalty: triplet_distance_penalty
100                .unwrap_or(DEFAULT_TRIPLET_DISTANCE_PENALTY),
101            feedback_influence: 0.0,
102            temporal_interval_prompt,
103            system_prompt,
104            system_prompt_path,
105            user_prompt_template,
106            generation_options,
107        }
108    }
109
110    async fn extract_interval(&self, query: &str) -> Result<Option<ParsedInterval>, SearchError> {
111        let now = chrono::Local::now().format("%d-%m-%Y").to_string();
112        let prompt_template = self
113            .temporal_interval_prompt
114            .as_deref()
115            .unwrap_or(DEFAULT_TEMPORAL_INTERVAL_PROMPT);
116        let system_prompt = prompt_template.replace("{time_now}", &now);
117
118        let interval = match self
119            .llm
120            .create_structured_output_with_messages::<QueryInterval>(
121                vec![
122                    Message::system(system_prompt),
123                    Message::user(query.to_string()),
124                ],
125                self.generation_options.clone(),
126            )
127            .await
128        {
129            Ok(interval) => interval,
130            Err(_) => return Ok(None),
131        };
132
133        let parsed = interval.parse();
134        if parsed.start.is_none() && parsed.end.is_none() {
135            return Ok(None);
136        }
137
138        Ok(Some(parsed))
139    }
140
141    fn get_graph_retrieval_config(&self, params: &SearchParams) -> GraphRetrievalConfig {
142        GraphRetrievalConfig {
143            top_k: params.top_k_or(self.top_k),
144            wide_search_top_k: params.wide_search_top_k_or(self.wide_search_top_k),
145            triplet_distance_penalty: params
146                .triplet_distance_penalty_or(self.triplet_distance_penalty),
147            feedback_influence: params.feedback_influence_or(self.feedback_influence),
148            node_type: params.node_type.clone(),
149            node_name: params.node_name.clone(),
150            node_name_filter_operator: params
151                .node_name_filter_operator
152                .as_deref()
153                .unwrap_or("OR")
154                .to_string(),
155        }
156    }
157
158    async fn get_ranked_graph_edges(
159        &self,
160        query: &str,
161        params: &SearchParams,
162    ) -> Result<Vec<RankedGraphEdge>, SearchError> {
163        brute_force_triplet_search(
164            query,
165            self.vector_db.as_ref(),
166            self.embedding_engine.as_ref(),
167            self.graph_db.as_ref(),
168            &self.get_graph_retrieval_config(params),
169        )
170        .await
171    }
172
173    fn ranked_edges_to_context(ranked_edges: Vec<RankedGraphEdge>) -> SearchContext {
174        ranked_edges
175            .into_iter()
176            .map(|edge| SearchItem {
177                id: None,
178                score: Some(edge.score),
179                payload: json!({
180                    "source_id": edge.source_id,
181                    "target_id": edge.target_id,
182                    "relationship": edge.relationship_name,
183                    "source_name": edge.source_name,
184                    "target_name": edge.target_name,
185                    "source_text": edge.source_text,
186                    "target_text": edge.target_text,
187                    "source_description": edge.source_description,
188                    "target_description": edge.target_description,
189                }),
190            })
191            .collect()
192    }
193
194    async fn get_fallback_context(
195        &self,
196        query: &str,
197        params: &SearchParams,
198    ) -> Result<SearchContext, SearchError> {
199        let ranked_edges = self.get_ranked_graph_edges(query, params).await?;
200        Ok(Self::ranked_edges_to_context(ranked_edges))
201    }
202
203    async fn rank_temporal_events(
204        &self,
205        query: &str,
206        event_ids: &HashSet<String>,
207        ranked_edges: &[RankedGraphEdge],
208    ) -> Result<Vec<(String, f32)>, SearchError> {
209        let mut scores = HashMap::<String, f32>::new();
210
211        for edge in ranked_edges {
212            if event_ids.contains(&edge.source_id) {
213                let score = scores.entry(edge.source_id.clone()).or_insert(edge.score);
214                *score = score.max(edge.score);
215            }
216            if event_ids.contains(&edge.target_id) {
217                let score = scores.entry(edge.target_id.clone()).or_insert(edge.score);
218                *score = score.max(edge.score);
219            }
220        }
221
222        if self
223            .vector_db
224            .has_collection(TEMPORAL_DATA_TYPE, TEMPORAL_FIELD_NAME)
225            .await?
226        {
227            let query_embeddings = self.embedding_engine.embed(&[query]).await?;
228            let query_vector = query_embeddings.into_iter().next().ok_or_else(|| {
229                SearchError::InvalidInput("embedding engine returned no vectors".to_string())
230            })?;
231
232            let semantic_results = self
233                .vector_db
234                .search_similar(
235                    TEMPORAL_DATA_TYPE,
236                    TEMPORAL_FIELD_NAME,
237                    &query_vector,
238                    self.wide_search_top_k.max(self.top_k),
239                )
240                .await?;
241
242            for result in semantic_results {
243                let event_id = result.id.to_string();
244                if !event_ids.contains(&event_id) {
245                    continue;
246                }
247
248                let score = scores.entry(event_id).or_insert(result.score);
249                *score = score.max(result.score);
250            }
251        }
252
253        let mut ranked = event_ids
254            .iter()
255            .map(|event_id| {
256                (
257                    event_id.clone(),
258                    scores.get(event_id).copied().unwrap_or(0.0),
259                )
260            })
261            .collect::<Vec<_>>();
262
263        ranked.sort_by(|left, right| {
264            right
265                .1
266                .partial_cmp(&left.1)
267                .unwrap_or(std::cmp::Ordering::Equal)
268                .then_with(|| left.0.cmp(&right.0))
269        });
270
271        Ok(ranked)
272    }
273
274    fn temporal_context_to_text(context: &SearchContext) -> String {
275        context
276            .iter()
277            .map(|item| {
278                if item.payload.get("event_id").is_some() {
279                    let name = item
280                        .payload
281                        .get("event_name")
282                        .and_then(Value::as_str)
283                        .unwrap_or("Unnamed event");
284                    let description = item
285                        .payload
286                        .get("event_description")
287                        .and_then(Value::as_str)
288                        .unwrap_or("No description");
289                    let time = item
290                        .payload
291                        .get("event_time")
292                        .and_then(Value::as_str)
293                        .unwrap_or("unknown time");
294
295                    return format!("{name} ({time}): {description}");
296                }
297
298                let source = item
299                    .payload
300                    .get("source_name")
301                    .and_then(Value::as_str)
302                    .or_else(|| item.payload.get("source_id").and_then(Value::as_str))
303                    .unwrap_or("unknown_source");
304                let target = item
305                    .payload
306                    .get("target_name")
307                    .and_then(Value::as_str)
308                    .or_else(|| item.payload.get("target_id").and_then(Value::as_str))
309                    .unwrap_or("unknown_target");
310                let relationship = item
311                    .payload
312                    .get("relationship")
313                    .and_then(Value::as_str)
314                    .or_else(|| {
315                        item.payload
316                            .get("relationship_name")
317                            .and_then(Value::as_str)
318                    })
319                    .unwrap_or("related_to");
320
321                format!("{source} -[{relationship}]-> {target}")
322            })
323            .collect::<Vec<_>>()
324            .join("\n")
325    }
326}
327
328#[async_trait]
329impl SearchRetriever for TemporalRetriever {
330    fn search_type(&self) -> SearchType {
331        SearchType::Temporal
332    }
333
334    async fn get_context(
335        &self,
336        query: &str,
337        params: &SearchParams,
338    ) -> Result<SearchContext, SearchError> {
339        if self.graph_db.is_empty().await? {
340            return Ok(vec![]);
341        }
342
343        let Some(interval) = self.extract_interval(query).await? else {
344            return self.get_fallback_context(query, params).await;
345        };
346
347        // Fix 1: Use typed query to find Timestamp nodes instead of full graph scan.
348        let (candidate_timestamps, _) = self
349            .graph_db
350            .get_filtered_graph_data(&HashMap::from([(
351                Cow::Borrowed("type"),
352                vec![json!("Timestamp")],
353            )]))
354            .await?;
355
356        let interval_from_ms = interval.start.map(|dt| dt.timestamp_millis());
357        let interval_to_ms = interval.end.map(|dt| dt.timestamp_millis());
358
359        let matching_ts_ids: Vec<String> = candidate_timestamps
360            .into_iter()
361            .filter_map(|(id, props)| {
362                let time_at = props.get("time_at")?.as_i64()?;
363                is_within_interval_ms(time_at, interval_from_ms, interval_to_ms).then_some(id)
364            })
365            .collect();
366
367        // Fix 2: Collect Event nodes reachable within 1-2 hops from matching Timestamps.
368        let mut event_node_ids = HashSet::new();
369        for ts_id in &matching_ts_ids {
370            for node_props in self.graph_db.get_neighbors(ts_id).await? {
371                let node_type = node_props.get("type").and_then(|v| v.as_str());
372                match node_type {
373                    Some("Event") => {
374                        if let Some(id) = node_props.get("id").and_then(|v| v.as_str()) {
375                            event_node_ids.insert(id.to_string());
376                        }
377                    }
378                    Some("Interval") => {
379                        // Hop through Interval node to reach Event nodes (hop 2).
380                        if let Some(interval_id) = node_props.get("id").and_then(|v| v.as_str()) {
381                            for inner_props in self.graph_db.get_neighbors(interval_id).await? {
382                                if inner_props.get("type").and_then(|v| v.as_str()) == Some("Event")
383                                    && let Some(id) = inner_props.get("id").and_then(|v| v.as_str())
384                                {
385                                    event_node_ids.insert(id.to_string());
386                                }
387                            }
388                        }
389                    }
390                    _ => {}
391                }
392            }
393        }
394
395        if event_node_ids.is_empty() {
396            return self.get_fallback_context(query, params).await;
397        }
398
399        let ranked_edges = self.get_ranked_graph_edges(query, params).await?;
400        let ranked_events = self
401            .rank_temporal_events(query, &event_node_ids, &ranked_edges)
402            .await?;
403
404        // Fetch Event nodes by ID for building the context payload.
405        let event_id_list: Vec<String> = ranked_events
406            .iter()
407            .take(params.top_k_or(self.top_k))
408            .map(|(id, _)| id.clone())
409            .collect();
410        let event_nodes = self.graph_db.get_nodes(&event_id_list).await?;
411        let nodes_by_id: HashMap<String, NodeData> =
412            event_id_list.into_iter().zip(event_nodes).collect();
413
414        let mut temporal_context = Vec::new();
415
416        for (event_id, score) in ranked_events.into_iter().take(params.top_k_or(self.top_k)) {
417            let Some(event_node) = nodes_by_id.get(&event_id) else {
418                continue;
419            };
420
421            temporal_context.push(SearchItem {
422                id: None,
423                score: Some(score),
424                payload: json!({
425                    "event_id": event_id,
426                    "event_name": extract_node_name(event_node),
427                    "event_description": extract_node_description(event_node),
428                }),
429            });
430        }
431
432        if temporal_context.is_empty() {
433            return Ok(Self::ranked_edges_to_context(ranked_edges));
434        }
435
436        Ok(temporal_context)
437    }
438
439    async fn get_completion(
440        &self,
441        query: &str,
442        context: Option<SearchContext>,
443        session: &SessionContext,
444        params: &SearchParams,
445    ) -> Result<SearchOutput, SearchError> {
446        let completion_context = match context {
447            Some(existing_context) => existing_context,
448            None => self.get_context(query, params).await?,
449        };
450
451        let system_prompt = resolve_system_prompt(
452            params
453                .system_prompt
454                .as_deref()
455                .or(self.system_prompt.as_deref()),
456            params
457                .system_prompt_path
458                .as_deref()
459                .or(self.system_prompt_path.as_deref()),
460        )?;
461
462        let user_prompt = render_graph_user_prompt(
463            self.user_prompt_template.as_deref(),
464            query,
465            &Self::temporal_context_to_text(&completion_context),
466        );
467
468        let messages = build_messages_with_history(system_prompt, user_prompt, session);
469
470        if let Some(schema) = &params.response_schema {
471            let structured_value = self
472                .llm
473                .create_structured_output_with_messages_raw(
474                    messages,
475                    schema,
476                    self.generation_options.clone(),
477                )
478                .await
479                .map_err(|e| SearchError::LlmError(e.to_string()))?;
480            Ok(SearchOutput::Structured(structured_value))
481        } else {
482            let completion = self
483                .llm
484                .generate(messages, self.generation_options.clone())
485                .await?;
486            Ok(SearchOutput::Text(completion.content))
487        }
488    }
489}
490
491fn extract_node_name(node_data: &NodeData) -> String {
492    node_data
493        .get("name")
494        .and_then(Value::as_str)
495        .or_else(|| node_data.get("title").and_then(Value::as_str))
496        .unwrap_or("Unnamed event")
497        .to_string()
498}
499
500fn extract_node_description(node_data: &NodeData) -> String {
501    node_data
502        .get("description")
503        .and_then(Value::as_str)
504        .or_else(|| node_data.get("text").and_then(Value::as_str))
505        .unwrap_or("")
506        .to_string()
507}
508
509// Fix 3: millisecond-based interval check for Timestamp nodes.
510fn is_within_interval_ms(time_at_ms: i64, from_ms: Option<i64>, to_ms: Option<i64>) -> bool {
511    from_ms.is_none_or(|from| time_at_ms >= from) && to_ms.is_none_or(|to| time_at_ms <= to)
512}
513
514fn parse_bound(input: &str, is_end: bool) -> Option<DateTime<Utc>> {
515    let trimmed = input.trim();
516    if trimmed.is_empty() {
517        return None;
518    }
519
520    if let Ok(timestamp) = DateTime::parse_from_rfc3339(trimmed) {
521        return Some(timestamp.with_timezone(&Utc));
522    }
523
524    if let Ok(naive_dt) = NaiveDateTime::parse_from_str(trimmed, "%Y-%m-%d %H:%M:%S") {
525        return Some(Utc.from_utc_datetime(&naive_dt));
526    }
527
528    if let Ok(date) = NaiveDate::parse_from_str(trimmed, "%Y-%m-%d") {
529        return to_datetime(date, is_end);
530    }
531
532    if trimmed.len() == 7 {
533        let month_candidate = format!("{trimmed}-01");
534        if let Ok(date) = NaiveDate::parse_from_str(&month_candidate, "%Y-%m-%d") {
535            return if is_end {
536                let (next_year, next_month) = if date.month() == 12 {
537                    (date.year() + 1, 1)
538                } else {
539                    (date.year(), date.month() + 1)
540                };
541
542                let next_month_start = NaiveDate::from_ymd_opt(next_year, next_month, 1)?;
543                let month_end = next_month_start.pred_opt()?;
544                to_datetime(month_end, true)
545            } else {
546                to_datetime(date, false)
547            };
548        }
549    }
550
551    if trimmed.len() == 4
552        && trimmed.chars().all(|character| character.is_ascii_digit())
553        && let Ok(year) = trimmed.parse::<i32>()
554    {
555        let date = if is_end {
556            NaiveDate::from_ymd_opt(year, 12, 31)?
557        } else {
558            NaiveDate::from_ymd_opt(year, 1, 1)?
559        };
560
561        return to_datetime(date, is_end);
562    }
563
564    None
565}
566
567fn to_datetime(date: NaiveDate, is_end: bool) -> Option<DateTime<Utc>> {
568    let naive_dt = if is_end {
569        date.and_hms_opt(23, 59, 59)?
570    } else {
571        date.and_hms_opt(0, 0, 0)?
572    };
573
574    Some(Utc.from_utc_datetime(&naive_dt))
575}
576
577#[cfg(test)]
578#[allow(
579    clippy::unwrap_used,
580    clippy::expect_used,
581    reason = "test code — panics are acceptable failures"
582)]
583mod tests {
584    use std::borrow::Cow;
585    use std::collections::{HashMap, HashSet};
586    use std::sync::{Arc, Mutex};
587
588    use async_trait::async_trait;
589    use cognee_embedding::EmbeddingResult;
590    use cognee_embedding::engine::EmbeddingEngine;
591    use cognee_graph::{EdgeData, GraphDBResult, GraphDBTrait, GraphNode, NodeData};
592    use cognee_llm::{
593        GenerationOptions, GenerationResponse, Llm, LlmError, LlmResult, Message, TokenUsage,
594    };
595    use cognee_vector::{SearchResult, VectorDB, VectorDBResult, VectorPoint};
596
597    use chrono::{TimeZone, Utc};
598    use serde_json::{Value, json};
599    use uuid::Uuid;
600
601    use cognee_session::SessionContext;
602
603    use super::{QueryInterval, TemporalRetriever};
604    use crate::graph_retrieval::RankedGraphEdge;
605    use crate::retrievers::SearchRetriever;
606    use crate::types::{SearchItem, SearchOutput, SearchParams};
607
608    struct TestEmbeddingEngine;
609
610    #[async_trait]
611    impl EmbeddingEngine for TestEmbeddingEngine {
612        async fn embed(&self, _texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
613            Ok(vec![vec![0.3, 0.7]])
614        }
615
616        fn dimension(&self) -> usize {
617            2
618        }
619
620        fn batch_size(&self) -> usize {
621            16
622        }
623
624        fn max_sequence_length(&self) -> usize {
625            512
626        }
627    }
628
629    struct TestVectorDb {
630        collections: HashMap<String, Vec<SearchResult>>,
631    }
632
633    impl TestVectorDb {
634        fn key(data_type: &str, field_name: &str) -> String {
635            format!("{data_type}_{field_name}")
636        }
637    }
638
639    #[async_trait]
640    impl VectorDB for TestVectorDb {
641        async fn create_collection(
642            &self,
643            _data_type: &str,
644            _field_name: &str,
645            _dimension: usize,
646        ) -> VectorDBResult<()> {
647            Ok(())
648        }
649
650        async fn has_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<bool> {
651            Ok(self
652                .collections
653                .contains_key(&Self::key(data_type, field_name)))
654        }
655
656        async fn index_points(
657            &self,
658            _data_type: &str,
659            _field_name: &str,
660            _points: &[VectorPoint],
661        ) -> VectorDBResult<()> {
662            Ok(())
663        }
664
665        async fn search_similar(
666            &self,
667            data_type: &str,
668            field_name: &str,
669            _query_vector: &[f32],
670            top_k: usize,
671        ) -> VectorDBResult<Vec<SearchResult>> {
672            Ok(self
673                .collections
674                .get(&Self::key(data_type, field_name))
675                .cloned()
676                .unwrap_or_default()
677                .into_iter()
678                .take(top_k)
679                .collect())
680        }
681
682        async fn delete_collection(
683            &self,
684            _data_type: &str,
685            _field_name: &str,
686        ) -> VectorDBResult<()> {
687            Ok(())
688        }
689
690        async fn delete_points(
691            &self,
692            _data_type: &str,
693            _field_name: &str,
694            _point_ids: &[Uuid],
695        ) -> VectorDBResult<()> {
696            Ok(())
697        }
698
699        async fn collection_size(
700            &self,
701            data_type: &str,
702            field_name: &str,
703        ) -> VectorDBResult<usize> {
704            Ok(self
705                .collections
706                .get(&Self::key(data_type, field_name))
707                .map(|items| items.len())
708                .unwrap_or_default())
709        }
710    }
711
712    struct TestGraphDb {
713        nodes: Vec<GraphNode>,
714        edges: Vec<EdgeData>,
715        /// Maps node_id -> list of neighbor NodeData returned by get_neighbors.
716        neighbors: HashMap<String, Vec<NodeData>>,
717    }
718
719    #[async_trait]
720    impl GraphDBTrait for TestGraphDb {
721        async fn initialize(&self) -> GraphDBResult<()> {
722            Ok(())
723        }
724
725        async fn is_empty(&self) -> GraphDBResult<bool> {
726            Ok(self.nodes.is_empty())
727        }
728
729        async fn query(
730            &self,
731            _query: &str,
732            _params: Option<HashMap<Cow<'static, str>, Value>>,
733        ) -> GraphDBResult<Vec<Vec<Value>>> {
734            Ok(vec![])
735        }
736
737        async fn delete_graph(&self) -> GraphDBResult<()> {
738            Ok(())
739        }
740
741        async fn has_node(&self, _node_id: &str) -> GraphDBResult<bool> {
742            Ok(false)
743        }
744
745        async fn add_node_raw(&self, _node: serde_json::Value) -> GraphDBResult<()> {
746            Ok(())
747        }
748
749        async fn add_nodes_raw(&self, _nodes: Vec<serde_json::Value>) -> GraphDBResult<()> {
750            Ok(())
751        }
752
753        async fn delete_node(&self, _node_id: &str) -> GraphDBResult<()> {
754            Ok(())
755        }
756
757        async fn delete_nodes(&self, _node_ids: &[String]) -> GraphDBResult<()> {
758            Ok(())
759        }
760
761        async fn get_node(&self, _node_id: &str) -> GraphDBResult<Option<NodeData>> {
762            Ok(None)
763        }
764
765        async fn get_nodes(&self, node_ids: &[String]) -> GraphDBResult<Vec<NodeData>> {
766            let nodes_map: HashMap<&str, &NodeData> = self
767                .nodes
768                .iter()
769                .map(|(id, data)| (id.as_str(), data))
770                .collect();
771            Ok(node_ids
772                .iter()
773                .filter_map(|id| nodes_map.get(id.as_str()).map(|d| (*d).clone()))
774                .collect())
775        }
776
777        async fn has_edge(
778            &self,
779            _source_id: &str,
780            _target_id: &str,
781            _relationship_name: &str,
782        ) -> GraphDBResult<bool> {
783            Ok(false)
784        }
785
786        async fn has_edges(&self, _edges: &[EdgeData]) -> GraphDBResult<Vec<EdgeData>> {
787            Ok(vec![])
788        }
789
790        async fn add_edge(
791            &self,
792            _source_id: &str,
793            _target_id: &str,
794            _relationship_name: &str,
795            _properties: Option<HashMap<Cow<'static, str>, Value>>,
796        ) -> GraphDBResult<()> {
797            Ok(())
798        }
799
800        async fn add_edges(&self, _edges: &[EdgeData]) -> GraphDBResult<()> {
801            Ok(())
802        }
803
804        async fn get_edges(&self, _node_id: &str) -> GraphDBResult<Vec<EdgeData>> {
805            Ok(vec![])
806        }
807
808        async fn get_neighbors(&self, node_id: &str) -> GraphDBResult<Vec<NodeData>> {
809            Ok(self.neighbors.get(node_id).cloned().unwrap_or_default())
810        }
811
812        async fn get_connections(
813            &self,
814            _node_id: &str,
815        ) -> GraphDBResult<Vec<(NodeData, HashMap<Cow<'static, str>, Value>, NodeData)>> {
816            Ok(vec![])
817        }
818
819        async fn get_graph_data(&self) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
820            Ok((self.nodes.clone(), self.edges.clone()))
821        }
822
823        async fn get_graph_metrics(
824            &self,
825            _include_optional: bool,
826        ) -> GraphDBResult<HashMap<Cow<'static, str>, Value>> {
827            Ok(HashMap::new())
828        }
829
830        async fn get_filtered_graph_data(
831            &self,
832            _attribute_filters: &HashMap<Cow<'static, str>, Vec<Value>>,
833        ) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
834            Ok((self.nodes.clone(), self.edges.clone()))
835        }
836
837        async fn get_nodeset_subgraph(
838            &self,
839            _node_type: &str,
840            _node_names: &[String],
841            _node_name_filter_operator: &str,
842        ) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
843            Ok((self.nodes.clone(), self.edges.clone()))
844        }
845    }
846
847    #[derive(Default)]
848    struct TestLlm {
849        completion_response: String,
850        interval_response: Option<QueryInterval>,
851        fail_structured_output: bool,
852        last_messages: Mutex<Vec<Message>>,
853        /// When set, `create_structured_output_with_messages_raw` returns this
854        /// value instead of serializing `interval_response`. Used by tests that
855        /// exercise the response_schema path in `get_completion`.
856        structured_completion_response: Mutex<Option<Value>>,
857        /// Messages captured by the most recent `create_structured_output_with_messages_raw` call.
858        last_structured_messages: Mutex<Vec<Message>>,
859    }
860
861    #[async_trait]
862    impl Llm for TestLlm {
863        async fn generate(
864            &self,
865            messages: Vec<Message>,
866            _options: Option<GenerationOptions>,
867        ) -> LlmResult<GenerationResponse> {
868            self.last_messages.lock().unwrap().clone_from(&messages);
869
870            Ok(GenerationResponse {
871                content: self.completion_response.clone(),
872                model: "test-model".to_string(),
873                usage: Some(TokenUsage {
874                    prompt_tokens: 1,
875                    completion_tokens: 1,
876                    total_tokens: 2,
877                }),
878                finish_reason: Some("stop".to_string()),
879            })
880        }
881
882        async fn create_structured_output_with_messages_raw(
883            &self,
884            messages: Vec<Message>,
885            _json_schema: &serde_json::Value,
886            _options: Option<GenerationOptions>,
887        ) -> LlmResult<serde_json::Value> {
888            self.last_structured_messages
889                .lock()
890                .unwrap()
891                .clone_from(&messages);
892
893            if self.fail_structured_output {
894                return Err(LlmError::ConfigError("forced failure".to_string()));
895            }
896
897            // If a custom structured completion response is set, return it.
898            if let Some(value) = self.structured_completion_response.lock().unwrap().clone() {
899                return Ok(value);
900            }
901
902            let response = self
903                .interval_response
904                .clone()
905                .ok_or_else(|| LlmError::ConfigError("missing interval response".to_string()))?;
906
907            serde_json::to_value(response).map_err(|error| LlmError::ConfigError(error.to_string()))
908        }
909
910        fn model(&self) -> &str {
911            "test-model"
912        }
913    }
914
915    fn event_node_data(id: &str, name: &str) -> NodeData {
916        HashMap::from([
917            (Cow::Borrowed("id"), json!(id)),
918            (Cow::Borrowed("name"), json!(name)),
919            (Cow::Borrowed("type"), json!("Event")),
920            (
921                Cow::Borrowed("description"),
922                json!(format!("Description for {name}")),
923            ),
924        ])
925    }
926
927    fn timestamp_node(id: &str, time_at_ms: i64) -> GraphNode {
928        (
929            id.to_string(),
930            HashMap::from([
931                (Cow::Borrowed("id"), json!(id)),
932                (Cow::Borrowed("type"), json!("Timestamp")),
933                (Cow::Borrowed("time_at"), json!(time_at_ms)),
934            ]),
935        )
936    }
937
938    fn event_graph_node(id: &str, name: &str) -> GraphNode {
939        (id.to_string(), event_node_data(id, name))
940    }
941
942    #[tokio::test]
943    async fn returns_temporal_event_context_when_interval_matches() {
944        // 2024-03-15 00:00:00 UTC in milliseconds
945        let launch_event_ms: i64 = 1710460800000;
946        // 2020-01-10 00:00:00 UTC in milliseconds
947        let old_event_ms: i64 = 1578614400000;
948
949        let ts_in_2024 = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa";
950        let ts_in_2020 = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb";
951        let event_launch = "11111111-1111-1111-1111-111111111111";
952        let event_old = "22222222-2222-2222-2222-222222222222";
953
954        let vector_db = Arc::new(TestVectorDb {
955            collections: HashMap::from([
956                (
957                    TestVectorDb::key("Entity", "name"),
958                    vec![SearchResult {
959                        id: uuid::Uuid::new_v4(),
960                        score: 0.8,
961                        metadata: HashMap::from([(String::from("type"), json!("entity"))]),
962                    }],
963                ),
964                (
965                    TestVectorDb::key("Event", "name"),
966                    vec![SearchResult {
967                        id: uuid::Uuid::parse_str(event_launch).unwrap(),
968                        score: 0.95,
969                        metadata: HashMap::new(),
970                    }],
971                ),
972            ]),
973        });
974
975        let embedding_engine = Arc::new(TestEmbeddingEngine);
976        let graph_db = Arc::new(TestGraphDb {
977            nodes: vec![
978                timestamp_node(ts_in_2024, launch_event_ms),
979                timestamp_node(ts_in_2020, old_event_ms),
980                event_graph_node(event_launch, "Launch event"),
981                event_graph_node(event_old, "Old event"),
982            ],
983            edges: vec![
984                (
985                    event_launch.to_string(),
986                    ts_in_2024.to_string(),
987                    "at".to_string(),
988                    HashMap::new(),
989                ),
990                (
991                    event_old.to_string(),
992                    ts_in_2020.to_string(),
993                    "at".to_string(),
994                    HashMap::new(),
995                ),
996            ],
997            neighbors: HashMap::from([
998                // The 2024 Timestamp node has the Launch event as a neighbor.
999                (
1000                    ts_in_2024.to_string(),
1001                    vec![event_node_data(event_launch, "Launch event")],
1002                ),
1003                // The 2020 Timestamp node has the Old event as a neighbor.
1004                (
1005                    ts_in_2020.to_string(),
1006                    vec![event_node_data(event_old, "Old event")],
1007                ),
1008            ]),
1009        });
1010        let llm = Arc::new(TestLlm {
1011            completion_response: "temporal answer".to_string(),
1012            interval_response: Some(QueryInterval {
1013                starts_at: Some("2024-01-01".to_string()),
1014                ends_at: Some("2024-12-31".to_string()),
1015            }),
1016            fail_structured_output: false,
1017            last_messages: Mutex::new(vec![]),
1018            structured_completion_response: Mutex::new(None),
1019            last_structured_messages: Mutex::new(vec![]),
1020        });
1021
1022        let retriever = TemporalRetriever::new(
1023            vector_db,
1024            embedding_engine,
1025            graph_db,
1026            llm,
1027            Some(5),
1028            Some(10),
1029            Some(0.0),
1030            None,
1031            None,
1032            None,
1033            None,
1034            None,
1035        );
1036
1037        let context = retriever
1038            .get_context("What happened in 2024?", &SearchParams::default())
1039            .await
1040            .unwrap();
1041
1042        assert_eq!(context.len(), 1);
1043        assert_eq!(
1044            context[0].payload.get("event_name").and_then(Value::as_str),
1045            Some("Launch event")
1046        );
1047    }
1048
1049    // ── parse_bound tests ──────────────────────────────────────────────
1050
1051    #[test]
1052    fn parse_bound_datetime_space_format() {
1053        use chrono::{TimeZone, Utc};
1054        let result = super::parse_bound("2024-01-15 10:30:00", false);
1055        assert_eq!(
1056            result,
1057            Some(Utc.with_ymd_and_hms(2024, 1, 15, 10, 30, 0).unwrap())
1058        );
1059    }
1060
1061    #[test]
1062    fn parse_bound_rfc3339_format() {
1063        use chrono::{TimeZone, Utc};
1064        let result = super::parse_bound("2024-01-15T10:30:00Z", false);
1065        assert_eq!(
1066            result,
1067            Some(Utc.with_ymd_and_hms(2024, 1, 15, 10, 30, 0).unwrap())
1068        );
1069    }
1070
1071    #[test]
1072    fn parse_bound_date_only_start() {
1073        use chrono::{TimeZone, Utc};
1074        let result = super::parse_bound("2024-03-15", false);
1075        assert_eq!(
1076            result,
1077            Some(Utc.with_ymd_and_hms(2024, 3, 15, 0, 0, 0).unwrap())
1078        );
1079    }
1080
1081    #[test]
1082    fn parse_bound_date_only_end() {
1083        use chrono::{TimeZone, Utc};
1084        let result = super::parse_bound("2024-03-15", true);
1085        assert_eq!(
1086            result,
1087            Some(Utc.with_ymd_and_hms(2024, 3, 15, 23, 59, 59).unwrap())
1088        );
1089    }
1090
1091    #[test]
1092    fn parse_bound_month_start() {
1093        use chrono::{TimeZone, Utc};
1094        let result = super::parse_bound("2024-03", false);
1095        assert_eq!(
1096            result,
1097            Some(Utc.with_ymd_and_hms(2024, 3, 1, 0, 0, 0).unwrap())
1098        );
1099    }
1100
1101    #[test]
1102    fn parse_bound_month_end() {
1103        use chrono::{TimeZone, Utc};
1104        let result = super::parse_bound("2024-03", true);
1105        assert_eq!(
1106            result,
1107            Some(Utc.with_ymd_and_hms(2024, 3, 31, 23, 59, 59).unwrap())
1108        );
1109    }
1110
1111    #[test]
1112    fn parse_bound_month_end_leap_year() {
1113        use chrono::{TimeZone, Utc};
1114        let result = super::parse_bound("2024-02", true);
1115        assert_eq!(
1116            result,
1117            Some(Utc.with_ymd_and_hms(2024, 2, 29, 23, 59, 59).unwrap())
1118        );
1119    }
1120
1121    #[test]
1122    fn parse_bound_month_end_non_leap_year() {
1123        use chrono::{TimeZone, Utc};
1124        let result = super::parse_bound("2023-02", true);
1125        assert_eq!(
1126            result,
1127            Some(Utc.with_ymd_and_hms(2023, 2, 28, 23, 59, 59).unwrap())
1128        );
1129    }
1130
1131    #[test]
1132    fn parse_bound_month_end_december_wrap() {
1133        use chrono::{TimeZone, Utc};
1134        let result = super::parse_bound("2024-12", true);
1135        assert_eq!(
1136            result,
1137            Some(Utc.with_ymd_and_hms(2024, 12, 31, 23, 59, 59).unwrap())
1138        );
1139    }
1140
1141    #[test]
1142    fn parse_bound_year_start() {
1143        use chrono::{TimeZone, Utc};
1144        let result = super::parse_bound("2024", false);
1145        assert_eq!(
1146            result,
1147            Some(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap())
1148        );
1149    }
1150
1151    #[test]
1152    fn parse_bound_year_end() {
1153        use chrono::{TimeZone, Utc};
1154        let result = super::parse_bound("2024", true);
1155        assert_eq!(
1156            result,
1157            Some(Utc.with_ymd_and_hms(2024, 12, 31, 23, 59, 59).unwrap())
1158        );
1159    }
1160
1161    #[test]
1162    fn parse_bound_empty_and_whitespace_returns_none() {
1163        assert_eq!(super::parse_bound("", false), None);
1164        assert_eq!(super::parse_bound("  ", false), None);
1165    }
1166
1167    #[test]
1168    fn parse_bound_invalid_input_returns_none() {
1169        assert_eq!(super::parse_bound("not-a-date", false), None);
1170        assert_eq!(super::parse_bound("abc", false), None);
1171    }
1172
1173    // ── is_within_interval_ms tests ───────────────────────────────────
1174
1175    #[test]
1176    fn is_within_interval_ms_basic_cases() {
1177        use super::is_within_interval_ms;
1178
1179        // In range
1180        assert!(is_within_interval_ms(500, Some(100), Some(1000)));
1181        // At lower boundary (inclusive)
1182        assert!(is_within_interval_ms(100, Some(100), Some(1000)));
1183        // At upper boundary (inclusive)
1184        assert!(is_within_interval_ms(1000, Some(100), Some(1000)));
1185        // Below range
1186        assert!(!is_within_interval_ms(50, Some(100), Some(1000)));
1187        // Above range
1188        assert!(!is_within_interval_ms(1500, Some(100), Some(1000)));
1189    }
1190
1191    #[test]
1192    fn is_within_interval_ms_open_ended_bounds() {
1193        use super::is_within_interval_ms;
1194
1195        // No lower bound (open start)
1196        assert!(is_within_interval_ms(50, None, Some(1000)));
1197        assert!(!is_within_interval_ms(1500, None, Some(1000)));
1198
1199        // No upper bound (open end)
1200        assert!(is_within_interval_ms(1500, Some(100), None));
1201        assert!(!is_within_interval_ms(50, Some(100), None));
1202
1203        // Both bounds None (everything matches)
1204        assert!(is_within_interval_ms(0, None, None));
1205        assert!(is_within_interval_ms(i64::MAX, None, None));
1206        assert!(is_within_interval_ms(i64::MIN, None, None));
1207    }
1208
1209    // ── QueryInterval::parse tests ────────────────────────────────────
1210
1211    #[test]
1212    fn query_interval_parse_both_bounds() {
1213        use chrono::{TimeZone, Utc};
1214
1215        let qi = QueryInterval {
1216            starts_at: Some("2024-01-01".to_string()),
1217            ends_at: Some("2024-12-31".to_string()),
1218        };
1219        let parsed = qi.parse();
1220        assert_eq!(
1221            parsed.start,
1222            Some(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap())
1223        );
1224        assert_eq!(
1225            parsed.end,
1226            Some(Utc.with_ymd_and_hms(2024, 12, 31, 23, 59, 59).unwrap())
1227        );
1228    }
1229
1230    #[test]
1231    fn query_interval_parse_none_bounds() {
1232        let qi = QueryInterval {
1233            starts_at: None,
1234            ends_at: None,
1235        };
1236        let parsed = qi.parse();
1237        assert!(parsed.start.is_none());
1238        assert!(parsed.end.is_none());
1239    }
1240
1241    #[test]
1242    fn query_interval_parse_partial_bounds() {
1243        use chrono::{TimeZone, Utc};
1244
1245        // Only starts_at
1246        let qi = QueryInterval {
1247            starts_at: Some("2024-06".to_string()),
1248            ends_at: None,
1249        };
1250        let parsed = qi.parse();
1251        assert_eq!(
1252            parsed.start,
1253            Some(Utc.with_ymd_and_hms(2024, 6, 1, 0, 0, 0).unwrap())
1254        );
1255        assert!(parsed.end.is_none());
1256
1257        // Only ends_at
1258        let qi = QueryInterval {
1259            starts_at: None,
1260            ends_at: Some("2024".to_string()),
1261        };
1262        let parsed = qi.parse();
1263        assert!(parsed.start.is_none());
1264        assert_eq!(
1265            parsed.end,
1266            Some(Utc.with_ymd_and_hms(2024, 12, 31, 23, 59, 59).unwrap())
1267        );
1268    }
1269
1270    #[tokio::test]
1271    async fn falls_back_to_graph_context_when_interval_extraction_fails() {
1272        let vector_db = Arc::new(TestVectorDb {
1273            collections: HashMap::from([(
1274                TestVectorDb::key("Entity", "name"),
1275                vec![SearchResult {
1276                    id: uuid::Uuid::parse_str("33333333-3333-3333-3333-333333333333").unwrap(),
1277                    score: 0.9,
1278                    metadata: HashMap::new(),
1279                }],
1280            )]),
1281        });
1282
1283        let embedding_engine = Arc::new(TestEmbeddingEngine);
1284        let graph_db = Arc::new(TestGraphDb {
1285            nodes: vec![
1286                (
1287                    "33333333-3333-3333-3333-333333333333".to_string(),
1288                    HashMap::from([
1289                        (
1290                            Cow::Borrowed("id"),
1291                            json!("33333333-3333-3333-3333-333333333333"),
1292                        ),
1293                        (Cow::Borrowed("name"), json!("Entity A")),
1294                    ]),
1295                ),
1296                (
1297                    "44444444-4444-4444-4444-444444444444".to_string(),
1298                    HashMap::from([
1299                        (
1300                            Cow::Borrowed("id"),
1301                            json!("44444444-4444-4444-4444-444444444444"),
1302                        ),
1303                        (Cow::Borrowed("name"), json!("Entity B")),
1304                    ]),
1305                ),
1306            ],
1307            edges: vec![(
1308                "33333333-3333-3333-3333-333333333333".to_string(),
1309                "44444444-4444-4444-4444-444444444444".to_string(),
1310                "connected_to".to_string(),
1311                HashMap::new(),
1312            )],
1313            neighbors: HashMap::new(),
1314        });
1315        let llm = Arc::new(TestLlm {
1316            completion_response: "fallback answer".to_string(),
1317            interval_response: None,
1318            fail_structured_output: true,
1319            last_messages: Mutex::new(vec![]),
1320            structured_completion_response: Mutex::new(None),
1321            last_structured_messages: Mutex::new(vec![]),
1322        });
1323
1324        let retriever = TemporalRetriever::new(
1325            vector_db,
1326            embedding_engine,
1327            graph_db,
1328            llm,
1329            Some(3),
1330            Some(10),
1331            Some(0.0),
1332            None,
1333            None,
1334            None,
1335            None,
1336            None,
1337        );
1338
1339        let context = retriever
1340            .get_context("What happened?", &SearchParams::default())
1341            .await
1342            .unwrap();
1343        assert_eq!(context.len(), 1);
1344        assert_eq!(
1345            context[0]
1346                .payload
1347                .get("relationship")
1348                .and_then(Value::as_str),
1349            Some("connected_to")
1350        );
1351    }
1352
1353    fn build_retriever_with_llm(llm: TestLlm) -> TemporalRetriever {
1354        TemporalRetriever::new(
1355            Arc::new(TestVectorDb {
1356                collections: HashMap::new(),
1357            }),
1358            Arc::new(TestEmbeddingEngine),
1359            Arc::new(TestGraphDb {
1360                nodes: vec![],
1361                edges: vec![],
1362                neighbors: HashMap::new(),
1363            }),
1364            Arc::new(llm),
1365            Some(5),
1366            Some(10),
1367            Some(0.0),
1368            None,
1369            None,
1370            None,
1371            None,
1372            None,
1373        )
1374    }
1375
1376    #[tokio::test]
1377    async fn extract_interval_returns_parsed_interval_from_llm() {
1378        let llm = TestLlm {
1379            completion_response: String::new(),
1380            interval_response: Some(QueryInterval {
1381                starts_at: Some("2024-01-01".into()),
1382                ends_at: Some("2024-12-31".into()),
1383            }),
1384            fail_structured_output: false,
1385            last_messages: Mutex::new(vec![]),
1386            ..TestLlm::default()
1387        };
1388        let retriever = build_retriever_with_llm(llm);
1389
1390        let result = retriever
1391            .extract_interval("What happened in 2024?")
1392            .await
1393            .unwrap();
1394
1395        let parsed = result.expect("should return Some(ParsedInterval)");
1396        assert_eq!(
1397            parsed.start,
1398            Some(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap())
1399        );
1400        assert_eq!(
1401            parsed.end,
1402            Some(Utc.with_ymd_and_hms(2024, 12, 31, 23, 59, 59).unwrap())
1403        );
1404    }
1405
1406    #[tokio::test]
1407    async fn extract_interval_returns_none_when_llm_returns_none_none() {
1408        let llm = TestLlm {
1409            completion_response: String::new(),
1410            interval_response: Some(QueryInterval {
1411                starts_at: None,
1412                ends_at: None,
1413            }),
1414            fail_structured_output: false,
1415            last_messages: Mutex::new(vec![]),
1416            ..TestLlm::default()
1417        };
1418        let retriever = build_retriever_with_llm(llm);
1419
1420        let result = retriever
1421            .extract_interval("Who is Einstein?")
1422            .await
1423            .unwrap();
1424
1425        assert!(
1426            result.is_none(),
1427            "both fields None means no interval detected"
1428        );
1429    }
1430
1431    #[tokio::test]
1432    async fn extract_interval_returns_none_when_llm_fails() {
1433        let llm = TestLlm {
1434            completion_response: String::new(),
1435            interval_response: None,
1436            fail_structured_output: true,
1437            last_messages: Mutex::new(vec![]),
1438            ..TestLlm::default()
1439        };
1440        let retriever = build_retriever_with_llm(llm);
1441
1442        let result = retriever.extract_interval("What happened?").await.unwrap();
1443
1444        assert!(result.is_none(), "error should be swallowed gracefully");
1445    }
1446
1447    #[tokio::test]
1448    async fn extract_interval_with_only_starts_at() {
1449        let llm = TestLlm {
1450            completion_response: String::new(),
1451            interval_response: Some(QueryInterval {
1452                starts_at: Some("2024-01-01".into()),
1453                ends_at: None,
1454            }),
1455            fail_structured_output: false,
1456            last_messages: Mutex::new(vec![]),
1457            ..TestLlm::default()
1458        };
1459        let retriever = build_retriever_with_llm(llm);
1460
1461        let result = retriever
1462            .extract_interval("What happened after 2024?")
1463            .await
1464            .unwrap();
1465
1466        let parsed = result.expect("should return Some(ParsedInterval)");
1467        assert_eq!(
1468            parsed.start,
1469            Some(Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap())
1470        );
1471        assert_eq!(parsed.end, None);
1472    }
1473
1474    #[tokio::test]
1475    async fn extract_interval_with_only_ends_at() {
1476        let llm = TestLlm {
1477            completion_response: String::new(),
1478            interval_response: Some(QueryInterval {
1479                starts_at: None,
1480                ends_at: Some("2024-12-31".into()),
1481            }),
1482            fail_structured_output: false,
1483            last_messages: Mutex::new(vec![]),
1484            ..TestLlm::default()
1485        };
1486        let retriever = build_retriever_with_llm(llm);
1487
1488        let result = retriever
1489            .extract_interval("What happened before 2025?")
1490            .await
1491            .unwrap();
1492
1493        let parsed = result.expect("should return Some(ParsedInterval)");
1494        assert_eq!(parsed.start, None);
1495        assert_eq!(
1496            parsed.end,
1497            Some(Utc.with_ymd_and_hms(2024, 12, 31, 23, 59, 59).unwrap())
1498        );
1499    }
1500
1501    // ---- Phase 3: get_context edge case tests ----
1502
1503    fn build_retriever(
1504        vector_db: Arc<dyn VectorDB>,
1505        graph_db: Arc<dyn GraphDBTrait>,
1506        llm: Arc<dyn Llm>,
1507    ) -> TemporalRetriever {
1508        TemporalRetriever::new(
1509            vector_db,
1510            Arc::new(TestEmbeddingEngine),
1511            graph_db,
1512            llm,
1513            Some(10),
1514            Some(100),
1515            Some(0.0),
1516            None,
1517            None,
1518            None,
1519            None,
1520            None,
1521        )
1522    }
1523
1524    #[tokio::test]
1525    async fn get_context_with_time_from_only() {
1526        // Timestamps: 2020-06-15, 2024-01-15, 2024-07-15
1527        let ts_2020 = "aa000000-0000-0000-0000-000000000001";
1528        let ts_2024_jan = "aa000000-0000-0000-0000-000000000002";
1529        let ts_2024_jul = "aa000000-0000-0000-0000-000000000003";
1530        let ev_old = "bb000000-0000-0000-0000-000000000001";
1531        let ev_jan = "bb000000-0000-0000-0000-000000000002";
1532        let ev_jul = "bb000000-0000-0000-0000-000000000003";
1533
1534        let graph_db = Arc::new(TestGraphDb {
1535            nodes: vec![
1536                timestamp_node(ts_2020, 1592179200000),
1537                timestamp_node(ts_2024_jan, 1705276800000),
1538                timestamp_node(ts_2024_jul, 1721001600000),
1539                event_graph_node(ev_old, "Old event"),
1540                event_graph_node(ev_jan, "Jan event"),
1541                event_graph_node(ev_jul, "Jul event"),
1542            ],
1543            edges: vec![
1544                (
1545                    ev_old.to_string(),
1546                    ts_2020.to_string(),
1547                    "at".to_string(),
1548                    HashMap::new(),
1549                ),
1550                (
1551                    ev_jan.to_string(),
1552                    ts_2024_jan.to_string(),
1553                    "at".to_string(),
1554                    HashMap::new(),
1555                ),
1556                (
1557                    ev_jul.to_string(),
1558                    ts_2024_jul.to_string(),
1559                    "at".to_string(),
1560                    HashMap::new(),
1561                ),
1562            ],
1563            neighbors: HashMap::from([
1564                (
1565                    ts_2020.to_string(),
1566                    vec![event_node_data(ev_old, "Old event")],
1567                ),
1568                (
1569                    ts_2024_jan.to_string(),
1570                    vec![event_node_data(ev_jan, "Jan event")],
1571                ),
1572                (
1573                    ts_2024_jul.to_string(),
1574                    vec![event_node_data(ev_jul, "Jul event")],
1575                ),
1576            ]),
1577        });
1578
1579        let vector_db = Arc::new(TestVectorDb {
1580            collections: HashMap::from([(
1581                TestVectorDb::key("Event", "name"),
1582                vec![
1583                    SearchResult {
1584                        id: uuid::Uuid::parse_str(ev_jan).unwrap(),
1585                        score: 0.9,
1586                        metadata: HashMap::new(),
1587                    },
1588                    SearchResult {
1589                        id: uuid::Uuid::parse_str(ev_jul).unwrap(),
1590                        score: 0.85,
1591                        metadata: HashMap::new(),
1592                    },
1593                ],
1594            )]),
1595        });
1596
1597        let llm = Arc::new(TestLlm {
1598            completion_response: String::new(),
1599            interval_response: Some(QueryInterval {
1600                starts_at: Some("2024-01-01".to_string()),
1601                ends_at: None,
1602            }),
1603            fail_structured_output: false,
1604            last_messages: Mutex::new(vec![]),
1605            ..TestLlm::default()
1606        });
1607
1608        let retriever = build_retriever(vector_db, graph_db, llm);
1609
1610        let context = retriever
1611            .get_context("What happened since 2024?", &SearchParams::default())
1612            .await
1613            .unwrap();
1614
1615        assert_eq!(context.len(), 2, "should have 2 items (both 2024 events)");
1616
1617        let event_names: Vec<&str> = context
1618            .iter()
1619            .filter_map(|item| item.payload.get("event_name").and_then(Value::as_str))
1620            .collect();
1621        assert!(
1622            event_names.contains(&"Jan event"),
1623            "Jan event should be present"
1624        );
1625        assert!(
1626            event_names.contains(&"Jul event"),
1627            "Jul event should be present"
1628        );
1629        assert!(
1630            !event_names.contains(&"Old event"),
1631            "Old event should NOT be present"
1632        );
1633    }
1634
1635    #[tokio::test]
1636    async fn get_context_with_time_to_only() {
1637        let ts_2020 = "aa000000-0000-0000-0000-000000000001";
1638        let ts_2024_jan = "aa000000-0000-0000-0000-000000000002";
1639        let ts_2024_jul = "aa000000-0000-0000-0000-000000000003";
1640        let ev_old = "bb000000-0000-0000-0000-000000000001";
1641        let ev_jan = "bb000000-0000-0000-0000-000000000002";
1642        let ev_jul = "bb000000-0000-0000-0000-000000000003";
1643
1644        let graph_db = Arc::new(TestGraphDb {
1645            nodes: vec![
1646                timestamp_node(ts_2020, 1592179200000),
1647                timestamp_node(ts_2024_jan, 1705276800000),
1648                timestamp_node(ts_2024_jul, 1721001600000),
1649                event_graph_node(ev_old, "Old event"),
1650                event_graph_node(ev_jan, "Jan event"),
1651                event_graph_node(ev_jul, "Jul event"),
1652            ],
1653            edges: vec![
1654                (
1655                    ev_old.to_string(),
1656                    ts_2020.to_string(),
1657                    "at".to_string(),
1658                    HashMap::new(),
1659                ),
1660                (
1661                    ev_jan.to_string(),
1662                    ts_2024_jan.to_string(),
1663                    "at".to_string(),
1664                    HashMap::new(),
1665                ),
1666                (
1667                    ev_jul.to_string(),
1668                    ts_2024_jul.to_string(),
1669                    "at".to_string(),
1670                    HashMap::new(),
1671                ),
1672            ],
1673            neighbors: HashMap::from([
1674                (
1675                    ts_2020.to_string(),
1676                    vec![event_node_data(ev_old, "Old event")],
1677                ),
1678                (
1679                    ts_2024_jan.to_string(),
1680                    vec![event_node_data(ev_jan, "Jan event")],
1681                ),
1682                (
1683                    ts_2024_jul.to_string(),
1684                    vec![event_node_data(ev_jul, "Jul event")],
1685                ),
1686            ]),
1687        });
1688
1689        let vector_db = Arc::new(TestVectorDb {
1690            collections: HashMap::from([(
1691                TestVectorDb::key("Event", "name"),
1692                vec![SearchResult {
1693                    id: uuid::Uuid::parse_str(ev_old).unwrap(),
1694                    score: 0.88,
1695                    metadata: HashMap::new(),
1696                }],
1697            )]),
1698        });
1699
1700        let llm = Arc::new(TestLlm {
1701            completion_response: String::new(),
1702            interval_response: Some(QueryInterval {
1703                starts_at: None,
1704                ends_at: Some("2021-12-31".to_string()),
1705            }),
1706            fail_structured_output: false,
1707            last_messages: Mutex::new(vec![]),
1708            ..TestLlm::default()
1709        });
1710
1711        let retriever = build_retriever(vector_db, graph_db, llm);
1712
1713        let context = retriever
1714            .get_context("What happened before 2022?", &SearchParams::default())
1715            .await
1716            .unwrap();
1717
1718        assert_eq!(context.len(), 1, "should have 1 item (only the 2020 event)");
1719        assert_eq!(
1720            context[0].payload.get("event_name").and_then(Value::as_str),
1721            Some("Old event")
1722        );
1723    }
1724
1725    #[tokio::test]
1726    async fn get_context_falls_back_when_no_events_in_range() {
1727        let ts_2020 = "aa000000-0000-0000-0000-000000000010";
1728        let ts_2021 = "aa000000-0000-0000-0000-000000000011";
1729        let ev_2020 = "bb000000-0000-0000-0000-000000000010";
1730        let ev_2021 = "bb000000-0000-0000-0000-000000000011";
1731        let entity_a = "cc000000-0000-0000-0000-000000000001";
1732        let entity_b = "cc000000-0000-0000-0000-000000000002";
1733
1734        let graph_db = Arc::new(TestGraphDb {
1735            nodes: vec![
1736                timestamp_node(ts_2020, 1577836800000), // 2020-01-01
1737                timestamp_node(ts_2021, 1609459200000), // 2021-01-01
1738                event_graph_node(ev_2020, "Event 2020"),
1739                event_graph_node(ev_2021, "Event 2021"),
1740                (
1741                    entity_a.to_string(),
1742                    HashMap::from([
1743                        (Cow::Borrowed("id"), json!(entity_a)),
1744                        (Cow::Borrowed("name"), json!("Entity A")),
1745                        (Cow::Borrowed("type"), json!("Entity")),
1746                    ]),
1747                ),
1748                (
1749                    entity_b.to_string(),
1750                    HashMap::from([
1751                        (Cow::Borrowed("id"), json!(entity_b)),
1752                        (Cow::Borrowed("name"), json!("Entity B")),
1753                        (Cow::Borrowed("type"), json!("Entity")),
1754                    ]),
1755                ),
1756            ],
1757            edges: vec![
1758                (
1759                    ev_2020.to_string(),
1760                    ts_2020.to_string(),
1761                    "at".to_string(),
1762                    HashMap::new(),
1763                ),
1764                (
1765                    ev_2021.to_string(),
1766                    ts_2021.to_string(),
1767                    "at".to_string(),
1768                    HashMap::new(),
1769                ),
1770                (
1771                    entity_a.to_string(),
1772                    entity_b.to_string(),
1773                    "connected_to".to_string(),
1774                    HashMap::new(),
1775                ),
1776            ],
1777            neighbors: HashMap::from([
1778                (
1779                    ts_2020.to_string(),
1780                    vec![event_node_data(ev_2020, "Event 2020")],
1781                ),
1782                (
1783                    ts_2021.to_string(),
1784                    vec![event_node_data(ev_2021, "Event 2021")],
1785                ),
1786            ]),
1787        });
1788
1789        let vector_db = Arc::new(TestVectorDb {
1790            collections: HashMap::from([(
1791                TestVectorDb::key("Entity", "name"),
1792                vec![SearchResult {
1793                    id: uuid::Uuid::parse_str(entity_a).unwrap(),
1794                    score: 0.9,
1795                    metadata: HashMap::new(),
1796                }],
1797            )]),
1798        });
1799
1800        let llm = Arc::new(TestLlm {
1801            completion_response: String::new(),
1802            interval_response: Some(QueryInterval {
1803                starts_at: Some("2030".to_string()),
1804                ends_at: Some("2031".to_string()),
1805            }),
1806            fail_structured_output: false,
1807            last_messages: Mutex::new(vec![]),
1808            ..TestLlm::default()
1809        });
1810
1811        let retriever = build_retriever(vector_db, graph_db, llm);
1812
1813        let context = retriever
1814            .get_context("What happened in 2030?", &SearchParams::default())
1815            .await
1816            .unwrap();
1817
1818        // Falls back to graph triplet search; context should have "relationship" in payload
1819        assert!(
1820            !context.is_empty(),
1821            "fallback should produce at least one result"
1822        );
1823        assert!(
1824            context
1825                .iter()
1826                .any(|item| item.payload.get("relationship").is_some()),
1827            "fallback context items should have 'relationship' in payload"
1828        );
1829        assert!(
1830            context
1831                .iter()
1832                .all(|item| item.payload.get("event_id").is_none()),
1833            "fallback context should NOT have 'event_id' (those are temporal items)"
1834        );
1835    }
1836
1837    #[tokio::test]
1838    async fn get_context_on_empty_graph() {
1839        let graph_db = Arc::new(TestGraphDb {
1840            nodes: vec![],
1841            edges: vec![],
1842            neighbors: HashMap::new(),
1843        });
1844
1845        let vector_db = Arc::new(TestVectorDb {
1846            collections: HashMap::new(),
1847        });
1848
1849        // LLM won't be called since the graph is empty
1850        let llm = Arc::new(TestLlm {
1851            completion_response: String::new(),
1852            interval_response: None,
1853            fail_structured_output: false,
1854            last_messages: Mutex::new(vec![]),
1855            ..TestLlm::default()
1856        });
1857
1858        let retriever = build_retriever(vector_db, graph_db, llm);
1859
1860        let context = retriever
1861            .get_context("Anything?", &SearchParams::default())
1862            .await
1863            .unwrap();
1864
1865        assert!(
1866            context.is_empty(),
1867            "empty graph should return empty context"
1868        );
1869    }
1870
1871    #[tokio::test]
1872    async fn get_context_respects_top_k() {
1873        let mut nodes = Vec::new();
1874        let mut edges = Vec::new();
1875        let mut neighbors = HashMap::new();
1876        let mut vector_results = Vec::new();
1877
1878        for i in 1..=5 {
1879            let ts_id = format!("aa000000-0000-0000-0000-0000000000{i:02}");
1880            let ev_id = format!("bb000000-0000-0000-0000-0000000000{i:02}");
1881            let ev_name = format!("Event {i}");
1882            // All in 2024: Jan through May
1883            let time_ms = 1704067200000_i64 + (i as i64 - 1) * 30 * 86400 * 1000;
1884
1885            nodes.push(timestamp_node(&ts_id, time_ms));
1886            nodes.push(event_graph_node(&ev_id, &ev_name));
1887            edges.push((
1888                ev_id.clone(),
1889                ts_id.clone(),
1890                "at".to_string(),
1891                HashMap::new(),
1892            ));
1893            neighbors.insert(ts_id, vec![event_node_data(&ev_id, &ev_name)]);
1894            vector_results.push(SearchResult {
1895                id: uuid::Uuid::parse_str(&ev_id).unwrap(),
1896                score: 0.9 - (i as f32 * 0.01),
1897                metadata: HashMap::new(),
1898            });
1899        }
1900
1901        let graph_db = Arc::new(TestGraphDb {
1902            nodes,
1903            edges,
1904            neighbors,
1905        });
1906
1907        let vector_db = Arc::new(TestVectorDb {
1908            collections: HashMap::from([(TestVectorDb::key("Event", "name"), vector_results)]),
1909        });
1910
1911        let llm = Arc::new(TestLlm {
1912            completion_response: String::new(),
1913            interval_response: Some(QueryInterval {
1914                starts_at: Some("2024-01-01".to_string()),
1915                ends_at: Some("2024-12-31".to_string()),
1916            }),
1917            fail_structured_output: false,
1918            last_messages: Mutex::new(vec![]),
1919            ..TestLlm::default()
1920        });
1921
1922        let retriever = build_retriever(vector_db, graph_db, llm);
1923
1924        let params = SearchParams {
1925            top_k: Some(2),
1926            ..Default::default()
1927        };
1928
1929        let context = retriever
1930            .get_context("What happened in 2024?", &params)
1931            .await
1932            .unwrap();
1933
1934        assert_eq!(context.len(), 2, "top_k=2 should limit results to 2 items");
1935    }
1936
1937    #[tokio::test]
1938    async fn get_context_2hop_interval_traversal() {
1939        let ts1 = "aa000000-0000-0000-0000-000000000020";
1940        let ts2 = "aa000000-0000-0000-0000-000000000021";
1941        let interval_id = "ii000000-0000-0000-0000-000000000001";
1942        let event_id = "bb000000-0000-0000-0000-000000000020";
1943
1944        let interval_node_data: NodeData = HashMap::from([
1945            (Cow::Borrowed("id"), json!(interval_id)),
1946            (Cow::Borrowed("type"), json!("Interval")),
1947            (Cow::Borrowed("name"), json!("Feb-Mar 2024")),
1948        ]);
1949
1950        let graph_db = Arc::new(TestGraphDb {
1951            nodes: vec![
1952                timestamp_node(ts1, 1706745600000), // 2024-02-01
1953                timestamp_node(ts2, 1709251200000), // 2024-03-01
1954                (interval_id.to_string(), interval_node_data.clone()),
1955                event_graph_node(event_id, "Team Meeting"),
1956            ],
1957            edges: vec![(
1958                event_id.to_string(),
1959                interval_id.to_string(),
1960                "during".to_string(),
1961                HashMap::new(),
1962            )],
1963            neighbors: HashMap::from([
1964                // Timestamp T1 -> Interval (1st hop)
1965                (ts1.to_string(), vec![interval_node_data.clone()]),
1966                // Timestamp T2 -> Interval (1st hop)
1967                (ts2.to_string(), vec![interval_node_data]),
1968                // Interval -> Event (2nd hop)
1969                (
1970                    interval_id.to_string(),
1971                    vec![event_node_data(event_id, "Team Meeting")],
1972                ),
1973            ]),
1974        });
1975
1976        let vector_db = Arc::new(TestVectorDb {
1977            collections: HashMap::from([(
1978                TestVectorDb::key("Event", "name"),
1979                vec![SearchResult {
1980                    id: uuid::Uuid::parse_str(event_id).unwrap(),
1981                    score: 0.92,
1982                    metadata: HashMap::new(),
1983                }],
1984            )]),
1985        });
1986
1987        let llm = Arc::new(TestLlm {
1988            completion_response: String::new(),
1989            interval_response: Some(QueryInterval {
1990                starts_at: Some("2024-02".to_string()),
1991                ends_at: Some("2024-03".to_string()),
1992            }),
1993            fail_structured_output: false,
1994            last_messages: Mutex::new(vec![]),
1995            ..TestLlm::default()
1996        });
1997
1998        let retriever = build_retriever(vector_db, graph_db, llm);
1999
2000        let context = retriever
2001            .get_context(
2002                "What meetings happened in Feb-Mar 2024?",
2003                &SearchParams::default(),
2004            )
2005            .await
2006            .unwrap();
2007
2008        assert_eq!(
2009            context.len(),
2010            1,
2011            "should find 1 event via 2-hop traversal (Timestamp -> Interval -> Event)"
2012        );
2013        assert_eq!(
2014            context[0].payload.get("event_name").and_then(Value::as_str),
2015            Some("Team Meeting")
2016        );
2017    }
2018
2019    // -----------------------------------------------------------------------
2020    // Phase 4 — get_completion unit tests
2021    // -----------------------------------------------------------------------
2022
2023    fn default_session() -> SessionContext {
2024        SessionContext {
2025            session_id: None,
2026            history: vec![],
2027            formatted_history: String::new(),
2028            graph_context: None,
2029        }
2030    }
2031
2032    fn make_event_context() -> Vec<SearchItem> {
2033        vec![
2034            SearchItem {
2035                id: None,
2036                score: Some(0.9),
2037                payload: json!({
2038                    "event_id": "evt-1",
2039                    "event_name": "Product Launch",
2040                    "event_description": "Launched the new product",
2041                    "event_time": "2024-03-15",
2042                }),
2043            },
2044            SearchItem {
2045                id: None,
2046                score: Some(0.7),
2047                payload: json!({
2048                    "event_id": "evt-2",
2049                    "event_name": "Quarterly Review",
2050                    "event_description": "Reviewed Q1 results",
2051                    "event_time": "2024-04-01",
2052                }),
2053            },
2054        ]
2055    }
2056
2057    fn simple_retriever(llm: Arc<TestLlm>) -> TemporalRetriever {
2058        let vector_db = Arc::new(TestVectorDb {
2059            collections: HashMap::new(),
2060        });
2061        let embedding_engine = Arc::new(TestEmbeddingEngine);
2062        let graph_db = Arc::new(TestGraphDb {
2063            nodes: vec![],
2064            edges: vec![],
2065            neighbors: HashMap::new(),
2066        });
2067
2068        TemporalRetriever::new(
2069            vector_db,
2070            embedding_engine,
2071            graph_db,
2072            llm,
2073            Some(5),
2074            Some(10),
2075            Some(0.0),
2076            None,
2077            None,
2078            None,
2079            None,
2080            None,
2081        )
2082    }
2083
2084    #[tokio::test]
2085    async fn get_completion_generates_text_from_context() {
2086        let llm = Arc::new(TestLlm {
2087            completion_response: "The product was launched in March 2024.".to_string(),
2088            last_messages: Mutex::new(vec![]),
2089            last_structured_messages: Mutex::new(vec![]),
2090            ..TestLlm::default()
2091        });
2092
2093        let retriever = simple_retriever(llm);
2094        let context = make_event_context();
2095        let session = default_session();
2096
2097        let output = retriever
2098            .get_completion(
2099                "What happened in 2024?",
2100                Some(context),
2101                &session,
2102                &SearchParams::default(),
2103            )
2104            .await
2105            .unwrap();
2106
2107        match output {
2108            SearchOutput::Text(text) => {
2109                assert_eq!(text, "The product was launched in March 2024.");
2110            }
2111            other => panic!("Expected SearchOutput::Text, got {other:?}"),
2112        }
2113    }
2114
2115    #[tokio::test]
2116    async fn get_completion_with_provided_context_passes_to_llm() {
2117        let llm = Arc::new(TestLlm {
2118            completion_response: "completion result".to_string(),
2119            last_messages: Mutex::new(vec![]),
2120            last_structured_messages: Mutex::new(vec![]),
2121            ..TestLlm::default()
2122        });
2123
2124        let retriever = simple_retriever(Arc::clone(&llm));
2125        let context = make_event_context();
2126        let session = default_session();
2127
2128        retriever
2129            .get_completion(
2130                "What happened in 2024?",
2131                Some(context),
2132                &session,
2133                &SearchParams::default(),
2134            )
2135            .await
2136            .unwrap();
2137
2138        let messages = llm.last_messages.lock().unwrap();
2139        assert_eq!(messages.len(), 2, "Expected system + user messages");
2140
2141        // The user prompt should contain the temporal context text.
2142        let user_msg = &messages[1].content;
2143        assert!(
2144            user_msg.contains("Product Launch"),
2145            "User prompt should contain event name from context"
2146        );
2147        assert!(
2148            user_msg.contains("Quarterly Review"),
2149            "User prompt should contain second event name from context"
2150        );
2151    }
2152
2153    #[tokio::test]
2154    async fn get_completion_without_context_calls_get_context() {
2155        // Setup a graph with temporal data so get_context can produce context.
2156        let launch_event_ms: i64 = 1710460800000; // 2024-03-15 UTC
2157        let ts_id = "ts-aaa";
2158        let event_id = "ev-111";
2159
2160        let vector_db = Arc::new(TestVectorDb {
2161            collections: HashMap::from([(
2162                TestVectorDb::key("Entity", "name"),
2163                vec![SearchResult {
2164                    id: Uuid::new_v4(),
2165                    score: 0.8,
2166                    metadata: HashMap::new(),
2167                }],
2168            )]),
2169        });
2170
2171        let embedding_engine = Arc::new(TestEmbeddingEngine);
2172        let graph_db = Arc::new(TestGraphDb {
2173            nodes: vec![
2174                timestamp_node(ts_id, launch_event_ms),
2175                event_graph_node(event_id, "Launch"),
2176            ],
2177            edges: vec![],
2178            neighbors: HashMap::from([(
2179                ts_id.to_string(),
2180                vec![event_node_data(event_id, "Launch")],
2181            )]),
2182        });
2183
2184        let llm = Arc::new(TestLlm {
2185            completion_response: "answer from internal context".to_string(),
2186            interval_response: Some(QueryInterval {
2187                starts_at: Some("2024-01-01".to_string()),
2188                ends_at: Some("2024-12-31".to_string()),
2189            }),
2190            last_messages: Mutex::new(vec![]),
2191            last_structured_messages: Mutex::new(vec![]),
2192            ..TestLlm::default()
2193        });
2194
2195        let retriever = TemporalRetriever::new(
2196            vector_db,
2197            embedding_engine,
2198            graph_db,
2199            llm.clone(),
2200            Some(5),
2201            Some(10),
2202            Some(0.0),
2203            None,
2204            None,
2205            None,
2206            None,
2207            None,
2208        );
2209
2210        let session = default_session();
2211
2212        let output = retriever
2213            .get_completion(
2214                "What happened in 2024?",
2215                None,
2216                &session,
2217                &SearchParams::default(),
2218            )
2219            .await
2220            .unwrap();
2221
2222        match output {
2223            SearchOutput::Text(text) => {
2224                assert_eq!(text, "answer from internal context");
2225            }
2226            other => panic!("Expected SearchOutput::Text, got {other:?}"),
2227        }
2228
2229        // Verify that the LLM's generate was called with messages containing context.
2230        let messages = llm.last_messages.lock().unwrap();
2231        assert!(!messages.is_empty(), "LLM generate should have been called");
2232        let user_msg = &messages[1].content;
2233        assert!(
2234            user_msg.contains("Launch"),
2235            "User prompt should reference the event from internal context"
2236        );
2237    }
2238
2239    #[tokio::test]
2240    async fn get_completion_with_response_schema() {
2241        let structured_value = json!({
2242            "answer": "The product launched in 2024",
2243            "confidence": 0.95
2244        });
2245
2246        let llm = Arc::new(TestLlm {
2247            completion_response: "should not be used".to_string(),
2248            structured_completion_response: Mutex::new(Some(structured_value.clone())),
2249            last_messages: Mutex::new(vec![]),
2250            last_structured_messages: Mutex::new(vec![]),
2251            ..TestLlm::default()
2252        });
2253
2254        let retriever = simple_retriever(llm);
2255        let context = make_event_context();
2256        let session = default_session();
2257
2258        let params = SearchParams {
2259            response_schema: Some(json!({
2260                "type": "object",
2261                "properties": {
2262                    "answer": { "type": "string" },
2263                    "confidence": { "type": "number" }
2264                }
2265            })),
2266            ..SearchParams::default()
2267        };
2268
2269        let output = retriever
2270            .get_completion("What happened in 2024?", Some(context), &session, &params)
2271            .await
2272            .unwrap();
2273
2274        match output {
2275            SearchOutput::Structured(value) => {
2276                assert_eq!(value, structured_value);
2277            }
2278            other => panic!("Expected SearchOutput::Structured, got {other:?}"),
2279        }
2280    }
2281
2282    #[tokio::test]
2283    async fn get_completion_includes_session_history() {
2284        let llm = Arc::new(TestLlm {
2285            completion_response: "history-aware answer".to_string(),
2286            last_messages: Mutex::new(vec![]),
2287            last_structured_messages: Mutex::new(vec![]),
2288            ..TestLlm::default()
2289        });
2290
2291        let retriever = simple_retriever(Arc::clone(&llm));
2292        let context = make_event_context();
2293
2294        let session = SessionContext {
2295            session_id: Some("sess-1".to_string()),
2296            history: vec![
2297                Message::user("Previous question?".to_string()),
2298                Message::assistant("Previous answer.".to_string()),
2299            ],
2300            formatted_history: "Q: Previous question?\nA: Previous answer.".to_string(),
2301            graph_context: None,
2302        };
2303
2304        retriever
2305            .get_completion(
2306                "Follow-up question?",
2307                Some(context),
2308                &session,
2309                &SearchParams::default(),
2310            )
2311            .await
2312            .unwrap();
2313
2314        let messages = llm.last_messages.lock().unwrap();
2315        assert_eq!(messages.len(), 2, "Expected system + user messages");
2316
2317        // The system prompt should contain session history (prepended via TASK:).
2318        let system_msg = &messages[0].content;
2319        assert!(
2320            system_msg.contains("Previous question?"),
2321            "System prompt should include session history"
2322        );
2323        assert!(
2324            system_msg.contains("Previous answer."),
2325            "System prompt should include session history answer"
2326        );
2327    }
2328
2329    // -----------------------------------------------------------------------
2330    // Phase 5 — rank_temporal_events unit tests
2331    // -----------------------------------------------------------------------
2332
2333    fn make_ranked_edge(source_id: &str, target_id: &str, score: f32) -> RankedGraphEdge {
2334        RankedGraphEdge {
2335            source_id: source_id.to_string(),
2336            target_id: target_id.to_string(),
2337            relationship_name: "related_to".to_string(),
2338            score,
2339            source_name: format!("Source-{source_id}"),
2340            target_name: format!("Target-{target_id}"),
2341            dataset_id: None,
2342            source_text: None,
2343            target_text: None,
2344            source_description: None,
2345            target_description: None,
2346        }
2347    }
2348
2349    fn ranking_retriever() -> TemporalRetriever {
2350        let vector_db = Arc::new(TestVectorDb {
2351            collections: HashMap::from([(
2352                TestVectorDb::key("Event", "name"),
2353                vec![
2354                    SearchResult {
2355                        id: Uuid::parse_str("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa").unwrap(),
2356                        score: 0.9,
2357                        metadata: HashMap::new(),
2358                    },
2359                    SearchResult {
2360                        id: Uuid::parse_str("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb").unwrap(),
2361                        score: 0.5,
2362                        metadata: HashMap::new(),
2363                    },
2364                    SearchResult {
2365                        id: Uuid::parse_str("cccccccc-cccc-cccc-cccc-cccccccccccc").unwrap(),
2366                        score: 0.3,
2367                        metadata: HashMap::new(),
2368                    },
2369                ],
2370            )]),
2371        });
2372
2373        let embedding_engine = Arc::new(TestEmbeddingEngine);
2374        let graph_db = Arc::new(TestGraphDb {
2375            nodes: vec![],
2376            edges: vec![],
2377            neighbors: HashMap::new(),
2378        });
2379
2380        let llm = Arc::new(TestLlm::default());
2381
2382        TemporalRetriever::new(
2383            vector_db,
2384            embedding_engine,
2385            graph_db,
2386            llm,
2387            Some(5),
2388            Some(10),
2389            Some(0.0),
2390            None,
2391            None,
2392            None,
2393            None,
2394            None,
2395        )
2396    }
2397
2398    #[tokio::test]
2399    async fn rank_sorts_by_combined_score() {
2400        let retriever = ranking_retriever();
2401
2402        let event_ids: HashSet<String> = [
2403            "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
2404            "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb",
2405            "cccccccc-cccc-cccc-cccc-cccccccccccc",
2406        ]
2407        .iter()
2408        .map(|s| s.to_string())
2409        .collect();
2410
2411        let ranked_edges = vec![
2412            make_ranked_edge("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", "other-node", 0.8),
2413            make_ranked_edge("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", "other-node", 0.4),
2414            make_ranked_edge("cccccccc-cccc-cccc-cccc-cccccccccccc", "other-node", 0.2),
2415        ];
2416
2417        let ranked = retriever
2418            .rank_temporal_events("test query", &event_ids, &ranked_edges)
2419            .await
2420            .unwrap();
2421
2422        assert_eq!(ranked.len(), 3);
2423        assert_eq!(ranked[0].0, "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa");
2424        assert_eq!(ranked[1].0, "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb");
2425        assert_eq!(ranked[2].0, "cccccccc-cccc-cccc-cccc-cccccccccccc");
2426
2427        // Verify descending order.
2428        assert!(ranked[0].1 >= ranked[1].1);
2429        assert!(ranked[1].1 >= ranked[2].1);
2430    }
2431
2432    #[tokio::test]
2433    async fn rank_events_not_in_vector_get_default_score() {
2434        let retriever = ranking_retriever();
2435
2436        let event_ids: HashSet<String> = [
2437            "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
2438            "dddddddd-dddd-dddd-dddd-dddddddddddd", // Not in vector DB
2439        ]
2440        .iter()
2441        .map(|s| s.to_string())
2442        .collect();
2443
2444        let ranked_edges = vec![make_ranked_edge(
2445            "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
2446            "other-node",
2447            0.8,
2448        )];
2449
2450        let ranked = retriever
2451            .rank_temporal_events("test query", &event_ids, &ranked_edges)
2452            .await
2453            .unwrap();
2454
2455        assert_eq!(ranked.len(), 2);
2456
2457        let unknown_event = ranked
2458            .iter()
2459            .find(|(id, _)| id == "dddddddd-dddd-dddd-dddd-dddddddddddd")
2460            .unwrap();
2461        assert!(
2462            unknown_event.1.abs() < f32::EPSILON,
2463            "Unknown event should have score 0.0, got {}",
2464            unknown_event.1
2465        );
2466
2467        let known_event = ranked
2468            .iter()
2469            .find(|(id, _)| id == "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
2470            .unwrap();
2471        assert!(
2472            known_event.1 > unknown_event.1,
2473            "Known event should have higher score"
2474        );
2475    }
2476
2477    #[tokio::test]
2478    async fn rank_empty_vector_results() {
2479        let vector_db = Arc::new(TestVectorDb {
2480            collections: HashMap::new(),
2481        });
2482        let embedding_engine = Arc::new(TestEmbeddingEngine);
2483        let graph_db = Arc::new(TestGraphDb {
2484            nodes: vec![],
2485            edges: vec![],
2486            neighbors: HashMap::new(),
2487        });
2488        let llm = Arc::new(TestLlm::default());
2489
2490        let retriever = TemporalRetriever::new(
2491            vector_db,
2492            embedding_engine,
2493            graph_db,
2494            llm,
2495            Some(5),
2496            Some(10),
2497            Some(0.0),
2498            None,
2499            None,
2500            None,
2501            None,
2502            None,
2503        );
2504
2505        let event_ids: HashSet<String> = ["ev-1", "ev-2"].iter().map(|s| s.to_string()).collect();
2506
2507        let ranked_edges = vec![
2508            make_ranked_edge("ev-1", "other", 0.6),
2509            make_ranked_edge("other", "ev-2", 0.3),
2510        ];
2511
2512        let ranked = retriever
2513            .rank_temporal_events("query", &event_ids, &ranked_edges)
2514            .await
2515            .unwrap();
2516
2517        assert_eq!(ranked.len(), 2);
2518        assert_eq!(ranked[0].0, "ev-1");
2519        assert!((ranked[0].1 - 0.6).abs() < f32::EPSILON);
2520        assert_eq!(ranked[1].0, "ev-2");
2521        assert!((ranked[1].1 - 0.3).abs() < f32::EPSILON);
2522    }
2523
2524    #[tokio::test]
2525    async fn rank_empty_event_ids() {
2526        let retriever = ranking_retriever();
2527
2528        let event_ids: HashSet<String> = HashSet::new();
2529        let ranked_edges = vec![make_ranked_edge("some-node", "other-node", 0.5)];
2530
2531        let ranked = retriever
2532            .rank_temporal_events("query", &event_ids, &ranked_edges)
2533            .await
2534            .unwrap();
2535
2536        assert!(
2537            ranked.is_empty(),
2538            "Empty event_ids should yield empty result"
2539        );
2540    }
2541
2542    #[tokio::test]
2543    async fn rank_mismatched_vector_ids() {
2544        let vector_db = Arc::new(TestVectorDb {
2545            collections: HashMap::from([(
2546                TestVectorDb::key("Event", "name"),
2547                vec![SearchResult {
2548                    id: Uuid::parse_str("ffffffff-ffff-ffff-ffff-ffffffffffff").unwrap(),
2549                    score: 0.99,
2550                    metadata: HashMap::new(),
2551                }],
2552            )]),
2553        });
2554        let embedding_engine = Arc::new(TestEmbeddingEngine);
2555        let graph_db = Arc::new(TestGraphDb {
2556            nodes: vec![],
2557            edges: vec![],
2558            neighbors: HashMap::new(),
2559        });
2560        let llm = Arc::new(TestLlm::default());
2561
2562        let retriever = TemporalRetriever::new(
2563            vector_db,
2564            embedding_engine,
2565            graph_db,
2566            llm,
2567            Some(5),
2568            Some(10),
2569            Some(0.0),
2570            None,
2571            None,
2572            None,
2573            None,
2574            None,
2575        );
2576
2577        let event_ids: HashSet<String> =
2578            ["ev-abc", "ev-def"].iter().map(|s| s.to_string()).collect();
2579
2580        let ranked_edges = vec![make_ranked_edge("ev-abc", "something", 0.4)];
2581
2582        let ranked = retriever
2583            .rank_temporal_events("query", &event_ids, &ranked_edges)
2584            .await
2585            .unwrap();
2586
2587        assert_eq!(ranked.len(), 2);
2588        let ev_abc = ranked.iter().find(|(id, _)| id == "ev-abc").unwrap();
2589        assert!((ev_abc.1 - 0.4).abs() < f32::EPSILON);
2590
2591        let ev_def = ranked.iter().find(|(id, _)| id == "ev-def").unwrap();
2592        assert!(ev_def.1.abs() < f32::EPSILON);
2593    }
2594
2595    // -----------------------------------------------------------------------
2596    // Phase 6 — temporal_context_to_text unit tests
2597    // -----------------------------------------------------------------------
2598
2599    #[test]
2600    fn context_to_text_formats_event_items() {
2601        let context = vec![
2602            SearchItem {
2603                id: None,
2604                score: Some(0.9),
2605                payload: json!({
2606                    "event_id": "evt-1",
2607                    "event_name": "Product Launch",
2608                    "event_description": "Launched the new product",
2609                    "event_time": "2024-03-15",
2610                }),
2611            },
2612            SearchItem {
2613                id: None,
2614                score: Some(0.7),
2615                payload: json!({
2616                    "event_id": "evt-2",
2617                    "event_name": "Quarterly Review",
2618                    "event_description": "Reviewed Q1 results",
2619                    "event_time": "2024-04-01",
2620                }),
2621            },
2622        ];
2623
2624        let text = TemporalRetriever::temporal_context_to_text(&context);
2625        let lines: Vec<&str> = text.lines().collect();
2626
2627        assert_eq!(lines.len(), 2);
2628        assert_eq!(
2629            lines[0],
2630            "Product Launch (2024-03-15): Launched the new product"
2631        );
2632        assert_eq!(
2633            lines[1],
2634            "Quarterly Review (2024-04-01): Reviewed Q1 results"
2635        );
2636    }
2637
2638    #[test]
2639    fn context_to_text_formats_triplet_items() {
2640        let context = vec![
2641            SearchItem {
2642                id: None,
2643                score: Some(0.8),
2644                payload: json!({
2645                    "source_name": "Alice",
2646                    "target_name": "Bob",
2647                    "relationship": "knows",
2648                }),
2649            },
2650            SearchItem {
2651                id: None,
2652                score: Some(0.6),
2653                payload: json!({
2654                    "source_name": "Company X",
2655                    "target_name": "Product Y",
2656                    "relationship": "produces",
2657                }),
2658            },
2659        ];
2660
2661        let text = TemporalRetriever::temporal_context_to_text(&context);
2662        let lines: Vec<&str> = text.lines().collect();
2663
2664        assert_eq!(lines.len(), 2);
2665        assert_eq!(lines[0], "Alice -[knows]-> Bob");
2666        assert_eq!(lines[1], "Company X -[produces]-> Product Y");
2667    }
2668
2669    #[test]
2670    fn context_to_text_empty_context() {
2671        let context: Vec<SearchItem> = vec![];
2672        let text = TemporalRetriever::temporal_context_to_text(&context);
2673        assert_eq!(text, "");
2674    }
2675
2676    #[test]
2677    fn context_to_text_missing_fields_use_defaults() {
2678        let context = vec![SearchItem {
2679            id: None,
2680            score: Some(0.5),
2681            payload: json!({
2682                "event_id": "evt-bare",
2683            }),
2684        }];
2685
2686        let text = TemporalRetriever::temporal_context_to_text(&context);
2687        assert_eq!(text, "Unnamed event (unknown time): No description");
2688    }
2689
2690    #[test]
2691    fn context_to_text_mixed_items() {
2692        let context = vec![
2693            SearchItem {
2694                id: None,
2695                score: Some(0.9),
2696                payload: json!({
2697                    "event_id": "evt-1",
2698                    "event_name": "Conference",
2699                    "event_description": "Annual tech conference",
2700                    "event_time": "2024-06-15",
2701                }),
2702            },
2703            SearchItem {
2704                id: None,
2705                score: Some(0.7),
2706                payload: json!({
2707                    "source_name": "Speaker",
2708                    "target_name": "Conference",
2709                    "relationship": "presents_at",
2710                }),
2711            },
2712        ];
2713
2714        let text = TemporalRetriever::temporal_context_to_text(&context);
2715        let lines: Vec<&str> = text.lines().collect();
2716
2717        assert_eq!(lines.len(), 2);
2718        assert_eq!(lines[0], "Conference (2024-06-15): Annual tech conference");
2719        assert_eq!(lines[1], "Speaker -[presents_at]-> Conference");
2720    }
2721}