1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use futures_util::StreamExt;
5use futures_util::stream::FuturesUnordered;
6use tokio::sync::{RwLock, Semaphore, broadcast, mpsc};
7use uuid::Uuid;
8
9use forge_core::cluster::NodeId;
10use forge_core::realtime::{Change, ReadSet, SessionId, SubscriptionId};
11
12use super::invalidation::{InvalidationConfig, InvalidationEngine};
13use super::listener::{ChangeListener, ListenerConfig};
14use super::manager::SubscriptionManager;
15use super::message::{
16 JobData, RealtimeConfig, RealtimeMessage, SessionServer, WorkflowData, WorkflowStepData,
17};
18use crate::function::{FunctionEntry, FunctionRegistry};
19
20#[derive(Debug, Clone)]
21pub struct ReactorConfig {
22 pub listener: ListenerConfig,
23 pub invalidation: InvalidationConfig,
24 pub realtime: RealtimeConfig,
25 pub max_listener_restarts: u32,
26 pub listener_restart_delay_ms: u64,
28 pub max_concurrent_reexecutions: usize,
30 pub session_cleanup_interval_secs: u64,
32}
33
34impl Default for ReactorConfig {
35 fn default() -> Self {
36 Self {
37 listener: ListenerConfig::default(),
38 invalidation: InvalidationConfig::default(),
39 realtime: RealtimeConfig::default(),
40 max_listener_restarts: 5,
41 listener_restart_delay_ms: 1000,
42 max_concurrent_reexecutions: 64,
43 session_cleanup_interval_secs: 60,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct JobSubscription {
51 #[allow(dead_code)]
52 pub subscription_id: SubscriptionId,
53 pub session_id: SessionId,
54 pub client_sub_id: String,
55 #[allow(dead_code)]
56 pub job_id: Uuid,
57 pub auth_context: forge_core::function::AuthContext,
58}
59
60#[derive(Debug, Clone)]
62pub struct WorkflowSubscription {
63 #[allow(dead_code)]
64 pub subscription_id: SubscriptionId,
65 pub session_id: SessionId,
66 pub client_sub_id: String,
67 #[allow(dead_code)]
68 pub workflow_id: Uuid,
69 pub auth_context: forge_core::function::AuthContext,
70}
71
72pub struct Reactor {
74 node_id: NodeId,
75 db_pool: sqlx::PgPool,
76 registry: FunctionRegistry,
77 subscription_manager: Arc<SubscriptionManager>,
78 session_server: Arc<SessionServer>,
79 change_listener: Arc<ChangeListener>,
80 invalidation_engine: Arc<InvalidationEngine>,
81 job_subscriptions: Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
83 workflow_subscriptions: Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
85 shutdown_tx: broadcast::Sender<()>,
86 max_listener_restarts: u32,
87 listener_restart_delay_ms: u64,
88 max_concurrent_reexecutions: usize,
89 session_cleanup_interval_secs: u64,
90}
91
92impl Reactor {
93 pub fn new(
95 node_id: NodeId,
96 db_pool: sqlx::PgPool,
97 registry: FunctionRegistry,
98 config: ReactorConfig,
99 ) -> Self {
100 let subscription_manager = Arc::new(SubscriptionManager::new(
101 config.realtime.max_subscriptions_per_session,
102 ));
103 let session_server = Arc::new(SessionServer::new(node_id, config.realtime.clone()));
104 let change_listener = Arc::new(ChangeListener::new(db_pool.clone(), config.listener));
105 let invalidation_engine = Arc::new(InvalidationEngine::new(
106 subscription_manager.clone(),
107 config.invalidation,
108 ));
109 let (shutdown_tx, _) = broadcast::channel(1);
110
111 Self {
112 node_id,
113 db_pool,
114 registry,
115 subscription_manager,
116 session_server,
117 change_listener,
118 invalidation_engine,
119 job_subscriptions: Arc::new(RwLock::new(HashMap::new())),
120 workflow_subscriptions: Arc::new(RwLock::new(HashMap::new())),
121 shutdown_tx,
122 max_listener_restarts: config.max_listener_restarts,
123 listener_restart_delay_ms: config.listener_restart_delay_ms,
124 max_concurrent_reexecutions: config.max_concurrent_reexecutions,
125 session_cleanup_interval_secs: config.session_cleanup_interval_secs,
126 }
127 }
128
129 pub fn node_id(&self) -> NodeId {
130 self.node_id
131 }
132
133 pub fn session_server(&self) -> Arc<SessionServer> {
134 self.session_server.clone()
135 }
136
137 pub fn subscription_manager(&self) -> Arc<SubscriptionManager> {
138 self.subscription_manager.clone()
139 }
140
141 pub fn shutdown_receiver(&self) -> broadcast::Receiver<()> {
142 self.shutdown_tx.subscribe()
143 }
144
145 pub fn register_session(&self, session_id: SessionId, sender: mpsc::Sender<RealtimeMessage>) {
147 self.session_server.register_connection(session_id, sender);
148 tracing::trace!(?session_id, "Session registered");
149 }
150
151 pub async fn remove_session(&self, session_id: SessionId) {
153 self.subscription_manager
155 .remove_session_subscriptions(session_id);
156
157 self.session_server.remove_connection(session_id);
159
160 {
162 let mut job_subs = self.job_subscriptions.write().await;
163 for subscribers in job_subs.values_mut() {
164 subscribers.retain(|s| s.session_id != session_id);
165 }
166 job_subs.retain(|_, v| !v.is_empty());
167 }
168
169 {
171 let mut workflow_subs = self.workflow_subscriptions.write().await;
172 for subscribers in workflow_subs.values_mut() {
173 subscribers.retain(|s| s.session_id != session_id);
174 }
175 workflow_subs.retain(|_, v| !v.is_empty());
176 }
177
178 tracing::trace!(?session_id, "Session removed");
179 }
180
181 pub async fn subscribe(
183 &self,
184 session_id: SessionId,
185 client_sub_id: String,
186 query_name: String,
187 args: serde_json::Value,
188 auth_context: forge_core::function::AuthContext,
189 ) -> forge_core::Result<(SubscriptionId, serde_json::Value)> {
190 let (table_deps, selected_cols) = match self.registry.get(&query_name) {
192 Some(FunctionEntry::Query { info, .. }) => {
193 (info.table_dependencies, info.selected_columns)
194 }
195 _ => (&[] as &[&str], &[] as &[&str]),
196 };
197
198 let (group_id, subscription_id, is_new_group) = self.subscription_manager.subscribe(
199 session_id,
200 client_sub_id,
201 &query_name,
202 &args,
203 &auth_context,
204 table_deps,
205 selected_cols,
206 )?;
207
208 if let Err(error) = self
210 .session_server
211 .add_subscription(session_id, subscription_id)
212 {
213 self.subscription_manager.unsubscribe(subscription_id);
214 return Err(error);
215 }
216
217 let data = if is_new_group {
219 let (data, read_set) = match self.execute_query(&query_name, &args, &auth_context).await
220 {
221 Ok(result) => result,
222 Err(error) => {
223 self.unsubscribe(subscription_id);
224 return Err(error);
225 }
226 };
227
228 let result_hash = Self::compute_hash(&data);
229
230 tracing::trace!(
231 ?group_id,
232 query = %query_name,
233 "New query group created"
234 );
235
236 self.subscription_manager
237 .update_group(group_id, read_set, result_hash);
238
239 data
240 } else {
241 let (data, _) = match self.execute_query(&query_name, &args, &auth_context).await {
244 Ok(result) => result,
245 Err(error) => {
246 self.unsubscribe(subscription_id);
247 return Err(error);
248 }
249 };
250 data
251 };
252
253 tracing::trace!(?subscription_id, "Subscription created");
254 Ok((subscription_id, data))
255 }
256
257 pub fn unsubscribe(&self, subscription_id: SubscriptionId) {
259 self.session_server.remove_subscription(subscription_id);
260 self.subscription_manager.unsubscribe(subscription_id);
261 tracing::trace!(?subscription_id, "Subscription removed");
262 }
263
264 pub async fn subscribe_job(
266 &self,
267 session_id: SessionId,
268 client_sub_id: String,
269 job_id: Uuid,
270 auth_context: &forge_core::function::AuthContext,
271 ) -> forge_core::Result<JobData> {
272 let subscription_id = SubscriptionId::new();
273 Self::ensure_job_access(&self.db_pool, job_id, auth_context).await?;
274 let job_data = self.fetch_job_data(job_id).await?;
275
276 let subscription = JobSubscription {
277 subscription_id,
278 session_id,
279 client_sub_id,
280 job_id,
281 auth_context: auth_context.clone(),
282 };
283
284 let mut subs = self.job_subscriptions.write().await;
285 subs.entry(job_id).or_default().push(subscription);
286
287 tracing::trace!(?subscription_id, %job_id, "Job subscription created");
288 Ok(job_data)
289 }
290
291 pub async fn unsubscribe_job(&self, session_id: SessionId, client_sub_id: &str) {
293 let mut subs = self.job_subscriptions.write().await;
294 for subscribers in subs.values_mut() {
295 subscribers
296 .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id));
297 }
298 subs.retain(|_, v| !v.is_empty());
299 }
300
301 pub async fn subscribe_workflow(
303 &self,
304 session_id: SessionId,
305 client_sub_id: String,
306 workflow_id: Uuid,
307 auth_context: &forge_core::function::AuthContext,
308 ) -> forge_core::Result<WorkflowData> {
309 let subscription_id = SubscriptionId::new();
310 Self::ensure_workflow_access(&self.db_pool, workflow_id, auth_context).await?;
311 let workflow_data = self.fetch_workflow_data(workflow_id).await?;
312
313 let subscription = WorkflowSubscription {
314 subscription_id,
315 session_id,
316 client_sub_id,
317 workflow_id,
318 auth_context: auth_context.clone(),
319 };
320
321 let mut subs = self.workflow_subscriptions.write().await;
322 subs.entry(workflow_id).or_default().push(subscription);
323
324 tracing::trace!(?subscription_id, %workflow_id, "Workflow subscription created");
325 Ok(workflow_data)
326 }
327
328 pub async fn unsubscribe_workflow(&self, session_id: SessionId, client_sub_id: &str) {
330 let mut subs = self.workflow_subscriptions.write().await;
331 for subscribers in subs.values_mut() {
332 subscribers
333 .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id));
334 }
335 subs.retain(|_, v| !v.is_empty());
336 }
337
338 #[allow(clippy::type_complexity)]
339 async fn fetch_job_data(&self, job_id: Uuid) -> forge_core::Result<JobData> {
340 Self::fetch_job_data_static(job_id, &self.db_pool).await
341 }
342
343 async fn fetch_workflow_data(&self, workflow_id: Uuid) -> forge_core::Result<WorkflowData> {
344 Self::fetch_workflow_data_static(workflow_id, &self.db_pool).await
345 }
346
347 async fn execute_query(
349 &self,
350 query_name: &str,
351 args: &serde_json::Value,
352 auth_context: &forge_core::function::AuthContext,
353 ) -> forge_core::Result<(serde_json::Value, ReadSet)> {
354 Self::execute_query_static(
355 &self.registry,
356 &self.db_pool,
357 query_name,
358 args,
359 auth_context,
360 )
361 .await
362 }
363
364 fn compute_hash(data: &serde_json::Value) -> String {
365 use std::collections::hash_map::DefaultHasher;
366 use std::hash::{Hash, Hasher};
367
368 let json = serde_json::to_string(data).unwrap_or_default();
369 let mut hasher = DefaultHasher::new();
370 json.hash(&mut hasher);
371 format!("{:x}", hasher.finish())
372 }
373
374 async fn flush_invalidations(
376 invalidation_engine: &Arc<InvalidationEngine>,
377 subscription_manager: &Arc<SubscriptionManager>,
378 session_server: &Arc<SessionServer>,
379 registry: &FunctionRegistry,
380 db_pool: &sqlx::PgPool,
381 max_concurrent: usize,
382 ) {
383 let invalidated_groups = invalidation_engine.check_pending().await;
384 if invalidated_groups.is_empty() {
385 return;
386 }
387
388 tracing::trace!(
389 count = invalidated_groups.len(),
390 "Invalidating query groups"
391 );
392
393 let groups_to_process: Vec<_> = invalidated_groups
395 .iter()
396 .filter_map(|gid| {
397 subscription_manager.get_group(*gid).map(|g| {
398 (
399 g.id,
400 g.query_name.clone(),
401 (*g.args).clone(),
402 g.last_result_hash.clone(),
403 g.auth_context.clone(),
404 )
405 })
406 })
407 .collect();
408
409 let semaphore = Arc::new(Semaphore::new(max_concurrent));
411 let mut futures = FuturesUnordered::new();
412
413 for (group_id, query_name, args, last_hash, auth_context) in groups_to_process {
414 let permit = match semaphore.clone().acquire_owned().await {
415 Ok(p) => p,
416 Err(_) => break,
417 };
418 let registry = registry.clone();
419 let db_pool = db_pool.clone();
420
421 futures.push(async move {
422 let result = Self::execute_query_static(
423 ®istry,
424 &db_pool,
425 &query_name,
426 &args,
427 &auth_context,
428 )
429 .await;
430 drop(permit);
431 (group_id, last_hash, result)
432 });
433 }
434
435 while let Some((group_id, last_hash, result)) = futures.next().await {
437 match result {
438 Ok((new_data, read_set)) => {
439 let new_hash = Self::compute_hash(&new_data);
440
441 if last_hash.as_ref() != Some(&new_hash) {
442 subscription_manager.update_group(group_id, read_set, new_hash);
444
445 let subscribers = subscription_manager.get_group_subscribers(group_id);
447 for (session_id, client_sub_id) in subscribers {
448 let message = RealtimeMessage::Data {
449 subscription_id: client_sub_id.clone(),
450 data: new_data.clone(),
451 };
452
453 if let Err(e) = session_server.try_send_to_session(session_id, message)
454 {
455 tracing::trace!(
456 client_id = %client_sub_id,
457 error = ?e,
458 "Failed to send update to subscriber"
459 );
460 }
461 }
462 }
463 }
464 Err(e) => {
465 tracing::warn!(?group_id, error = %e, "Failed to re-execute query group");
466 }
467 }
468 }
469 }
470
471 pub async fn start(&self) -> forge_core::Result<()> {
473 let listener = self.change_listener.clone();
474 let invalidation_engine = self.invalidation_engine.clone();
475 let subscription_manager = self.subscription_manager.clone();
476 let job_subscriptions = self.job_subscriptions.clone();
477 let workflow_subscriptions = self.workflow_subscriptions.clone();
478 let session_server = self.session_server.clone();
479 let registry = self.registry.clone();
480 let db_pool = self.db_pool.clone();
481 let mut shutdown_rx = self.shutdown_tx.subscribe();
482 let max_restarts = self.max_listener_restarts;
483 let base_delay_ms = self.listener_restart_delay_ms;
484 let max_concurrent = self.max_concurrent_reexecutions;
485 let cleanup_secs = self.session_cleanup_interval_secs;
486
487 let mut change_rx = listener.subscribe();
488
489 tokio::spawn(async move {
490 tracing::debug!("Reactor listening for changes");
491
492 let mut restart_count: u32 = 0;
493 let (listener_error_tx, mut listener_error_rx) = mpsc::channel::<String>(1);
494
495 let listener_clone = listener.clone();
497 let error_tx = listener_error_tx.clone();
498 let mut listener_handle = Some(tokio::spawn(async move {
499 if let Err(e) = listener_clone.run().await {
500 let _ = error_tx.send(format!("Change listener error: {}", e)).await;
501 }
502 }));
503
504 let mut flush_interval = tokio::time::interval(std::time::Duration::from_millis(25));
505 flush_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
506
507 let mut cleanup_interval =
508 tokio::time::interval(std::time::Duration::from_secs(cleanup_secs));
509 cleanup_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
510
511 loop {
512 tokio::select! {
513 result = change_rx.recv() => {
514 match result {
515 Ok(change) => {
516 Self::handle_change(
517 &change,
518 &invalidation_engine,
519 &job_subscriptions,
520 &workflow_subscriptions,
521 &session_server,
522 &db_pool,
523 ).await;
524 }
525 Err(broadcast::error::RecvError::Lagged(n)) => {
526 tracing::warn!("Reactor lagged by {} messages", n);
527 }
528 Err(broadcast::error::RecvError::Closed) => {
529 tracing::debug!("Change channel closed");
530 break;
531 }
532 }
533 }
534 _ = flush_interval.tick() => {
535 Self::flush_invalidations(
536 &invalidation_engine,
537 &subscription_manager,
538 &session_server,
539 ®istry,
540 &db_pool,
541 max_concurrent,
542 ).await;
543 }
544 _ = cleanup_interval.tick() => {
545 session_server.cleanup_stale(std::time::Duration::from_secs(300));
546 }
547 Some(error_msg) = listener_error_rx.recv() => {
548 if restart_count >= max_restarts {
549 tracing::error!(
550 attempts = restart_count,
551 last_error = %error_msg,
552 "Change listener failed permanently, real-time updates disabled"
553 );
554 break;
555 }
556
557 restart_count += 1;
558 let delay = base_delay_ms * 2u64.saturating_pow(restart_count - 1);
559 tracing::warn!(
560 attempt = restart_count,
561 max = max_restarts,
562 delay_ms = delay,
563 error = %error_msg,
564 "Change listener restarting"
565 );
566
567 tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
568
569 let listener_clone = listener.clone();
570 let error_tx = listener_error_tx.clone();
571 if let Some(handle) = listener_handle.take() {
572 handle.abort();
573 }
574 change_rx = listener.subscribe();
575 listener_handle = Some(tokio::spawn(async move {
576 if let Err(e) = listener_clone.run().await {
577 let _ = error_tx.send(format!("Change listener error: {}", e)).await;
578 }
579 }));
580 }
581 _ = shutdown_rx.recv() => {
582 tracing::debug!("Reactor shutting down");
583 break;
584 }
585 }
586 }
587
588 if let Some(handle) = listener_handle {
589 handle.abort();
590 }
591 });
592
593 Ok(())
594 }
595
596 #[allow(clippy::too_many_arguments)]
598 async fn handle_change(
599 change: &Change,
600 invalidation_engine: &Arc<InvalidationEngine>,
601 job_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
602 workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
603 session_server: &Arc<SessionServer>,
604 db_pool: &sqlx::PgPool,
605 ) {
606 tracing::trace!(table = %change.table, op = ?change.operation, row_id = ?change.row_id, "Processing change");
607
608 match change.table.as_str() {
609 "forge_jobs" => {
610 if let Some(job_id) = change.row_id {
611 Self::handle_job_change(job_id, job_subscriptions, session_server, db_pool)
612 .await;
613 }
614 return;
615 }
616 "forge_workflow_runs" => {
617 if let Some(workflow_id) = change.row_id {
618 Self::handle_workflow_change(
619 workflow_id,
620 workflow_subscriptions,
621 session_server,
622 db_pool,
623 )
624 .await;
625 }
626 return;
627 }
628 "forge_workflow_steps" => {
629 if let Some(step_id) = change.row_id {
630 Self::handle_workflow_step_change(
631 step_id,
632 workflow_subscriptions,
633 session_server,
634 db_pool,
635 )
636 .await;
637 }
638 return;
639 }
640 _ => {}
641 }
642
643 invalidation_engine.process_change(change.clone()).await;
645 }
646
647 async fn handle_job_change(
648 job_id: Uuid,
649 job_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
650 session_server: &Arc<SessionServer>,
651 db_pool: &sqlx::PgPool,
652 ) {
653 let subs = job_subscriptions.read().await;
654 let subscribers = match subs.get(&job_id) {
655 Some(s) if !s.is_empty() => s.clone(),
656 _ => return,
657 };
658 drop(subs);
659
660 let job_data = match Self::fetch_job_data_static(job_id, db_pool).await {
661 Ok(data) => data,
662 Err(e) => {
663 tracing::debug!(%job_id, error = %e, "Failed to fetch job data");
664 return;
665 }
666 };
667
668 let owner_subject = match Self::fetch_job_owner_subject_static(job_id, db_pool).await {
669 Ok(owner) => owner,
670 Err(e) => {
671 tracing::debug!(%job_id, error = %e, "Failed to fetch job owner");
672 return;
673 }
674 };
675
676 let mut unauthorized: HashSet<(SessionId, String)> = HashSet::new();
677
678 for sub in &subscribers {
679 if Self::check_owner_access(owner_subject.clone(), &sub.auth_context).is_err() {
680 unauthorized.insert((sub.session_id, sub.client_sub_id.clone()));
681 continue;
682 }
683
684 let message = RealtimeMessage::JobUpdate {
685 client_sub_id: sub.client_sub_id.clone(),
686 job: job_data.clone(),
687 };
688
689 if let Err(e) = session_server
690 .send_to_session(sub.session_id, message)
691 .await
692 {
693 tracing::trace!(%job_id, error = %e, "Failed to send job update");
694 }
695 }
696
697 if !unauthorized.is_empty() {
698 let mut subs = job_subscriptions.write().await;
699 if let Some(entries) = subs.get_mut(&job_id) {
700 entries
701 .retain(|e| !unauthorized.contains(&(e.session_id, e.client_sub_id.clone())));
702 }
703 subs.retain(|_, v| !v.is_empty());
704 }
705 }
706
707 async fn handle_workflow_change(
708 workflow_id: Uuid,
709 workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
710 session_server: &Arc<SessionServer>,
711 db_pool: &sqlx::PgPool,
712 ) {
713 let subs = workflow_subscriptions.read().await;
714 let subscribers = match subs.get(&workflow_id) {
715 Some(s) if !s.is_empty() => s.clone(),
716 _ => return,
717 };
718 drop(subs);
719
720 let workflow_data = match Self::fetch_workflow_data_static(workflow_id, db_pool).await {
721 Ok(data) => data,
722 Err(e) => {
723 tracing::debug!(%workflow_id, error = %e, "Failed to fetch workflow data");
724 return;
725 }
726 };
727
728 let owner_subject =
729 match Self::fetch_workflow_owner_subject_static(workflow_id, db_pool).await {
730 Ok(owner) => owner,
731 Err(e) => {
732 tracing::debug!(%workflow_id, error = %e, "Failed to fetch workflow owner");
733 return;
734 }
735 };
736
737 let mut unauthorized: HashSet<(SessionId, String)> = HashSet::new();
738
739 for sub in &subscribers {
740 if Self::check_owner_access(owner_subject.clone(), &sub.auth_context).is_err() {
741 unauthorized.insert((sub.session_id, sub.client_sub_id.clone()));
742 continue;
743 }
744
745 let message = RealtimeMessage::WorkflowUpdate {
746 client_sub_id: sub.client_sub_id.clone(),
747 workflow: workflow_data.clone(),
748 };
749
750 if let Err(e) = session_server
751 .send_to_session(sub.session_id, message)
752 .await
753 {
754 tracing::trace!(%workflow_id, error = %e, "Failed to send workflow update");
755 }
756 }
757
758 if !unauthorized.is_empty() {
759 let mut subs = workflow_subscriptions.write().await;
760 if let Some(entries) = subs.get_mut(&workflow_id) {
761 entries
762 .retain(|e| !unauthorized.contains(&(e.session_id, e.client_sub_id.clone())));
763 }
764 subs.retain(|_, v| !v.is_empty());
765 }
766 }
767
768 async fn handle_workflow_step_change(
769 step_id: Uuid,
770 workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
771 session_server: &Arc<SessionServer>,
772 db_pool: &sqlx::PgPool,
773 ) {
774 let workflow_id: Option<Uuid> = match sqlx::query_scalar(
775 "SELECT workflow_run_id FROM forge_workflow_steps WHERE id = $1",
776 )
777 .bind(step_id)
778 .fetch_optional(db_pool)
779 .await
780 {
781 Ok(id) => id,
782 Err(e) => {
783 tracing::debug!(%step_id, error = %e, "Failed to look up workflow for step");
784 return;
785 }
786 };
787
788 if let Some(wf_id) = workflow_id {
789 Self::handle_workflow_change(wf_id, workflow_subscriptions, session_server, db_pool)
790 .await;
791 }
792 }
793
794 #[allow(clippy::type_complexity)]
795 async fn fetch_job_data_static(
796 job_id: Uuid,
797 db_pool: &sqlx::PgPool,
798 ) -> forge_core::Result<JobData> {
799 let row: Option<(
800 String,
801 Option<i32>,
802 Option<String>,
803 Option<serde_json::Value>,
804 Option<String>,
805 )> = sqlx::query_as(
806 r#"
807 SELECT status, progress_percent, progress_message, output, last_error
808 FROM forge_jobs WHERE id = $1
809 "#,
810 )
811 .bind(job_id)
812 .fetch_optional(db_pool)
813 .await
814 .map_err(forge_core::ForgeError::Sql)?;
815
816 match row {
817 Some((status, progress_percent, progress_message, output, error)) => Ok(JobData {
818 job_id: job_id.to_string(),
819 status,
820 progress_percent,
821 progress_message,
822 output,
823 error,
824 }),
825 None => Err(forge_core::ForgeError::NotFound(format!(
826 "Job {} not found",
827 job_id
828 ))),
829 }
830 }
831
832 async fn fetch_job_owner_subject_static(
833 job_id: Uuid,
834 db_pool: &sqlx::PgPool,
835 ) -> forge_core::Result<Option<String>> {
836 let owner_subject: Option<Option<String>> =
837 sqlx::query_scalar("SELECT owner_subject FROM forge_jobs WHERE id = $1")
838 .bind(job_id)
839 .fetch_optional(db_pool)
840 .await
841 .map_err(forge_core::ForgeError::Sql)?;
842
843 owner_subject
844 .ok_or_else(|| forge_core::ForgeError::NotFound(format!("Job {} not found", job_id)))
845 }
846
847 #[allow(clippy::type_complexity)]
848 async fn fetch_workflow_data_static(
849 workflow_id: Uuid,
850 db_pool: &sqlx::PgPool,
851 ) -> forge_core::Result<WorkflowData> {
852 let row: Option<(
853 String,
854 Option<String>,
855 Option<serde_json::Value>,
856 Option<String>,
857 )> = sqlx::query_as(
858 r#"
859 SELECT status, current_step, output, error
860 FROM forge_workflow_runs WHERE id = $1
861 "#,
862 )
863 .bind(workflow_id)
864 .fetch_optional(db_pool)
865 .await
866 .map_err(forge_core::ForgeError::Sql)?;
867
868 let (status, current_step, output, error) = match row {
869 Some(r) => r,
870 None => {
871 return Err(forge_core::ForgeError::NotFound(format!(
872 "Workflow {} not found",
873 workflow_id
874 )));
875 }
876 };
877
878 let step_rows: Vec<(String, String, Option<String>)> = sqlx::query_as(
879 r#"
880 SELECT step_name, status, error
881 FROM forge_workflow_steps
882 WHERE workflow_run_id = $1
883 ORDER BY started_at ASC NULLS LAST
884 "#,
885 )
886 .bind(workflow_id)
887 .fetch_all(db_pool)
888 .await
889 .map_err(forge_core::ForgeError::Sql)?;
890
891 let steps = step_rows
892 .into_iter()
893 .map(|(name, status, error)| WorkflowStepData {
894 name,
895 status,
896 error,
897 })
898 .collect();
899
900 Ok(WorkflowData {
901 workflow_id: workflow_id.to_string(),
902 status,
903 current_step,
904 steps,
905 output,
906 error,
907 })
908 }
909
910 async fn fetch_workflow_owner_subject_static(
911 workflow_id: Uuid,
912 db_pool: &sqlx::PgPool,
913 ) -> forge_core::Result<Option<String>> {
914 let owner_subject: Option<Option<String>> =
915 sqlx::query_scalar("SELECT owner_subject FROM forge_workflow_runs WHERE id = $1")
916 .bind(workflow_id)
917 .fetch_optional(db_pool)
918 .await
919 .map_err(forge_core::ForgeError::Sql)?;
920
921 owner_subject.ok_or_else(|| {
922 forge_core::ForgeError::NotFound(format!("Workflow {} not found", workflow_id))
923 })
924 }
925
926 async fn execute_query_static(
927 registry: &FunctionRegistry,
928 db_pool: &sqlx::PgPool,
929 query_name: &str,
930 args: &serde_json::Value,
931 auth_context: &forge_core::function::AuthContext,
932 ) -> forge_core::Result<(serde_json::Value, ReadSet)> {
933 match registry.get(query_name) {
934 Some(FunctionEntry::Query { info, handler }) => {
935 Self::check_query_auth(info, auth_context)?;
936 let enforce = !info.is_public && info.has_input_args;
937 Self::check_identity_args(query_name, args, auth_context, enforce)?;
938
939 let ctx = forge_core::function::QueryContext::new(
940 db_pool.clone(),
941 auth_context.clone(),
942 forge_core::function::RequestMetadata::new(),
943 );
944
945 let normalized_args = match args {
946 v if v.as_object().is_some_and(|o| o.is_empty()) => serde_json::Value::Null,
947 v => v.clone(),
948 };
949
950 let data = handler(&ctx, normalized_args).await?;
951
952 let mut read_set = ReadSet::new();
953
954 if info.table_dependencies.is_empty() {
955 let table_name = Self::extract_table_name(query_name);
956 read_set.add_table(&table_name);
957 tracing::trace!(
958 query = %query_name,
959 fallback_table = %table_name,
960 "Using naming convention fallback for table dependency"
961 );
962 } else {
963 for table in info.table_dependencies {
964 read_set.add_table(*table);
965 }
966 }
967
968 Ok((data, read_set))
969 }
970 _ => Err(forge_core::ForgeError::Validation(format!(
971 "Query '{}' not found or not a query",
972 query_name
973 ))),
974 }
975 }
976
977 fn extract_table_name(query_name: &str) -> String {
978 if let Some(rest) = query_name.strip_prefix("get_") {
979 rest.to_string()
980 } else if let Some(rest) = query_name.strip_prefix("list_") {
981 rest.to_string()
982 } else if let Some(rest) = query_name.strip_prefix("find_") {
983 rest.to_string()
984 } else if let Some(rest) = query_name.strip_prefix("fetch_") {
985 rest.to_string()
986 } else {
987 query_name.to_string()
988 }
989 }
990
991 fn check_query_auth(
992 info: &forge_core::function::FunctionInfo,
993 auth: &forge_core::function::AuthContext,
994 ) -> forge_core::Result<()> {
995 if info.is_public {
996 return Ok(());
997 }
998
999 if !auth.is_authenticated() {
1000 return Err(forge_core::ForgeError::Unauthorized(
1001 "Authentication required".into(),
1002 ));
1003 }
1004
1005 if let Some(role) = info.required_role
1006 && !auth.has_role(role)
1007 {
1008 return Err(forge_core::ForgeError::Forbidden(format!(
1009 "Role '{}' required",
1010 role
1011 )));
1012 }
1013
1014 Ok(())
1015 }
1016
1017 fn check_identity_args(
1018 function_name: &str,
1019 args: &serde_json::Value,
1020 auth: &forge_core::function::AuthContext,
1021 enforce_scope: bool,
1022 ) -> forge_core::Result<()> {
1023 if auth.is_admin() {
1024 return Ok(());
1025 }
1026
1027 let Some(obj) = args.as_object() else {
1028 if enforce_scope && auth.is_authenticated() {
1029 return Err(forge_core::ForgeError::Forbidden(format!(
1030 "Function '{function_name}' must include identity or tenant scope arguments"
1031 )));
1032 }
1033 return Ok(());
1034 };
1035
1036 let mut principal_values: Vec<String> = Vec::new();
1037 if let Some(user_id) = auth.user_id().map(|id| id.to_string()) {
1038 principal_values.push(user_id);
1039 }
1040 if let Some(subject) = auth.principal_id()
1041 && !principal_values.iter().any(|v| v == &subject)
1042 {
1043 principal_values.push(subject);
1044 }
1045
1046 let mut has_scope_key = false;
1047
1048 for key in [
1049 "user_id",
1050 "userId",
1051 "owner_id",
1052 "ownerId",
1053 "owner_subject",
1054 "ownerSubject",
1055 "subject",
1056 "sub",
1057 "principal_id",
1058 "principalId",
1059 ] {
1060 let Some(value) = obj.get(key) else {
1061 continue;
1062 };
1063 has_scope_key = true;
1064
1065 if !auth.is_authenticated() {
1066 return Err(forge_core::ForgeError::Unauthorized(format!(
1067 "Function '{function_name}' requires authentication for identity-scoped argument '{key}'"
1068 )));
1069 }
1070
1071 let serde_json::Value::String(actual) = value else {
1072 return Err(forge_core::ForgeError::InvalidArgument(format!(
1073 "Function '{function_name}' argument '{key}' must be a non-empty string"
1074 )));
1075 };
1076
1077 if actual.trim().is_empty() || !principal_values.iter().any(|v| v == actual) {
1078 return Err(forge_core::ForgeError::Forbidden(format!(
1079 "Function '{function_name}' argument '{key}' does not match authenticated principal"
1080 )));
1081 }
1082 }
1083
1084 for key in ["tenant_id", "tenantId"] {
1085 let Some(value) = obj.get(key) else {
1086 continue;
1087 };
1088 has_scope_key = true;
1089
1090 if !auth.is_authenticated() {
1091 return Err(forge_core::ForgeError::Unauthorized(format!(
1092 "Function '{function_name}' requires authentication for tenant-scoped argument '{key}'"
1093 )));
1094 }
1095
1096 let expected = auth
1097 .claim("tenant_id")
1098 .and_then(|v| v.as_str())
1099 .ok_or_else(|| {
1100 forge_core::ForgeError::Forbidden(format!(
1101 "Function '{function_name}' argument '{key}' is not allowed for this principal"
1102 ))
1103 })?;
1104
1105 let serde_json::Value::String(actual) = value else {
1106 return Err(forge_core::ForgeError::InvalidArgument(format!(
1107 "Function '{function_name}' argument '{key}' must be a non-empty string"
1108 )));
1109 };
1110
1111 if actual.trim().is_empty() || actual != expected {
1112 return Err(forge_core::ForgeError::Forbidden(format!(
1113 "Function '{function_name}' argument '{key}' does not match authenticated tenant"
1114 )));
1115 }
1116 }
1117
1118 if enforce_scope && auth.is_authenticated() && !has_scope_key {
1119 return Err(forge_core::ForgeError::Forbidden(format!(
1120 "Function '{function_name}' must include identity or tenant scope arguments"
1121 )));
1122 }
1123
1124 Ok(())
1125 }
1126
1127 async fn ensure_job_access(
1128 db_pool: &sqlx::PgPool,
1129 job_id: Uuid,
1130 auth: &forge_core::function::AuthContext,
1131 ) -> forge_core::Result<()> {
1132 let owner_subject_row: Option<(Option<String>,)> =
1133 sqlx::query_as(r#"SELECT owner_subject FROM forge_jobs WHERE id = $1"#)
1134 .bind(job_id)
1135 .fetch_optional(db_pool)
1136 .await
1137 .map_err(forge_core::ForgeError::Sql)?;
1138
1139 let owner_subject = owner_subject_row
1140 .ok_or_else(|| forge_core::ForgeError::NotFound(format!("Job {} not found", job_id)))?
1141 .0;
1142
1143 Self::check_owner_access(owner_subject, auth)
1144 }
1145
1146 async fn ensure_workflow_access(
1147 db_pool: &sqlx::PgPool,
1148 workflow_id: Uuid,
1149 auth: &forge_core::function::AuthContext,
1150 ) -> forge_core::Result<()> {
1151 let owner_subject_row: Option<(Option<String>,)> =
1152 sqlx::query_as(r#"SELECT owner_subject FROM forge_workflow_runs WHERE id = $1"#)
1153 .bind(workflow_id)
1154 .fetch_optional(db_pool)
1155 .await
1156 .map_err(forge_core::ForgeError::Sql)?;
1157
1158 let owner_subject = owner_subject_row
1159 .ok_or_else(|| {
1160 forge_core::ForgeError::NotFound(format!("Workflow {} not found", workflow_id))
1161 })?
1162 .0;
1163
1164 Self::check_owner_access(owner_subject, auth)
1165 }
1166
1167 fn check_owner_access(
1168 owner_subject: Option<String>,
1169 auth: &forge_core::function::AuthContext,
1170 ) -> forge_core::Result<()> {
1171 if auth.is_admin() {
1172 return Ok(());
1173 }
1174
1175 let principal = auth.principal_id().ok_or_else(|| {
1176 forge_core::ForgeError::Unauthorized("Authentication required".to_string())
1177 })?;
1178
1179 match owner_subject {
1180 Some(owner) if owner == principal => Ok(()),
1181 Some(_) => Err(forge_core::ForgeError::Forbidden(
1182 "Not authorized to access this resource".to_string(),
1183 )),
1184 None => Err(forge_core::ForgeError::Forbidden(
1185 "Resource has no owner; admin role required".to_string(),
1186 )),
1187 }
1188 }
1189
1190 pub fn stop(&self) {
1191 let _ = self.shutdown_tx.send(());
1192 self.change_listener.stop();
1193 }
1194
1195 pub async fn stats(&self) -> ReactorStats {
1196 let session_stats = self.session_server.stats();
1197 let inv_stats = self.invalidation_engine.stats().await;
1198
1199 ReactorStats {
1200 connections: session_stats.connections,
1201 subscriptions: session_stats.subscriptions,
1202 query_groups: self.subscription_manager.group_count(),
1203 pending_invalidations: inv_stats.pending_groups,
1204 listener_running: self.change_listener.is_running(),
1205 }
1206 }
1207}
1208
1209#[derive(Debug, Clone)]
1211pub struct ReactorStats {
1212 pub connections: usize,
1213 pub subscriptions: usize,
1214 pub query_groups: usize,
1215 pub pending_invalidations: usize,
1216 pub listener_running: bool,
1217}
1218
1219#[cfg(test)]
1220mod tests {
1221 use super::*;
1222 use std::collections::HashMap;
1223
1224 #[test]
1225 fn test_reactor_config_default() {
1226 let config = ReactorConfig::default();
1227 assert_eq!(config.listener.channel, "forge_changes");
1228 assert_eq!(config.invalidation.debounce_ms, 50);
1229 assert_eq!(config.max_listener_restarts, 5);
1230 assert_eq!(config.listener_restart_delay_ms, 1000);
1231 assert_eq!(config.max_concurrent_reexecutions, 64);
1232 assert_eq!(config.session_cleanup_interval_secs, 60);
1233 }
1234
1235 #[test]
1236 fn test_compute_hash() {
1237 let data1 = serde_json::json!({"name": "test"});
1238 let data2 = serde_json::json!({"name": "test"});
1239 let data3 = serde_json::json!({"name": "different"});
1240
1241 let hash1 = Reactor::compute_hash(&data1);
1242 let hash2 = Reactor::compute_hash(&data2);
1243 let hash3 = Reactor::compute_hash(&data3);
1244
1245 assert_eq!(hash1, hash2);
1246 assert_ne!(hash1, hash3);
1247 }
1248
1249 #[test]
1250 fn test_check_identity_args_rejects_cross_user() {
1251 let user_id = uuid::Uuid::new_v4();
1252 let auth = forge_core::function::AuthContext::authenticated(
1253 user_id,
1254 vec!["user".to_string()],
1255 HashMap::from([(
1256 "sub".to_string(),
1257 serde_json::Value::String(user_id.to_string()),
1258 )]),
1259 );
1260
1261 let result = Reactor::check_identity_args(
1262 "list_orders",
1263 &serde_json::json!({"user_id": uuid::Uuid::new_v4().to_string()}),
1264 &auth,
1265 true,
1266 );
1267 assert!(matches!(result, Err(forge_core::ForgeError::Forbidden(_))));
1268 }
1269
1270 #[test]
1271 fn test_check_identity_args_requires_scope_for_non_public_queries() {
1272 let user_id = uuid::Uuid::new_v4();
1273 let auth = forge_core::function::AuthContext::authenticated(
1274 user_id,
1275 vec!["user".to_string()],
1276 HashMap::from([(
1277 "sub".to_string(),
1278 serde_json::Value::String(user_id.to_string()),
1279 )]),
1280 );
1281
1282 let result =
1283 Reactor::check_identity_args("list_orders", &serde_json::json!({}), &auth, true);
1284 assert!(matches!(result, Err(forge_core::ForgeError::Forbidden(_))));
1285 }
1286
1287 #[test]
1288 fn test_check_owner_access_allows_admin() {
1289 let auth = forge_core::function::AuthContext::authenticated_without_uuid(
1290 vec!["admin".to_string()],
1291 HashMap::from([(
1292 "sub".to_string(),
1293 serde_json::Value::String("admin-1".to_string()),
1294 )]),
1295 );
1296
1297 let result = Reactor::check_owner_access(Some("other-user".to_string()), &auth);
1298 assert!(result.is_ok());
1299 }
1300}