1use std::collections::HashMap;
36use std::sync::Arc;
37use std::sync::atomic::{AtomicUsize, Ordering};
38use std::time::Duration;
39
40use chrono::{DateTime, Utc};
41
42use futures_core::future::BoxFuture;
43use futures_core::stream::BoxStream;
44use sqlx::postgres::{PgConnection, PgQueryResult, PgRow};
45use sqlx::{Postgres, Transaction};
46use tokio::sync::Mutex as AsyncMutex;
47use uuid::Uuid;
48
49use tracing::Instrument;
50
51use super::dispatch::{JobDispatch, KvHandle, WorkflowDispatch};
52use crate::auth::Claims;
53use crate::env::{EnvAccess, EnvProvider, RealEnvProvider};
54use crate::http::CircuitBreakerClient;
55
56pub trait TokenIssuer: Send + Sync {
61 fn sign(&self, claims: &Claims) -> crate::error::Result<String>;
63}
64
65pub enum ForgeConn<'a> {
78 Pool(sqlx::pool::PoolConnection<Postgres>),
79 Tx(tokio::sync::MutexGuard<'a, Option<Transaction<'static, Postgres>>>),
80}
81
82impl std::ops::Deref for ForgeConn<'_> {
83 type Target = PgConnection;
84 fn deref(&self) -> &PgConnection {
85 match self {
86 ForgeConn::Pool(c) => c,
87 ForgeConn::Tx(g) => g
88 .as_ref()
89 .expect("ForgeConn::Tx held while transaction was already taken"),
90 }
91 }
92}
93
94impl std::ops::DerefMut for ForgeConn<'_> {
95 fn deref_mut(&mut self) -> &mut PgConnection {
96 match self {
97 ForgeConn::Pool(c) => c,
98 ForgeConn::Tx(g) => g
99 .as_mut()
100 .expect("ForgeConn::Tx held while transaction was already taken"),
101 }
102 }
103}
104
105#[derive(Clone)]
117pub struct ForgeDb(sqlx::PgPool);
118
119impl std::fmt::Debug for ForgeDb {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 f.debug_tuple("ForgeDb").finish()
122 }
123}
124
125impl ForgeDb {
126 pub fn from_pool(pool: &sqlx::PgPool) -> Self {
128 Self(pool.clone())
129 }
130}
131
132fn sql_operation(sql: &str) -> &'static str {
133 let bytes = sql.trim_start().as_bytes();
134 match bytes.get(..6) {
135 Some(prefix) if prefix.eq_ignore_ascii_case(b"select") => "SELECT",
136 Some(prefix) if prefix.eq_ignore_ascii_case(b"insert") => "INSERT",
137 Some(prefix) if prefix.eq_ignore_ascii_case(b"update") => "UPDATE",
138 Some(prefix) if prefix.eq_ignore_ascii_case(b"delete") => "DELETE",
139 _ => "OTHER",
140 }
141}
142
143impl sqlx::Executor<'static> for ForgeDb {
144 type Database = Postgres;
145
146 fn fetch_many<'e, 'q: 'e, E>(
147 self,
148 query: E,
149 ) -> BoxStream<'e, Result<sqlx::Either<PgQueryResult, PgRow>, sqlx::Error>>
150 where
151 E: sqlx::Execute<'q, Postgres> + 'q,
152 {
153 (&self.0).fetch_many(query)
154 }
155
156 fn fetch_optional<'e, 'q: 'e, E>(
157 self,
158 query: E,
159 ) -> BoxFuture<'e, Result<Option<PgRow>, sqlx::Error>>
160 where
161 E: sqlx::Execute<'q, Postgres> + 'q,
162 {
163 let op = sql_operation(query.sql());
164 let span =
165 tracing::info_span!("db.query", db.system = "postgresql", db.operation.name = op,);
166 Box::pin(
167 async move { sqlx::Executor::fetch_optional(&self.0, query).await }.instrument(span),
168 )
169 }
170
171 fn execute<'e, 'q: 'e, E>(self, query: E) -> BoxFuture<'e, Result<PgQueryResult, sqlx::Error>>
172 where
173 E: sqlx::Execute<'q, Postgres> + 'q,
174 {
175 let op = sql_operation(query.sql());
176 let span =
177 tracing::info_span!("db.query", db.system = "postgresql", db.operation.name = op,);
178 Box::pin(async move { sqlx::Executor::execute(&self.0, query).await }.instrument(span))
179 }
180
181 fn fetch_all<'e, 'q: 'e, E>(self, query: E) -> BoxFuture<'e, Result<Vec<PgRow>, sqlx::Error>>
182 where
183 E: sqlx::Execute<'q, Postgres> + 'q,
184 {
185 let op = sql_operation(query.sql());
186 let span =
187 tracing::info_span!("db.query", db.system = "postgresql", db.operation.name = op,);
188 Box::pin(async move { sqlx::Executor::fetch_all(&self.0, query).await }.instrument(span))
189 }
190
191 fn fetch_one<'e, 'q: 'e, E>(self, query: E) -> BoxFuture<'e, Result<PgRow, sqlx::Error>>
192 where
193 E: sqlx::Execute<'q, Postgres> + 'q,
194 {
195 let op = sql_operation(query.sql());
196 let span =
197 tracing::info_span!("db.query", db.system = "postgresql", db.operation.name = op,);
198 Box::pin(async move { sqlx::Executor::fetch_one(&self.0, query).await }.instrument(span))
199 }
200
201 fn prepare_with<'e, 'q: 'e>(
202 self,
203 sql: &'q str,
204 parameters: &'e [<Postgres as sqlx::Database>::TypeInfo],
205 ) -> BoxFuture<'e, Result<<Postgres as sqlx::Database>::Statement<'q>, sqlx::Error>> {
206 Box::pin(async move { sqlx::Executor::prepare_with(&self.0, sql, parameters).await })
207 }
208
209 fn describe<'e, 'q: 'e>(
210 self,
211 sql: &'q str,
212 ) -> BoxFuture<'e, Result<sqlx::Describe<Postgres>, sqlx::Error>> {
213 Box::pin(async move { sqlx::Executor::describe(&self.0, sql).await })
214 }
215}
216
217#[non_exhaustive]
233pub enum DbConn<'a> {
234 Pool(sqlx::PgPool),
236 Transaction(
238 Arc<AsyncMutex<Option<Transaction<'static, Postgres>>>>,
239 &'a sqlx::PgPool,
240 ),
241}
242
243impl DbConn<'_> {
244 pub async fn fetch_one<'q, O>(
246 &self,
247 query: sqlx::query::QueryAs<'q, Postgres, O, sqlx::postgres::PgArguments>,
248 ) -> sqlx::Result<O>
249 where
250 O: Send + Unpin + for<'r> sqlx::FromRow<'r, PgRow>,
251 {
252 match self {
253 DbConn::Pool(pool) => query.fetch_one(pool).await,
254 DbConn::Transaction(tx, _) => {
255 let mut guard = tx.lock().await;
256 let conn = guard.as_mut().ok_or(sqlx::Error::PoolClosed)?;
257 query.fetch_one(&mut **conn).await
258 }
259 }
260 }
261
262 pub async fn fetch_optional<'q, O>(
264 &self,
265 query: sqlx::query::QueryAs<'q, Postgres, O, sqlx::postgres::PgArguments>,
266 ) -> sqlx::Result<Option<O>>
267 where
268 O: Send + Unpin + for<'r> sqlx::FromRow<'r, PgRow>,
269 {
270 match self {
271 DbConn::Pool(pool) => query.fetch_optional(pool).await,
272 DbConn::Transaction(tx, _) => {
273 let mut guard = tx.lock().await;
274 let conn = guard.as_mut().ok_or(sqlx::Error::PoolClosed)?;
275 query.fetch_optional(&mut **conn).await
276 }
277 }
278 }
279
280 pub async fn fetch_all<'q, O>(
282 &self,
283 query: sqlx::query::QueryAs<'q, Postgres, O, sqlx::postgres::PgArguments>,
284 ) -> sqlx::Result<Vec<O>>
285 where
286 O: Send + Unpin + for<'r> sqlx::FromRow<'r, PgRow>,
287 {
288 match self {
289 DbConn::Pool(pool) => query.fetch_all(pool).await,
290 DbConn::Transaction(tx, _) => {
291 let mut guard = tx.lock().await;
292 let conn = guard.as_mut().ok_or(sqlx::Error::PoolClosed)?;
293 query.fetch_all(&mut **conn).await
294 }
295 }
296 }
297
298 pub async fn execute<'q>(
300 &self,
301 query: sqlx::query::Query<'q, Postgres, sqlx::postgres::PgArguments>,
302 ) -> sqlx::Result<PgQueryResult> {
303 match self {
304 DbConn::Pool(pool) => query.execute(pool).await,
305 DbConn::Transaction(tx, _) => {
306 let mut guard = tx.lock().await;
307 let conn = guard.as_mut().ok_or(sqlx::Error::PoolClosed)?;
308 query.execute(&mut **conn).await
309 }
310 }
311 }
312}
313
314impl std::fmt::Debug for DbConn<'_> {
315 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316 match self {
317 DbConn::Pool(_) => f.debug_tuple("DbConn::Pool").finish(),
318 DbConn::Transaction(_, _) => f.debug_tuple("DbConn::Transaction").finish(),
319 }
320 }
321}
322
323impl std::fmt::Debug for ForgeConn<'_> {
324 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325 match self {
326 ForgeConn::Pool(_) => f.debug_tuple("ForgeConn::Pool").finish(),
327 ForgeConn::Tx(_) => f.debug_tuple("ForgeConn::Tx").finish(),
328 }
329 }
330}
331
332impl<'c> sqlx::Executor<'c> for &'c mut ForgeConn<'_> {
333 type Database = Postgres;
334
335 fn fetch_many<'e, 'q: 'e, E>(
336 self,
337 query: E,
338 ) -> BoxStream<'e, Result<sqlx::Either<PgQueryResult, PgRow>, sqlx::Error>>
339 where
340 'c: 'e,
341 E: sqlx::Execute<'q, Postgres> + 'q,
342 {
343 let conn: &'e mut PgConnection = &mut *self;
344 conn.fetch_many(query)
345 }
346
347 fn fetch_optional<'e, 'q: 'e, E>(
348 self,
349 query: E,
350 ) -> BoxFuture<'e, Result<Option<PgRow>, sqlx::Error>>
351 where
352 'c: 'e,
353 E: sqlx::Execute<'q, Postgres> + 'q,
354 {
355 let op = sql_operation(query.sql());
356 let span =
357 tracing::info_span!("db.query", db.system = "postgresql", db.operation.name = op,);
358 let conn: &'e mut PgConnection = &mut *self;
359 Box::pin(conn.fetch_optional(query).instrument(span))
360 }
361
362 fn execute<'e, 'q: 'e, E>(self, query: E) -> BoxFuture<'e, Result<PgQueryResult, sqlx::Error>>
363 where
364 'c: 'e,
365 E: sqlx::Execute<'q, Postgres> + 'q,
366 {
367 let op = sql_operation(query.sql());
368 let span =
369 tracing::info_span!("db.query", db.system = "postgresql", db.operation.name = op,);
370 let conn: &'e mut PgConnection = &mut *self;
371 Box::pin(conn.execute(query).instrument(span))
372 }
373
374 fn fetch_all<'e, 'q: 'e, E>(self, query: E) -> BoxFuture<'e, Result<Vec<PgRow>, sqlx::Error>>
375 where
376 'c: 'e,
377 E: sqlx::Execute<'q, Postgres> + 'q,
378 {
379 let op = sql_operation(query.sql());
380 let span =
381 tracing::info_span!("db.query", db.system = "postgresql", db.operation.name = op,);
382 let conn: &'e mut PgConnection = &mut *self;
383 Box::pin(conn.fetch_all(query).instrument(span))
384 }
385
386 fn fetch_one<'e, 'q: 'e, E>(self, query: E) -> BoxFuture<'e, Result<PgRow, sqlx::Error>>
387 where
388 'c: 'e,
389 E: sqlx::Execute<'q, Postgres> + 'q,
390 {
391 let op = sql_operation(query.sql());
392 let span =
393 tracing::info_span!("db.query", db.system = "postgresql", db.operation.name = op,);
394 let conn: &'e mut PgConnection = &mut *self;
395 Box::pin(conn.fetch_one(query).instrument(span))
396 }
397
398 fn prepare_with<'e, 'q: 'e>(
399 self,
400 sql: &'q str,
401 parameters: &'e [<Postgres as sqlx::Database>::TypeInfo],
402 ) -> BoxFuture<'e, Result<<Postgres as sqlx::Database>::Statement<'q>, sqlx::Error>>
403 where
404 'c: 'e,
405 {
406 let conn: &'e mut PgConnection = &mut *self;
407 conn.prepare_with(sql, parameters)
408 }
409
410 fn describe<'e, 'q: 'e>(
411 self,
412 sql: &'q str,
413 ) -> BoxFuture<'e, Result<sqlx::Describe<Postgres>, sqlx::Error>>
414 where
415 'c: 'e,
416 {
417 let conn: &'e mut PgConnection = &mut *self;
418 conn.describe(sql)
419 }
420}
421
422#[derive(Debug, Clone)]
424#[non_exhaustive]
425pub struct AuthContext {
426 user_id: Option<Uuid>,
427 roles: Vec<String>,
428 claims: HashMap<String, serde_json::Value>,
429 authenticated: bool,
430 token_exp: Option<i64>,
432}
433
434impl AuthContext {
435 pub fn unauthenticated() -> Self {
437 Self {
438 user_id: None,
439 roles: Vec::new(),
440 claims: HashMap::new(),
441 authenticated: false,
442 token_exp: None,
443 }
444 }
445
446 pub fn authenticated(
448 user_id: Uuid,
449 roles: Vec<String>,
450 claims: HashMap<String, serde_json::Value>,
451 ) -> Self {
452 Self {
453 user_id: Some(user_id),
454 roles,
455 claims,
456 authenticated: true,
457 token_exp: None,
458 }
459 }
460
461 pub fn authenticated_without_uuid(
467 roles: Vec<String>,
468 claims: HashMap<String, serde_json::Value>,
469 ) -> Self {
470 Self {
471 user_id: None,
472 roles,
473 claims,
474 authenticated: true,
475 token_exp: None,
476 }
477 }
478
479 pub fn with_token_exp(mut self, exp: i64) -> Self {
484 self.token_exp = Some(exp);
485 self
486 }
487
488 pub fn token_exp(&self) -> Option<i64> {
490 self.token_exp
491 }
492
493 pub fn token_is_expired(&self) -> bool {
497 self.token_exp
498 .map(|exp| exp < chrono::Utc::now().timestamp())
499 .unwrap_or(false)
500 }
501
502 pub fn is_authenticated(&self) -> bool {
504 self.authenticated
505 }
506
507 pub fn user_id(&self) -> Option<Uuid> {
509 self.user_id
510 }
511
512 pub fn require_user_id(&self) -> crate::error::Result<Uuid> {
514 self.user_id
515 .ok_or_else(|| crate::error::ForgeError::Unauthorized("Authentication required".into()))
516 }
517
518 pub fn has_role(&self, role: &str) -> bool {
520 self.roles.iter().any(|r| r == role)
521 }
522
523 pub fn require_role(&self, role: &str) -> crate::error::Result<()> {
525 if self.has_role(role) {
526 Ok(())
527 } else {
528 Err(crate::error::ForgeError::Forbidden(format!(
529 "Required role '{}' not present",
530 role
531 )))
532 }
533 }
534
535 pub fn claim(&self, key: &str) -> Option<&serde_json::Value> {
537 self.claims.get(key)
538 }
539
540 pub fn claims(&self) -> &HashMap<String, serde_json::Value> {
542 &self.claims
543 }
544
545 pub fn roles(&self) -> &[String] {
547 &self.roles
548 }
549
550 pub fn subject(&self) -> Option<&str> {
556 self.claims.get("sub").and_then(|v| v.as_str())
557 }
558
559 pub fn require_subject(&self) -> crate::error::Result<&str> {
561 if !self.authenticated {
562 return Err(crate::error::ForgeError::Unauthorized(
563 "Authentication required".to_string(),
564 ));
565 }
566 self.subject().ok_or_else(|| {
567 crate::error::ForgeError::Unauthorized("No subject claim in token".to_string())
568 })
569 }
570
571 pub fn principal_id(&self) -> Option<String> {
575 self.subject()
576 .map(ToString::to_string)
577 .or_else(|| self.user_id.map(|id| id.to_string()))
578 }
579
580 pub fn is_admin(&self) -> bool {
582 self.roles.iter().any(|r| r == "admin")
583 }
584
585 pub fn tenant_id(&self) -> Option<uuid::Uuid> {
590 self.claims
591 .get("tenant_id")
592 .and_then(|v| v.as_str())
593 .and_then(|s| uuid::Uuid::parse_str(s).ok())
594 }
595}
596
597#[derive(Debug, Clone)]
599#[non_exhaustive]
600pub struct RequestMetadata {
601 pub(crate) request_id: Uuid,
602 pub(crate) trace_id: String,
603 pub(crate) client_ip: Option<String>,
604 pub(crate) user_agent: Option<String>,
605 pub(crate) correlation_id: Option<String>,
607 pub(crate) timestamp: chrono::DateTime<chrono::Utc>,
608}
609
610impl RequestMetadata {
611 pub fn new() -> Self {
613 Self {
614 request_id: Uuid::new_v4(),
615 trace_id: Uuid::new_v4().to_string(),
616 client_ip: None,
617 user_agent: None,
618 correlation_id: None,
619 timestamp: chrono::Utc::now(),
620 }
621 }
622
623 pub fn with_trace_id(trace_id: String) -> Self {
625 Self {
626 request_id: Uuid::new_v4(),
627 trace_id,
628 client_ip: None,
629 user_agent: None,
630 correlation_id: None,
631 timestamp: chrono::Utc::now(),
632 }
633 }
634
635 #[doc(hidden)]
642 pub fn __build_internal(
643 request_id: Uuid,
644 trace_id: String,
645 client_ip: Option<String>,
646 user_agent: Option<String>,
647 correlation_id: Option<String>,
648 ) -> Self {
649 Self {
650 request_id,
651 trace_id,
652 client_ip,
653 user_agent,
654 correlation_id,
655 timestamp: chrono::Utc::now(),
656 }
657 }
658
659 pub fn set_client_ip(&mut self, ip: Option<String>) {
661 self.client_ip = ip;
662 }
663
664 pub fn set_user_agent(&mut self, ua: Option<String>) {
666 self.user_agent = ua;
667 }
668
669 pub fn set_correlation_id(&mut self, id: Option<String>) {
671 self.correlation_id = id;
672 }
673
674 pub fn request_id(&self) -> Uuid {
676 self.request_id
677 }
678
679 pub fn trace_id(&self) -> &str {
681 &self.trace_id
682 }
683
684 pub fn client_ip(&self) -> Option<&str> {
686 self.client_ip.as_deref()
687 }
688
689 pub fn user_agent(&self) -> Option<&str> {
691 self.user_agent.as_deref()
692 }
693
694 pub fn correlation_id(&self) -> Option<&str> {
696 self.correlation_id.as_deref()
697 }
698
699 pub fn timestamp(&self) -> chrono::DateTime<chrono::Utc> {
701 self.timestamp
702 }
703}
704
705impl Default for RequestMetadata {
706 fn default() -> Self {
707 Self::new()
708 }
709}
710
711#[non_exhaustive]
713pub struct QueryContext {
714 pub auth: AuthContext,
715 pub request: RequestMetadata,
716 db_pool: sqlx::PgPool,
717 env_provider: Arc<dyn EnvProvider>,
718 kv: Option<Arc<dyn KvHandle>>,
719}
720
721impl QueryContext {
722 pub fn new(db_pool: sqlx::PgPool, auth: AuthContext, request: RequestMetadata) -> Self {
724 Self {
725 auth,
726 request,
727 db_pool,
728 env_provider: RealEnvProvider::shared(),
729 kv: None,
730 }
731 }
732
733 pub fn with_env(
735 db_pool: sqlx::PgPool,
736 auth: AuthContext,
737 request: RequestMetadata,
738 env_provider: Arc<dyn EnvProvider>,
739 ) -> Self {
740 Self {
741 auth,
742 request,
743 db_pool,
744 env_provider,
745 kv: None,
746 }
747 }
748
749 pub fn set_kv(&mut self, kv: Arc<dyn KvHandle>) {
752 self.kv = Some(kv);
753 }
754
755 pub fn kv(&self) -> crate::error::Result<&dyn KvHandle> {
761 self.kv
762 .as_deref()
763 .ok_or_else(|| crate::error::ForgeError::internal("KV store not available"))
764 }
765
766 pub fn db(&self) -> ForgeDb {
775 ForgeDb(self.db_pool.clone())
776 }
777
778 pub fn db_conn(&self) -> DbConn<'_> {
792 DbConn::Pool(self.db_pool.clone())
793 }
794
795 pub fn user_id(&self) -> crate::error::Result<Uuid> {
797 self.auth.require_user_id()
798 }
799
800 pub fn tenant_id(&self) -> Option<Uuid> {
802 self.auth.tenant_id()
803 }
804
805 pub fn claim(&self, key: &str) -> Option<&serde_json::Value> {
810 self.auth.claim(key)
811 }
812}
813
814impl EnvAccess for QueryContext {
815 fn env_provider(&self) -> &dyn EnvProvider {
816 self.env_provider.as_ref()
817 }
818}
819
820#[derive(Debug, Clone)]
822#[non_exhaustive]
823pub struct AuthTokenTtl {
824 pub access_token_secs: i64,
826 pub refresh_token_days: i64,
828}
829
830impl AuthTokenTtl {
831 pub fn new(access_token_secs: i64, refresh_token_days: i64) -> Self {
833 Self {
834 access_token_secs,
835 refresh_token_days,
836 }
837 }
838}
839
840impl Default for AuthTokenTtl {
841 fn default() -> Self {
842 Self {
843 access_token_secs: 3600,
844 refresh_token_days: 30,
845 }
846 }
847}
848
849#[non_exhaustive]
851pub struct MutationContext {
852 pub auth: AuthContext,
853 pub request: RequestMetadata,
854 db_pool: sqlx::PgPool,
855 http_client: CircuitBreakerClient,
856 http_timeout: Option<Duration>,
858 job_dispatch: Option<Arc<dyn JobDispatch>>,
859 workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
860 env_provider: Arc<dyn EnvProvider>,
861 tx: Option<Arc<AsyncMutex<Option<Transaction<'static, Postgres>>>>>,
864 token_issuer: Option<Arc<dyn TokenIssuer>>,
865 token_ttl: AuthTokenTtl,
866 dispatched_job_count: Arc<AtomicUsize>,
867 max_jobs_per_request: usize,
869 kv: Option<Arc<dyn KvHandle>>,
870 email_sender: Option<Arc<dyn crate::email::EmailSender>>,
871}
872
873impl MutationContext {
874 pub fn new(db_pool: sqlx::PgPool, auth: AuthContext, request: RequestMetadata) -> Self {
876 Self {
877 auth,
878 request,
879 db_pool,
880 http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
881 http_timeout: None,
882 job_dispatch: None,
883 workflow_dispatch: None,
884 env_provider: RealEnvProvider::shared(),
885 tx: None,
886 token_issuer: None,
887 token_ttl: AuthTokenTtl::default(),
888 dispatched_job_count: Arc::new(AtomicUsize::new(0)),
889 max_jobs_per_request: 0,
890 kv: None,
891 email_sender: None,
892 }
893 }
894
895 pub fn with_dispatch(
897 db_pool: sqlx::PgPool,
898 auth: AuthContext,
899 request: RequestMetadata,
900 http_client: CircuitBreakerClient,
901 job_dispatch: Option<Arc<dyn JobDispatch>>,
902 workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
903 ) -> Self {
904 Self {
905 auth,
906 request,
907 db_pool,
908 http_client,
909 http_timeout: None,
910 job_dispatch,
911 workflow_dispatch,
912 env_provider: RealEnvProvider::shared(),
913 tx: None,
914 token_issuer: None,
915 token_ttl: AuthTokenTtl::default(),
916 dispatched_job_count: Arc::new(AtomicUsize::new(0)),
917 max_jobs_per_request: 0,
918 kv: None,
919 email_sender: None,
920 }
921 }
922
923 pub fn with_env(
925 db_pool: sqlx::PgPool,
926 auth: AuthContext,
927 request: RequestMetadata,
928 http_client: CircuitBreakerClient,
929 job_dispatch: Option<Arc<dyn JobDispatch>>,
930 workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
931 env_provider: Arc<dyn EnvProvider>,
932 ) -> Self {
933 Self {
934 auth,
935 request,
936 db_pool,
937 http_client,
938 http_timeout: None,
939 job_dispatch,
940 workflow_dispatch,
941 env_provider,
942 tx: None,
943 token_issuer: None,
944 token_ttl: AuthTokenTtl::default(),
945 dispatched_job_count: Arc::new(AtomicUsize::new(0)),
946 max_jobs_per_request: 0,
947 kv: None,
948 email_sender: None,
949 }
950 }
951
952 pub fn with_transaction(
961 db_pool: sqlx::PgPool,
962 tx: Transaction<'static, Postgres>,
963 auth: AuthContext,
964 request: RequestMetadata,
965 http_client: CircuitBreakerClient,
966 job_dispatch: Option<Arc<dyn JobDispatch>>,
967 workflow_dispatch: Option<Arc<dyn WorkflowDispatch>>,
968 ) -> (
969 Self,
970 Arc<AsyncMutex<Option<Transaction<'static, Postgres>>>>,
971 ) {
972 let tx_handle = Arc::new(AsyncMutex::new(Some(tx)));
973
974 let ctx = Self {
975 auth,
976 request,
977 db_pool,
978 http_client,
979 http_timeout: None,
980 job_dispatch,
981 workflow_dispatch,
982 env_provider: RealEnvProvider::shared(),
983 tx: Some(tx_handle.clone()),
984 token_issuer: None,
985 token_ttl: AuthTokenTtl::default(),
986 dispatched_job_count: Arc::new(AtomicUsize::new(0)),
987 max_jobs_per_request: 0,
988 kv: None,
989 email_sender: None,
990 };
991
992 (ctx, tx_handle)
993 }
994
995 pub fn set_kv(&mut self, kv: Arc<dyn KvHandle>) {
998 self.kv = Some(kv);
999 }
1000
1001 pub fn kv(&self) -> crate::error::Result<&dyn KvHandle> {
1007 self.kv
1008 .as_deref()
1009 .ok_or_else(|| crate::error::ForgeError::internal("KV store not available"))
1010 }
1011
1012 pub fn set_email(&mut self, sender: Arc<dyn crate::email::EmailSender>) {
1014 self.email_sender = Some(sender);
1015 }
1016
1017 pub fn email(&self) -> crate::error::Result<&dyn crate::email::EmailSender> {
1019 self.email_sender
1020 .as_deref()
1021 .ok_or_else(|| crate::error::ForgeError::internal("Email not configured"))
1022 }
1023
1024 pub fn is_transactional(&self) -> bool {
1025 self.tx.is_some()
1026 }
1027
1028 pub async fn conn(&self) -> sqlx::Result<ForgeConn<'_>> {
1040 match &self.tx {
1041 Some(tx) => Ok(ForgeConn::Tx(tx.lock().await)),
1042 None => Ok(ForgeConn::Pool(self.db_pool.acquire().await?)),
1043 }
1044 }
1045
1046 pub fn bypass_pool(&self) -> &sqlx::PgPool {
1056 &self.db_pool
1057 }
1058
1059 pub fn tx(&self) -> DbConn<'_> {
1073 match &self.tx {
1074 Some(tx) => DbConn::Transaction(tx.clone(), &self.db_pool),
1075 None => DbConn::Pool(self.db_pool.clone()),
1076 }
1077 }
1078
1079 pub fn db_conn(&self) -> DbConn<'_> {
1081 self.tx()
1082 }
1083
1084 pub fn http(&self) -> crate::http::HttpClient {
1090 self.http_client.with_timeout(self.http_timeout)
1091 }
1092
1093 pub fn raw_http(&self) -> &reqwest::Client {
1095 self.http_client.inner()
1096 }
1097
1098 pub fn set_http_timeout(&mut self, timeout: Option<Duration>) {
1100 self.http_timeout = timeout;
1101 }
1102
1103 pub fn user_id(&self) -> crate::error::Result<Uuid> {
1105 self.auth.require_user_id()
1106 }
1107
1108 pub fn tenant_id(&self) -> Option<Uuid> {
1110 self.auth.tenant_id()
1111 }
1112
1113 pub fn claim(&self, key: &str) -> Option<&serde_json::Value> {
1117 self.auth.claim(key)
1118 }
1119
1120 pub fn set_token_issuer(&mut self, issuer: Arc<dyn TokenIssuer>) {
1122 self.token_issuer = Some(issuer);
1123 }
1124
1125 pub fn set_token_ttl(&mut self, ttl: AuthTokenTtl) {
1127 self.token_ttl = ttl;
1128 }
1129
1130 pub fn set_max_jobs_per_request(&mut self, limit: usize) {
1133 self.max_jobs_per_request = limit;
1134 }
1135
1136 pub fn issue_token(&self, claims: &Claims) -> crate::error::Result<String> {
1151 let issuer = self.token_issuer.as_ref().ok_or_else(|| {
1152 crate::error::ForgeError::internal(
1153 "Token issuer not available. Configure [auth] with an HMAC algorithm in forge.toml",
1154 )
1155 })?;
1156 issuer.sign(claims)
1157 }
1158
1159 pub async fn issue_token_pair(
1169 &self,
1170 user_id: Uuid,
1171 roles: &[&str],
1172 ) -> crate::error::Result<crate::auth::TokenPair> {
1173 let issuer = self.token_issuer.clone().ok_or_else(|| {
1174 crate::error::ForgeError::internal(
1175 "Token issuer not available. Configure [auth] in forge.toml",
1176 )
1177 })?;
1178 let access_ttl = self.token_ttl.access_token_secs;
1179 let refresh_ttl = self.token_ttl.refresh_token_days;
1180 crate::auth::tokens::issue_token_pair(
1181 &self.db_pool,
1182 user_id,
1183 roles,
1184 access_ttl,
1185 refresh_ttl,
1186 move |uid, r, ttl| {
1187 let claims = Claims::builder()
1188 .subject(uid)
1189 .roles(r.iter().map(|s| s.to_string()).collect())
1190 .duration_secs(ttl)
1191 .build()
1192 .map_err(crate::error::ForgeError::internal)?;
1193 issuer.sign(&claims)
1194 },
1195 )
1196 .await
1197 }
1198
1199 pub async fn rotate_refresh_token(
1204 &self,
1205 old_refresh_token: &str,
1206 ) -> crate::error::Result<crate::auth::TokenPair> {
1207 let issuer = self.token_issuer.clone().ok_or_else(|| {
1208 crate::error::ForgeError::internal(
1209 "Token issuer not available. Configure [auth] in forge.toml",
1210 )
1211 })?;
1212 let access_ttl = self.token_ttl.access_token_secs;
1213 let refresh_ttl = self.token_ttl.refresh_token_days;
1214 crate::auth::tokens::rotate_refresh_token(
1215 &self.db_pool,
1216 old_refresh_token,
1217 access_ttl,
1218 refresh_ttl,
1219 move |uid, r, ttl| {
1220 let claims = Claims::builder()
1221 .subject(uid)
1222 .roles(r.iter().map(|s| s.to_string()).collect())
1223 .duration_secs(ttl)
1224 .build()
1225 .map_err(crate::error::ForgeError::internal)?;
1226 issuer.sign(&claims)
1227 },
1228 )
1229 .await
1230 }
1231
1232 pub async fn revoke_refresh_token(&self, refresh_token: &str) -> crate::error::Result<()> {
1234 crate::auth::tokens::revoke_refresh_token(&self.db_pool, refresh_token).await
1235 }
1236
1237 pub async fn revoke_all_refresh_tokens(&self, user_id: Uuid) -> crate::error::Result<()> {
1239 crate::auth::tokens::revoke_all_refresh_tokens(&self.db_pool, user_id).await
1240 }
1241
1242 pub async fn dispatch_job<T: serde::Serialize>(
1251 &self,
1252 job_type: &str,
1253 args: T,
1254 ) -> crate::error::Result<Uuid> {
1255 if self.max_jobs_per_request > 0 {
1256 let count = self.dispatched_job_count.fetch_add(1, Ordering::Relaxed);
1257 if count >= self.max_jobs_per_request {
1258 self.dispatched_job_count.fetch_sub(1, Ordering::Relaxed);
1261 return Err(crate::error::ForgeError::Validation(format!(
1262 "max_jobs_per_request limit of {} exceeded",
1263 self.max_jobs_per_request
1264 )));
1265 }
1266 }
1267
1268 let args_json = serde_json::to_value(args)?;
1269 let dispatcher = self
1270 .job_dispatch
1271 .as_ref()
1272 .ok_or_else(|| crate::error::ForgeError::internal("Job dispatch not available"))?;
1273
1274 if let Some(tx) = &self.tx {
1275 let mut guard = tx.lock().await;
1276 let conn = guard.as_mut().ok_or_else(|| {
1277 crate::error::ForgeError::internal("Transaction already taken; cannot dispatch job")
1278 })?;
1279 return dispatcher
1280 .dispatch_in_conn(
1281 conn,
1282 job_type,
1283 args_json,
1284 self.auth.principal_id(),
1285 self.auth.tenant_id(),
1286 )
1287 .await;
1288 }
1289
1290 dispatcher
1291 .dispatch_by_name(
1292 job_type,
1293 args_json,
1294 self.auth.principal_id(),
1295 self.auth.tenant_id(),
1296 )
1297 .await
1298 }
1299
1300 pub async fn dispatch_job_at<T: serde::Serialize>(
1310 &self,
1311 job_type: &str,
1312 args: T,
1313 scheduled_at: DateTime<Utc>,
1314 ) -> crate::error::Result<Uuid> {
1315 if self.max_jobs_per_request > 0 {
1316 let count = self.dispatched_job_count.fetch_add(1, Ordering::Relaxed);
1317 if count >= self.max_jobs_per_request {
1318 self.dispatched_job_count.fetch_sub(1, Ordering::Relaxed);
1319 return Err(crate::error::ForgeError::Validation(format!(
1320 "max_jobs_per_request limit of {} exceeded",
1321 self.max_jobs_per_request
1322 )));
1323 }
1324 }
1325
1326 let args_json = serde_json::to_value(args)?;
1327 let dispatcher = self
1328 .job_dispatch
1329 .as_ref()
1330 .ok_or_else(|| crate::error::ForgeError::internal("Job dispatch not available"))?;
1331
1332 if let Some(tx) = &self.tx {
1333 let mut guard = tx.lock().await;
1334 let conn = guard.as_mut().ok_or_else(|| {
1335 crate::error::ForgeError::internal("Transaction already taken; cannot dispatch job")
1336 })?;
1337 return dispatcher
1338 .dispatch_in_conn_at(
1339 conn,
1340 job_type,
1341 args_json,
1342 scheduled_at,
1343 self.auth.principal_id(),
1344 self.auth.tenant_id(),
1345 )
1346 .await;
1347 }
1348
1349 dispatcher
1350 .dispatch_by_name_at(
1351 job_type,
1352 args_json,
1353 scheduled_at,
1354 self.auth.principal_id(),
1355 self.auth.tenant_id(),
1356 )
1357 .await
1358 }
1359
1360 pub async fn dispatch_job_after<T: serde::Serialize>(
1369 &self,
1370 job_type: &str,
1371 args: T,
1372 delay: Duration,
1373 ) -> crate::error::Result<Uuid> {
1374 let scheduled_at = Utc::now()
1375 + chrono::Duration::from_std(delay)
1376 .map_err(|_| crate::error::ForgeError::InvalidArgument("delay too large".into()))?;
1377 self.dispatch_job_at(job_type, args, scheduled_at).await
1378 }
1379
1380 pub async fn dispatch<J: crate::ForgeJob>(&self, args: J::Args) -> crate::error::Result<Uuid> {
1383 self.dispatch_job(J::info().name, args).await
1384 }
1385
1386 pub async fn dispatch_at<J: crate::ForgeJob>(
1388 &self,
1389 args: J::Args,
1390 scheduled_at: DateTime<Utc>,
1391 ) -> crate::error::Result<Uuid> {
1392 self.dispatch_job_at(J::info().name, args, scheduled_at)
1393 .await
1394 }
1395
1396 pub async fn dispatch_after<J: crate::ForgeJob>(
1398 &self,
1399 args: J::Args,
1400 delay: Duration,
1401 ) -> crate::error::Result<Uuid> {
1402 self.dispatch_job_after(J::info().name, args, delay).await
1403 }
1404
1405 pub async fn cancel_job(
1407 &self,
1408 job_id: Uuid,
1409 reason: Option<String>,
1410 ) -> crate::error::Result<bool> {
1411 let dispatcher = self
1412 .job_dispatch
1413 .as_ref()
1414 .ok_or_else(|| crate::error::ForgeError::internal("Job dispatch not available"))?;
1415 dispatcher.cancel(job_id, reason).await
1416 }
1417
1418 pub async fn start_workflow<T: serde::Serialize>(
1425 &self,
1426 workflow_name: &str,
1427 input: T,
1428 ) -> crate::error::Result<Uuid> {
1429 let input_json = serde_json::to_value(input)?;
1430 let dispatcher = self
1431 .workflow_dispatch
1432 .as_ref()
1433 .ok_or_else(|| crate::error::ForgeError::internal("Workflow dispatch not available"))?;
1434
1435 let trace_id = Some(self.request.trace_id().to_string());
1436
1437 if let Some(tx) = &self.tx {
1438 let mut guard = tx.lock().await;
1439 let conn = guard.as_mut().ok_or_else(|| {
1440 crate::error::ForgeError::internal(
1441 "Transaction already taken; cannot start workflow",
1442 )
1443 })?;
1444 return dispatcher
1445 .start_in_conn(
1446 conn,
1447 workflow_name,
1448 input_json,
1449 self.auth.principal_id(),
1450 trace_id,
1451 )
1452 .await;
1453 }
1454
1455 dispatcher
1456 .start_by_name(
1457 workflow_name,
1458 input_json,
1459 self.auth.principal_id(),
1460 trace_id,
1461 )
1462 .await
1463 }
1464
1465 pub async fn start<W: crate::ForgeWorkflow>(
1468 &self,
1469 input: W::Input,
1470 ) -> crate::error::Result<Uuid> {
1471 self.start_workflow(W::info().name, input).await
1472 }
1473}
1474
1475impl EnvAccess for MutationContext {
1476 fn env_provider(&self) -> &dyn EnvProvider {
1477 self.env_provider.as_ref()
1478 }
1479}
1480
1481#[cfg(test)]
1482#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
1483mod tests {
1484 use super::*;
1485
1486 #[test]
1487 fn test_auth_context_unauthenticated() {
1488 let ctx = AuthContext::unauthenticated();
1489 assert!(!ctx.is_authenticated());
1490 assert!(ctx.user_id().is_none());
1491 assert!(ctx.require_user_id().is_err());
1492 }
1493
1494 #[test]
1495 fn test_auth_context_authenticated() {
1496 let user_id = Uuid::new_v4();
1497 let ctx = AuthContext::authenticated(
1498 user_id,
1499 vec!["admin".to_string(), "user".to_string()],
1500 HashMap::new(),
1501 );
1502
1503 assert!(ctx.is_authenticated());
1504 assert_eq!(ctx.user_id(), Some(user_id));
1505 assert!(ctx.require_user_id().is_ok());
1506 assert!(ctx.has_role("admin"));
1507 assert!(ctx.has_role("user"));
1508 assert!(!ctx.has_role("superadmin"));
1509 assert!(ctx.require_role("admin").is_ok());
1510 assert!(ctx.require_role("superadmin").is_err());
1511 }
1512
1513 #[test]
1514 fn test_auth_context_with_claims() {
1515 let mut claims = HashMap::new();
1516 claims.insert("org_id".to_string(), serde_json::json!("org-123"));
1517
1518 let ctx = AuthContext::authenticated(Uuid::new_v4(), vec![], claims);
1519
1520 assert_eq!(ctx.claim("org_id"), Some(&serde_json::json!("org-123")));
1521 assert!(ctx.claim("nonexistent").is_none());
1522 }
1523
1524 #[test]
1525 fn test_request_metadata() {
1526 let meta = RequestMetadata::new();
1527 assert!(!meta.trace_id.is_empty());
1528 assert!(meta.client_ip.is_none());
1529
1530 let meta2 = RequestMetadata::with_trace_id("trace-123".to_string());
1531 assert_eq!(meta2.trace_id, "trace-123");
1532 }
1533
1534 #[test]
1535 fn auth_context_without_uuid_carries_claims_but_no_user_id() {
1536 let mut claims = HashMap::new();
1537 claims.insert("sub".to_string(), serde_json::json!("user@example.com"));
1538 let ctx = AuthContext::authenticated_without_uuid(vec!["user".to_string()], claims);
1539
1540 assert!(ctx.is_authenticated());
1541 assert!(ctx.user_id().is_none());
1542 assert!(ctx.require_user_id().is_err());
1543 assert_eq!(ctx.subject(), Some("user@example.com"));
1544 assert!(ctx.has_role("user"));
1545 }
1546
1547 #[test]
1548 fn require_subject_errors_when_unauthenticated() {
1549 let ctx = AuthContext::unauthenticated();
1550 let err = ctx.require_subject().unwrap_err();
1551 assert!(matches!(err, crate::error::ForgeError::Unauthorized(_)));
1552 }
1553
1554 #[test]
1555 fn require_subject_errors_when_authenticated_without_sub_claim() {
1556 let ctx = AuthContext::authenticated(Uuid::new_v4(), vec![], HashMap::new());
1560 let err = ctx.require_subject().unwrap_err();
1561 assert!(matches!(err, crate::error::ForgeError::Unauthorized(_)));
1562 }
1563
1564 #[test]
1565 fn require_subject_returns_sub_claim_when_present() {
1566 let mut claims = HashMap::new();
1567 claims.insert("sub".to_string(), serde_json::json!("abc"));
1568 let ctx = AuthContext::authenticated(Uuid::new_v4(), vec![], claims);
1569 assert_eq!(ctx.require_subject().unwrap(), "abc");
1570 }
1571
1572 #[test]
1573 fn principal_id_prefers_sub_claim_over_uuid() {
1574 let user_id = Uuid::new_v4();
1575 let mut claims = HashMap::new();
1576 claims.insert("sub".to_string(), serde_json::json!("string-sub"));
1577 let ctx = AuthContext::authenticated(user_id, vec![], claims);
1578 assert_eq!(ctx.principal_id(), Some("string-sub".to_string()));
1579 }
1580
1581 #[test]
1582 fn principal_id_falls_back_to_uuid_when_no_sub_claim() {
1583 let user_id = Uuid::new_v4();
1584 let ctx = AuthContext::authenticated(user_id, vec![], HashMap::new());
1585 assert_eq!(ctx.principal_id(), Some(user_id.to_string()));
1586 }
1587
1588 #[test]
1589 fn principal_id_is_none_for_unauthenticated_with_no_sub() {
1590 let ctx = AuthContext::unauthenticated();
1591 assert_eq!(ctx.principal_id(), None);
1592 }
1593
1594 #[test]
1595 fn is_admin_only_true_when_admin_role_present() {
1596 let plain = AuthContext::authenticated(Uuid::new_v4(), vec!["user".into()], HashMap::new());
1597 assert!(!plain.is_admin());
1598 let admin =
1599 AuthContext::authenticated(Uuid::new_v4(), vec!["admin".into()], HashMap::new());
1600 assert!(admin.is_admin());
1601 assert!(!AuthContext::unauthenticated().is_admin());
1602 }
1603
1604 #[test]
1605 fn tenant_id_parses_valid_uuid_claim() {
1606 let tenant = Uuid::new_v4();
1607 let mut claims = HashMap::new();
1608 claims.insert(
1609 "tenant_id".to_string(),
1610 serde_json::json!(tenant.to_string()),
1611 );
1612 let ctx = AuthContext::authenticated(Uuid::new_v4(), vec![], claims);
1613 assert_eq!(ctx.tenant_id(), Some(tenant));
1614 }
1615
1616 #[test]
1617 fn tenant_id_returns_none_for_missing_or_invalid_claim() {
1618 let ctx = AuthContext::authenticated(Uuid::new_v4(), vec![], HashMap::new());
1619 assert!(ctx.tenant_id().is_none());
1620
1621 let mut claims = HashMap::new();
1622 claims.insert("tenant_id".to_string(), serde_json::json!(123));
1623 let ctx = AuthContext::authenticated(Uuid::new_v4(), vec![], claims);
1624 assert!(ctx.tenant_id().is_none());
1625
1626 let mut claims = HashMap::new();
1627 claims.insert("tenant_id".to_string(), serde_json::json!("not-a-uuid"));
1628 let ctx = AuthContext::authenticated(Uuid::new_v4(), vec![], claims);
1629 assert!(ctx.tenant_id().is_none());
1630 }
1631
1632 #[test]
1633 fn token_exp_round_trips_and_drives_expiry_check() {
1634 let ctx = AuthContext::authenticated(Uuid::new_v4(), vec![], HashMap::new());
1635 assert!(!ctx.token_is_expired());
1636 assert!(ctx.token_exp().is_none());
1637
1638 let expired = ctx.clone().with_token_exp(1);
1639 assert_eq!(expired.token_exp(), Some(1));
1640 assert!(expired.token_is_expired());
1641
1642 let live = ctx.with_token_exp(chrono::Utc::now().timestamp() + 3600);
1643 assert!(!live.token_is_expired());
1644 }
1645
1646 #[test]
1647 fn token_is_expired_false_for_unauthenticated_without_exp() {
1648 let ctx = AuthContext::unauthenticated();
1649 assert!(!ctx.token_is_expired());
1650 }
1651
1652 #[test]
1653 fn claims_and_roles_accessors_return_stored_values() {
1654 let mut claims = HashMap::new();
1655 claims.insert("k".to_string(), serde_json::json!("v"));
1656 let ctx = AuthContext::authenticated(
1657 Uuid::new_v4(),
1658 vec!["a".into(), "b".into()],
1659 claims.clone(),
1660 );
1661
1662 assert_eq!(ctx.claims(), &claims);
1663 assert_eq!(ctx.roles(), &["a".to_string(), "b".to_string()]);
1664 }
1665
1666 #[test]
1667 fn request_metadata_setters_mutate_fields() {
1668 let mut meta = RequestMetadata::new();
1669 meta.set_client_ip(Some("1.2.3.4".to_string()));
1670 meta.set_user_agent(Some("ua/1".to_string()));
1671 meta.set_correlation_id(Some("corr-1".to_string()));
1672
1673 assert_eq!(meta.client_ip(), Some("1.2.3.4"));
1674 assert_eq!(meta.user_agent(), Some("ua/1"));
1675 assert_eq!(meta.correlation_id(), Some("corr-1"));
1676
1677 meta.set_client_ip(None);
1678 assert!(meta.client_ip().is_none());
1679 }
1680
1681 #[test]
1682 fn request_metadata_internal_constructor_carries_fields() {
1683 let rid = Uuid::new_v4();
1684 let meta = RequestMetadata::__build_internal(
1685 rid,
1686 "t-1".into(),
1687 Some("ip".into()),
1688 Some("ua".into()),
1689 Some("corr".into()),
1690 );
1691 assert_eq!(meta.request_id(), rid);
1692 assert_eq!(meta.trace_id(), "t-1");
1693 assert_eq!(meta.client_ip(), Some("ip"));
1694 assert_eq!(meta.user_agent(), Some("ua"));
1695 assert_eq!(meta.correlation_id(), Some("corr"));
1696 }
1697
1698 #[test]
1699 fn request_metadata_default_matches_new() {
1700 let a = RequestMetadata::default();
1701 let b = RequestMetadata::new();
1702 assert!(a.client_ip().is_none());
1703 assert!(b.user_agent().is_none());
1704 }
1705
1706 #[test]
1707 fn auth_token_ttl_default_is_one_hour_and_thirty_days() {
1708 let ttl = AuthTokenTtl::default();
1709 assert_eq!(ttl.access_token_secs, 3600);
1710 assert_eq!(ttl.refresh_token_days, 30);
1711
1712 let custom = AuthTokenTtl::new(60, 7);
1713 assert_eq!(custom.access_token_secs, 60);
1714 assert_eq!(custom.refresh_token_days, 7);
1715 }
1716
1717 #[test]
1718 fn sql_operation_classifies_common_prefixes() {
1719 assert_eq!(sql_operation("SELECT 1"), "SELECT");
1720 assert_eq!(sql_operation(" select * from users"), "SELECT");
1721 assert_eq!(sql_operation("Insert into x values (1)"), "INSERT");
1722 assert_eq!(sql_operation("UPDATE x SET v = 1"), "UPDATE");
1723 assert_eq!(sql_operation("delete from x"), "DELETE");
1724 }
1725
1726 #[test]
1727 fn sql_operation_falls_back_to_other_for_unknown_or_short() {
1728 assert_eq!(
1729 sql_operation("WITH cte AS (SELECT 1) SELECT * FROM cte"),
1730 "OTHER"
1731 );
1732 assert_eq!(sql_operation("BEGIN"), "OTHER");
1733 assert_eq!(sql_operation(""), "OTHER");
1734 assert_eq!(sql_operation("hi"), "OTHER");
1735 }
1736}