systemprompt-agent 0.2.0

Core Agent protocol module for systemprompt.io
Documentation
use std::sync::Arc;

use axum::response::sse::Event;
use serde_json::json;
use systemprompt_identifiers::{ContextId, MessageId, TaskId};
use systemprompt_models::{A2AEventBuilder, AgUiEventBuilder, RequestContext};
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};

use crate::models::a2a::jsonrpc::NumberOrString;
use crate::models::a2a::{Message, TaskState};
use crate::repository::task::TaskRepository;
use crate::services::a2a_server::processing::message::{MessageProcessor, StreamEvent};

use super::handlers::text::TextStreamState;
use super::handlers::{HandleCompleteParams, HandleErrorParams, handle_complete, handle_error};
use super::webhook_client::WebhookContext;

pub struct ProcessEventsParams {
    pub tx: UnboundedSender<Event>,
    pub chunk_rx: UnboundedReceiver<StreamEvent>,
    pub task_id: TaskId,
    pub context_id: ContextId,
    pub message_id: MessageId,
    pub original_message: Message,
    pub agent_name: String,
    pub context: RequestContext,
    pub task_repo: TaskRepository,
    pub processor: Arc<MessageProcessor>,
    pub request_id: NumberOrString,
}

impl std::fmt::Debug for ProcessEventsParams {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("ProcessEventsParams")
            .field("task_id", &self.task_id)
            .field("context_id", &self.context_id)
            .field("message_id", &self.message_id)
            .field("agent_name", &self.agent_name)
            .finish_non_exhaustive()
    }
}

struct SendA2aStatusEventParams<'a> {
    tx: &'a UnboundedSender<Event>,
    task_id: &'a TaskId,
    context_id: &'a ContextId,
    state: &'a str,
    is_final: bool,
    request_id: &'a NumberOrString,
}

fn send_a2a_status_event(params: &SendA2aStatusEventParams<'_>) {
    let SendA2aStatusEventParams {
        tx,
        task_id,
        context_id,
        state,
        is_final,
        request_id,
    } = params;
    let event = json!({
        "jsonrpc": "2.0",
        "id": request_id,
        "result": {
            "kind": "status-update",
            "taskId": task_id.as_str(),
            "contextId": context_id.as_str(),
            "status": {
                "state": state,
                "timestamp": chrono::Utc::now().to_rfc3339()
            },
            "final": is_final
        }
    });
    if tx.send(Event::default().data(event.to_string())).is_err() {
        tracing::trace!("Failed to send status event, channel closed");
    }
}

pub struct EmitRunStartedParams<'a> {
    pub tx: &'a UnboundedSender<Event>,
    pub webhook_context: &'a WebhookContext,
    pub context_id: &'a ContextId,
    pub task_id: &'a TaskId,
    pub task_repo: &'a TaskRepository,
    pub request_id: &'a NumberOrString,
}

pub async fn emit_run_started(params: EmitRunStartedParams<'_>) {
    let EmitRunStartedParams {
        tx,
        webhook_context,
        context_id,
        task_id,
        task_repo,
        request_id,
    } = params;
    let working_timestamp = chrono::Utc::now();
    if let Err(e) = task_repo
        .update_task_state(task_id, TaskState::Working, &working_timestamp)
        .await
    {
        tracing::error!(task_id = %task_id, error = %e, "Failed to update task state");
        return;
    }

    send_a2a_status_event(&SendA2aStatusEventParams {
        tx,
        task_id,
        context_id,
        state: "working",
        is_final: false,
        request_id,
    });

    let a2a_event = A2AEventBuilder::task_status_update(
        task_id.clone(),
        context_id.clone(),
        TaskState::Working,
        None,
    );
    if let Err(e) = webhook_context.broadcast_a2a(a2a_event).await {
        tracing::error!(error = %e, "Failed to broadcast A2A working");
    }

    let event = AgUiEventBuilder::run_started(context_id.clone(), task_id.clone(), None);
    if let Err(e) = webhook_context.broadcast_agui(event).await {
        tracing::error!(error = %e, "Failed to broadcast RUN_STARTED");
    }
}

