Skip to main content

forge_runtime/realtime/
reactor.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use futures_util::StreamExt;
5use futures_util::stream::FuturesUnordered;
6use tokio::sync::{RwLock, Semaphore, broadcast, mpsc};
7use uuid::Uuid;
8
9use forge_core::cluster::NodeId;
10use forge_core::realtime::{Change, ReadSet, SessionId, SubscriptionId};
11
12use super::invalidation::{InvalidationConfig, InvalidationEngine};
13use super::listener::{ChangeListener, ListenerConfig};
14use super::manager::SubscriptionManager;
15use super::message::{
16    JobData, RealtimeConfig, RealtimeMessage, SessionServer, WorkflowData, WorkflowStepData,
17};
18use crate::function::{FunctionEntry, FunctionRegistry};
19
20#[derive(Debug, Clone)]
21pub struct ReactorConfig {
22    pub listener: ListenerConfig,
23    pub invalidation: InvalidationConfig,
24    pub realtime: RealtimeConfig,
25    pub max_listener_restarts: u32,
26    /// Doubles with each attempt for exponential backoff
27    pub listener_restart_delay_ms: u64,
28    /// Maximum concurrent re-executions during flush.
29    pub max_concurrent_reexecutions: usize,
30    /// Session cleanup interval in seconds.
31    pub session_cleanup_interval_secs: u64,
32}
33
34impl Default for ReactorConfig {
35    fn default() -> Self {
36        Self {
37            listener: ListenerConfig::default(),
38            invalidation: InvalidationConfig::default(),
39            realtime: RealtimeConfig::default(),
40            max_listener_restarts: 5,
41            listener_restart_delay_ms: 1000,
42            max_concurrent_reexecutions: 64,
43            session_cleanup_interval_secs: 60,
44        }
45    }
46}
47
48/// Job subscription tracking.
49/// The job_id is the HashMap key, not stored here.
50#[derive(Debug, Clone)]
51pub struct JobSubscription {
52    pub session_id: SessionId,
53    pub client_sub_id: String,
54    pub auth_context: forge_core::function::AuthContext,
55}
56
57/// Workflow subscription tracking.
58/// The workflow_id is the HashMap key, not stored here.
59#[derive(Debug, Clone)]
60pub struct WorkflowSubscription {
61    pub session_id: SessionId,
62    pub client_sub_id: String,
63    pub auth_context: forge_core::function::AuthContext,
64}
65
66/// ChangeListener -> InvalidationEngine -> Group Re-execution -> SSE Fan-out
67pub struct Reactor {
68    node_id: NodeId,
69    db_pool: sqlx::PgPool,
70    registry: FunctionRegistry,
71    subscription_manager: Arc<SubscriptionManager>,
72    session_server: Arc<SessionServer>,
73    change_listener: Arc<ChangeListener>,
74    invalidation_engine: Arc<InvalidationEngine>,
75    /// Job subscriptions: job_id -> list of subscribers.
76    job_subscriptions: Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
77    /// Workflow subscriptions: workflow_id -> list of subscribers.
78    workflow_subscriptions: Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
79    shutdown_tx: broadcast::Sender<()>,
80    max_listener_restarts: u32,
81    listener_restart_delay_ms: u64,
82    max_concurrent_reexecutions: usize,
83    session_cleanup_interval_secs: u64,
84}
85
86impl Reactor {
87    /// Create a new reactor.
88    pub fn new(
89        node_id: NodeId,
90        db_pool: sqlx::PgPool,
91        registry: FunctionRegistry,
92        config: ReactorConfig,
93    ) -> Self {
94        let subscription_manager = Arc::new(SubscriptionManager::new(
95            config.realtime.max_subscriptions_per_session,
96        ));
97        let session_server = Arc::new(SessionServer::new(node_id, config.realtime.clone()));
98        let change_listener = Arc::new(ChangeListener::new(db_pool.clone(), config.listener));
99        let invalidation_engine = Arc::new(InvalidationEngine::new(
100            subscription_manager.clone(),
101            config.invalidation,
102        ));
103        let (shutdown_tx, _) = broadcast::channel(1);
104
105        Self {
106            node_id,
107            db_pool,
108            registry,
109            subscription_manager,
110            session_server,
111            change_listener,
112            invalidation_engine,
113            job_subscriptions: Arc::new(RwLock::new(HashMap::new())),
114            workflow_subscriptions: Arc::new(RwLock::new(HashMap::new())),
115            shutdown_tx,
116            max_listener_restarts: config.max_listener_restarts,
117            listener_restart_delay_ms: config.listener_restart_delay_ms,
118            max_concurrent_reexecutions: config.max_concurrent_reexecutions,
119            session_cleanup_interval_secs: config.session_cleanup_interval_secs,
120        }
121    }
122
123    pub fn node_id(&self) -> NodeId {
124        self.node_id
125    }
126
127    pub fn session_server(&self) -> Arc<SessionServer> {
128        self.session_server.clone()
129    }
130
131    pub fn subscription_manager(&self) -> Arc<SubscriptionManager> {
132        self.subscription_manager.clone()
133    }
134
135    pub fn shutdown_receiver(&self) -> broadcast::Receiver<()> {
136        self.shutdown_tx.subscribe()
137    }
138
139    /// Register a new session.
140    pub fn register_session(&self, session_id: SessionId, sender: mpsc::Sender<RealtimeMessage>) {
141        self.session_server.register_connection(session_id, sender);
142        tracing::trace!(?session_id, "Session registered");
143    }
144
145    /// Remove a session and all its subscriptions.
146    pub async fn remove_session(&self, session_id: SessionId) {
147        // Clean up query subscriptions via the subscription manager
148        self.subscription_manager
149            .remove_session_subscriptions(session_id);
150
151        // Clean up session server
152        self.session_server.remove_connection(session_id);
153
154        // Clean up job subscriptions
155        {
156            let mut job_subs = self.job_subscriptions.write().await;
157            for subscribers in job_subs.values_mut() {
158                subscribers.retain(|s| s.session_id != session_id);
159            }
160            job_subs.retain(|_, v| !v.is_empty());
161        }
162
163        // Clean up workflow subscriptions
164        {
165            let mut workflow_subs = self.workflow_subscriptions.write().await;
166            for subscribers in workflow_subs.values_mut() {
167                subscribers.retain(|s| s.session_id != session_id);
168            }
169            workflow_subs.retain(|_, v| !v.is_empty());
170        }
171
172        tracing::trace!(?session_id, "Session removed");
173    }
174
175    /// Subscribe to a query. Uses query groups for coalescing.
176    pub async fn subscribe(
177        &self,
178        session_id: SessionId,
179        client_sub_id: String,
180        query_name: String,
181        args: serde_json::Value,
182        auth_context: forge_core::function::AuthContext,
183    ) -> forge_core::Result<(SubscriptionId, serde_json::Value)> {
184        // Look up function info for compile-time metadata
185        let (table_deps, selected_cols) = match self.registry.get(&query_name) {
186            Some(FunctionEntry::Query { info, .. }) => {
187                (info.table_dependencies, info.selected_columns)
188            }
189            _ => (&[] as &[&str], &[] as &[&str]),
190        };
191
192        let (group_id, subscription_id, is_new_group) = self.subscription_manager.subscribe(
193            session_id,
194            client_sub_id,
195            &query_name,
196            &args,
197            &auth_context,
198            table_deps,
199            selected_cols,
200        )?;
201
202        // Register subscription in session server for message routing
203        if let Err(error) = self
204            .session_server
205            .add_subscription(session_id, subscription_id)
206        {
207            self.subscription_manager.unsubscribe(subscription_id);
208            return Err(error);
209        }
210
211        // Only execute the query if this is a new group (no cached result yet)
212        let data = if is_new_group {
213            let (data, read_set) = match self.execute_query(&query_name, &args, &auth_context).await
214            {
215                Ok(result) => result,
216                Err(error) => {
217                    self.unsubscribe(subscription_id);
218                    return Err(error);
219                }
220            };
221
222            let result_hash = Self::compute_hash(&data);
223
224            tracing::trace!(
225                ?group_id,
226                query = %query_name,
227                "New query group created"
228            );
229
230            self.subscription_manager
231                .update_group(group_id, read_set, result_hash);
232
233            data
234        } else {
235            // Group exists, re-execute to get fresh data for this subscriber
236            // (they might have joined mid-cycle)
237            let (data, _) = match self.execute_query(&query_name, &args, &auth_context).await {
238                Ok(result) => result,
239                Err(error) => {
240                    self.unsubscribe(subscription_id);
241                    return Err(error);
242                }
243            };
244            data
245        };
246
247        tracing::trace!(?subscription_id, "Subscription created");
248        Ok((subscription_id, data))
249    }
250
251    /// Unsubscribe from a query.
252    pub fn unsubscribe(&self, subscription_id: SubscriptionId) {
253        self.session_server.remove_subscription(subscription_id);
254        self.subscription_manager.unsubscribe(subscription_id);
255        tracing::trace!(?subscription_id, "Subscription removed");
256    }
257
258    /// Subscribe to job progress updates.
259    pub async fn subscribe_job(
260        &self,
261        session_id: SessionId,
262        client_sub_id: String,
263        job_id: Uuid,
264        auth_context: &forge_core::function::AuthContext,
265    ) -> forge_core::Result<JobData> {
266        Self::ensure_job_access(&self.db_pool, job_id, auth_context).await?;
267        let job_data = self.fetch_job_data(job_id).await?;
268
269        let subscription = JobSubscription {
270            session_id,
271            client_sub_id,
272            auth_context: auth_context.clone(),
273        };
274
275        let mut subs = self.job_subscriptions.write().await;
276        subs.entry(job_id).or_default().push(subscription);
277
278        tracing::trace!(%job_id, %session_id, "Job subscription created");
279        Ok(job_data)
280    }
281
282    /// Unsubscribe from job updates.
283    pub async fn unsubscribe_job(&self, session_id: SessionId, client_sub_id: &str) {
284        let mut subs = self.job_subscriptions.write().await;
285        for subscribers in subs.values_mut() {
286            subscribers
287                .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id));
288        }
289        subs.retain(|_, v| !v.is_empty());
290    }
291
292    /// Subscribe to workflow progress updates.
293    pub async fn subscribe_workflow(
294        &self,
295        session_id: SessionId,
296        client_sub_id: String,
297        workflow_id: Uuid,
298        auth_context: &forge_core::function::AuthContext,
299    ) -> forge_core::Result<WorkflowData> {
300        Self::ensure_workflow_access(&self.db_pool, workflow_id, auth_context).await?;
301        let workflow_data = self.fetch_workflow_data(workflow_id).await?;
302
303        let subscription = WorkflowSubscription {
304            session_id,
305            client_sub_id,
306            auth_context: auth_context.clone(),
307        };
308
309        let mut subs = self.workflow_subscriptions.write().await;
310        subs.entry(workflow_id).or_default().push(subscription);
311
312        tracing::trace!(%workflow_id, %session_id, "Workflow subscription created");
313        Ok(workflow_data)
314    }
315
316    /// Unsubscribe from workflow updates.
317    pub async fn unsubscribe_workflow(&self, session_id: SessionId, client_sub_id: &str) {
318        let mut subs = self.workflow_subscriptions.write().await;
319        for subscribers in subs.values_mut() {
320            subscribers
321                .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id));
322        }
323        subs.retain(|_, v| !v.is_empty());
324    }
325
326    #[allow(clippy::type_complexity)]
327    async fn fetch_job_data(&self, job_id: Uuid) -> forge_core::Result<JobData> {
328        Self::fetch_job_data_static(job_id, &self.db_pool).await
329    }
330
331    async fn fetch_workflow_data(&self, workflow_id: Uuid) -> forge_core::Result<WorkflowData> {
332        Self::fetch_workflow_data_static(workflow_id, &self.db_pool).await
333    }
334
335    /// Execute a query and return data with read set.
336    async fn execute_query(
337        &self,
338        query_name: &str,
339        args: &serde_json::Value,
340        auth_context: &forge_core::function::AuthContext,
341    ) -> forge_core::Result<(serde_json::Value, ReadSet)> {
342        Self::execute_query_static(
343            &self.registry,
344            &self.db_pool,
345            query_name,
346            args,
347            auth_context,
348        )
349        .await
350    }
351
352    fn compute_hash(data: &serde_json::Value) -> String {
353        use std::collections::hash_map::DefaultHasher;
354        use std::hash::{Hash, Hasher};
355
356        let json = serde_json::to_string(data).unwrap_or_default();
357        let mut hasher = DefaultHasher::new();
358        json.hash(&mut hasher);
359        format!("{:x}", hasher.finish())
360    }
361
362    /// Parallel group re-execution with bounded concurrency.
363    async fn flush_invalidations(
364        invalidation_engine: &Arc<InvalidationEngine>,
365        subscription_manager: &Arc<SubscriptionManager>,
366        session_server: &Arc<SessionServer>,
367        registry: &FunctionRegistry,
368        db_pool: &sqlx::PgPool,
369        max_concurrent: usize,
370    ) {
371        let invalidated_groups = invalidation_engine.check_pending().await;
372        if invalidated_groups.is_empty() {
373            return;
374        }
375
376        tracing::trace!(
377            count = invalidated_groups.len(),
378            "Invalidating query groups"
379        );
380
381        // Collect group data we need for re-execution
382        let groups_to_process: Vec<_> = invalidated_groups
383            .iter()
384            .filter_map(|gid| {
385                subscription_manager.get_group(*gid).map(|g| {
386                    (
387                        g.id,
388                        g.query_name.clone(),
389                        (*g.args).clone(),
390                        g.last_result_hash.clone(),
391                        g.auth_context.clone(),
392                    )
393                })
394            })
395            .collect();
396
397        // Parallel re-execution bounded by semaphore
398        let semaphore = Arc::new(Semaphore::new(max_concurrent));
399        let mut futures = FuturesUnordered::new();
400
401        for (group_id, query_name, args, last_hash, auth_context) in groups_to_process {
402            let permit = match semaphore.clone().acquire_owned().await {
403                Ok(p) => p,
404                Err(_) => break,
405            };
406            let registry = registry.clone();
407            let db_pool = db_pool.clone();
408
409            futures.push(async move {
410                let result = Self::execute_query_static(
411                    &registry,
412                    &db_pool,
413                    &query_name,
414                    &args,
415                    &auth_context,
416                )
417                .await;
418                drop(permit);
419                (group_id, last_hash, result)
420            });
421        }
422
423        // Process results and fan out to subscribers
424        while let Some((group_id, last_hash, result)) = futures.next().await {
425            match result {
426                Ok((new_data, read_set)) => {
427                    let new_hash = Self::compute_hash(&new_data);
428
429                    if last_hash.as_ref() != Some(&new_hash) {
430                        // Update group state
431                        subscription_manager.update_group(group_id, read_set, new_hash);
432
433                        // Fan out to all subscribers in this group
434                        let subscribers = subscription_manager.get_group_subscribers(group_id);
435                        for (session_id, client_sub_id) in subscribers {
436                            let message = RealtimeMessage::Data {
437                                subscription_id: client_sub_id.clone(),
438                                data: new_data.clone(),
439                            };
440
441                            if let Err(e) = session_server.try_send_to_session(session_id, message)
442                            {
443                                tracing::trace!(
444                                    client_id = %client_sub_id,
445                                    error = ?e,
446                                    "Failed to send update to subscriber"
447                                );
448                            }
449                        }
450                    }
451                }
452                Err(e) => {
453                    tracing::warn!(?group_id, error = %e, "Failed to re-execute query group");
454                }
455            }
456        }
457    }
458
459    /// Start the reactor.
460    pub async fn start(&self) -> forge_core::Result<()> {
461        let listener = self.change_listener.clone();
462        let invalidation_engine = self.invalidation_engine.clone();
463        let subscription_manager = self.subscription_manager.clone();
464        let job_subscriptions = self.job_subscriptions.clone();
465        let workflow_subscriptions = self.workflow_subscriptions.clone();
466        let session_server = self.session_server.clone();
467        let registry = self.registry.clone();
468        let db_pool = self.db_pool.clone();
469        let mut shutdown_rx = self.shutdown_tx.subscribe();
470        let max_restarts = self.max_listener_restarts;
471        let base_delay_ms = self.listener_restart_delay_ms;
472        let max_concurrent = self.max_concurrent_reexecutions;
473        let cleanup_secs = self.session_cleanup_interval_secs;
474
475        let mut change_rx = listener.subscribe();
476
477        tokio::spawn(async move {
478            tracing::debug!("Reactor listening for changes");
479
480            let mut restart_count: u32 = 0;
481            let (listener_error_tx, mut listener_error_rx) = mpsc::channel::<String>(1);
482
483            // Start initial listener
484            let listener_clone = listener.clone();
485            let error_tx = listener_error_tx.clone();
486            let mut listener_handle = Some(tokio::spawn(async move {
487                if let Err(e) = listener_clone.run().await {
488                    let _ = error_tx.send(format!("Change listener error: {}", e)).await;
489                }
490            }));
491
492            let mut flush_interval = tokio::time::interval(std::time::Duration::from_millis(25));
493            flush_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
494
495            let mut cleanup_interval =
496                tokio::time::interval(std::time::Duration::from_secs(cleanup_secs));
497            cleanup_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
498
499            loop {
500                tokio::select! {
501                    result = change_rx.recv() => {
502                        match result {
503                            Ok(change) => {
504                                Self::handle_change(
505                                    &change,
506                                    &invalidation_engine,
507                                    &job_subscriptions,
508                                    &workflow_subscriptions,
509                                    &session_server,
510                                    &db_pool,
511                                ).await;
512                            }
513                            Err(broadcast::error::RecvError::Lagged(n)) => {
514                                tracing::warn!("Reactor lagged by {} messages", n);
515                            }
516                            Err(broadcast::error::RecvError::Closed) => {
517                                tracing::debug!("Change channel closed");
518                                break;
519                            }
520                        }
521                    }
522                    _ = flush_interval.tick() => {
523                        Self::flush_invalidations(
524                            &invalidation_engine,
525                            &subscription_manager,
526                            &session_server,
527                            &registry,
528                            &db_pool,
529                            max_concurrent,
530                        ).await;
531                    }
532                    _ = cleanup_interval.tick() => {
533                        session_server.cleanup_stale(std::time::Duration::from_secs(300));
534                    }
535                    Some(error_msg) = listener_error_rx.recv() => {
536                        if restart_count >= max_restarts {
537                            tracing::error!(
538                                attempts = restart_count,
539                                last_error = %error_msg,
540                                "Change listener failed permanently, real-time updates disabled"
541                            );
542                            break;
543                        }
544
545                        restart_count += 1;
546                        let delay = base_delay_ms * 2u64.saturating_pow(restart_count - 1);
547                        tracing::warn!(
548                            attempt = restart_count,
549                            max = max_restarts,
550                            delay_ms = delay,
551                            error = %error_msg,
552                            "Change listener restarting"
553                        );
554
555                        tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
556
557                        let listener_clone = listener.clone();
558                        let error_tx = listener_error_tx.clone();
559                        if let Some(handle) = listener_handle.take() {
560                            handle.abort();
561                        }
562                        change_rx = listener.subscribe();
563                        listener_handle = Some(tokio::spawn(async move {
564                            if let Err(e) = listener_clone.run().await {
565                                let _ = error_tx.send(format!("Change listener error: {}", e)).await;
566                            }
567                        }));
568                    }
569                    _ = shutdown_rx.recv() => {
570                        tracing::debug!("Reactor shutting down");
571                        break;
572                    }
573                }
574            }
575
576            if let Some(handle) = listener_handle {
577                handle.abort();
578            }
579        });
580
581        Ok(())
582    }
583
584    /// Handle a database change event.
585    #[allow(clippy::too_many_arguments)]
586    async fn handle_change(
587        change: &Change,
588        invalidation_engine: &Arc<InvalidationEngine>,
589        job_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
590        workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
591        session_server: &Arc<SessionServer>,
592        db_pool: &sqlx::PgPool,
593    ) {
594        tracing::trace!(table = %change.table, op = ?change.operation, row_id = ?change.row_id, "Processing change");
595
596        match change.table.as_str() {
597            "forge_jobs" => {
598                if let Some(job_id) = change.row_id {
599                    Self::handle_job_change(job_id, job_subscriptions, session_server, db_pool)
600                        .await;
601                }
602                return;
603            }
604            "forge_workflow_runs" => {
605                if let Some(workflow_id) = change.row_id {
606                    Self::handle_workflow_change(
607                        workflow_id,
608                        workflow_subscriptions,
609                        session_server,
610                        db_pool,
611                    )
612                    .await;
613                }
614                return;
615            }
616            "forge_workflow_steps" => {
617                if let Some(step_id) = change.row_id {
618                    Self::handle_workflow_step_change(
619                        step_id,
620                        workflow_subscriptions,
621                        session_server,
622                        db_pool,
623                    )
624                    .await;
625                }
626                return;
627            }
628            _ => {}
629        }
630
631        // Record change for debounced group invalidation
632        invalidation_engine.process_change(change.clone()).await;
633    }
634
635    async fn handle_job_change(
636        job_id: Uuid,
637        job_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
638        session_server: &Arc<SessionServer>,
639        db_pool: &sqlx::PgPool,
640    ) {
641        let subs = job_subscriptions.read().await;
642        let subscribers = match subs.get(&job_id) {
643            Some(s) if !s.is_empty() => s.clone(),
644            _ => return,
645        };
646        drop(subs);
647
648        let job_data = match Self::fetch_job_data_static(job_id, db_pool).await {
649            Ok(data) => data,
650            Err(e) => {
651                tracing::debug!(%job_id, error = %e, "Failed to fetch job data");
652                return;
653            }
654        };
655
656        let owner_subject = match Self::fetch_job_owner_subject_static(job_id, db_pool).await {
657            Ok(owner) => owner,
658            Err(e) => {
659                tracing::debug!(%job_id, error = %e, "Failed to fetch job owner");
660                return;
661            }
662        };
663
664        let mut unauthorized: HashSet<(SessionId, String)> = HashSet::new();
665
666        for sub in &subscribers {
667            if Self::check_owner_access(owner_subject.clone(), &sub.auth_context).is_err() {
668                unauthorized.insert((sub.session_id, sub.client_sub_id.clone()));
669                continue;
670            }
671
672            let message = RealtimeMessage::JobUpdate {
673                client_sub_id: sub.client_sub_id.clone(),
674                job: job_data.clone(),
675            };
676
677            if let Err(e) = session_server
678                .send_to_session(sub.session_id, message)
679                .await
680            {
681                tracing::trace!(%job_id, error = %e, "Failed to send job update");
682            }
683        }
684
685        if !unauthorized.is_empty() {
686            let mut subs = job_subscriptions.write().await;
687            if let Some(entries) = subs.get_mut(&job_id) {
688                entries
689                    .retain(|e| !unauthorized.contains(&(e.session_id, e.client_sub_id.clone())));
690            }
691            subs.retain(|_, v| !v.is_empty());
692        }
693    }
694
695    async fn handle_workflow_change(
696        workflow_id: Uuid,
697        workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
698        session_server: &Arc<SessionServer>,
699        db_pool: &sqlx::PgPool,
700    ) {
701        let subs = workflow_subscriptions.read().await;
702        let subscribers = match subs.get(&workflow_id) {
703            Some(s) if !s.is_empty() => s.clone(),
704            _ => return,
705        };
706        drop(subs);
707
708        let workflow_data = match Self::fetch_workflow_data_static(workflow_id, db_pool).await {
709            Ok(data) => data,
710            Err(e) => {
711                tracing::debug!(%workflow_id, error = %e, "Failed to fetch workflow data");
712                return;
713            }
714        };
715
716        let owner_subject =
717            match Self::fetch_workflow_owner_subject_static(workflow_id, db_pool).await {
718                Ok(owner) => owner,
719                Err(e) => {
720                    tracing::debug!(%workflow_id, error = %e, "Failed to fetch workflow owner");
721                    return;
722                }
723            };
724
725        let mut unauthorized: HashSet<(SessionId, String)> = HashSet::new();
726
727        for sub in &subscribers {
728            if Self::check_owner_access(owner_subject.clone(), &sub.auth_context).is_err() {
729                unauthorized.insert((sub.session_id, sub.client_sub_id.clone()));
730                continue;
731            }
732
733            let message = RealtimeMessage::WorkflowUpdate {
734                client_sub_id: sub.client_sub_id.clone(),
735                workflow: workflow_data.clone(),
736            };
737
738            if let Err(e) = session_server
739                .send_to_session(sub.session_id, message)
740                .await
741            {
742                tracing::trace!(%workflow_id, error = %e, "Failed to send workflow update");
743            }
744        }
745
746        if !unauthorized.is_empty() {
747            let mut subs = workflow_subscriptions.write().await;
748            if let Some(entries) = subs.get_mut(&workflow_id) {
749                entries
750                    .retain(|e| !unauthorized.contains(&(e.session_id, e.client_sub_id.clone())));
751            }
752            subs.retain(|_, v| !v.is_empty());
753        }
754    }
755
756    async fn handle_workflow_step_change(
757        step_id: Uuid,
758        workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
759        session_server: &Arc<SessionServer>,
760        db_pool: &sqlx::PgPool,
761    ) {
762        let workflow_id: Option<Uuid> = match sqlx::query_scalar(
763            "SELECT workflow_run_id FROM forge_workflow_steps WHERE id = $1",
764        )
765        .bind(step_id)
766        .fetch_optional(db_pool)
767        .await
768        {
769            Ok(id) => id,
770            Err(e) => {
771                tracing::debug!(%step_id, error = %e, "Failed to look up workflow for step");
772                return;
773            }
774        };
775
776        if let Some(wf_id) = workflow_id {
777            Self::handle_workflow_change(wf_id, workflow_subscriptions, session_server, db_pool)
778                .await;
779        }
780    }
781
782    #[allow(clippy::type_complexity)]
783    async fn fetch_job_data_static(
784        job_id: Uuid,
785        db_pool: &sqlx::PgPool,
786    ) -> forge_core::Result<JobData> {
787        let row: Option<(
788            String,
789            Option<i32>,
790            Option<String>,
791            Option<serde_json::Value>,
792            Option<String>,
793        )> = sqlx::query_as(
794            r#"
795                SELECT status, progress_percent, progress_message, output, last_error
796                FROM forge_jobs WHERE id = $1
797                "#,
798        )
799        .bind(job_id)
800        .fetch_optional(db_pool)
801        .await
802        .map_err(forge_core::ForgeError::Sql)?;
803
804        match row {
805            Some((status, progress_percent, progress_message, output, error)) => Ok(JobData {
806                job_id: job_id.to_string(),
807                status,
808                progress_percent,
809                progress_message,
810                output,
811                error,
812            }),
813            None => Err(forge_core::ForgeError::NotFound(format!(
814                "Job {} not found",
815                job_id
816            ))),
817        }
818    }
819
820    async fn fetch_job_owner_subject_static(
821        job_id: Uuid,
822        db_pool: &sqlx::PgPool,
823    ) -> forge_core::Result<Option<String>> {
824        let owner_subject: Option<Option<String>> =
825            sqlx::query_scalar("SELECT owner_subject FROM forge_jobs WHERE id = $1")
826                .bind(job_id)
827                .fetch_optional(db_pool)
828                .await
829                .map_err(forge_core::ForgeError::Sql)?;
830
831        owner_subject
832            .ok_or_else(|| forge_core::ForgeError::NotFound(format!("Job {} not found", job_id)))
833    }
834
835    #[allow(clippy::type_complexity)]
836    async fn fetch_workflow_data_static(
837        workflow_id: Uuid,
838        db_pool: &sqlx::PgPool,
839    ) -> forge_core::Result<WorkflowData> {
840        let row: Option<(
841            String,
842            Option<String>,
843            Option<serde_json::Value>,
844            Option<String>,
845        )> = sqlx::query_as(
846            r#"
847                SELECT status, current_step, output, error
848                FROM forge_workflow_runs WHERE id = $1
849                "#,
850        )
851        .bind(workflow_id)
852        .fetch_optional(db_pool)
853        .await
854        .map_err(forge_core::ForgeError::Sql)?;
855
856        let (status, current_step, output, error) = match row {
857            Some(r) => r,
858            None => {
859                return Err(forge_core::ForgeError::NotFound(format!(
860                    "Workflow {} not found",
861                    workflow_id
862                )));
863            }
864        };
865
866        let step_rows: Vec<(String, String, Option<String>)> = sqlx::query_as(
867            r#"
868            SELECT step_name, status, error
869            FROM forge_workflow_steps
870            WHERE workflow_run_id = $1
871            ORDER BY started_at ASC NULLS LAST
872            "#,
873        )
874        .bind(workflow_id)
875        .fetch_all(db_pool)
876        .await
877        .map_err(forge_core::ForgeError::Sql)?;
878
879        let steps = step_rows
880            .into_iter()
881            .map(|(name, status, error)| WorkflowStepData {
882                name,
883                status,
884                error,
885            })
886            .collect();
887
888        Ok(WorkflowData {
889            workflow_id: workflow_id.to_string(),
890            status,
891            current_step,
892            steps,
893            output,
894            error,
895        })
896    }
897
898    async fn fetch_workflow_owner_subject_static(
899        workflow_id: Uuid,
900        db_pool: &sqlx::PgPool,
901    ) -> forge_core::Result<Option<String>> {
902        let owner_subject: Option<Option<String>> =
903            sqlx::query_scalar("SELECT owner_subject FROM forge_workflow_runs WHERE id = $1")
904                .bind(workflow_id)
905                .fetch_optional(db_pool)
906                .await
907                .map_err(forge_core::ForgeError::Sql)?;
908
909        owner_subject.ok_or_else(|| {
910            forge_core::ForgeError::NotFound(format!("Workflow {} not found", workflow_id))
911        })
912    }
913
914    async fn execute_query_static(
915        registry: &FunctionRegistry,
916        db_pool: &sqlx::PgPool,
917        query_name: &str,
918        args: &serde_json::Value,
919        auth_context: &forge_core::function::AuthContext,
920    ) -> forge_core::Result<(serde_json::Value, ReadSet)> {
921        match registry.get(query_name) {
922            Some(FunctionEntry::Query { info, handler }) => {
923                Self::check_query_auth(info, auth_context)?;
924                let enforce = !info.is_public && info.has_input_args;
925                auth_context.check_identity_args(query_name, args, enforce)?;
926
927                let ctx = forge_core::function::QueryContext::new(
928                    db_pool.clone(),
929                    auth_context.clone(),
930                    forge_core::function::RequestMetadata::new(),
931                );
932
933                let normalized_args = match args {
934                    v if v.as_object().is_some_and(|o| o.is_empty()) => serde_json::Value::Null,
935                    v => v.clone(),
936                };
937
938                let data = handler(&ctx, normalized_args).await?;
939
940                let mut read_set = ReadSet::new();
941
942                if info.table_dependencies.is_empty() {
943                    let table_name = Self::extract_table_name(query_name);
944                    read_set.add_table(&table_name);
945                    tracing::trace!(
946                        query = %query_name,
947                        fallback_table = %table_name,
948                        "Using naming convention fallback for table dependency"
949                    );
950                } else {
951                    for table in info.table_dependencies {
952                        read_set.add_table(*table);
953                    }
954                }
955
956                Ok((data, read_set))
957            }
958            _ => Err(forge_core::ForgeError::Validation(format!(
959                "Query '{}' not found or not a query",
960                query_name
961            ))),
962        }
963    }
964
965    fn extract_table_name(query_name: &str) -> String {
966        if let Some(rest) = query_name.strip_prefix("get_") {
967            rest.to_string()
968        } else if let Some(rest) = query_name.strip_prefix("list_") {
969            rest.to_string()
970        } else if let Some(rest) = query_name.strip_prefix("find_") {
971            rest.to_string()
972        } else if let Some(rest) = query_name.strip_prefix("fetch_") {
973            rest.to_string()
974        } else {
975            query_name.to_string()
976        }
977    }
978
979    fn check_query_auth(
980        info: &forge_core::function::FunctionInfo,
981        auth: &forge_core::function::AuthContext,
982    ) -> forge_core::Result<()> {
983        if info.is_public {
984            return Ok(());
985        }
986
987        if !auth.is_authenticated() {
988            return Err(forge_core::ForgeError::Unauthorized(
989                "Authentication required".into(),
990            ));
991        }
992
993        if let Some(role) = info.required_role
994            && !auth.has_role(role)
995        {
996            return Err(forge_core::ForgeError::Forbidden(format!(
997                "Role '{}' required",
998                role
999            )));
1000        }
1001
1002        Ok(())
1003    }
1004
1005    async fn ensure_job_access(
1006        db_pool: &sqlx::PgPool,
1007        job_id: Uuid,
1008        auth: &forge_core::function::AuthContext,
1009    ) -> forge_core::Result<()> {
1010        let owner_subject_row: Option<(Option<String>,)> =
1011            sqlx::query_as(r#"SELECT owner_subject FROM forge_jobs WHERE id = $1"#)
1012                .bind(job_id)
1013                .fetch_optional(db_pool)
1014                .await
1015                .map_err(forge_core::ForgeError::Sql)?;
1016
1017        let owner_subject = owner_subject_row
1018            .ok_or_else(|| forge_core::ForgeError::NotFound(format!("Job {} not found", job_id)))?
1019            .0;
1020
1021        Self::check_owner_access(owner_subject, auth)
1022    }
1023
1024    async fn ensure_workflow_access(
1025        db_pool: &sqlx::PgPool,
1026        workflow_id: Uuid,
1027        auth: &forge_core::function::AuthContext,
1028    ) -> forge_core::Result<()> {
1029        let owner_subject_row: Option<(Option<String>,)> =
1030            sqlx::query_as(r#"SELECT owner_subject FROM forge_workflow_runs WHERE id = $1"#)
1031                .bind(workflow_id)
1032                .fetch_optional(db_pool)
1033                .await
1034                .map_err(forge_core::ForgeError::Sql)?;
1035
1036        let owner_subject = owner_subject_row
1037            .ok_or_else(|| {
1038                forge_core::ForgeError::NotFound(format!("Workflow {} not found", workflow_id))
1039            })?
1040            .0;
1041
1042        Self::check_owner_access(owner_subject, auth)
1043    }
1044
1045    fn check_owner_access(
1046        owner_subject: Option<String>,
1047        auth: &forge_core::function::AuthContext,
1048    ) -> forge_core::Result<()> {
1049        if auth.is_admin() {
1050            return Ok(());
1051        }
1052
1053        // Treat empty string the same as NULL (no owner)
1054        let Some(owner) = owner_subject.filter(|s| !s.is_empty()) else {
1055            return Ok(());
1056        };
1057
1058        let principal = auth.principal_id().ok_or_else(|| {
1059            forge_core::ForgeError::Unauthorized("Authentication required".to_string())
1060        })?;
1061
1062        if owner == principal {
1063            Ok(())
1064        } else {
1065            Err(forge_core::ForgeError::Forbidden(
1066                "Not authorized to access this resource".to_string(),
1067            ))
1068        }
1069    }
1070
1071    pub fn stop(&self) {
1072        let _ = self.shutdown_tx.send(());
1073        self.change_listener.stop();
1074    }
1075
1076    pub async fn stats(&self) -> ReactorStats {
1077        let session_stats = self.session_server.stats();
1078        let inv_stats = self.invalidation_engine.stats().await;
1079
1080        ReactorStats {
1081            connections: session_stats.connections,
1082            subscriptions: session_stats.subscriptions,
1083            query_groups: self.subscription_manager.group_count(),
1084            pending_invalidations: inv_stats.pending_groups,
1085            listener_running: self.change_listener.is_running(),
1086        }
1087    }
1088}
1089
1090/// Reactor statistics.
1091#[derive(Debug, Clone)]
1092pub struct ReactorStats {
1093    pub connections: usize,
1094    pub subscriptions: usize,
1095    pub query_groups: usize,
1096    pub pending_invalidations: usize,
1097    pub listener_running: bool,
1098}
1099
1100#[cfg(test)]
1101mod tests {
1102    use super::*;
1103    use std::collections::HashMap;
1104
1105    #[test]
1106    fn test_reactor_config_default() {
1107        let config = ReactorConfig::default();
1108        assert_eq!(config.listener.channel, "forge_changes");
1109        assert_eq!(config.invalidation.debounce_ms, 50);
1110        assert_eq!(config.max_listener_restarts, 5);
1111        assert_eq!(config.listener_restart_delay_ms, 1000);
1112        assert_eq!(config.max_concurrent_reexecutions, 64);
1113        assert_eq!(config.session_cleanup_interval_secs, 60);
1114    }
1115
1116    #[test]
1117    fn test_compute_hash() {
1118        let data1 = serde_json::json!({"name": "test"});
1119        let data2 = serde_json::json!({"name": "test"});
1120        let data3 = serde_json::json!({"name": "different"});
1121
1122        let hash1 = Reactor::compute_hash(&data1);
1123        let hash2 = Reactor::compute_hash(&data2);
1124        let hash3 = Reactor::compute_hash(&data3);
1125
1126        assert_eq!(hash1, hash2);
1127        assert_ne!(hash1, hash3);
1128    }
1129
1130    #[test]
1131    fn test_check_identity_args_rejects_cross_user() {
1132        let user_id = uuid::Uuid::new_v4();
1133        let auth = forge_core::function::AuthContext::authenticated(
1134            user_id,
1135            vec!["user".to_string()],
1136            HashMap::from([(
1137                "sub".to_string(),
1138                serde_json::Value::String(user_id.to_string()),
1139            )]),
1140        );
1141
1142        let result = auth.check_identity_args(
1143            "list_orders",
1144            &serde_json::json!({"user_id": uuid::Uuid::new_v4().to_string()}),
1145            true,
1146        );
1147        assert!(matches!(result, Err(forge_core::ForgeError::Forbidden(_))));
1148    }
1149
1150    #[test]
1151    fn test_check_identity_args_requires_scope_for_non_public_queries() {
1152        let user_id = uuid::Uuid::new_v4();
1153        let auth = forge_core::function::AuthContext::authenticated(
1154            user_id,
1155            vec!["user".to_string()],
1156            HashMap::from([(
1157                "sub".to_string(),
1158                serde_json::Value::String(user_id.to_string()),
1159            )]),
1160        );
1161
1162        let result = auth.check_identity_args("list_orders", &serde_json::json!({}), true);
1163        assert!(matches!(result, Err(forge_core::ForgeError::Forbidden(_))));
1164    }
1165
1166    #[test]
1167    fn test_check_owner_access_allows_admin() {
1168        let auth = forge_core::function::AuthContext::authenticated_without_uuid(
1169            vec!["admin".to_string()],
1170            HashMap::from([(
1171                "sub".to_string(),
1172                serde_json::Value::String("admin-1".to_string()),
1173            )]),
1174        );
1175
1176        let result = Reactor::check_owner_access(Some("other-user".to_string()), &auth);
1177        assert!(result.is_ok());
1178    }
1179}