systemprompt-agent 0.2.0

Core Agent protocol module for systemprompt.io
Documentation
use super::{TaskRepository, task_state_to_db_string};
use crate::models::a2a::{Message, Task, TaskState};
use crate::repository::context::message::{
    FileUploadContext, PersistMessageSqlxParams, get_next_sequence_number_sqlx,
    persist_message_sqlx,
};
use systemprompt_traits::RepositoryError;

#[allow(missing_debug_implementations)]
pub struct UpdateTaskAndSaveMessagesParams<'a> {
    pub task: &'a Task,
    pub user_message: &'a Message,
    pub agent_message: &'a Message,
    pub user_id: Option<&'a systemprompt_identifiers::UserId>,
    pub session_id: &'a systemprompt_identifiers::SessionId,
    pub trace_id: &'a systemprompt_identifiers::TraceId,
}

impl TaskRepository {
    pub async fn update_task_and_save_messages(
        &self,
        params: UpdateTaskAndSaveMessagesParams<'_>,
    ) -> Result<Task, RepositoryError> {
        let UpdateTaskAndSaveMessagesParams {
            task,
            user_message,
            agent_message,
            user_id,
            session_id,
            trace_id,
        } = params;
        let mut tx = self
            .write_pool
            .begin()
            .await
            .map_err(RepositoryError::database)?;

        let status = task_state_to_db_string(task.status.state);
        let metadata_json = task
            .metadata
            .as_ref().map_or_else(|| serde_json::json!({}), |m| {
                serde_json::to_value(m).unwrap_or_else(|e| {
                    tracing::warn!(error = %e, task_id = %task.id, "Failed to serialize task metadata");
                    serde_json::json!({})
                })
            });

        let task_id_str = task.id.as_str();
        let is_completed = task.status.state == TaskState::Completed;

        let result = if is_completed {
            sqlx::query!(
                r#"UPDATE agent_tasks SET
                    status = $1,
                    status_timestamp = $2,
                    metadata = $3,
                    updated_at = CURRENT_TIMESTAMP,
                    completed_at = CURRENT_TIMESTAMP,
                    started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
                    execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
                WHERE task_id = $4"#,
                status,
                task.status.timestamp,
                metadata_json,
                task_id_str
            )
            .execute(&mut *tx)
            .await
            .map_err(RepositoryError::database)?
        } else {
            sqlx::query!(
                r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, metadata = $3, updated_at = CURRENT_TIMESTAMP WHERE task_id = $4"#,
                status,
                task.status.timestamp,
                metadata_json,
                task_id_str
            )
            .execute(&mut *tx)
            .await
            .map_err(RepositoryError::database)?
        };

        if result.rows_affected() == 0 {
            return Err(RepositoryError::NotFound(format!(
                "Task not found for update: {}",
                task.id
            )));
        }

        let upload_ctx = self
            .file_upload_provider
            .as_ref()
            .map(|svc| FileUploadContext {
                upload_provider: svc,
                context_id: &task.context_id,
                user_id,
                session_id: Some(session_id),
                trace_id: Some(trace_id),
            });

        let user_seq = get_next_sequence_number_sqlx(&mut tx, &task.id).await?;
        persist_message_sqlx(PersistMessageSqlxParams {
            tx: &mut tx,
            message: user_message,
            task_id: &task.id,
            context_id: &task.context_id,
            sequence_number: user_seq,
            user_id,
            session_id,
            trace_id,
            upload_ctx: upload_ctx.as_ref(),
        })
        .await?;

        let agent_seq = get_next_sequence_number_sqlx(&mut tx, &task.id).await?;
        persist_message_sqlx(PersistMessageSqlxParams {
            tx: &mut tx,
            message: agent_message,
            task_id: &task.id,
            context_id: &task.context_id,
            sequence_number: agent_seq,
            user_id,
            session_id,
            trace_id,
            upload_ctx: upload_ctx.as_ref(),
        })
        .await?;

        tx.commit().await.map_err(RepositoryError::database)?;

        if let Some(ref analytics_provider) = self.session_analytics_provider {
            for _ in 0..2 {
                if let Err(e) = analytics_provider.increment_message_count(session_id).await {
                    tracing::warn!(error = %e, "Failed to increment analytics message count");
                }
            }
        }

        let updated_task = self.get_task(&task.id).await?.ok_or_else(|| {
            RepositoryError::NotFound(format!("Task not found after update: {}", task.id))
        })?;

        Ok(updated_task)
    }

    pub async fn delete_task(
        &self,
        task_id: &systemprompt_identifiers::TaskId,
    ) -> Result<(), RepositoryError> {
        let task_id_str = task_id.as_str();

        sqlx::query!(
            "DELETE FROM message_parts WHERE message_id IN (SELECT message_id FROM task_messages \
             WHERE task_id = $1)",
            task_id_str
        )
        .execute(&*self.write_pool)
        .await
        .map_err(RepositoryError::database)?;

        sqlx::query!("DELETE FROM task_messages WHERE task_id = $1", task_id_str)
            .execute(&*self.write_pool)
            .await
            .map_err(RepositoryError::database)?;

        sqlx::query!(
            "DELETE FROM task_execution_steps WHERE task_id = $1",
            task_id_str
        )
        .execute(&*self.write_pool)
        .await
        .map_err(RepositoryError::database)?;

        sqlx::query!("DELETE FROM agent_tasks WHERE task_id = $1", task_id_str)
            .execute(&*self.write_pool)
            .await
            .map_err(RepositoryError::database)?;

        Ok(())
    }
}