cognee_cognify/fact_extraction/
extractor.rs1use std::sync::Arc;
7
8use cognee_llm::{GenerationOptions, Llm, LlmExt};
9use tracing::debug;
10
11use super::models::{GraphModel, KnowledgeGraph};
12use crate::error::CognifyError;
13
14const DEFAULT_GRAPH_PROMPT: &str = include_str!("prompts/generate_graph_prompt.txt");
20
21#[derive(Clone)]
41pub struct FactExtractor {
42 llm: Arc<dyn Llm>,
43}
44
45impl FactExtractor {
46 pub fn new(llm: Arc<dyn Llm>) -> Self {
54 Self { llm }
55 }
56
57 pub fn default_graph_prompt() -> &'static str {
59 DEFAULT_GRAPH_PROMPT
60 }
61
62 pub async fn extract<M: GraphModel>(
82 &self,
83 text: &str,
84 custom_prompt: Option<&str>,
85 ) -> Result<M, CognifyError> {
86 debug!("Extracting model {} from text", std::any::type_name::<M>());
87 let system_prompt = custom_prompt.unwrap_or(DEFAULT_GRAPH_PROMPT);
88
89 let result: M = self
90 .llm
91 .create_structured_output(
92 text,
93 system_prompt,
94 Some(GenerationOptions {
101 temperature: Some(0.1),
102 max_tokens: None,
103 ..Default::default()
104 }),
105 )
106 .await
107 .map_err(|e| CognifyError::LlmError(e.to_string()))?;
108
109 debug!("Extracted model {}", std::any::type_name::<M>());
110 Ok(result)
111 }
112
113 pub async fn extract_facts(
128 &self,
129 text: &str,
130 custom_prompt: Option<&str>,
131 ) -> Result<KnowledgeGraph, CognifyError> {
132 debug!("Extracting facts from text: {}", text);
133
134 let mut graph: KnowledgeGraph = self.extract(text, custom_prompt).await?;
135
136 debug!(
137 "Extracted graph with {} nodes and {} edges",
138 graph.node_count(),
139 graph.edge_count()
140 );
141
142 for node in &mut graph.nodes {
145 if node.name.is_empty() {
146 node.name = node.id.clone();
147 }
148 }
149
150 Ok(graph)
151 }
152
153 pub async fn extract_facts_batch(
168 &self,
169 texts: Vec<String>, custom_prompt: Option<String>, ) -> Result<Vec<KnowledgeGraph>, CognifyError> {
172 let mut tasks = Vec::new();
173
174 for text in texts {
175 let llm_clone = Arc::clone(&self.llm);
176 let prompt_clone = custom_prompt.clone();
177
178 let task = tokio::spawn(async move {
179 let extractor = FactExtractor { llm: llm_clone };
180 extractor
181 .extract_facts(&text, prompt_clone.as_deref())
182 .await
183 });
184
185 tasks.push(task);
186 }
187
188 let results = futures::future::join_all(tasks).await;
189
190 let mut graphs = Vec::new();
191 for result in results {
192 let graph =
193 result.map_err(|e| CognifyError::LlmError(format!("Task join error: {e}")))??;
194 graphs.push(graph);
195 }
196
197 Ok(graphs)
198 }
199
200 pub fn llm(&self) -> &Arc<dyn Llm> {
202 &self.llm
203 }
204}
205
206#[cfg(test)]
207#[allow(
208 clippy::unwrap_used,
209 clippy::expect_used,
210 reason = "test code — panics are acceptable failures"
211)]
212mod tests {
213 use super::*;
214
215 #[derive(Clone)]
217 struct MockLlm;
218
219 #[async_trait::async_trait]
220 impl Llm for MockLlm {
221 async fn generate(
222 &self,
223 _messages: Vec<cognee_llm::Message>,
224 _options: Option<GenerationOptions>,
225 ) -> cognee_llm::LlmResult<cognee_llm::GenerationResponse> {
226 unimplemented!()
227 }
228
229 async fn create_structured_output_with_messages_raw(
230 &self,
231 _messages: Vec<cognee_llm::Message>,
232 _json_schema: &serde_json::Value,
233 _options: Option<GenerationOptions>,
234 ) -> cognee_llm::LlmResult<serde_json::Value> {
235 let graph = KnowledgeGraph {
236 nodes: vec![super::super::models::Node {
237 id: "test_node".to_string(),
238 name: "Test Node".to_string(),
239 node_type: "TEST".to_string(),
240 description: "A test node".to_string(),
241 }],
242 edges: vec![],
243 };
244 Ok(serde_json::to_value(&graph).unwrap())
245 }
246
247 fn model(&self) -> &str {
248 "mock"
249 }
250 }
251
252 #[tokio::test]
253 async fn test_fact_extractor_creation() {
254 let llm = Arc::new(MockLlm);
255 let extractor = FactExtractor::new(llm);
256 assert_eq!(extractor.llm().model(), "mock");
257 }
258
259 #[tokio::test]
260 async fn test_extract_facts() {
261 let llm = Arc::new(MockLlm);
262 let extractor = FactExtractor::new(llm);
263
264 let result = extractor.extract_facts("Test text", None).await;
265 assert!(result.is_ok());
266
267 let graph = result.unwrap();
268 assert_eq!(graph.node_count(), 1);
269 assert_eq!(graph.nodes[0].id, "test_node");
270 }
271
272 #[derive(Clone)]
274 struct MockLlmEmptyName;
275
276 #[async_trait::async_trait]
277 impl Llm for MockLlmEmptyName {
278 async fn generate(
279 &self,
280 _messages: Vec<cognee_llm::Message>,
281 _options: Option<GenerationOptions>,
282 ) -> cognee_llm::LlmResult<cognee_llm::GenerationResponse> {
283 unimplemented!()
284 }
285
286 async fn create_structured_output_with_messages_raw(
287 &self,
288 _messages: Vec<cognee_llm::Message>,
289 _json_schema: &serde_json::Value,
290 _options: Option<GenerationOptions>,
291 ) -> cognee_llm::LlmResult<serde_json::Value> {
292 let graph = KnowledgeGraph {
293 nodes: vec![
294 super::super::models::Node {
295 id: "alice_johnson".to_string(),
296 name: "".to_string(), node_type: "PERSON".to_string(),
298 description: "A person".to_string(),
299 },
300 super::super::models::Node {
301 id: "techcorp".to_string(),
302 name: "TechCorp".to_string(), node_type: "ORGANIZATION".to_string(),
304 description: "A company".to_string(),
305 },
306 ],
307 edges: vec![],
308 };
309 Ok(serde_json::to_value(&graph).unwrap())
310 }
311
312 fn model(&self) -> &str {
313 "mock-empty-name"
314 }
315 }
316
317 #[tokio::test]
318 async fn test_empty_node_name_defaults_to_id() {
319 let llm = Arc::new(MockLlmEmptyName);
320 let extractor = FactExtractor::new(llm);
321
322 let graph = extractor.extract_facts("Test text", None).await.unwrap();
323
324 assert_eq!(graph.node_count(), 2);
325
326 assert_eq!(graph.nodes[0].id, "alice_johnson");
328 assert_eq!(graph.nodes[0].name, "alice_johnson");
329
330 assert_eq!(graph.nodes[1].id, "techcorp");
332 assert_eq!(graph.nodes[1].name, "TechCorp");
333 }
334
335 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
339 struct CustomEvent {
340 event_name: String,
341 participants: Vec<String>,
342 }
343
344 impl super::super::models::GraphModel for CustomEvent {}
345
346 #[derive(Clone)]
348 struct MockLlmCustom;
349
350 #[async_trait::async_trait]
351 impl Llm for MockLlmCustom {
352 async fn generate(
353 &self,
354 _messages: Vec<cognee_llm::Message>,
355 _options: Option<GenerationOptions>,
356 ) -> cognee_llm::LlmResult<cognee_llm::GenerationResponse> {
357 unimplemented!()
358 }
359
360 async fn create_structured_output_with_messages_raw(
361 &self,
362 _messages: Vec<cognee_llm::Message>,
363 _json_schema: &serde_json::Value,
364 _options: Option<GenerationOptions>,
365 ) -> cognee_llm::LlmResult<serde_json::Value> {
366 let event = CustomEvent {
367 event_name: "Conference".to_string(),
368 participants: vec!["Alice".to_string(), "Bob".to_string()],
369 };
370 Ok(serde_json::to_value(&event).unwrap())
371 }
372
373 fn model(&self) -> &str {
374 "mock-custom"
375 }
376 }
377
378 #[tokio::test]
379 async fn test_extract_generic_custom_model() {
380 let llm = Arc::new(MockLlmCustom);
381 let extractor = FactExtractor::new(llm);
382
383 let event: CustomEvent = extractor.extract("Test text", None).await.unwrap();
384 assert_eq!(event.event_name, "Conference");
385 assert_eq!(event.participants, vec!["Alice", "Bob"]);
386 }
387
388 #[tokio::test]
389 async fn test_extract_generic_knowledge_graph() {
390 let llm = Arc::new(MockLlmEmptyName);
392 let extractor = FactExtractor::new(llm);
393
394 let graph: KnowledgeGraph = extractor.extract("Test text", None).await.unwrap();
395 assert_eq!(graph.nodes[0].name, "");
397 }
398
399 #[tokio::test]
400 async fn test_extract_facts_delegates_to_extract() {
401 let llm = Arc::new(MockLlm);
403 let extractor = FactExtractor::new(llm);
404
405 let via_extract: KnowledgeGraph = extractor.extract("Test text", None).await.unwrap();
406 let via_facts = extractor.extract_facts("Test text", None).await.unwrap();
407
408 assert_eq!(via_extract.node_count(), via_facts.node_count());
410 assert_eq!(via_extract.nodes[0].id, via_facts.nodes[0].id);
411 }
412
413 #[test]
414 fn graph_prompt_matches_vendored_txt() {
415 let vendored = include_str!("prompts/generate_graph_prompt.txt");
419 assert_eq!(
420 DEFAULT_GRAPH_PROMPT, vendored,
421 "const drifted from vendored .txt"
422 );
423 assert!(
425 vendored.contains("Every edge should include a description"),
426 "edge-description paragraph missing — not the Python prompt"
427 );
428 assert!(
429 vendored.contains(r#"label it as **"Person"**"#),
430 "Title-case 'Person' missing — UPPERCASE Rust prompt regressed"
431 );
432 assert!(
433 !vendored.contains("the entity type label in uppercase"),
434 "old UPPERCASE-forcing line still present"
435 );
436 }
437}