use std::sync::Arc;
use cognee_llm::{GenerationOptions, Llm, LlmExt};
use cognee_models::DocumentChunk;
use super::models::{SummarizedContent, TextSummary};
use crate::error::CognifyError;
fn default_summary_options() -> GenerationOptions {
GenerationOptions {
temperature: Some(0.3),
max_tokens: None,
..Default::default()
}
}
const DEFAULT_SUMMARY_PROMPT: &str = include_str!("prompts/summarize_content.txt");
#[derive(Clone)]
pub struct SummaryExtractor {
llm: Arc<dyn Llm>,
summary_schema: Option<serde_json::Value>,
}
impl SummaryExtractor {
pub fn new(llm: Arc<dyn Llm>) -> Self {
Self {
llm,
summary_schema: None,
}
}
pub fn new_with_schema(llm: Arc<dyn Llm>, schema: Option<serde_json::Value>) -> Self {
Self {
llm,
summary_schema: schema,
}
}
pub async fn extract_summary(
&self,
text: &str,
custom_prompt: Option<&str>,
) -> Result<SummarizedContent, CognifyError> {
let system_prompt = custom_prompt.unwrap_or(DEFAULT_SUMMARY_PROMPT);
let options = Some(default_summary_options());
match &self.summary_schema {
None => {
let summarized: SummarizedContent = self
.llm
.create_structured_output(text, system_prompt, options)
.await
.map_err(|e| CognifyError::LlmError(e.to_string()))?;
Ok(summarized)
}
Some(schema) => {
let raw: serde_json::Value = self
.llm
.create_structured_output_raw(text, system_prompt, schema, options)
.await
.map_err(|e| CognifyError::LlmError(e.to_string()))?;
let summary = raw.get("summary").and_then(|v| v.as_str()).ok_or_else(|| {
CognifyError::LlmError(
"summary_schema output missing string `summary` field".to_string(),
)
})?;
Ok(SummarizedContent {
summary: summary.to_string(),
description: String::new(),
})
}
}
}
pub async fn summarize_chunks(
&self,
chunks: &[DocumentChunk],
custom_prompt: Option<String>,
) -> Result<Vec<TextSummary>, CognifyError> {
if chunks.is_empty() {
return Ok(vec![]);
}
let mut tasks = Vec::new();
for chunk in chunks {
let llm_clone = Arc::clone(&self.llm);
let schema_clone = self.summary_schema.clone();
let prompt_clone = custom_prompt.clone();
let text = chunk.text.clone();
let task = tokio::spawn(async move {
let extractor = SummaryExtractor {
llm: llm_clone,
summary_schema: schema_clone,
};
extractor
.extract_summary(&text, prompt_clone.as_deref())
.await
});
tasks.push(task);
}
let results = futures::future::join_all(tasks).await;
let model_name = self.llm.model().to_string();
let mut summaries = Vec::new();
for (chunk_index, result) in results.into_iter().enumerate() {
let chunk = &chunks[chunk_index];
let summarized =
result.map_err(|e| CognifyError::LlmError(format!("Task join error: {e}")))??;
let text_summary =
TextSummary::from_summarized_content(chunk.base.id, summarized, model_name.clone());
summaries.push(text_summary);
}
Ok(summaries)
}
pub fn llm(&self) -> &Arc<dyn Llm> {
&self.llm
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
reason = "test code — panics are acceptable failures"
)]
mod tests {
use super::*;
use crate::config::validate_summary_schema;
#[test]
#[allow(
clippy::const_is_empty,
reason = "intentional sanity check that the const is non-empty"
)]
fn test_default_prompt_not_empty() {
assert!(!DEFAULT_SUMMARY_PROMPT.is_empty());
assert!(DEFAULT_SUMMARY_PROMPT.contains("Summarize the chunk for retrieval"));
}
#[test]
fn summary_prompt_matches_vendored_txt() {
let vendored = include_str!("prompts/summarize_content.txt");
assert_eq!(
DEFAULT_SUMMARY_PROMPT, vendored,
"const drifted from vendored .txt"
);
assert!(
vendored.contains("Output two sections only"),
"Python two-section structure marker missing"
);
assert!(
vendored.contains("Max 200 tokens"),
"token-limit marker missing"
);
}
#[test]
fn new_returns_no_schema() {
let llm: Arc<dyn Llm> = Arc::new(NoopLlm);
let extractor = SummaryExtractor::new(llm);
assert!(extractor.summary_schema.is_none());
}
#[test]
fn new_with_schema_stores_schema() {
let llm: Arc<dyn Llm> = Arc::new(NoopLlm);
let schema = serde_json::json!({
"type": "object",
"properties": { "summary": { "type": "string" } }
});
let extractor = SummaryExtractor::new_with_schema(llm, Some(schema.clone()));
assert_eq!(extractor.summary_schema, Some(schema));
}
#[test]
fn validate_summary_schema_accepts_valid() {
let schema = serde_json::json!({
"type": "object",
"properties": { "summary": { "type": "string" } }
});
assert!(validate_summary_schema(&schema).is_ok());
}
#[test]
fn validate_summary_schema_rejects_missing_summary() {
let schema = serde_json::json!({
"type": "object",
"properties": { "other_field": { "type": "string" } }
});
assert!(validate_summary_schema(&schema).is_err());
}
#[test]
fn validate_summary_schema_rejects_non_string_summary() {
let schema = serde_json::json!({
"type": "object",
"properties": { "summary": { "type": "integer" } }
});
assert!(validate_summary_schema(&schema).is_err());
}
#[test]
fn validate_summary_schema_rejects_non_object() {
let schema = serde_json::json!([1, 2, 3]);
assert!(validate_summary_schema(&schema).is_err());
}
struct NoopLlm;
#[async_trait::async_trait]
impl Llm for NoopLlm {
async fn generate(
&self,
_messages: Vec<cognee_llm::Message>,
_options: Option<cognee_llm::types::GenerationOptions>,
) -> cognee_llm::LlmResult<cognee_llm::GenerationResponse> {
unimplemented!()
}
async fn create_structured_output_with_messages_raw(
&self,
_messages: Vec<cognee_llm::Message>,
_json_schema: &serde_json::Value,
_options: Option<cognee_llm::types::GenerationOptions>,
) -> cognee_llm::LlmResult<serde_json::Value> {
unimplemented!()
}
fn model(&self) -> &str {
"noop"
}
fn max_context_length(&self) -> u32 {
4096
}
}
}