systemprompt-agent 0.2.0

Core Agent protocol module for systemprompt.io
Documentation
use systemprompt_identifiers::{ContextId, SessionId, TaskId, TraceId, UserId};
use systemprompt_traits::RepositoryError;

use crate::models::a2a::{Message, MessageRole};

use super::parts::{
    FileUploadContext, PersistPartSqlxParams, persist_part_sqlx, persist_part_with_tx,
};

const fn role_to_str(role: &MessageRole) -> &'static str {
    match role {
        MessageRole::User => "user",
        MessageRole::Agent => "agent",
    }
}

#[allow(missing_debug_implementations)]
pub struct PersistMessageSqlxParams<'a> {
    pub tx: &'a mut sqlx::Transaction<'static, sqlx::Postgres>,
    pub message: &'a Message,
    pub task_id: &'a TaskId,
    pub context_id: &'a ContextId,
    pub sequence_number: i32,
    pub user_id: Option<&'a UserId>,
    pub session_id: &'a SessionId,
    pub trace_id: &'a TraceId,
    pub upload_ctx: Option<&'a FileUploadContext<'a>>,
}

pub async fn persist_message_sqlx(
    params: PersistMessageSqlxParams<'_>,
) -> Result<(), RepositoryError> {
    let PersistMessageSqlxParams {
        tx,
        message,
        task_id,
        context_id,
        sequence_number,
        user_id,
        session_id,
        trace_id,
        upload_ctx,
    } = params;
    let metadata_json =
        serde_json::to_value(&message.metadata).map_err(RepositoryError::Serialization)?;

    sqlx::query!(
        "DELETE FROM message_parts WHERE message_id = $1",
        message.message_id.as_str()
    )
    .execute(&mut **tx)
    .await
    .map_err(RepositoryError::database)?;

    sqlx::query!(
        "DELETE FROM task_messages WHERE message_id = $1",
        message.message_id.as_str()
    )
    .execute(&mut **tx)
    .await
    .map_err(RepositoryError::database)?;

    let client_message_id = message
        .metadata
        .as_ref()
        .and_then(|m| m.get("clientMessageId"))
        .and_then(|v| v.as_str());

    let reference_task_ids: Option<Vec<String>> = message
        .reference_task_ids
        .as_ref()
        .map(|ids| ids.iter().map(ToString::to_string).collect());

    sqlx::query!(
        r#"INSERT INTO task_messages (task_id, message_id, client_message_id, role, context_id,
        user_id, session_id, trace_id, sequence_number, metadata, reference_task_ids)
        VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)"#,
        task_id.as_str(),
        message.message_id.as_str(),
        client_message_id,
        role_to_str(&message.role),
        context_id.as_str(),
        user_id.map(UserId::as_str),
        session_id.as_str(),
        trace_id.as_str(),
        sequence_number,
        metadata_json,
        reference_task_ids.as_deref()
    )
    .execute(&mut **tx)
    .await
    .map_err(RepositoryError::database)?;

    for (idx, part) in message.parts.iter().enumerate() {
        persist_part_sqlx(PersistPartSqlxParams {
            tx,
            part,
            message_id: &message.message_id,
            task_id,
            sequence_number: idx as i32,
            upload_ctx,
        })
        .await?;
    }

    Ok(())
}

#[allow(missing_debug_implementations)]
pub struct PersistMessageWithTxParams<'a> {
    pub tx: &'a mut dyn systemprompt_database::DatabaseTransaction,
    pub message: &'a Message,
    pub task_id: &'a TaskId,
    pub context_id: &'a ContextId,
    pub sequence_number: i32,
    pub user_id: Option<&'a UserId>,
    pub session_id: &'a SessionId,
    pub trace_id: &'a TraceId,
}

pub async fn persist_message_with_tx(
    params: PersistMessageWithTxParams<'_>,
) -> Result<(), RepositoryError> {
    let PersistMessageWithTxParams {
        tx,
        message,
        task_id,
        context_id,
        sequence_number,
        user_id,
        session_id,
        trace_id,
    } = params;
    let metadata_json = serde_json::to_string(&message.metadata)?;

    let delete_parts_query: &str = "DELETE FROM message_parts WHERE message_id = $1";
    tx.execute(&delete_parts_query, &[&message.message_id.as_str()])
        .await?;

    let delete_messages_query: &str = "DELETE FROM task_messages WHERE message_id = $1";
    tx.execute(&delete_messages_query, &[&message.message_id.as_str()])
        .await?;

    let client_message_id = message
        .metadata
        .as_ref()
        .and_then(|m| m.get("clientMessageId"))
        .and_then(|v| v.as_str());

    let reference_task_ids = message
        .reference_task_ids
        .as_ref()
        .map(|ids| ids.iter().map(ToString::to_string).collect::<Vec<String>>());

    let task_id_str = task_id.as_str();
    let context_id_str = context_id.as_str();
    let user_id_str = user_id.map(UserId::as_str);
    let session_id_str = session_id.as_str();
    let trace_id_str = trace_id.as_str();

    let insert_query: &str = "INSERT INTO task_messages (task_id, message_id, client_message_id, \
                              role, context_id, user_id, session_id, trace_id, sequence_number, \
                              metadata, reference_task_ids) VALUES ($1, $2, $3, $4, $5, $6, $7, \
                              $8, $9, $10, $11)";
    tx.execute(
        &insert_query,
        &[
            &task_id_str,
            &message.message_id.as_str(),
            &client_message_id,
            &role_to_str(&message.role),
            &context_id_str,
            &user_id_str,
            &session_id_str,
            &trace_id_str,
            &sequence_number,
            &metadata_json,
            &reference_task_ids,
        ],
    )
    .await?;

    for (idx, part) in message.parts.iter().enumerate() {
        persist_part_with_tx(tx, part, &message.message_id, task_id, idx as i32).await?;
    }

    Ok(())
}