use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, watch};
use tracing::Span;
use uuid::Uuid;
use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
use crate::function::{JobDispatch, WorkflowDispatch};
use crate::http::CircuitBreakerClient;
pub struct DaemonContext {
pub daemon_name: String,
pub instance_id: Uuid,
db_pool: sqlx::PgPool,
http_client: CircuitBreakerClient,
http_timeout: Option<Duration>,
shutdown_rx: Mutex<watch::Receiver<bool>>,
job_dispatch: Option<Arc<dyn JobDispatch>>,
workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
env_provider: Arc<dyn EnvProvider>,
span: Span,
}
impl DaemonContext {
pub fn new(
daemon_name: String,
instance_id: Uuid,
db_pool: sqlx::PgPool,
http_client: CircuitBreakerClient,
shutdown_rx: watch::Receiver<bool>,
) -> Self {
Self {
daemon_name,
instance_id,
db_pool,
http_client,
http_timeout: None,
shutdown_rx: Mutex::new(shutdown_rx),
job_dispatch: None,
workflow_dispatch: None,
env_provider: Arc::new(RealEnvProvider::new()),
span: Span::current(),
}
}
pub fn with_job_dispatch(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
self.job_dispatch = Some(dispatcher);
self
}
pub fn with_workflow_dispatch(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
self.workflow_dispatch = Some(dispatcher);
self
}
pub fn with_env_provider(mut self, provider: Arc<dyn EnvProvider>) -> Self {
self.env_provider = provider;
self
}
pub fn db(&self) -> crate::function::ForgeDb {
crate::function::ForgeDb::from_pool(&self.db_pool)
}
pub fn db_conn(&self) -> crate::function::DbConn<'_> {
crate::function::DbConn::Pool(self.db_pool.clone())
}
pub async fn conn(&self) -> sqlx::Result<crate::function::ForgeConn<'static>> {
Ok(crate::function::ForgeConn::Pool(
self.db_pool.acquire().await?,
))
}
pub fn http(&self) -> crate::http::HttpClient {
self.http_client.with_timeout(self.http_timeout)
}
pub fn raw_http(&self) -> &reqwest::Client {
self.http_client.inner()
}
pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
self.http_timeout = timeout;
}
pub async fn dispatch_job<T: serde::Serialize>(
&self,
job_type: &str,
args: T,
) -> crate::Result<Uuid> {
let dispatcher = self.job_dispatch.as_ref().ok_or_else(|| {
crate::error::ForgeError::Internal("Job dispatch not available".to_string())
})?;
let args_json = serde_json::to_value(args)?;
dispatcher.dispatch_by_name(job_type, args_json, None).await
}
pub async fn start_workflow<T: serde::Serialize>(
&self,
workflow_name: &str,
input: T,
) -> crate::Result<Uuid> {
let dispatcher = self.workflow_dispatch.as_ref().ok_or_else(|| {
crate::error::ForgeError::Internal("Workflow dispatch not available".to_string())
})?;
let input_json = serde_json::to_value(input)?;
dispatcher
.start_by_name(workflow_name, input_json, None)
.await
}
pub fn is_shutdown_requested(&self) -> bool {
self.shutdown_rx
.try_lock()
.map(|rx| *rx.borrow())
.unwrap_or(false)
}
pub async fn shutdown_signal(&self) {
let mut rx = self.shutdown_rx.lock().await;
while !*rx.borrow_and_update() {
if rx.changed().await.is_err() {
break;
}
}
}
pub async fn heartbeat(&self) -> crate::Result<()> {
tracing::trace!(daemon.name = %self.daemon_name, "Sending heartbeat");
sqlx::query!(
r#"
UPDATE forge_daemons
SET last_heartbeat = NOW()
WHERE name = $1 AND instance_id = $2
"#,
self.daemon_name,
self.instance_id,
)
.execute(&self.db_pool)
.await
.map_err(|e| crate::ForgeError::Database(e.to_string()))?;
Ok(())
}
pub fn trace_id(&self) -> String {
self.instance_id.to_string()
}
pub fn span(&self) -> &Span {
&self.span
}
}
impl EnvAccess for DaemonContext {
fn env_provider(&self) -> &dyn EnvProvider {
self.env_provider.as_ref()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod tests {
use super::*;
#[tokio::test]
async fn test_daemon_context_creation() {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect_lazy("postgres://localhost/nonexistent")
.expect("Failed to create mock pool");
let (shutdown_tx, shutdown_rx) = watch::channel(false);
let instance_id = Uuid::new_v4();
let ctx = DaemonContext::new(
"test_daemon".to_string(),
instance_id,
pool,
CircuitBreakerClient::with_defaults(reqwest::Client::new()),
shutdown_rx,
);
assert_eq!(ctx.daemon_name, "test_daemon");
assert_eq!(ctx.instance_id, instance_id);
assert!(!ctx.is_shutdown_requested());
shutdown_tx.send(true).unwrap();
assert!(ctx.is_shutdown_requested());
}
#[tokio::test]
async fn test_shutdown_signal() {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect_lazy("postgres://localhost/nonexistent")
.expect("Failed to create mock pool");
let (shutdown_tx, shutdown_rx) = watch::channel(false);
let ctx = DaemonContext::new(
"test_daemon".to_string(),
Uuid::new_v4(),
pool,
CircuitBreakerClient::with_defaults(reqwest::Client::new()),
shutdown_rx,
);
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
shutdown_tx.send(true).unwrap();
});
tokio::time::timeout(std::time::Duration::from_millis(200), ctx.shutdown_signal())
.await
.expect("Shutdown signal should complete");
assert!(ctx.is_shutdown_requested());
}
}