1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3
4use tokio::sync::broadcast;
5
6use crate::agent::AgentRegistry;
7use crate::circuit;
8use crate::dependency::DependencyStore;
9use crate::engine::Engine;
10use crate::error::{EnvoyError, Result};
11use crate::event::bus::{DeliveryTracker, EventBus};
12use crate::message::MessageStore;
13use crate::monitor::{ProjectConfigStore, SubscriptionStore};
14use crate::rate_limit::{HybridRateLimiter, RateLimitConfig};
15use crate::status::NudgeConfig;
16use crate::task::store::TaskStore;
17
18pub(crate) fn recover_lock<T>(lock: &Mutex<T>) -> std::sync::MutexGuard<'_, T> {
19 match lock.lock() {
20 Ok(guard) => guard,
21 Err(e) => {
22 eprintln!("mutex lock poisoned, recovering: {e}");
23 e.into_inner()
24 }
25 }
26}
27
28pub(crate) struct WsRegistry {
30 senders: Mutex<HashMap<String, broadcast::Sender<String>>>,
31}
32
33impl WsRegistry {
34 pub(crate) fn new() -> Self {
35 Self {
36 senders: Mutex::new(HashMap::new()),
37 }
38 }
39
40 pub(crate) fn register(&self, agent_id: &str) -> broadcast::Receiver<String> {
41 let mut senders = recover_lock(&self.senders);
42 if let Some(tx) = senders.get(agent_id) {
43 tx.subscribe()
44 } else {
45 let (tx, rx) = broadcast::channel(256);
46 senders.insert(agent_id.to_string(), tx);
47 rx
48 }
49 }
50
51 pub(crate) fn unregister(&self, agent_id: &str) {
52 let mut senders = recover_lock(&self.senders);
53 senders.remove(agent_id);
54 }
55
56 pub(crate) fn send_json(
57 &self,
58 agent_id: &str,
59 event_type: &str,
60 data: &serde_json::Value,
61 ) -> bool {
62 let event = serde_json::json!({
63 "event": event_type,
64 "data": data
65 });
66 let senders = recover_lock(&self.senders);
67 if let Some(tx) = senders.get(agent_id) {
68 tx.send(event.to_string()).is_ok()
69 } else {
70 false
71 }
72 }
73}
74
75pub struct AppState {
77 pub agent_registry: AgentRegistry,
78 pub audit_store: crate::audit::AuditStore,
79 pub dependency_store: DependencyStore,
80 pub message_store: MessageStore,
81 pub event_bus: EventBus,
82 pub delivery_tracker: DeliveryTracker,
83 pub task_store: TaskStore,
84 pub subscription_store: SubscriptionStore,
85 pub project_config_store: ProjectConfigStore,
86 pub circuit_breaker: circuit::CircuitBreaker,
87 pub(crate) engine: Arc<Mutex<Engine>>,
88 pub(crate) ws_registry: WsRegistry,
89 pub rate_limiter: HybridRateLimiter,
90 pub nudge_config: Mutex<NudgeConfig>,
91 pub start_time: chrono::DateTime<chrono::Utc>,
92 #[cfg(feature = "atheneum")]
93 pub atheneum_path: Option<String>,
94}
95
96impl AppState {
97 pub fn new(engine: Engine) -> Result<Self> {
98 let agent_registry = AgentRegistry::new(engine.graph())?;
99 let rate_limiter = HybridRateLimiter::new(
100 engine.graph(),
101 RateLimitConfig::default(),
102 1000, )?;
104 Ok(Self {
105 agent_registry,
106 audit_store: crate::audit::AuditStore::new(),
107 dependency_store: DependencyStore::new(),
108 message_store: MessageStore::new(),
109 event_bus: EventBus::new(),
110 delivery_tracker: DeliveryTracker::new(),
111 task_store: TaskStore::new(),
112 subscription_store: SubscriptionStore::new(),
113 project_config_store: ProjectConfigStore::new(),
114 circuit_breaker: circuit::CircuitBreaker::with_defaults(),
115 engine: Arc::new(Mutex::new(engine)),
116 ws_registry: WsRegistry::new(),
117 rate_limiter,
118 nudge_config: Mutex::new(NudgeConfig::default()),
119 start_time: chrono::Utc::now(),
120 #[cfg(feature = "atheneum")]
121 atheneum_path: None,
122 })
123 }
124
125 #[cfg(feature = "atheneum")]
126 pub fn with_atheneum(mut self, path: Option<String>) -> Self {
127 self.atheneum_path = path;
128 self
129 }
130
131 #[cfg(feature = "atheneum")]
132 pub fn require_atheneum_path(&self) -> Result<String> {
133 self.atheneum_path
134 .clone()
135 .ok_or_else(|| EnvoyError::Atheneum(anyhow::anyhow!("atheneum not configured")))
136 }
137
138 pub async fn with_graph_async<F, T>(&self, f: F) -> Result<T>
141 where
142 F: FnOnce(&sqlitegraph::SqliteGraph) -> T + Send + 'static,
143 T: Send + 'static,
144 {
145 let engine = self.engine.clone();
146 let result = tokio::task::spawn_blocking(move || {
147 let engine = recover_lock(&engine);
148 f(engine.graph())
149 })
150 .await
151 .map_err(|_| EnvoyError::InvalidEntity("blocking task panicked".into()))?;
152 Ok(result)
153 }
154
155 pub async fn with_engine_async<F, T>(&self, f: F) -> Result<T>
157 where
158 F: FnOnce(&Engine) -> Result<T> + Send + 'static,
159 T: Send + 'static,
160 {
161 let engine = self.engine.clone();
162 tokio::task::spawn_blocking(move || {
163 let engine = recover_lock(&engine);
164 f(&engine)
165 })
166 .await
167 .map_err(|_| EnvoyError::InvalidEntity("blocking task panicked".into()))?
168 }
169}
170
171pub async fn run_nudge_loop(state: Arc<AppState>) {
173 loop {
174 let interval = {
175 let cfg = recover_lock(&state.nudge_config);
176 cfg.check_interval_seconds
177 };
178 tokio::time::sleep(std::time::Duration::from_secs(interval)).await;
179
180 let threshold = recover_lock(&state.nudge_config).stale_threshold_minutes;
181 let stale = match state.agent_registry.get_stale_agents(threshold) {
182 Ok(s) => s,
183 Err(e) => {
184 eprintln!("nudge loop: failed to get stale agents: {e}");
185 continue;
186 }
187 };
188
189 for agent in &stale {
190 let nudge_data = serde_json::json!({
191 "reason": format!(
192 "No heartbeat for {} minutes. Current status: {:?}",
193 threshold,
194 agent.status.as_ref().map(|s| s.state.as_str()).unwrap_or("unknown")
195 ),
196 "severity": "warning",
197 "agent_id": agent.agent_id,
198 "last_heartbeat": agent.last_heartbeat_at,
199 });
200 state
201 .ws_registry
202 .send_json(&agent.agent_id, "nudge", &nudge_data);
203
204 let state_fb = state.clone();
206 let agent_id_fb = agent.agent_id.clone();
207 let (deps, reclaimed) = tokio::task::spawn_blocking(move || {
208 let engine = recover_lock(&state_fb.engine);
209 let deps = state_fb
210 .dependency_store
211 .find_by_blocker(engine.graph(), &agent_id_fb)
212 .unwrap_or_default();
213 let reclaimed = state_fb
214 .task_store
215 .reclaim_stale(engine.graph(), &agent_id_fb)
216 .unwrap_or_default();
217 (deps, reclaimed)
218 })
219 .await
220 .unwrap_or((Vec::new(), Vec::new()));
221
222 for dep in &deps {
224 let unblock_msg = serde_json::json!({
225 "blocker_agent": agent.agent_id,
226 "blocker_status": agent.status.as_ref().map(|s| s.state.as_str()).unwrap_or("unknown"),
227 "message": format!(
228 "Your blocker ({}) may be stalled — no heartbeat for {}m",
229 agent.agent_id, threshold
230 ),
231 });
232 state
233 .ws_registry
234 .send_json(&dep.dependent_agent, "blocker_stale", &unblock_msg);
235 }
236 for task_id in &reclaimed {
237 let reclaim_msg = serde_json::json!({
238 "task_id": task_id,
239 "message": format!("Task reclaimed — {} is stale", agent.agent_id),
240 });
241 state
242 .ws_registry
243 .send_json(&agent.agent_id, "task_reclaimed", &reclaim_msg);
244 }
245 }
246 }
247}
248
249pub type SharedState = Arc<AppState>;