use std::sync::{Arc, mpsc};
use std::time::Duration;
use uuid::Uuid;
use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
use crate::function::AuthContext;
use crate::http::CircuitBreakerClient;
pub fn empty_saved_data() -> serde_json::Value {
serde_json::Value::Object(serde_json::Map::new())
}
pub struct JobContext {
pub job_id: Uuid,
pub job_type: String,
pub attempt: u32,
pub max_attempts: u32,
pub auth: AuthContext,
saved_data: Arc<tokio::sync::RwLock<serde_json::Value>>,
db_pool: sqlx::PgPool,
http_client: CircuitBreakerClient,
http_timeout: Option<Duration>,
progress_tx: Option<mpsc::Sender<ProgressUpdate>>,
env_provider: Arc<dyn EnvProvider>,
}
#[derive(Debug, Clone)]
pub struct ProgressUpdate {
pub job_id: Uuid,
pub percentage: u8,
pub message: String,
}
impl JobContext {
pub fn new(
job_id: Uuid,
job_type: String,
attempt: u32,
max_attempts: u32,
db_pool: sqlx::PgPool,
http_client: CircuitBreakerClient,
) -> Self {
Self {
job_id,
job_type,
attempt,
max_attempts,
auth: AuthContext::unauthenticated(),
saved_data: Arc::new(tokio::sync::RwLock::new(empty_saved_data())),
db_pool,
http_client,
http_timeout: None,
progress_tx: None,
env_provider: Arc::new(RealEnvProvider::new()),
}
}
pub fn with_saved(mut self, data: serde_json::Value) -> Self {
self.saved_data = Arc::new(tokio::sync::RwLock::new(data));
self
}
pub fn with_auth(mut self, auth: AuthContext) -> Self {
self.auth = auth;
self
}
pub fn with_progress(mut self, tx: mpsc::Sender<ProgressUpdate>) -> Self {
self.progress_tx = Some(tx);
self
}
pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
self.env_provider = provider;
self
}
pub fn db(&self) -> crate::function::ForgeDb {
crate::function::ForgeDb::from_pool(&self.db_pool)
}
pub fn db_conn(&self) -> crate::function::DbConn<'_> {
crate::function::DbConn::Pool(self.db_pool.clone())
}
pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
Ok(crate::function::ForgeConn::Pool(
self.db_pool.acquire().await?,
))
}
pub fn http(&self) -> crate::http::HttpClient {
self.http_client.with_timeout(self.http_timeout)
}
pub fn raw_http(&self) -> &reqwest::Client {
self.http_client.inner()
}
pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
self.http_timeout = timeout;
}
pub fn progress(&self, percentage: u8, message: impl Into<String>) -> crate::Result<()> {
let update = ProgressUpdate {
job_id: self.job_id,
percentage: percentage.min(100),
message: message.into(),
};
if let Some(ref tx) = self.progress_tx {
tx.send(update)
.map_err(|e| crate::ForgeError::Job(format!("Failed to send progress: {}", e)))?;
}
Ok(())
}
pub async fn saved(&self) -> serde_json::Value {
self.saved_data.read().await.clone()
}
pub async fn set_saved(&self, data: serde_json::Value) -> crate::Result<()> {
let mut guard = self.saved_data.write().await;
*guard = data;
let persisted = Self::clone_and_drop(guard);
if self.job_id.is_nil() {
return Ok(());
}
self.persist_saved_data(persisted).await
}
pub async fn save(&self, key: &str, value: serde_json::Value) -> crate::Result<()> {
let mut guard = self.saved_data.write().await;
Self::apply_save(&mut guard, key, value);
let persisted = Self::clone_and_drop(guard);
if self.job_id.is_nil() {
return Ok(());
}
self.persist_saved_data(persisted).await
}
pub async fn is_cancel_requested(&self) -> crate::Result<bool> {
let row = sqlx::query_scalar!(
r#"
SELECT status
FROM forge_jobs
WHERE id = $1
"#,
self.job_id
)
.fetch_optional(&self.db_pool)
.await
.map_err(|e| crate::ForgeError::Database(e.to_string()))?;
Ok(matches!(
row.as_deref(),
Some("cancel_requested") | Some("cancelled")
))
}
pub async fn check_cancelled(&self) -> crate::Result<()> {
if self.is_cancel_requested().await? {
Err(crate::ForgeError::JobCancelled(
"Job cancellation requested".to_string(),
))
} else {
Ok(())
}
}
async fn persist_saved_data(&self, data: serde_json::Value) -> crate::Result<()> {
sqlx::query!(
r#"
UPDATE forge_jobs
SET job_context = $2
WHERE id = $1
"#,
self.job_id,
data,
)
.execute(&self.db_pool)
.await
.map_err(|e| crate::ForgeError::Database(e.to_string()))?;
Ok(())
}
fn apply_save(data: &mut serde_json::Value, key: &str, value: serde_json::Value) {
if let Some(map) = data.as_object_mut() {
map.insert(key.to_string(), value);
} else {
let mut map = serde_json::Map::new();
map.insert(key.to_string(), value);
*data = serde_json::Value::Object(map);
}
}
fn clone_and_drop(
guard: tokio::sync::RwLockWriteGuard<'_, serde_json::Value>,
) -> serde_json::Value {
let cloned = guard.clone();
drop(guard);
cloned
}
pub async fn heartbeat(&self) -> crate::Result<()> {
sqlx::query!(
r#"
UPDATE forge_jobs
SET last_heartbeat = NOW()
WHERE id = $1
"#,
self.job_id,
)
.execute(&self.db_pool)
.await
.map_err(|e| crate::ForgeError::Database(e.to_string()))?;
Ok(())
}
pub fn is_retry(&self) -> bool {
self.attempt > 1
}
pub fn is_last_attempt(&self) -> bool {
self.attempt >= self.max_attempts
}
}
impl EnvAccess for JobContext {
fn env_provider(&self) -> &dyn EnvProvider {
self.env_provider.as_ref()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod tests {
use super::*;
#[tokio::test]
async fn test_job_context_creation() {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect_lazy("postgres://localhost/nonexistent")
.expect("Failed to create mock pool");
let job_id = Uuid::new_v4();
let ctx = JobContext::new(
job_id,
"test_job".to_string(),
1,
3,
pool,
CircuitBreakerClient::with_defaults(reqwest::Client::new()),
);
assert_eq!(ctx.job_id, job_id);
assert_eq!(ctx.job_type, "test_job");
assert_eq!(ctx.attempt, 1);
assert_eq!(ctx.max_attempts, 3);
assert!(!ctx.is_retry());
assert!(!ctx.is_last_attempt());
}
#[tokio::test]
async fn test_is_retry() {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect_lazy("postgres://localhost/nonexistent")
.expect("Failed to create mock pool");
let ctx = JobContext::new(
Uuid::new_v4(),
"test".to_string(),
2,
3,
pool,
CircuitBreakerClient::with_defaults(reqwest::Client::new()),
);
assert!(ctx.is_retry());
}
#[tokio::test]
async fn test_is_last_attempt() {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect_lazy("postgres://localhost/nonexistent")
.expect("Failed to create mock pool");
let ctx = JobContext::new(
Uuid::new_v4(),
"test".to_string(),
3,
3,
pool,
CircuitBreakerClient::with_defaults(reqwest::Client::new()),
);
assert!(ctx.is_last_attempt());
}
#[test]
fn test_progress_update() {
let update = ProgressUpdate {
job_id: Uuid::new_v4(),
percentage: 50,
message: "Halfway there".to_string(),
};
assert_eq!(update.percentage, 50);
assert_eq!(update.message, "Halfway there");
}
#[tokio::test]
async fn test_saved_data_in_memory() {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect_lazy("postgres://localhost/nonexistent")
.expect("Failed to create mock pool");
let ctx = JobContext::new(
Uuid::nil(),
"test_job".to_string(),
1,
3,
pool,
CircuitBreakerClient::with_defaults(reqwest::Client::new()),
)
.with_saved(serde_json::json!({"foo": "bar"}));
let saved = ctx.saved().await;
assert_eq!(saved["foo"], "bar");
}
#[tokio::test]
async fn test_save_key_value() {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect_lazy("postgres://localhost/nonexistent")
.expect("Failed to create mock pool");
let ctx = JobContext::new(
Uuid::nil(),
"test_job".to_string(),
1,
3,
pool,
CircuitBreakerClient::with_defaults(reqwest::Client::new()),
);
ctx.save("charge_id", serde_json::json!("ch_123"))
.await
.unwrap();
ctx.save("amount", serde_json::json!(100)).await.unwrap();
let saved = ctx.saved().await;
assert_eq!(saved["charge_id"], "ch_123");
assert_eq!(saved["amount"], 100);
}
}