use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Duration;
use chrono::Utc;
use forge_core::{
AuthContext, CircuitBreakerClient, ForgeError, FunctionInfo, FunctionKind, JobDispatch,
MutationContext, OutboxBuffer, PendingJob, PendingWorkflow, QueryContext, RequestMetadata,
Result, WorkflowDispatch,
job::JobStatus,
rate_limit::{RateLimitConfig, RateLimitKey},
workflow::WorkflowStatus,
};
use serde_json::Value;
use tracing::Instrument;
use super::cache::QueryCache;
use super::registry::{BoxedMutationFn, FunctionEntry, FunctionRegistry};
use crate::db::Database;
use crate::rate_limit::HybridRateLimiter;
fn require_auth(is_public: bool, required_role: Option<&str>, auth: &AuthContext) -> Result<()> {
if is_public {
return Ok(());
}
if !auth.is_authenticated() {
return Err(ForgeError::Unauthorized("Authentication required".into()));
}
if let Some(role) = required_role
&& !auth.has_role(role)
{
return Err(ForgeError::Forbidden(format!("Role '{role}' required")));
}
Ok(())
}
pub enum RouteResult {
Query(Value),
Mutation(Value),
Job(Value),
Workflow(Value),
}
pub struct FunctionRouter {
registry: Arc<FunctionRegistry>,
db: Database,
http_client: CircuitBreakerClient,
job_dispatcher: Option<Arc<dyn JobDispatch>>,
workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
rate_limiter: HybridRateLimiter,
query_cache: QueryCache,
token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
token_ttl: forge_core::AuthTokenTtl,
}
impl FunctionRouter {
pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
let rate_limiter = HybridRateLimiter::new(db.primary().clone());
Self {
registry,
db,
http_client: CircuitBreakerClient::with_defaults(reqwest::Client::new()),
job_dispatcher: None,
workflow_dispatcher: None,
rate_limiter,
query_cache: QueryCache::new(),
token_issuer: None,
token_ttl: forge_core::AuthTokenTtl::default(),
}
}
pub fn with_http_client(
registry: Arc<FunctionRegistry>,
db: Database,
http_client: CircuitBreakerClient,
) -> Self {
let rate_limiter = HybridRateLimiter::new(db.primary().clone());
Self {
registry,
db,
http_client,
job_dispatcher: None,
workflow_dispatcher: None,
rate_limiter,
query_cache: QueryCache::new(),
token_issuer: None,
token_ttl: forge_core::AuthTokenTtl::default(),
}
}
pub fn with_token_issuer(mut self, issuer: Arc<dyn forge_core::TokenIssuer>) -> Self {
self.token_issuer = Some(issuer);
self
}
pub fn with_token_ttl(mut self, ttl: forge_core::AuthTokenTtl) -> Self {
self.token_ttl = ttl;
self
}
pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
self.token_ttl = ttl;
}
pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
self.job_dispatcher = Some(dispatcher);
self
}
pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
self.workflow_dispatcher = Some(dispatcher);
self
}
pub async fn route(
&self,
function_name: &str,
args: Value,
auth: AuthContext,
request: RequestMetadata,
) -> Result<RouteResult> {
if let Some(entry) = self.registry.get(function_name) {
self.check_auth(entry.info(), &auth)?;
if !entry.info().is_public {
self.verify_user_exists(&auth).await?;
}
self.check_rate_limit(entry.info(), function_name, &auth, &request)
.await?;
return match entry {
FunctionEntry::Query { handler, info, .. } => {
let pool = if info.consistent {
self.db.primary().clone()
} else {
self.db.read_pool().clone()
};
let auth_scope = Self::auth_cache_scope(&auth);
if let Some(ttl) = info.cache_ttl {
if let Some(cached) =
self.query_cache
.get(function_name, &args, auth_scope.as_deref())
{
return Ok(RouteResult::Query(Value::clone(&cached)));
}
let ctx = QueryContext::new(pool, auth, request);
let result = handler(&ctx, args.clone()).await?;
self.query_cache.set(
function_name,
&args,
auth_scope.as_deref(),
result.clone(),
Duration::from_secs(ttl),
);
Ok(RouteResult::Query(result))
} else {
let ctx = QueryContext::new(pool, auth, request);
let result = handler(&ctx, args).await?;
Ok(RouteResult::Query(result))
}
}
FunctionEntry::Mutation { handler, info } => {
if info.transactional {
self.execute_transactional(info, handler, args, auth, request)
.await
} else {
let mut ctx = MutationContext::with_dispatch(
self.db.primary().clone(),
auth,
request,
self.http_client.clone(),
self.job_dispatcher.clone(),
self.workflow_dispatcher.clone(),
);
if let Some(ref issuer) = self.token_issuer {
ctx.set_token_issuer(issuer.clone());
}
ctx.set_token_ttl(self.token_ttl.clone());
ctx.set_http_timeout(info.http_timeout.map(Duration::from_secs));
let result = handler(&ctx, args).await?;
Ok(RouteResult::Mutation(result))
}
}
};
}
if let Some(ref job_dispatcher) = self.job_dispatcher
&& let Some(job_info) = job_dispatcher.get_info(function_name)
{
self.check_job_auth(&job_info, &auth)?;
match job_dispatcher
.dispatch_by_name(function_name, args.clone(), auth.principal_id())
.await
{
Ok(job_id) => {
return Ok(RouteResult::Job(serde_json::json!({ "job_id": job_id })));
}
Err(ForgeError::NotFound(_)) => {}
Err(e) => return Err(e),
}
}
if let Some(ref workflow_dispatcher) = self.workflow_dispatcher
&& let Some(workflow_info) = workflow_dispatcher.get_info(function_name)
{
self.check_workflow_auth(&workflow_info, &auth)?;
match workflow_dispatcher
.start_by_name(function_name, args.clone(), auth.principal_id())
.await
{
Ok(workflow_id) => {
return Ok(RouteResult::Workflow(
serde_json::json!({ "workflow_id": workflow_id }),
));
}
Err(ForgeError::NotFound(_)) => {}
Err(e) => return Err(e),
}
}
Err(ForgeError::NotFound(format!(
"Function '{}' not found",
function_name
)))
}
fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
require_auth(info.is_public, info.required_role, auth)
}
async fn verify_user_exists(&self, auth: &AuthContext) -> Result<()> {
let user_id = match auth.user_id() {
Some(id) => id,
None => return Ok(()),
};
let exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)")
.bind(user_id)
.fetch_one(self.db.read_pool())
.await
.unwrap_or(false);
if !exists {
return Err(ForgeError::Unauthorized("User no longer exists".into()));
}
Ok(())
}
fn check_job_auth(&self, info: &forge_core::job::JobInfo, auth: &AuthContext) -> Result<()> {
require_auth(info.is_public, info.required_role, auth)
}
fn check_workflow_auth(
&self,
info: &forge_core::workflow::WorkflowInfo,
auth: &AuthContext,
) -> Result<()> {
require_auth(info.is_public, info.required_role, auth)
}
async fn check_rate_limit(
&self,
info: &FunctionInfo,
function_name: &str,
auth: &AuthContext,
request: &RequestMetadata,
) -> Result<()> {
let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
(Some(r), Some(p)) => (r, p),
_ => return Ok(()),
};
let key_str = info.rate_limit_key.unwrap_or("user");
let key_type: RateLimitKey = match key_str.parse() {
Ok(k) => k,
Err(_) => {
tracing::error!(
function = %function_name,
key = %key_str,
"Invalid rate limit key, falling back to 'user'"
);
RateLimitKey::default()
}
};
let config =
RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
let bucket_key = self
.rate_limiter
.build_key(key_type, function_name, auth, request);
self.rate_limiter.enforce(&bucket_key, &config).await?;
Ok(())
}
fn auth_cache_scope(auth: &AuthContext) -> Option<String> {
if !auth.is_authenticated() {
return Some("anon".to_string());
}
let mut roles = auth.roles().to_vec();
roles.sort();
roles.dedup();
let mut claims = BTreeMap::new();
for (k, v) in auth.claims() {
claims.insert(k.clone(), v.clone());
}
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
roles.hash(&mut hasher);
serde_json::to_string(&claims)
.unwrap_or_default()
.hash(&mut hasher);
let principal = auth
.principal_id()
.unwrap_or_else(|| "authenticated".to_string());
Some(format!(
"subject:{principal}:scope:{:016x}",
hasher.finish()
))
}
pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
self.registry.get(function_name).map(|e| e.kind())
}
pub fn has_function(&self, function_name: &str) -> bool {
self.registry.get(function_name).is_some()
}
async fn execute_transactional(
&self,
info: &FunctionInfo,
handler: &BoxedMutationFn,
args: Value,
auth: AuthContext,
request: RequestMetadata,
) -> Result<RouteResult> {
let span = tracing::info_span!("db.transaction", db.system = "postgresql",);
async {
let primary = self.db.primary();
let tx = primary
.begin()
.await
.map_err(|e| ForgeError::Database(e.to_string()))?;
let job_dispatcher = self.job_dispatcher.clone();
let job_lookup: forge_core::JobInfoLookup =
Arc::new(move |name: &str| job_dispatcher.as_ref().and_then(|d| d.get_info(name)));
let (mut ctx, tx_handle, outbox) = MutationContext::with_transaction(
primary.clone(),
tx,
auth,
request,
self.http_client.clone(),
job_lookup,
);
if let Some(ref issuer) = self.token_issuer {
ctx.set_token_issuer(issuer.clone());
}
ctx.set_token_ttl(self.token_ttl.clone());
ctx.set_http_timeout(info.http_timeout.map(Duration::from_secs));
match handler(&ctx, args).await {
Ok(value) => {
drop(ctx);
let buffer = {
let guard = outbox.lock().unwrap_or_else(|poisoned| {
tracing::error!("Outbox mutex was poisoned, recovering");
poisoned.into_inner()
});
OutboxBuffer {
jobs: guard.jobs.clone(),
workflows: guard.workflows.clone(),
}
};
let mut tx = Arc::try_unwrap(tx_handle)
.map_err(|_| ForgeError::Internal("Transaction still in use".into()))?
.into_inner();
for job in &buffer.jobs {
Self::insert_job(&mut tx, job).await?;
}
for workflow in &buffer.workflows {
if self
.workflow_dispatcher
.as_ref()
.and_then(|d| d.get_info(&workflow.workflow_name))
.is_none()
{
return Err(ForgeError::NotFound(format!(
"Workflow '{}' not found",
workflow.workflow_name
)));
}
Self::insert_workflow(&mut tx, workflow).await?;
}
tx.commit()
.await
.map_err(|e| ForgeError::Database(e.to_string()))?;
Ok(RouteResult::Mutation(value))
}
Err(e) => {
drop(ctx);
Err(e)
}
}
}
.instrument(span)
.await
}
async fn insert_job(
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
job: &PendingJob,
) -> Result<()> {
let now = Utc::now();
sqlx::query!(
r#"
INSERT INTO forge_jobs (
id, job_type, input, job_context, status, priority, attempts, max_attempts,
worker_capability, owner_subject, scheduled_at, created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
"#,
job.id,
&job.job_type,
job.args as _,
job.context as _,
JobStatus::Pending.as_str(),
job.priority,
0i32,
job.max_attempts,
job.worker_capability.as_deref(),
job.owner_subject as _,
now,
now,
)
.execute(&mut **tx)
.await
.map_err(|e| ForgeError::Database(e.to_string()))?;
Ok(())
}
async fn insert_workflow(
tx: &mut sqlx::Transaction<'_, sqlx::Postgres>,
workflow: &PendingWorkflow,
) -> Result<()> {
let now = Utc::now();
sqlx::query!(
r#"
INSERT INTO forge_workflow_runs (
id, workflow_name, workflow_version, workflow_signature,
owner_subject, input, status, current_step,
step_results, started_at, trace_id
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
"#,
workflow.id,
&workflow.workflow_name,
&workflow.workflow_version,
&workflow.workflow_signature,
workflow.owner_subject as _,
workflow.input as _,
WorkflowStatus::Created.as_str(),
Option::<String>::None,
serde_json::json!({}) as _,
now,
workflow.id.to_string(),
)
.execute(&mut **tx)
.await
.map_err(|e| ForgeError::Database(e.to_string()))?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_check_auth_public() {
let info = FunctionInfo {
name: "test",
description: None,
kind: FunctionKind::Query,
required_role: None,
is_public: true,
cache_ttl: None,
timeout: None,
http_timeout: None,
rate_limit_requests: None,
rate_limit_per_secs: None,
rate_limit_key: None,
log_level: None,
table_dependencies: &[],
selected_columns: &[],
transactional: false,
consistent: false,
max_upload_size_bytes: None,
};
let _auth = AuthContext::unauthenticated();
assert!(info.is_public);
}
#[test]
fn test_auth_cache_scope_changes_with_claims() {
let user_id = uuid::Uuid::new_v4();
let auth_a = AuthContext::authenticated(
user_id,
vec!["user".to_string()],
HashMap::from([
(
"sub".to_string(),
serde_json::Value::String(user_id.to_string()),
),
(
"tenant_id".to_string(),
serde_json::Value::String("tenant-a".to_string()),
),
]),
);
let auth_b = AuthContext::authenticated(
user_id,
vec!["user".to_string()],
HashMap::from([
(
"sub".to_string(),
serde_json::Value::String(user_id.to_string()),
),
(
"tenant_id".to_string(),
serde_json::Value::String("tenant-b".to_string()),
),
]),
);
let scope_a = FunctionRouter::auth_cache_scope(&auth_a);
let scope_b = FunctionRouter::auth_cache_scope(&auth_b);
assert_ne!(scope_a, scope_b);
}
}