Skip to main content

cognee_cognify/temporal_extraction/
event_extractor.rs

1use std::sync::Arc;
2
3use cognee_llm::{GenerationOptions, Llm, LlmExt};
4use cognee_models::{CognifyInterval, RawExtractedTimestamp, TemporalEvent, to_cognify_timestamp};
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7
8use crate::CognifyError;
9
10const TEMPORAL_EVENT_EXTRACTION_PROMPT: &str =
11    include_str!("prompts/temporal_event_extraction.txt");
12
13/// Raw event as returned by the LLM.
14#[derive(Debug, Serialize, Deserialize, JsonSchema)]
15struct RawEvent {
16    pub name: String,
17    pub description: Option<String>,
18    pub time_from: Option<RawExtractedTimestamp>,
19    pub time_to: Option<RawExtractedTimestamp>,
20    pub location: Option<String>,
21}
22
23/// Object wrapper for structured-output APIs that require a root JSON object.
24#[derive(Debug, Serialize, Deserialize, JsonSchema)]
25struct RawEventsOutput {
26    #[serde(default)]
27    pub events: Vec<RawEvent>,
28}
29
30pub struct TemporalEventExtractor {
31    pub(crate) llm: Arc<dyn Llm>,
32}
33
34impl TemporalEventExtractor {
35    pub fn new(llm: Arc<dyn Llm>) -> Self {
36        Self { llm }
37    }
38
39    /// Extract events from a single chunk of text.
40    /// Returns an empty Vec (with a warning log) on LLM or parse errors
41    /// — extraction failures must not abort the cognify pipeline.
42    pub async fn extract_events(
43        &self,
44        chunk_text: &str,
45    ) -> Result<Vec<TemporalEvent>, CognifyError> {
46        // Python parity: `acreate_structured_output` passes no output cap on
47        // extraction, so responses use the model's full default budget. A small
48        // max_tokens truncates large event lists mid-JSON. Leave it None.
49        let options = GenerationOptions {
50            temperature: Some(0.1),
51            max_tokens: None,
52            ..Default::default()
53        };
54
55        let raw: RawEventsOutput = match self
56            .llm
57            .create_structured_output::<RawEventsOutput>(
58                chunk_text,
59                TEMPORAL_EVENT_EXTRACTION_PROMPT,
60                Some(options),
61            )
62            .await
63        {
64            Ok(v) => v,
65            Err(e) => {
66                tracing::warn!("Temporal event extraction failed: {e}");
67                return Ok(vec![]);
68            }
69        };
70
71        let events = raw
72            .events
73            .into_iter()
74            .filter_map(convert_raw_event)
75            .collect();
76
77        Ok(events)
78    }
79}
80
81fn convert_raw_event(raw: RawEvent) -> Option<TemporalEvent> {
82    if raw.name.trim().is_empty() {
83        return None;
84    }
85
86    // If both bounds are present, build an Interval instead of a single point.
87    let (at, during) = match (raw.time_from, raw.time_to) {
88        (Some(from), Some(to)) => {
89            let ts_from = to_cognify_timestamp(from)?;
90            let ts_to = to_cognify_timestamp(to)?;
91            (
92                None,
93                Some(CognifyInterval {
94                    time_from: ts_from,
95                    time_to: ts_to,
96                }),
97            )
98        }
99        (Some(from), None) => (to_cognify_timestamp(from), None),
100        (None, Some(to)) => (to_cognify_timestamp(to), None),
101        (None, None) => (None, None),
102    };
103
104    Some(TemporalEvent {
105        name: raw.name,
106        description: raw.description,
107        location: raw.location,
108        at,
109        during,
110        attributes: vec![], // populated by Phase 4
111    })
112}
113
114#[cfg(test)]
115#[allow(
116    clippy::unwrap_used,
117    clippy::expect_used,
118    reason = "test code — panics are acceptable failures"
119)]
120mod tests {
121    use super::*;
122    use async_trait::async_trait;
123    use cognee_llm::error::{LlmError, LlmResult};
124    use cognee_llm::types::{GenerationOptions, GenerationResponse, Message};
125    use serde_json::Value;
126
127    /// Mock LLM that returns a pre-configured JSON value from
128    /// `create_structured_output_with_messages_raw`.
129    struct MockLlm {
130        response: Result<Value, String>,
131    }
132
133    impl MockLlm {
134        fn with_json(value: Value) -> Self {
135            Self {
136                response: Ok(value),
137            }
138        }
139
140        fn with_error(msg: &str) -> Self {
141            Self {
142                response: Err(msg.to_string()),
143            }
144        }
145    }
146
147    #[async_trait]
148    impl Llm for MockLlm {
149        async fn generate(
150            &self,
151            _messages: Vec<Message>,
152            _options: Option<GenerationOptions>,
153        ) -> LlmResult<GenerationResponse> {
154            unimplemented!("not used in event_extractor tests")
155        }
156
157        async fn create_structured_output_with_messages_raw(
158            &self,
159            _messages: Vec<Message>,
160            _json_schema: &Value,
161            _options: Option<GenerationOptions>,
162        ) -> LlmResult<Value> {
163            match &self.response {
164                Ok(v) => Ok(v.clone()),
165                Err(msg) => Err(LlmError::ApiError(msg.clone())),
166            }
167        }
168
169        fn model(&self) -> &str {
170            "mock-llm"
171        }
172    }
173
174    #[tokio::test]
175    async fn extract_events_happy_path() {
176        // Mock returns two events: one point-in-time, one interval.
177        let json = serde_json::json!({
178            "events": [
179                {
180                    "name": "Moon Landing",
181                    "description": "First humans on the Moon",
182                    "time_from": { "year": 1969, "month": 7, "day": 20, "hour": 20, "minute": 17, "second": 0 },
183                    "time_to": null,
184                    "location": "Sea of Tranquility"
185                },
186                {
187                    "name": "World War II",
188                    "description": "Global conflict",
189                    "time_from": { "year": 1939, "month": 9, "day": 1 },
190                    "time_to": { "year": 1945, "month": 9, "day": 2 },
191                    "location": null
192                }
193            ]
194        });
195
196        let llm = Arc::new(MockLlm::with_json(json));
197        let extractor = TemporalEventExtractor::new(llm);
198
199        let events = extractor.extract_events("some text").await.unwrap();
200        assert_eq!(events.len(), 2);
201
202        // First event: point-in-time (only time_from, no time_to).
203        let e0 = &events[0];
204        assert_eq!(e0.name, "Moon Landing");
205        assert_eq!(e0.description.as_deref(), Some("First humans on the Moon"));
206        assert_eq!(e0.location.as_deref(), Some("Sea of Tranquility"));
207        assert!(e0.at.is_some(), "point-in-time event should have `at`");
208        assert!(e0.during.is_none());
209        let ts = e0.at.as_ref().unwrap();
210        assert_eq!(ts.year, 1969);
211        assert_eq!(ts.month, 7);
212        assert_eq!(ts.day, 20);
213
214        // Second event: interval (both time_from and time_to).
215        let e1 = &events[1];
216        assert_eq!(e1.name, "World War II");
217        assert!(e1.at.is_none());
218        assert!(e1.during.is_some(), "interval event should have `during`");
219        let interval = e1.during.as_ref().unwrap();
220        assert_eq!(interval.time_from.year, 1939);
221        assert_eq!(interval.time_to.year, 1945);
222    }
223
224    #[tokio::test]
225    async fn extract_events_returns_empty_on_llm_error() {
226        let llm = Arc::new(MockLlm::with_error("service unavailable"));
227        let extractor = TemporalEventExtractor::new(llm);
228
229        let events = extractor.extract_events("some text").await.unwrap();
230        assert!(events.is_empty(), "LLM error should yield empty vec");
231    }
232
233    #[tokio::test]
234    async fn extract_events_filters_empty_names() {
235        let json = serde_json::json!({
236            "events": [
237                {
238                    "name": "",
239                    "description": null,
240                    "time_from": null,
241                    "time_to": null,
242                    "location": null
243                },
244                {
245                    "name": "Valid Event",
246                    "description": "Has a name",
247                    "time_from": { "year": 2020, "month": 1, "day": 1 },
248                    "time_to": null,
249                    "location": null
250                }
251            ]
252        });
253
254        let llm = Arc::new(MockLlm::with_json(json));
255        let extractor = TemporalEventExtractor::new(llm);
256
257        let events = extractor.extract_events("some text").await.unwrap();
258        assert_eq!(events.len(), 1);
259        assert_eq!(events[0].name, "Valid Event");
260    }
261
262    #[test]
263    fn convert_raw_event_point_in_time() {
264        let raw = RawEvent {
265            name: "Launch".to_string(),
266            description: Some("Rocket launch".to_string()),
267            time_from: Some(RawExtractedTimestamp {
268                year: 2024,
269                month: 3,
270                day: 15,
271                hour: 10,
272                minute: 30,
273                second: 0,
274            }),
275            time_to: None,
276            location: Some("Cape Canaveral".to_string()),
277        };
278
279        let event = convert_raw_event(raw).unwrap();
280        assert_eq!(event.name, "Launch");
281        assert!(event.at.is_some());
282        assert!(event.during.is_none());
283        let ts = event.at.unwrap();
284        assert_eq!(ts.year, 2024);
285        assert_eq!(ts.month, 3);
286        assert_eq!(ts.day, 15);
287        assert_eq!(ts.hour, 10);
288        assert_eq!(ts.minute, 30);
289        assert_eq!(ts.timestamp_str, "2024-03-15 10:30:00");
290    }
291
292    #[test]
293    fn convert_raw_event_interval() {
294        let raw = RawEvent {
295            name: "Conference".to_string(),
296            description: None,
297            time_from: Some(RawExtractedTimestamp {
298                year: 2025,
299                month: 6,
300                day: 1,
301                hour: 0,
302                minute: 0,
303                second: 0,
304            }),
305            time_to: Some(RawExtractedTimestamp {
306                year: 2025,
307                month: 6,
308                day: 5,
309                hour: 0,
310                minute: 0,
311                second: 0,
312            }),
313            location: None,
314        };
315
316        let event = convert_raw_event(raw).unwrap();
317        assert_eq!(event.name, "Conference");
318        assert!(event.at.is_none());
319        assert!(event.during.is_some());
320        let interval = event.during.unwrap();
321        assert_eq!(interval.time_from.year, 2025);
322        assert_eq!(interval.time_from.day, 1);
323        assert_eq!(interval.time_to.day, 5);
324    }
325
326    #[test]
327    fn convert_raw_event_invalid_timestamp() {
328        // Month 13 is invalid — to_cognify_timestamp returns None.
329        // For a point-in-time case (only time_from), the event is still
330        // returned but with at: None and during: None.
331        let raw = RawEvent {
332            name: "Bad Date".to_string(),
333            description: None,
334            time_from: Some(RawExtractedTimestamp {
335                year: 2024,
336                month: 13,
337                day: 1,
338                hour: 0,
339                minute: 0,
340                second: 0,
341            }),
342            time_to: None,
343            location: None,
344        };
345
346        let event = convert_raw_event(raw).expect("event with invalid timestamp is still returned");
347        assert!(
348            event.at.is_none(),
349            "Invalid month should cause `at` to be None"
350        );
351        assert!(event.during.is_none());
352
353        // For an interval case, if time_from is invalid the entire interval
354        // is dropped — convert_raw_event returns None because `?` propagates
355        // the None from to_cognify_timestamp inside the (Some, Some) branch.
356        let raw_interval = RawEvent {
357            name: "Bad Interval".to_string(),
358            description: None,
359            time_from: Some(RawExtractedTimestamp {
360                year: 2024,
361                month: 13,
362                day: 1,
363                hour: 0,
364                minute: 0,
365                second: 0,
366            }),
367            time_to: Some(RawExtractedTimestamp {
368                year: 2024,
369                month: 6,
370                day: 1,
371                hour: 0,
372                minute: 0,
373                second: 0,
374            }),
375            location: None,
376        };
377
378        let result = convert_raw_event(raw_interval);
379        assert!(
380            result.is_none(),
381            "Invalid month in interval should cause convert_raw_event to return None"
382        );
383    }
384}