systemprompt-agent 0.2.2

Agent-to-Agent (A2A) protocol for systemprompt.io AI governance: streaming, JSON-RPC models, task lifecycle, .well-known discovery, and governed agent orchestration.
Documentation
use sqlx::PgPool;
use std::collections::HashMap;
use std::sync::Arc;
use systemprompt_identifiers::{ContextId, MessageId, SessionId, TaskId, TraceId, UserId};
use systemprompt_traits::RepositoryError;

use crate::models::a2a::{Message, MessageRole, Part};
use crate::repository::task::constructor::batch_queries;


pub async fn get_messages_by_task(
    pool: &Arc<PgPool>,
    task_id: &TaskId,
) -> Result<Vec<Message>, RepositoryError> {
    let message_rows: Vec<crate::models::TaskMessage> = sqlx::query_as!(
        crate::models::TaskMessage,
        r#"SELECT
            id as "id!",
            task_id as "task_id!: TaskId",
            message_id as "message_id!: MessageId",
            client_message_id,
            role as "role!",
            context_id as "context_id?: ContextId",
            user_id as "user_id?: UserId",
            session_id as "session_id?: SessionId",
            trace_id as "trace_id?: TraceId",
            sequence_number as "sequence_number!",
            created_at as "created_at!",
            updated_at as "updated_at!",
            metadata,
            reference_task_ids
        FROM task_messages WHERE task_id = $1 ORDER BY sequence_number ASC"#,
        task_id.as_str()
    )
    .fetch_all(pool.as_ref())
    .await
    .map_err(RepositoryError::database)?;

    let task_ids: Vec<String> = message_rows.iter().map(|r| r.task_id.to_string()).collect();
    let all_parts = batch_queries::fetch_message_parts(pool, &task_ids).await?;
    let parts_by_message = group_parts_by_message(all_parts);

    let mut messages = Vec::new();

    for row in message_rows {
        let parts = parts_by_message
            .get(&row.message_id)
            .cloned()
            .unwrap_or_default();

        let reference_task_ids = row
            .reference_task_ids
            .map(|ids| ids.into_iter().map(TaskId::new).collect());

        let role = match row.role.as_str() {
            "user" | "ROLE_USER" => MessageRole::User,
            _ => MessageRole::Agent,
        };

        messages.push(Message {
            role,
            message_id: row.message_id,
            task_id: Some(row.task_id),
            context_id: row.context_id.unwrap_or_else(ContextId::empty),
            parts,
            metadata: row.metadata,
            extensions: None,
            reference_task_ids,
        });
    }

    Ok(messages)
}

pub async fn get_messages_by_context(
    pool: &Arc<PgPool>,
    context_id: &ContextId,
) -> Result<Vec<Message>, RepositoryError> {
    let message_rows: Vec<crate::models::TaskMessage> = sqlx::query_as!(
        crate::models::TaskMessage,
        r#"SELECT
            m.id as "id!",
            m.task_id as "task_id!: TaskId",
            m.message_id as "message_id!: MessageId",
            m.client_message_id,
            m.role as "role!",
            m.context_id as "context_id?: ContextId",
            m.user_id as "user_id?: UserId",
            m.session_id as "session_id?: SessionId",
            m.trace_id as "trace_id?: TraceId",
            m.sequence_number as "sequence_number!",
            m.created_at as "created_at!",
            m.updated_at as "updated_at!",
            m.metadata,
            m.reference_task_ids
        FROM task_messages m
        JOIN agent_tasks t ON m.task_id = t.task_id
        WHERE t.context_id = $1
        ORDER BY m.created_at ASC"#,
        context_id.as_str()
    )
    .fetch_all(pool.as_ref())
    .await
    .map_err(RepositoryError::database)?;

    let task_ids: Vec<String> = message_rows.iter().map(|r| r.task_id.to_string()).collect();
    let all_parts = batch_queries::fetch_message_parts(pool, &task_ids).await?;
    let parts_by_message = group_parts_by_message(all_parts);

    let mut messages = Vec::new();

    for row in message_rows {
        let parts = parts_by_message
            .get(&row.message_id)
            .cloned()
            .unwrap_or_default();

        let role = match row.role.as_str() {
            "user" | "ROLE_USER" => MessageRole::User,
            _ => MessageRole::Agent,
        };

        messages.push(Message {
            role,
            message_id: row.message_id,
            task_id: Some(row.task_id),
            context_id: row.context_id.unwrap_or_else(|| context_id.clone()),
            parts,
            metadata: row.metadata,
            extensions: None,
            reference_task_ids: None,
        });
    }

    Ok(messages)
}

pub async fn get_next_sequence_number(
    pool: &Arc<PgPool>,
    task_id: &TaskId,
) -> Result<i32, RepositoryError> {
    let row = sqlx::query!(
        r#"SELECT MAX(sequence_number) as "max_seq" FROM task_messages WHERE task_id = $1"#,
        task_id.as_str()
    )
    .fetch_optional(pool.as_ref())
    .await
    .map_err(RepositoryError::database)?;

    Ok(row.and_then(|r| r.max_seq).map_or(0, |s| s + 1))
}

pub async fn get_next_sequence_number_sqlx(
    tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
    task_id: &TaskId,
) -> Result<i32, RepositoryError> {
    let row = sqlx::query!(
        r#"SELECT MAX(sequence_number) as "max_seq" FROM task_messages WHERE task_id = $1"#,
        task_id.as_str()
    )
    .fetch_optional(&mut **tx)
    .await
    .map_err(RepositoryError::database)?;

    Ok(row.and_then(|r| r.max_seq).map_or(0, |s| s + 1))
}

pub async fn get_next_sequence_number_in_tx(
    tx: &mut dyn systemprompt_database::DatabaseTransaction,
    task_id: &TaskId,
) -> Result<i32, RepositoryError> {
    let query: &str =
        "SELECT MAX(sequence_number) as max_seq FROM task_messages WHERE task_id = $1";
    let task_id_str = task_id.as_str();
    let row = tx.fetch_optional(&query, &[&task_id_str]).await?;

    let max_seq = row.as_ref().and_then(|r| {
        r.get("max_seq")
            .and_then(serde_json::Value::as_i64)
            .map(|v| v as i32)
    });

    Ok(max_seq.map_or(0, |s| s + 1))
}

fn group_parts_by_message(
    all_parts: Vec<crate::models::MessagePart>,
) -> HashMap<MessageId, Vec<Part>> {
    use crate::models::a2a::{DataPart, FileContent, FilePart, TextPart};

    let mut map: HashMap<MessageId, Vec<Part>> = HashMap::new();
    for row in all_parts {
        let part = match row.part_kind.as_str() {
            "text" => row.text_content.map(|text| Part::Text(TextPart { text })),
            "file" => Some(Part::File(FilePart {
                file: FileContent {
                    name: row.file_name,
                    mime_type: row.file_mime_type,
                    bytes: row.file_bytes,
                    url: row.file_uri,
                },
            })),
            "data" => row.data_content.and_then(|v| match v {
                serde_json::Value::Object(data) => Some(Part::Data(DataPart { data })),
                _ => None,
            }),
            _ => None,
        };
        if let Some(part) = part {
            map.entry(row.message_id).or_default().push(part);
        }
    }
    map
}