#[cfg(test)]
pub mod test;
use std::sync::Arc;
use std::time::Duration;
use std::{env, fs};
use async_trait::async_trait;
use log::{debug, error, info, trace, warn};
use serde::Serialize;
use serde::de::DeserializeOwned;
use sqlx::migrate::Migrator;
use sqlx::postgres::PgPoolOptions;
use sqlx::{PgPool, Row};
use thiserror::Error;
use tokio::sync::Mutex;
use crate::config::Config;
use crate::consumer::ConsumeAttempt;
use crate::consumer::consumer::ConsumeAttemptResult;
use crate::database::Database;
use crate::transform::{TransformAttempt, TransformRequest};
use crate::worker::worker_manager::WorkerManagerResult;
#[derive(Debug, Error)]
pub enum PostgresDatabaseError {
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Migration error: {0}")]
Migration(#[from] sqlx::migrate::MigrateError),
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("TOML deserialization error: {0}")]
TomlDeserialization(#[from] toml::de::Error),
#[error("TOML serialization error: {0}")]
TomlSerialization(#[from] toml::ser::Error),
#[error("Not found: {0}")]
NotFound(String),
#[error("Attempt already exists: {0}")]
Conflict(String),
#[error("Migration error occurred")]
MigrationError,
#[error("Hex decode error: {0}")]
HexDecodeError(#[from] hex::FromHexError),
}
#[derive(Debug, Clone)]
pub struct PostgresDatabase<TR, TA, CA, C> {
pool: PgPool,
_marker: std::marker::PhantomData<(TR, TA, CA, C)>,
}
#[async_trait]
impl<TR, TA, CA, C> Database for PostgresDatabase<TR, TA, CA, C>
where
TR: TransformRequest + Send + Sync + for<'a> serde::Deserialize<'a> + serde::Serialize,
TA: TransformAttempt<
TransformRequestIdentifier = TR::Identifier,
CallArgsType = TR::Input,
ReturnType = TR::Output,
> + Send
+ Sync
+ DeserializeOwned
+ serde::Serialize,
CA: ConsumeAttempt<
TransformRequestIdentifier = TR::Identifier,
TransformAttemptIdentifier = TA::Identifier,
ConsumeVal = TR::Output,
> + Send
+ Sync
+ DeserializeOwned
+ serde::Serialize,
C: Config<KeyType = String, ValueType = Vec<u8>>,
TR::Input: serde::Serialize + DeserializeOwned,
TR::Output: serde::Serialize + DeserializeOwned,
CA::Identifier: serde::Serialize + DeserializeOwned,
CA::ReturnCtx: serde::Serialize + DeserializeOwned,
TR::Identifier: serde::Serialize + DeserializeOwned,
TA::Identifier: serde::Serialize + DeserializeOwned,
TA::ReturnPackage: serde::Serialize + DeserializeOwned,
{
type Config = C;
type ConsumeAttempt = CA;
type DatabaseError = PostgresDatabaseError;
type Input = TR::Input;
type Output = TR::Output;
type TransformAttempt = TA;
type TransformRequest = TR;
async fn new(ctx: Arc<Mutex<Self::Config>>) -> Result<Self, Self::DatabaseError> {
info!("Initializing PostgresDatabase connection pool");
let conn_str_bytes = ctx
.lock()
.await
.get("db.conn_str".to_string())
.await
.unwrap_or_default();
let connection_string: toml::Value = serde_json::from_slice(&conn_str_bytes)?;
let conn_str = connection_string.as_str().unwrap().to_owned();
let pool = PgPoolOptions::new()
.max_connections(20)
.acquire_timeout(Duration::from_secs(5))
.connect(&conn_str)
.await
.map_err(|e| {
error!("Failed to create database connection pool: {}", e);
e
})?;
info!("Database connection pool created successfully");
let instance = Self {
pool,
_marker: std::marker::PhantomData,
};
instance.run_migrations().await?;
Ok(instance)
}
async fn get_dyn_configs(
&mut self,
) -> Result<
Vec<(
<Self::Config as Config>::KeyType,
<Self::Config as Config>::ValueType,
)>,
Self::DatabaseError,
> {
debug!("Fetching dynamic configurations from database");
let rows = sqlx::query(
r#"
SELECT key, value FROM dynamic_configs
ORDER BY key
"#,
)
.fetch_all(&self.pool)
.await
.map_err(|e| {
error!("Failed to fetch dynamic configs: {}", e);
PostgresDatabaseError::Database(e)
})?;
debug!("Dynamic configurations fetched successfully");
let configs = rows
.into_iter()
.map(|row| {
let key: String = row.get("key");
let value_hex: String = row.get("value");
let value = hex::decode(value_hex).map_err(|e| {
error!("Failed to decode hex value for key '{}': {}", key, e);
PostgresDatabaseError::HexDecodeError(e)
})?;
Ok((key, value))
})
.collect::<Result<Vec<(String, Vec<u8>)>, PostgresDatabaseError>>()?;
Ok(configs)
}
async fn register_transform_request(
&mut self,
request: &Self::TransformRequest,
) -> Result<(), Self::DatabaseError> {
debug!("Registering new transform request");
let request_id = serde_json::to_value(request.request_id())?;
let input = serde_json::to_value(request.input())?;
let dyn_cfgs: Vec<(String, String)> = request
.get_dyn_configs()
.into_iter()
.map(|(key, value)| {
let value = hex::encode(value);
debug!("Serializing value for key '{}': {}", key, value);
(key, value)
})
.collect();
let mut tx = self.pool.begin().await?;
let rows_affected = sqlx::query(
r#"
INSERT INTO transform_requests (request_id, input)
VALUES ($1, $2)
ON CONFLICT (request_id) DO NOTHING
RETURNING 1
"#,
)
.bind(request_id)
.bind(input)
.execute(&mut *tx)
.await?
.rows_affected();
for (key, value) in dyn_cfgs {
sqlx::query(
r#"
INSERT INTO dynamic_configs (key, value, created_at)
VALUES ($1, $2, NOW())
ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value
"#,
)
.bind(key)
.bind(value)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
if rows_affected > 0 {
debug!("Transform request registered successfully");
Ok(())
} else {
warn!("Transform request already exists");
Err(PostgresDatabaseError::Conflict(
"Transform request already exists".into(),
))
}
}
async fn register_transform_attempt(
&mut self,
attempt: &Self::TransformAttempt,
) -> Result<(), Self::DatabaseError> {
debug!("Registering new transform attempt");
let request_id = serde_json::to_value(attempt.request_id())?;
let attempt_id = serde_json::to_value(attempt.attempt_id())?;
let rows_affected = sqlx::query(
r#"
INSERT INTO transform_attempts
(attempt_id, request_id, status)
VALUES ($1, $2, 'pending')
ON CONFLICT (request_id, attempt_id) DO NOTHING
RETURNING 1
"#,
)
.bind(attempt_id)
.bind(request_id)
.execute(&self.pool)
.await?
.rows_affected();
if rows_affected > 0 {
debug!("Transform attempt registered successfully");
Ok(())
} else {
warn!("Transform attempt already exists");
Err(PostgresDatabaseError::Conflict(
"Transform attempt already exists".into(),
))
}
}
async fn update_transform_attempt(
&mut self,
attempt: &WorkerManagerResult<Self::TransformAttempt>,
) -> Result<(), Self::DatabaseError> {
debug!("Updating transform attempt status");
let (attempt_id, return_pkg, status) = match attempt {
WorkerManagerResult::Success(id, pkg) => (id, pkg, "success"),
WorkerManagerResult::Failure(id, pkg) => (id, pkg, "failure"),
};
let attempt_id = serde_json::to_value(&attempt_id)?;
let return_pkg = serde_json::to_value(&return_pkg)?;
let rows_affected = sqlx::query(
r#"
UPDATE transform_attempts
SET return_pkg = $1,
status = $2::attempt_status,
updated_at = NOW()
WHERE attempt_id = $3
"#,
)
.bind(return_pkg)
.bind(status)
.bind(attempt_id)
.execute(&self.pool)
.await?
.rows_affected();
if rows_affected > 0 {
debug!("Transform attempt updated successfully");
Ok(())
} else {
warn!("Transform attempt not found for update");
Err(PostgresDatabaseError::NotFound(
"Transform attempt not found".into(),
))
}
}
async fn register_consume_attempt(
&mut self,
attempt: &Self::ConsumeAttempt,
) -> Result<(), Self::DatabaseError> {
debug!("Registering new consume attempt");
let request_id = serde_json::to_value(attempt.request_id())?;
let attempt_id = serde_json::to_value(attempt.attempt_id())?;
let consume_id = serde_json::to_value(attempt.consume_id())?;
let dyn_cfgs: Vec<(String, String)> = attempt
.get_dyn_configs()
.into_iter()
.map(|(key, value)| {
let value = hex::encode(value);
debug!("Serializing value for key '{}': {}", key, value);
(key, value)
})
.collect();
let mut tx = self.pool.begin().await?;
let rows_affected = sqlx::query(
r#"
INSERT INTO consume_attempts
(request_id, attempt_id, consume_id, status)
VALUES ($1, $2, $3, 'pending')
ON CONFLICT (request_id, attempt_id, consume_id) DO NOTHING
RETURNING 1
"#,
)
.bind(request_id)
.bind(attempt_id)
.bind(consume_id)
.execute(&mut *tx)
.await?
.rows_affected();
for (key, value) in dyn_cfgs {
sqlx::query(
r#"
INSERT INTO dynamic_configs (key, value, created_at)
VALUES ($1, $2, NOW())
ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value
"#,
)
.bind(key)
.bind(value)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
if rows_affected > 0 {
debug!("Consume attempt registered successfully");
Ok(())
} else {
warn!("Consume attempt already exists");
Err(PostgresDatabaseError::Conflict(
"Consume attempt already exists".into(),
))
}
}
async fn update_consume_attempt(
&mut self,
attempt: ConsumeAttemptResult<Self::ConsumeAttempt>,
) -> Result<(), Self::DatabaseError> {
debug!("Updating consume attempt status");
let (consume_id, return_ctx, status) = match attempt {
ConsumeAttemptResult::Success(id, ctx) => (id, ctx, "success"),
ConsumeAttemptResult::Failure(id, ctx) => (id, ctx, "failure"),
};
let consume_id = serde_json::to_value(&consume_id)?;
let return_ctx = serde_json::to_value(&return_ctx)?;
let rows_affected = sqlx::query(
r#"
UPDATE consume_attempts
SET return_ctx = $1,
status = $2,
updated_at = NOW()
WHERE consume_id = $3
"#,
)
.bind(return_ctx)
.bind(status)
.bind(consume_id)
.execute(&self.pool)
.await?
.rows_affected();
if rows_affected > 0 {
debug!("Consume attempt updated successfully");
Ok(())
} else {
warn!("Consume attempt not found for update");
Err(PostgresDatabaseError::NotFound(
"Consume attempt not found".into(),
))
}
}
async fn archive_request_with_id(
&mut self,
request_id: &<Self::TransformRequest as TransformRequest>::Identifier,
) -> Result<(), Self::DatabaseError> {
debug!("Archiving transform request by request_id");
let mut tx = self.pool.begin().await?;
let request_id_json = serde_json::to_value(request_id)?;
sqlx::query(
r#"
WITH moved_consumes AS (
DELETE FROM consume_attempts
WHERE request_id = $1
RETURNING *
)
INSERT INTO archive_consume_attempts
SELECT * FROM moved_consumes
"#,
)
.bind(&request_id_json)
.execute(&mut *tx)
.await?;
sqlx::query(
r#"
WITH moved_attempts AS (
DELETE FROM transform_attempts
WHERE request_id = $1
RETURNING *
)
INSERT INTO archive_transform_attempts
SELECT * FROM moved_attempts
"#,
)
.bind(&request_id_json)
.execute(&mut *tx)
.await?;
sqlx::query(
r#"
WITH moved_requests AS (
DELETE FROM transform_requests
WHERE request_id = $1
RETURNING *
)
INSERT INTO archive_transform_requests
SELECT * FROM moved_requests
"#,
)
.bind(&request_id_json)
.execute(&mut *tx)
.await?;
tx.commit().await?;
debug!("Transform request archived successfully");
Ok(())
}
}
impl<TR, TA, CA, C> PostgresDatabase<TR, TA, CA, C>
where
TR: TransformRequest + Send + Sync,
TR::Identifier: Serialize + DeserializeOwned,
TA: TransformAttempt + Send + Sync,
TA::Identifier: Serialize + DeserializeOwned,
CA: ConsumeAttempt + Send + Sync,
CA::Identifier: Serialize + DeserializeOwned,
C: Config + Send + Sync,
{
async fn run_migrations(&self) -> Result<(), PostgresDatabaseError> {
info!("Running database migrations");
let temp_dir = env::temp_dir();
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let temp_file_path = temp_dir
.join(format!("shepherd_migrations_{}", timestamp))
.join("0000_default_schema.up.sql");
trace!("Using temporary file for migrations: {:?}", temp_file_path);
if let Some(parent_dir) = temp_file_path.parent() {
fs::create_dir_all(parent_dir).map_err(|e| {
error!(
"Failed to create directories for temporary file path: {}",
e
);
PostgresDatabaseError::MigrationError
})?;
}
fs::write(&temp_file_path, MIGRATIONS).map_err(|e| {
error!("Failed to write migrations to temporary file: {}", e);
PostgresDatabaseError::MigrationError
})?;
let migrator = Migrator::new(temp_file_path.clone().parent().unwrap())
.await
.map_err(|e| {
error!("Failed to initialize migrator: {}", e);
e
})?;
migrator.run(&self.pool).await.map_err(|e| {
error!("Failed to apply migrations: {}", e);
e
})?;
fs::remove_file(&temp_file_path).map_err(|e| {
warn!("Failed to delete temporary migrations file: {}", e);
PostgresDatabaseError::MigrationError
})?;
info!("Database migrations completed successfully");
Ok(())
}
}
const MIGRATIONS: &str = include_str!("./0000_default_schema.up.sql");