systemprompt-agent 0.2.2

Agent-to-Agent (A2A) protocol for systemprompt.io AI governance: streaming, JSON-RPC models, task lifecycle, .well-known discovery, and governed agent orchestration.
Documentation
use sqlx::PgPool;
use std::sync::Arc;
use systemprompt_traits::RepositoryError;

use crate::models::a2a::{Task, TaskState};

pub const fn task_state_to_db_string(state: TaskState) -> &'static str {
    match state {
        TaskState::Pending => "TASK_STATE_PENDING",
        TaskState::Submitted => "TASK_STATE_SUBMITTED",
        TaskState::Working => "TASK_STATE_WORKING",
        TaskState::InputRequired => "TASK_STATE_INPUT_REQUIRED",
        TaskState::Completed => "TASK_STATE_COMPLETED",
        TaskState::Canceled => "TASK_STATE_CANCELED",
        TaskState::Failed => "TASK_STATE_FAILED",
        TaskState::Rejected => "TASK_STATE_REJECTED",
        TaskState::AuthRequired => "TASK_STATE_AUTH_REQUIRED",
        TaskState::Unknown => "TASK_STATE_UNKNOWN",
    }
}

#[allow(missing_debug_implementations)]
pub struct CreateTaskParams<'a> {
    pub pool: &'a Arc<PgPool>,
    pub task: &'a Task,
    pub user_id: &'a systemprompt_identifiers::UserId,
    pub session_id: &'a systemprompt_identifiers::SessionId,
    pub trace_id: &'a systemprompt_identifiers::TraceId,
    pub agent_name: &'a str,
}

pub async fn create_task(params: CreateTaskParams<'_>) -> Result<String, RepositoryError> {
    let CreateTaskParams {
        pool,
        task,
        user_id,
        session_id,
        trace_id,
        agent_name,
    } = params;
    let metadata_json = task.metadata.as_ref().map_or_else(
        || serde_json::json!({}),
        |m| {
            serde_json::to_value(m).unwrap_or_else(|e| {
                tracing::warn!(error = %e, task_id = %task.id, "Failed to serialize task metadata");
                serde_json::json!({})
            })
        },
    );

    let status = task_state_to_db_string(task.status.state);
    let task_id_str = task.id.as_str();
    let context_id_str = task.context_id.as_str();
    let user_id_str = user_id.as_ref();
    let session_id_str = session_id.as_ref();
    let trace_id_str = trace_id.as_ref();

    sqlx::query!(
        r#"INSERT INTO agent_tasks (task_id, context_id, status, status_timestamp, user_id, session_id, trace_id, metadata, agent_name)
        VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"#,
        task_id_str,
        context_id_str,
        status,
        task.status.timestamp,
        user_id_str,
        session_id_str,
        trace_id_str,
        metadata_json,
        agent_name
    )
    .execute(pool.as_ref())
    .await
    .map_err(RepositoryError::database)?;

    Ok(task.id.to_string())
}

pub async fn track_agent_in_context(
    pool: &Arc<PgPool>,
    context_id: &systemprompt_identifiers::ContextId,
    agent_name: &str,
) -> Result<(), RepositoryError> {
    let context_id_str = context_id.as_str();
    sqlx::query!(
        r#"INSERT INTO context_agents (context_id, agent_name) VALUES ($1, $2)
        ON CONFLICT (context_id, agent_name) DO NOTHING"#,
        context_id_str,
        agent_name
    )
    .execute(pool.as_ref())
    .await
    .map_err(RepositoryError::database)?;

    Ok(())
}

pub async fn update_task_state(
    pool: &Arc<PgPool>,
    task_id: &systemprompt_identifiers::TaskId,
    state: TaskState,
    timestamp: &chrono::DateTime<chrono::Utc>,
) -> Result<(), RepositoryError> {
    let status = task_state_to_db_string(state);
    let task_id_str = task_id.as_str();

    if state == TaskState::Completed {
        sqlx::query!(
            r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP,
            completed_at = CURRENT_TIMESTAMP,
            started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
            execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
            WHERE task_id = $3"#,
            status,
            timestamp,
            task_id_str
        )
        .execute(pool.as_ref())
        .await
        .map_err(RepositoryError::database)?;
    } else if state == TaskState::Working {
        sqlx::query!(
            r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP,
            started_at = COALESCE(started_at, CURRENT_TIMESTAMP)
            WHERE task_id = $3"#,
            status,
            timestamp,
            task_id_str
        )
        .execute(pool.as_ref())
        .await
        .map_err(RepositoryError::database)?;
    } else {
        sqlx::query!(
            r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, updated_at = CURRENT_TIMESTAMP WHERE task_id = $3"#,
            status,
            timestamp,
            task_id_str
        )
        .execute(pool.as_ref())
        .await
        .map_err(RepositoryError::database)?;
    }

    Ok(())
}

pub async fn update_task_failed_with_error(
    pool: &Arc<PgPool>,
    task_id: &systemprompt_identifiers::TaskId,
    error_message: &str,
    timestamp: &chrono::DateTime<chrono::Utc>,
) -> Result<(), RepositoryError> {
    let task_id_str = task_id.as_str();

    sqlx::query!(
        r#"UPDATE agent_tasks SET
            status = 'TASK_STATE_FAILED',
            status_timestamp = $1,
            error_message = $2,
            updated_at = CURRENT_TIMESTAMP,
            completed_at = CURRENT_TIMESTAMP,
            started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
            execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
        WHERE task_id = $3"#,
        timestamp,
        error_message,
        task_id_str
    )
    .execute(pool.as_ref())
    .await
    .map_err(RepositoryError::database)?;

    Ok(())
}