use sqlx::PgPool;
use std::sync::Arc;
use systemprompt_identifiers::{ContextId, MessageId, SessionId, TaskId, TraceId, UserId};
use systemprompt_traits::{DynFileUploadProvider, FileUploadInput, RepositoryError};
use crate::models::a2a::Part;
#[derive(Clone)]
pub struct FileUploadContext<'a> {
pub upload_provider: &'a DynFileUploadProvider,
pub context_id: &'a ContextId,
pub user_id: Option<&'a UserId>,
pub session_id: Option<&'a SessionId>,
pub trace_id: Option<&'a TraceId>,
}
impl std::fmt::Debug for FileUploadContext<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FileUploadContext")
.field("upload_provider", &"<DynFileUploadProvider>")
.field("context_id", &self.context_id)
.field("user_id", &self.user_id)
.field("session_id", &self.session_id)
.field("trace_id", &self.trace_id)
.finish()
}
}
pub async fn get_message_parts(
pool: &Arc<PgPool>,
message_id: &MessageId,
) -> Result<Vec<Part>, RepositoryError> {
let part_rows: Vec<crate::models::MessagePart> = sqlx::query_as!(
crate::models::MessagePart,
r#"SELECT
id as "id!",
message_id as "message_id!: MessageId",
task_id as "task_id!: TaskId",
part_kind as "part_kind!",
sequence_number as "sequence_number!",
text_content,
file_name,
file_mime_type,
file_uri,
file_bytes,
data_content,
metadata
FROM message_parts WHERE message_id = $1 ORDER BY sequence_number ASC"#,
message_id.as_str()
)
.fetch_all(pool.as_ref())
.await
.map_err(RepositoryError::database)?;
let mut parts = Vec::new();
for row in part_rows {
let part = match row.part_kind.as_str() {
"text" => {
let text = row
.text_content
.ok_or_else(|| RepositoryError::InvalidData("Missing text_content".into()))?;
Part::Text(crate::models::a2a::TextPart { text })
},
"file" => Part::File(crate::models::a2a::FilePart {
file: crate::models::a2a::FileContent {
name: row.file_name,
mime_type: row.file_mime_type,
bytes: row.file_bytes,
url: row.file_uri,
},
}),
"data" => {
let data_value = row
.data_content
.ok_or_else(|| RepositoryError::InvalidData("Missing data_content".into()))?;
let serde_json::Value::Object(data) = data_value else {
return Err(RepositoryError::InvalidData(
"Data content must be a JSON object".into(),
));
};
Part::Data(crate::models::a2a::DataPart { data })
},
_ => {
return Err(RepositoryError::InvalidData(format!(
"Unknown part kind: {}",
row.part_kind
)));
},
};
parts.push(part);
}
Ok(parts)
}
#[allow(missing_debug_implementations)]
pub struct PersistPartSqlxParams<'a> {
pub tx: &'a mut sqlx::Transaction<'static, sqlx::Postgres>,
pub part: &'a Part,
pub message_id: &'a MessageId,
pub task_id: &'a TaskId,
pub sequence_number: i32,
pub upload_ctx: Option<&'a FileUploadContext<'a>>,
}
pub async fn persist_part_sqlx(params: PersistPartSqlxParams<'_>) -> Result<(), RepositoryError> {
let PersistPartSqlxParams {
tx,
part,
message_id,
task_id,
sequence_number,
upload_ctx,
} = params;
match part {
Part::Text(text_part) => {
sqlx::query!(
r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, text_content)
VALUES ($1, $2, 'text', $3, $4)"#,
message_id.as_str(),
task_id.as_str(),
sequence_number,
text_part.text
)
.execute(&mut **tx)
.await
.map_err(RepositoryError::database)?;
},
Part::File(file_part) => {
let upload_result = try_upload_file(file_part, upload_ctx).await;
let (file_id, file_uri) = match upload_result {
Some((id, uri)) => (Some(id), Some(uri)),
None => (None, None),
};
sqlx::query!(
r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, file_name, file_mime_type, file_uri, file_bytes, file_id)
VALUES ($1, $2, 'file', $3, $4, $5, $6, $7, $8)"#,
message_id.as_str(),
task_id.as_str(),
sequence_number,
file_part.file.name,
file_part.file.mime_type,
file_uri,
file_part.file.bytes.as_deref(),
file_id
)
.execute(&mut **tx)
.await
.map_err(RepositoryError::database)?;
},
Part::Data(data_part) => {
let data_json =
serde_json::to_value(&data_part.data).map_err(RepositoryError::Serialization)?;
sqlx::query!(
r#"INSERT INTO message_parts (message_id, task_id, part_kind, sequence_number, data_content)
VALUES ($1, $2, 'data', $3, $4)"#,
message_id.as_str(),
task_id.as_str(),
sequence_number,
data_json
)
.execute(&mut **tx)
.await
.map_err(RepositoryError::database)?;
},
}
Ok(())
}
async fn try_upload_file(
file_part: &crate::models::a2a::FilePart,
upload_ctx: Option<&FileUploadContext<'_>>,
) -> Option<(uuid::Uuid, String)> {
let ctx = upload_ctx?;
if !ctx.upload_provider.is_enabled() {
return None;
}
let mime_type = file_part
.file
.mime_type
.as_deref()
.unwrap_or("application/octet-stream");
let bytes = file_part.file.bytes.as_deref()?;
let mut input = FileUploadInput::new(mime_type, bytes, ctx.context_id.clone());
if let Some(name) = &file_part.file.name {
input = input.with_name(name);
}
if let Some(user_id) = ctx.user_id {
input = input.with_user_id(user_id.clone());
}
if let Some(session_id) = ctx.session_id {
input = input.with_session_id(session_id.clone());
}
if let Some(trace_id) = ctx.trace_id {
input = input.with_trace_id(trace_id.clone());
}
match ctx.upload_provider.upload_file(input).await {
Ok(uploaded) => {
let file_uuid = uuid::Uuid::parse_str(uploaded.file_id.as_str())
.map_err(|e| {
tracing::warn!(file_id = %uploaded.file_id, error = %e, "Invalid UUID from file service");
e
})
.ok()?;
Some((file_uuid, uploaded.public_url))
},
Err(e) => {
tracing::warn!(error = %e, "File upload failed, continuing with base64 only");
None
},
}
}
pub async fn persist_part_with_tx(
tx: &mut dyn systemprompt_database::DatabaseTransaction,
part: &Part,
message_id: &MessageId,
task_id: &TaskId,
sequence_number: i32,
) -> Result<(), RepositoryError> {
let message_id_str = message_id.as_str();
let task_id_str = task_id.as_str();
match part {
Part::Text(text_part) => {
let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
sequence_number, text_content) VALUES ($1, $2, 'text', $3, $4)";
tx.execute(
&query,
&[
&message_id_str,
&task_id_str,
&sequence_number,
&text_part.text,
],
)
.await?;
},
Part::File(file_part) => {
let uri_opt: Option<&str> = None;
let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
sequence_number, file_name, file_mime_type, file_uri, file_bytes) \
VALUES ($1, $2, 'file', $3, $4, $5, $6, $7)";
tx.execute(
&query,
&[
&message_id_str,
&task_id_str,
&sequence_number,
&file_part.file.name,
&file_part.file.mime_type,
&uri_opt,
&file_part.file.bytes.as_deref(),
],
)
.await?;
},
Part::Data(data_part) => {
let data_json = serde_json::to_string(&data_part.data)?;
let query: &str = "INSERT INTO message_parts (message_id, task_id, part_kind, \
sequence_number, data_content) VALUES ($1, $2, 'data', $3, $4)";
tx.execute(
&query,
&[&message_id_str, &task_id_str, &sequence_number, &data_json],
)
.await?;
},
}
Ok(())
}