systemprompt-agent 0.2.0

Core Agent protocol module for systemprompt.io
Documentation
use crate::models::a2a::{
    Artifact, Message, MessageRole, Part, Task, TaskState, TaskStatus, TextPart,
};
use crate::repository::context::ContextRepository;
use crate::repository::task::TaskRepository;
use crate::services::MessageService;
use rmcp::ErrorData as McpError;
use systemprompt_database::DbPool;
use systemprompt_identifiers::{ContextId, MessageId, SessionId, TaskId, TraceId, UserId};
use systemprompt_models::{Config, TaskMetadata};

#[derive(Debug)]
pub struct TaskResult {
    pub task_id: TaskId,
    pub is_owner: bool,
}

pub async fn ensure_task_exists(
    db_pool: &DbPool,
    request_context: &mut systemprompt_models::execution::context::RequestContext,
    tool_name: &str,
    mcp_server_name: &str,
) -> Result<TaskResult, McpError> {
    if let Some(task_id) = request_context.task_id() {
        tracing::info!(task_id = %task_id.as_str(), "Task reused from parent");
        return Ok(TaskResult {
            task_id: task_id.clone(),
            is_owner: false,
        });
    }

    let context_id = request_context.context_id();
    let context_repo = ContextRepository::new(db_pool).map_err(|e| {
        McpError::internal_error(format!("Failed to create context repository: {e}"), None)
    })?;

    let context_id = if context_id.is_empty() {
        if let Ok(Some(existing)) = context_repo
            .find_by_session_id(request_context.session_id())
            .await
        {
            tracing::debug!(
                context_id = %existing.context_id,
                session_id = %request_context.session_id(),
                "Reusing existing context for MCP session"
            );
            request_context.execution.context_id = existing.context_id.clone();
            existing.context_id
        } else {
            let new_context_id = context_repo
                .create_context(
                    request_context.user_id(),
                    Some(request_context.session_id()),
                    &format!("MCP Session: {}", request_context.session_id()),
                )
                .await
                .map_err(|e| {
                    tracing::error!(error = %e, "Failed to auto-create context for MCP session");
                    McpError::internal_error(format!("Failed to create context: {e}"), None)
                })?;

            request_context.execution.context_id = new_context_id.clone();
            tracing::info!(
                context_id = %new_context_id,
                session_id = %request_context.session_id(),
                "Auto-created context for MCP session"
            );
            new_context_id
        }
    } else {
        let old_context_id = context_id.clone();
        match context_repo
            .validate_context_ownership(&old_context_id, request_context.user_id())
            .await
        {
            Ok(()) => old_context_id,
            Err(e) => {
                tracing::warn!(
                    context_id = %old_context_id,
                    user_id = %request_context.user_id(),
                    error = %e,
                    "Context validation failed, auto-creating new context"
                );
                let new_context_id = context_repo
                    .create_context(
                        request_context.user_id(),
                        Some(request_context.session_id()),
                        &format!("MCP Session: {}", request_context.session_id()),
                    )
                    .await
                    .map_err(|e| {
                        tracing::error!(error = %e, "Failed to auto-create replacement context");
                        McpError::internal_error(format!("Failed to create context: {e}"), None)
                    })?;

                request_context.execution.context_id = new_context_id.clone();
                tracing::info!(
                    old_context_id = %old_context_id,
                    new_context_id = %new_context_id,
                    session_id = %request_context.session_id(),
                    "Auto-created replacement context for invalid context_id"
                );
                new_context_id
            },
        }
    };

    let task_repo = TaskRepository::new(db_pool).map_err(|e| {
        McpError::internal_error(format!("Failed to create task repository: {e}"), None)
    })?;

    let task_id = TaskId::generate();

    let agent_name = request_context.agent_name().to_string();

    let metadata = TaskMetadata::new_mcp_execution(
        agent_name.clone(),
        tool_name.to_string(),
        mcp_server_name.to_string(),
    );

    let task = Task {
        id: task_id.clone(),
        context_id: context_id.clone(),
        status: TaskStatus {
            state: TaskState::Submitted,
            message: None,
            timestamp: Some(chrono::Utc::now()),
        },
        history: None,
        artifacts: None,
        metadata: Some(metadata),
        created_at: Some(chrono::Utc::now()),
        last_modified: Some(chrono::Utc::now()),
    };

    task_repo
        .create_task(crate::repository::task::RepoCreateTaskParams {
            task: &task,
            user_id: request_context.user_id(),
            session_id: request_context.session_id(),
            trace_id: request_context.trace_id(),
            agent_name: &agent_name,
        })
        .await
        .map_err(|e| McpError::internal_error(format!("Failed to create task: {e}"), None))?;

    request_context.execution.task_id = Some(task_id.clone());

    tracing::info!(
        task_id = %task_id.as_str(),
        tool = %tool_name,
        agent = %agent_name,
        "Task created"
    );

    Ok(TaskResult {
        task_id,
        is_owner: true,
    })
}

