sekuire 0.1.0

The official SDK for the Sekuire Agent Identity Protocol
Documentation
//! Sekuire Task Worker
//!
//! Handles SSE connection to Core for real-time task delivery.
//! Provides `on_task()` registration for capability handlers.

use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;

use futures_util::StreamExt;
use reqwest::header::{HeaderMap, AUTHORIZATION, CONTENT_TYPE};
use reqwest_eventsource::{Event, EventSource};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use tokio::time::sleep;

// ============================================================================
// Types
// ============================================================================

/// Incoming task event from SSE stream
#[derive(Debug, Clone, Deserialize)]
pub struct TaskEvent {
    #[serde(alias = "id")]
    pub task_id: String,
    #[serde(alias = "target_capability")]
    pub capability: Option<String>,
    #[serde(alias = "target_tool")]
    pub tool: Option<String>,
    pub input: serde_json::Value,
    pub workspace_id: String,
    pub requester_agent_id: Option<String>,
    pub parent_task_id: Option<String>,
    pub workflow_id: Option<String>,
}

/// Context passed to task handlers
#[derive(Debug, Clone)]
pub struct TaskContext {
    pub task_id: String,
    pub workspace_id: String,
    pub requester_id: Option<String>,
    pub workflow_id: Option<String>,
}

impl From<&TaskEvent> for TaskContext {
    fn from(event: &TaskEvent) -> Self {
        Self {
            task_id: event.task_id.clone(),
            workspace_id: event.workspace_id.clone(),
            requester_id: event.requester_agent_id.clone(),
            workflow_id: event.workflow_id.clone(),
        }
    }
}

/// Task handler trait
#[async_trait::async_trait]
pub trait TaskHandler: Send + Sync {
    async fn handle(
        &self,
        ctx: TaskContext,
        input: serde_json::Value,
    ) -> Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>>;
}

/// Simple function wrapper for TaskHandler
pub struct FnHandler<F>(pub F);

#[async_trait::async_trait]
impl<F, Fut> TaskHandler for FnHandler<F>
where
    F: Fn(TaskContext, serde_json::Value) -> Fut + Send + Sync,
    Fut: std::future::Future<
            Output = Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>>,
        > + Send,
{
    async fn handle(
        &self,
        ctx: TaskContext,
        input: serde_json::Value,
    ) -> Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>> {
        (self.0)(ctx, input).await
    }
}

/// Worker configuration
#[derive(Debug, Clone)]
pub struct WorkerConfig {
    pub api_base_url: String,
    pub token: String,
    pub heartbeat_interval: Duration,
    pub reconnect_delay: Duration,
    pub max_reconnect_delay: Duration,
}

impl WorkerConfig {
    pub fn new(api_base_url: impl Into<String>, token: impl Into<String>) -> Self {
        Self {
            api_base_url: api_base_url.into(),
            token: token.into(),
            heartbeat_interval: Duration::from_secs(10),
            reconnect_delay: Duration::from_secs(1),
            max_reconnect_delay: Duration::from_secs(30),
        }
    }
}

#[derive(Debug, Serialize)]
struct CompleteTaskRequest {
    status: String,
    output: Option<serde_json::Value>,
    error: Option<serde_json::Value>,
}

#[derive(Debug, Serialize)]
struct HeartbeatRequest {
    status: String,
    load: f32,
}

// ============================================================================
// Task Worker
// ============================================================================

/// Task Worker for receiving and processing tasks via SSE
pub struct TaskWorker {
    config: WorkerConfig,
    handlers: Arc<RwLock<HashMap<String, Arc<dyn TaskHandler>>>>,
    client: reqwest::Client,
    running: Arc<RwLock<bool>>,
}

impl TaskWorker {
    /// Create a new TaskWorker
    pub fn new(config: WorkerConfig) -> Self {
        let mut headers = HeaderMap::new();
        headers.insert(
            AUTHORIZATION,
            format!("Bearer {}", config.token).parse().unwrap(),
        );
        headers.insert(CONTENT_TYPE, "application/json".parse().unwrap());

        let client = reqwest::Client::builder()
            .default_headers(headers)
            .build()
            .unwrap();

        Self {
            config,
            handlers: Arc::new(RwLock::new(HashMap::new())),
            client,
            running: Arc::new(RwLock::new(false)),
        }
    }

