systemprompt-agent 0.2.0

Core Agent protocol module for systemprompt.io
Documentation
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use sqlx::PgPool;
use std::sync::Arc;
use systemprompt_database::DbPool;
use systemprompt_identifiers::TaskId;
use systemprompt_models::{ExecutionStep, PlannedTool, StepContent, StepId, StepStatus};

#[allow(missing_debug_implementations)]
struct ParseStepParams {
    step_id: String,
    task_id: TaskId,
    status: String,
    content: serde_json::Value,
    started_at: DateTime<Utc>,
    completed_at: Option<DateTime<Utc>>,
    duration_ms: Option<i32>,
    error_message: Option<String>,
}

fn parse_step(params: ParseStepParams) -> Result<ExecutionStep> {
    let ParseStepParams {
        step_id,
        task_id,
        status,
        content,
        started_at,
        completed_at,
        duration_ms,
        error_message,
    } = params;
    let status = status
        .parse::<StepStatus>()
        .map_err(|e| anyhow::anyhow!("Invalid status: {}", e))?;
    let content: StepContent =
        serde_json::from_value(content).map_err(|e| anyhow::anyhow!("Invalid content: {}", e))?;
    Ok(ExecutionStep {
        step_id: step_id.into(),
        task_id,
        status,
        started_at,
        completed_at,
        duration_ms,
        error_message,
        content,
    })
}

#[derive(Debug, Clone)]
pub struct ExecutionStepRepository {
    pool: Arc<PgPool>,
    write_pool: Arc<PgPool>,
}

impl ExecutionStepRepository {
    pub fn new(db: &DbPool) -> Result<Self> {
        let pool = db.pool_arc()?;
        let write_pool = db.write_pool_arc()?;
        Ok(Self { pool, write_pool })
    }

    pub async fn create(&self, step: &ExecutionStep) -> Result<()> {
        let step_id_str = step.step_id.as_str();
        let task_id = &step.task_id;
        let status_str = step.status.to_string();
        let step_type_str = step.content.step_type().to_string();
        let title = step.content.title();
        let content_json =
            serde_json::to_value(&step.content).context("Failed to serialize step content")?;
        sqlx::query!(
            r#"INSERT INTO task_execution_steps (
                step_id, task_id, step_type, title, status, content, started_at, completed_at, duration_ms, error_message
            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"#,
            step_id_str,
            task_id.as_str(),
            step_type_str,
            title,
            status_str,
            content_json,
            step.started_at,
            step.completed_at,
            step.duration_ms,
            step.error_message
        )
        .execute(&*self.write_pool)
        .await
        .context("Failed to create execution step")?;
        Ok(())
    }

    pub async fn get(&self, step_id: &StepId) -> Result<Option<ExecutionStep>> {
        let step_id_str = step_id.as_str();
        let row = sqlx::query!(
            r#"SELECT step_id, task_id as "task_id!: TaskId", status, content,
                    started_at as "started_at!", completed_at, duration_ms, error_message
                FROM task_execution_steps WHERE step_id = $1"#,
            step_id_str
        )
        .fetch_optional(&*self.pool)
        .await
        .context(format!("Failed to get execution step: {step_id}"))?;
        row.map(|r| {
            parse_step(ParseStepParams {
                step_id: r.step_id,
                task_id: r.task_id,
                status: r.status,
                content: r.content,
                started_at: r.started_at,
                completed_at: r.completed_at,
                duration_ms: r.duration_ms,
                error_message: r.error_message,
            })
        })
        .transpose()
    }

    pub async fn list_by_task(&self, task_id: &TaskId) -> Result<Vec<ExecutionStep>> {
        let rows = sqlx::query!(
            r#"SELECT step_id, task_id as "task_id!: TaskId", status, content,
                    started_at as "started_at!", completed_at, duration_ms, error_message
                FROM task_execution_steps WHERE task_id = $1 ORDER BY started_at ASC"#,
            task_id.as_str()
        )
        .fetch_all(&*self.pool)
        .await
        .context(format!(
            "Failed to list execution steps for task: {}",
            task_id
        ))?;
        rows.into_iter()
            .map(|r| {
                parse_step(ParseStepParams {
                    step_id: r.step_id,
                    task_id: r.task_id,
                    status: r.status,
                    content: r.content,
                    started_at: r.started_at,
                    completed_at: r.completed_at,
                    duration_ms: r.duration_ms,
                    error_message: r.error_message,
                })
            })
            .collect()
    }

