1use crate::llm::LLMProvider;
10use anyhow::Result;
11use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::{RwLock, broadcast};
17use uuid::Uuid;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct TaskOrigin {
22 pub routing_key: String,
24 #[serde(flatten)]
26 pub metadata: HashMap<String, Value>,
27}
28
29impl TaskOrigin {
30 pub fn new(routing_key: impl Into<String>) -> Self {
32 Self {
33 routing_key: routing_key.into(),
34 metadata: HashMap::new(),
35 }
36 }
37
38 pub fn from_channel(channel: &str, chat_id: &str) -> Self {
40 Self::new(format!("{}:{}", channel, chat_id))
41 }
42
43 pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
45 self.metadata.insert(key.into(), value);
46 self
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
52pub enum TaskStatus {
53 Pending,
55 Running,
57 Completed(String),
59 Failed(String),
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct BackgroundTask {
66 pub id: String,
68 pub prompt: String,
70 pub origin: TaskOrigin,
72 pub status: TaskStatus,
74 pub started_at: DateTime<Utc>,
76 #[serde(skip_serializing_if = "Option::is_none")]
78 pub completed_at: Option<DateTime<Utc>>,
79}
80
81impl BackgroundTask {
82 pub fn new(prompt: impl Into<String>, origin: TaskOrigin) -> Self {
84 Self {
85 id: Uuid::new_v4().to_string()[..8].to_string(),
86 prompt: prompt.into(),
87 origin,
88 status: TaskStatus::Pending,
89 started_at: Utc::now(),
90 completed_at: None,
91 }
92 }
93
94 pub fn mark_running(&mut self) {
96 self.status = TaskStatus::Running;
97 }
98
99 pub fn mark_completed(&mut self, result: impl Into<String>) {
101 self.status = TaskStatus::Completed(result.into());
102 self.completed_at = Some(Utc::now());
103 }
104
105 pub fn mark_failed(&mut self, error: impl Into<String>) {
107 self.status = TaskStatus::Failed(error.into());
108 self.completed_at = Some(Utc::now());
109 }
110
111 pub fn is_finished(&self) -> bool {
113 matches!(
114 self.status,
115 TaskStatus::Completed(_) | TaskStatus::Failed(_)
116 )
117 }
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct TaskResult {
123 pub task_id: String,
125 pub origin: TaskOrigin,
127 pub content: String,
129 pub success: bool,
131 pub timestamp: DateTime<Utc>,
133}
134
135impl TaskResult {
136 pub fn success(
138 task_id: impl Into<String>,
139 origin: TaskOrigin,
140 content: impl Into<String>,
141 ) -> Self {
142 Self {
143 task_id: task_id.into(),
144 origin,
145 content: content.into(),
146 success: true,
147 timestamp: Utc::now(),
148 }
149 }
150
151 pub fn failure(
153 task_id: impl Into<String>,
154 origin: TaskOrigin,
155 error: impl Into<String>,
156 ) -> Self {
157 Self {
158 task_id: task_id.into(),
159 origin,
160 content: error.into(),
161 success: false,
162 timestamp: Utc::now(),
163 }
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct TaskOrchestratorConfig {
170 pub max_concurrent_tasks: usize,
172 pub default_model: String,
174}
175
176impl Default for TaskOrchestratorConfig {
177 fn default() -> Self {
178 Self {
179 max_concurrent_tasks: 10,
180 default_model: "gpt-4o-mini".to_string(),
181 }
182 }
183}
184
185pub struct TaskOrchestrator {
187 provider: Arc<dyn LLMProvider>,
189 active_tasks: Arc<RwLock<HashMap<String, BackgroundTask>>>,
191 result_sender: broadcast::Sender<TaskResult>,
193 config: TaskOrchestratorConfig,
195}
196
197impl TaskOrchestrator {
198 pub fn new(provider: Arc<dyn LLMProvider>, config: TaskOrchestratorConfig) -> Self {
200 let (result_sender, _) = broadcast::channel(100);
201
202 Self {
203 provider,
204 active_tasks: Arc::new(RwLock::new(HashMap::new())),
205 result_sender,
206 config,
207 }
208 }
209
210 pub fn with_defaults(provider: Arc<dyn LLMProvider>) -> Self {
212 Self::new(provider, TaskOrchestratorConfig::default())
213 }
214
215 pub async fn spawn(&self, prompt: &str, origin: TaskOrigin) -> Result<String> {
217 let active_count = self.active_tasks.read().await.len();
219 if active_count >= self.config.max_concurrent_tasks {
220 return Err(anyhow::anyhow!(
221 "Maximum concurrent tasks ({}) reached",
222 self.config.max_concurrent_tasks
223 ));
224 }
225
226 let mut task = BackgroundTask::new(prompt, origin.clone());
228 let task_id = task.id.clone();
229 task.mark_running();
230
231 self.active_tasks
233 .write()
234 .await
235 .insert(task_id.clone(), task.clone());
236
237 let provider = Arc::clone(&self.provider);
239 let active_tasks = Arc::clone(&self.active_tasks);
240 let result_sender = self.result_sender.clone();
241 let model = self.config.default_model.clone();
242 let prompt = prompt.to_string();
243 let task_id_clone = task_id.clone();
244
245 tokio::spawn(async move {
246 let result = Self::run_task(&provider, &model, &prompt).await;
247
248 {
250 let mut tasks = active_tasks.write().await;
251 if let Some(task) = tasks.get_mut(&task_id_clone) {
252 match &result {
253 Ok(content) => task.mark_completed(content),
254 Err(e) => task.mark_failed(e.to_string()),
255 }
256 }
257 }
258
259 let task_result = match &result {
261 Ok(content) => TaskResult::success(&task_id_clone, origin, content),
262 Err(e) => TaskResult::failure(&task_id_clone, origin, e.to_string()),
263 };
264
265 let _ = result_sender.send(task_result);
266
267 tokio::time::sleep(tokio::time::Duration::from_secs(300)).await;
269 let mut tasks = active_tasks.write().await;
270 tasks.remove(&task_id_clone);
271 });
272
273 Ok(task_id)
274 }
275
276 async fn run_task(
278 provider: &Arc<dyn LLMProvider>,
279 model: &str,
280 prompt: &str,
281 ) -> Result<String> {
282 use crate::llm::types::ChatCompletionRequest;
283
284 let request = ChatCompletionRequest::new(model)
285 .system(
286 "You are a helpful assistant. Complete the given task thoroughly and concisely.",
287 )
288 .user(prompt);
289
290 let response = provider.chat(request).await?;
291
292 response
293 .content()
294 .map(|s| s.to_string())
295 .ok_or_else(|| anyhow::anyhow!("No response content"))
296 }
297
298 pub fn subscribe_results(&self) -> broadcast::Receiver<TaskResult> {
300 self.result_sender.subscribe()
301 }
302
303 pub async fn get_active_tasks(&self) -> Vec<BackgroundTask> {
305 self.active_tasks.read().await.values().cloned().collect()
306 }
307
308 pub async fn get_task(&self, task_id: &str) -> Option<BackgroundTask> {
310 self.active_tasks.read().await.get(task_id).cloned()
311 }
312
313 pub fn config(&self) -> &TaskOrchestratorConfig {
315 &self.config
316 }
317}