Skip to main content

forge_runtime/realtime/
reactor.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use futures_util::StreamExt;
5use futures_util::stream::FuturesUnordered;
6use tokio::sync::{RwLock, Semaphore, broadcast, mpsc};
7use uuid::Uuid;
8
9use forge_core::cluster::NodeId;
10use forge_core::realtime::{Change, ReadSet, SessionId, SubscriptionId};
11
12use super::invalidation::{InvalidationConfig, InvalidationEngine};
13use super::listener::{ChangeListener, ListenerConfig};
14use super::manager::SubscriptionManager;
15use super::message::{
16    JobData, RealtimeConfig, RealtimeMessage, SessionServer, WorkflowData, WorkflowStepData,
17};
18use crate::function::{FunctionEntry, FunctionRegistry};
19
20#[derive(Debug, Clone)]
21pub struct ReactorConfig {
22    pub listener: ListenerConfig,
23    pub invalidation: InvalidationConfig,
24    pub realtime: RealtimeConfig,
25    pub max_listener_restarts: u32,
26    /// Doubles with each attempt for exponential backoff
27    pub listener_restart_delay_ms: u64,
28    /// Maximum concurrent re-executions during flush.
29    pub max_concurrent_reexecutions: usize,
30    /// Session cleanup interval in seconds.
31    pub session_cleanup_interval_secs: u64,
32}
33
34impl Default for ReactorConfig {
35    fn default() -> Self {
36        Self {
37            listener: ListenerConfig::default(),
38            invalidation: InvalidationConfig::default(),
39            realtime: RealtimeConfig::default(),
40            max_listener_restarts: 5,
41            listener_restart_delay_ms: 1000,
42            max_concurrent_reexecutions: 64,
43            session_cleanup_interval_secs: 60,
44        }
45    }
46}
47
48/// Job subscription tracking.
49#[derive(Debug, Clone)]
50pub struct JobSubscription {
51    #[allow(dead_code)]
52    pub subscription_id: SubscriptionId,
53    pub session_id: SessionId,
54    pub client_sub_id: String,
55    #[allow(dead_code)]
56    pub job_id: Uuid,
57    pub auth_context: forge_core::function::AuthContext,
58}
59
60/// Workflow subscription tracking.
61#[derive(Debug, Clone)]
62pub struct WorkflowSubscription {
63    #[allow(dead_code)]
64    pub subscription_id: SubscriptionId,
65    pub session_id: SessionId,
66    pub client_sub_id: String,
67    #[allow(dead_code)]
68    pub workflow_id: Uuid,
69    pub auth_context: forge_core::function::AuthContext,
70}
71
72/// ChangeListener -> InvalidationEngine -> Group Re-execution -> SSE Fan-out
73pub struct Reactor {
74    node_id: NodeId,
75    db_pool: sqlx::PgPool,
76    registry: FunctionRegistry,
77    subscription_manager: Arc<SubscriptionManager>,
78    session_server: Arc<SessionServer>,
79    change_listener: Arc<ChangeListener>,
80    invalidation_engine: Arc<InvalidationEngine>,
81    /// Job subscriptions: job_id -> list of subscribers.
82    job_subscriptions: Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
83    /// Workflow subscriptions: workflow_id -> list of subscribers.
84    workflow_subscriptions: Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
85    shutdown_tx: broadcast::Sender<()>,
86    max_listener_restarts: u32,
87    listener_restart_delay_ms: u64,
88    max_concurrent_reexecutions: usize,
89    session_cleanup_interval_secs: u64,
90}
91
92impl Reactor {
93    /// Create a new reactor.
94    pub fn new(
95        node_id: NodeId,
96        db_pool: sqlx::PgPool,
97        registry: FunctionRegistry,
98        config: ReactorConfig,
99    ) -> Self {
100        let subscription_manager = Arc::new(SubscriptionManager::new(
101            config.realtime.max_subscriptions_per_session,
102        ));
103        let session_server = Arc::new(SessionServer::new(node_id, config.realtime.clone()));
104        let change_listener = Arc::new(ChangeListener::new(db_pool.clone(), config.listener));
105        let invalidation_engine = Arc::new(InvalidationEngine::new(
106            subscription_manager.clone(),
107            config.invalidation,
108        ));
109        let (shutdown_tx, _) = broadcast::channel(1);
110
111        Self {
112            node_id,
113            db_pool,
114            registry,
115            subscription_manager,
116            session_server,
117            change_listener,
118            invalidation_engine,
119            job_subscriptions: Arc::new(RwLock::new(HashMap::new())),
120            workflow_subscriptions: Arc::new(RwLock::new(HashMap::new())),
121            shutdown_tx,
122            max_listener_restarts: config.max_listener_restarts,
123            listener_restart_delay_ms: config.listener_restart_delay_ms,
124            max_concurrent_reexecutions: config.max_concurrent_reexecutions,
125            session_cleanup_interval_secs: config.session_cleanup_interval_secs,
126        }
127    }
128
129    pub fn node_id(&self) -> NodeId {
130        self.node_id
131    }
132
133    pub fn session_server(&self) -> Arc<SessionServer> {
134        self.session_server.clone()
135    }
136
137    pub fn subscription_manager(&self) -> Arc<SubscriptionManager> {
138        self.subscription_manager.clone()
139    }
140
141    pub fn shutdown_receiver(&self) -> broadcast::Receiver<()> {
142        self.shutdown_tx.subscribe()
143    }
144
145    /// Register a new session.
146    pub fn register_session(&self, session_id: SessionId, sender: mpsc::Sender<RealtimeMessage>) {
147        self.session_server.register_connection(session_id, sender);
148        tracing::trace!(?session_id, "Session registered");
149    }
150
151    /// Remove a session and all its subscriptions.
152    pub async fn remove_session(&self, session_id: SessionId) {
153        // Clean up query subscriptions via the subscription manager
154        self.subscription_manager
155            .remove_session_subscriptions(session_id);
156
157        // Clean up session server
158        self.session_server.remove_connection(session_id);
159
160        // Clean up job subscriptions
161        {
162            let mut job_subs = self.job_subscriptions.write().await;
163            for subscribers in job_subs.values_mut() {
164                subscribers.retain(|s| s.session_id != session_id);
165            }
166            job_subs.retain(|_, v| !v.is_empty());
167        }
168
169        // Clean up workflow subscriptions
170        {
171            let mut workflow_subs = self.workflow_subscriptions.write().await;
172            for subscribers in workflow_subs.values_mut() {
173                subscribers.retain(|s| s.session_id != session_id);
174            }
175            workflow_subs.retain(|_, v| !v.is_empty());
176        }
177
178        tracing::trace!(?session_id, "Session removed");
179    }
180
181    /// Subscribe to a query. Uses query groups for coalescing.
182    pub async fn subscribe(
183        &self,
184        session_id: SessionId,
185        client_sub_id: String,
186        query_name: String,
187        args: serde_json::Value,
188        auth_context: forge_core::function::AuthContext,
189    ) -> forge_core::Result<(SubscriptionId, serde_json::Value)> {
190        // Look up function info for compile-time metadata
191        let (table_deps, selected_cols) = match self.registry.get(&query_name) {
192            Some(FunctionEntry::Query { info, .. }) => {
193                (info.table_dependencies, info.selected_columns)
194            }
195            _ => (&[] as &[&str], &[] as &[&str]),
196        };
197
198        let (group_id, subscription_id, is_new_group) = self.subscription_manager.subscribe(
199            session_id,
200            client_sub_id,
201            &query_name,
202            &args,
203            &auth_context,
204            table_deps,
205            selected_cols,
206        )?;
207
208        // Register subscription in session server for message routing
209        if let Err(error) = self
210            .session_server
211            .add_subscription(session_id, subscription_id)
212        {
213            self.subscription_manager.unsubscribe(subscription_id);
214            return Err(error);
215        }
216
217        // Only execute the query if this is a new group (no cached result yet)
218        let data = if is_new_group {
219            let (data, read_set) = match self.execute_query(&query_name, &args, &auth_context).await
220            {
221                Ok(result) => result,
222                Err(error) => {
223                    self.unsubscribe(subscription_id);
224                    return Err(error);
225                }
226            };
227
228            let result_hash = Self::compute_hash(&data);
229
230            tracing::trace!(
231                ?group_id,
232                query = %query_name,
233                "New query group created"
234            );
235
236            self.subscription_manager
237                .update_group(group_id, read_set, result_hash);
238
239            data
240        } else {
241            // Group exists, re-execute to get fresh data for this subscriber
242            // (they might have joined mid-cycle)
243            let (data, _) = match self.execute_query(&query_name, &args, &auth_context).await {
244                Ok(result) => result,
245                Err(error) => {
246                    self.unsubscribe(subscription_id);
247                    return Err(error);
248                }
249            };
250            data
251        };
252
253        tracing::trace!(?subscription_id, "Subscription created");
254        Ok((subscription_id, data))
255    }
256
257    /// Unsubscribe from a query.
258    pub fn unsubscribe(&self, subscription_id: SubscriptionId) {
259        self.session_server.remove_subscription(subscription_id);
260        self.subscription_manager.unsubscribe(subscription_id);
261        tracing::trace!(?subscription_id, "Subscription removed");
262    }
263
264    /// Subscribe to job progress updates.
265    pub async fn subscribe_job(
266        &self,
267        session_id: SessionId,
268        client_sub_id: String,
269        job_id: Uuid,
270        auth_context: &forge_core::function::AuthContext,
271    ) -> forge_core::Result<JobData> {
272        let subscription_id = SubscriptionId::new();
273        Self::ensure_job_access(&self.db_pool, job_id, auth_context).await?;
274        let job_data = self.fetch_job_data(job_id).await?;
275
276        let subscription = JobSubscription {
277            subscription_id,
278            session_id,
279            client_sub_id,
280            job_id,
281            auth_context: auth_context.clone(),
282        };
283
284        let mut subs = self.job_subscriptions.write().await;
285        subs.entry(job_id).or_default().push(subscription);
286
287        tracing::trace!(?subscription_id, %job_id, "Job subscription created");
288        Ok(job_data)
289    }
290
291    /// Unsubscribe from job updates.
292    pub async fn unsubscribe_job(&self, session_id: SessionId, client_sub_id: &str) {
293        let mut subs = self.job_subscriptions.write().await;
294        for subscribers in subs.values_mut() {
295            subscribers
296                .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id));
297        }
298        subs.retain(|_, v| !v.is_empty());
299    }
300
301    /// Subscribe to workflow progress updates.
302    pub async fn subscribe_workflow(
303        &self,
304        session_id: SessionId,
305        client_sub_id: String,
306        workflow_id: Uuid,
307        auth_context: &forge_core::function::AuthContext,
308    ) -> forge_core::Result<WorkflowData> {
309        let subscription_id = SubscriptionId::new();
310        Self::ensure_workflow_access(&self.db_pool, workflow_id, auth_context).await?;
311        let workflow_data = self.fetch_workflow_data(workflow_id).await?;
312
313        let subscription = WorkflowSubscription {
314            subscription_id,
315            session_id,
316            client_sub_id,
317            workflow_id,
318            auth_context: auth_context.clone(),
319        };
320
321        let mut subs = self.workflow_subscriptions.write().await;
322        subs.entry(workflow_id).or_default().push(subscription);
323
324        tracing::trace!(?subscription_id, %workflow_id, "Workflow subscription created");
325        Ok(workflow_data)
326    }
327
328    /// Unsubscribe from workflow updates.
329    pub async fn unsubscribe_workflow(&self, session_id: SessionId, client_sub_id: &str) {
330        let mut subs = self.workflow_subscriptions.write().await;
331        for subscribers in subs.values_mut() {
332            subscribers
333                .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id));
334        }
335        subs.retain(|_, v| !v.is_empty());
336    }
337
338    #[allow(clippy::type_complexity)]
339    async fn fetch_job_data(&self, job_id: Uuid) -> forge_core::Result<JobData> {
340        Self::fetch_job_data_static(job_id, &self.db_pool).await
341    }
342
343    async fn fetch_workflow_data(&self, workflow_id: Uuid) -> forge_core::Result<WorkflowData> {
344        Self::fetch_workflow_data_static(workflow_id, &self.db_pool).await
345    }
346
347    /// Execute a query and return data with read set.
348    async fn execute_query(
349        &self,
350        query_name: &str,
351        args: &serde_json::Value,
352        auth_context: &forge_core::function::AuthContext,
353    ) -> forge_core::Result<(serde_json::Value, ReadSet)> {
354        Self::execute_query_static(
355            &self.registry,
356            &self.db_pool,
357            query_name,
358            args,
359            auth_context,
360        )
361        .await
362    }
363
364    fn compute_hash(data: &serde_json::Value) -> String {
365        use std::collections::hash_map::DefaultHasher;
366        use std::hash::{Hash, Hasher};
367
368        let json = serde_json::to_string(data).unwrap_or_default();
369        let mut hasher = DefaultHasher::new();
370        json.hash(&mut hasher);
371        format!("{:x}", hasher.finish())
372    }
373
374    /// Parallel group re-execution with bounded concurrency.
375    async fn flush_invalidations(
376        invalidation_engine: &Arc<InvalidationEngine>,
377        subscription_manager: &Arc<SubscriptionManager>,
378        session_server: &Arc<SessionServer>,
379        registry: &FunctionRegistry,
380        db_pool: &sqlx::PgPool,
381        max_concurrent: usize,
382    ) {
383        let invalidated_groups = invalidation_engine.check_pending().await;
384        if invalidated_groups.is_empty() {
385            return;
386        }
387
388        tracing::trace!(
389            count = invalidated_groups.len(),
390            "Invalidating query groups"
391        );
392
393        // Collect group data we need for re-execution
394        let groups_to_process: Vec<_> = invalidated_groups
395            .iter()
396            .filter_map(|gid| {
397                subscription_manager.get_group(*gid).map(|g| {
398                    (
399                        g.id,
400                        g.query_name.clone(),
401                        (*g.args).clone(),
402                        g.last_result_hash.clone(),
403                        g.auth_context.clone(),
404                    )
405                })
406            })
407            .collect();
408
409        // Parallel re-execution bounded by semaphore
410        let semaphore = Arc::new(Semaphore::new(max_concurrent));
411        let mut futures = FuturesUnordered::new();
412
413        for (group_id, query_name, args, last_hash, auth_context) in groups_to_process {
414            let permit = match semaphore.clone().acquire_owned().await {
415                Ok(p) => p,
416                Err(_) => break,
417            };
418            let registry = registry.clone();
419            let db_pool = db_pool.clone();
420
421            futures.push(async move {
422                let result = Self::execute_query_static(
423                    &registry,
424                    &db_pool,
425                    &query_name,
426                    &args,
427                    &auth_context,
428                )
429                .await;
430                drop(permit);
431                (group_id, last_hash, result)
432            });
433        }
434
435        // Process results and fan out to subscribers
436        while let Some((group_id, last_hash, result)) = futures.next().await {
437            match result {
438                Ok((new_data, read_set)) => {
439                    let new_hash = Self::compute_hash(&new_data);
440
441                    if last_hash.as_ref() != Some(&new_hash) {
442                        // Update group state
443                        subscription_manager.update_group(group_id, read_set, new_hash);
444
445                        // Fan out to all subscribers in this group
446                        let subscribers = subscription_manager.get_group_subscribers(group_id);
447                        for (session_id, client_sub_id) in subscribers {
448                            let message = RealtimeMessage::Data {
449                                subscription_id: client_sub_id.clone(),
450                                data: new_data.clone(),
451                            };
452
453                            if let Err(e) = session_server.try_send_to_session(session_id, message)
454                            {
455                                tracing::trace!(
456                                    client_id = %client_sub_id,
457                                    error = ?e,
458                                    "Failed to send update to subscriber"
459                                );
460                            }
461                        }
462                    }
463                }
464                Err(e) => {
465                    tracing::warn!(?group_id, error = %e, "Failed to re-execute query group");
466                }
467            }
468        }
469    }
470
471    /// Start the reactor.
472    pub async fn start(&self) -> forge_core::Result<()> {
473        let listener = self.change_listener.clone();
474        let invalidation_engine = self.invalidation_engine.clone();
475        let subscription_manager = self.subscription_manager.clone();
476        let job_subscriptions = self.job_subscriptions.clone();
477        let workflow_subscriptions = self.workflow_subscriptions.clone();
478        let session_server = self.session_server.clone();
479        let registry = self.registry.clone();
480        let db_pool = self.db_pool.clone();
481        let mut shutdown_rx = self.shutdown_tx.subscribe();
482        let max_restarts = self.max_listener_restarts;
483        let base_delay_ms = self.listener_restart_delay_ms;
484        let max_concurrent = self.max_concurrent_reexecutions;
485        let cleanup_secs = self.session_cleanup_interval_secs;
486
487        let mut change_rx = listener.subscribe();
488
489        tokio::spawn(async move {
490            tracing::debug!("Reactor listening for changes");
491
492            let mut restart_count: u32 = 0;
493            let (listener_error_tx, mut listener_error_rx) = mpsc::channel::<String>(1);
494
495            // Start initial listener
496            let listener_clone = listener.clone();
497            let error_tx = listener_error_tx.clone();
498            let mut listener_handle = Some(tokio::spawn(async move {
499                if let Err(e) = listener_clone.run().await {
500                    let _ = error_tx.send(format!("Change listener error: {}", e)).await;
501                }
502            }));
503
504            let mut flush_interval = tokio::time::interval(std::time::Duration::from_millis(25));
505            flush_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
506
507            let mut cleanup_interval =
508                tokio::time::interval(std::time::Duration::from_secs(cleanup_secs));
509            cleanup_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
510
511            loop {
512                tokio::select! {
513                    result = change_rx.recv() => {
514                        match result {
515                            Ok(change) => {
516                                Self::handle_change(
517                                    &change,
518                                    &invalidation_engine,
519                                    &job_subscriptions,
520                                    &workflow_subscriptions,
521                                    &session_server,
522                                    &db_pool,
523                                ).await;
524                            }
525                            Err(broadcast::error::RecvError::Lagged(n)) => {
526                                tracing::warn!("Reactor lagged by {} messages", n);
527                            }
528                            Err(broadcast::error::RecvError::Closed) => {
529                                tracing::debug!("Change channel closed");
530                                break;
531                            }
532                        }
533                    }
534                    _ = flush_interval.tick() => {
535                        Self::flush_invalidations(
536                            &invalidation_engine,
537                            &subscription_manager,
538                            &session_server,
539                            &registry,
540                            &db_pool,
541                            max_concurrent,
542                        ).await;
543                    }
544                    _ = cleanup_interval.tick() => {
545                        session_server.cleanup_stale(std::time::Duration::from_secs(300));
546                    }
547                    Some(error_msg) = listener_error_rx.recv() => {
548                        if restart_count >= max_restarts {
549                            tracing::error!(
550                                attempts = restart_count,
551                                last_error = %error_msg,
552                                "Change listener failed permanently, real-time updates disabled"
553                            );
554                            break;
555                        }
556
557                        restart_count += 1;
558                        let delay = base_delay_ms * 2u64.saturating_pow(restart_count - 1);
559                        tracing::warn!(
560                            attempt = restart_count,
561                            max = max_restarts,
562                            delay_ms = delay,
563                            error = %error_msg,
564                            "Change listener restarting"
565                        );
566
567                        tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
568
569                        let listener_clone = listener.clone();
570                        let error_tx = listener_error_tx.clone();
571                        if let Some(handle) = listener_handle.take() {
572                            handle.abort();
573                        }
574                        change_rx = listener.subscribe();
575                        listener_handle = Some(tokio::spawn(async move {
576                            if let Err(e) = listener_clone.run().await {
577                                let _ = error_tx.send(format!("Change listener error: {}", e)).await;
578                            }
579                        }));
580                    }
581                    _ = shutdown_rx.recv() => {
582                        tracing::debug!("Reactor shutting down");
583                        break;
584                    }
585                }
586            }
587
588            if let Some(handle) = listener_handle {
589                handle.abort();
590            }
591        });
592
593        Ok(())
594    }
595
596    /// Handle a database change event.
597    #[allow(clippy::too_many_arguments)]
598    async fn handle_change(
599        change: &Change,
600        invalidation_engine: &Arc<InvalidationEngine>,
601        job_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
602        workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
603        session_server: &Arc<SessionServer>,
604        db_pool: &sqlx::PgPool,
605    ) {
606        tracing::trace!(table = %change.table, op = ?change.operation, row_id = ?change.row_id, "Processing change");
607
608        match change.table.as_str() {
609            "forge_jobs" => {
610                if let Some(job_id) = change.row_id {
611                    Self::handle_job_change(job_id, job_subscriptions, session_server, db_pool)
612                        .await;
613                }
614                return;
615            }
616            "forge_workflow_runs" => {
617                if let Some(workflow_id) = change.row_id {
618                    Self::handle_workflow_change(
619                        workflow_id,
620                        workflow_subscriptions,
621                        session_server,
622                        db_pool,
623                    )
624                    .await;
625                }
626                return;
627            }
628            "forge_workflow_steps" => {
629                if let Some(step_id) = change.row_id {
630                    Self::handle_workflow_step_change(
631                        step_id,
632                        workflow_subscriptions,
633                        session_server,
634                        db_pool,
635                    )
636                    .await;
637                }
638                return;
639            }
640            _ => {}
641        }
642
643        // Record change for debounced group invalidation
644        invalidation_engine.process_change(change.clone()).await;
645    }
646
647    async fn handle_job_change(
648        job_id: Uuid,
649        job_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
650        session_server: &Arc<SessionServer>,
651        db_pool: &sqlx::PgPool,
652    ) {
653        let subs = job_subscriptions.read().await;
654        let subscribers = match subs.get(&job_id) {
655            Some(s) if !s.is_empty() => s.clone(),
656            _ => return,
657        };
658        drop(subs);
659
660        let job_data = match Self::fetch_job_data_static(job_id, db_pool).await {
661            Ok(data) => data,
662            Err(e) => {
663                tracing::debug!(%job_id, error = %e, "Failed to fetch job data");
664                return;
665            }
666        };
667
668        let owner_subject = match Self::fetch_job_owner_subject_static(job_id, db_pool).await {
669            Ok(owner) => owner,
670            Err(e) => {
671                tracing::debug!(%job_id, error = %e, "Failed to fetch job owner");
672                return;
673            }
674        };
675
676        let mut unauthorized: HashSet<(SessionId, String)> = HashSet::new();
677
678        for sub in &subscribers {
679            if Self::check_owner_access(owner_subject.clone(), &sub.auth_context).is_err() {
680                unauthorized.insert((sub.session_id, sub.client_sub_id.clone()));
681                continue;
682            }
683
684            let message = RealtimeMessage::JobUpdate {
685                client_sub_id: sub.client_sub_id.clone(),
686                job: job_data.clone(),
687            };
688
689            if let Err(e) = session_server
690                .send_to_session(sub.session_id, message)
691                .await
692            {
693                tracing::trace!(%job_id, error = %e, "Failed to send job update");
694            }
695        }
696
697        if !unauthorized.is_empty() {
698            let mut subs = job_subscriptions.write().await;
699            if let Some(entries) = subs.get_mut(&job_id) {
700                entries
701                    .retain(|e| !unauthorized.contains(&(e.session_id, e.client_sub_id.clone())));
702            }
703            subs.retain(|_, v| !v.is_empty());
704        }
705    }
706
707    async fn handle_workflow_change(
708        workflow_id: Uuid,
709        workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
710        session_server: &Arc<SessionServer>,
711        db_pool: &sqlx::PgPool,
712    ) {
713        let subs = workflow_subscriptions.read().await;
714        let subscribers = match subs.get(&workflow_id) {
715            Some(s) if !s.is_empty() => s.clone(),
716            _ => return,
717        };
718        drop(subs);
719
720        let workflow_data = match Self::fetch_workflow_data_static(workflow_id, db_pool).await {
721            Ok(data) => data,
722            Err(e) => {
723                tracing::debug!(%workflow_id, error = %e, "Failed to fetch workflow data");
724                return;
725            }
726        };
727
728        let owner_subject =
729            match Self::fetch_workflow_owner_subject_static(workflow_id, db_pool).await {
730                Ok(owner) => owner,
731                Err(e) => {
732                    tracing::debug!(%workflow_id, error = %e, "Failed to fetch workflow owner");
733                    return;
734                }
735            };
736
737        let mut unauthorized: HashSet<(SessionId, String)> = HashSet::new();
738
739        for sub in &subscribers {
740            if Self::check_owner_access(owner_subject.clone(), &sub.auth_context).is_err() {
741                unauthorized.insert((sub.session_id, sub.client_sub_id.clone()));
742                continue;
743            }
744
745            let message = RealtimeMessage::WorkflowUpdate {
746                client_sub_id: sub.client_sub_id.clone(),
747                workflow: workflow_data.clone(),
748            };
749
750            if let Err(e) = session_server
751                .send_to_session(sub.session_id, message)
752                .await
753            {
754                tracing::trace!(%workflow_id, error = %e, "Failed to send workflow update");
755            }
756        }
757
758        if !unauthorized.is_empty() {
759            let mut subs = workflow_subscriptions.write().await;
760            if let Some(entries) = subs.get_mut(&workflow_id) {
761                entries
762                    .retain(|e| !unauthorized.contains(&(e.session_id, e.client_sub_id.clone())));
763            }
764            subs.retain(|_, v| !v.is_empty());
765        }
766    }
767
768    async fn handle_workflow_step_change(
769        step_id: Uuid,
770        workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
771        session_server: &Arc<SessionServer>,
772        db_pool: &sqlx::PgPool,
773    ) {
774        let workflow_id: Option<Uuid> = match sqlx::query_scalar(
775            "SELECT workflow_run_id FROM forge_workflow_steps WHERE id = $1",
776        )
777        .bind(step_id)
778        .fetch_optional(db_pool)
779        .await
780        {
781            Ok(id) => id,
782            Err(e) => {
783                tracing::debug!(%step_id, error = %e, "Failed to look up workflow for step");
784                return;
785            }
786        };
787
788        if let Some(wf_id) = workflow_id {
789            Self::handle_workflow_change(wf_id, workflow_subscriptions, session_server, db_pool)
790                .await;
791        }
792    }
793
794    #[allow(clippy::type_complexity)]
795    async fn fetch_job_data_static(
796        job_id: Uuid,
797        db_pool: &sqlx::PgPool,
798    ) -> forge_core::Result<JobData> {
799        let row: Option<(
800            String,
801            Option<i32>,
802            Option<String>,
803            Option<serde_json::Value>,
804            Option<String>,
805        )> = sqlx::query_as(
806            r#"
807                SELECT status, progress_percent, progress_message, output, last_error
808                FROM forge_jobs WHERE id = $1
809                "#,
810        )
811        .bind(job_id)
812        .fetch_optional(db_pool)
813        .await
814        .map_err(forge_core::ForgeError::Sql)?;
815
816        match row {
817            Some((status, progress_percent, progress_message, output, error)) => Ok(JobData {
818                job_id: job_id.to_string(),
819                status,
820                progress_percent,
821                progress_message,
822                output,
823                error,
824            }),
825            None => Err(forge_core::ForgeError::NotFound(format!(
826                "Job {} not found",
827                job_id
828            ))),
829        }
830    }
831
832    async fn fetch_job_owner_subject_static(
833        job_id: Uuid,
834        db_pool: &sqlx::PgPool,
835    ) -> forge_core::Result<Option<String>> {
836        let owner_subject: Option<Option<String>> =
837            sqlx::query_scalar("SELECT owner_subject FROM forge_jobs WHERE id = $1")
838                .bind(job_id)
839                .fetch_optional(db_pool)
840                .await
841                .map_err(forge_core::ForgeError::Sql)?;
842
843        owner_subject
844            .ok_or_else(|| forge_core::ForgeError::NotFound(format!("Job {} not found", job_id)))
845    }
846
847    #[allow(clippy::type_complexity)]
848    async fn fetch_workflow_data_static(
849        workflow_id: Uuid,
850        db_pool: &sqlx::PgPool,
851    ) -> forge_core::Result<WorkflowData> {
852        let row: Option<(
853            String,
854            Option<String>,
855            Option<serde_json::Value>,
856            Option<String>,
857        )> = sqlx::query_as(
858            r#"
859                SELECT status, current_step, output, error
860                FROM forge_workflow_runs WHERE id = $1
861                "#,
862        )
863        .bind(workflow_id)
864        .fetch_optional(db_pool)
865        .await
866        .map_err(forge_core::ForgeError::Sql)?;
867
868        let (status, current_step, output, error) = match row {
869            Some(r) => r,
870            None => {
871                return Err(forge_core::ForgeError::NotFound(format!(
872                    "Workflow {} not found",
873                    workflow_id
874                )));
875            }
876        };
877
878        let step_rows: Vec<(String, String, Option<String>)> = sqlx::query_as(
879            r#"
880            SELECT step_name, status, error
881            FROM forge_workflow_steps
882            WHERE workflow_run_id = $1
883            ORDER BY started_at ASC NULLS LAST
884            "#,
885        )
886        .bind(workflow_id)
887        .fetch_all(db_pool)
888        .await
889        .map_err(forge_core::ForgeError::Sql)?;
890
891        let steps = step_rows
892            .into_iter()
893            .map(|(name, status, error)| WorkflowStepData {
894                name,
895                status,
896                error,
897            })
898            .collect();
899
900        Ok(WorkflowData {
901            workflow_id: workflow_id.to_string(),
902            status,
903            current_step,
904            steps,
905            output,
906            error,
907        })
908    }
909
910    async fn fetch_workflow_owner_subject_static(
911        workflow_id: Uuid,
912        db_pool: &sqlx::PgPool,
913    ) -> forge_core::Result<Option<String>> {
914        let owner_subject: Option<Option<String>> =
915            sqlx::query_scalar("SELECT owner_subject FROM forge_workflow_runs WHERE id = $1")
916                .bind(workflow_id)
917                .fetch_optional(db_pool)
918                .await
919                .map_err(forge_core::ForgeError::Sql)?;
920
921        owner_subject.ok_or_else(|| {
922            forge_core::ForgeError::NotFound(format!("Workflow {} not found", workflow_id))
923        })
924    }
925
926    async fn execute_query_static(
927        registry: &FunctionRegistry,
928        db_pool: &sqlx::PgPool,
929        query_name: &str,
930        args: &serde_json::Value,
931        auth_context: &forge_core::function::AuthContext,
932    ) -> forge_core::Result<(serde_json::Value, ReadSet)> {
933        match registry.get(query_name) {
934            Some(FunctionEntry::Query { info, handler }) => {
935                Self::check_query_auth(info, auth_context)?;
936                let enforce = !info.is_public && info.has_input_args;
937                Self::check_identity_args(query_name, args, auth_context, enforce)?;
938
939                let ctx = forge_core::function::QueryContext::new(
940                    db_pool.clone(),
941                    auth_context.clone(),
942                    forge_core::function::RequestMetadata::new(),
943                );
944
945                let normalized_args = match args {
946                    v if v.as_object().is_some_and(|o| o.is_empty()) => serde_json::Value::Null,
947                    v => v.clone(),
948                };
949
950                let data = handler(&ctx, normalized_args).await?;
951
952                let mut read_set = ReadSet::new();
953
954                if info.table_dependencies.is_empty() {
955                    let table_name = Self::extract_table_name(query_name);
956                    read_set.add_table(&table_name);
957                    tracing::trace!(
958                        query = %query_name,
959                        fallback_table = %table_name,
960                        "Using naming convention fallback for table dependency"
961                    );
962                } else {
963                    for table in info.table_dependencies {
964                        read_set.add_table(*table);
965                    }
966                }
967
968                Ok((data, read_set))
969            }
970            _ => Err(forge_core::ForgeError::Validation(format!(
971                "Query '{}' not found or not a query",
972                query_name
973            ))),
974        }
975    }
976
977    fn extract_table_name(query_name: &str) -> String {
978        if let Some(rest) = query_name.strip_prefix("get_") {
979            rest.to_string()
980        } else if let Some(rest) = query_name.strip_prefix("list_") {
981            rest.to_string()
982        } else if let Some(rest) = query_name.strip_prefix("find_") {
983            rest.to_string()
984        } else if let Some(rest) = query_name.strip_prefix("fetch_") {
985            rest.to_string()
986        } else {
987            query_name.to_string()
988        }
989    }
990
991    fn check_query_auth(
992        info: &forge_core::function::FunctionInfo,
993        auth: &forge_core::function::AuthContext,
994    ) -> forge_core::Result<()> {
995        if info.is_public {
996            return Ok(());
997        }
998
999        if !auth.is_authenticated() {
1000            return Err(forge_core::ForgeError::Unauthorized(
1001                "Authentication required".into(),
1002            ));
1003        }
1004
1005        if let Some(role) = info.required_role
1006            && !auth.has_role(role)
1007        {
1008            return Err(forge_core::ForgeError::Forbidden(format!(
1009                "Role '{}' required",
1010                role
1011            )));
1012        }
1013
1014        Ok(())
1015    }
1016
1017    fn check_identity_args(
1018        function_name: &str,
1019        args: &serde_json::Value,
1020        auth: &forge_core::function::AuthContext,
1021        enforce_scope: bool,
1022    ) -> forge_core::Result<()> {
1023        if auth.is_admin() {
1024            return Ok(());
1025        }
1026
1027        let Some(obj) = args.as_object() else {
1028            if enforce_scope && auth.is_authenticated() {
1029                return Err(forge_core::ForgeError::Forbidden(format!(
1030                    "Function '{function_name}' must include identity or tenant scope arguments"
1031                )));
1032            }
1033            return Ok(());
1034        };
1035
1036        let mut principal_values: Vec<String> = Vec::new();
1037        if let Some(user_id) = auth.user_id().map(|id| id.to_string()) {
1038            principal_values.push(user_id);
1039        }
1040        if let Some(subject) = auth.principal_id()
1041            && !principal_values.iter().any(|v| v == &subject)
1042        {
1043            principal_values.push(subject);
1044        }
1045
1046        let mut has_scope_key = false;
1047
1048        for key in [
1049            "user_id",
1050            "userId",
1051            "owner_id",
1052            "ownerId",
1053            "owner_subject",
1054            "ownerSubject",
1055            "subject",
1056            "sub",
1057            "principal_id",
1058            "principalId",
1059        ] {
1060            let Some(value) = obj.get(key) else {
1061                continue;
1062            };
1063            has_scope_key = true;
1064
1065            if !auth.is_authenticated() {
1066                return Err(forge_core::ForgeError::Unauthorized(format!(
1067                    "Function '{function_name}' requires authentication for identity-scoped argument '{key}'"
1068                )));
1069            }
1070
1071            let serde_json::Value::String(actual) = value else {
1072                return Err(forge_core::ForgeError::InvalidArgument(format!(
1073                    "Function '{function_name}' argument '{key}' must be a non-empty string"
1074                )));
1075            };
1076
1077            if actual.trim().is_empty() || !principal_values.iter().any(|v| v == actual) {
1078                return Err(forge_core::ForgeError::Forbidden(format!(
1079                    "Function '{function_name}' argument '{key}' does not match authenticated principal"
1080                )));
1081            }
1082        }
1083
1084        for key in ["tenant_id", "tenantId"] {
1085            let Some(value) = obj.get(key) else {
1086                continue;
1087            };
1088            has_scope_key = true;
1089
1090            if !auth.is_authenticated() {
1091                return Err(forge_core::ForgeError::Unauthorized(format!(
1092                    "Function '{function_name}' requires authentication for tenant-scoped argument '{key}'"
1093                )));
1094            }
1095
1096            let expected = auth
1097                .claim("tenant_id")
1098                .and_then(|v| v.as_str())
1099                .ok_or_else(|| {
1100                    forge_core::ForgeError::Forbidden(format!(
1101                        "Function '{function_name}' argument '{key}' is not allowed for this principal"
1102                    ))
1103                })?;
1104
1105            let serde_json::Value::String(actual) = value else {
1106                return Err(forge_core::ForgeError::InvalidArgument(format!(
1107                    "Function '{function_name}' argument '{key}' must be a non-empty string"
1108                )));
1109            };
1110
1111            if actual.trim().is_empty() || actual != expected {
1112                return Err(forge_core::ForgeError::Forbidden(format!(
1113                    "Function '{function_name}' argument '{key}' does not match authenticated tenant"
1114                )));
1115            }
1116        }
1117
1118        if enforce_scope && auth.is_authenticated() && !has_scope_key {
1119            return Err(forge_core::ForgeError::Forbidden(format!(
1120                "Function '{function_name}' must include identity or tenant scope arguments"
1121            )));
1122        }
1123
1124        Ok(())
1125    }
1126
1127    async fn ensure_job_access(
1128        db_pool: &sqlx::PgPool,
1129        job_id: Uuid,
1130        auth: &forge_core::function::AuthContext,
1131    ) -> forge_core::Result<()> {
1132        let owner_subject_row: Option<(Option<String>,)> =
1133            sqlx::query_as(r#"SELECT owner_subject FROM forge_jobs WHERE id = $1"#)
1134                .bind(job_id)
1135                .fetch_optional(db_pool)
1136                .await
1137                .map_err(forge_core::ForgeError::Sql)?;
1138
1139        let owner_subject = owner_subject_row
1140            .ok_or_else(|| forge_core::ForgeError::NotFound(format!("Job {} not found", job_id)))?
1141            .0;
1142
1143        Self::check_owner_access(owner_subject, auth)
1144    }
1145
1146    async fn ensure_workflow_access(
1147        db_pool: &sqlx::PgPool,
1148        workflow_id: Uuid,
1149        auth: &forge_core::function::AuthContext,
1150    ) -> forge_core::Result<()> {
1151        let owner_subject_row: Option<(Option<String>,)> =
1152            sqlx::query_as(r#"SELECT owner_subject FROM forge_workflow_runs WHERE id = $1"#)
1153                .bind(workflow_id)
1154                .fetch_optional(db_pool)
1155                .await
1156                .map_err(forge_core::ForgeError::Sql)?;
1157
1158        let owner_subject = owner_subject_row
1159            .ok_or_else(|| {
1160                forge_core::ForgeError::NotFound(format!("Workflow {} not found", workflow_id))
1161            })?
1162            .0;
1163
1164        Self::check_owner_access(owner_subject, auth)
1165    }
1166
1167    fn check_owner_access(
1168        owner_subject: Option<String>,
1169        auth: &forge_core::function::AuthContext,
1170    ) -> forge_core::Result<()> {
1171        if auth.is_admin() {
1172            return Ok(());
1173        }
1174
1175        let principal = auth.principal_id().ok_or_else(|| {
1176            forge_core::ForgeError::Unauthorized("Authentication required".to_string())
1177        })?;
1178
1179        match owner_subject {
1180            Some(owner) if owner == principal => Ok(()),
1181            Some(_) => Err(forge_core::ForgeError::Forbidden(
1182                "Not authorized to access this resource".to_string(),
1183            )),
1184            None => Err(forge_core::ForgeError::Forbidden(
1185                "Resource has no owner; admin role required".to_string(),
1186            )),
1187        }
1188    }
1189
1190    pub fn stop(&self) {
1191        let _ = self.shutdown_tx.send(());
1192        self.change_listener.stop();
1193    }
1194
1195    pub async fn stats(&self) -> ReactorStats {
1196        let session_stats = self.session_server.stats();
1197        let inv_stats = self.invalidation_engine.stats().await;
1198
1199        ReactorStats {
1200            connections: session_stats.connections,
1201            subscriptions: session_stats.subscriptions,
1202            query_groups: self.subscription_manager.group_count(),
1203            pending_invalidations: inv_stats.pending_groups,
1204            listener_running: self.change_listener.is_running(),
1205        }
1206    }
1207}
1208
1209/// Reactor statistics.
1210#[derive(Debug, Clone)]
1211pub struct ReactorStats {
1212    pub connections: usize,
1213    pub subscriptions: usize,
1214    pub query_groups: usize,
1215    pub pending_invalidations: usize,
1216    pub listener_running: bool,
1217}
1218
1219#[cfg(test)]
1220mod tests {
1221    use super::*;
1222    use std::collections::HashMap;
1223
1224    #[test]
1225    fn test_reactor_config_default() {
1226        let config = ReactorConfig::default();
1227        assert_eq!(config.listener.channel, "forge_changes");
1228        assert_eq!(config.invalidation.debounce_ms, 50);
1229        assert_eq!(config.max_listener_restarts, 5);
1230        assert_eq!(config.listener_restart_delay_ms, 1000);
1231        assert_eq!(config.max_concurrent_reexecutions, 64);
1232        assert_eq!(config.session_cleanup_interval_secs, 60);
1233    }
1234
1235    #[test]
1236    fn test_compute_hash() {
1237        let data1 = serde_json::json!({"name": "test"});
1238        let data2 = serde_json::json!({"name": "test"});
1239        let data3 = serde_json::json!({"name": "different"});
1240
1241        let hash1 = Reactor::compute_hash(&data1);
1242        let hash2 = Reactor::compute_hash(&data2);
1243        let hash3 = Reactor::compute_hash(&data3);
1244
1245        assert_eq!(hash1, hash2);
1246        assert_ne!(hash1, hash3);
1247    }
1248
1249    #[test]
1250    fn test_check_identity_args_rejects_cross_user() {
1251        let user_id = uuid::Uuid::new_v4();
1252        let auth = forge_core::function::AuthContext::authenticated(
1253            user_id,
1254            vec!["user".to_string()],
1255            HashMap::from([(
1256                "sub".to_string(),
1257                serde_json::Value::String(user_id.to_string()),
1258            )]),
1259        );
1260
1261        let result = Reactor::check_identity_args(
1262            "list_orders",
1263            &serde_json::json!({"user_id": uuid::Uuid::new_v4().to_string()}),
1264            &auth,
1265            true,
1266        );
1267        assert!(matches!(result, Err(forge_core::ForgeError::Forbidden(_))));
1268    }
1269
1270    #[test]
1271    fn test_check_identity_args_requires_scope_for_non_public_queries() {
1272        let user_id = uuid::Uuid::new_v4();
1273        let auth = forge_core::function::AuthContext::authenticated(
1274            user_id,
1275            vec!["user".to_string()],
1276            HashMap::from([(
1277                "sub".to_string(),
1278                serde_json::Value::String(user_id.to_string()),
1279            )]),
1280        );
1281
1282        let result =
1283            Reactor::check_identity_args("list_orders", &serde_json::json!({}), &auth, true);
1284        assert!(matches!(result, Err(forge_core::ForgeError::Forbidden(_))));
1285    }
1286
1287    #[test]
1288    fn test_check_owner_access_allows_admin() {
1289        let auth = forge_core::function::AuthContext::authenticated_without_uuid(
1290            vec!["admin".to_string()],
1291            HashMap::from([(
1292                "sub".to_string(),
1293                serde_json::Value::String("admin-1".to_string()),
1294            )]),
1295        );
1296
1297        let result = Reactor::check_owner_access(Some("other-user".to_string()), &auth);
1298        assert!(result.is_ok());
1299    }
1300}