systemprompt-agent 0.2.0

Core Agent protocol module for systemprompt.io
Documentation
use std::sync::Arc;

use axum::response::sse::Event;
use serde_json::json;
use systemprompt_identifiers::{
    AgentName, ContextId, MessageId, SessionId, TaskId, TraceId, UserId,
};
use systemprompt_models::{RequestContext, TaskMetadata};
use tokio::sync::mpsc::UnboundedSender;
use uuid::Uuid;

use crate::models::a2a::jsonrpc::NumberOrString;
use crate::models::a2a::protocol::PushNotificationConfig;
use crate::models::a2a::{Message, Task, TaskState, TaskStatus};
use crate::repository::content::PushNotificationConfigRepository;
use crate::repository::context::ContextRepository;
use crate::repository::task::TaskRepository;
use crate::services::a2a_server::errors::classify_database_error;
use crate::services::a2a_server::handlers::AgentHandlerState;
use crate::services::a2a_server::processing::message::MessageProcessor;

use super::agent_loader::load_agent_runtime;
use super::broadcast::{BroadcastTaskCreatedParams, broadcast_task_created};
use super::types::{PersistTaskInput, StreamInput, StreamSetupResult};

pub fn create_jsonrpc_error_event(code: i32, message: &str, request_id: &NumberOrString) -> Event {
    let error_event = json!({
        "jsonrpc": "2.0",
        "error": { "code": code, "message": message },
        "id": request_id
    });
    Event::default().data(error_event.to_string())
}

pub fn detect_mcp_server_and_update_context(
    agent_name: &str,
    context: &mut RequestContext,
    state: &Arc<AgentHandlerState>,
) {
    let is_mcp_server = state
        .agent_state
        .mcp_service_provider()
        .is_some_and(|provider| {
            provider
                .validate_registry()
                .ok()
                .and_then(|()| {
                    provider
                        .find_server(agent_name)
                        .map_err(|e| {
                            tracing::trace!(agent_name = %agent_name, error = %e, "MCP server lookup failed");
                            e
                        })
                        .ok()
                        .flatten()
                })
                .is_some()
        });

    if is_mcp_server && context.agent_name().as_str() != agent_name {
        tracing::info!(
            agent_name = %agent_name,
            context_agent = %context.agent_name().as_str(),
            "MCP server handling request from agent"
        );
    } else if !is_mcp_server && context.agent_name().as_str() != agent_name {
        tracing::warn!(
            context_agent = %context.agent_name().as_str(),
            service_agent = %agent_name,
            "Agent mismatch, using service name"
        );

        context.execution.agent_name = AgentName::new(agent_name.to_string());
    }
}

pub fn resolve_task_id(message: &Message) -> TaskId {
    message
        .task_id
        .clone()
        .unwrap_or_else(|| TaskId::new(Uuid::new_v4().to_string()))
}

pub async fn validate_context(
    context_id: &ContextId,
    user_id: &UserId,
    state: &Arc<AgentHandlerState>,
    tx: &UnboundedSender<Event>,
    request_id: &NumberOrString,
) -> Result<(), ()> {
    let context_repo = ContextRepository::new(&state.db_pool).map_err(|e| {
        tracing::error!(error = %e, "Failed to create ContextRepository");
        if tx
            .send(create_jsonrpc_error_event(
                -32603,
                &format!("Failed to initialize context repository: {e}"),
                request_id,
            ))
            .is_err()
        {
            tracing::trace!("Failed to send error event, channel closed");
        }
    })?;

    context_repo
        .get_context(context_id, user_id)
        .await
        .map_err(|e| {
            tracing::error!(
                context_id = %context_id,
                user_id = %user_id,
                error = %e,
                "Context validation failed"
            );
            if tx
                .send(create_jsonrpc_error_event(
                    -32603,
                    &format!("Context validation failed: {e}"),
                    request_id,
                ))
                .is_err()
            {
                tracing::trace!("Failed to send error event, channel closed");
            }
        })?;

    tracing::info!(
        context_id = %context_id,
        user_id = %user_id,
        "Context validated"
    );

    Ok(())
}

