1use std::sync::Arc;
2
3use cognee_llm::{GenerationOptions, Llm, LlmExt};
4use cognee_models::{CognifyInterval, RawExtractedTimestamp, TemporalEvent, to_cognify_timestamp};
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7
8use crate::CognifyError;
9
10const TEMPORAL_EVENT_EXTRACTION_PROMPT: &str =
11 include_str!("prompts/temporal_event_extraction.txt");
12
13#[derive(Debug, Serialize, Deserialize, JsonSchema)]
15struct RawEvent {
16 pub name: String,
17 pub description: Option<String>,
18 pub time_from: Option<RawExtractedTimestamp>,
19 pub time_to: Option<RawExtractedTimestamp>,
20 pub location: Option<String>,
21}
22
23#[derive(Debug, Serialize, Deserialize, JsonSchema)]
25struct RawEventsOutput {
26 #[serde(default)]
27 pub events: Vec<RawEvent>,
28}
29
30pub struct TemporalEventExtractor {
31 pub(crate) llm: Arc<dyn Llm>,
32}
33
34impl TemporalEventExtractor {
35 pub fn new(llm: Arc<dyn Llm>) -> Self {
36 Self { llm }
37 }
38
39 pub async fn extract_events(
43 &self,
44 chunk_text: &str,
45 ) -> Result<Vec<TemporalEvent>, CognifyError> {
46 let options = GenerationOptions {
50 temperature: Some(0.1),
51 max_tokens: None,
52 ..Default::default()
53 };
54
55 let raw: RawEventsOutput = match self
56 .llm
57 .create_structured_output::<RawEventsOutput>(
58 chunk_text,
59 TEMPORAL_EVENT_EXTRACTION_PROMPT,
60 Some(options),
61 )
62 .await
63 {
64 Ok(v) => v,
65 Err(e) => {
66 tracing::warn!("Temporal event extraction failed: {e}");
67 return Ok(vec![]);
68 }
69 };
70
71 let events = raw
72 .events
73 .into_iter()
74 .filter_map(convert_raw_event)
75 .collect();
76
77 Ok(events)
78 }
79}
80
81fn convert_raw_event(raw: RawEvent) -> Option<TemporalEvent> {
82 if raw.name.trim().is_empty() {
83 return None;
84 }
85
86 let (at, during) = match (raw.time_from, raw.time_to) {
88 (Some(from), Some(to)) => {
89 let ts_from = to_cognify_timestamp(from)?;
90 let ts_to = to_cognify_timestamp(to)?;
91 (
92 None,
93 Some(CognifyInterval {
94 time_from: ts_from,
95 time_to: ts_to,
96 }),
97 )
98 }
99 (Some(from), None) => (to_cognify_timestamp(from), None),
100 (None, Some(to)) => (to_cognify_timestamp(to), None),
101 (None, None) => (None, None),
102 };
103
104 Some(TemporalEvent {
105 name: raw.name,
106 description: raw.description,
107 location: raw.location,
108 at,
109 during,
110 attributes: vec![], })
112}
113
114#[cfg(test)]
115#[allow(
116 clippy::unwrap_used,
117 clippy::expect_used,
118 reason = "test code — panics are acceptable failures"
119)]
120mod tests {
121 use super::*;
122 use async_trait::async_trait;
123 use cognee_llm::error::{LlmError, LlmResult};
124 use cognee_llm::types::{GenerationOptions, GenerationResponse, Message};
125 use serde_json::Value;
126
127 struct MockLlm {
130 response: Result<Value, String>,
131 }
132
133 impl MockLlm {
134 fn with_json(value: Value) -> Self {
135 Self {
136 response: Ok(value),
137 }
138 }
139
140 fn with_error(msg: &str) -> Self {
141 Self {
142 response: Err(msg.to_string()),
143 }
144 }
145 }
146
147 #[async_trait]
148 impl Llm for MockLlm {
149 async fn generate(
150 &self,
151 _messages: Vec<Message>,
152 _options: Option<GenerationOptions>,
153 ) -> LlmResult<GenerationResponse> {
154 unimplemented!("not used in event_extractor tests")
155 }
156
157 async fn create_structured_output_with_messages_raw(
158 &self,
159 _messages: Vec<Message>,
160 _json_schema: &Value,
161 _options: Option<GenerationOptions>,
162 ) -> LlmResult<Value> {
163 match &self.response {
164 Ok(v) => Ok(v.clone()),
165 Err(msg) => Err(LlmError::ApiError(msg.clone())),
166 }
167 }
168
169 fn model(&self) -> &str {
170 "mock-llm"
171 }
172 }
173
174 #[tokio::test]
175 async fn extract_events_happy_path() {
176 let json = serde_json::json!({
178 "events": [
179 {
180 "name": "Moon Landing",
181 "description": "First humans on the Moon",
182 "time_from": { "year": 1969, "month": 7, "day": 20, "hour": 20, "minute": 17, "second": 0 },
183 "time_to": null,
184 "location": "Sea of Tranquility"
185 },
186 {
187 "name": "World War II",
188 "description": "Global conflict",
189 "time_from": { "year": 1939, "month": 9, "day": 1 },
190 "time_to": { "year": 1945, "month": 9, "day": 2 },
191 "location": null
192 }
193 ]
194 });
195
196 let llm = Arc::new(MockLlm::with_json(json));
197 let extractor = TemporalEventExtractor::new(llm);
198
199 let events = extractor.extract_events("some text").await.unwrap();
200 assert_eq!(events.len(), 2);
201
202 let e0 = &events[0];
204 assert_eq!(e0.name, "Moon Landing");
205 assert_eq!(e0.description.as_deref(), Some("First humans on the Moon"));
206 assert_eq!(e0.location.as_deref(), Some("Sea of Tranquility"));
207 assert!(e0.at.is_some(), "point-in-time event should have `at`");
208 assert!(e0.during.is_none());
209 let ts = e0.at.as_ref().unwrap();
210 assert_eq!(ts.year, 1969);
211 assert_eq!(ts.month, 7);
212 assert_eq!(ts.day, 20);
213
214 let e1 = &events[1];
216 assert_eq!(e1.name, "World War II");
217 assert!(e1.at.is_none());
218 assert!(e1.during.is_some(), "interval event should have `during`");
219 let interval = e1.during.as_ref().unwrap();
220 assert_eq!(interval.time_from.year, 1939);
221 assert_eq!(interval.time_to.year, 1945);
222 }
223
224 #[tokio::test]
225 async fn extract_events_returns_empty_on_llm_error() {
226 let llm = Arc::new(MockLlm::with_error("service unavailable"));
227 let extractor = TemporalEventExtractor::new(llm);
228
229 let events = extractor.extract_events("some text").await.unwrap();
230 assert!(events.is_empty(), "LLM error should yield empty vec");
231 }
232
233 #[tokio::test]
234 async fn extract_events_filters_empty_names() {
235 let json = serde_json::json!({
236 "events": [
237 {
238 "name": "",
239 "description": null,
240 "time_from": null,
241 "time_to": null,
242 "location": null
243 },
244 {
245 "name": "Valid Event",
246 "description": "Has a name",
247 "time_from": { "year": 2020, "month": 1, "day": 1 },
248 "time_to": null,
249 "location": null
250 }
251 ]
252 });
253
254 let llm = Arc::new(MockLlm::with_json(json));
255 let extractor = TemporalEventExtractor::new(llm);
256
257 let events = extractor.extract_events("some text").await.unwrap();
258 assert_eq!(events.len(), 1);
259 assert_eq!(events[0].name, "Valid Event");
260 }
261
262 #[test]
263 fn convert_raw_event_point_in_time() {
264 let raw = RawEvent {
265 name: "Launch".to_string(),
266 description: Some("Rocket launch".to_string()),
267 time_from: Some(RawExtractedTimestamp {
268 year: 2024,
269 month: 3,
270 day: 15,
271 hour: 10,
272 minute: 30,
273 second: 0,
274 }),
275 time_to: None,
276 location: Some("Cape Canaveral".to_string()),
277 };
278
279 let event = convert_raw_event(raw).unwrap();
280 assert_eq!(event.name, "Launch");
281 assert!(event.at.is_some());
282 assert!(event.during.is_none());
283 let ts = event.at.unwrap();
284 assert_eq!(ts.year, 2024);
285 assert_eq!(ts.month, 3);
286 assert_eq!(ts.day, 15);
287 assert_eq!(ts.hour, 10);
288 assert_eq!(ts.minute, 30);
289 assert_eq!(ts.timestamp_str, "2024-03-15 10:30:00");
290 }
291
292 #[test]
293 fn convert_raw_event_interval() {
294 let raw = RawEvent {
295 name: "Conference".to_string(),
296 description: None,
297 time_from: Some(RawExtractedTimestamp {
298 year: 2025,
299 month: 6,
300 day: 1,
301 hour: 0,
302 minute: 0,
303 second: 0,
304 }),
305 time_to: Some(RawExtractedTimestamp {
306 year: 2025,
307 month: 6,
308 day: 5,
309 hour: 0,
310 minute: 0,
311 second: 0,
312 }),
313 location: None,
314 };
315
316 let event = convert_raw_event(raw).unwrap();
317 assert_eq!(event.name, "Conference");
318 assert!(event.at.is_none());
319 assert!(event.during.is_some());
320 let interval = event.during.unwrap();
321 assert_eq!(interval.time_from.year, 2025);
322 assert_eq!(interval.time_from.day, 1);
323 assert_eq!(interval.time_to.day, 5);
324 }
325
326 #[test]
327 fn convert_raw_event_invalid_timestamp() {
328 let raw = RawEvent {
332 name: "Bad Date".to_string(),
333 description: None,
334 time_from: Some(RawExtractedTimestamp {
335 year: 2024,
336 month: 13,
337 day: 1,
338 hour: 0,
339 minute: 0,
340 second: 0,
341 }),
342 time_to: None,
343 location: None,
344 };
345
346 let event = convert_raw_event(raw).expect("event with invalid timestamp is still returned");
347 assert!(
348 event.at.is_none(),
349 "Invalid month should cause `at` to be None"
350 );
351 assert!(event.during.is_none());
352
353 let raw_interval = RawEvent {
357 name: "Bad Interval".to_string(),
358 description: None,
359 time_from: Some(RawExtractedTimestamp {
360 year: 2024,
361 month: 13,
362 day: 1,
363 hour: 0,
364 minute: 0,
365 second: 0,
366 }),
367 time_to: Some(RawExtractedTimestamp {
368 year: 2024,
369 month: 6,
370 day: 1,
371 hour: 0,
372 minute: 0,
373 second: 0,
374 }),
375 location: None,
376 };
377
378 let result = convert_raw_event(raw_interval);
379 assert!(
380 result.is_none(),
381 "Invalid month in interval should cause convert_raw_event to return None"
382 );
383 }
384}