Skip to main content

cognee_cognify/fact_extraction/
extractor.rs

1//! Fact extractor using LLM for knowledge graph extraction.
2//!
3//! Port of Python's cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py
4//! and cognee/tasks/graph/extract_graph_from_data.py
5
6use std::sync::Arc;
7
8use cognee_llm::{GenerationOptions, Llm, LlmExt};
9use tracing::debug;
10
11use super::models::{GraphModel, KnowledgeGraph};
12use crate::error::CognifyError;
13
14/// Default system prompt for knowledge graph extraction.
15///
16/// Vendored byte-for-byte from Python's
17/// `cognee/infrastructure/llm/prompts/generate_graph_prompt.txt` (kept in sync via
18/// the prompt-parity drift guard in the inline `#[cfg(test)]` block below).
19const DEFAULT_GRAPH_PROMPT: &str = include_str!("prompts/generate_graph_prompt.txt");
20
21/// Fact extractor for knowledge graph generation.
22///
23/// Uses an LLM (via the Llm trait) to extract structured facts from text.
24/// Produces a KnowledgeGraph containing nodes (entities) and edges (relationships).
25///
26/// # Example
27/// ```ignore
28/// use cognee_cognify::FactExtractor;
29/// use cognee_llm::OpenAIAdapter;
30/// use std::sync::Arc;
31///
32/// let llm = Arc::new(OpenAIAdapter::new("gpt-4", "sk-...", None)?);
33/// let extractor = FactExtractor::new(llm);
34///
35/// let text = "Alice works at TechCorp in San Francisco.";
36/// let graph = extractor.extract_facts(text, None).await?;
37///
38/// println!("Extracted {} nodes and {} edges", graph.node_count(), graph.edge_count());
39/// ```
40#[derive(Clone)]
41pub struct FactExtractor {
42    llm: Arc<dyn Llm>,
43}
44
45impl FactExtractor {
46    /// Create a new fact extractor with the given LLM.
47    ///
48    /// # Arguments
49    /// * `llm` - An LLM implementation (e.g., OpenAIAdapter, OllamaAdapter)
50    ///
51    /// # Returns
52    /// A new FactExtractor instance
53    pub fn new(llm: Arc<dyn Llm>) -> Self {
54        Self { llm }
55    }
56
57    /// Return the default graph extraction prompt used by `extract_facts`.
58    pub fn default_graph_prompt() -> &'static str {
59        DEFAULT_GRAPH_PROMPT
60    }
61
62    /// Extract a structured model from text via LLM.
63    ///
64    /// Generic counterpart of [`extract_facts`](Self::extract_facts).
65    /// Works with any type implementing [`GraphModel`], which requires
66    /// `Serialize + DeserializeOwned + JsonSchema + Clone + Send + Sync`.
67    ///
68    /// The LLM's [`create_structured_output`](cognee_llm::LlmExt::create_structured_output)
69    /// method infers the JSON schema from `M` and deserializes the response
70    /// into the concrete type.
71    ///
72    /// No post-processing is applied; for the default [`KnowledgeGraph`] flow
73    /// with name-fallback logic, use [`extract_facts`](Self::extract_facts).
74    ///
75    /// # Arguments
76    /// * `text` - Input text to extract from
77    /// * `custom_prompt` - Optional custom system prompt (uses [`DEFAULT_GRAPH_PROMPT`] if None)
78    ///
79    /// # Errors
80    /// Returns [`CognifyError::LlmError`] if the LLM call fails
81    pub async fn extract<M: GraphModel>(
82        &self,
83        text: &str,
84        custom_prompt: Option<&str>,
85    ) -> Result<M, CognifyError> {
86        debug!("Extracting model {} from text", std::any::type_name::<M>());
87        let system_prompt = custom_prompt.unwrap_or(DEFAULT_GRAPH_PROMPT);
88
89        let result: M = self
90            .llm
91            .create_structured_output(
92                text,
93                system_prompt,
94                // Python parity: `acreate_structured_output` passes NO
95                // max_tokens/max_completion_tokens to the extraction call, so
96                // the response uses the model's full default output budget. A
97                // small cap here truncates large graphs mid-JSON on dense
98                // chunks, aborting cognify with a deserialization error. Leave
99                // max_tokens as None to match Python (no artificial cap).
100                Some(GenerationOptions {
101                    temperature: Some(0.1),
102                    max_tokens: None,
103                    ..Default::default()
104                }),
105            )
106            .await
107            .map_err(|e| CognifyError::LlmError(e.to_string()))?;
108
109        debug!("Extracted model {}", std::any::type_name::<M>());
110        Ok(result)
111    }
112
113    /// Extract facts (knowledge graph) from text.
114    ///
115    /// Mirrors Python's `extract_content_graph` function.
116    /// Uses the LLM to extract structured Node and Edge objects from the input text.
117    ///
118    /// # Arguments
119    /// * `text` - Input text to extract facts from
120    /// * `custom_prompt` - Optional custom system prompt (uses DEFAULT_GRAPH_PROMPT if None)
121    ///
122    /// # Returns
123    /// A KnowledgeGraph containing extracted nodes and edges
124    ///
125    /// # Errors
126    /// Returns CognifyError::LlmError if the LLM call fails
127    pub async fn extract_facts(
128        &self,
129        text: &str,
130        custom_prompt: Option<&str>,
131    ) -> Result<KnowledgeGraph, CognifyError> {
132        debug!("Extracting facts from text: {}", text);
133
134        let mut graph: KnowledgeGraph = self.extract(text, custom_prompt).await?;
135
136        debug!(
137            "Extracted graph with {} nodes and {} edges",
138            graph.node_count(),
139            graph.edge_count()
140        );
141
142        // Post-processing: ensure every node has a non-empty name (Python compat).
143        // In Python's Node.__init__, name defaults to id when empty.
144        for node in &mut graph.nodes {
145            if node.name.is_empty() {
146                node.name = node.id.clone();
147            }
148        }
149
150        Ok(graph)
151    }
152
153    /// Extract facts from multiple text chunks in parallel.
154    ///
155    /// Mirrors Python's pattern in extract_graph_from_data.py where all chunks
156    /// are processed concurrently using asyncio.gather.
157    ///
158    /// # Arguments
159    /// * `texts` - Slice of text strings to extract facts from
160    /// * `custom_prompt` - Optional custom system prompt
161    ///
162    /// # Returns
163    /// A vector of KnowledgeGraphs, one per input text
164    ///
165    /// # Errors
166    /// Returns CognifyError::LlmError if any LLM call fails
167    pub async fn extract_facts_batch(
168        &self,
169        texts: Vec<String>, // Changed to owned Vec<String> to avoid lifetime issues
170        custom_prompt: Option<String>, // Changed to owned String
171    ) -> Result<Vec<KnowledgeGraph>, CognifyError> {
172        let mut tasks = Vec::new();
173
174        for text in texts {
175            let llm_clone = Arc::clone(&self.llm);
176            let prompt_clone = custom_prompt.clone();
177
178            let task = tokio::spawn(async move {
179                let extractor = FactExtractor { llm: llm_clone };
180                extractor
181                    .extract_facts(&text, prompt_clone.as_deref())
182                    .await
183            });
184
185            tasks.push(task);
186        }
187
188        let results = futures::future::join_all(tasks).await;
189
190        let mut graphs = Vec::new();
191        for result in results {
192            let graph =
193                result.map_err(|e| CognifyError::LlmError(format!("Task join error: {e}")))??;
194            graphs.push(graph);
195        }
196
197        Ok(graphs)
198    }
199
200    /// Get a reference to the underlying LLM.
201    pub fn llm(&self) -> &Arc<dyn Llm> {
202        &self.llm
203    }
204}
205
206#[cfg(test)]
207#[allow(
208    clippy::unwrap_used,
209    clippy::expect_used,
210    reason = "test code — panics are acceptable failures"
211)]
212mod tests {
213    use super::*;
214
215    // Mock LLM for testing
216    #[derive(Clone)]
217    struct MockLlm;
218
219    #[async_trait::async_trait]
220    impl Llm for MockLlm {
221        async fn generate(
222            &self,
223            _messages: Vec<cognee_llm::Message>,
224            _options: Option<GenerationOptions>,
225        ) -> cognee_llm::LlmResult<cognee_llm::GenerationResponse> {
226            unimplemented!()
227        }
228
229        async fn create_structured_output_with_messages_raw(
230            &self,
231            _messages: Vec<cognee_llm::Message>,
232            _json_schema: &serde_json::Value,
233            _options: Option<GenerationOptions>,
234        ) -> cognee_llm::LlmResult<serde_json::Value> {
235            let graph = KnowledgeGraph {
236                nodes: vec![super::super::models::Node {
237                    id: "test_node".to_string(),
238                    name: "Test Node".to_string(),
239                    node_type: "TEST".to_string(),
240                    description: "A test node".to_string(),
241                }],
242                edges: vec![],
243            };
244            Ok(serde_json::to_value(&graph).unwrap())
245        }
246
247        fn model(&self) -> &str {
248            "mock"
249        }
250    }
251
252    #[tokio::test]
253    async fn test_fact_extractor_creation() {
254        let llm = Arc::new(MockLlm);
255        let extractor = FactExtractor::new(llm);
256        assert_eq!(extractor.llm().model(), "mock");
257    }
258
259    #[tokio::test]
260    async fn test_extract_facts() {
261        let llm = Arc::new(MockLlm);
262        let extractor = FactExtractor::new(llm);
263
264        let result = extractor.extract_facts("Test text", None).await;
265        assert!(result.is_ok());
266
267        let graph = result.unwrap();
268        assert_eq!(graph.node_count(), 1);
269        assert_eq!(graph.nodes[0].id, "test_node");
270    }
271
272    /// Mock LLM that returns a node with an empty name to test the fallback.
273    #[derive(Clone)]
274    struct MockLlmEmptyName;
275
276    #[async_trait::async_trait]
277    impl Llm for MockLlmEmptyName {
278        async fn generate(
279            &self,
280            _messages: Vec<cognee_llm::Message>,
281            _options: Option<GenerationOptions>,
282        ) -> cognee_llm::LlmResult<cognee_llm::GenerationResponse> {
283            unimplemented!()
284        }
285
286        async fn create_structured_output_with_messages_raw(
287            &self,
288            _messages: Vec<cognee_llm::Message>,
289            _json_schema: &serde_json::Value,
290            _options: Option<GenerationOptions>,
291        ) -> cognee_llm::LlmResult<serde_json::Value> {
292            let graph = KnowledgeGraph {
293                nodes: vec![
294                    super::super::models::Node {
295                        id: "alice_johnson".to_string(),
296                        name: "".to_string(), // Empty name — should be set to id
297                        node_type: "PERSON".to_string(),
298                        description: "A person".to_string(),
299                    },
300                    super::super::models::Node {
301                        id: "techcorp".to_string(),
302                        name: "TechCorp".to_string(), // Non-empty — should stay unchanged
303                        node_type: "ORGANIZATION".to_string(),
304                        description: "A company".to_string(),
305                    },
306                ],
307                edges: vec![],
308            };
309            Ok(serde_json::to_value(&graph).unwrap())
310        }
311
312        fn model(&self) -> &str {
313            "mock-empty-name"
314        }
315    }
316
317    #[tokio::test]
318    async fn test_empty_node_name_defaults_to_id() {
319        let llm = Arc::new(MockLlmEmptyName);
320        let extractor = FactExtractor::new(llm);
321
322        let graph = extractor.extract_facts("Test text", None).await.unwrap();
323
324        assert_eq!(graph.node_count(), 2);
325
326        // Node with empty name should have name set to its id
327        assert_eq!(graph.nodes[0].id, "alice_johnson");
328        assert_eq!(graph.nodes[0].name, "alice_johnson");
329
330        // Node with non-empty name should remain unchanged
331        assert_eq!(graph.nodes[1].id, "techcorp");
332        assert_eq!(graph.nodes[1].name, "TechCorp");
333    }
334
335    // ── Tests for the generic extract<M> method ────────────────────────
336
337    /// A custom graph model used to verify generic extraction.
338    #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
339    struct CustomEvent {
340        event_name: String,
341        participants: Vec<String>,
342    }
343
344    impl super::super::models::GraphModel for CustomEvent {}
345
346    /// Mock LLM that returns a `CustomEvent` JSON payload.
347    #[derive(Clone)]
348    struct MockLlmCustom;
349
350    #[async_trait::async_trait]
351    impl Llm for MockLlmCustom {
352        async fn generate(
353            &self,
354            _messages: Vec<cognee_llm::Message>,
355            _options: Option<GenerationOptions>,
356        ) -> cognee_llm::LlmResult<cognee_llm::GenerationResponse> {
357            unimplemented!()
358        }
359
360        async fn create_structured_output_with_messages_raw(
361            &self,
362            _messages: Vec<cognee_llm::Message>,
363            _json_schema: &serde_json::Value,
364            _options: Option<GenerationOptions>,
365        ) -> cognee_llm::LlmResult<serde_json::Value> {
366            let event = CustomEvent {
367                event_name: "Conference".to_string(),
368                participants: vec!["Alice".to_string(), "Bob".to_string()],
369            };
370            Ok(serde_json::to_value(&event).unwrap())
371        }
372
373        fn model(&self) -> &str {
374            "mock-custom"
375        }
376    }
377
378    #[tokio::test]
379    async fn test_extract_generic_custom_model() {
380        let llm = Arc::new(MockLlmCustom);
381        let extractor = FactExtractor::new(llm);
382
383        let event: CustomEvent = extractor.extract("Test text", None).await.unwrap();
384        assert_eq!(event.event_name, "Conference");
385        assert_eq!(event.participants, vec!["Alice", "Bob"]);
386    }
387
388    #[tokio::test]
389    async fn test_extract_generic_knowledge_graph() {
390        // Verify that extract::<KnowledgeGraph> works (without post-processing)
391        let llm = Arc::new(MockLlmEmptyName);
392        let extractor = FactExtractor::new(llm);
393
394        let graph: KnowledgeGraph = extractor.extract("Test text", None).await.unwrap();
395        // No post-processing: empty name stays empty (unlike extract_facts)
396        assert_eq!(graph.nodes[0].name, "");
397    }
398
399    #[tokio::test]
400    async fn test_extract_facts_delegates_to_extract() {
401        // Verify extract_facts still applies post-processing on top of extract
402        let llm = Arc::new(MockLlm);
403        let extractor = FactExtractor::new(llm);
404
405        let via_extract: KnowledgeGraph = extractor.extract("Test text", None).await.unwrap();
406        let via_facts = extractor.extract_facts("Test text", None).await.unwrap();
407
408        // Both should get the same node
409        assert_eq!(via_extract.node_count(), via_facts.node_count());
410        assert_eq!(via_extract.nodes[0].id, via_facts.nodes[0].id);
411    }
412
413    #[test]
414    fn graph_prompt_matches_vendored_txt() {
415        // Drift guard: const must equal the vendored .txt byte-for-byte.
416        // Manual re-sync: cp /tmp/cognee-python/cognee/infrastructure/llm/prompts/generate_graph_prompt.txt \
417        //   crates/cognify/src/fact_extraction/prompts/generate_graph_prompt.txt
418        let vendored = include_str!("prompts/generate_graph_prompt.txt");
419        assert_eq!(
420            DEFAULT_GRAPH_PROMPT, vendored,
421            "const drifted from vendored .txt"
422        );
423        // Python markers the old Rust prompt did NOT have:
424        assert!(
425            vendored.contains("Every edge should include a description"),
426            "edge-description paragraph missing — not the Python prompt"
427        );
428        assert!(
429            vendored.contains(r#"label it as **"Person"**"#),
430            "Title-case 'Person' missing — UPPERCASE Rust prompt regressed"
431        );
432        assert!(
433            !vendored.contains("the entity type label in uppercase"),
434            "old UPPERCASE-forcing line still present"
435        );
436    }
437}