pub async fn persist_initial_task(input: PersistTaskInput<'_>) -> Result<TaskRepository, ()> {
    let PersistTaskInput {
        task_id,
        context_id,
        agent_name,
        context,
        state,
        tx,
        request_id,
    } = input;

    let task_repo = TaskRepository::new(&state.db_pool).map_err(|e| {
        tracing::error!(error = %e, "Failed to create TaskRepository");
        if tx
            .send(create_jsonrpc_error_event(
                -32603,
                &format!("Failed to initialize task repository: {e}"),
                request_id,
            ))
            .is_err()
        {
            tracing::trace!("Failed to send error event, channel closed");
        }
    })?;
    let metadata = TaskMetadata::new_agent_message(agent_name.to_string());

    let task = Task {
        id: task_id.clone(),
        context_id: context_id.clone(),
        status: TaskStatus {
            state: TaskState::Submitted,
            message: None,
            timestamp: Some(chrono::Utc::now()),
        },
        history: None,
        artifacts: None,
        metadata: Some(metadata),
        created_at: Some(chrono::Utc::now()),
        last_modified: Some(chrono::Utc::now()),
    };

    task_repo
        .create_task(crate::repository::task::RepoCreateTaskParams {
            task: &task,
            user_id: &UserId::new(context.user_id().as_str()),
            session_id: &SessionId::new(context.session_id().as_str()),
            trace_id: &TraceId::new(context.trace_id().as_str()),
            agent_name,
        })
        .await
        .map_err(|e| {
            tracing::error!(task_id = %task_id, error = %e, "Failed to persist task at start");
            let error_detail = classify_database_error(&e);
            if tx
                .send(create_jsonrpc_error_event(
                    -32603,
                    &format!("Failed to create task: {error_detail}"),
                    request_id,
                ))
                .is_err()
            {
                tracing::trace!("Failed to send error event, channel closed");
            }
        })?;

    tracing::info!(task_id = %task_id, "Task persisted to database at stream start");

    if let Err(e) = task_repo
        .track_agent_in_context(context_id, agent_name)
        .await
    {
        tracing::warn!(context_id = %context_id, error = %e, "Failed to track agent in context");
    }

    Ok(task_repo)
}

pub async fn save_push_notification_config(
    task_id: &TaskId,
    callback_config: Option<&PushNotificationConfig>,
    state: &Arc<AgentHandlerState>,
) {
    let Some(config) = callback_config else {
        return;
    };

    tracing::info!(url = %config.url, "Push notification callback registered");

    let config_repo = match PushNotificationConfigRepository::new(&state.db_pool) {
        Ok(repo) => repo,
        Err(e) => {
            tracing::warn!(task_id = %task_id, error = %e, "Failed to create PushNotificationConfigRepository");
            return;
        },
    };

    match config_repo.add_config(task_id, config).await {
        Ok(_) => tracing::info!(task_id = %task_id, "Push notification config saved"),
        Err(e) => {
            tracing::warn!(task_id = %task_id, error = %e, "Failed to save push notification config");
        },
    }
}

pub async fn setup_stream(
    input: StreamInput,
    tx: &UnboundedSender<Event>,
) -> Result<StreamSetupResult, ()> {
    let StreamInput {
        message,
        agent_name,
        state,
        request_id,
        mut context,
        callback_config,
    } = input;

    detect_mcp_server_and_update_context(&agent_name, &mut context, &state);

    let task_id = resolve_task_id(&message);
    let context_id = message.context_id.clone();
    let message_id = MessageId::new(Uuid::new_v4().to_string());

    tracing::info!(
        task_id = %task_id,
        context_id = %context_id,
        message_id = %message_id,
        "Generated IDs"
    );

    validate_context(&context_id, context.user_id(), &state, tx, &request_id).await?;

    let persist_input = PersistTaskInput {
        task_id: &task_id,
        context_id: &context_id,
        agent_name: &agent_name,
        context: &context,
        state: &state,
        tx,
        request_id: &request_id,
    };
    let task_repo = persist_initial_task(persist_input).await?;

    broadcast_task_created(BroadcastTaskCreatedParams {
        task_id: &task_id,
        context_id: &context_id,
        user_id: context.user_id().as_str(),
        user_message: &message,
        agent_name: &agent_name,
        token: context.auth.auth_token.as_str(),
    })
    .await;

    save_push_notification_config(&task_id, callback_config.as_ref(), &state).await;

    let agent_runtime =
        load_agent_runtime(&agent_name, &task_id, &task_repo, tx, &request_id).await?;

    let processor =
        MessageProcessor::new(&state.db_pool, Arc::clone(&state.ai_service)).map_err(|e| {
            tracing::error!(error = %e, "Failed to create MessageProcessor");
            if tx
                .send(create_jsonrpc_error_event(
                    -32603,
                    &format!("Failed to initialize message processor: {e}"),
                    &request_id,
                ))
                .is_err()
            {
                tracing::trace!("Failed to send error event, channel closed");
            }
        })?;

    Ok(StreamSetupResult {
        task_id,
        context_id,
        message_id,
        message,
        agent_name,
        context,
        task_repo,
        agent_runtime,
        processor: Arc::new(processor),
        request_id,
    })
}