use crate::models::a2a::{Task, TaskStatus};
use crate::models::{
ArtifactPartRow, ArtifactRow, ExecutionStepBatchRow, MessagePart, TaskMessage, TaskRow,
};
use std::collections::HashMap;
use std::sync::Arc;
use systemprompt_identifiers::{ArtifactId, MessageId, TaskId};
use systemprompt_traits::RepositoryError;
use super::batch_builders::{build_artifacts, build_execution_steps, build_messages};
use super::{TaskConstructor, converters};
pub async fn construct_tasks_batch(
constructor: &TaskConstructor,
task_ids: &[TaskId],
) -> Result<Vec<Task>, RepositoryError> {
if task_ids.is_empty() {
return Ok(Vec::new());
}
let pool = Arc::clone(constructor.pool());
let task_id_strings: Vec<String> = task_ids.iter().map(ToString::to_string).collect();
let task_rows = super::batch_queries::fetch_task_rows(&pool, &task_id_strings).await?;
let all_messages = super::batch_queries::fetch_messages(&pool, &task_id_strings).await?;
let all_parts = super::batch_queries::fetch_message_parts(&pool, &task_id_strings).await?;
let all_artifact_rows = super::batch_queries::fetch_artifacts(&pool, &task_id_strings).await?;
let all_execution_steps =
super::batch_queries::fetch_execution_steps(&pool, &task_id_strings).await?;
let artifact_ids: Vec<String> = all_artifact_rows
.iter()
.map(|a| a.artifact_id.to_string())
.collect();
let all_artifact_parts =
super::batch_queries::fetch_artifact_parts(&pool, &artifact_ids).await?;
let parts_by_message: HashMap<MessageId, Vec<&MessagePart>> =
group_by_key(&all_parts, |p| p.message_id.clone());
let messages_by_task: HashMap<TaskId, Vec<&TaskMessage>> =
group_by_key(&all_messages, |m| m.task_id.clone());
let artifacts_by_task: HashMap<TaskId, Vec<&ArtifactRow>> =
group_by_key(&all_artifact_rows, |a| a.task_id.clone());
let artifact_parts_by_id: HashMap<ArtifactId, Vec<&ArtifactPartRow>> =
group_by_key(&all_artifact_parts, |p| p.artifact_id.clone());
let steps_by_task: HashMap<TaskId, Vec<&ExecutionStepBatchRow>> =
group_by_key(&all_execution_steps, |s| s.task_id.clone());
build_tasks(&BuildTasksParams {
task_rows: &task_rows,
messages_by_task: &messages_by_task,
parts_by_message: &parts_by_message,
artifacts_by_task: &artifacts_by_task,
artifact_parts_by_id: &artifact_parts_by_id,
steps_by_task: &steps_by_task,
})
}
fn group_by_key<T, F, K>(items: &[T], key_fn: F) -> HashMap<K, Vec<&T>>
where
F: Fn(&T) -> K,
K: std::hash::Hash + Eq,
{
items.iter().fold(HashMap::new(), |mut acc, item| {
let key = key_fn(item);
acc.entry(key).or_default().push(item);
acc
})
}
#[allow(missing_debug_implementations)]
struct BuildTasksParams<'a> {
task_rows: &'a [TaskRow],
messages_by_task: &'a HashMap<TaskId, Vec<&'a TaskMessage>>,
parts_by_message: &'a HashMap<MessageId, Vec<&'a MessagePart>>,
artifacts_by_task: &'a HashMap<TaskId, Vec<&'a ArtifactRow>>,
artifact_parts_by_id: &'a HashMap<ArtifactId, Vec<&'a ArtifactPartRow>>,
steps_by_task: &'a HashMap<TaskId, Vec<&'a ExecutionStepBatchRow>>,
}
fn build_tasks(params: &BuildTasksParams<'_>) -> Result<Vec<Task>, RepositoryError> {
let BuildTasksParams {
task_rows,
messages_by_task,
parts_by_message,
artifacts_by_task,
artifact_parts_by_id,
steps_by_task,
} = params;
let mut tasks = Vec::new();
for row in *task_rows {
let history = build_messages(messages_by_task.get(&row.task_id), parts_by_message);
let artifacts = build_artifacts(artifacts_by_task.get(&row.task_id), artifact_parts_by_id);
let execution_steps = build_execution_steps(steps_by_task.get(&row.task_id));
let mut metadata = converters::construct_metadata(row);
metadata.execution_steps = execution_steps;
let task_state = converters::parse_task_state(&row.status)
.map_err(|e| RepositoryError::InvalidData(e.to_string()))?;
tasks.push(Task {
id: row.task_id.clone(),
context_id: row.context_id.clone(),
status: TaskStatus {
state: task_state,
message: None,
timestamp: row.status_timestamp,
},
history,
artifacts,
metadata: Some(metadata),
created_at: Some(row.created_at),
last_modified: Some(row.updated_at),
});
}
Ok(tasks)
}