pub async fn complete_task(
    db_pool: &DbPool,
    task_id: &TaskId,
    jwt_token: &str,
) -> Result<(), McpError> {
    if let Err(e) = trigger_task_completion_broadcast(db_pool, task_id, jwt_token).await {
        tracing::error!(
            task_id = %task_id.as_str(),
            error = ?e,
            "Webhook broadcast failed"
        );
    }

    Ok(())
}

async fn trigger_task_completion_broadcast(
    db_pool: &DbPool,
    task_id: &TaskId,
    jwt_token: &str,
) -> Result<(), McpError> {
    let task_repo = TaskRepository::new(db_pool).map_err(|e| {
        McpError::internal_error(format!("Failed to create task repository: {e}"), None)
    })?;

    let task_info = task_repo
        .get_task_context_info(task_id)
        .await
        .map_err(|e| {
            McpError::internal_error(format!("Failed to load task for webhook: {e}"), None)
        })?;

    if let Some(info) = task_info {
        let context_id = info.context_id;
        let user_id = info.user_id;

        let config = Config::get().map_err(|e| McpError::internal_error(e.to_string(), None))?;
        let webhook_url = format!("{}/api/v1/webhook/broadcast", config.api_server_url);
        let webhook_payload = serde_json::json!({
            "event_type": "task_completed",
            "entity_id": task_id.as_str(),
            "context_id": context_id,
            "user_id": user_id,
        });

        tracing::debug!(
            task_id = %task_id.as_str(),
            context_id = %context_id,
            "Webhook triggering"
        );

        let client = reqwest::Client::new();
        match client
            .post(webhook_url)
            .header("Authorization", format!("Bearer {jwt_token}"))
            .json(&webhook_payload)
            .timeout(std::time::Duration::from_secs(5))
            .send()
            .await
        {
            Ok(response) => {
                if response.status().is_success() {
                    tracing::debug!(
                        task_id = %task_id.as_str(),
                        "Task completed, webhook success"
                    );
                } else {
                    let status = response.status();
                    tracing::error!(
                        task_id = %task_id.as_str(),
                        status = %status,
                        "Task completed, webhook failed"
                    );
                }
            },
            Err(e) => {
                tracing::error!(
                    task_id = %task_id.as_str(),
                    error = %e,
                    "Webhook failed"
                );
            },
        }
    }

    Ok(())
}

#[derive(Debug)]
pub struct SaveMessagesForToolExecutionParams<'a> {
    pub db_pool: &'a DbPool,
    pub task_id: &'a TaskId,
    pub context_id: &'a ContextId,
    pub tool_name: &'a str,
    pub tool_result: &'a str,
    pub artifact: Option<&'a Artifact>,
    pub user_id: &'a UserId,
    pub session_id: &'a SessionId,
    pub trace_id: &'a TraceId,
}

pub async fn save_messages_for_tool_execution(
    params: SaveMessagesForToolExecutionParams<'_>,
) -> Result<(), McpError> {
    let SaveMessagesForToolExecutionParams {
        db_pool,
        task_id,
        context_id,
        tool_name,
        tool_result,
        artifact,
        user_id,
        session_id,
        trace_id,
    } = params;
    let message_service = MessageService::new(db_pool).map_err(|e| {
        McpError::internal_error(format!("Failed to create message service: {e}"), None)
    })?;

    let user_message = Message {
        role: MessageRole::User,
        parts: vec![Part::Text(TextPart {
            text: format!("Execute tool: {tool_name}"),
        })],
        message_id: MessageId::generate(),
        task_id: Some(task_id.clone()),
        context_id: context_id.clone(),
        metadata: None,
        extensions: None,
        reference_task_ids: None,
    };

    let agent_text = artifact.map_or_else(
        || format!("Tool execution completed. Result: {tool_result}"),
        |artifact| {
            format!(
                "Tool execution completed. Result: {}\n\nArtifact created: {} (type: {})",
                tool_result, artifact.id, artifact.metadata.artifact_type
            )
        },
    );

    let agent_message = Message {
        role: MessageRole::Agent,
        parts: vec![Part::Text(TextPart { text: agent_text })],
        message_id: MessageId::generate(),
        task_id: Some(task_id.clone()),
        context_id: context_id.clone(),
        metadata: None,
        extensions: None,
        reference_task_ids: None,
    };

    message_service
        .persist_messages(crate::services::PersistMessagesParams {
            task_id,
            context_id,
            messages: vec![user_message, agent_message],
            user_id: Some(user_id),
            session_id,
            trace_id,
        })
        .await
        .map_err(|e| McpError::internal_error(format!("Failed to save messages: {e}"), None))?;

    Ok(())
}