systemprompt-agent 0.2.2

Agent-to-Agent (A2A) protocol for systemprompt.io AI governance: streaming, JSON-RPC models, task lifecycle, .well-known discovery, and governed agent orchestration.
Documentation
use anyhow::{Result, anyhow};
use serde_json::json;
use uuid::Uuid;

use crate::models::a2a::{Message, MessageRole, Part, TextPart};
use crate::repository::context::message::PersistMessageWithTxParams;
use crate::repository::task::TaskRepository;
use systemprompt_database::{DatabaseProvider, DatabaseTransaction, DbPool};
use systemprompt_identifiers::{ContextId, TaskId};
use systemprompt_models::RequestContext;

pub struct PersistMessageInTxParams<'a> {
    pub tx: &'a mut dyn DatabaseTransaction,
    pub message: &'a Message,
    pub task_id: &'a TaskId,
    pub context_id: &'a ContextId,
    pub user_id: Option<&'a systemprompt_identifiers::UserId>,
    pub session_id: &'a systemprompt_identifiers::SessionId,
    pub trace_id: &'a systemprompt_identifiers::TraceId,
}

impl std::fmt::Debug for PersistMessageInTxParams<'_> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("PersistMessageInTxParams")
            .field("message", &self.message)
            .field("task_id", &self.task_id)
            .field("context_id", &self.context_id)
            .field("user_id", &self.user_id)
            .field("session_id", &self.session_id)
            .field("trace_id", &self.trace_id)
            .finish_non_exhaustive()
    }
}

#[derive(Debug)]
pub struct PersistMessagesParams<'a> {
    pub task_id: &'a TaskId,
    pub context_id: &'a ContextId,
    pub messages: Vec<Message>,
    pub user_id: Option<&'a systemprompt_identifiers::UserId>,
    pub session_id: &'a systemprompt_identifiers::SessionId,
    pub trace_id: &'a systemprompt_identifiers::TraceId,
}

#[derive(Debug)]
pub struct CreateToolExecutionMessageParams<'a> {
    pub task_id: &'a TaskId,
    pub context_id: &'a ContextId,
    pub tool_name: &'a str,
    pub tool_args: &'a serde_json::Value,
    pub request_context: &'a RequestContext,
}

pub struct MessageService {
    task_repo: TaskRepository,
}

impl std::fmt::Debug for MessageService {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("MessageService").finish_non_exhaustive()
    }
}

impl MessageService {
    pub fn new(db_pool: &DbPool) -> Result<Self> {
        Ok(Self {
            task_repo: TaskRepository::new(db_pool)?,
        })
    }

    pub async fn persist_message_in_tx(&self, params: PersistMessageInTxParams<'_>) -> Result<i32> {
        let PersistMessageInTxParams {
            tx,
            message,
            task_id,
            context_id,
            user_id,
            session_id,
            trace_id,
        } = params;
        let sequence_number = self
            .task_repo
            .get_next_sequence_number_in_tx(tx, task_id)
            .await?;

        self.task_repo
            .persist_message_with_tx(PersistMessageWithTxParams {
                tx,
                message,
                task_id,
                context_id,
                sequence_number,
                user_id,
                session_id,
                trace_id,
            })
            .await
            .map_err(|e| anyhow!("Failed to persist message: {}", e))?;

        tracing::info!(
            message_id = %message.message_id,
            task_id = %task_id,
            sequence_number = sequence_number,
            "Message persisted"
        );

        Ok(sequence_number)
    }

    pub async fn persist_messages(&self, params: PersistMessagesParams<'_>) -> Result<Vec<i32>> {
        let PersistMessagesParams {
            task_id,
            context_id,
            messages,
            user_id,
            session_id,
            trace_id,
        } = params;

        if messages.is_empty() {
            return Ok(Vec::new());
        }

        let mut tx = self
            .task_repo
            .db_pool()
            .as_ref()
            .begin_transaction()
            .await?;
        let mut sequence_numbers = Vec::new();

        tracing::info!(
            task_id = %task_id,
            message_count = messages.len(),
            "Persisting multiple messages"
        );

        for message in messages {
            let seq = self
                .persist_message_in_tx(PersistMessageInTxParams {
                    tx: &mut *tx,
                    message: &message,
                    task_id,
                    context_id,
                    user_id,
                    session_id,
                    trace_id,
                })
                .await?;
            sequence_numbers.push(seq);
        }

        tx.commit().await?;

        tracing::info!(
            task_id = %task_id,
            sequence_numbers = ?sequence_numbers,
            "Messages persisted successfully"
        );

        Ok(sequence_numbers)
    }

    pub async fn create_tool_execution_message(
        &self,
        params: CreateToolExecutionMessageParams<'_>,
    ) -> Result<(String, i32)> {
        let CreateToolExecutionMessageParams {
            task_id,
            context_id,
            tool_name,
            tool_args,
            request_context,
        } = params;
        let message_id = Uuid::new_v4().to_string();

        let tool_args_display =
            serde_json::to_string_pretty(tool_args).unwrap_or_else(|_| tool_args.to_string());

        let timestamp = chrono::Utc::now().to_rfc3339();

        let message = Message {
            role: MessageRole::User,
            message_id: message_id.clone().into(),
            task_id: Some(task_id.clone()),
            context_id: context_id.clone(),
            parts: vec![Part::Text(TextPart {
                text: format!(
                    "Executed MCP tool: {} with arguments:\n{}\n\nExecution ID: {} at {}",
                    tool_name,
                    tool_args_display,
                    task_id.as_str(),
                    timestamp
                ),
            })],
            metadata: Some(json!({
                "source": "mcp_direct_call",
                "tool_name": tool_name,
                "is_synthetic": true,
                "tool_args": tool_args,
                "execution_timestamp": timestamp,
            })),
            extensions: None,
            reference_task_ids: None,
        };

        let mut tx = self
            .task_repo
            .db_pool()
            .as_ref()
            .begin_transaction()
            .await?;

        let sequence_number = self
            .persist_message_in_tx(PersistMessageInTxParams {
                tx: &mut *tx,
                message: &message,
                task_id,
                context_id,
                user_id: Some(request_context.user_id()),
                session_id: request_context.session_id(),
                trace_id: request_context.trace_id(),
            })
            .await?;

        tx.commit().await?;

        tracing::info!(
            message_id = %message_id,
            task_id = %task_id,
            tool_name = %tool_name,
            sequence_number = sequence_number,
            "Created synthetic tool execution message"
        );

        Ok((message_id, sequence_number))
    }
}