    /// Register a handler for a capability
    pub async fn on_task<H: TaskHandler + 'static>(&self, capability: &str, handler: H) {
        let mut handlers = self.handlers.write().await;
        handlers.insert(capability.to_string(), Arc::new(handler));
        eprintln!("[Worker] Registered handler for capability: {}", capability);
    }

    /// Register a function as a handler
    pub async fn on_task_fn<F, Fut>(&self, capability: &str, handler: F)
    where
        F: Fn(TaskContext, serde_json::Value) -> Fut + Send + Sync + 'static,
        Fut: std::future::Future<
                Output = Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>>,
            > + Send
            + 'static,
    {
        self.on_task(capability, FnHandler(handler)).await;
    }

    /// Start the worker (runs until stopped)
    pub async fn start(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
        *self.running.write().await = true;

        // Start heartbeat task
        let heartbeat_config = self.config.clone();
        let heartbeat_client = self.client.clone();
        let heartbeat_running = self.running.clone();
        tokio::spawn(async move {
            Self::heartbeat_loop(heartbeat_config, heartbeat_client, heartbeat_running).await;
        });

        // Main SSE loop with reconnection
        let mut current_delay = self.config.reconnect_delay;

        while *self.running.read().await {
            match self.connect().await {
                Ok(()) => {
                    // Connection closed normally
                    current_delay = self.config.reconnect_delay;
                }
                Err(e) => {
                    eprintln!("[Worker] SSE connection error: {:?}", e);
                }
            }

            if *self.running.read().await {
                eprintln!("[Worker] Reconnecting in {:?}...", current_delay);
                sleep(current_delay).await;

                // Exponential backoff
                current_delay = std::cmp::min(current_delay * 2, self.config.max_reconnect_delay);
            }
        }

        Ok(())
    }

    /// Stop the worker
    pub async fn stop(&self) {
        eprintln!("[Worker] Stopping task worker...");
        *self.running.write().await = false;
    }

    /// Connect to SSE stream
    async fn connect(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
        let url = format!("{}/api/v1/tasks/stream", self.config.api_base_url);
        eprintln!("[Worker] Connecting to SSE stream: {}", url);

        let mut headers = HeaderMap::new();
        headers.insert(
            AUTHORIZATION,
            format!("Bearer {}", self.config.token).parse().unwrap(),
        );

        // EventSource expects a RequestBuilder, not a built Request
        let builder = reqwest::Client::new().get(&url).headers(headers);
        let mut es = EventSource::new(builder)?;

        eprintln!("[Worker] SSE connection established");

        while let Some(event) = es.next().await {
            if !*self.running.read().await {
                break;
            }

            match event {
                Ok(Event::Open) => {
                    eprintln!("[Worker] SSE stream opened");
                }
                Ok(Event::Message(msg)) => {
                    if msg.event == "new_task" {
                        if let Ok(task) = serde_json::from_str::<TaskEvent>(&msg.data) {
                            self.handle_task(task).await;
                        }
                    } else if msg.event == "connected" {
                        eprintln!("[Worker] Connected event: {}", msg.data);
                    }
                }
                Err(e) => {
                    eprintln!("[Worker] SSE error: {:?}", e);
                    return Err(Box::new(e));
                }
            }
        }

        Ok(())
    }

    /// Handle incoming task
    async fn handle_task(&self, task: TaskEvent) {
        eprintln!("[Worker] Received task: {}", task.task_id);

        let capability = task
            .capability
            .clone()
            .or_else(|| task.tool.clone())
            .unwrap_or_else(|| "default".to_string());

        let handler = {
            let handlers = self.handlers.read().await;
            handlers
                .get(&capability)
                .or_else(|| handlers.get("default"))
                .cloned()
        };

        let Some(handler) = handler else {
            eprintln!("[Worker] No handler for capability: {}", capability);
            self.complete_task(
                &task.task_id,
                "failed",
                None,
                Some(serde_json::json!({
                    "error": format!("No handler registered for capability: {}", capability)
                })),
            )
            .await;
            return;
        };

        let ctx = TaskContext::from(&task);

        eprintln!("[Worker] Executing handler for: {}", capability);

        match handler.handle(ctx, task.input.clone()).await {
            Ok(output) => {
                self.complete_task(&task.task_id, "completed", Some(output), None)
                    .await;
                eprintln!("[Worker] Task {} completed successfully", task.task_id);
            }
            Err(e) => {
                eprintln!("[Worker] Task {} failed: {:?}", task.task_id, e);
                self.complete_task(
                    &task.task_id,
                    "failed",
                    None,
                    Some(serde_json::json!({"error": e.to_string()})),
                )
                .await;
            }
        }
    }

    /// Report task completion
    async fn complete_task(
        &self,
        task_id: &str,
        status: &str,
        output: Option<serde_json::Value>,
        error: Option<serde_json::Value>,
    ) {
        let url = format!(
            "{}/api/v1/tasks/{}/complete",
            self.config.api_base_url, task_id
        );

        let request = CompleteTaskRequest {
            status: status.to_string(),
            output,
            error,
        };

        if let Err(e) = self.client.post(&url).json(&request).send().await {
            eprintln!("[Worker] Failed to report task completion: {:?}", e);
        }
    }

    /// Heartbeat loop
    async fn heartbeat_loop(
        config: WorkerConfig,
        client: reqwest::Client,
        running: Arc<RwLock<bool>>,
    ) {
        let url = format!("{}/api/v1/agents/heartbeat", config.api_base_url);

        while *running.read().await {
            let request = HeartbeatRequest {
                status: "idle".to_string(),
                load: 0.0,
            };

            if let Err(e) = client.post(&url).json(&request).send().await {
                eprintln!("[Worker] Heartbeat failed: {:?}", e);
            }

            sleep(config.heartbeat_interval).await;
        }
    }
}

// ============================================================================
// Convenience Function
// ============================================================================

/// Create a new TaskWorker
pub fn create_worker(api_base_url: impl Into<String>, token: impl Into<String>) -> TaskWorker {
    TaskWorker::new(WorkerConfig::new(api_base_url, token))
}