Skip to main content

cognee_cognify/summarization/
extractor.rs

1//! Summary extractor using LLM for text summarization.
2//!
3//! Port of Python's:
4//! - cognee/infrastructure/llm/extraction/extract_summary.py
5//! - cognee/tasks/summarization/summarize_text.py
6
7use std::sync::Arc;
8
9use cognee_llm::{GenerationOptions, Llm, LlmExt};
10use cognee_models::DocumentChunk;
11
12use super::models::{SummarizedContent, TextSummary};
13use crate::error::CognifyError;
14
15/// Default summarization options shared by both the typed and dynamic paths.
16fn default_summary_options() -> GenerationOptions {
17    // Python parity: `acreate_structured_output` passes no output cap on the
18    // summarization call (the ≤200-token limit is enforced via the prompt, not
19    // an API max_tokens). A hard cap here can truncate the structured JSON
20    // response mid-object. Leave max_tokens as None to match Python.
21    GenerationOptions {
22        temperature: Some(0.3),
23        max_tokens: None,
24        ..Default::default()
25    }
26}
27
28/// Default system prompt for text summarization.
29///
30/// Vendored byte-for-byte from Python's
31/// `cognee/infrastructure/llm/prompts/summarize_content.txt` (structured
32/// categories + ordered facts, ≤200 tokens). Kept in sync via the prompt-parity
33/// drift guard.
34const DEFAULT_SUMMARY_PROMPT: &str = include_str!("prompts/summarize_content.txt");
35
36/// Summary extractor for text chunks.
37///
38/// Uses an LLM (via the Llm trait) to generate hierarchical summaries from text chunks.
39/// Produces TextSummary objects linked to source chunks via deterministic UUIDs.
40///
41/// # Example
42/// ```ignore
43/// use cognee_cognify::SummaryExtractor;
44/// use cognee_llm::OpenAIAdapter;
45/// use std::sync::Arc;
46///
47/// let llm = Arc::new(OpenAIAdapter::new("gpt-4", "sk-...", None)?);
48/// let extractor = SummaryExtractor::new(llm);
49///
50/// let text = "Long article text here...";
51/// let summary = extractor.extract_summary(text, None).await?;
52///
53/// println!("Summary: {}", summary.summary);
54/// ```
55#[derive(Clone)]
56pub struct SummaryExtractor {
57    llm: Arc<dyn Llm>,
58    /// When `Some`, the dynamic schema path is taken instead of the typed
59    /// `SummarizedContent` path (Python `summarization_model` parity).
60    summary_schema: Option<serde_json::Value>,
61}
62
63impl SummaryExtractor {
64    /// Create a new summary extractor using the built-in `SummarizedContent` schema.
65    pub fn new(llm: Arc<dyn Llm>) -> Self {
66        Self {
67            llm,
68            summary_schema: None,
69        }
70    }
71
72    /// Create a new summary extractor with an optional custom output schema.
73    ///
74    /// When `schema` is `Some`, the LLM is called via the dynamic raw path and
75    /// the `summary` string field is extracted from the response. When `None`,
76    /// the built-in typed `SummarizedContent` path is used.
77    pub fn new_with_schema(llm: Arc<dyn Llm>, schema: Option<serde_json::Value>) -> Self {
78        Self {
79            llm,
80            summary_schema: schema,
81        }
82    }
83
84    /// Extract a summary from text.
85    ///
86    /// When `summary_schema` is `None`, uses the typed `SummarizedContent` path.
87    /// When `Some`, calls the LLM with the custom schema and extracts the
88    /// `summary` string field from the raw response (Python parity).
89    pub async fn extract_summary(
90        &self,
91        text: &str,
92        custom_prompt: Option<&str>,
93    ) -> Result<SummarizedContent, CognifyError> {
94        let system_prompt = custom_prompt.unwrap_or(DEFAULT_SUMMARY_PROMPT);
95        let options = Some(default_summary_options());
96
97        match &self.summary_schema {
98            None => {
99                let summarized: SummarizedContent = self
100                    .llm
101                    .create_structured_output(text, system_prompt, options)
102                    .await
103                    .map_err(|e| CognifyError::LlmError(e.to_string()))?;
104                Ok(summarized)
105            }
106            Some(schema) => {
107                let raw: serde_json::Value = self
108                    .llm
109                    .create_structured_output_raw(text, system_prompt, schema, options)
110                    .await
111                    .map_err(|e| CognifyError::LlmError(e.to_string()))?;
112                let summary = raw.get("summary").and_then(|v| v.as_str()).ok_or_else(|| {
113                    CognifyError::LlmError(
114                        "summary_schema output missing string `summary` field".to_string(),
115                    )
116                })?;
117                Ok(SummarizedContent {
118                    summary: summary.to_string(),
119                    description: String::new(),
120                })
121            }
122        }
123    }
124
125    /// Summarize multiple text chunks in parallel.
126    ///
127    /// # Arguments
128    /// * `chunks` - Slice of DocumentChunks to summarize
129    /// * `custom_prompt` - Optional custom system prompt
130    ///
131    /// # Returns
132    /// A vector of TextSummary objects, one per input chunk
133    ///
134    /// # Errors
135    /// Returns CognifyError::LlmError if any LLM call fails
136    pub async fn summarize_chunks(
137        &self,
138        chunks: &[DocumentChunk],
139        custom_prompt: Option<String>,
140    ) -> Result<Vec<TextSummary>, CognifyError> {
141        if chunks.is_empty() {
142            return Ok(vec![]);
143        }
144
145        let mut tasks = Vec::new();
146
147        for chunk in chunks {
148            let llm_clone = Arc::clone(&self.llm);
149            let schema_clone = self.summary_schema.clone();
150            let prompt_clone = custom_prompt.clone();
151            let text = chunk.text.clone();
152
153            let task = tokio::spawn(async move {
154                let extractor = SummaryExtractor {
155                    llm: llm_clone,
156                    summary_schema: schema_clone,
157                };
158                extractor
159                    .extract_summary(&text, prompt_clone.as_deref())
160                    .await
161            });
162
163            tasks.push(task);
164        }
165
166        let results = futures::future::join_all(tasks).await;
167
168        // Get model name from LLM
169        let model_name = self.llm.model().to_string();
170
171        let mut summaries = Vec::new();
172        for (chunk_index, result) in results.into_iter().enumerate() {
173            let chunk = &chunks[chunk_index];
174            let summarized =
175                result.map_err(|e| CognifyError::LlmError(format!("Task join error: {e}")))??;
176
177            let text_summary =
178                TextSummary::from_summarized_content(chunk.base.id, summarized, model_name.clone());
179
180            summaries.push(text_summary);
181        }
182
183        Ok(summaries)
184    }
185
186    /// Get a reference to the underlying LLM.
187    pub fn llm(&self) -> &Arc<dyn Llm> {
188        &self.llm
189    }
190}
191
192#[cfg(test)]
193#[allow(
194    clippy::unwrap_used,
195    clippy::expect_used,
196    reason = "test code — panics are acceptable failures"
197)]
198mod tests {
199    use super::*;
200    use crate::config::validate_summary_schema;
201
202    // Note: Tests that require LLM are in integration tests (tests/)
203    // These are just structural tests
204
205    #[test]
206    #[allow(
207        clippy::const_is_empty,
208        reason = "intentional sanity check that the const is non-empty"
209    )]
210    fn test_default_prompt_not_empty() {
211        assert!(!DEFAULT_SUMMARY_PROMPT.is_empty());
212        assert!(DEFAULT_SUMMARY_PROMPT.contains("Summarize the chunk for retrieval"));
213    }
214
215    #[test]
216    fn summary_prompt_matches_vendored_txt() {
217        let vendored = include_str!("prompts/summarize_content.txt");
218        assert_eq!(
219            DEFAULT_SUMMARY_PROMPT, vendored,
220            "const drifted from vendored .txt"
221        );
222        assert!(
223            vendored.contains("Output two sections only"),
224            "Python two-section structure marker missing"
225        );
226        assert!(
227            vendored.contains("Max 200 tokens"),
228            "token-limit marker missing"
229        );
230    }
231
232    #[test]
233    fn new_returns_no_schema() {
234        // new() leaves summary_schema as None
235        let llm: Arc<dyn Llm> = Arc::new(NoopLlm);
236        let extractor = SummaryExtractor::new(llm);
237        assert!(extractor.summary_schema.is_none());
238    }
239
240    #[test]
241    fn new_with_schema_stores_schema() {
242        let llm: Arc<dyn Llm> = Arc::new(NoopLlm);
243        let schema = serde_json::json!({
244            "type": "object",
245            "properties": { "summary": { "type": "string" } }
246        });
247        let extractor = SummaryExtractor::new_with_schema(llm, Some(schema.clone()));
248        assert_eq!(extractor.summary_schema, Some(schema));
249    }
250
251    #[test]
252    fn validate_summary_schema_accepts_valid() {
253        let schema = serde_json::json!({
254            "type": "object",
255            "properties": { "summary": { "type": "string" } }
256        });
257        assert!(validate_summary_schema(&schema).is_ok());
258    }
259
260    #[test]
261    fn validate_summary_schema_rejects_missing_summary() {
262        let schema = serde_json::json!({
263            "type": "object",
264            "properties": { "other_field": { "type": "string" } }
265        });
266        assert!(validate_summary_schema(&schema).is_err());
267    }
268
269    #[test]
270    fn validate_summary_schema_rejects_non_string_summary() {
271        let schema = serde_json::json!({
272            "type": "object",
273            "properties": { "summary": { "type": "integer" } }
274        });
275        assert!(validate_summary_schema(&schema).is_err());
276    }
277
278    #[test]
279    fn validate_summary_schema_rejects_non_object() {
280        let schema = serde_json::json!([1, 2, 3]);
281        assert!(validate_summary_schema(&schema).is_err());
282    }
283
284    // Minimal no-op LLM for structural tests only.
285    struct NoopLlm;
286
287    #[async_trait::async_trait]
288    impl Llm for NoopLlm {
289        async fn generate(
290            &self,
291            _messages: Vec<cognee_llm::Message>,
292            _options: Option<cognee_llm::types::GenerationOptions>,
293        ) -> cognee_llm::LlmResult<cognee_llm::GenerationResponse> {
294            unimplemented!()
295        }
296        async fn create_structured_output_with_messages_raw(
297            &self,
298            _messages: Vec<cognee_llm::Message>,
299            _json_schema: &serde_json::Value,
300            _options: Option<cognee_llm::types::GenerationOptions>,
301        ) -> cognee_llm::LlmResult<serde_json::Value> {
302            unimplemented!()
303        }
304        fn model(&self) -> &str {
305            "noop"
306        }
307        fn max_context_length(&self) -> u32 {
308            4096
309        }
310    }
311}