use super::{TaskRepository, task_state_to_db_string};
use crate::models::a2a::{Message, Task, TaskState};
use crate::repository::context::message::{
FileUploadContext, PersistMessageSqlxParams, get_next_sequence_number_sqlx,
persist_message_sqlx,
};
use systemprompt_traits::RepositoryError;
#[allow(missing_debug_implementations)]
pub struct UpdateTaskAndSaveMessagesParams<'a> {
pub task: &'a Task,
pub user_message: &'a Message,
pub agent_message: &'a Message,
pub user_id: Option<&'a systemprompt_identifiers::UserId>,
pub session_id: &'a systemprompt_identifiers::SessionId,
pub trace_id: &'a systemprompt_identifiers::TraceId,
}
impl TaskRepository {
pub async fn update_task_and_save_messages(
&self,
params: UpdateTaskAndSaveMessagesParams<'_>,
) -> Result<Task, RepositoryError> {
let UpdateTaskAndSaveMessagesParams {
task,
user_message,
agent_message,
user_id,
session_id,
trace_id,
} = params;
let mut tx = self
.write_pool
.begin()
.await
.map_err(RepositoryError::database)?;
let status = task_state_to_db_string(task.status.state);
let metadata_json = task
.metadata
.as_ref().map_or_else(|| serde_json::json!({}), |m| {
serde_json::to_value(m).unwrap_or_else(|e| {
tracing::warn!(error = %e, task_id = %task.id, "Failed to serialize task metadata");
serde_json::json!({})
})
});
let task_id_str = task.id.as_str();
let is_completed = task.status.state == TaskState::Completed;
let result = if is_completed {
sqlx::query!(
r#"UPDATE agent_tasks SET
status = $1,
status_timestamp = $2,
metadata = $3,
updated_at = CURRENT_TIMESTAMP,
completed_at = CURRENT_TIMESTAMP,
started_at = COALESCE(started_at, CURRENT_TIMESTAMP),
execution_time_ms = EXTRACT(EPOCH FROM (CURRENT_TIMESTAMP - COALESCE(started_at, CURRENT_TIMESTAMP))) * 1000
WHERE task_id = $4"#,
status,
task.status.timestamp,
metadata_json,
task_id_str
)
.execute(&mut *tx)
.await
.map_err(RepositoryError::database)?
} else {
sqlx::query!(
r#"UPDATE agent_tasks SET status = $1, status_timestamp = $2, metadata = $3, updated_at = CURRENT_TIMESTAMP WHERE task_id = $4"#,
status,
task.status.timestamp,
metadata_json,
task_id_str
)
.execute(&mut *tx)
.await
.map_err(RepositoryError::database)?
};
if result.rows_affected() == 0 {
return Err(RepositoryError::NotFound(format!(
"Task not found for update: {}",
task.id
)));
}
let upload_ctx = self
.file_upload_provider
.as_ref()
.map(|svc| FileUploadContext {
upload_provider: svc,
context_id: &task.context_id,
user_id,
session_id: Some(session_id),
trace_id: Some(trace_id),
});
let user_seq = get_next_sequence_number_sqlx(&mut tx, &task.id).await?;
persist_message_sqlx(PersistMessageSqlxParams {
tx: &mut tx,
message: user_message,
task_id: &task.id,
context_id: &task.context_id,
sequence_number: user_seq,
user_id,
session_id,
trace_id,
upload_ctx: upload_ctx.as_ref(),
})
.await?;
let agent_seq = get_next_sequence_number_sqlx(&mut tx, &task.id).await?;
persist_message_sqlx(PersistMessageSqlxParams {
tx: &mut tx,
message: agent_message,
task_id: &task.id,
context_id: &task.context_id,
sequence_number: agent_seq,
user_id,
session_id,
trace_id,
upload_ctx: upload_ctx.as_ref(),
})
.await?;
tx.commit().await.map_err(RepositoryError::database)?;
if let Some(ref analytics_provider) = self.session_analytics_provider {
for _ in 0..2 {
if let Err(e) = analytics_provider.increment_message_count(session_id).await {
tracing::warn!(error = %e, "Failed to increment analytics message count");
}
}
}
let updated_task = self.get_task(&task.id).await?.ok_or_else(|| {
RepositoryError::NotFound(format!("Task not found after update: {}", task.id))
})?;
Ok(updated_task)
}
pub async fn delete_task(
&self,
task_id: &systemprompt_identifiers::TaskId,
) -> Result<(), RepositoryError> {
let task_id_str = task_id.as_str();
sqlx::query!(
"DELETE FROM message_parts WHERE message_id IN (SELECT message_id FROM task_messages \
WHERE task_id = $1)",
task_id_str
)
.execute(&*self.write_pool)
.await
.map_err(RepositoryError::database)?;
sqlx::query!("DELETE FROM task_messages WHERE task_id = $1", task_id_str)
.execute(&*self.write_pool)
.await
.map_err(RepositoryError::database)?;
sqlx::query!(
"DELETE FROM task_execution_steps WHERE task_id = $1",
task_id_str
)
.execute(&*self.write_pool)
.await
.map_err(RepositoryError::database)?;
sqlx::query!("DELETE FROM agent_tasks WHERE task_id = $1", task_id_str)
.execute(&*self.write_pool)
.await
.map_err(RepositoryError::database)?;
Ok(())
}
}