systemprompt-agent 0.1.19

Core Agent protocol module for systemprompt.io
Documentation
use systemprompt_identifiers::{ContextId, SessionId, TaskId, TraceId, UserId};
use systemprompt_traits::RepositoryError;

use crate::models::a2a::Message;

use super::parts::{FileUploadContext, persist_part_sqlx, persist_part_with_tx};

pub async fn persist_message_sqlx(
    tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
    message: &Message,
    task_id: &TaskId,
    context_id: &ContextId,
    sequence_number: i32,
    user_id: Option<&UserId>,
    session_id: &SessionId,
    trace_id: &TraceId,
    upload_ctx: Option<&FileUploadContext<'_>>,
) -> Result<(), RepositoryError> {
    let metadata_json =
        serde_json::to_value(&message.metadata).map_err(|e| RepositoryError::Serialization(e))?;

    sqlx::query!(
        "DELETE FROM message_parts WHERE message_id = $1",
        message.id.as_str()
    )
    .execute(&mut **tx)
    .await
    .map_err(|e| RepositoryError::database(e))?;

    sqlx::query!(
        "DELETE FROM task_messages WHERE message_id = $1",
        message.id.as_str()
    )
    .execute(&mut **tx)
    .await
    .map_err(|e| RepositoryError::database(e))?;

    let client_message_id = message
        .metadata
        .as_ref()
        .and_then(|m| m.get("clientMessageId"))
        .and_then(|v| v.as_str());

    let reference_task_ids: Option<Vec<String>> = message
        .reference_task_ids
        .as_ref()
        .map(|ids| ids.iter().map(|id| id.to_string()).collect());

    sqlx::query!(
        r#"INSERT INTO task_messages (task_id, message_id, client_message_id, role, context_id,
        user_id, session_id, trace_id, sequence_number, metadata, reference_task_ids)
        VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)"#,
        task_id.as_str(),
        message.id.as_str(),
        client_message_id,
        message.role,
        context_id.as_str(),
        user_id.map(UserId::as_str),
        session_id.as_str(),
        trace_id.as_str(),
        sequence_number,
        metadata_json,
        reference_task_ids.as_deref()
    )
    .execute(&mut **tx)
    .await
    .map_err(|e| RepositoryError::database(e))?;

    for (idx, part) in message.parts.iter().enumerate() {
        persist_part_sqlx(tx, part, &message.id, task_id, idx as i32, upload_ctx).await?;
    }

    Ok(())
}

pub async fn persist_message_with_tx(
    tx: &mut dyn systemprompt_database::DatabaseTransaction,
    message: &Message,
    task_id: &TaskId,
    context_id: &ContextId,
    sequence_number: i32,
    user_id: Option<&UserId>,
    session_id: &SessionId,
    trace_id: &TraceId,
) -> Result<(), RepositoryError> {
    let metadata_json = serde_json::to_string(&message.metadata)?;

    let delete_parts_query: &str = "DELETE FROM message_parts WHERE message_id = $1";
    tx.execute(&delete_parts_query, &[&message.id.as_str()])
        .await?;

    let delete_messages_query: &str = "DELETE FROM task_messages WHERE message_id = $1";
    tx.execute(&delete_messages_query, &[&message.id.as_str()])
        .await?;

    let client_message_id = message
        .metadata
        .as_ref()
        .and_then(|m| m.get("clientMessageId"))
        .and_then(|v| v.as_str());

    let reference_task_ids = message
        .reference_task_ids
        .as_ref()
        .map(|ids| ids.iter().map(|id| id.to_string()).collect::<Vec<String>>());

    let task_id_str = task_id.as_str();
    let context_id_str = context_id.as_str();
    let user_id_str = user_id.map(UserId::as_str);
    let session_id_str = session_id.as_str();
    let trace_id_str = trace_id.as_str();

    let insert_query: &str = "INSERT INTO task_messages (task_id, message_id, client_message_id, \
                              role, context_id, user_id, session_id, trace_id, sequence_number, \
                              metadata, reference_task_ids) VALUES ($1, $2, $3, $4, $5, $6, $7, \
                              $8, $9, $10, $11)";
    tx.execute(
        &insert_query,
        &[
            &task_id_str,
            &message.id.as_str(),
            &client_message_id,
            &message.role,
            &context_id_str,
            &user_id_str,
            &session_id_str,
            &trace_id_str,
            &sequence_number,
            &metadata_json,
            &reference_task_ids,
        ],
    )
    .await?;

    for (idx, part) in message.parts.iter().enumerate() {
        persist_part_with_tx(tx, part, &message.id, task_id, idx as i32).await?;
    }

    Ok(())
}