1use std::collections::HashMap;
36use std::sync::{Arc, Mutex};
37
38use sqlx::postgres::{PgArguments, PgConnection, PgQueryResult, PgRow};
39use sqlx::{FromRow, Postgres, Transaction};
40use tokio::sync::Mutex as AsyncMutex;
41use uuid::Uuid;
42
43use tracing::Instrument;
44
45use super::dispatch::{JobDispatch, WorkflowDispatch};
46use crate::auth::Claims;
47use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
48use crate::http::CircuitBreakerClient;
49use crate::job::JobInfo;
50
51pub trait TokenIssuer: Send + Sync {
56 fn sign(&self, claims: &Claims) -> crate::error::Result<String>;
58}
59
60pub enum DbConn<'a> {
62 Pool(&'a sqlx::PgPool),
63 Transaction(Arc<AsyncMutex<Transaction<'static, Postgres>>>),
64}
65
66impl DbConn<'_> {
67 pub async fn fetch_one<'q, O>(
68 &self,
69 query: sqlx::query::QueryAs<'q, Postgres, O, PgArguments>,
70 ) -> sqlx::Result<O>
71 where
72 O: Send + Unpin + for<'r> FromRow<'r, PgRow>,
73 {
74 let span = tracing::info_span!(
75 "db.query",
76 db.system = "postgresql",
77 db.operation.name = "fetch_one",
78 );
79 async {
80 match self {
81 DbConn::Pool(pool) => query.fetch_one(*pool).await,
82 DbConn::Transaction(tx) => query.fetch_one(&mut **tx.lock().await).await,
83 }
84 }
85 .instrument(span)
86 .await
87 }
88
89 pub async fn fetch_optional<'q, O>(
90 &self,
91 query: sqlx::query::QueryAs<'q, Postgres, O, PgArguments>,
92 ) -> sqlx::Result<Option<O>>
93 where
94 O: Send + Unpin + for<'r> FromRow<'r, PgRow>,
95 {
96 let span = tracing::info_span!(
97 "db.query",
98 db.system = "postgresql",
99 db.operation.name = "fetch_optional",
100 );
101 async {
102 match self {
103 DbConn::Pool(pool) => query.fetch_optional(*pool).await,
104 DbConn::Transaction(tx) => query.fetch_optional(&mut **tx.lock().await).await,
105 }
106 }
107 .instrument(span)
108 .await
109 }
110
111 pub async fn fetch_all<'q, O>(
112 &self,
113 query: sqlx::query::QueryAs<'q, Postgres, O, PgArguments>,
114 ) -> sqlx::Result<Vec<O>>
115 where
116 O: Send + Unpin + for<'r> FromRow<'r, PgRow>,
117 {
118 let span = tracing::info_span!(
119 "db.query",
120 db.system = "postgresql",
121 db.operation.name = "fetch_all",
122 );
123 async {
124 match self {
125 DbConn::Pool(pool) => query.fetch_all(*pool).await,
126 DbConn::Transaction(tx) => query.fetch_all(&mut **tx.lock().await).await,
127 }
128 }
129 .instrument(span)
130 .await
131 }
132
133 pub async fn execute<'q>(
134 &self,
135 query: sqlx::query::Query<'q, Postgres, PgArguments>,
136 ) -> sqlx::Result<PgQueryResult> {
137 let span = tracing::info_span!(
138 "db.query",
139 db.system = "postgresql",
140 db.operation.name = "execute",
141 );
142 async {
143 match self {
144 DbConn::Pool(pool) => query.execute(*pool).await,
145 DbConn::Transaction(tx) => query.execute(&mut **tx.lock().await).await,
146 }
147 }
148 .instrument(span)
149 .await
150 }
151}
152
153pub enum ForgeConn<'a> {
167 Pool(sqlx::pool::PoolConnection<Postgres>),
168 Tx(tokio::sync::MutexGuard<'a, Transaction<'static, Postgres>>),
169}
170
171impl std::ops::Deref for ForgeConn<'_> {
172 type Target = PgConnection;
173 fn deref(&self) -> &PgConnection {
174 match self {
175 ForgeConn::Pool(c) => c,
176 ForgeConn::Tx(g) => g,
177 }
178 }
179}
180
181impl std::ops::DerefMut for ForgeConn<'_> {
182 fn deref_mut(&mut self) -> &mut PgConnection {
183 match self {
184 ForgeConn::Pool(c) => c,
185 ForgeConn::Tx(g) => g,
186 }
187 }
188}
189
190#[derive(Debug, Clone)]
191pub struct PendingJob {
192 pub id: Uuid,
193 pub job_type: String,
194 pub args: serde_json::Value,
195 pub context: serde_json::Value,
196 pub owner_subject: Option<String>,
197 pub priority: i32,
198 pub max_attempts: i32,
199 pub worker_capability: Option<String>,
200}
201
202#[derive(Debug, Clone)]
203pub struct PendingWorkflow {
204 pub id: Uuid,
205 pub workflow_name: String,
206 pub input: serde_json::Value,
207 pub owner_subject: Option<String>,
208}
209
210#[derive(Default)]
211pub struct OutboxBuffer {
212 pub jobs: Vec<PendingJob>,
213 pub workflows: Vec<PendingWorkflow>,
214}
215
216#[derive(Debug, Clone)]
218pub struct AuthContext {
219 user_id: Option<Uuid>,
221 roles: Vec<String>,
223 claims: HashMap<String, serde_json::Value>,
225 authenticated: bool,
227}
228
229impl AuthContext {
230 pub fn unauthenticated() -> Self {
232 Self {
233 user_id: None,
234 roles: Vec::new(),
235 claims: HashMap::new(),
236 authenticated: false,
237 }
238 }
239
240 pub fn authenticated(
242 user_id: Uuid,
243 roles: Vec<String>,
244 claims: HashMap<String, serde_json::Value>,
245 ) -> Self {
246 Self {
247 user_id: Some(user_id),
248 roles,
249 claims,
250 authenticated: true,
251 }
252 }
253
254 pub fn authenticated_without_uuid(
260 roles: Vec<String>,
261 claims: HashMap<String, serde_json::Value>,
262 ) -> Self {
263 Self {
264 user_id: None,
265 roles,
266 claims,
267 authenticated: true,
268 }
269 }
270
271 pub fn is_authenticated(&self) -> bool {
273 self.authenticated
274 }
275
276 pub fn user_id(&self) -> Option<Uuid> {
278 self.user_id
279 }
280
281 pub fn require_user_id(&self) -> crate::error::Result<Uuid> {
283 self.user_id
284 .ok_or_else(|| crate::error::ForgeError::Unauthorized("Authentication required".into()))
285 }
286
287 pub fn has_role(&self, role: &str) -> bool {
289 self.roles.iter().any(|r| r == role)
290 }
291
292 pub fn require_role(&self, role: &str) -> crate::error::Result<()> {
294 if self.has_role(role) {
295 Ok(())
296 } else {
297 Err(crate::error::ForgeError::Forbidden(format!(
298 "Required role '{}' not present",
299 role
300 )))
301 }
302 }
303
304 pub fn claim(&self, key: &str) -> Option<&serde_json::Value> {
306 self.claims.get(key)
307 }
308
309 pub fn claims(&self) -> &HashMap<String, serde_json::Value> {
311 &self.claims
312 }
313
314 pub fn roles(&self) -> &[String] {
316 &self.roles
317 }
318
319 pub fn subject(&self) -> Option<&str> {
325 self.claims.get("sub").and_then(|v| v.as_str())
326 }
327
328 pub fn require_subject(&self) -> crate::error::Result<&str> {
330 if !self.authenticated {
331 return Err(crate::error::ForgeError::Unauthorized(
332 "Authentication required".to_string(),
333 ));
334 }
335 self.subject().ok_or_else(|| {
336 crate::error::ForgeError::Unauthorized("No subject claim in token".to_string())
337 })
338 }
339
340 pub fn principal_id(&self) -> Option<String> {
344 self.subject()
345 .map(ToString::to_string)
346 .or_else(|| self.user_id.map(|id| id.to_string()))
347 }
348
349 pub fn is_admin(&self) -> bool {
351 self.roles.iter().any(|r| r == "admin")
352 }
353}
354
355#[derive(Debug, Clone)]
357pub struct RequestMetadata {
358 pub request_id: Uuid,
360 pub trace_id: String,
362 pub client_ip: Option<String>,
364 pub user_agent: Option<String>,
366 pub timestamp: chrono::DateTime<chrono::Utc>,
368}
369
370impl RequestMetadata {
371 pub fn new() -> Self {
373 Self {
374 request_id: Uuid::new_v4(),
375 trace_id: Uuid::new_v4().to_string(),
376 client_ip: None,
377 user_agent: None,
378 timestamp: chrono::Utc::now(),
379 }
380 }
381
382 pub fn with_trace_id(trace_id: String) -> Self {
384 Self {
385 request_id: Uuid::new_v4(),
386 trace_id,
387 client_ip: None,
388 user_agent: None,
389 timestamp: chrono::Utc::now(),
390 }
391 }
392}
393
394impl Default for RequestMetadata {
395 fn default() -> Self {
396 Self::new()
397 }
398}
399
400pub struct QueryContext {
402 pub auth: AuthContext,
404 pub request: RequestMetadata,
406 db_pool: sqlx::PgPool,
408 env_provider: Arc<dyn EnvProvider>,
410}
411
412impl QueryContext {
413 pub fn new(db_pool: sqlx::PgPool, auth: AuthContext, request: RequestMetadata) -> Self {
415 Self {
416 auth,
417 request,
418 db_pool,
419 env_provider: Arc::new(RealEnvProvider::new()),
420 }
421 }
422
423 pub fn with_env(
425 db_pool: sqlx::PgPool,
426 auth: AuthContext,
427 request: RequestMetadata,
428 env_provider: Arc<dyn EnvProvider>,
429 ) -> Self {
430 Self {
431 auth,
432 request,
433 db_pool,
434 env_provider,
435 }
436 }
437
438 pub fn db(&self) -> &sqlx::PgPool {
439 &self.db_pool
440 }
441
442 pub fn db_conn(&self) -> DbConn<'_> {
445 DbConn::Pool(&self.db_pool)
446 }
447
448 pub fn require_user_id(&self) -> crate::error::Result<Uuid> {
449 self.auth.require_user_id()
450 }
451
452 pub fn require_subject(&self) -> crate::error::Result<&str> {
454 self.auth.require_subject()
455 }
456}
457
458impl EnvAccess for QueryContext {
459 fn env_provider(&self) -> &dyn EnvProvider {
460 self.env_provider.as_ref()
461 }
462}
463
464pub type JobInfoLookup = Arc<dyn Fn(&str) -> Option<JobInfo> + Send + Sync>;
466
467pub struct MutationContext {
469 pub auth: AuthContext,
471 pub request: RequestMetadata,
473 db_pool: sqlx::PgPool,
475 http_client: CircuitBreakerClient,
477 job_dispatch: Option<Arc<dyn JobDispatch>>,
479 workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
481 env_provider: Arc<dyn EnvProvider>,
483 tx: Option<Arc<AsyncMutex<Transaction<'static, Postgres>>>>,
485 outbox: Option<Arc<Mutex<OutboxBuffer>>>,
487 job_info_lookup: Option<JobInfoLookup>,
489 token_issuer: Option<Arc<dyn TokenIssuer>>,
491}
492
493impl MutationContext {
494 pub fn new(db_pool: sqlx::PgPool, auth: AuthContext, request: RequestMetadata) -> Self {
496 Self {
497 auth,
498 request,
499 db_pool,
500 http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
501 job_dispatch: None,
502 workflow_dispatch: None,
503 env_provider: Arc::new(RealEnvProvider::new()),
504 tx: None,
505 outbox: None,
506 job_info_lookup: None,
507 token_issuer: None,
508 }
509 }
510
511 pub fn with_dispatch(
513 db_pool: sqlx::PgPool,
514 auth: AuthContext,
515 request: RequestMetadata,
516 http_client: CircuitBreakerClient,
517 job_dispatch: Option<Arc<dyn JobDispatch>>,
518 workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
519 ) -> Self {
520 Self {
521 auth,
522 request,
523 db_pool,
524 http_client,
525 job_dispatch,
526 workflow_dispatch,
527 env_provider: Arc::new(RealEnvProvider::new()),
528 tx: None,
529 outbox: None,
530 job_info_lookup: None,
531 token_issuer: None,
532 }
533 }
534
535 pub fn with_env(
537 db_pool: sqlx::PgPool,
538 auth: AuthContext,
539 request: RequestMetadata,
540 http_client: CircuitBreakerClient,
541 job_dispatch: Option<Arc<dyn JobDispatch>>,
542 workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
543 env_provider: Arc<dyn EnvProvider>,
544 ) -> Self {
545 Self {
546 auth,
547 request,
548 db_pool,
549 http_client,
550 job_dispatch,
551 workflow_dispatch,
552 env_provider,
553 tx: None,
554 outbox: None,
555 job_info_lookup: None,
556 token_issuer: None,
557 }
558 }
559
560 #[allow(clippy::type_complexity)]
562 pub fn with_transaction(
563 db_pool: sqlx::PgPool,
564 tx: Transaction<'static, Postgres>,
565 auth: AuthContext,
566 request: RequestMetadata,
567 http_client: CircuitBreakerClient,
568 job_info_lookup: JobInfoLookup,
569 ) -> (
570 Self,
571 Arc<AsyncMutex<Transaction<'static, Postgres>>>,
572 Arc<Mutex<OutboxBuffer>>,
573 ) {
574 let tx_handle = Arc::new(AsyncMutex::new(tx));
575 let outbox = Arc::new(Mutex::new(OutboxBuffer::default()));
576
577 let ctx = Self {
578 auth,
579 request,
580 db_pool,
581 http_client,
582 job_dispatch: None,
583 workflow_dispatch: None,
584 env_provider: Arc::new(RealEnvProvider::new()),
585 tx: Some(tx_handle.clone()),
586 outbox: Some(outbox.clone()),
587 job_info_lookup: Some(job_info_lookup),
588 token_issuer: None,
589 };
590
591 (ctx, tx_handle, outbox)
592 }
593
594 pub fn is_transactional(&self) -> bool {
595 self.tx.is_some()
596 }
597
598 pub fn db(&self) -> DbConn<'_> {
599 match &self.tx {
600 Some(tx) => DbConn::Transaction(tx.clone()),
601 None => DbConn::Pool(&self.db_pool),
602 }
603 }
604
605 pub async fn conn(&self) -> sqlx::Result<ForgeConn<'_>> {
617 match &self.tx {
618 Some(tx) => Ok(ForgeConn::Tx(tx.lock().await)),
619 None => Ok(ForgeConn::Pool(self.db_pool.acquire().await?)),
620 }
621 }
622
623 pub fn pool(&self) -> &sqlx::PgPool {
625 &self.db_pool
626 }
627
628 pub fn http(&self) -> &reqwest::Client {
634 self.http_client.inner()
635 }
636
637 pub fn http_with_circuit_breaker(&self) -> &CircuitBreakerClient {
639 &self.http_client
640 }
641
642 pub fn require_user_id(&self) -> crate::error::Result<Uuid> {
643 self.auth.require_user_id()
644 }
645
646 pub fn require_subject(&self) -> crate::error::Result<&str> {
647 self.auth.require_subject()
648 }
649
650 pub fn set_token_issuer(&mut self, issuer: Arc<dyn TokenIssuer>) {
652 self.token_issuer = Some(issuer);
653 }
654
655 pub fn issue_token(&self, claims: &Claims) -> crate::error::Result<String> {
670 let issuer = self.token_issuer.as_ref().ok_or_else(|| {
671 crate::error::ForgeError::Internal(
672 "Token issuer not available. Configure [auth] with an HMAC algorithm in forge.toml"
673 .into(),
674 )
675 })?;
676 issuer.sign(claims)
677 }
678
679 pub async fn dispatch_job<T: serde::Serialize>(
681 &self,
682 job_type: &str,
683 args: T,
684 ) -> crate::error::Result<Uuid> {
685 let args_json = serde_json::to_value(args)?;
686
687 if let (Some(outbox), Some(job_info_lookup)) = (&self.outbox, &self.job_info_lookup) {
689 let job_info = job_info_lookup(job_type).ok_or_else(|| {
690 crate::error::ForgeError::NotFound(format!("Job type '{}' not found", job_type))
691 })?;
692
693 let pending = PendingJob {
694 id: Uuid::new_v4(),
695 job_type: job_type.to_string(),
696 args: args_json,
697 context: serde_json::json!({}),
698 owner_subject: self.auth.principal_id(),
699 priority: job_info.priority.as_i32(),
700 max_attempts: job_info.retry.max_attempts as i32,
701 worker_capability: job_info.worker_capability.map(|s| s.to_string()),
702 };
703
704 let job_id = pending.id;
705 outbox
706 .lock()
707 .expect("outbox lock poisoned")
708 .jobs
709 .push(pending);
710 return Ok(job_id);
711 }
712
713 let dispatcher = self.job_dispatch.as_ref().ok_or_else(|| {
715 crate::error::ForgeError::Internal("Job dispatch not available".into())
716 })?;
717 dispatcher
718 .dispatch_by_name(job_type, args_json, self.auth.principal_id())
719 .await
720 }
721
722 pub async fn dispatch_job_with_context<T: serde::Serialize>(
724 &self,
725 job_type: &str,
726 args: T,
727 context: serde_json::Value,
728 ) -> crate::error::Result<Uuid> {
729 let args_json = serde_json::to_value(args)?;
730
731 if let (Some(outbox), Some(job_info_lookup)) = (&self.outbox, &self.job_info_lookup) {
732 let job_info = job_info_lookup(job_type).ok_or_else(|| {
733 crate::error::ForgeError::NotFound(format!("Job type '{}' not found", job_type))
734 })?;
735
736 let pending = PendingJob {
737 id: Uuid::new_v4(),
738 job_type: job_type.to_string(),
739 args: args_json,
740 context,
741 owner_subject: self.auth.principal_id(),
742 priority: job_info.priority.as_i32(),
743 max_attempts: job_info.retry.max_attempts as i32,
744 worker_capability: job_info.worker_capability.map(|s| s.to_string()),
745 };
746
747 let job_id = pending.id;
748 outbox
749 .lock()
750 .expect("outbox lock poisoned")
751 .jobs
752 .push(pending);
753 return Ok(job_id);
754 }
755
756 let dispatcher = self.job_dispatch.as_ref().ok_or_else(|| {
757 crate::error::ForgeError::Internal("Job dispatch not available".into())
758 })?;
759 dispatcher
760 .dispatch_by_name(job_type, args_json, self.auth.principal_id())
761 .await
762 }
763
764 pub async fn cancel_job(
766 &self,
767 job_id: Uuid,
768 reason: Option<String>,
769 ) -> crate::error::Result<bool> {
770 let dispatcher = self.job_dispatch.as_ref().ok_or_else(|| {
771 crate::error::ForgeError::Internal("Job dispatch not available".into())
772 })?;
773 dispatcher.cancel(job_id, reason).await
774 }
775
776 pub async fn start_workflow<T: serde::Serialize>(
778 &self,
779 workflow_name: &str,
780 input: T,
781 ) -> crate::error::Result<Uuid> {
782 let input_json = serde_json::to_value(input)?;
783
784 if let Some(outbox) = &self.outbox {
786 let pending = PendingWorkflow {
787 id: Uuid::new_v4(),
788 workflow_name: workflow_name.to_string(),
789 input: input_json,
790 owner_subject: self.auth.principal_id(),
791 };
792
793 let workflow_id = pending.id;
794 outbox
795 .lock()
796 .expect("outbox lock poisoned")
797 .workflows
798 .push(pending);
799 return Ok(workflow_id);
800 }
801
802 let dispatcher = self.workflow_dispatch.as_ref().ok_or_else(|| {
804 crate::error::ForgeError::Internal("Workflow dispatch not available".into())
805 })?;
806 dispatcher
807 .start_by_name(workflow_name, input_json, self.auth.principal_id())
808 .await
809 }
810}
811
812impl EnvAccess for MutationContext {
813 fn env_provider(&self) -> &dyn EnvProvider {
814 self.env_provider.as_ref()
815 }
816}
817
818#[cfg(test)]
819#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
820mod tests {
821 use super::*;
822
823 #[test]
824 fn test_auth_context_unauthenticated() {
825 let ctx = AuthContext::unauthenticated();
826 assert!(!ctx.is_authenticated());
827 assert!(ctx.user_id().is_none());
828 assert!(ctx.require_user_id().is_err());
829 }
830
831 #[test]
832 fn test_auth_context_authenticated() {
833 let user_id = Uuid::new_v4();
834 let ctx = AuthContext::authenticated(
835 user_id,
836 vec!["admin".to_string(), "user".to_string()],
837 HashMap::new(),
838 );
839
840 assert!(ctx.is_authenticated());
841 assert_eq!(ctx.user_id(), Some(user_id));
842 assert!(ctx.require_user_id().is_ok());
843 assert!(ctx.has_role("admin"));
844 assert!(ctx.has_role("user"));
845 assert!(!ctx.has_role("superadmin"));
846 assert!(ctx.require_role("admin").is_ok());
847 assert!(ctx.require_role("superadmin").is_err());
848 }
849
850 #[test]
851 fn test_auth_context_with_claims() {
852 let mut claims = HashMap::new();
853 claims.insert("org_id".to_string(), serde_json::json!("org-123"));
854
855 let ctx = AuthContext::authenticated(Uuid::new_v4(), vec![], claims);
856
857 assert_eq!(ctx.claim("org_id"), Some(&serde_json::json!("org-123")));
858 assert!(ctx.claim("nonexistent").is_none());
859 }
860
861 #[test]
862 fn test_request_metadata() {
863 let meta = RequestMetadata::new();
864 assert!(!meta.trace_id.is_empty());
865 assert!(meta.client_ip.is_none());
866
867 let meta2 = RequestMetadata::with_trace_id("trace-123".to_string());
868 assert_eq!(meta2.trace_id, "trace-123");
869 }
870}