use std::sync::Arc;
use tokio::sync::{RwLock, mpsc};
use tokio::task::JoinHandle;
use terraphim_config::Role;
use terraphim_types::Document;
use crate::ServiceError;
use crate::summarization_queue::{
Priority, QueueCommand, QueueConfig, QueueStats, SubmitResult, SummarizationQueue,
SummarizationTask, TaskId, TaskStatus,
};
use crate::summarization_worker::SummarizationWorker;
pub struct SummarizationManager {
queue: SummarizationQueue,
worker_handle: Option<JoinHandle<Result<(), ServiceError>>>,
#[allow(dead_code)] command_sender: mpsc::Sender<QueueCommand>,
}
impl SummarizationManager {
pub fn new(config: QueueConfig) -> Self {
let (command_sender, command_receiver) = mpsc::channel(100);
let queue = SummarizationQueue::new(config.clone(), command_sender.clone());
let task_status = queue.get_task_status_storage();
let worker = SummarizationWorker::new(config, task_status);
let worker_handle = Some(tokio::spawn(
async move { worker.run(command_receiver).await },
));
Self {
queue,
worker_handle,
command_sender,
}
}
pub async fn summarize_document(
&self,
document: Document,
role: Role,
priority: Option<Priority>,
max_summary_length: Option<usize>,
force_regenerate: Option<bool>,
callback_url: Option<String>,
) -> Result<SubmitResult, ServiceError> {
let mut task = SummarizationTask::new(document, role);
if let Some(priority) = priority {
task = task.with_priority(priority);
}
if let Some(length) = max_summary_length {
task = task.with_max_summary_length(length);
}
if let Some(force) = force_regenerate {
task = task.with_force_regenerate(force);
}
if let Some(url) = callback_url {
task = task.with_callback_url(url);
}
self.queue.submit_task(task).await
}
#[allow(clippy::too_many_arguments)]
pub async fn summarize_document_with_config(
&self,
document: Document,
role: Role,
config: terraphim_config::Config,
priority: Option<Priority>,
max_summary_length: Option<usize>,
force_regenerate: Option<bool>,
callback_url: Option<String>,
) -> Result<SubmitResult, ServiceError> {
let mut task = SummarizationTask::new(document, role).with_config(config);
if let Some(priority) = priority {
task = task.with_priority(priority);
}
if let Some(length) = max_summary_length {
task = task.with_max_summary_length(length);
}
if let Some(force) = force_regenerate {
task = task.with_force_regenerate(force);
}
if let Some(url) = callback_url {
task = task.with_callback_url(url);
}
self.queue.submit_task(task).await
}
pub async fn get_task_status(&self, task_id: &TaskId) -> Option<TaskStatus> {
self.queue.get_task_status(task_id).await
}
pub async fn cancel_task(&self, task_id: TaskId, reason: String) -> Result<bool, ServiceError> {
self.queue.cancel_task(task_id, reason).await
}
pub async fn get_stats(&self) -> Result<QueueStats, ServiceError> {
self.queue.get_stats().await
}
pub async fn pause(&self) -> Result<(), ServiceError> {
self.queue.pause().await
}
pub async fn resume(&self) -> Result<(), ServiceError> {
self.queue.resume().await
}
pub async fn process_document_fields(
&self,
doc: &mut Document,
role: &Role,
extract_description: bool,
queue_summarization: bool,
) -> Result<Option<TaskId>, ServiceError> {
let mut task_id = None;
if extract_description && doc.description.is_none() && !doc.body.is_empty() {
match Self::extract_description_from_body(&doc.body, 200) {
Ok(description) => {
log::debug!(
"Generated description for document '{}': {} chars",
doc.id,
description.len()
);
doc.description = Some(description);
}
Err(e) => {
log::warn!(
"Failed to extract description for document '{}': {}",
doc.id,
e
);
}
}
}
if queue_summarization && doc.body.len() >= 500 {
let submit_result = self
.summarize_document(
doc.clone(),
role.clone(),
Some(Priority::Normal),
Some(300), Some(false), None, )
.await?;
match submit_result {
SubmitResult::Queued {
task_id: queued_task_id,
..
} => {
task_id = Some(queued_task_id.clone());
log::debug!(
"Queued AI summarization for document '{}' with task ID: {:?}",
doc.id,
queued_task_id
);
}
SubmitResult::Duplicate(existing_task_id) => {
task_id = Some(existing_task_id.clone());
log::debug!(
"Document '{}' already has summarization task: {:?}",
doc.id,
existing_task_id
);
}
SubmitResult::ValidationError(error) => {
log::warn!("Validation error for document '{}': {}", doc.id, error);
}
SubmitResult::QueueFull => {
log::warn!(
"Summarization queue is full, cannot queue document '{}'",
doc.id
);
}
}
}
Ok(task_id)
}
pub fn extract_description_from_body(
body: &str,
max_length: usize,
) -> Result<String, ServiceError> {
if body.is_empty() {
return Err(ServiceError::Config("Document body is empty".to_string()));
}
let first_paragraph = body
.split('\n')
.map(|line| line.trim())
.find(|line| !line.is_empty() && line.len() > 10)
.unwrap_or_else(|| body.trim());
if first_paragraph.len() <= max_length {
return Ok(first_paragraph.to_string());
}
let truncated = &first_paragraph[..max_length];
if let Some(last_period) = truncated.rfind(". ") {
if last_period > max_length / 2 {
return Ok(truncated[..=last_period].to_string());
}
}
if let Some(last_space) = truncated.rfind(' ') {
if last_space > max_length / 2 {
return Ok(format!("{}...", &truncated[..last_space]));
}
}
Ok(format!("{}...", &first_paragraph[..max_length - 3]))
}
pub async fn process_documents_batch(
&self,
documents: &mut [Document],
role: &Role,
extract_description: bool,
queue_summarization: bool,
) -> Result<Vec<Option<TaskId>>, ServiceError> {
log::info!(
"Processing {} documents for description and summarization",
documents.len()
);
let mut task_ids = Vec::with_capacity(documents.len());
let mut successful_count = 0;
let mut error_count = 0;
for doc in documents.iter_mut() {
match self
.process_document_fields(doc, role, extract_description, queue_summarization)
.await
{
Ok(task_id) => {
task_ids.push(task_id);
successful_count += 1;
}
Err(e) => {
log::error!("Failed to process document '{}': {}", doc.id, e);
task_ids.push(None);
error_count += 1;
}
}
}
log::info!(
"Completed batch processing: {} successful, {} errors",
successful_count,
error_count
);
Ok(task_ids)
}
pub async fn shutdown(&mut self) -> Result<(), ServiceError> {
self.queue.shutdown().await?;
if let Some(handle) = self.worker_handle.take() {
match handle.await {
Ok(result) => result?,
Err(e) => {
log::error!("Worker task panicked: {:?}", e);
return Err(ServiceError::Config("Worker task panicked".to_string()));
}
}
}
log::info!("Summarization manager shut down successfully");
Ok(())
}
pub fn is_healthy(&self) -> bool {
self.worker_handle
.as_ref()
.is_some_and(|handle| !handle.is_finished())
}
pub fn get_queue(&self) -> &SummarizationQueue {
&self.queue
}
}
pub struct SummarizationManagerBuilder {
config: QueueConfig,
}
impl SummarizationManagerBuilder {
pub fn new() -> Self {
Self {
config: QueueConfig::default(),
}
}
pub fn max_queue_size(mut self, size: usize) -> Self {
self.config.max_queue_size = size;
self
}
pub fn max_concurrent_workers(mut self, workers: usize) -> Self {
self.config.max_concurrent_workers = workers;
self
}
pub fn max_queue_time(mut self, duration: std::time::Duration) -> Self {
self.config.max_queue_time = duration;
self
}
pub fn task_retention_time(mut self, duration: std::time::Duration) -> Self {
self.config.task_retention_time = duration;
self
}
pub fn build(self) -> SummarizationManager {
SummarizationManager::new(self.config)
}
}
impl Default for SummarizationManagerBuilder {
fn default() -> Self {
Self::new()
}
}
impl SummarizationQueue {
pub(crate) fn get_task_status_storage(
&self,
) -> Arc<RwLock<std::collections::HashMap<TaskId, TaskStatus>>> {
Arc::clone(&self.task_status)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tokio::time::sleep;
fn create_test_document() -> Document {
Document {
id: "test-doc".to_string(),
title: "Test Document".to_string(),
body: "This is a test document for summarization with enough content to make it interesting.".to_string(),
url: "http://example.com".to_string(),
description: None,
summarization: None,
stub: None,
tags: Some(vec![]),
rank: None,
source_haystack: None,
doc_type: terraphim_types::DocumentType::KgEntry,
synonyms: None,
route: None,
priority: None,
}
}
fn create_test_role() -> Role {
Role {
shortname: Some("test-role".to_string()),
name: "Test Role".to_string().into(),
relevance_function: terraphim_types::RelevanceFunction::TitleScorer,
haystacks: vec![],
terraphim_it: false,
theme: "default".to_string(),
kg: None,
llm_enabled: false,
llm_api_key: None,
llm_model: None,
llm_auto_summarize: false,
llm_chat_enabled: false,
llm_chat_system_prompt: None,
llm_chat_model: None,
llm_context_window: Some(32768),
extra: {
let mut extra = ahash::AHashMap::new();
extra.insert(
"llm_provider".to_string(),
serde_json::Value::String("test".to_string()),
);
extra
},
llm_router_enabled: false,
llm_router_config: None,
}
}
#[tokio::test]
async fn test_manager_creation() {
let manager = SummarizationManager::new(QueueConfig::default());
assert!(manager.is_healthy());
sleep(Duration::from_millis(100)).await;
assert!(manager.is_healthy());
}
#[tokio::test]
async fn test_manager_builder() {
let manager = SummarizationManagerBuilder::new()
.max_queue_size(500)
.max_concurrent_workers(2)
.max_queue_time(Duration::from_secs(180))
.build();
assert!(manager.is_healthy());
}
#[tokio::test]
async fn test_task_submission() {
let manager = SummarizationManager::new(QueueConfig::default());
sleep(Duration::from_millis(100)).await;
let document = create_test_document();
let role = create_test_role();
let result = manager
.summarize_document(
document,
role,
Some(Priority::High),
Some(200),
Some(false),
Some("http://callback.com".to_string()),
)
.await;
assert!(result.is_ok());
match result.unwrap() {
SubmitResult::Queued { task_id, .. } => {
sleep(Duration::from_millis(100)).await;
let status = manager.get_task_status(&task_id).await;
assert!(status.is_some());
}
other => panic!("Unexpected result: {:?}", other),
}
}
#[tokio::test]
#[ignore = "Flaky test - timing dependent"]
async fn test_task_cancellation() {
let manager = SummarizationManager::new(QueueConfig::default());
sleep(Duration::from_millis(100)).await;
manager.pause().await.expect("Failed to pause manager");
let document = create_test_document();
let role = create_test_role();
let result = manager
.summarize_document(
document,
role,
Some(Priority::Low), None,
None,
None,
)
.await
.unwrap();
if let SubmitResult::Queued { task_id, .. } = result {
let cancelled = manager
.cancel_task(task_id.clone(), "Test cancellation".to_string())
.await
.unwrap();
assert!(cancelled, "Task cancellation should succeed");
sleep(Duration::from_millis(100)).await;
let status = manager.get_task_status(&task_id).await;
if let Some(TaskStatus::Cancelled { reason, .. }) = status {
assert_eq!(reason, "Test cancellation");
} else {
panic!("Task should be cancelled, got: {:?}", status);
}
}
}
#[tokio::test]
async fn test_queue_stats() {
let manager = SummarizationManager::new(QueueConfig::default());
sleep(Duration::from_millis(100)).await;
let stats = manager.get_stats().await.unwrap();
assert_eq!(stats.queue_size, 0);
assert_eq!(stats.pending_tasks, 0);
assert_eq!(stats.processing_tasks, 0);
assert!(!stats.is_paused);
}
#[tokio::test]
async fn test_pause_resume() {
let manager = SummarizationManager::new(QueueConfig::default());
sleep(Duration::from_millis(100)).await;
manager.pause().await.unwrap();
let stats = manager.get_stats().await.unwrap();
assert!(stats.is_paused);
manager.resume().await.unwrap();
let stats = manager.get_stats().await.unwrap();
assert!(!stats.is_paused);
}
#[tokio::test]
async fn test_manager_shutdown() {
let mut manager = SummarizationManager::new(QueueConfig::default());
sleep(Duration::from_millis(100)).await;
assert!(manager.is_healthy());
manager.shutdown().await.unwrap();
assert!(!manager.is_healthy());
}
}