    pub async fn complete_step(
        &self,
        step_id: &StepId,
        started_at: DateTime<Utc>,
        tool_result: Option<serde_json::Value>,
    ) -> Result<()> {
        let completed_at = Utc::now();
        let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
        let step_id_str = step_id.as_str();
        let status_str = StepStatus::Completed.to_string();

        if let Some(result) = tool_result {
            sqlx::query!(
                r#"UPDATE task_execution_steps SET
                    status = $2,
                    completed_at = $3,
                    duration_ms = $4,
                    content = jsonb_set(content, '{tool_result}', $5)
                WHERE step_id = $1"#,
                step_id_str,
                status_str,
                completed_at,
                duration_ms,
                result
            )
            .execute(&*self.write_pool)
            .await
            .context(format!("Failed to complete execution step: {step_id}"))?;
        } else {
            sqlx::query!(
                r#"UPDATE task_execution_steps SET
                    status = $2,
                    completed_at = $3,
                    duration_ms = $4
                WHERE step_id = $1"#,
                step_id_str,
                status_str,
                completed_at,
                duration_ms
            )
            .execute(&*self.write_pool)
            .await
            .context(format!("Failed to complete execution step: {step_id}"))?;
        }

        Ok(())
    }

    pub async fn fail_step(
        &self,
        step_id: &StepId,
        started_at: DateTime<Utc>,
        error_message: &str,
    ) -> Result<()> {
        let completed_at = Utc::now();
        let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
        let step_id_str = step_id.as_str();
        let status_str = StepStatus::Failed.to_string();

        sqlx::query!(
            r#"UPDATE task_execution_steps SET
                status = $2,
                completed_at = $3,
                duration_ms = $4,
                error_message = $5
            WHERE step_id = $1"#,
            step_id_str,
            status_str,
            completed_at,
            duration_ms,
            error_message
        )
        .execute(&*self.write_pool)
        .await
        .context(format!("Failed to fail execution step: {step_id}"))?;

        Ok(())
    }

    pub async fn fail_in_progress_steps_for_task(
        &self,
        task_id: &TaskId,
        error_message: &str,
    ) -> Result<u64> {
        let completed_at = Utc::now();
        let in_progress_str = StepStatus::InProgress.to_string();
        let failed_str = StepStatus::Failed.to_string();
        let task_id_str = task_id.as_str();

        let result = sqlx::query!(
            r#"UPDATE task_execution_steps SET
                status = $3,
                completed_at = $4,
                error_message = $5
            WHERE task_id = $1 AND status = $2"#,
            task_id_str,
            in_progress_str,
            failed_str,
            completed_at,
            error_message
        )
        .execute(&*self.write_pool)
        .await
        .context(format!(
            "Failed to fail in-progress steps for task: {}",
            task_id
        ))?;

        Ok(result.rows_affected())
    }

    pub async fn complete_planning_step(
        &self,
        step_id: &StepId,
        started_at: DateTime<Utc>,
        reasoning: Option<String>,
        planned_tools: Option<Vec<PlannedTool>>,
    ) -> Result<ExecutionStep> {
        let completed_at = Utc::now();
        let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
        let step_id_str = step_id.as_str();
        let status_str = StepStatus::Completed.to_string();

        let content = StepContent::planning(reasoning, planned_tools);
        let content_json =
            serde_json::to_value(&content).context("Failed to serialize planning content")?;

        let row = sqlx::query!(
            r#"UPDATE task_execution_steps SET
                status = $2,
                completed_at = $3,
                duration_ms = $4,
                content = $5
            WHERE step_id = $1
            RETURNING step_id, task_id as "task_id!: TaskId", status, content,
                    started_at as "started_at!", completed_at, duration_ms, error_message"#,
            step_id_str,
            status_str,
            completed_at,
            duration_ms,
            content_json
        )
        .fetch_one(&*self.write_pool)
        .await
        .context(format!("Failed to complete planning step: {step_id}"))?;

        parse_step(ParseStepParams {
            step_id: row.step_id,
            task_id: row.task_id,
            status: row.status,
            content: row.content,
            started_at: row.started_at,
            completed_at: row.completed_at,
            duration_ms: row.duration_ms,
            error_message: row.error_message,
        })
    }

    pub async fn mcp_execution_id_exists(&self, mcp_execution_id: &str) -> Result<bool> {
        let exists = sqlx::query_scalar!(
            r#"SELECT EXISTS(SELECT 1 FROM mcp_tool_executions WHERE mcp_execution_id = $1) as "exists!""#,
            mcp_execution_id
        )
        .fetch_one(&*self.pool)
        .await
        .context("Failed to check mcp_execution_id existence")?;

        Ok(exists)
    }
}