cognee_cognify/summarization/
extractor.rs1use std::sync::Arc;
8
9use cognee_llm::{GenerationOptions, Llm, LlmExt};
10use cognee_models::DocumentChunk;
11
12use super::models::{SummarizedContent, TextSummary};
13use crate::error::CognifyError;
14
15fn default_summary_options() -> GenerationOptions {
17 GenerationOptions {
22 temperature: Some(0.3),
23 max_tokens: None,
24 ..Default::default()
25 }
26}
27
28const DEFAULT_SUMMARY_PROMPT: &str = include_str!("prompts/summarize_content.txt");
35
36#[derive(Clone)]
56pub struct SummaryExtractor {
57 llm: Arc<dyn Llm>,
58 summary_schema: Option<serde_json::Value>,
61}
62
63impl SummaryExtractor {
64 pub fn new(llm: Arc<dyn Llm>) -> Self {
66 Self {
67 llm,
68 summary_schema: None,
69 }
70 }
71
72 pub fn new_with_schema(llm: Arc<dyn Llm>, schema: Option<serde_json::Value>) -> Self {
78 Self {
79 llm,
80 summary_schema: schema,
81 }
82 }
83
84 pub async fn extract_summary(
90 &self,
91 text: &str,
92 custom_prompt: Option<&str>,
93 ) -> Result<SummarizedContent, CognifyError> {
94 let system_prompt = custom_prompt.unwrap_or(DEFAULT_SUMMARY_PROMPT);
95 let options = Some(default_summary_options());
96
97 match &self.summary_schema {
98 None => {
99 let summarized: SummarizedContent = self
100 .llm
101 .create_structured_output(text, system_prompt, options)
102 .await
103 .map_err(|e| CognifyError::LlmError(e.to_string()))?;
104 Ok(summarized)
105 }
106 Some(schema) => {
107 let raw: serde_json::Value = self
108 .llm
109 .create_structured_output_raw(text, system_prompt, schema, options)
110 .await
111 .map_err(|e| CognifyError::LlmError(e.to_string()))?;
112 let summary = raw.get("summary").and_then(|v| v.as_str()).ok_or_else(|| {
113 CognifyError::LlmError(
114 "summary_schema output missing string `summary` field".to_string(),
115 )
116 })?;
117 Ok(SummarizedContent {
118 summary: summary.to_string(),
119 description: String::new(),
120 })
121 }
122 }
123 }
124
125 pub async fn summarize_chunks(
137 &self,
138 chunks: &[DocumentChunk],
139 custom_prompt: Option<String>,
140 ) -> Result<Vec<TextSummary>, CognifyError> {
141 if chunks.is_empty() {
142 return Ok(vec![]);
143 }
144
145 let mut tasks = Vec::new();
146
147 for chunk in chunks {
148 let llm_clone = Arc::clone(&self.llm);
149 let schema_clone = self.summary_schema.clone();
150 let prompt_clone = custom_prompt.clone();
151 let text = chunk.text.clone();
152
153 let task = tokio::spawn(async move {
154 let extractor = SummaryExtractor {
155 llm: llm_clone,
156 summary_schema: schema_clone,
157 };
158 extractor
159 .extract_summary(&text, prompt_clone.as_deref())
160 .await
161 });
162
163 tasks.push(task);
164 }
165
166 let results = futures::future::join_all(tasks).await;
167
168 let model_name = self.llm.model().to_string();
170
171 let mut summaries = Vec::new();
172 for (chunk_index, result) in results.into_iter().enumerate() {
173 let chunk = &chunks[chunk_index];
174 let summarized =
175 result.map_err(|e| CognifyError::LlmError(format!("Task join error: {e}")))??;
176
177 let text_summary =
178 TextSummary::from_summarized_content(chunk.base.id, summarized, model_name.clone());
179
180 summaries.push(text_summary);
181 }
182
183 Ok(summaries)
184 }
185
186 pub fn llm(&self) -> &Arc<dyn Llm> {
188 &self.llm
189 }
190}
191
192#[cfg(test)]
193#[allow(
194 clippy::unwrap_used,
195 clippy::expect_used,
196 reason = "test code — panics are acceptable failures"
197)]
198mod tests {
199 use super::*;
200 use crate::config::validate_summary_schema;
201
202 #[test]
206 #[allow(
207 clippy::const_is_empty,
208 reason = "intentional sanity check that the const is non-empty"
209 )]
210 fn test_default_prompt_not_empty() {
211 assert!(!DEFAULT_SUMMARY_PROMPT.is_empty());
212 assert!(DEFAULT_SUMMARY_PROMPT.contains("Summarize the chunk for retrieval"));
213 }
214
215 #[test]
216 fn summary_prompt_matches_vendored_txt() {
217 let vendored = include_str!("prompts/summarize_content.txt");
218 assert_eq!(
219 DEFAULT_SUMMARY_PROMPT, vendored,
220 "const drifted from vendored .txt"
221 );
222 assert!(
223 vendored.contains("Output two sections only"),
224 "Python two-section structure marker missing"
225 );
226 assert!(
227 vendored.contains("Max 200 tokens"),
228 "token-limit marker missing"
229 );
230 }
231
232 #[test]
233 fn new_returns_no_schema() {
234 let llm: Arc<dyn Llm> = Arc::new(NoopLlm);
236 let extractor = SummaryExtractor::new(llm);
237 assert!(extractor.summary_schema.is_none());
238 }
239
240 #[test]
241 fn new_with_schema_stores_schema() {
242 let llm: Arc<dyn Llm> = Arc::new(NoopLlm);
243 let schema = serde_json::json!({
244 "type": "object",
245 "properties": { "summary": { "type": "string" } }
246 });
247 let extractor = SummaryExtractor::new_with_schema(llm, Some(schema.clone()));
248 assert_eq!(extractor.summary_schema, Some(schema));
249 }
250
251 #[test]
252 fn validate_summary_schema_accepts_valid() {
253 let schema = serde_json::json!({
254 "type": "object",
255 "properties": { "summary": { "type": "string" } }
256 });
257 assert!(validate_summary_schema(&schema).is_ok());
258 }
259
260 #[test]
261 fn validate_summary_schema_rejects_missing_summary() {
262 let schema = serde_json::json!({
263 "type": "object",
264 "properties": { "other_field": { "type": "string" } }
265 });
266 assert!(validate_summary_schema(&schema).is_err());
267 }
268
269 #[test]
270 fn validate_summary_schema_rejects_non_string_summary() {
271 let schema = serde_json::json!({
272 "type": "object",
273 "properties": { "summary": { "type": "integer" } }
274 });
275 assert!(validate_summary_schema(&schema).is_err());
276 }
277
278 #[test]
279 fn validate_summary_schema_rejects_non_object() {
280 let schema = serde_json::json!([1, 2, 3]);
281 assert!(validate_summary_schema(&schema).is_err());
282 }
283
284 struct NoopLlm;
286
287 #[async_trait::async_trait]
288 impl Llm for NoopLlm {
289 async fn generate(
290 &self,
291 _messages: Vec<cognee_llm::Message>,
292 _options: Option<cognee_llm::types::GenerationOptions>,
293 ) -> cognee_llm::LlmResult<cognee_llm::GenerationResponse> {
294 unimplemented!()
295 }
296 async fn create_structured_output_with_messages_raw(
297 &self,
298 _messages: Vec<cognee_llm::Message>,
299 _json_schema: &serde_json::Value,
300 _options: Option<cognee_llm::types::GenerationOptions>,
301 ) -> cognee_llm::LlmResult<serde_json::Value> {
302 unimplemented!()
303 }
304 fn model(&self) -> &str {
305 "noop"
306 }
307 fn max_context_length(&self) -> u32 {
308 4096
309 }
310 }
311}