systemprompt-agent 0.1.22

Core Agent protocol module for systemprompt.io
Documentation
use crate::models::a2a::{Artifact, Message, MessageRole, Part, Task, TaskStatus};
use crate::models::{MessagePart, TaskMessage, TaskRow};
use crate::repository::execution::ExecutionStepRepository;
use systemprompt_identifiers::{
    AgentName, ContextId, MessageId, SessionId, TaskId, TraceId, UserId,
};
use systemprompt_models::ExecutionStep;
use systemprompt_traits::RepositoryError;

use super::{TaskConstructor, converters};

pub async fn construct_task_from_task_id(
    constructor: &TaskConstructor,
    task_id: &TaskId,
) -> Result<Task, RepositoryError> {
    let row = fetch_task_row(constructor, task_id).await?;
    construct_task_from_row(constructor, &row).await
}

async fn fetch_task_row(
    constructor: &TaskConstructor,
    task_id: &TaskId,
) -> Result<TaskRow, RepositoryError> {
    let pool = constructor.pool();
    let task_id_str = task_id.as_str();

    sqlx::query_as!(
        TaskRow,
        r#"SELECT
            task_id as "task_id!: TaskId",
            context_id as "context_id!: ContextId",
            status as "status!",
            status_timestamp,
            user_id as "user_id?: UserId",
            session_id as "session_id?: SessionId",
            trace_id as "trace_id?: TraceId",
            agent_name as "agent_name?: AgentName",
            started_at,
            completed_at,
            execution_time_ms,
            error_message,
            metadata,
            created_at as "created_at!",
            updated_at as "updated_at!"
        FROM agent_tasks WHERE task_id = $1"#,
        task_id_str
    )
    .fetch_one(pool.as_ref())
    .await
    .map_err(|e| match e {
        sqlx::Error::RowNotFound => {
            RepositoryError::NotFound(format!("Task {} not found", task_id))
        },
        _ => RepositoryError::database(e),
    })
}

async fn construct_task_from_row(
    constructor: &TaskConstructor,
    row: &TaskRow,
) -> Result<Task, RepositoryError> {
    let task_id = TaskId::new(&row.task_id);

    let history = load_task_messages(constructor, &task_id).await?;
    let artifacts = load_task_artifacts(constructor, &task_id).await?;
    let execution_steps = load_execution_steps(constructor, &task_id).await?;

    let mut metadata = converters::construct_metadata(row);
    if let Some(steps) = execution_steps {
        metadata.execution_steps = Some(steps);
    }

    let task_state = converters::parse_task_state(&row.status)
        .map_err(|e| RepositoryError::InvalidData(e.to_string()))?;

    Ok(Task {
        id: task_id,
        context_id: row.context_id.clone(),
        status: TaskStatus {
            state: task_state,
            message: None,
            timestamp: row.status_timestamp,
        },
        history,
        artifacts,
        metadata: Some(metadata),
        created_at: Some(row.created_at),
        last_modified: Some(row.updated_at),
    })
}

async fn load_task_messages(
    constructor: &TaskConstructor,
    task_id: &TaskId,
) -> Result<Option<Vec<Message>>, RepositoryError> {
    let pool = constructor.pool();
    let task_id_str = task_id.as_str();

    let message_rows: Vec<TaskMessage> = sqlx::query_as!(
        TaskMessage,
        r#"SELECT
            id as "id!",
            task_id as "task_id!: TaskId",
            message_id as "message_id!: MessageId",
            client_message_id,
            role as "role!",
            context_id as "context_id?: ContextId",
            user_id as "user_id?: UserId",
            session_id as "session_id?: SessionId",
            trace_id as "trace_id?: TraceId",
            sequence_number as "sequence_number!",
            created_at as "created_at!",
            updated_at as "updated_at!",
            metadata,
            reference_task_ids
        FROM task_messages WHERE task_id = $1 ORDER BY sequence_number ASC"#,
        task_id_str
    )
    .fetch_all(pool.as_ref())
    .await
    .map_err(RepositoryError::database)?;

    if message_rows.is_empty() {
        return Ok(None);
    }

    let mut messages = Vec::new();
    for msg_row in message_rows {
        let parts = load_message_parts(constructor, msg_row.message_id.as_str(), task_id).await?;
        let message = build_message_from_row(msg_row, parts);
        messages.push(message);
    }

    Ok(Some(messages))
}

async fn load_message_parts(
    constructor: &TaskConstructor,
    message_id: &str,
    task_id: &TaskId,
) -> Result<Vec<Part>, RepositoryError> {
    let pool = constructor.pool();
    let task_id_str = task_id.as_str();

    let part_rows: Vec<MessagePart> = sqlx::query_as!(
        MessagePart,
        r#"SELECT
            id as "id!",
            message_id as "message_id!",
            task_id as "task_id!",
            part_kind as "part_kind!",
            sequence_number as "sequence_number!",
            text_content,
            file_name,
            file_mime_type,
            file_uri,
            file_bytes,
            data_content,
            metadata
        FROM message_parts WHERE message_id = $1 AND task_id = $2 ORDER BY sequence_number ASC"#,
        message_id,
        task_id_str
    )
    .fetch_all(pool.as_ref())
    .await
    .map_err(RepositoryError::database)?;

    converters::build_parts_from_rows(&part_rows)
}

async fn load_task_artifacts(
    constructor: &TaskConstructor,
    task_id: &TaskId,
) -> Result<Option<Vec<Artifact>>, RepositoryError> {
    let artifacts = constructor
        .artifact_repo()
        .get_artifacts_by_task(task_id)
        .await
        .map_err(|e| RepositoryError::InvalidData(e.to_string()))?;

    if artifacts.is_empty() {
        Ok(None)
    } else {
        Ok(Some(artifacts))
    }
}

async fn load_execution_steps(
    constructor: &TaskConstructor,
    task_id: &TaskId,
) -> Result<Option<Vec<ExecutionStep>>, RepositoryError> {
    let step_repo = ExecutionStepRepository::new(constructor.db_pool())?;

    let steps = step_repo
        .list_by_task(task_id)
        .await
        .map_err(RepositoryError::Other)?;

    if steps.is_empty() {
        Ok(None)
    } else {
        Ok(Some(steps))
    }
}

fn build_message_from_row(msg_row: TaskMessage, parts: Vec<Part>) -> Message {
    let reference_task_ids = msg_row
        .reference_task_ids
        .map(|ids| ids.into_iter().map(Into::into).collect());

    let mut final_metadata = msg_row.metadata.unwrap_or_else(|| serde_json::json!({}));
    if let Some(client_id) = &msg_row.client_message_id {
        if let Some(obj) = final_metadata.as_object_mut() {
            obj.insert(
                "clientMessageId".to_string(),
                serde_json::Value::String(client_id.clone()),
            );
        }
    }

    let role = match msg_row.role.as_str() {
        "user" | "ROLE_USER" => MessageRole::User,
        _ => MessageRole::Agent,
    };

    Message {
        role,
        parts,
        message_id: msg_row.message_id,
        task_id: Some(msg_row.task_id),
        context_id: msg_row.context_id.unwrap_or_else(ContextId::empty),
        metadata: if final_metadata == serde_json::json!({}) {
            None
        } else {
            Some(final_metadata)
        },
        extensions: None,
        reference_task_ids,
    }
}