Skip to main content

mofa_foundation/llm/
task_orchestrator.rs

1//! Task orchestration framework for background task spawning
2//!
3//! This module provides:
4//! - Background task spawning via tokio::spawn
5//! - Origin-based result routing
6//! - Task lifecycle management
7//! - Result streaming via channels
8
9use crate::llm::LLMProvider;
10use anyhow::Result;
11use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::{RwLock, broadcast};
17use uuid::Uuid;
18
19/// Where to route task results
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct TaskOrigin {
22    /// Routing key (e.g., "channel:chat_id")
23    pub routing_key: String,
24    /// Optional metadata
25    #[serde(flatten)]
26    pub metadata: HashMap<String, Value>,
27}
28
29impl TaskOrigin {
30    /// Create a new task origin
31    pub fn new(routing_key: impl Into<String>) -> Self {
32        Self {
33            routing_key: routing_key.into(),
34            metadata: HashMap::new(),
35        }
36    }
37
38    /// Create from channel and chat_id
39    pub fn from_channel(channel: &str, chat_id: &str) -> Self {
40        Self::new(format!("{}:{}", channel, chat_id))
41    }
42
43    /// Add metadata
44    pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
45        self.metadata.insert(key.into(), value);
46        self
47    }
48}
49
50/// Background task status
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
52pub enum TaskStatus {
53    /// Task is pending
54    Pending,
55    /// Task is running
56    Running,
57    /// Task completed successfully
58    Completed(String),
59    /// Task failed
60    Failed(String),
61}
62
63/// Background task with origin tracking
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct BackgroundTask {
66    /// Unique task ID
67    pub id: String,
68    /// Task prompt/description
69    pub prompt: String,
70    /// Where to route results
71    pub origin: TaskOrigin,
72    /// Current status
73    pub status: TaskStatus,
74    /// When the task started
75    pub started_at: DateTime<Utc>,
76    /// When the task completed (if applicable)
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub completed_at: Option<DateTime<Utc>>,
79}
80
81impl BackgroundTask {
82    /// Create a new background task
83    pub fn new(prompt: impl Into<String>, origin: TaskOrigin) -> Self {
84        Self {
85            id: Uuid::new_v4().to_string()[..8].to_string(),
86            prompt: prompt.into(),
87            origin,
88            status: TaskStatus::Pending,
89            started_at: Utc::now(),
90            completed_at: None,
91        }
92    }
93
94    /// Mark as running
95    pub fn mark_running(&mut self) {
96        self.status = TaskStatus::Running;
97    }
98
99    /// Mark as completed
100    pub fn mark_completed(&mut self, result: impl Into<String>) {
101        self.status = TaskStatus::Completed(result.into());
102        self.completed_at = Some(Utc::now());
103    }
104
105    /// Mark as failed
106    pub fn mark_failed(&mut self, error: impl Into<String>) {
107        self.status = TaskStatus::Failed(error.into());
108        self.completed_at = Some(Utc::now());
109    }
110
111    /// Check if task is finished
112    pub fn is_finished(&self) -> bool {
113        matches!(
114            self.status,
115            TaskStatus::Completed(_) | TaskStatus::Failed(_)
116        )
117    }
118}
119
120/// Result from a completed task
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct TaskResult {
123    /// Task ID
124    pub task_id: String,
125    /// Where to route the result
126    pub origin: TaskOrigin,
127    /// Result content
128    pub content: String,
129    /// Whether the task succeeded
130    pub success: bool,
131    /// Timestamp
132    pub timestamp: DateTime<Utc>,
133}
134
135impl TaskResult {
136    /// Create a successful result
137    pub fn success(
138        task_id: impl Into<String>,
139        origin: TaskOrigin,
140        content: impl Into<String>,
141    ) -> Self {
142        Self {
143            task_id: task_id.into(),
144            origin,
145            content: content.into(),
146            success: true,
147            timestamp: Utc::now(),
148        }
149    }
150
151    /// Create a failed result
152    pub fn failure(
153        task_id: impl Into<String>,
154        origin: TaskOrigin,
155        error: impl Into<String>,
156    ) -> Self {
157        Self {
158            task_id: task_id.into(),
159            origin,
160            content: error.into(),
161            success: false,
162            timestamp: Utc::now(),
163        }
164    }
165}
166
167/// Configuration for task orchestrator
168#[derive(Debug, Clone)]
169pub struct TaskOrchestratorConfig {
170    /// Maximum concurrent tasks
171    pub max_concurrent_tasks: usize,
172    /// Default model for tasks
173    pub default_model: String,
174}
175
176impl Default for TaskOrchestratorConfig {
177    fn default() -> Self {
178        Self {
179            max_concurrent_tasks: 10,
180            default_model: "gpt-4o-mini".to_string(),
181        }
182    }
183}
184
185/// Orchestrator for background tasks
186pub struct TaskOrchestrator {
187    /// LLM provider for tasks
188    provider: Arc<dyn LLMProvider>,
189    /// Active tasks
190    active_tasks: Arc<RwLock<HashMap<String, BackgroundTask>>>,
191    /// Result sender
192    result_sender: broadcast::Sender<TaskResult>,
193    /// Configuration
194    config: TaskOrchestratorConfig,
195}
196
197impl TaskOrchestrator {
198    /// Create a new task orchestrator
199    pub fn new(provider: Arc<dyn LLMProvider>, config: TaskOrchestratorConfig) -> Self {
200        let (result_sender, _) = broadcast::channel(100);
201
202        Self {
203            provider,
204            active_tasks: Arc::new(RwLock::new(HashMap::new())),
205            result_sender,
206            config,
207        }
208    }
209
210    /// Create with default configuration
211    pub fn with_defaults(provider: Arc<dyn LLMProvider>) -> Self {
212        Self::new(provider, TaskOrchestratorConfig::default())
213    }
214
215    /// Spawn a background task
216    pub async fn spawn(&self, prompt: &str, origin: TaskOrigin) -> Result<String> {
217        // Check concurrent limit
218        let active_count = self.active_tasks.read().await.len();
219        if active_count >= self.config.max_concurrent_tasks {
220            return Err(anyhow::anyhow!(
221                "Maximum concurrent tasks ({}) reached",
222                self.config.max_concurrent_tasks
223            ));
224        }
225
226        // Create task
227        let mut task = BackgroundTask::new(prompt, origin.clone());
228        let task_id = task.id.clone();
229        task.mark_running();
230
231        // Store task
232        self.active_tasks
233            .write()
234            .await
235            .insert(task_id.clone(), task.clone());
236
237        // Spawn background task
238        let provider = Arc::clone(&self.provider);
239        let active_tasks = Arc::clone(&self.active_tasks);
240        let result_sender = self.result_sender.clone();
241        let model = self.config.default_model.clone();
242        let prompt = prompt.to_string();
243        let task_id_clone = task_id.clone();
244
245        tokio::spawn(async move {
246            let result = Self::run_task(&provider, &model, &prompt).await;
247
248            // Update task status
249            {
250                let mut tasks = active_tasks.write().await;
251                if let Some(task) = tasks.get_mut(&task_id_clone) {
252                    match &result {
253                        Ok(content) => task.mark_completed(content),
254                        Err(e) => task.mark_failed(e.to_string()),
255                    }
256                }
257            }
258
259            // Send result
260            let task_result = match &result {
261                Ok(content) => TaskResult::success(&task_id_clone, origin, content),
262                Err(e) => TaskResult::failure(&task_id_clone, origin, e.to_string()),
263            };
264
265            let _ = result_sender.send(task_result);
266
267            // Cleanup completed tasks after a delay
268            tokio::time::sleep(tokio::time::Duration::from_secs(300)).await;
269            let mut tasks = active_tasks.write().await;
270            tasks.remove(&task_id_clone);
271        });
272
273        Ok(task_id)
274    }
275
276    /// Run a single task to completion
277    async fn run_task(
278        provider: &Arc<dyn LLMProvider>,
279        model: &str,
280        prompt: &str,
281    ) -> Result<String> {
282        use crate::llm::types::ChatCompletionRequest;
283
284        let request = ChatCompletionRequest::new(model)
285            .system(
286                "You are a helpful assistant. Complete the given task thoroughly and concisely.",
287            )
288            .user(prompt);
289
290        let response = provider.chat(request).await?;
291
292        response
293            .content()
294            .map(|s| s.to_string())
295            .ok_or_else(|| anyhow::anyhow!("No response content"))
296    }
297
298    /// Subscribe to task results
299    pub fn subscribe_results(&self) -> broadcast::Receiver<TaskResult> {
300        self.result_sender.subscribe()
301    }
302
303    /// Get all active tasks
304    pub async fn get_active_tasks(&self) -> Vec<BackgroundTask> {
305        self.active_tasks.read().await.values().cloned().collect()
306    }
307
308    /// Get a specific task
309    pub async fn get_task(&self, task_id: &str) -> Option<BackgroundTask> {
310        self.active_tasks.read().await.get(task_id).cloned()
311    }
312
313    /// Get the configuration
314    pub fn config(&self) -> &TaskOrchestratorConfig {
315        &self.config
316    }
317}