use sqlx::PgPool;
use std::sync::Arc;
use systemprompt_identifiers::{ContextId, MessageId, SessionId, TaskId, TraceId, UserId};
use systemprompt_traits::RepositoryError;
use crate::models::a2a::Message;
use super::parts::get_message_parts;
pub async fn get_messages_by_task(
pool: &Arc<PgPool>,
task_id: &TaskId,
) -> Result<Vec<Message>, RepositoryError> {
let message_rows: Vec<crate::models::TaskMessage> = sqlx::query_as!(
crate::models::TaskMessage,
r#"SELECT
id as "id!",
task_id as "task_id!: TaskId",
message_id as "message_id!: MessageId",
client_message_id,
role as "role!",
context_id as "context_id?: ContextId",
user_id as "user_id?: UserId",
session_id as "session_id?: SessionId",
trace_id as "trace_id?: TraceId",
sequence_number as "sequence_number!",
created_at as "created_at!",
updated_at as "updated_at!",
metadata,
reference_task_ids
FROM task_messages WHERE task_id = $1 ORDER BY sequence_number ASC"#,
task_id.as_str()
)
.fetch_all(pool.as_ref())
.await
.map_err(|e| RepositoryError::database(e))?;
let mut messages = Vec::new();
for row in message_rows {
let parts = get_message_parts(pool, &row.message_id).await?;
let reference_task_ids = row
.reference_task_ids
.map(|ids| ids.into_iter().map(TaskId::new).collect());
messages.push(Message {
role: row.role,
id: row.message_id,
task_id: Some(row.task_id),
context_id: row.context_id.unwrap_or_else(ContextId::empty),
kind: "message".to_string(),
parts,
metadata: row.metadata,
extensions: None,
reference_task_ids,
});
}
Ok(messages)
}
pub async fn get_messages_by_context(
pool: &Arc<PgPool>,
context_id: &ContextId,
) -> Result<Vec<Message>, RepositoryError> {
let message_rows: Vec<crate::models::TaskMessage> = sqlx::query_as!(
crate::models::TaskMessage,
r#"SELECT
m.id as "id!",
m.task_id as "task_id!: TaskId",
m.message_id as "message_id!: MessageId",
m.client_message_id,
m.role as "role!",
m.context_id as "context_id?: ContextId",
m.user_id as "user_id?: UserId",
m.session_id as "session_id?: SessionId",
m.trace_id as "trace_id?: TraceId",
m.sequence_number as "sequence_number!",
m.created_at as "created_at!",
m.updated_at as "updated_at!",
m.metadata,
m.reference_task_ids
FROM task_messages m
JOIN agent_tasks t ON m.task_id = t.task_id
WHERE t.context_id = $1
ORDER BY m.created_at ASC"#,
context_id.as_str()
)
.fetch_all(pool.as_ref())
.await
.map_err(|e| RepositoryError::database(e))?;
let mut messages = Vec::new();
for row in message_rows {
let parts = get_message_parts(pool, &row.message_id).await?;
messages.push(Message {
role: row.role,
id: row.message_id,
task_id: Some(row.task_id),
context_id: row.context_id.unwrap_or_else(|| context_id.clone()),
kind: "message".to_string(),
parts,
metadata: row.metadata,
extensions: None,
reference_task_ids: None,
});
}
Ok(messages)
}
pub async fn get_next_sequence_number(
pool: &Arc<PgPool>,
task_id: &TaskId,
) -> Result<i32, RepositoryError> {
let row = sqlx::query!(
r#"SELECT MAX(sequence_number) as "max_seq" FROM task_messages WHERE task_id = $1"#,
task_id.as_str()
)
.fetch_optional(pool.as_ref())
.await
.map_err(|e| RepositoryError::database(e))?;
Ok(row.and_then(|r| r.max_seq).map(|s| s + 1).unwrap_or(0))
}
pub async fn get_next_sequence_number_sqlx(
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
task_id: &TaskId,
) -> Result<i32, RepositoryError> {
let row = sqlx::query!(
r#"SELECT MAX(sequence_number) as "max_seq" FROM task_messages WHERE task_id = $1"#,
task_id.as_str()
)
.fetch_optional(&mut **tx)
.await
.map_err(|e| RepositoryError::database(e))?;
Ok(row.and_then(|r| r.max_seq).map(|s| s + 1).unwrap_or(0))
}
pub async fn get_next_sequence_number_in_tx(
tx: &mut dyn systemprompt_database::DatabaseTransaction,
task_id: &TaskId,
) -> Result<i32, RepositoryError> {
let query: &str =
"SELECT MAX(sequence_number) as max_seq FROM task_messages WHERE task_id = $1";
let task_id_str = task_id.as_str();
let row = tx.fetch_optional(&query, &[&task_id_str]).await?;
let max_seq = if let Some(ref r) = row {
r.get("max_seq").and_then(|v| v.as_i64()).map(|v| v as i32)
} else {
None
};
Ok(max_seq.map(|s| s + 1).unwrap_or(0))
}