use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use tokio::sync::{RwLock, mpsc};
use uuid::Uuid;
use terraphim_config::{Config, Role};
use terraphim_types::Document;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TaskId(pub Uuid);
impl Default for TaskId {
fn default() -> Self {
Self::new()
}
}
impl TaskId {
pub fn new() -> Self {
Self(Uuid::new_v4())
}
}
impl std::fmt::Display for TaskId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default)]
pub enum Priority {
Low = 0,
#[default]
Normal = 1,
High = 2,
Critical = 3,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TaskStatus {
Pending {
queued_at: DateTime<Utc>,
position_in_queue: Option<usize>,
},
Processing {
started_at: DateTime<Utc>,
progress: Option<f32>,
},
Completed {
summary: String,
completed_at: DateTime<Utc>,
processing_duration_seconds: u64,
},
Failed {
error: String,
failed_at: DateTime<Utc>,
retry_count: u32,
next_retry_at: Option<DateTime<Utc>>,
},
Cancelled {
cancelled_at: DateTime<Utc>,
reason: String,
},
}
impl TaskStatus {
pub fn is_terminal(&self) -> bool {
matches!(
self,
TaskStatus::Completed { .. } | TaskStatus::Failed { .. } | TaskStatus::Cancelled { .. }
)
}
pub fn is_processing(&self) -> bool {
matches!(self, TaskStatus::Processing { .. })
}
pub fn is_pending(&self) -> bool {
matches!(self, TaskStatus::Pending { .. })
}
}
#[derive(Debug, Clone)]
pub struct SummarizationTask {
pub id: TaskId,
pub document: Document,
pub role: Role,
pub config: Option<Config>,
pub priority: Priority,
pub retry_count: u32,
pub max_retries: u32,
pub created_at: DateTime<Utc>,
pub max_summary_length: Option<usize>,
pub force_regenerate: bool,
pub callback_url: Option<String>,
}
impl SummarizationTask {
pub fn new(document: Document, role: Role) -> Self {
Self {
id: TaskId::new(),
document,
role,
config: None,
priority: Priority::default(),
retry_count: 0,
max_retries: 3,
created_at: Utc::now(),
max_summary_length: None,
force_regenerate: false,
callback_url: None,
}
}
pub fn with_priority(mut self, priority: Priority) -> Self {
self.priority = priority;
self
}
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
pub fn with_max_summary_length(mut self, length: usize) -> Self {
self.max_summary_length = Some(length);
self
}
pub fn with_force_regenerate(mut self, force: bool) -> Self {
self.force_regenerate = force;
self
}
pub fn with_callback_url(mut self, url: String) -> Self {
self.callback_url = Some(url);
self
}
pub fn with_config(mut self, config: Config) -> Self {
self.config = Some(config);
self
}
pub fn can_retry(&self) -> bool {
self.retry_count < self.max_retries
}
pub fn increment_retry(&mut self) {
self.retry_count += 1;
}
pub fn get_summary_length(&self) -> usize {
self.max_summary_length.unwrap_or(250)
}
}
#[derive(Debug, Clone)]
pub struct QueueConfig {
pub max_queue_size: usize,
pub max_concurrent_workers: usize,
pub max_queue_time: Duration,
pub task_retention_time: Duration,
pub rate_limits: HashMap<String, RateLimitConfig>,
pub retry_delay: Duration,
pub max_retry_delay: Duration,
}
impl Default for QueueConfig {
fn default() -> Self {
let mut rate_limits = HashMap::new();
rate_limits.insert(
"openrouter".to_string(),
RateLimitConfig {
max_requests_per_minute: 60,
max_tokens_per_minute: 10000,
burst_size: 10,
},
);
rate_limits.insert(
"ollama".to_string(),
RateLimitConfig {
max_requests_per_minute: 300,
max_tokens_per_minute: 50000,
burst_size: 50,
},
);
Self {
max_queue_size: 1000,
max_concurrent_workers: 3,
max_queue_time: Duration::from_secs(300), task_retention_time: Duration::from_secs(3600), rate_limits,
retry_delay: Duration::from_secs(1),
max_retry_delay: Duration::from_secs(60),
}
}
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_requests_per_minute: u32,
pub max_tokens_per_minute: u32,
pub burst_size: u32,
}
#[derive(Debug)]
pub enum QueueCommand {
SubmitTask(Box<SummarizationTask>),
CancelTask(TaskId, String),
Pause,
Resume,
GetStats(tokio::sync::oneshot::Sender<QueueStats>),
Shutdown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueueStats {
pub queue_size: usize,
pub pending_tasks: usize,
pub processing_tasks: usize,
pub completed_tasks: usize,
pub failed_tasks: usize,
pub cancelled_tasks: usize,
pub avg_processing_time_seconds: Option<u64>,
pub is_paused: bool,
pub active_workers: usize,
pub rate_limiter_status: HashMap<String, RateLimiterStatus>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimiterStatus {
pub current_tokens: f64,
pub max_tokens: f64,
pub requests_in_window: u32,
pub reset_in_seconds: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SubmitResult {
Queued {
task_id: TaskId,
position_in_queue: usize,
estimated_wait_time_seconds: Option<u64>,
},
QueueFull,
Duplicate(TaskId),
ValidationError(String),
}
pub struct SummarizationQueue {
command_sender: mpsc::Sender<QueueCommand>,
pub(crate) task_status: Arc<RwLock<HashMap<TaskId, TaskStatus>>>,
config: QueueConfig,
}
impl SummarizationQueue {
pub fn new(config: QueueConfig, command_sender: mpsc::Sender<QueueCommand>) -> Self {
let task_status = Arc::new(RwLock::new(HashMap::new()));
Self {
command_sender,
task_status,
config,
}
}
pub async fn submit_task(
&self,
task: SummarizationTask,
) -> Result<SubmitResult, crate::ServiceError> {
let task_status = self.task_status.read().await;
if task_status.contains_key(&task.id) {
return Ok(SubmitResult::Duplicate(task.id.clone()));
}
drop(task_status);
if task.document.body.trim().is_empty() {
return Ok(SubmitResult::ValidationError(
"Document body is empty".to_string(),
));
}
let stats = self.get_stats().await?;
if stats.queue_size >= self.config.max_queue_size {
return Ok(SubmitResult::QueueFull);
}
let task_id = task.id.clone();
if (self
.command_sender
.send(QueueCommand::SubmitTask(Box::new(task)))
.await)
.is_err()
{
return Err(crate::ServiceError::Config(
"Queue worker not running".to_string(),
));
}
let estimated_wait = self.estimate_wait_time(stats.pending_tasks + 1).await;
Ok(SubmitResult::Queued {
task_id,
position_in_queue: stats.pending_tasks + 1,
estimated_wait_time_seconds: estimated_wait.map(|d| d.as_secs()),
})
}
pub async fn cancel_task(
&self,
task_id: TaskId,
reason: String,
) -> Result<bool, crate::ServiceError> {
let task_status = self.task_status.read().await;
if !task_status.contains_key(&task_id) {
return Ok(false);
}
drop(task_status);
if (self
.command_sender
.send(QueueCommand::CancelTask(task_id, reason))
.await)
.is_err()
{
return Err(crate::ServiceError::Config(
"Queue worker not running".to_string(),
));
}
Ok(true)
}
pub async fn get_task_status(&self, task_id: &TaskId) -> Option<TaskStatus> {
let task_status = self.task_status.read().await;
task_status.get(task_id).cloned()
}
pub async fn get_stats(&self) -> Result<QueueStats, crate::ServiceError> {
let (sender, receiver) = tokio::sync::oneshot::channel();
if (self
.command_sender
.send(QueueCommand::GetStats(sender))
.await)
.is_err()
{
return Err(crate::ServiceError::Config(
"Queue worker not running".to_string(),
));
}
receiver
.await
.map_err(|_| crate::ServiceError::Config("Failed to get queue stats".to_string()))
}
pub async fn pause(&self) -> Result<(), crate::ServiceError> {
if (self.command_sender.send(QueueCommand::Pause).await).is_err() {
return Err(crate::ServiceError::Config(
"Queue worker not running".to_string(),
));
}
Ok(())
}
pub async fn resume(&self) -> Result<(), crate::ServiceError> {
if (self.command_sender.send(QueueCommand::Resume).await).is_err() {
return Err(crate::ServiceError::Config(
"Queue worker not running".to_string(),
));
}
Ok(())
}
async fn estimate_wait_time(&self, position: usize) -> Option<Duration> {
if position == 0 {
return Some(Duration::from_secs(0));
}
let avg_processing_time = Duration::from_secs(10);
let concurrent_workers = self.config.max_concurrent_workers;
let estimated_seconds =
(position as u64 * avg_processing_time.as_secs()) / concurrent_workers as u64;
Some(Duration::from_secs(estimated_seconds))
}
pub async fn shutdown(&self) -> Result<(), crate::ServiceError> {
if (self.command_sender.send(QueueCommand::Shutdown).await).is_err() {
return Err(crate::ServiceError::Config(
"Queue worker not running".to_string(),
));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use terraphim_config::Role;
use tokio::sync::mpsc;
fn create_test_document() -> Document {
Document {
id: "test-doc".to_string(),
title: "Test Document".to_string(),
body: "This is a test document for summarization.".to_string(),
url: "http://example.com".to_string(),
description: None,
summarization: None,
stub: None,
tags: Some(vec![]),
rank: None,
source_haystack: None,
doc_type: terraphim_types::DocumentType::KgEntry,
synonyms: None,
route: None,
priority: None,
}
}
fn create_test_role() -> Role {
Role {
shortname: Some("test-role".to_string()),
name: "Test Role".to_string().into(),
relevance_function: terraphim_types::RelevanceFunction::TitleScorer,
haystacks: vec![],
terraphim_it: false,
theme: "default".to_string(),
kg: None,
llm_enabled: false,
llm_api_key: None,
llm_model: None,
llm_auto_summarize: false,
llm_chat_enabled: false,
llm_chat_system_prompt: None,
llm_chat_model: None,
llm_context_window: Some(32768),
extra: ahash::AHashMap::new(),
llm_router_enabled: false,
llm_router_config: None,
}
}
#[test]
fn test_task_id_generation() {
let id1 = TaskId::new();
let id2 = TaskId::new();
assert_ne!(id1, id2);
}
#[test]
fn test_task_creation() {
let document = create_test_document();
let role = create_test_role();
let task = SummarizationTask::new(document.clone(), role.clone());
assert_eq!(task.document.id, document.id);
assert_eq!(task.role.name, role.name);
assert_eq!(task.priority, Priority::Normal);
assert_eq!(task.retry_count, 0);
assert!(task.can_retry());
}
#[test]
fn test_task_builder_methods() {
let document = create_test_document();
let role = create_test_role();
let task = SummarizationTask::new(document, role)
.with_priority(Priority::High)
.with_max_retries(5)
.with_max_summary_length(500)
.with_force_regenerate(true)
.with_callback_url("http://callback.com".to_string());
assert_eq!(task.priority, Priority::High);
assert_eq!(task.max_retries, 5);
assert_eq!(task.get_summary_length(), 500);
assert!(task.force_regenerate);
assert_eq!(task.callback_url, Some("http://callback.com".to_string()));
}
#[test]
fn test_priority_ordering() {
assert!(Priority::Critical > Priority::High);
assert!(Priority::High > Priority::Normal);
assert!(Priority::Normal > Priority::Low);
}
#[test]
fn test_task_status_checks() {
let pending = TaskStatus::Pending {
queued_at: Utc::now(),
position_in_queue: Some(1),
};
assert!(pending.is_pending());
assert!(!pending.is_terminal());
let completed = TaskStatus::Completed {
summary: "test".to_string(),
completed_at: Utc::now(),
processing_duration_seconds: 10,
};
assert!(completed.is_terminal());
assert!(!completed.is_processing());
}
#[tokio::test]
async fn test_queue_creation() {
let config = QueueConfig::default();
let (command_sender, _receiver) = mpsc::channel(10);
let queue = SummarizationQueue::new(config, command_sender);
assert!(queue.task_status.read().await.is_empty());
}
}