1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{
5 AuthContext, ForgeError, FunctionInfo, FunctionKind, JobDispatch, MutationContext,
6 QueryContext, RequestMetadata, Result, WorkflowDispatch,
7 rate_limit::{RateLimitConfig, RateLimitKey},
8};
9use serde_json::Value;
10
11use super::cache::QueryCache;
12use super::registry::{FunctionEntry, FunctionRegistry};
13use crate::rate_limit::RateLimiter;
14
15pub enum RouteResult {
17 Query(Value),
19 Mutation(Value),
21}
22
23pub struct FunctionRouter {
25 registry: Arc<FunctionRegistry>,
26 db_pool: sqlx::PgPool,
27 http_client: reqwest::Client,
28 job_dispatcher: Option<Arc<dyn JobDispatch>>,
29 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
30 rate_limiter: RateLimiter,
31 query_cache: QueryCache,
32}
33
34impl FunctionRouter {
35 pub fn new(registry: Arc<FunctionRegistry>, db_pool: sqlx::PgPool) -> Self {
37 let rate_limiter = RateLimiter::new(db_pool.clone());
38 Self {
39 registry,
40 db_pool,
41 http_client: reqwest::Client::new(),
42 job_dispatcher: None,
43 workflow_dispatcher: None,
44 rate_limiter,
45 query_cache: QueryCache::new(),
46 }
47 }
48
49 pub fn with_http_client(
51 registry: Arc<FunctionRegistry>,
52 db_pool: sqlx::PgPool,
53 http_client: reqwest::Client,
54 ) -> Self {
55 let rate_limiter = RateLimiter::new(db_pool.clone());
56 Self {
57 registry,
58 db_pool,
59 http_client,
60 job_dispatcher: None,
61 workflow_dispatcher: None,
62 rate_limiter,
63 query_cache: QueryCache::new(),
64 }
65 }
66
67 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
69 self.job_dispatcher = Some(dispatcher);
70 self
71 }
72
73 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
75 self.workflow_dispatcher = Some(dispatcher);
76 self
77 }
78
79 pub async fn route(
81 &self,
82 function_name: &str,
83 args: Value,
84 auth: AuthContext,
85 request: RequestMetadata,
86 ) -> Result<RouteResult> {
87 let entry = self.registry.get(function_name).ok_or_else(|| {
88 ForgeError::NotFound(format!("Function '{}' not found", function_name))
89 })?;
90
91 self.check_auth(entry.info(), &auth)?;
93
94 self.check_rate_limit(entry.info(), function_name, &auth, &request)
96 .await?;
97
98 match entry {
99 FunctionEntry::Query { handler, info, .. } => {
100 if let Some(ttl) = info.cache_ttl {
102 if let Some(cached) = self.query_cache.get(function_name, &args) {
103 return Ok(RouteResult::Query(cached));
104 }
105
106 let ctx = QueryContext::new(self.db_pool.clone(), auth, request);
108 let result = handler(&ctx, args.clone()).await?;
109
110 self.query_cache.set(
111 function_name,
112 &args,
113 result.clone(),
114 Duration::from_secs(ttl),
115 );
116
117 Ok(RouteResult::Query(result))
118 } else {
119 let ctx = QueryContext::new(self.db_pool.clone(), auth, request);
120 let result = handler(&ctx, args).await?;
121 Ok(RouteResult::Query(result))
122 }
123 }
124 FunctionEntry::Mutation { handler, .. } => {
125 let ctx = MutationContext::with_dispatch(
126 self.db_pool.clone(),
127 auth,
128 request,
129 self.http_client.clone(),
130 self.job_dispatcher.clone(),
131 self.workflow_dispatcher.clone(),
132 );
133 let result = handler(&ctx, args).await?;
134 Ok(RouteResult::Mutation(result))
135 }
136 }
137 }
138
139 fn check_auth(&self, info: &FunctionInfo, auth: &AuthContext) -> Result<()> {
141 if info.is_public {
143 return Ok(());
144 }
145
146 if info.requires_auth && !auth.is_authenticated() {
148 return Err(ForgeError::Unauthorized("Authentication required".into()));
149 }
150
151 if let Some(role) = info.required_role {
153 if !auth.is_authenticated() {
155 return Err(ForgeError::Unauthorized("Authentication required".into()));
156 }
157 if !auth.has_role(role) {
158 return Err(ForgeError::Forbidden(format!("Role '{}' required", role)));
159 }
160 }
161
162 Ok(())
163 }
164
165 async fn check_rate_limit(
167 &self,
168 info: &FunctionInfo,
169 function_name: &str,
170 auth: &AuthContext,
171 request: &RequestMetadata,
172 ) -> Result<()> {
173 let (requests, per_secs) = match (info.rate_limit_requests, info.rate_limit_per_secs) {
175 (Some(r), Some(p)) => (r, p),
176 _ => return Ok(()),
177 };
178
179 let key_type: RateLimitKey = info
181 .rate_limit_key
182 .unwrap_or("user")
183 .parse()
184 .unwrap_or_default();
185
186 let config =
187 RateLimitConfig::new(requests, Duration::from_secs(per_secs)).with_key(key_type);
188
189 let bucket_key = self
191 .rate_limiter
192 .build_key(key_type, function_name, auth, request);
193
194 self.rate_limiter.enforce(&bucket_key, &config).await?;
196
197 Ok(())
198 }
199
200 pub fn get_function_kind(&self, function_name: &str) -> Option<FunctionKind> {
202 self.registry.get(function_name).map(|e| e.kind())
203 }
204
205 pub fn has_function(&self, function_name: &str) -> bool {
207 self.registry.get(function_name).is_some()
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
216 fn test_check_auth_public() {
217 let info = FunctionInfo {
218 name: "test",
219 description: None,
220 kind: FunctionKind::Query,
221 requires_auth: false,
222 required_role: None,
223 is_public: true,
224 cache_ttl: None,
225 timeout: None,
226 rate_limit_requests: None,
227 rate_limit_per_secs: None,
228 rate_limit_key: None,
229 log_level: None,
230 table_dependencies: &[],
231 transactional: false,
232 };
233
234 let _auth = AuthContext::unauthenticated();
235
236 assert!(info.is_public);
239 }
240}