1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use sqlx::postgres::{PgArguments, PgQueryResult, PgRow};
5use sqlx::{FromRow, Postgres, Transaction};
6use tokio::sync::Mutex as AsyncMutex;
7use uuid::Uuid;
8
9use super::dispatch::{JobDispatch, WorkflowDispatch};
10use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
11use crate::http::CircuitBreakerClient;
12use crate::job::JobInfo;
13
14pub enum DbConn<'a> {
16 Pool(&'a sqlx::PgPool),
17 Transaction(Arc<AsyncMutex<Transaction<'static, Postgres>>>),
18}
19
20impl DbConn<'_> {
21 pub async fn fetch_one<'q, O>(
22 &self,
23 query: sqlx::query::QueryAs<'q, Postgres, O, PgArguments>,
24 ) -> sqlx::Result<O>
25 where
26 O: Send + Unpin + for<'r> FromRow<'r, PgRow>,
27 {
28 match self {
29 DbConn::Pool(pool) => query.fetch_one(*pool).await,
30 DbConn::Transaction(tx) => query.fetch_one(&mut **tx.lock().await).await,
31 }
32 }
33
34 pub async fn fetch_optional<'q, O>(
35 &self,
36 query: sqlx::query::QueryAs<'q, Postgres, O, PgArguments>,
37 ) -> sqlx::Result<Option<O>>
38 where
39 O: Send + Unpin + for<'r> FromRow<'r, PgRow>,
40 {
41 match self {
42 DbConn::Pool(pool) => query.fetch_optional(*pool).await,
43 DbConn::Transaction(tx) => query.fetch_optional(&mut **tx.lock().await).await,
44 }
45 }
46
47 pub async fn fetch_all<'q, O>(
48 &self,
49 query: sqlx::query::QueryAs<'q, Postgres, O, PgArguments>,
50 ) -> sqlx::Result<Vec<O>>
51 where
52 O: Send + Unpin + for<'r> FromRow<'r, PgRow>,
53 {
54 match self {
55 DbConn::Pool(pool) => query.fetch_all(*pool).await,
56 DbConn::Transaction(tx) => query.fetch_all(&mut **tx.lock().await).await,
57 }
58 }
59
60 pub async fn execute<'q>(
61 &self,
62 query: sqlx::query::Query<'q, Postgres, PgArguments>,
63 ) -> sqlx::Result<PgQueryResult> {
64 match self {
65 DbConn::Pool(pool) => query.execute(*pool).await,
66 DbConn::Transaction(tx) => query.execute(&mut **tx.lock().await).await,
67 }
68 }
69}
70
71#[derive(Debug, Clone)]
72pub struct PendingJob {
73 pub id: Uuid,
74 pub job_type: String,
75 pub args: serde_json::Value,
76 pub context: serde_json::Value,
77 pub priority: i32,
78 pub max_attempts: i32,
79 pub worker_capability: Option<String>,
80}
81
82#[derive(Debug, Clone)]
83pub struct PendingWorkflow {
84 pub id: Uuid,
85 pub workflow_name: String,
86 pub input: serde_json::Value,
87}
88
89#[derive(Default)]
90pub struct OutboxBuffer {
91 pub jobs: Vec<PendingJob>,
92 pub workflows: Vec<PendingWorkflow>,
93}
94
95#[derive(Debug, Clone)]
97pub struct AuthContext {
98 user_id: Option<Uuid>,
100 roles: Vec<String>,
102 claims: HashMap<String, serde_json::Value>,
104 authenticated: bool,
106}
107
108impl AuthContext {
109 pub fn unauthenticated() -> Self {
111 Self {
112 user_id: None,
113 roles: Vec::new(),
114 claims: HashMap::new(),
115 authenticated: false,
116 }
117 }
118
119 pub fn authenticated(
121 user_id: Uuid,
122 roles: Vec<String>,
123 claims: HashMap<String, serde_json::Value>,
124 ) -> Self {
125 Self {
126 user_id: Some(user_id),
127 roles,
128 claims,
129 authenticated: true,
130 }
131 }
132
133 pub fn authenticated_without_uuid(
139 roles: Vec<String>,
140 claims: HashMap<String, serde_json::Value>,
141 ) -> Self {
142 Self {
143 user_id: None,
144 roles,
145 claims,
146 authenticated: true,
147 }
148 }
149
150 pub fn is_authenticated(&self) -> bool {
152 self.authenticated
153 }
154
155 pub fn user_id(&self) -> Option<Uuid> {
157 self.user_id
158 }
159
160 pub fn require_user_id(&self) -> crate::error::Result<Uuid> {
162 self.user_id
163 .ok_or_else(|| crate::error::ForgeError::Unauthorized("Authentication required".into()))
164 }
165
166 pub fn has_role(&self, role: &str) -> bool {
168 self.roles.iter().any(|r| r == role)
169 }
170
171 pub fn require_role(&self, role: &str) -> crate::error::Result<()> {
173 if self.has_role(role) {
174 Ok(())
175 } else {
176 Err(crate::error::ForgeError::Forbidden(format!(
177 "Required role '{}' not present",
178 role
179 )))
180 }
181 }
182
183 pub fn claim(&self, key: &str) -> Option<&serde_json::Value> {
185 self.claims.get(key)
186 }
187
188 pub fn roles(&self) -> &[String] {
190 &self.roles
191 }
192
193 pub fn subject(&self) -> Option<&str> {
199 self.claims.get("sub").and_then(|v| v.as_str())
200 }
201
202 pub fn require_subject(&self) -> crate::error::Result<&str> {
204 if !self.authenticated {
205 return Err(crate::error::ForgeError::Unauthorized(
206 "Authentication required".to_string(),
207 ));
208 }
209 self.subject().ok_or_else(|| {
210 crate::error::ForgeError::Unauthorized("No subject claim in token".to_string())
211 })
212 }
213}
214
215#[derive(Debug, Clone)]
217pub struct RequestMetadata {
218 pub request_id: Uuid,
220 pub trace_id: String,
222 pub client_ip: Option<String>,
224 pub user_agent: Option<String>,
226 pub timestamp: chrono::DateTime<chrono::Utc>,
228}
229
230impl RequestMetadata {
231 pub fn new() -> Self {
233 Self {
234 request_id: Uuid::new_v4(),
235 trace_id: Uuid::new_v4().to_string(),
236 client_ip: None,
237 user_agent: None,
238 timestamp: chrono::Utc::now(),
239 }
240 }
241
242 pub fn with_trace_id(trace_id: String) -> Self {
244 Self {
245 request_id: Uuid::new_v4(),
246 trace_id,
247 client_ip: None,
248 user_agent: None,
249 timestamp: chrono::Utc::now(),
250 }
251 }
252}
253
254impl Default for RequestMetadata {
255 fn default() -> Self {
256 Self::new()
257 }
258}
259
260pub struct QueryContext {
262 pub auth: AuthContext,
264 pub request: RequestMetadata,
266 db_pool: sqlx::PgPool,
268 env_provider: Arc<dyn EnvProvider>,
270}
271
272impl QueryContext {
273 pub fn new(db_pool: sqlx::PgPool, auth: AuthContext, request: RequestMetadata) -> Self {
275 Self {
276 auth,
277 request,
278 db_pool,
279 env_provider: Arc::new(RealEnvProvider::new()),
280 }
281 }
282
283 pub fn with_env(
285 db_pool: sqlx::PgPool,
286 auth: AuthContext,
287 request: RequestMetadata,
288 env_provider: Arc<dyn EnvProvider>,
289 ) -> Self {
290 Self {
291 auth,
292 request,
293 db_pool,
294 env_provider,
295 }
296 }
297
298 pub fn db(&self) -> &sqlx::PgPool {
300 &self.db_pool
301 }
302
303 pub fn require_user_id(&self) -> crate::error::Result<Uuid> {
305 self.auth.require_user_id()
306 }
307
308 pub fn require_subject(&self) -> crate::error::Result<&str> {
310 self.auth.require_subject()
311 }
312}
313
314impl EnvAccess for QueryContext {
315 fn env_provider(&self) -> &dyn EnvProvider {
316 self.env_provider.as_ref()
317 }
318}
319
320pub type JobInfoLookup = Arc<dyn Fn(&str) -> Option<JobInfo> + Send + Sync>;
322
323pub struct MutationContext {
325 pub auth: AuthContext,
327 pub request: RequestMetadata,
329 db_pool: sqlx::PgPool,
331 http_client: CircuitBreakerClient,
333 job_dispatch: Option<Arc<dyn JobDispatch>>,
335 workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
337 env_provider: Arc<dyn EnvProvider>,
339 tx: Option<Arc<AsyncMutex<Transaction<'static, Postgres>>>>,
341 outbox: Option<Arc<Mutex<OutboxBuffer>>>,
343 job_info_lookup: Option<JobInfoLookup>,
345}
346
347impl MutationContext {
348 pub fn new(db_pool: sqlx::PgPool, auth: AuthContext, request: RequestMetadata) -> Self {
350 Self {
351 auth,
352 request,
353 db_pool,
354 http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
355 job_dispatch: None,
356 workflow_dispatch: None,
357 env_provider: Arc::new(RealEnvProvider::new()),
358 tx: None,
359 outbox: None,
360 job_info_lookup: None,
361 }
362 }
363
364 pub fn with_dispatch(
366 db_pool: sqlx::PgPool,
367 auth: AuthContext,
368 request: RequestMetadata,
369 http_client: CircuitBreakerClient,
370 job_dispatch: Option<Arc<dyn JobDispatch>>,
371 workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
372 ) -> Self {
373 Self {
374 auth,
375 request,
376 db_pool,
377 http_client,
378 job_dispatch,
379 workflow_dispatch,
380 env_provider: Arc::new(RealEnvProvider::new()),
381 tx: None,
382 outbox: None,
383 job_info_lookup: None,
384 }
385 }
386
387 pub fn with_env(
389 db_pool: sqlx::PgPool,
390 auth: AuthContext,
391 request: RequestMetadata,
392 http_client: CircuitBreakerClient,
393 job_dispatch: Option<Arc<dyn JobDispatch>>,
394 workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
395 env_provider: Arc<dyn EnvProvider>,
396 ) -> Self {
397 Self {
398 auth,
399 request,
400 db_pool,
401 http_client,
402 job_dispatch,
403 workflow_dispatch,
404 env_provider,
405 tx: None,
406 outbox: None,
407 job_info_lookup: None,
408 }
409 }
410
411 #[allow(clippy::type_complexity)]
413 pub fn with_transaction(
414 db_pool: sqlx::PgPool,
415 tx: Transaction<'static, Postgres>,
416 auth: AuthContext,
417 request: RequestMetadata,
418 http_client: CircuitBreakerClient,
419 job_info_lookup: JobInfoLookup,
420 ) -> (
421 Self,
422 Arc<AsyncMutex<Transaction<'static, Postgres>>>,
423 Arc<Mutex<OutboxBuffer>>,
424 ) {
425 let tx_handle = Arc::new(AsyncMutex::new(tx));
426 let outbox = Arc::new(Mutex::new(OutboxBuffer::default()));
427
428 let ctx = Self {
429 auth,
430 request,
431 db_pool,
432 http_client,
433 job_dispatch: None,
434 workflow_dispatch: None,
435 env_provider: Arc::new(RealEnvProvider::new()),
436 tx: Some(tx_handle.clone()),
437 outbox: Some(outbox.clone()),
438 job_info_lookup: Some(job_info_lookup),
439 };
440
441 (ctx, tx_handle, outbox)
442 }
443
444 pub fn is_transactional(&self) -> bool {
445 self.tx.is_some()
446 }
447
448 pub fn db(&self) -> DbConn<'_> {
449 match &self.tx {
450 Some(tx) => DbConn::Transaction(tx.clone()),
451 None => DbConn::Pool(&self.db_pool),
452 }
453 }
454
455 pub fn pool(&self) -> &sqlx::PgPool {
457 &self.db_pool
458 }
459
460 pub fn http(&self) -> &reqwest::Client {
466 self.http_client.inner()
467 }
468
469 pub fn http_with_circuit_breaker(&self) -> &CircuitBreakerClient {
471 &self.http_client
472 }
473
474 pub fn require_user_id(&self) -> crate::error::Result<Uuid> {
475 self.auth.require_user_id()
476 }
477
478 pub fn require_subject(&self) -> crate::error::Result<&str> {
479 self.auth.require_subject()
480 }
481
482 pub async fn dispatch_job<T: serde::Serialize>(
484 &self,
485 job_type: &str,
486 args: T,
487 ) -> crate::error::Result<Uuid> {
488 let args_json = serde_json::to_value(args)?;
489
490 if let (Some(outbox), Some(job_info_lookup)) = (&self.outbox, &self.job_info_lookup) {
492 let job_info = job_info_lookup(job_type).ok_or_else(|| {
493 crate::error::ForgeError::NotFound(format!("Job type '{}' not found", job_type))
494 })?;
495
496 let pending = PendingJob {
497 id: Uuid::new_v4(),
498 job_type: job_type.to_string(),
499 args: args_json,
500 context: serde_json::json!({}),
501 priority: job_info.priority.as_i32(),
502 max_attempts: job_info.retry.max_attempts as i32,
503 worker_capability: job_info.worker_capability.map(|s| s.to_string()),
504 };
505
506 let job_id = pending.id;
507 outbox.lock().unwrap().jobs.push(pending);
508 return Ok(job_id);
509 }
510
511 let dispatcher = self.job_dispatch.as_ref().ok_or_else(|| {
513 crate::error::ForgeError::Internal("Job dispatch not available".into())
514 })?;
515 dispatcher.dispatch_by_name(job_type, args_json).await
516 }
517
518 pub async fn dispatch_job_with_context<T: serde::Serialize>(
520 &self,
521 job_type: &str,
522 args: T,
523 context: serde_json::Value,
524 ) -> crate::error::Result<Uuid> {
525 let args_json = serde_json::to_value(args)?;
526
527 if let (Some(outbox), Some(job_info_lookup)) = (&self.outbox, &self.job_info_lookup) {
528 let job_info = job_info_lookup(job_type).ok_or_else(|| {
529 crate::error::ForgeError::NotFound(format!("Job type '{}' not found", job_type))
530 })?;
531
532 let pending = PendingJob {
533 id: Uuid::new_v4(),
534 job_type: job_type.to_string(),
535 args: args_json,
536 context,
537 priority: job_info.priority.as_i32(),
538 max_attempts: job_info.retry.max_attempts as i32,
539 worker_capability: job_info.worker_capability.map(|s| s.to_string()),
540 };
541
542 let job_id = pending.id;
543 outbox.lock().unwrap().jobs.push(pending);
544 return Ok(job_id);
545 }
546
547 let dispatcher = self.job_dispatch.as_ref().ok_or_else(|| {
548 crate::error::ForgeError::Internal("Job dispatch not available".into())
549 })?;
550 dispatcher.dispatch_by_name(job_type, args_json).await
551 }
552
553 pub async fn cancel_job(
555 &self,
556 job_id: Uuid,
557 reason: Option<String>,
558 ) -> crate::error::Result<bool> {
559 let dispatcher = self.job_dispatch.as_ref().ok_or_else(|| {
560 crate::error::ForgeError::Internal("Job dispatch not available".into())
561 })?;
562 dispatcher.cancel(job_id, reason).await
563 }
564
565 pub async fn start_workflow<T: serde::Serialize>(
567 &self,
568 workflow_name: &str,
569 input: T,
570 ) -> crate::error::Result<Uuid> {
571 let input_json = serde_json::to_value(input)?;
572
573 if let Some(outbox) = &self.outbox {
575 let pending = PendingWorkflow {
576 id: Uuid::new_v4(),
577 workflow_name: workflow_name.to_string(),
578 input: input_json,
579 };
580
581 let workflow_id = pending.id;
582 outbox.lock().unwrap().workflows.push(pending);
583 return Ok(workflow_id);
584 }
585
586 let dispatcher = self.workflow_dispatch.as_ref().ok_or_else(|| {
588 crate::error::ForgeError::Internal("Workflow dispatch not available".into())
589 })?;
590 dispatcher.start_by_name(workflow_name, input_json).await
591 }
592}
593
594impl EnvAccess for MutationContext {
595 fn env_provider(&self) -> &dyn EnvProvider {
596 self.env_provider.as_ref()
597 }
598}
599
600#[cfg(test)]
601mod tests {
602 use super::*;
603
604 #[test]
605 fn test_auth_context_unauthenticated() {
606 let ctx = AuthContext::unauthenticated();
607 assert!(!ctx.is_authenticated());
608 assert!(ctx.user_id().is_none());
609 assert!(ctx.require_user_id().is_err());
610 }
611
612 #[test]
613 fn test_auth_context_authenticated() {
614 let user_id = Uuid::new_v4();
615 let ctx = AuthContext::authenticated(
616 user_id,
617 vec!["admin".to_string(), "user".to_string()],
618 HashMap::new(),
619 );
620
621 assert!(ctx.is_authenticated());
622 assert_eq!(ctx.user_id(), Some(user_id));
623 assert!(ctx.require_user_id().is_ok());
624 assert!(ctx.has_role("admin"));
625 assert!(ctx.has_role("user"));
626 assert!(!ctx.has_role("superadmin"));
627 assert!(ctx.require_role("admin").is_ok());
628 assert!(ctx.require_role("superadmin").is_err());
629 }
630
631 #[test]
632 fn test_auth_context_with_claims() {
633 let mut claims = HashMap::new();
634 claims.insert("org_id".to_string(), serde_json::json!("org-123"));
635
636 let ctx = AuthContext::authenticated(Uuid::new_v4(), vec![], claims);
637
638 assert_eq!(ctx.claim("org_id"), Some(&serde_json::json!("org-123")));
639 assert!(ctx.claim("nonexistent").is_none());
640 }
641
642 #[test]
643 fn test_request_metadata() {
644 let meta = RequestMetadata::new();
645 assert!(!meta.trace_id.is_empty());
646 assert!(meta.client_ip.is_none());
647
648 let meta2 = RequestMetadata::with_trace_id("trace-123".to_string());
649 assert_eq!(meta2.trace_id, "trace-123");
650 }
651}