1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use futures_util::StreamExt;
5use futures_util::stream::FuturesUnordered;
6use tokio::sync::{RwLock, Semaphore, broadcast, mpsc};
7use uuid::Uuid;
8
9use forge_core::cluster::NodeId;
10use forge_core::realtime::{Change, ReadSet, SessionId, SubscriptionId};
11
12use super::invalidation::{InvalidationConfig, InvalidationEngine};
13use super::listener::{ChangeListener, ListenerConfig};
14use super::manager::SubscriptionManager;
15use super::message::{
16 JobData, RealtimeConfig, RealtimeMessage, SessionServer, WorkflowData, WorkflowStepData,
17};
18use crate::function::{FunctionEntry, FunctionRegistry};
19
20#[derive(Debug, Clone)]
21pub struct ReactorConfig {
22 pub listener: ListenerConfig,
23 pub invalidation: InvalidationConfig,
24 pub realtime: RealtimeConfig,
25 pub max_listener_restarts: u32,
26 pub listener_restart_delay_ms: u64,
28 pub max_concurrent_reexecutions: usize,
30 pub session_cleanup_interval_secs: u64,
32}
33
34impl Default for ReactorConfig {
35 fn default() -> Self {
36 Self {
37 listener: ListenerConfig::default(),
38 invalidation: InvalidationConfig::default(),
39 realtime: RealtimeConfig::default(),
40 max_listener_restarts: 5,
41 listener_restart_delay_ms: 1000,
42 max_concurrent_reexecutions: 64,
43 session_cleanup_interval_secs: 60,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct JobSubscription {
51 #[allow(dead_code)]
52 pub subscription_id: SubscriptionId,
53 pub session_id: SessionId,
54 pub client_sub_id: String,
55 #[allow(dead_code)]
56 pub job_id: Uuid,
57 pub auth_context: forge_core::function::AuthContext,
58}
59
60#[derive(Debug, Clone)]
62pub struct WorkflowSubscription {
63 #[allow(dead_code)]
64 pub subscription_id: SubscriptionId,
65 pub session_id: SessionId,
66 pub client_sub_id: String,
67 #[allow(dead_code)]
68 pub workflow_id: Uuid,
69 pub auth_context: forge_core::function::AuthContext,
70}
71
72pub struct Reactor {
74 node_id: NodeId,
75 db_pool: sqlx::PgPool,
76 registry: FunctionRegistry,
77 subscription_manager: Arc<SubscriptionManager>,
78 session_server: Arc<SessionServer>,
79 change_listener: Arc<ChangeListener>,
80 invalidation_engine: Arc<InvalidationEngine>,
81 job_subscriptions: Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
83 workflow_subscriptions: Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
85 shutdown_tx: broadcast::Sender<()>,
86 max_listener_restarts: u32,
87 listener_restart_delay_ms: u64,
88 max_concurrent_reexecutions: usize,
89 session_cleanup_interval_secs: u64,
90}
91
92impl Reactor {
93 pub fn new(
95 node_id: NodeId,
96 db_pool: sqlx::PgPool,
97 registry: FunctionRegistry,
98 config: ReactorConfig,
99 ) -> Self {
100 let subscription_manager = Arc::new(SubscriptionManager::new(
101 config.realtime.max_subscriptions_per_session,
102 ));
103 let session_server = Arc::new(SessionServer::new(node_id, config.realtime.clone()));
104 let change_listener = Arc::new(ChangeListener::new(db_pool.clone(), config.listener));
105 let invalidation_engine = Arc::new(InvalidationEngine::new(
106 subscription_manager.clone(),
107 config.invalidation,
108 ));
109 let (shutdown_tx, _) = broadcast::channel(1);
110
111 Self {
112 node_id,
113 db_pool,
114 registry,
115 subscription_manager,
116 session_server,
117 change_listener,
118 invalidation_engine,
119 job_subscriptions: Arc::new(RwLock::new(HashMap::new())),
120 workflow_subscriptions: Arc::new(RwLock::new(HashMap::new())),
121 shutdown_tx,
122 max_listener_restarts: config.max_listener_restarts,
123 listener_restart_delay_ms: config.listener_restart_delay_ms,
124 max_concurrent_reexecutions: config.max_concurrent_reexecutions,
125 session_cleanup_interval_secs: config.session_cleanup_interval_secs,
126 }
127 }
128
129 pub fn node_id(&self) -> NodeId {
130 self.node_id
131 }
132
133 pub fn session_server(&self) -> Arc<SessionServer> {
134 self.session_server.clone()
135 }
136
137 pub fn subscription_manager(&self) -> Arc<SubscriptionManager> {
138 self.subscription_manager.clone()
139 }
140
141 pub fn shutdown_receiver(&self) -> broadcast::Receiver<()> {
142 self.shutdown_tx.subscribe()
143 }
144
145 pub fn register_session(&self, session_id: SessionId, sender: mpsc::Sender<RealtimeMessage>) {
147 self.session_server.register_connection(session_id, sender);
148 tracing::trace!(?session_id, "Session registered");
149 }
150
151 pub async fn remove_session(&self, session_id: SessionId) {
153 self.subscription_manager
155 .remove_session_subscriptions(session_id);
156
157 self.session_server.remove_connection(session_id);
159
160 {
162 let mut job_subs = self.job_subscriptions.write().await;
163 for subscribers in job_subs.values_mut() {
164 subscribers.retain(|s| s.session_id != session_id);
165 }
166 job_subs.retain(|_, v| !v.is_empty());
167 }
168
169 {
171 let mut workflow_subs = self.workflow_subscriptions.write().await;
172 for subscribers in workflow_subs.values_mut() {
173 subscribers.retain(|s| s.session_id != session_id);
174 }
175 workflow_subs.retain(|_, v| !v.is_empty());
176 }
177
178 tracing::trace!(?session_id, "Session removed");
179 }
180
181 pub async fn subscribe(
183 &self,
184 session_id: SessionId,
185 client_sub_id: String,
186 query_name: String,
187 args: serde_json::Value,
188 auth_context: forge_core::function::AuthContext,
189 ) -> forge_core::Result<(SubscriptionId, serde_json::Value)> {
190 let (table_deps, selected_cols) = match self.registry.get(&query_name) {
192 Some(FunctionEntry::Query { info, .. }) => {
193 (info.table_dependencies, info.selected_columns)
194 }
195 _ => (&[] as &[&str], &[] as &[&str]),
196 };
197
198 let (group_id, subscription_id, is_new_group) = self.subscription_manager.subscribe(
199 session_id,
200 client_sub_id,
201 &query_name,
202 &args,
203 &auth_context,
204 table_deps,
205 selected_cols,
206 )?;
207
208 if let Err(error) = self
210 .session_server
211 .add_subscription(session_id, subscription_id)
212 {
213 self.subscription_manager.unsubscribe(subscription_id);
214 return Err(error);
215 }
216
217 let data = if is_new_group {
219 let (data, read_set) = match self.execute_query(&query_name, &args, &auth_context).await
220 {
221 Ok(result) => result,
222 Err(error) => {
223 self.unsubscribe(subscription_id);
224 return Err(error);
225 }
226 };
227
228 let result_hash = Self::compute_hash(&data);
229
230 tracing::trace!(
231 ?group_id,
232 query = %query_name,
233 "New query group created"
234 );
235
236 self.subscription_manager
237 .update_group(group_id, read_set, result_hash);
238
239 data
240 } else {
241 let (data, _) = match self.execute_query(&query_name, &args, &auth_context).await {
244 Ok(result) => result,
245 Err(error) => {
246 self.unsubscribe(subscription_id);
247 return Err(error);
248 }
249 };
250 data
251 };
252
253 tracing::trace!(?subscription_id, "Subscription created");
254 Ok((subscription_id, data))
255 }
256
257 pub fn unsubscribe(&self, subscription_id: SubscriptionId) {
259 self.session_server.remove_subscription(subscription_id);
260 self.subscription_manager.unsubscribe(subscription_id);
261 tracing::trace!(?subscription_id, "Subscription removed");
262 }
263
264 pub async fn subscribe_job(
266 &self,
267 session_id: SessionId,
268 client_sub_id: String,
269 job_id: Uuid,
270 auth_context: &forge_core::function::AuthContext,
271 ) -> forge_core::Result<JobData> {
272 let subscription_id = SubscriptionId::new();
273 Self::ensure_job_access(&self.db_pool, job_id, auth_context).await?;
274 let job_data = self.fetch_job_data(job_id).await?;
275
276 let subscription = JobSubscription {
277 subscription_id,
278 session_id,
279 client_sub_id,
280 job_id,
281 auth_context: auth_context.clone(),
282 };
283
284 let mut subs = self.job_subscriptions.write().await;
285 subs.entry(job_id).or_default().push(subscription);
286
287 tracing::trace!(?subscription_id, %job_id, "Job subscription created");
288 Ok(job_data)
289 }
290
291 pub async fn unsubscribe_job(&self, session_id: SessionId, client_sub_id: &str) {
293 let mut subs = self.job_subscriptions.write().await;
294 for subscribers in subs.values_mut() {
295 subscribers
296 .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id));
297 }
298 subs.retain(|_, v| !v.is_empty());
299 }
300
301 pub async fn subscribe_workflow(
303 &self,
304 session_id: SessionId,
305 client_sub_id: String,
306 workflow_id: Uuid,
307 auth_context: &forge_core::function::AuthContext,
308 ) -> forge_core::Result<WorkflowData> {
309 let subscription_id = SubscriptionId::new();
310 Self::ensure_workflow_access(&self.db_pool, workflow_id, auth_context).await?;
311 let workflow_data = self.fetch_workflow_data(workflow_id).await?;
312
313 let subscription = WorkflowSubscription {
314 subscription_id,
315 session_id,
316 client_sub_id,
317 workflow_id,
318 auth_context: auth_context.clone(),
319 };
320
321 let mut subs = self.workflow_subscriptions.write().await;
322 subs.entry(workflow_id).or_default().push(subscription);
323
324 tracing::trace!(?subscription_id, %workflow_id, "Workflow subscription created");
325 Ok(workflow_data)
326 }
327
328 pub async fn unsubscribe_workflow(&self, session_id: SessionId, client_sub_id: &str) {
330 let mut subs = self.workflow_subscriptions.write().await;
331 for subscribers in subs.values_mut() {
332 subscribers
333 .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id));
334 }
335 subs.retain(|_, v| !v.is_empty());
336 }
337
338 #[allow(clippy::type_complexity)]
339 async fn fetch_job_data(&self, job_id: Uuid) -> forge_core::Result<JobData> {
340 Self::fetch_job_data_static(job_id, &self.db_pool).await
341 }
342
343 async fn fetch_workflow_data(&self, workflow_id: Uuid) -> forge_core::Result<WorkflowData> {
344 Self::fetch_workflow_data_static(workflow_id, &self.db_pool).await
345 }
346
347 async fn execute_query(
349 &self,
350 query_name: &str,
351 args: &serde_json::Value,
352 auth_context: &forge_core::function::AuthContext,
353 ) -> forge_core::Result<(serde_json::Value, ReadSet)> {
354 Self::execute_query_static(
355 &self.registry,
356 &self.db_pool,
357 query_name,
358 args,
359 auth_context,
360 )
361 .await
362 }
363
364 fn compute_hash(data: &serde_json::Value) -> String {
365 use std::collections::hash_map::DefaultHasher;
366 use std::hash::{Hash, Hasher};
367
368 let json = serde_json::to_string(data).unwrap_or_default();
369 let mut hasher = DefaultHasher::new();
370 json.hash(&mut hasher);
371 format!("{:x}", hasher.finish())
372 }
373
374 async fn flush_invalidations(
376 invalidation_engine: &Arc<InvalidationEngine>,
377 subscription_manager: &Arc<SubscriptionManager>,
378 session_server: &Arc<SessionServer>,
379 registry: &FunctionRegistry,
380 db_pool: &sqlx::PgPool,
381 max_concurrent: usize,
382 ) {
383 let invalidated_groups = invalidation_engine.check_pending().await;
384 if invalidated_groups.is_empty() {
385 return;
386 }
387
388 tracing::trace!(
389 count = invalidated_groups.len(),
390 "Invalidating query groups"
391 );
392
393 let groups_to_process: Vec<_> = invalidated_groups
395 .iter()
396 .filter_map(|gid| {
397 subscription_manager.get_group(*gid).map(|g| {
398 (
399 g.id,
400 g.query_name.clone(),
401 (*g.args).clone(),
402 g.last_result_hash.clone(),
403 g.auth_context.clone(),
404 )
405 })
406 })
407 .collect();
408
409 let semaphore = Arc::new(Semaphore::new(max_concurrent));
411 let mut futures = FuturesUnordered::new();
412
413 for (group_id, query_name, args, last_hash, auth_context) in groups_to_process {
414 let permit = match semaphore.clone().acquire_owned().await {
415 Ok(p) => p,
416 Err(_) => break,
417 };
418 let registry = registry.clone();
419 let db_pool = db_pool.clone();
420
421 futures.push(async move {
422 let result = Self::execute_query_static(
423 ®istry,
424 &db_pool,
425 &query_name,
426 &args,
427 &auth_context,
428 )
429 .await;
430 drop(permit);
431 (group_id, last_hash, result)
432 });
433 }
434
435 while let Some((group_id, last_hash, result)) = futures.next().await {
437 match result {
438 Ok((new_data, read_set)) => {
439 let new_hash = Self::compute_hash(&new_data);
440
441 if last_hash.as_ref() != Some(&new_hash) {
442 subscription_manager.update_group(group_id, read_set, new_hash);
444
445 let subscribers = subscription_manager.get_group_subscribers(group_id);
447 for (session_id, client_sub_id) in subscribers {
448 let message = RealtimeMessage::Data {
449 subscription_id: client_sub_id.clone(),
450 data: new_data.clone(),
451 };
452
453 if let Err(e) = session_server.try_send_to_session(session_id, message)
454 {
455 tracing::trace!(
456 client_id = %client_sub_id,
457 error = ?e,
458 "Failed to send update to subscriber"
459 );
460 }
461 }
462 }
463 }
464 Err(e) => {
465 tracing::warn!(?group_id, error = %e, "Failed to re-execute query group");
466 }
467 }
468 }
469 }
470
471 pub async fn start(&self) -> forge_core::Result<()> {
473 let listener = self.change_listener.clone();
474 let invalidation_engine = self.invalidation_engine.clone();
475 let subscription_manager = self.subscription_manager.clone();
476 let job_subscriptions = self.job_subscriptions.clone();
477 let workflow_subscriptions = self.workflow_subscriptions.clone();
478 let session_server = self.session_server.clone();
479 let registry = self.registry.clone();
480 let db_pool = self.db_pool.clone();
481 let mut shutdown_rx = self.shutdown_tx.subscribe();
482 let max_restarts = self.max_listener_restarts;
483 let base_delay_ms = self.listener_restart_delay_ms;
484 let max_concurrent = self.max_concurrent_reexecutions;
485 let cleanup_secs = self.session_cleanup_interval_secs;
486
487 let mut change_rx = listener.subscribe();
488
489 tokio::spawn(async move {
490 tracing::debug!("Reactor listening for changes");
491
492 let mut restart_count: u32 = 0;
493 let (listener_error_tx, mut listener_error_rx) = mpsc::channel::<String>(1);
494
495 let listener_clone = listener.clone();
497 let error_tx = listener_error_tx.clone();
498 let mut listener_handle = Some(tokio::spawn(async move {
499 if let Err(e) = listener_clone.run().await {
500 let _ = error_tx.send(format!("Change listener error: {}", e)).await;
501 }
502 }));
503
504 let mut flush_interval = tokio::time::interval(std::time::Duration::from_millis(25));
505 flush_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
506
507 let mut cleanup_interval =
508 tokio::time::interval(std::time::Duration::from_secs(cleanup_secs));
509 cleanup_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
510
511 loop {
512 tokio::select! {
513 result = change_rx.recv() => {
514 match result {
515 Ok(change) => {
516 Self::handle_change(
517 &change,
518 &invalidation_engine,
519 &job_subscriptions,
520 &workflow_subscriptions,
521 &session_server,
522 &db_pool,
523 ).await;
524 }
525 Err(broadcast::error::RecvError::Lagged(n)) => {
526 tracing::warn!("Reactor lagged by {} messages", n);
527 }
528 Err(broadcast::error::RecvError::Closed) => {
529 tracing::debug!("Change channel closed");
530 break;
531 }
532 }
533 }
534 _ = flush_interval.tick() => {
535 Self::flush_invalidations(
536 &invalidation_engine,
537 &subscription_manager,
538 &session_server,
539 ®istry,
540 &db_pool,
541 max_concurrent,
542 ).await;
543 }
544 _ = cleanup_interval.tick() => {
545 session_server.cleanup_stale(std::time::Duration::from_secs(300));
546 }
547 Some(error_msg) = listener_error_rx.recv() => {
548 if restart_count >= max_restarts {
549 tracing::error!(
550 attempts = restart_count,
551 last_error = %error_msg,
552 "Change listener failed permanently, real-time updates disabled"
553 );
554 break;
555 }
556
557 restart_count += 1;
558 let delay = base_delay_ms * 2u64.saturating_pow(restart_count - 1);
559 tracing::warn!(
560 attempt = restart_count,
561 max = max_restarts,
562 delay_ms = delay,
563 error = %error_msg,
564 "Change listener restarting"
565 );
566
567 tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
568
569 let listener_clone = listener.clone();
570 let error_tx = listener_error_tx.clone();
571 if let Some(handle) = listener_handle.take() {
572 handle.abort();
573 }
574 change_rx = listener.subscribe();
575 listener_handle = Some(tokio::spawn(async move {
576 if let Err(e) = listener_clone.run().await {
577 let _ = error_tx.send(format!("Change listener error: {}", e)).await;
578 }
579 }));
580 }
581 _ = shutdown_rx.recv() => {
582 tracing::debug!("Reactor shutting down");
583 break;
584 }
585 }
586 }
587
588 if let Some(handle) = listener_handle {
589 handle.abort();
590 }
591 });
592
593 Ok(())
594 }
595
596 #[allow(clippy::too_many_arguments)]
598 async fn handle_change(
599 change: &Change,
600 invalidation_engine: &Arc<InvalidationEngine>,
601 job_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
602 workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
603 session_server: &Arc<SessionServer>,
604 db_pool: &sqlx::PgPool,
605 ) {
606 tracing::trace!(table = %change.table, op = ?change.operation, row_id = ?change.row_id, "Processing change");
607
608 match change.table.as_str() {
609 "forge_jobs" => {
610 if let Some(job_id) = change.row_id {
611 Self::handle_job_change(job_id, job_subscriptions, session_server, db_pool)
612 .await;
613 }
614 return;
615 }
616 "forge_workflow_runs" => {
617 if let Some(workflow_id) = change.row_id {
618 Self::handle_workflow_change(
619 workflow_id,
620 workflow_subscriptions,
621 session_server,
622 db_pool,
623 )
624 .await;
625 }
626 return;
627 }
628 "forge_workflow_steps" => {
629 if let Some(step_id) = change.row_id {
630 Self::handle_workflow_step_change(
631 step_id,
632 workflow_subscriptions,
633 session_server,
634 db_pool,
635 )
636 .await;
637 }
638 return;
639 }
640 _ => {}
641 }
642
643 invalidation_engine.process_change(change.clone()).await;
645 }
646
647 async fn handle_job_change(
648 job_id: Uuid,
649 job_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
650 session_server: &Arc<SessionServer>,
651 db_pool: &sqlx::PgPool,
652 ) {
653 let subs = job_subscriptions.read().await;
654 let subscribers = match subs.get(&job_id) {
655 Some(s) if !s.is_empty() => s.clone(),
656 _ => return,
657 };
658 drop(subs);
659
660 let job_data = match Self::fetch_job_data_static(job_id, db_pool).await {
661 Ok(data) => data,
662 Err(e) => {
663 tracing::debug!(%job_id, error = %e, "Failed to fetch job data");
664 return;
665 }
666 };
667
668 let owner_subject = match Self::fetch_job_owner_subject_static(job_id, db_pool).await {
669 Ok(owner) => owner,
670 Err(e) => {
671 tracing::debug!(%job_id, error = %e, "Failed to fetch job owner");
672 return;
673 }
674 };
675
676 let mut unauthorized: HashSet<(SessionId, String)> = HashSet::new();
677
678 for sub in &subscribers {
679 if Self::check_owner_access(owner_subject.clone(), &sub.auth_context).is_err() {
680 unauthorized.insert((sub.session_id, sub.client_sub_id.clone()));
681 continue;
682 }
683
684 let message = RealtimeMessage::JobUpdate {
685 client_sub_id: sub.client_sub_id.clone(),
686 job: job_data.clone(),
687 };
688
689 if let Err(e) = session_server
690 .send_to_session(sub.session_id, message)
691 .await
692 {
693 tracing::trace!(%job_id, error = %e, "Failed to send job update");
694 }
695 }
696
697 if !unauthorized.is_empty() {
698 let mut subs = job_subscriptions.write().await;
699 if let Some(entries) = subs.get_mut(&job_id) {
700 entries
701 .retain(|e| !unauthorized.contains(&(e.session_id, e.client_sub_id.clone())));
702 }
703 subs.retain(|_, v| !v.is_empty());
704 }
705 }
706
707 async fn handle_workflow_change(
708 workflow_id: Uuid,
709 workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
710 session_server: &Arc<SessionServer>,
711 db_pool: &sqlx::PgPool,
712 ) {
713 let subs = workflow_subscriptions.read().await;
714 let subscribers = match subs.get(&workflow_id) {
715 Some(s) if !s.is_empty() => s.clone(),
716 _ => return,
717 };
718 drop(subs);
719
720 let workflow_data = match Self::fetch_workflow_data_static(workflow_id, db_pool).await {
721 Ok(data) => data,
722 Err(e) => {
723 tracing::debug!(%workflow_id, error = %e, "Failed to fetch workflow data");
724 return;
725 }
726 };
727
728 let owner_subject =
729 match Self::fetch_workflow_owner_subject_static(workflow_id, db_pool).await {
730 Ok(owner) => owner,
731 Err(e) => {
732 tracing::debug!(%workflow_id, error = %e, "Failed to fetch workflow owner");
733 return;
734 }
735 };
736
737 let mut unauthorized: HashSet<(SessionId, String)> = HashSet::new();
738
739 for sub in &subscribers {
740 if Self::check_owner_access(owner_subject.clone(), &sub.auth_context).is_err() {
741 unauthorized.insert((sub.session_id, sub.client_sub_id.clone()));
742 continue;
743 }
744
745 let message = RealtimeMessage::WorkflowUpdate {
746 client_sub_id: sub.client_sub_id.clone(),
747 workflow: workflow_data.clone(),
748 };
749
750 if let Err(e) = session_server
751 .send_to_session(sub.session_id, message)
752 .await
753 {
754 tracing::trace!(%workflow_id, error = %e, "Failed to send workflow update");
755 }
756 }
757
758 if !unauthorized.is_empty() {
759 let mut subs = workflow_subscriptions.write().await;
760 if let Some(entries) = subs.get_mut(&workflow_id) {
761 entries
762 .retain(|e| !unauthorized.contains(&(e.session_id, e.client_sub_id.clone())));
763 }
764 subs.retain(|_, v| !v.is_empty());
765 }
766 }
767
768 async fn handle_workflow_step_change(
769 step_id: Uuid,
770 workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
771 session_server: &Arc<SessionServer>,
772 db_pool: &sqlx::PgPool,
773 ) {
774 let workflow_id: Option<Uuid> = match sqlx::query_scalar(
775 "SELECT workflow_run_id FROM forge_workflow_steps WHERE id = $1",
776 )
777 .bind(step_id)
778 .fetch_optional(db_pool)
779 .await
780 {
781 Ok(id) => id,
782 Err(e) => {
783 tracing::debug!(%step_id, error = %e, "Failed to look up workflow for step");
784 return;
785 }
786 };
787
788 if let Some(wf_id) = workflow_id {
789 Self::handle_workflow_change(wf_id, workflow_subscriptions, session_server, db_pool)
790 .await;
791 }
792 }
793
794 #[allow(clippy::type_complexity)]
795 async fn fetch_job_data_static(
796 job_id: Uuid,
797 db_pool: &sqlx::PgPool,
798 ) -> forge_core::Result<JobData> {
799 let row: Option<(
800 String,
801 Option<i32>,
802 Option<String>,
803 Option<serde_json::Value>,
804 Option<String>,
805 )> = sqlx::query_as(
806 r#"
807 SELECT status, progress_percent, progress_message, output, last_error
808 FROM forge_jobs WHERE id = $1
809 "#,
810 )
811 .bind(job_id)
812 .fetch_optional(db_pool)
813 .await
814 .map_err(forge_core::ForgeError::Sql)?;
815
816 match row {
817 Some((status, progress_percent, progress_message, output, error)) => Ok(JobData {
818 job_id: job_id.to_string(),
819 status,
820 progress_percent,
821 progress_message,
822 output,
823 error,
824 }),
825 None => Err(forge_core::ForgeError::NotFound(format!(
826 "Job {} not found",
827 job_id
828 ))),
829 }
830 }
831
832 async fn fetch_job_owner_subject_static(
833 job_id: Uuid,
834 db_pool: &sqlx::PgPool,
835 ) -> forge_core::Result<Option<String>> {
836 let owner_subject: Option<Option<String>> =
837 sqlx::query_scalar("SELECT owner_subject FROM forge_jobs WHERE id = $1")
838 .bind(job_id)
839 .fetch_optional(db_pool)
840 .await
841 .map_err(forge_core::ForgeError::Sql)?;
842
843 owner_subject
844 .ok_or_else(|| forge_core::ForgeError::NotFound(format!("Job {} not found", job_id)))
845 }
846
847 #[allow(clippy::type_complexity)]
848 async fn fetch_workflow_data_static(
849 workflow_id: Uuid,
850 db_pool: &sqlx::PgPool,
851 ) -> forge_core::Result<WorkflowData> {
852 let row: Option<(
853 String,
854 Option<String>,
855 Option<serde_json::Value>,
856 Option<String>,
857 )> = sqlx::query_as(
858 r#"
859 SELECT status, current_step, output, error
860 FROM forge_workflow_runs WHERE id = $1
861 "#,
862 )
863 .bind(workflow_id)
864 .fetch_optional(db_pool)
865 .await
866 .map_err(forge_core::ForgeError::Sql)?;
867
868 let (status, current_step, output, error) = match row {
869 Some(r) => r,
870 None => {
871 return Err(forge_core::ForgeError::NotFound(format!(
872 "Workflow {} not found",
873 workflow_id
874 )));
875 }
876 };
877
878 let step_rows: Vec<(String, String, Option<String>)> = sqlx::query_as(
879 r#"
880 SELECT step_name, status, error
881 FROM forge_workflow_steps
882 WHERE workflow_run_id = $1
883 ORDER BY started_at ASC NULLS LAST
884 "#,
885 )
886 .bind(workflow_id)
887 .fetch_all(db_pool)
888 .await
889 .map_err(forge_core::ForgeError::Sql)?;
890
891 let steps = step_rows
892 .into_iter()
893 .map(|(name, status, error)| WorkflowStepData {
894 name,
895 status,
896 error,
897 })
898 .collect();
899
900 Ok(WorkflowData {
901 workflow_id: workflow_id.to_string(),
902 status,
903 current_step,
904 steps,
905 output,
906 error,
907 })
908 }
909
910 async fn fetch_workflow_owner_subject_static(
911 workflow_id: Uuid,
912 db_pool: &sqlx::PgPool,
913 ) -> forge_core::Result<Option<String>> {
914 let owner_subject: Option<Option<String>> =
915 sqlx::query_scalar("SELECT owner_subject FROM forge_workflow_runs WHERE id = $1")
916 .bind(workflow_id)
917 .fetch_optional(db_pool)
918 .await
919 .map_err(forge_core::ForgeError::Sql)?;
920
921 owner_subject.ok_or_else(|| {
922 forge_core::ForgeError::NotFound(format!("Workflow {} not found", workflow_id))
923 })
924 }
925
926 async fn execute_query_static(
927 registry: &FunctionRegistry,
928 db_pool: &sqlx::PgPool,
929 query_name: &str,
930 args: &serde_json::Value,
931 auth_context: &forge_core::function::AuthContext,
932 ) -> forge_core::Result<(serde_json::Value, ReadSet)> {
933 match registry.get(query_name) {
934 Some(FunctionEntry::Query { info, handler }) => {
935 Self::check_query_auth(info, auth_context)?;
936 Self::check_identity_args(query_name, args, auth_context, !info.is_public)?;
937
938 let ctx = forge_core::function::QueryContext::new(
939 db_pool.clone(),
940 auth_context.clone(),
941 forge_core::function::RequestMetadata::new(),
942 );
943
944 let normalized_args = match args {
945 v if v.as_object().is_some_and(|o| o.is_empty()) => serde_json::Value::Null,
946 v => v.clone(),
947 };
948
949 let data = handler(&ctx, normalized_args).await?;
950
951 let mut read_set = ReadSet::new();
952
953 if info.table_dependencies.is_empty() {
954 let table_name = Self::extract_table_name(query_name);
955 read_set.add_table(&table_name);
956 tracing::trace!(
957 query = %query_name,
958 fallback_table = %table_name,
959 "Using naming convention fallback for table dependency"
960 );
961 } else {
962 for table in info.table_dependencies {
963 read_set.add_table(*table);
964 }
965 }
966
967 Ok((data, read_set))
968 }
969 _ => Err(forge_core::ForgeError::Validation(format!(
970 "Query '{}' not found or not a query",
971 query_name
972 ))),
973 }
974 }
975
976 fn extract_table_name(query_name: &str) -> String {
977 if let Some(rest) = query_name.strip_prefix("get_") {
978 rest.to_string()
979 } else if let Some(rest) = query_name.strip_prefix("list_") {
980 rest.to_string()
981 } else if let Some(rest) = query_name.strip_prefix("find_") {
982 rest.to_string()
983 } else if let Some(rest) = query_name.strip_prefix("fetch_") {
984 rest.to_string()
985 } else {
986 query_name.to_string()
987 }
988 }
989
990 fn check_query_auth(
991 info: &forge_core::function::FunctionInfo,
992 auth: &forge_core::function::AuthContext,
993 ) -> forge_core::Result<()> {
994 if info.is_public {
995 return Ok(());
996 }
997
998 if !auth.is_authenticated() {
999 return Err(forge_core::ForgeError::Unauthorized(
1000 "Authentication required".into(),
1001 ));
1002 }
1003
1004 if let Some(role) = info.required_role
1005 && !auth.has_role(role)
1006 {
1007 return Err(forge_core::ForgeError::Forbidden(format!(
1008 "Role '{}' required",
1009 role
1010 )));
1011 }
1012
1013 Ok(())
1014 }
1015
1016 fn check_identity_args(
1017 function_name: &str,
1018 args: &serde_json::Value,
1019 auth: &forge_core::function::AuthContext,
1020 enforce_scope: bool,
1021 ) -> forge_core::Result<()> {
1022 if auth.is_admin() {
1023 return Ok(());
1024 }
1025
1026 let Some(obj) = args.as_object() else {
1027 if enforce_scope && auth.is_authenticated() {
1028 return Err(forge_core::ForgeError::Forbidden(format!(
1029 "Function '{function_name}' must include identity or tenant scope arguments"
1030 )));
1031 }
1032 return Ok(());
1033 };
1034
1035 let mut principal_values: Vec<String> = Vec::new();
1036 if let Some(user_id) = auth.user_id().map(|id| id.to_string()) {
1037 principal_values.push(user_id);
1038 }
1039 if let Some(subject) = auth.principal_id()
1040 && !principal_values.iter().any(|v| v == &subject)
1041 {
1042 principal_values.push(subject);
1043 }
1044
1045 let mut has_scope_key = false;
1046
1047 for key in [
1048 "user_id",
1049 "userId",
1050 "owner_id",
1051 "ownerId",
1052 "owner_subject",
1053 "ownerSubject",
1054 "subject",
1055 "sub",
1056 "principal_id",
1057 "principalId",
1058 ] {
1059 let Some(value) = obj.get(key) else {
1060 continue;
1061 };
1062 has_scope_key = true;
1063
1064 if !auth.is_authenticated() {
1065 return Err(forge_core::ForgeError::Unauthorized(format!(
1066 "Function '{function_name}' requires authentication for identity-scoped argument '{key}'"
1067 )));
1068 }
1069
1070 let serde_json::Value::String(actual) = value else {
1071 return Err(forge_core::ForgeError::InvalidArgument(format!(
1072 "Function '{function_name}' argument '{key}' must be a non-empty string"
1073 )));
1074 };
1075
1076 if actual.trim().is_empty() || !principal_values.iter().any(|v| v == actual) {
1077 return Err(forge_core::ForgeError::Forbidden(format!(
1078 "Function '{function_name}' argument '{key}' does not match authenticated principal"
1079 )));
1080 }
1081 }
1082
1083 for key in ["tenant_id", "tenantId"] {
1084 let Some(value) = obj.get(key) else {
1085 continue;
1086 };
1087 has_scope_key = true;
1088
1089 if !auth.is_authenticated() {
1090 return Err(forge_core::ForgeError::Unauthorized(format!(
1091 "Function '{function_name}' requires authentication for tenant-scoped argument '{key}'"
1092 )));
1093 }
1094
1095 let expected = auth
1096 .claim("tenant_id")
1097 .and_then(|v| v.as_str())
1098 .ok_or_else(|| {
1099 forge_core::ForgeError::Forbidden(format!(
1100 "Function '{function_name}' argument '{key}' is not allowed for this principal"
1101 ))
1102 })?;
1103
1104 let serde_json::Value::String(actual) = value else {
1105 return Err(forge_core::ForgeError::InvalidArgument(format!(
1106 "Function '{function_name}' argument '{key}' must be a non-empty string"
1107 )));
1108 };
1109
1110 if actual.trim().is_empty() || actual != expected {
1111 return Err(forge_core::ForgeError::Forbidden(format!(
1112 "Function '{function_name}' argument '{key}' does not match authenticated tenant"
1113 )));
1114 }
1115 }
1116
1117 if enforce_scope && auth.is_authenticated() && !has_scope_key {
1118 return Err(forge_core::ForgeError::Forbidden(format!(
1119 "Function '{function_name}' must include identity or tenant scope arguments"
1120 )));
1121 }
1122
1123 Ok(())
1124 }
1125
1126 async fn ensure_job_access(
1127 db_pool: &sqlx::PgPool,
1128 job_id: Uuid,
1129 auth: &forge_core::function::AuthContext,
1130 ) -> forge_core::Result<()> {
1131 let owner_subject_row: Option<(Option<String>,)> =
1132 sqlx::query_as(r#"SELECT owner_subject FROM forge_jobs WHERE id = $1"#)
1133 .bind(job_id)
1134 .fetch_optional(db_pool)
1135 .await
1136 .map_err(forge_core::ForgeError::Sql)?;
1137
1138 let owner_subject = owner_subject_row
1139 .ok_or_else(|| forge_core::ForgeError::NotFound(format!("Job {} not found", job_id)))?
1140 .0;
1141
1142 Self::check_owner_access(owner_subject, auth)
1143 }
1144
1145 async fn ensure_workflow_access(
1146 db_pool: &sqlx::PgPool,
1147 workflow_id: Uuid,
1148 auth: &forge_core::function::AuthContext,
1149 ) -> forge_core::Result<()> {
1150 let owner_subject_row: Option<(Option<String>,)> =
1151 sqlx::query_as(r#"SELECT owner_subject FROM forge_workflow_runs WHERE id = $1"#)
1152 .bind(workflow_id)
1153 .fetch_optional(db_pool)
1154 .await
1155 .map_err(forge_core::ForgeError::Sql)?;
1156
1157 let owner_subject = owner_subject_row
1158 .ok_or_else(|| {
1159 forge_core::ForgeError::NotFound(format!("Workflow {} not found", workflow_id))
1160 })?
1161 .0;
1162
1163 Self::check_owner_access(owner_subject, auth)
1164 }
1165
1166 fn check_owner_access(
1167 owner_subject: Option<String>,
1168 auth: &forge_core::function::AuthContext,
1169 ) -> forge_core::Result<()> {
1170 if auth.is_admin() {
1171 return Ok(());
1172 }
1173
1174 let principal = auth.principal_id().ok_or_else(|| {
1175 forge_core::ForgeError::Unauthorized("Authentication required".to_string())
1176 })?;
1177
1178 match owner_subject {
1179 Some(owner) if owner == principal => Ok(()),
1180 Some(_) => Err(forge_core::ForgeError::Forbidden(
1181 "Not authorized to access this resource".to_string(),
1182 )),
1183 None => Err(forge_core::ForgeError::Forbidden(
1184 "Resource has no owner; admin role required".to_string(),
1185 )),
1186 }
1187 }
1188
1189 pub fn stop(&self) {
1190 let _ = self.shutdown_tx.send(());
1191 self.change_listener.stop();
1192 }
1193
1194 pub async fn stats(&self) -> ReactorStats {
1195 let session_stats = self.session_server.stats();
1196 let inv_stats = self.invalidation_engine.stats().await;
1197
1198 ReactorStats {
1199 connections: session_stats.connections,
1200 subscriptions: session_stats.subscriptions,
1201 query_groups: self.subscription_manager.group_count(),
1202 pending_invalidations: inv_stats.pending_groups,
1203 listener_running: self.change_listener.is_running(),
1204 }
1205 }
1206}
1207
1208#[derive(Debug, Clone)]
1210pub struct ReactorStats {
1211 pub connections: usize,
1212 pub subscriptions: usize,
1213 pub query_groups: usize,
1214 pub pending_invalidations: usize,
1215 pub listener_running: bool,
1216}
1217
1218#[cfg(test)]
1219mod tests {
1220 use super::*;
1221 use std::collections::HashMap;
1222
1223 #[test]
1224 fn test_reactor_config_default() {
1225 let config = ReactorConfig::default();
1226 assert_eq!(config.listener.channel, "forge_changes");
1227 assert_eq!(config.invalidation.debounce_ms, 50);
1228 assert_eq!(config.max_listener_restarts, 5);
1229 assert_eq!(config.listener_restart_delay_ms, 1000);
1230 assert_eq!(config.max_concurrent_reexecutions, 64);
1231 assert_eq!(config.session_cleanup_interval_secs, 60);
1232 }
1233
1234 #[test]
1235 fn test_compute_hash() {
1236 let data1 = serde_json::json!({"name": "test"});
1237 let data2 = serde_json::json!({"name": "test"});
1238 let data3 = serde_json::json!({"name": "different"});
1239
1240 let hash1 = Reactor::compute_hash(&data1);
1241 let hash2 = Reactor::compute_hash(&data2);
1242 let hash3 = Reactor::compute_hash(&data3);
1243
1244 assert_eq!(hash1, hash2);
1245 assert_ne!(hash1, hash3);
1246 }
1247
1248 #[test]
1249 fn test_check_identity_args_rejects_cross_user() {
1250 let user_id = uuid::Uuid::new_v4();
1251 let auth = forge_core::function::AuthContext::authenticated(
1252 user_id,
1253 vec!["user".to_string()],
1254 HashMap::from([(
1255 "sub".to_string(),
1256 serde_json::Value::String(user_id.to_string()),
1257 )]),
1258 );
1259
1260 let result = Reactor::check_identity_args(
1261 "list_orders",
1262 &serde_json::json!({"user_id": uuid::Uuid::new_v4().to_string()}),
1263 &auth,
1264 true,
1265 );
1266 assert!(matches!(result, Err(forge_core::ForgeError::Forbidden(_))));
1267 }
1268
1269 #[test]
1270 fn test_check_identity_args_requires_scope_for_non_public_queries() {
1271 let user_id = uuid::Uuid::new_v4();
1272 let auth = forge_core::function::AuthContext::authenticated(
1273 user_id,
1274 vec!["user".to_string()],
1275 HashMap::from([(
1276 "sub".to_string(),
1277 serde_json::Value::String(user_id.to_string()),
1278 )]),
1279 );
1280
1281 let result =
1282 Reactor::check_identity_args("list_orders", &serde_json::json!({}), &auth, true);
1283 assert!(matches!(result, Err(forge_core::ForgeError::Forbidden(_))));
1284 }
1285
1286 #[test]
1287 fn test_check_owner_access_allows_admin() {
1288 let auth = forge_core::function::AuthContext::authenticated_without_uuid(
1289 vec!["admin".to_string()],
1290 HashMap::from([(
1291 "sub".to_string(),
1292 serde_json::Value::String("admin-1".to_string()),
1293 )]),
1294 );
1295
1296 let result = Reactor::check_owner_access(Some("other-user".to_string()), &auth);
1297 assert!(result.is_ok());
1298 }
1299}