use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use futures_util::StreamExt;
use reqwest::header::{HeaderMap, AUTHORIZATION, CONTENT_TYPE};
use reqwest_eventsource::{Event, EventSource};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use tokio::time::sleep;
#[derive(Debug, Clone, Deserialize)]
pub struct TaskEvent {
#[serde(alias = "id")]
pub task_id: String,
#[serde(alias = "target_capability")]
pub capability: Option<String>,
#[serde(alias = "target_tool")]
pub tool: Option<String>,
pub input: serde_json::Value,
pub workspace_id: String,
pub requester_agent_id: Option<String>,
pub parent_task_id: Option<String>,
pub workflow_id: Option<String>,
}
#[derive(Debug, Clone)]
pub struct TaskContext {
pub task_id: String,
pub workspace_id: String,
pub requester_id: Option<String>,
pub workflow_id: Option<String>,
}
impl From<&TaskEvent> for TaskContext {
fn from(event: &TaskEvent) -> Self {
Self {
task_id: event.task_id.clone(),
workspace_id: event.workspace_id.clone(),
requester_id: event.requester_agent_id.clone(),
workflow_id: event.workflow_id.clone(),
}
}
}
#[async_trait::async_trait]
pub trait TaskHandler: Send + Sync {
async fn handle(
&self,
ctx: TaskContext,
input: serde_json::Value,
) -> Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>>;
}
pub struct FnHandler<F>(pub F);
#[async_trait::async_trait]
impl<F, Fut> TaskHandler for FnHandler<F>
where
F: Fn(TaskContext, serde_json::Value) -> Fut + Send + Sync,
Fut: std::future::Future<
Output = Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>>,
> + Send,
{
async fn handle(
&self,
ctx: TaskContext,
input: serde_json::Value,
) -> Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>> {
(self.0)(ctx, input).await
}
}
#[derive(Debug, Clone)]
pub struct WorkerConfig {
pub api_base_url: String,
pub token: String,
pub heartbeat_interval: Duration,
pub reconnect_delay: Duration,
pub max_reconnect_delay: Duration,
}
impl WorkerConfig {
pub fn new(api_base_url: impl Into<String>, token: impl Into<String>) -> Self {
Self {
api_base_url: api_base_url.into(),
token: token.into(),
heartbeat_interval: Duration::from_secs(10),
reconnect_delay: Duration::from_secs(1),
max_reconnect_delay: Duration::from_secs(30),
}
}
}
#[derive(Debug, Serialize)]
struct CompleteTaskRequest {
status: String,
output: Option<serde_json::Value>,
error: Option<serde_json::Value>,
}
#[derive(Debug, Serialize)]
struct HeartbeatRequest {
status: String,
load: f32,
}
pub struct TaskWorker {
config: WorkerConfig,
handlers: Arc<RwLock<HashMap<String, Arc<dyn TaskHandler>>>>,
client: reqwest::Client,
running: Arc<RwLock<bool>>,
}
impl TaskWorker {
pub fn new(config: WorkerConfig) -> Self {
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
format!("Bearer {}", config.token).parse().unwrap(),
);
headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());
let client = reqwest::Client::builder()
.default_headers(headers)
.build()
.unwrap();
Self {
config,
handlers: Arc::new(RwLock::new(HashMap::new())),
client,
running: Arc::new(RwLock::new(false)),
}
}
pub async fn on_task<H: TaskHandler + 'static>(&self, capability: &str, handler: H) {
let mut handlers = self.handlers.write().await;
handlers.insert(capability.to_string(), Arc::new(handler));
eprintln!("[Worker] Registered handler for capability: {}", capability);
}
pub async fn on_task_fn<F, Fut>(&self, capability: &str, handler: F)
where
F: Fn(TaskContext, serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<
Output = Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>>,
> + Send
+ 'static,
{
self.on_task(capability, FnHandler(handler)).await;
}
pub async fn start(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
*self.running.write().await = true;
let heartbeat_config = self.config.clone();
let heartbeat_client = self.client.clone();
let heartbeat_running = self.running.clone();
tokio::spawn(async move {
Self::heartbeat_loop(heartbeat_config, heartbeat_client, heartbeat_running).await;
});
let mut current_delay = self.config.reconnect_delay;
while *self.running.read().await {
match self.connect().await {
Ok(()) => {
current_delay = self.config.reconnect_delay;
}
Err(e) => {
eprintln!("[Worker] SSE connection error: {:?}", e);
}
}
if *self.running.read().await {
eprintln!("[Worker] Reconnecting in {:?}...", current_delay);
sleep(current_delay).await;
current_delay = std::cmp::min(current_delay * 2, self.config.max_reconnect_delay);
}
}
Ok(())
}
pub async fn stop(&self) {
eprintln!("[Worker] Stopping task worker...");
*self.running.write().await = false;
}
async fn connect(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let url = format!("{}/api/v1/tasks/stream", self.config.api_base_url);
eprintln!("[Worker] Connecting to SSE stream: {}", url);
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
format!("Bearer {}", self.config.token).parse().unwrap(),
);
let builder = reqwest::Client::new().get(&url).headers(headers);
let mut es = EventSource::new(builder)?;
eprintln!("[Worker] SSE connection established");
while let Some(event) = es.next().await {
if !*self.running.read().await {
break;
}
match event {
Ok(Event::Open) => {
eprintln!("[Worker] SSE stream opened");
}
Ok(Event::Message(msg)) => {
if msg.event == "new_task" {
if let Ok(task) = serde_json::from_str::<TaskEvent>(&msg.data) {
self.handle_task(task).await;
}
} else if msg.event == "connected" {
eprintln!("[Worker] Connected event: {}", msg.data);
}
}
Err(e) => {
eprintln!("[Worker] SSE error: {:?}", e);
return Err(Box::new(e));
}
}
}
Ok(())
}
async fn handle_task(&self, task: TaskEvent) {
eprintln!("[Worker] Received task: {}", task.task_id);
let capability = task
.capability
.clone()
.or_else(|| task.tool.clone())
.unwrap_or_else(|| "default".to_string());
let handler = {
let handlers = self.handlers.read().await;
handlers
.get(&capability)
.or_else(|| handlers.get("default"))
.cloned()
};
let Some(handler) = handler else {
eprintln!("[Worker] No handler for capability: {}", capability);
self.complete_task(
&task.task_id,
"failed",
None,
Some(serde_json::json!({
"error": format!("No handler registered for capability: {}", capability)
})),
)
.await;
return;
};
let ctx = TaskContext::from(&task);
eprintln!("[Worker] Executing handler for: {}", capability);
match handler.handle(ctx, task.input.clone()).await {
Ok(output) => {
self.complete_task(&task.task_id, "completed", Some(output), None)
.await;
eprintln!("[Worker] Task {} completed successfully", task.task_id);
}
Err(e) => {
eprintln!("[Worker] Task {} failed: {:?}", task.task_id, e);
self.complete_task(
&task.task_id,
"failed",
None,
Some(serde_json::json!({"error": e.to_string()})),
)
.await;
}
}
}
async fn complete_task(
&self,
task_id: &str,
status: &str,
output: Option<serde_json::Value>,
error: Option<serde_json::Value>,
) {
let url = format!(
"{}/api/v1/tasks/{}/complete",
self.config.api_base_url, task_id
);
let request = CompleteTaskRequest {
status: status.to_string(),
output,
error,
};
if let Err(e) = self.client.post(&url).json(&request).send().await {
eprintln!("[Worker] Failed to report task completion: {:?}", e);
}
}
async fn heartbeat_loop(
config: WorkerConfig,
client: reqwest::Client,
running: Arc<RwLock<bool>>,
) {
let url = format!("{}/api/v1/agents/heartbeat", config.api_base_url);
while *running.read().await {
let request = HeartbeatRequest {
status: "idle".to_string(),
load: 0.0,
};
if let Err(e) = client.post(&url).json(&request).send().await {
eprintln!("[Worker] Heartbeat failed: {:?}", e);
}
sleep(config.heartbeat_interval).await;
}
}
}
pub fn create_worker(api_base_url: impl Into<String>, token: impl Into<String>) -> TaskWorker {
TaskWorker::new(WorkerConfig::new(api_base_url, token))
}