pub async fn process_events(params: ProcessEventsParams) {
    let ProcessEventsParams {
        tx,
        mut chunk_rx,
        task_id,
        context_id,
        message_id,
        original_message,
        agent_name,
        context,
        task_repo,
        processor,
        request_id,
    } = params;

    let webhook_context =
        WebhookContext::new(context.user_id().clone(), context.auth_token().as_str());

    emit_run_started(EmitRunStartedParams {
        tx: &tx,
        webhook_context: &webhook_context,
        context_id: &context_id,
        task_id: &task_id,
        task_repo: &task_repo,
        request_id: &request_id,
    })
    .await;

    tracing::info!("Stream channel received, waiting for events...");

    let mut text_state = TextStreamState::new().with_webhook_context(webhook_context.clone());

    while let Some(event) = chunk_rx.recv().await {
        match event {
            StreamEvent::Text(text) => {
                text_state.handle_text(text, &message_id).await;
            },
            StreamEvent::ToolCallStarted(tool_call) => {
                let tool_call_id = tool_call.ai_tool_call_id.as_str();
                let start_event = AgUiEventBuilder::tool_call_start(
                    tool_call_id,
                    &tool_call.name,
                    Some(message_id.to_string()),
                );
                if let Err(e) = webhook_context.broadcast_agui(start_event).await {
                    tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_START");
                }

                let args_json =
                    serde_json::to_string(&tool_call.arguments).unwrap_or_else(|_| String::new());
                let args_event = AgUiEventBuilder::tool_call_args(tool_call_id, &args_json);
                if let Err(e) = webhook_context.broadcast_agui(args_event).await {
                    tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_ARGS");
                }

                let end_event = AgUiEventBuilder::tool_call_end(tool_call_id);
                if let Err(e) = webhook_context.broadcast_agui(end_event).await {
                    tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_END");
                }
            },
            StreamEvent::ToolResult { call_id, result } => {
                let result_value =
                    serde_json::to_value(&result).unwrap_or_else(|_| serde_json::Value::Null);
                let result_event = AgUiEventBuilder::tool_call_result(
                    uuid::Uuid::new_v4().to_string(),
                    &call_id,
                    result_value,
                );
                if let Err(e) = webhook_context.broadcast_agui(result_event).await {
                    tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_RESULT");
                }
            },
            StreamEvent::ExecutionStepUpdate { step } => {
                let step_event = AgUiEventBuilder::execution_step(step.clone(), context_id.clone());
                if let Err(e) = webhook_context.broadcast_agui(step_event).await {
                    tracing::error!(error = %e, "Failed to broadcast execution_step");
                }
            },
            StreamEvent::Complete {
                full_text,
                artifacts,
            } => {
                text_state.finalize(&message_id).await;

                let complete_params = HandleCompleteParams {
                    tx: &tx,
                    webhook_context: &webhook_context,
                    full_text,
                    artifacts,
                    task_id: &task_id,
                    context_id: &context_id,
                    id: message_id.as_str(),
                    original_message: &original_message,
                    agent_name: &agent_name,
                    context: &context,
                    auth_token: context.auth_token().as_str(),
                    task_repo: &task_repo,
                    processor: &processor,
                };
                handle_complete(complete_params).await;

                send_a2a_status_event(&SendA2aStatusEventParams {
                    tx: &tx,
                    task_id: &task_id,
                    context_id: &context_id,
                    state: "completed",
                    is_final: true,
                    request_id: &request_id,
                });

                let a2a_event = A2AEventBuilder::task_status_update(
                    task_id.clone(),
                    context_id.clone(),
                    TaskState::Completed,
                    None,
                );
                if let Err(e) = webhook_context.broadcast_a2a(a2a_event).await {
                    tracing::error!(error = %e, "Failed to broadcast A2A completed");
                }

                break;
            },
            StreamEvent::Error(error) => {
                text_state.finalize(&message_id).await;
                handle_error(HandleErrorParams {
                    tx: &tx,
                    webhook_context: &webhook_context,
                    error,
                    task_id: &task_id,
                    context_id: &context_id,
                    task_repo: &task_repo,
                })
                .await;

                send_a2a_status_event(&SendA2aStatusEventParams {
                    tx: &tx,
                    task_id: &task_id,
                    context_id: &context_id,
                    state: "failed",
                    is_final: true,
                    request_id: &request_id,
                });

                let a2a_event = A2AEventBuilder::task_status_update(
                    task_id.clone(),
                    context_id.clone(),
                    TaskState::Failed,
                    None,
                );
                if let Err(e) = webhook_context.broadcast_a2a(a2a_event).await {
                    tracing::error!(error = %e, "Failed to broadcast A2A failed");
                }

                break;
            },
        }
    }

    drop(tx);

    tracing::info!("Stream event loop ended");
}

pub async fn handle_stream_creation_error(
    webhook_context: &WebhookContext,
    error: anyhow::Error,
    task_id: &TaskId,
    _context_id: &ContextId,
    task_repo: &TaskRepository,
) {
    let error_msg = format!("Failed to create stream: {}", error);
    tracing::error!(task_id = %task_id, error = %error, "Failed to create stream");

    let failed_timestamp = chrono::Utc::now();
    if let Err(e) = task_repo
        .update_task_failed_with_error(task_id, &error_msg, &failed_timestamp)
        .await
    {
        tracing::error!(task_id = %task_id, error = %e, "Failed to update task to failed state");
    }

    let error_event = AgUiEventBuilder::run_error(
        format!("Failed to process message: {error}"),
        Some("STREAM_CREATION_ERROR".to_string()),
    );
    if let Err(e) = webhook_context.broadcast_agui(error_event).await {
        tracing::error!(error = %e, "Failed to broadcast RUN_ERROR");
    }
}