systemprompt-agent 0.2.0

Core Agent protocol module for systemprompt.io
Documentation
use crate::models::a2a::{Task, TaskStatus};
use crate::models::{
    ArtifactPartRow, ArtifactRow, ExecutionStepBatchRow, MessagePart, TaskMessage, TaskRow,
};
use std::collections::HashMap;
use std::sync::Arc;
use systemprompt_identifiers::{ArtifactId, MessageId, TaskId};
use systemprompt_traits::RepositoryError;

use super::batch_builders::{build_artifacts, build_execution_steps, build_messages};
use super::{TaskConstructor, converters};

pub async fn construct_tasks_batch(
    constructor: &TaskConstructor,
    task_ids: &[TaskId],
) -> Result<Vec<Task>, RepositoryError> {
    if task_ids.is_empty() {
        return Ok(Vec::new());
    }

    let pool = Arc::clone(constructor.pool());
    let task_id_strings: Vec<String> = task_ids.iter().map(ToString::to_string).collect();

    let task_rows = super::batch_queries::fetch_task_rows(&pool, &task_id_strings).await?;
    let all_messages = super::batch_queries::fetch_messages(&pool, &task_id_strings).await?;
    let all_parts = super::batch_queries::fetch_message_parts(&pool, &task_id_strings).await?;
    let all_artifact_rows = super::batch_queries::fetch_artifacts(&pool, &task_id_strings).await?;
    let all_execution_steps =
        super::batch_queries::fetch_execution_steps(&pool, &task_id_strings).await?;

    let artifact_ids: Vec<String> = all_artifact_rows
        .iter()
        .map(|a| a.artifact_id.to_string())
        .collect();
    let all_artifact_parts =
        super::batch_queries::fetch_artifact_parts(&pool, &artifact_ids).await?;

    let parts_by_message: HashMap<MessageId, Vec<&MessagePart>> =
        group_by_key(&all_parts, |p| p.message_id.clone());
    let messages_by_task: HashMap<TaskId, Vec<&TaskMessage>> =
        group_by_key(&all_messages, |m| m.task_id.clone());
    let artifacts_by_task: HashMap<TaskId, Vec<&ArtifactRow>> =
        group_by_key(&all_artifact_rows, |a| a.task_id.clone());
    let artifact_parts_by_id: HashMap<ArtifactId, Vec<&ArtifactPartRow>> =
        group_by_key(&all_artifact_parts, |p| p.artifact_id.clone());
    let steps_by_task: HashMap<TaskId, Vec<&ExecutionStepBatchRow>> =
        group_by_key(&all_execution_steps, |s| s.task_id.clone());

    build_tasks(&BuildTasksParams {
        task_rows: &task_rows,
        messages_by_task: &messages_by_task,
        parts_by_message: &parts_by_message,
        artifacts_by_task: &artifacts_by_task,
        artifact_parts_by_id: &artifact_parts_by_id,
        steps_by_task: &steps_by_task,
    })
}

fn group_by_key<T, F, K>(items: &[T], key_fn: F) -> HashMap<K, Vec<&T>>
where
    F: Fn(&T) -> K,
    K: std::hash::Hash + Eq,
{
    items.iter().fold(HashMap::new(), |mut acc, item| {
        let key = key_fn(item);
        acc.entry(key).or_default().push(item);
        acc
    })
}

#[allow(missing_debug_implementations)]
struct BuildTasksParams<'a> {
    task_rows: &'a [TaskRow],
    messages_by_task: &'a HashMap<TaskId, Vec<&'a TaskMessage>>,
    parts_by_message: &'a HashMap<MessageId, Vec<&'a MessagePart>>,
    artifacts_by_task: &'a HashMap<TaskId, Vec<&'a ArtifactRow>>,
    artifact_parts_by_id: &'a HashMap<ArtifactId, Vec<&'a ArtifactPartRow>>,
    steps_by_task: &'a HashMap<TaskId, Vec<&'a ExecutionStepBatchRow>>,
}

fn build_tasks(params: &BuildTasksParams<'_>) -> Result<Vec<Task>, RepositoryError> {
    let BuildTasksParams {
        task_rows,
        messages_by_task,
        parts_by_message,
        artifacts_by_task,
        artifact_parts_by_id,
        steps_by_task,
    } = params;
    let mut tasks = Vec::new();

    for row in *task_rows {
        let history = build_messages(messages_by_task.get(&row.task_id), parts_by_message);
        let artifacts = build_artifacts(artifacts_by_task.get(&row.task_id), artifact_parts_by_id);
        let execution_steps = build_execution_steps(steps_by_task.get(&row.task_id));

        let mut metadata = converters::construct_metadata(row);
        metadata.execution_steps = execution_steps;

        let task_state = converters::parse_task_state(&row.status)
            .map_err(|e| RepositoryError::InvalidData(e.to_string()))?;

        tasks.push(Task {
            id: row.task_id.clone(),
            context_id: row.context_id.clone(),
            status: TaskStatus {
                state: task_state,
                message: None,
                timestamp: row.status_timestamp,
            },
            history,
            artifacts,
            metadata: Some(metadata),
            created_at: Some(row.created_at),
            last_modified: Some(row.updated_at),
        });
    }

    Ok(tasks)
}