Skip to main content

envoy/http/
state.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use parking_lot::Mutex;
5use tokio::sync::broadcast;
6
7use crate::agent::AgentRegistry;
8use crate::circuit;
9use crate::dependency::DependencyStore;
10use crate::engine::Engine;
11use crate::error::{EnvoyError, Result};
12use crate::event::bus::{DeliveryTracker, EventBus};
13use crate::message::MessageStore;
14use crate::monitor::{ProjectConfigStore, SubscriptionStore};
15use crate::rate_limit::{HybridRateLimiter, RateLimitConfig};
16use crate::status::NudgeConfig;
17use crate::task::store::TaskStore;
18
19#[cfg(feature = "atheneum")]
20use atheneum::AtheneumGraph;
21
22/// Registry of active WebSocket senders, keyed by agent_id.
23pub(crate) struct WsRegistry {
24    senders: Mutex<HashMap<String, broadcast::Sender<String>>>,
25}
26
27impl WsRegistry {
28    pub(crate) fn new() -> Self {
29        Self {
30            senders: Mutex::new(HashMap::new()),
31        }
32    }
33
34    pub(crate) fn register(&self, agent_id: &str) -> broadcast::Receiver<String> {
35        let mut senders = self.senders.lock();
36        if let Some(tx) = senders.get(agent_id) {
37            tx.subscribe()
38        } else {
39            let (tx, rx) = broadcast::channel(256);
40            senders.insert(agent_id.to_string(), tx);
41            rx
42        }
43    }
44
45    pub(crate) fn unregister(&self, agent_id: &str) {
46        let mut senders = self.senders.lock();
47        senders.remove(agent_id);
48    }
49
50    pub(crate) fn send_json(
51        &self,
52        agent_id: &str,
53        event_type: &str,
54        data: &serde_json::Value,
55    ) -> bool {
56        let event = serde_json::json!({
57            "event": event_type,
58            "data": data
59        });
60        let senders = self.senders.lock();
61        if let Some(tx) = senders.get(agent_id) {
62            tx.send(event.to_string()).is_ok()
63        } else {
64            false
65        }
66    }
67}
68
69/// Shared application state across all handlers.
70pub struct AppState {
71    pub agent_registry: AgentRegistry,
72    pub audit_store: crate::audit::AuditStore,
73    pub dependency_store: DependencyStore,
74    pub message_store: MessageStore,
75    pub event_bus: EventBus,
76    pub delivery_tracker: DeliveryTracker,
77    pub task_store: TaskStore,
78    pub subscription_store: SubscriptionStore,
79    pub project_config_store: ProjectConfigStore,
80    pub circuit_breaker: circuit::CircuitBreaker,
81    pub(crate) engine: Arc<Mutex<Engine>>,
82    pub(crate) ws_registry: WsRegistry,
83    pub rate_limiter: HybridRateLimiter,
84    pub nudge_config: Mutex<NudgeConfig>,
85    pub start_time: chrono::DateTime<chrono::Utc>,
86    #[cfg(feature = "atheneum")]
87    pub atheneum_path: Option<String>,
88    #[cfg(feature = "atheneum")]
89    atheneum_graph: Arc<Mutex<Option<AtheneumGraph>>>,
90}
91
92impl AppState {
93    pub fn new(engine: Engine) -> Result<Self> {
94        let agent_registry = AgentRegistry::new(engine.graph())?;
95        let rate_limiter = HybridRateLimiter::new(
96            engine.graph(),
97            RateLimitConfig::default(),
98            1000, // L1 capacity
99        )?;
100        Ok(Self {
101            agent_registry,
102            audit_store: crate::audit::AuditStore::new(),
103            dependency_store: DependencyStore::new(),
104            message_store: MessageStore::new(),
105            event_bus: EventBus::new(),
106            delivery_tracker: DeliveryTracker::new(),
107            task_store: TaskStore::new(),
108            subscription_store: SubscriptionStore::new(),
109            project_config_store: ProjectConfigStore::new(),
110            circuit_breaker: circuit::CircuitBreaker::with_defaults(),
111            engine: Arc::new(Mutex::new(engine)),
112            ws_registry: WsRegistry::new(),
113            rate_limiter,
114            nudge_config: Mutex::new(NudgeConfig::default()),
115            start_time: chrono::Utc::now(),
116            #[cfg(feature = "atheneum")]
117            atheneum_path: None,
118            #[cfg(feature = "atheneum")]
119            atheneum_graph: Arc::new(Mutex::new(None)),
120        })
121    }
122
123    #[cfg(feature = "atheneum")]
124    pub fn with_atheneum(mut self, path: Option<String>) -> Self {
125        self.atheneum_path = path;
126        self
127    }
128
129    #[cfg(feature = "atheneum")]
130    pub fn require_atheneum_path(&self) -> Result<String> {
131        self.atheneum_path
132            .clone()
133            .ok_or_else(|| EnvoyError::Atheneum(anyhow::anyhow!("atheneum not configured")))
134    }
135
136    /// Async version of with_graph — offloads DB work to the blocking thread pool.
137    /// Use this from all async handlers to avoid blocking tokio worker threads.
138    pub async fn with_graph_async<F, T>(&self, f: F) -> Result<T>
139    where
140        F: FnOnce(&sqlitegraph::SqliteGraph) -> T + Send + 'static,
141        T: Send + 'static,
142    {
143        let engine = self.engine.clone();
144        let result = tokio::task::spawn_blocking(move || {
145            let engine = engine.lock();
146            f(engine.graph())
147        })
148        .await
149        .map_err(|_| EnvoyError::InvalidEntity("blocking task panicked".into()))?;
150        Ok(result)
151    }
152
153    /// Async version that provides the full Engine (not just graph).
154    pub async fn with_engine_async<F, T>(&self, f: F) -> Result<T>
155    where
156        F: FnOnce(&Engine) -> Result<T> + Send + 'static,
157        T: Send + 'static,
158    {
159        let engine = self.engine.clone();
160        tokio::task::spawn_blocking(move || {
161            let engine = engine.lock();
162            f(&engine)
163        })
164        .await
165        .map_err(|_| EnvoyError::InvalidEntity("blocking task panicked".into()))?
166    }
167
168    /// Async version that provides a cached AtheneumGraph.
169    /// Eliminates per-request SQLite open cost by reusing a single connection.
170    #[cfg(feature = "atheneum")]
171    pub async fn with_atheneum_async<F, T>(&self, f: F) -> Result<T>
172    where
173        F: FnOnce(&AtheneumGraph) -> Result<T> + Send + 'static,
174        T: Send + 'static,
175    {
176        let path = self.require_atheneum_path()?;
177        let graph_arc = self.atheneum_graph.clone();
178        tokio::task::spawn_blocking(move || {
179            let mut guard = graph_arc.lock();
180            if guard.is_none() {
181                let g = AtheneumGraph::open(std::path::Path::new(&path))?;
182                *guard = Some(g);
183            }
184            let g = guard.as_ref().unwrap();
185            f(g)
186        })
187        .await
188        .map_err(|_| EnvoyError::InvalidEntity("blocking task panicked".into()))?
189    }
190}
191
192/// Background task that checks for stale agents and pushes nudge events.
193pub async fn run_nudge_loop(state: Arc<AppState>) {
194    loop {
195        let interval = {
196            let cfg = state.nudge_config.lock();
197            cfg.check_interval_seconds
198        };
199        tokio::time::sleep(std::time::Duration::from_secs(interval)).await;
200
201        let threshold = state.nudge_config.lock().stale_threshold_minutes;
202        let stale = match state.agent_registry.get_stale_agents(threshold) {
203            Ok(s) => s,
204            Err(e) => {
205                eprintln!("nudge loop: failed to get stale agents: {e}");
206                continue;
207            }
208        };
209
210        for agent in &stale {
211            let nudge_data = serde_json::json!({
212                "reason": format!(
213                    "No heartbeat for {} minutes. Current status: {:?}",
214                    threshold,
215                    agent.status.as_ref().map(|s| s.state.as_str()).unwrap_or("unknown")
216                ),
217                "severity": "warning",
218                "agent_id": agent.agent_id,
219                "last_heartbeat": agent.last_heartbeat_at,
220            });
221            state
222                .ws_registry
223                .send_json(&agent.agent_id, "nudge", &nudge_data);
224
225            // Fetch blocked dependents + reclaim stale tasks via blocking pool
226            let state_fb = state.clone();
227            let agent_id_fb = agent.agent_id.clone();
228            let (deps, reclaimed) = tokio::task::spawn_blocking(move || {
229                let engine = state_fb.engine.lock();
230                let deps = state_fb
231                    .dependency_store
232                    .find_by_blocker(engine.graph(), &agent_id_fb)
233                    .unwrap_or_default();
234                let reclaimed = state_fb
235                    .task_store
236                    .reclaim_stale(engine.graph(), &agent_id_fb)
237                    .unwrap_or_default();
238                (deps, reclaimed)
239            })
240            .await
241            .unwrap_or((Vec::new(), Vec::new()));
242
243            // WS sends are in-memory
244            for dep in &deps {
245                let unblock_msg = serde_json::json!({
246                    "blocker_agent": agent.agent_id,
247                    "blocker_status": agent.status.as_ref().map(|s| s.state.as_str()).unwrap_or("unknown"),
248                    "message": format!(
249                        "Your blocker ({}) may be stalled — no heartbeat for {}m",
250                        agent.agent_id, threshold
251                    ),
252                });
253                state
254                    .ws_registry
255                    .send_json(&dep.dependent_agent, "blocker_stale", &unblock_msg);
256            }
257            for task_id in &reclaimed {
258                let reclaim_msg = serde_json::json!({
259                    "task_id": task_id,
260                    "message": format!("Task reclaimed — {} is stale", agent.agent_id),
261                });
262                state
263                    .ws_registry
264                    .send_json(&agent.agent_id, "task_reclaimed", &reclaim_msg);
265            }
266        }
267    }
268}
269
270pub type SharedState = Arc<AppState>;