Skip to main content

forge_runtime/realtime/
reactor.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3
4use futures_util::StreamExt;
5use futures_util::stream::FuturesUnordered;
6use tokio::sync::{RwLock, Semaphore, broadcast, mpsc};
7use uuid::Uuid;
8
9use forge_core::cluster::NodeId;
10use forge_core::realtime::{Change, ReadSet, SessionId, SubscriptionId};
11
12use super::invalidation::{InvalidationConfig, InvalidationEngine};
13use super::listener::{ChangeListener, ListenerConfig};
14use super::manager::SubscriptionManager;
15use super::message::{
16    JobData, RealtimeConfig, RealtimeMessage, SessionServer, WorkflowData, WorkflowStepData,
17};
18use crate::function::{FunctionEntry, FunctionRegistry};
19
20#[derive(Debug, Clone)]
21pub struct ReactorConfig {
22    pub listener: ListenerConfig,
23    pub invalidation: InvalidationConfig,
24    pub realtime: RealtimeConfig,
25    pub max_listener_restarts: u32,
26    /// Doubles with each attempt for exponential backoff
27    pub listener_restart_delay_ms: u64,
28    /// Maximum concurrent re-executions during flush.
29    pub max_concurrent_reexecutions: usize,
30    /// Session cleanup interval in seconds.
31    pub session_cleanup_interval_secs: u64,
32}
33
34impl Default for ReactorConfig {
35    fn default() -> Self {
36        Self {
37            listener: ListenerConfig::default(),
38            invalidation: InvalidationConfig::default(),
39            realtime: RealtimeConfig::default(),
40            max_listener_restarts: 5,
41            listener_restart_delay_ms: 1000,
42            max_concurrent_reexecutions: 64,
43            session_cleanup_interval_secs: 60,
44        }
45    }
46}
47
48/// Job subscription tracking.
49/// The job_id is the HashMap key, not stored here.
50#[derive(Debug, Clone)]
51pub struct JobSubscription {
52    pub session_id: SessionId,
53    pub client_sub_id: String,
54    pub auth_context: forge_core::function::AuthContext,
55}
56
57/// Workflow subscription tracking.
58/// The workflow_id is the HashMap key, not stored here.
59#[derive(Debug, Clone)]
60pub struct WorkflowSubscription {
61    pub session_id: SessionId,
62    pub client_sub_id: String,
63    pub auth_context: forge_core::function::AuthContext,
64}
65
66/// ChangeListener -> InvalidationEngine -> Group Re-execution -> SSE Fan-out
67pub struct Reactor {
68    node_id: NodeId,
69    db_pool: sqlx::PgPool,
70    registry: FunctionRegistry,
71    subscription_manager: Arc<SubscriptionManager>,
72    session_server: Arc<SessionServer>,
73    change_listener: Arc<ChangeListener>,
74    invalidation_engine: Arc<InvalidationEngine>,
75    /// Job subscriptions: job_id -> list of subscribers.
76    job_subscriptions: Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
77    /// Workflow subscriptions: workflow_id -> list of subscribers.
78    workflow_subscriptions: Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
79    shutdown_tx: broadcast::Sender<()>,
80    max_listener_restarts: u32,
81    listener_restart_delay_ms: u64,
82    max_concurrent_reexecutions: usize,
83    session_cleanup_interval_secs: u64,
84}
85
86impl Reactor {
87    /// Create a new reactor.
88    pub fn new(
89        node_id: NodeId,
90        db_pool: sqlx::PgPool,
91        registry: FunctionRegistry,
92        config: ReactorConfig,
93    ) -> Self {
94        let subscription_manager = Arc::new(SubscriptionManager::new(
95            config.realtime.max_subscriptions_per_session,
96        ));
97        let session_server = Arc::new(SessionServer::new(node_id, config.realtime.clone()));
98        let change_listener = Arc::new(ChangeListener::new(db_pool.clone(), config.listener));
99        let invalidation_engine = Arc::new(InvalidationEngine::new(
100            subscription_manager.clone(),
101            config.invalidation,
102        ));
103        let (shutdown_tx, _) = broadcast::channel(1);
104
105        Self {
106            node_id,
107            db_pool,
108            registry,
109            subscription_manager,
110            session_server,
111            change_listener,
112            invalidation_engine,
113            job_subscriptions: Arc::new(RwLock::new(HashMap::new())),
114            workflow_subscriptions: Arc::new(RwLock::new(HashMap::new())),
115            shutdown_tx,
116            max_listener_restarts: config.max_listener_restarts,
117            listener_restart_delay_ms: config.listener_restart_delay_ms,
118            max_concurrent_reexecutions: config.max_concurrent_reexecutions,
119            session_cleanup_interval_secs: config.session_cleanup_interval_secs,
120        }
121    }
122
123    pub fn node_id(&self) -> NodeId {
124        self.node_id
125    }
126
127    pub fn session_server(&self) -> Arc<SessionServer> {
128        self.session_server.clone()
129    }
130
131    pub fn subscription_manager(&self) -> Arc<SubscriptionManager> {
132        self.subscription_manager.clone()
133    }
134
135    pub fn shutdown_receiver(&self) -> broadcast::Receiver<()> {
136        self.shutdown_tx.subscribe()
137    }
138
139    /// Register a new session.
140    pub fn register_session(&self, session_id: SessionId, sender: mpsc::Sender<RealtimeMessage>) {
141        self.session_server.register_connection(session_id, sender);
142        tracing::trace!(?session_id, "Session registered");
143    }
144
145    /// Remove a session and all its subscriptions.
146    pub async fn remove_session(&self, session_id: SessionId) {
147        // Clean up query subscriptions via the subscription manager
148        self.subscription_manager
149            .remove_session_subscriptions(session_id);
150
151        // Clean up session server
152        self.session_server.remove_connection(session_id);
153
154        // Clean up job subscriptions
155        {
156            let mut job_subs = self.job_subscriptions.write().await;
157            for subscribers in job_subs.values_mut() {
158                subscribers.retain(|s| s.session_id != session_id);
159            }
160            job_subs.retain(|_, v| !v.is_empty());
161        }
162
163        // Clean up workflow subscriptions
164        {
165            let mut workflow_subs = self.workflow_subscriptions.write().await;
166            for subscribers in workflow_subs.values_mut() {
167                subscribers.retain(|s| s.session_id != session_id);
168            }
169            workflow_subs.retain(|_, v| !v.is_empty());
170        }
171
172        tracing::trace!(?session_id, "Session removed");
173    }
174
175    /// Subscribe to a query. Uses query groups for coalescing.
176    pub async fn subscribe(
177        &self,
178        session_id: SessionId,
179        client_sub_id: String,
180        query_name: String,
181        args: serde_json::Value,
182        auth_context: forge_core::function::AuthContext,
183    ) -> forge_core::Result<(SubscriptionId, serde_json::Value)> {
184        // Look up function info for compile-time metadata
185        let (table_deps, selected_cols) = match self.registry.get(&query_name) {
186            Some(FunctionEntry::Query { info, .. }) => {
187                (info.table_dependencies, info.selected_columns)
188            }
189            _ => (&[] as &[&str], &[] as &[&str]),
190        };
191
192        let (group_id, subscription_id, is_new_group) = self.subscription_manager.subscribe(
193            session_id,
194            client_sub_id,
195            &query_name,
196            &args,
197            &auth_context,
198            table_deps,
199            selected_cols,
200        )?;
201
202        // Register subscription in session server for message routing
203        if let Err(error) = self
204            .session_server
205            .add_subscription(session_id, subscription_id)
206        {
207            self.subscription_manager.unsubscribe(subscription_id);
208            return Err(error);
209        }
210
211        // Only execute the query if this is a new group (no cached result yet)
212        let data = if is_new_group {
213            let (data, read_set) = match self.execute_query(&query_name, &args, &auth_context).await
214            {
215                Ok(result) => result,
216                Err(error) => {
217                    self.unsubscribe(subscription_id);
218                    return Err(error);
219                }
220            };
221
222            let result_hash = Self::compute_hash(&data);
223
224            tracing::trace!(
225                ?group_id,
226                query = %query_name,
227                "New query group created"
228            );
229
230            self.subscription_manager
231                .update_group(group_id, read_set, result_hash);
232
233            data
234        } else {
235            // Group exists, re-execute to get fresh data for this subscriber
236            // (they might have joined mid-cycle)
237            let (data, _) = match self.execute_query(&query_name, &args, &auth_context).await {
238                Ok(result) => result,
239                Err(error) => {
240                    self.unsubscribe(subscription_id);
241                    return Err(error);
242                }
243            };
244            data
245        };
246
247        tracing::trace!(?subscription_id, "Subscription created");
248        Ok((subscription_id, data))
249    }
250
251    /// Unsubscribe from a query.
252    pub fn unsubscribe(&self, subscription_id: SubscriptionId) {
253        self.session_server.remove_subscription(subscription_id);
254        self.subscription_manager.unsubscribe(subscription_id);
255        tracing::trace!(?subscription_id, "Subscription removed");
256    }
257
258    /// Subscribe to job progress updates.
259    pub async fn subscribe_job(
260        &self,
261        session_id: SessionId,
262        client_sub_id: String,
263        job_id: Uuid,
264        auth_context: &forge_core::function::AuthContext,
265    ) -> forge_core::Result<JobData> {
266        Self::ensure_job_access(&self.db_pool, job_id, auth_context).await?;
267        let job_data = self.fetch_job_data(job_id).await?;
268
269        let subscription = JobSubscription {
270            session_id,
271            client_sub_id,
272            auth_context: auth_context.clone(),
273        };
274
275        let mut subs = self.job_subscriptions.write().await;
276        subs.entry(job_id).or_default().push(subscription);
277
278        tracing::trace!(%job_id, %session_id, "Job subscription created");
279        Ok(job_data)
280    }
281
282    /// Unsubscribe from job updates.
283    pub async fn unsubscribe_job(&self, session_id: SessionId, client_sub_id: &str) {
284        let mut subs = self.job_subscriptions.write().await;
285        for subscribers in subs.values_mut() {
286            subscribers
287                .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id));
288        }
289        subs.retain(|_, v| !v.is_empty());
290    }
291
292    /// Subscribe to workflow progress updates.
293    pub async fn subscribe_workflow(
294        &self,
295        session_id: SessionId,
296        client_sub_id: String,
297        workflow_id: Uuid,
298        auth_context: &forge_core::function::AuthContext,
299    ) -> forge_core::Result<WorkflowData> {
300        Self::ensure_workflow_access(&self.db_pool, workflow_id, auth_context).await?;
301        let workflow_data = self.fetch_workflow_data(workflow_id).await?;
302
303        let subscription = WorkflowSubscription {
304            session_id,
305            client_sub_id,
306            auth_context: auth_context.clone(),
307        };
308
309        let mut subs = self.workflow_subscriptions.write().await;
310        subs.entry(workflow_id).or_default().push(subscription);
311
312        tracing::trace!(%workflow_id, %session_id, "Workflow subscription created");
313        Ok(workflow_data)
314    }
315
316    /// Unsubscribe from workflow updates.
317    pub async fn unsubscribe_workflow(&self, session_id: SessionId, client_sub_id: &str) {
318        let mut subs = self.workflow_subscriptions.write().await;
319        for subscribers in subs.values_mut() {
320            subscribers
321                .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id));
322        }
323        subs.retain(|_, v| !v.is_empty());
324    }
325
326    #[allow(clippy::type_complexity)]
327    async fn fetch_job_data(&self, job_id: Uuid) -> forge_core::Result<JobData> {
328        Self::fetch_job_data_static(job_id, &self.db_pool).await
329    }
330
331    async fn fetch_workflow_data(&self, workflow_id: Uuid) -> forge_core::Result<WorkflowData> {
332        Self::fetch_workflow_data_static(workflow_id, &self.db_pool).await
333    }
334
335    /// Execute a query and return data with read set.
336    async fn execute_query(
337        &self,
338        query_name: &str,
339        args: &serde_json::Value,
340        auth_context: &forge_core::function::AuthContext,
341    ) -> forge_core::Result<(serde_json::Value, ReadSet)> {
342        Self::execute_query_static(
343            &self.registry,
344            &self.db_pool,
345            query_name,
346            args,
347            auth_context,
348        )
349        .await
350    }
351
352    fn compute_hash(data: &serde_json::Value) -> String {
353        use std::collections::hash_map::DefaultHasher;
354        use std::hash::{Hash, Hasher};
355
356        let json = serde_json::to_string(data).unwrap_or_default();
357        let mut hasher = DefaultHasher::new();
358        json.hash(&mut hasher);
359        format!("{:x}", hasher.finish())
360    }
361
362    /// Parallel group re-execution with bounded concurrency.
363    async fn flush_invalidations(
364        invalidation_engine: &Arc<InvalidationEngine>,
365        subscription_manager: &Arc<SubscriptionManager>,
366        session_server: &Arc<SessionServer>,
367        registry: &FunctionRegistry,
368        db_pool: &sqlx::PgPool,
369        max_concurrent: usize,
370    ) {
371        let invalidated_groups = invalidation_engine.check_pending().await;
372        if invalidated_groups.is_empty() {
373            return;
374        }
375
376        tracing::trace!(
377            count = invalidated_groups.len(),
378            "Invalidating query groups"
379        );
380
381        // Collect group data we need for re-execution
382        let groups_to_process: Vec<_> = invalidated_groups
383            .iter()
384            .filter_map(|gid| {
385                subscription_manager.get_group(*gid).map(|g| {
386                    (
387                        g.id,
388                        g.query_name.clone(),
389                        (*g.args).clone(),
390                        g.last_result_hash.clone(),
391                        g.auth_context.clone(),
392                    )
393                })
394            })
395            .collect();
396
397        // Parallel re-execution bounded by semaphore
398        let semaphore = Arc::new(Semaphore::new(max_concurrent));
399        let mut futures = FuturesUnordered::new();
400
401        for (group_id, query_name, args, last_hash, auth_context) in groups_to_process {
402            let permit = match semaphore.clone().acquire_owned().await {
403                Ok(p) => p,
404                Err(_) => break,
405            };
406            let registry = registry.clone();
407            let db_pool = db_pool.clone();
408
409            futures.push(async move {
410                let result = Self::execute_query_static(
411                    &registry,
412                    &db_pool,
413                    &query_name,
414                    &args,
415                    &auth_context,
416                )
417                .await;
418                drop(permit);
419                (group_id, last_hash, result)
420            });
421        }
422
423        // Process results and fan out to subscribers
424        while let Some((group_id, last_hash, result)) = futures.next().await {
425            match result {
426                Ok((new_data, read_set)) => {
427                    let new_hash = Self::compute_hash(&new_data);
428
429                    if last_hash.as_ref() != Some(&new_hash) {
430                        // Update group state
431                        subscription_manager.update_group(group_id, read_set, new_hash);
432
433                        // Fan out to all subscribers in this group
434                        let subscribers = subscription_manager.get_group_subscribers(group_id);
435                        for (session_id, client_sub_id) in subscribers {
436                            let message = RealtimeMessage::Data {
437                                subscription_id: client_sub_id.clone(),
438                                data: new_data.clone(),
439                            };
440
441                            if let Err(e) = session_server.try_send_to_session(session_id, message)
442                            {
443                                tracing::trace!(
444                                    client_id = %client_sub_id,
445                                    error = ?e,
446                                    "Failed to send update to subscriber"
447                                );
448                            }
449                        }
450                    }
451                }
452                Err(e) => {
453                    tracing::warn!(?group_id, error = %e, "Failed to re-execute query group");
454                }
455            }
456        }
457    }
458
459    /// Start the reactor.
460    pub async fn start(&self) -> forge_core::Result<()> {
461        let listener = self.change_listener.clone();
462        let invalidation_engine = self.invalidation_engine.clone();
463        let subscription_manager = self.subscription_manager.clone();
464        let job_subscriptions = self.job_subscriptions.clone();
465        let workflow_subscriptions = self.workflow_subscriptions.clone();
466        let session_server = self.session_server.clone();
467        let registry = self.registry.clone();
468        let db_pool = self.db_pool.clone();
469        let mut shutdown_rx = self.shutdown_tx.subscribe();
470        let max_restarts = self.max_listener_restarts;
471        let base_delay_ms = self.listener_restart_delay_ms;
472        let max_concurrent = self.max_concurrent_reexecutions;
473        let cleanup_secs = self.session_cleanup_interval_secs;
474
475        let mut change_rx = listener.subscribe();
476
477        tokio::spawn(async move {
478            tracing::debug!("Reactor listening for changes");
479
480            let mut restart_count: u32 = 0;
481            let (listener_error_tx, mut listener_error_rx) = mpsc::channel::<String>(1);
482
483            // Start initial listener
484            let listener_clone = listener.clone();
485            let error_tx = listener_error_tx.clone();
486            let mut listener_handle = Some(tokio::spawn(async move {
487                if let Err(e) = listener_clone.run().await {
488                    let _ = error_tx.send(format!("Change listener error: {}", e)).await;
489                }
490            }));
491
492            let mut flush_interval = tokio::time::interval(std::time::Duration::from_millis(25));
493            flush_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
494
495            let mut cleanup_interval =
496                tokio::time::interval(std::time::Duration::from_secs(cleanup_secs));
497            cleanup_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
498
499            loop {
500                tokio::select! {
501                    result = change_rx.recv() => {
502                        match result {
503                            Ok(change) => {
504                                Self::handle_change(
505                                    &change,
506                                    &invalidation_engine,
507                                    &job_subscriptions,
508                                    &workflow_subscriptions,
509                                    &session_server,
510                                    &db_pool,
511                                ).await;
512                            }
513                            Err(broadcast::error::RecvError::Lagged(n)) => {
514                                tracing::warn!("Reactor lagged by {} messages", n);
515                            }
516                            Err(broadcast::error::RecvError::Closed) => {
517                                tracing::debug!("Change channel closed");
518                                break;
519                            }
520                        }
521                    }
522                    _ = flush_interval.tick() => {
523                        Self::flush_invalidations(
524                            &invalidation_engine,
525                            &subscription_manager,
526                            &session_server,
527                            &registry,
528                            &db_pool,
529                            max_concurrent,
530                        ).await;
531                    }
532                    _ = cleanup_interval.tick() => {
533                        session_server.cleanup_stale(std::time::Duration::from_secs(300));
534                    }
535                    Some(error_msg) = listener_error_rx.recv() => {
536                        if restart_count >= max_restarts {
537                            tracing::error!(
538                                attempts = restart_count,
539                                last_error = %error_msg,
540                                "Change listener failed permanently, real-time updates disabled"
541                            );
542                            break;
543                        }
544
545                        restart_count += 1;
546                        let delay = base_delay_ms * 2u64.saturating_pow(restart_count - 1);
547                        tracing::warn!(
548                            attempt = restart_count,
549                            max = max_restarts,
550                            delay_ms = delay,
551                            error = %error_msg,
552                            "Change listener restarting"
553                        );
554
555                        tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
556
557                        let listener_clone = listener.clone();
558                        let error_tx = listener_error_tx.clone();
559                        if let Some(handle) = listener_handle.take() {
560                            handle.abort();
561                        }
562                        change_rx = listener.subscribe();
563                        listener_handle = Some(tokio::spawn(async move {
564                            if let Err(e) = listener_clone.run().await {
565                                let _ = error_tx.send(format!("Change listener error: {}", e)).await;
566                            }
567                        }));
568                    }
569                    _ = shutdown_rx.recv() => {
570                        tracing::debug!("Reactor shutting down");
571                        break;
572                    }
573                }
574            }
575
576            if let Some(handle) = listener_handle {
577                handle.abort();
578            }
579        });
580
581        Ok(())
582    }
583
584    /// Handle a database change event.
585    #[allow(clippy::too_many_arguments)]
586    async fn handle_change(
587        change: &Change,
588        invalidation_engine: &Arc<InvalidationEngine>,
589        job_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
590        workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
591        session_server: &Arc<SessionServer>,
592        db_pool: &sqlx::PgPool,
593    ) {
594        tracing::trace!(table = %change.table, op = ?change.operation, row_id = ?change.row_id, "Processing change");
595
596        match change.table.as_str() {
597            "forge_jobs" => {
598                if let Some(job_id) = change.row_id {
599                    Self::handle_job_change(job_id, job_subscriptions, session_server, db_pool)
600                        .await;
601                }
602                return;
603            }
604            "forge_workflow_runs" => {
605                if let Some(workflow_id) = change.row_id {
606                    Self::handle_workflow_change(
607                        workflow_id,
608                        workflow_subscriptions,
609                        session_server,
610                        db_pool,
611                    )
612                    .await;
613                }
614                return;
615            }
616            "forge_workflow_steps" => {
617                if let Some(step_id) = change.row_id {
618                    Self::handle_workflow_step_change(
619                        step_id,
620                        workflow_subscriptions,
621                        session_server,
622                        db_pool,
623                    )
624                    .await;
625                }
626                return;
627            }
628            _ => {}
629        }
630
631        // Record change for debounced group invalidation
632        invalidation_engine.process_change(change.clone()).await;
633    }
634
635    async fn handle_job_change(
636        job_id: Uuid,
637        job_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<JobSubscription>>>>,
638        session_server: &Arc<SessionServer>,
639        db_pool: &sqlx::PgPool,
640    ) {
641        let subs = job_subscriptions.read().await;
642        let subscribers = match subs.get(&job_id) {
643            Some(s) if !s.is_empty() => s.clone(),
644            _ => return,
645        };
646        drop(subs);
647
648        let job_data = match Self::fetch_job_data_static(job_id, db_pool).await {
649            Ok(data) => data,
650            Err(e) => {
651                tracing::debug!(%job_id, error = %e, "Failed to fetch job data");
652                return;
653            }
654        };
655
656        let owner_subject = match Self::fetch_job_owner_subject_static(job_id, db_pool).await {
657            Ok(owner) => owner,
658            Err(e) => {
659                tracing::debug!(%job_id, error = %e, "Failed to fetch job owner");
660                return;
661            }
662        };
663
664        let mut unauthorized: HashSet<(SessionId, String)> = HashSet::new();
665
666        for sub in &subscribers {
667            if Self::check_owner_access(owner_subject.clone(), &sub.auth_context).is_err() {
668                unauthorized.insert((sub.session_id, sub.client_sub_id.clone()));
669                continue;
670            }
671
672            let message = RealtimeMessage::JobUpdate {
673                client_sub_id: sub.client_sub_id.clone(),
674                job: job_data.clone(),
675            };
676
677            if let Err(e) = session_server
678                .send_to_session(sub.session_id, message)
679                .await
680            {
681                tracing::trace!(%job_id, error = %e, "Failed to send job update");
682            }
683        }
684
685        if !unauthorized.is_empty() {
686            let mut subs = job_subscriptions.write().await;
687            if let Some(entries) = subs.get_mut(&job_id) {
688                entries
689                    .retain(|e| !unauthorized.contains(&(e.session_id, e.client_sub_id.clone())));
690            }
691            subs.retain(|_, v| !v.is_empty());
692        }
693    }
694
695    async fn handle_workflow_change(
696        workflow_id: Uuid,
697        workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
698        session_server: &Arc<SessionServer>,
699        db_pool: &sqlx::PgPool,
700    ) {
701        let subs = workflow_subscriptions.read().await;
702        let subscribers = match subs.get(&workflow_id) {
703            Some(s) if !s.is_empty() => s.clone(),
704            _ => return,
705        };
706        drop(subs);
707
708        let workflow_data = match Self::fetch_workflow_data_static(workflow_id, db_pool).await {
709            Ok(data) => data,
710            Err(e) => {
711                tracing::debug!(%workflow_id, error = %e, "Failed to fetch workflow data");
712                return;
713            }
714        };
715
716        let owner_subject =
717            match Self::fetch_workflow_owner_subject_static(workflow_id, db_pool).await {
718                Ok(owner) => owner,
719                Err(e) => {
720                    tracing::debug!(%workflow_id, error = %e, "Failed to fetch workflow owner");
721                    return;
722                }
723            };
724
725        let mut unauthorized: HashSet<(SessionId, String)> = HashSet::new();
726
727        for sub in &subscribers {
728            if Self::check_owner_access(owner_subject.clone(), &sub.auth_context).is_err() {
729                unauthorized.insert((sub.session_id, sub.client_sub_id.clone()));
730                continue;
731            }
732
733            let message = RealtimeMessage::WorkflowUpdate {
734                client_sub_id: sub.client_sub_id.clone(),
735                workflow: workflow_data.clone(),
736            };
737
738            if let Err(e) = session_server
739                .send_to_session(sub.session_id, message)
740                .await
741            {
742                tracing::trace!(%workflow_id, error = %e, "Failed to send workflow update");
743            }
744        }
745
746        if !unauthorized.is_empty() {
747            let mut subs = workflow_subscriptions.write().await;
748            if let Some(entries) = subs.get_mut(&workflow_id) {
749                entries
750                    .retain(|e| !unauthorized.contains(&(e.session_id, e.client_sub_id.clone())));
751            }
752            subs.retain(|_, v| !v.is_empty());
753        }
754    }
755
756    async fn handle_workflow_step_change(
757        step_id: Uuid,
758        workflow_subscriptions: &Arc<RwLock<HashMap<Uuid, Vec<WorkflowSubscription>>>>,
759        session_server: &Arc<SessionServer>,
760        db_pool: &sqlx::PgPool,
761    ) {
762        let workflow_id: Option<Uuid> = match sqlx::query_scalar!(
763            "SELECT workflow_run_id FROM forge_workflow_steps WHERE id = $1",
764            step_id,
765        )
766        .fetch_optional(db_pool)
767        .await
768        {
769            Ok(id) => id,
770            Err(e) => {
771                tracing::debug!(%step_id, error = %e, "Failed to look up workflow for step");
772                return;
773            }
774        };
775
776        if let Some(wf_id) = workflow_id {
777            Self::handle_workflow_change(wf_id, workflow_subscriptions, session_server, db_pool)
778                .await;
779        }
780    }
781
782    #[allow(clippy::type_complexity)]
783    async fn fetch_job_data_static(
784        job_id: Uuid,
785        db_pool: &sqlx::PgPool,
786    ) -> forge_core::Result<JobData> {
787        let row = sqlx::query!(
788            r#"
789                SELECT status, progress_percent, progress_message, output, last_error
790                FROM forge_jobs WHERE id = $1
791                "#,
792            job_id
793        )
794        .fetch_optional(db_pool)
795        .await
796        .map_err(forge_core::ForgeError::Sql)?;
797
798        match row {
799            Some(row) => Ok(JobData {
800                job_id: job_id.to_string(),
801                status: row.status,
802                progress_percent: row.progress_percent,
803                progress_message: row.progress_message,
804                output: row.output,
805                error: row.last_error,
806            }),
807            None => Err(forge_core::ForgeError::NotFound(format!(
808                "Job {} not found",
809                job_id
810            ))),
811        }
812    }
813
814    async fn fetch_job_owner_subject_static(
815        job_id: Uuid,
816        db_pool: &sqlx::PgPool,
817    ) -> forge_core::Result<Option<String>> {
818        let owner_subject: Option<Option<String>> =
819            sqlx::query_scalar!("SELECT owner_subject FROM forge_jobs WHERE id = $1", job_id)
820                .fetch_optional(db_pool)
821                .await
822                .map_err(forge_core::ForgeError::Sql)?;
823
824        owner_subject
825            .ok_or_else(|| forge_core::ForgeError::NotFound(format!("Job {} not found", job_id)))
826    }
827
828    #[allow(clippy::type_complexity)]
829    async fn fetch_workflow_data_static(
830        workflow_id: Uuid,
831        db_pool: &sqlx::PgPool,
832    ) -> forge_core::Result<WorkflowData> {
833        let row = sqlx::query!(
834            r#"
835                SELECT status, current_step, output, error
836                FROM forge_workflow_runs WHERE id = $1
837                "#,
838            workflow_id
839        )
840        .fetch_optional(db_pool)
841        .await
842        .map_err(forge_core::ForgeError::Sql)?;
843
844        let row = match row {
845            Some(r) => r,
846            None => {
847                return Err(forge_core::ForgeError::NotFound(format!(
848                    "Workflow {} not found",
849                    workflow_id
850                )));
851            }
852        };
853
854        let step_rows = sqlx::query!(
855            r#"
856            SELECT step_name, status, error
857            FROM forge_workflow_steps
858            WHERE workflow_run_id = $1
859            ORDER BY started_at ASC NULLS LAST
860            "#,
861            workflow_id
862        )
863        .fetch_all(db_pool)
864        .await
865        .map_err(forge_core::ForgeError::Sql)?;
866
867        let steps = step_rows
868            .into_iter()
869            .map(|row| WorkflowStepData {
870                name: row.step_name,
871                status: row.status,
872                error: row.error,
873            })
874            .collect();
875
876        Ok(WorkflowData {
877            workflow_id: workflow_id.to_string(),
878            status: row.status,
879            current_step: row.current_step,
880            steps,
881            output: row.output,
882            error: row.error,
883        })
884    }
885
886    async fn fetch_workflow_owner_subject_static(
887        workflow_id: Uuid,
888        db_pool: &sqlx::PgPool,
889    ) -> forge_core::Result<Option<String>> {
890        let owner_subject: Option<Option<String>> = sqlx::query_scalar!(
891            "SELECT owner_subject FROM forge_workflow_runs WHERE id = $1",
892            workflow_id,
893        )
894        .fetch_optional(db_pool)
895        .await
896        .map_err(forge_core::ForgeError::Sql)?;
897
898        owner_subject.ok_or_else(|| {
899            forge_core::ForgeError::NotFound(format!("Workflow {} not found", workflow_id))
900        })
901    }
902
903    async fn execute_query_static(
904        registry: &FunctionRegistry,
905        db_pool: &sqlx::PgPool,
906        query_name: &str,
907        args: &serde_json::Value,
908        auth_context: &forge_core::function::AuthContext,
909    ) -> forge_core::Result<(serde_json::Value, ReadSet)> {
910        match registry.get(query_name) {
911            Some(FunctionEntry::Query { info, handler }) => {
912                Self::check_query_auth(info, auth_context)?;
913
914                let ctx = forge_core::function::QueryContext::new(
915                    db_pool.clone(),
916                    auth_context.clone(),
917                    forge_core::function::RequestMetadata::new(),
918                );
919
920                let normalized_args = match args {
921                    v if v.as_object().is_some_and(|o| o.is_empty()) => serde_json::Value::Null,
922                    v => v.clone(),
923                };
924
925                let data = handler(&ctx, normalized_args).await?;
926
927                let mut read_set = ReadSet::new();
928
929                if info.table_dependencies.is_empty() {
930                    let table_name = Self::extract_table_name(query_name);
931                    read_set.add_table(&table_name);
932                    tracing::trace!(
933                        query = %query_name,
934                        fallback_table = %table_name,
935                        "Using naming convention fallback for table dependency"
936                    );
937                } else {
938                    for table in info.table_dependencies {
939                        read_set.add_table(*table);
940                    }
941                }
942
943                Ok((data, read_set))
944            }
945            _ => Err(forge_core::ForgeError::Validation(format!(
946                "Query '{}' not found or not a query",
947                query_name
948            ))),
949        }
950    }
951
952    fn extract_table_name(query_name: &str) -> String {
953        if let Some(rest) = query_name.strip_prefix("get_") {
954            rest.to_string()
955        } else if let Some(rest) = query_name.strip_prefix("list_") {
956            rest.to_string()
957        } else if let Some(rest) = query_name.strip_prefix("find_") {
958            rest.to_string()
959        } else if let Some(rest) = query_name.strip_prefix("fetch_") {
960            rest.to_string()
961        } else {
962            query_name.to_string()
963        }
964    }
965
966    fn check_query_auth(
967        info: &forge_core::function::FunctionInfo,
968        auth: &forge_core::function::AuthContext,
969    ) -> forge_core::Result<()> {
970        if info.is_public {
971            return Ok(());
972        }
973
974        if !auth.is_authenticated() {
975            return Err(forge_core::ForgeError::Unauthorized(
976                "Authentication required".into(),
977            ));
978        }
979
980        if let Some(role) = info.required_role
981            && !auth.has_role(role)
982        {
983            return Err(forge_core::ForgeError::Forbidden(format!(
984                "Role '{}' required",
985                role
986            )));
987        }
988
989        Ok(())
990    }
991
992    async fn ensure_job_access(
993        db_pool: &sqlx::PgPool,
994        job_id: Uuid,
995        auth: &forge_core::function::AuthContext,
996    ) -> forge_core::Result<()> {
997        let owner_subject_row = sqlx::query_scalar!(
998            r#"SELECT owner_subject FROM forge_jobs WHERE id = $1"#,
999            job_id
1000        )
1001        .fetch_optional(db_pool)
1002        .await
1003        .map_err(forge_core::ForgeError::Sql)?;
1004
1005        let owner_subject = owner_subject_row
1006            .ok_or_else(|| forge_core::ForgeError::NotFound(format!("Job {} not found", job_id)))?;
1007
1008        Self::check_owner_access(owner_subject, auth)
1009    }
1010
1011    async fn ensure_workflow_access(
1012        db_pool: &sqlx::PgPool,
1013        workflow_id: Uuid,
1014        auth: &forge_core::function::AuthContext,
1015    ) -> forge_core::Result<()> {
1016        let owner_subject_row = sqlx::query_scalar!(
1017            r#"SELECT owner_subject FROM forge_workflow_runs WHERE id = $1"#,
1018            workflow_id
1019        )
1020        .fetch_optional(db_pool)
1021        .await
1022        .map_err(forge_core::ForgeError::Sql)?;
1023
1024        let owner_subject = owner_subject_row.ok_or_else(|| {
1025            forge_core::ForgeError::NotFound(format!("Workflow {} not found", workflow_id))
1026        })?;
1027
1028        Self::check_owner_access(owner_subject, auth)
1029    }
1030
1031    fn check_owner_access(
1032        owner_subject: Option<String>,
1033        auth: &forge_core::function::AuthContext,
1034    ) -> forge_core::Result<()> {
1035        if auth.is_admin() {
1036            return Ok(());
1037        }
1038
1039        // Treat empty string the same as NULL (no owner)
1040        let Some(owner) = owner_subject.filter(|s| !s.is_empty()) else {
1041            return Ok(());
1042        };
1043
1044        let principal = auth.principal_id().ok_or_else(|| {
1045            forge_core::ForgeError::Unauthorized("Authentication required".to_string())
1046        })?;
1047
1048        if owner == principal {
1049            Ok(())
1050        } else {
1051            Err(forge_core::ForgeError::Forbidden(
1052                "Not authorized to access this resource".to_string(),
1053            ))
1054        }
1055    }
1056
1057    pub fn stop(&self) {
1058        let _ = self.shutdown_tx.send(());
1059        self.change_listener.stop();
1060    }
1061
1062    pub async fn stats(&self) -> ReactorStats {
1063        let session_stats = self.session_server.stats();
1064        let inv_stats = self.invalidation_engine.stats().await;
1065
1066        ReactorStats {
1067            connections: session_stats.connections,
1068            subscriptions: session_stats.subscriptions,
1069            query_groups: self.subscription_manager.group_count(),
1070            pending_invalidations: inv_stats.pending_groups,
1071            listener_running: self.change_listener.is_running(),
1072        }
1073    }
1074}
1075
1076/// Reactor statistics.
1077#[derive(Debug, Clone)]
1078pub struct ReactorStats {
1079    pub connections: usize,
1080    pub subscriptions: usize,
1081    pub query_groups: usize,
1082    pub pending_invalidations: usize,
1083    pub listener_running: bool,
1084}
1085
1086#[cfg(test)]
1087mod tests {
1088    use super::*;
1089    use std::collections::HashMap;
1090
1091    #[test]
1092    fn test_reactor_config_default() {
1093        let config = ReactorConfig::default();
1094        assert_eq!(config.listener.channel, "forge_changes");
1095        assert_eq!(config.invalidation.debounce_ms, 50);
1096        assert_eq!(config.max_listener_restarts, 5);
1097        assert_eq!(config.listener_restart_delay_ms, 1000);
1098        assert_eq!(config.max_concurrent_reexecutions, 64);
1099        assert_eq!(config.session_cleanup_interval_secs, 60);
1100    }
1101
1102    #[test]
1103    fn test_compute_hash() {
1104        let data1 = serde_json::json!({"name": "test"});
1105        let data2 = serde_json::json!({"name": "test"});
1106        let data3 = serde_json::json!({"name": "different"});
1107
1108        let hash1 = Reactor::compute_hash(&data1);
1109        let hash2 = Reactor::compute_hash(&data2);
1110        let hash3 = Reactor::compute_hash(&data3);
1111
1112        assert_eq!(hash1, hash2);
1113        assert_ne!(hash1, hash3);
1114    }
1115
1116    #[test]
1117    fn test_check_owner_access_allows_admin() {
1118        let auth = forge_core::function::AuthContext::authenticated_without_uuid(
1119            vec!["admin".to_string()],
1120            HashMap::from([(
1121                "sub".to_string(),
1122                serde_json::Value::String("admin-1".to_string()),
1123            )]),
1124        );
1125
1126        let result = Reactor::check_owner_access(Some("other-user".to_string()), &auth);
1127        assert!(result.is_ok());
1128    }
1129}