Skip to main content

opendev_web/state/
mod.rs

1//! Shared application state.
2//!
3//! Thread-safe state shared between HTTP handlers and WebSocket connections.
4//! Uses `tokio::sync::oneshot` channels for approval, ask-user, and plan-approval
5//! notification so that waiting agent tasks are woken immediately on resolution
6//! (no polling).
7
8mod approvals;
9mod bridge;
10
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::{Mutex, RwLock, broadcast, mpsc, oneshot};
14
15use opendev_config::ModelRegistry;
16use opendev_history::SessionManager;
17use opendev_http::UserStore;
18use opendev_models::AppConfig;
19
20/// WebSocket broadcast message.
21#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
22pub struct WsBroadcast {
23    #[serde(rename = "type")]
24    pub msg_type: String,
25    #[serde(default)]
26    pub data: serde_json::Value,
27}
28
29/// Shared application state wrapped in Arc for use with Axum.
30#[derive(Clone)]
31pub struct AppState {
32    inner: Arc<AppStateInner>,
33}
34
35pub(super) struct AppStateInner {
36    /// Session manager for persistence.
37    pub(super) session_manager: RwLock<SessionManager>,
38    /// Application configuration.
39    pub(super) config: RwLock<AppConfig>,
40    /// Working directory for the current project.
41    pub(super) working_dir: String,
42    /// Broadcast channel for WebSocket messages.
43    pub(super) ws_tx: broadcast::Sender<WsBroadcast>,
44    /// Pending approval requests: approval_id -> (metadata, oneshot sender).
45    pub(super) pending_approvals: Mutex<HashMap<String, PendingApprovalSlot>>,
46    /// Pending ask-user requests: request_id -> (metadata, oneshot sender).
47    pub(super) pending_ask_users: Mutex<HashMap<String, PendingAskUserSlot>>,
48    /// Pending plan approval requests: request_id -> (metadata, oneshot sender).
49    pub(super) pending_plan_approvals: Mutex<HashMap<String, PendingPlanApprovalSlot>>,
50    /// Current operation mode (normal/plan).
51    pub(super) mode: RwLock<OperationMode>,
52    /// Autonomy level.
53    pub(super) autonomy_level: RwLock<String>,
54    /// Interrupt flag.
55    pub(super) interrupt_requested: Mutex<bool>,
56    /// Running sessions: session_id -> status.
57    pub(super) running_sessions: Mutex<HashMap<String, String>>,
58    /// Live message injection queues: session_id -> bounded mpsc sender.
59    pub(super) injection_queues: Mutex<HashMap<String, mpsc::Sender<String>>>,
60    /// Agent executor (trait-object, set once on first query).
61    pub(super) agent_executor: Mutex<Option<Arc<dyn AgentExecutor>>>,
62    /// User store for authentication.
63    pub(super) user_store: Arc<UserStore>,
64    /// Model/provider registry from models.dev cache.
65    pub(super) model_registry: RwLock<ModelRegistry>,
66    /// Bridge mode state.
67    pub(super) bridge: RwLock<BridgeState>,
68}
69
70/// Bridge mode state: when the TUI owns agent execution and
71/// the Web UI mirrors it.
72#[derive(Debug, Default)]
73pub(super) struct BridgeState {
74    /// Session ID currently owned by the TUI bridge.
75    pub(super) session_id: Option<String>,
76    /// Whether bridge mode is active.
77    pub(super) active: bool,
78}
79
80/// Operation mode for the agent.
81#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
82#[serde(rename_all = "lowercase")]
83pub enum OperationMode {
84    Normal,
85    Plan,
86}
87
88impl std::fmt::Display for OperationMode {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        match self {
91            OperationMode::Normal => write!(f, "normal"),
92            OperationMode::Plan => write!(f, "plan"),
93        }
94    }
95}
96
97/// Metadata for a pending approval request.
98#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
99pub struct PendingApproval {
100    pub tool_name: String,
101    pub arguments: serde_json::Value,
102    pub session_id: Option<String>,
103}
104
105/// Internal slot holding approval metadata and the oneshot sender.
106pub(super) struct PendingApprovalSlot {
107    pub meta: PendingApproval,
108    pub tx: Option<oneshot::Sender<ApprovalResult>>,
109}
110
111/// Result sent through the oneshot channel when an approval is resolved.
112#[derive(Debug, Clone)]
113pub struct ApprovalResult {
114    pub approved: bool,
115    pub auto_approve: bool,
116}
117
118/// Metadata for a pending ask-user request.
119#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
120pub struct PendingAskUser {
121    pub prompt: String,
122    pub session_id: Option<String>,
123}
124
125/// Internal slot holding ask-user metadata and the oneshot sender.
126pub(super) struct PendingAskUserSlot {
127    pub meta: PendingAskUser,
128    pub tx: Option<oneshot::Sender<AskUserResult>>,
129}
130
131/// Result sent through the oneshot channel when ask-user is resolved.
132#[derive(Debug, Clone)]
133pub struct AskUserResult {
134    pub answers: Option<serde_json::Value>,
135    pub cancelled: bool,
136}
137
138/// Metadata for a pending plan approval request.
139#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
140pub struct PendingPlanApproval {
141    pub data: serde_json::Value,
142    pub session_id: Option<String>,
143}
144
145/// Internal slot holding plan-approval metadata and the oneshot sender.
146pub(super) struct PendingPlanApprovalSlot {
147    pub meta: PendingPlanApproval,
148    pub tx: Option<oneshot::Sender<PlanApprovalResult>>,
149}
150
151/// Result sent through the oneshot channel when a plan approval is resolved.
152#[derive(Debug, Clone)]
153pub struct PlanApprovalResult {
154    pub action: String,
155    pub feedback: String,
156}
157
158/// Trait for agent execution -- injected into AppState for testability.
159#[async_trait::async_trait]
160pub trait AgentExecutor: Send + Sync + 'static {
161    /// Execute a query for a given session. Called as a background task.
162    async fn execute_query(
163        &self,
164        message: String,
165        session_id: String,
166        state: AppState,
167    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
168}
169
170/// Injection queue capacity per session.
171const INJECTION_QUEUE_CAPACITY: usize = 10;
172
173impl AppState {
174    /// Create a new AppState.
175    pub fn new(
176        session_manager: SessionManager,
177        config: AppConfig,
178        working_dir: String,
179        user_store: UserStore,
180        model_registry: ModelRegistry,
181    ) -> Self {
182        let (ws_tx, _) = broadcast::channel(256);
183        Self {
184            inner: Arc::new(AppStateInner {
185                session_manager: RwLock::new(session_manager),
186                config: RwLock::new(config),
187                working_dir,
188                ws_tx,
189                pending_approvals: Mutex::new(HashMap::new()),
190                pending_ask_users: Mutex::new(HashMap::new()),
191                pending_plan_approvals: Mutex::new(HashMap::new()),
192                mode: RwLock::new(OperationMode::Normal),
193                autonomy_level: RwLock::new("Manual".to_string()),
194                interrupt_requested: Mutex::new(false),
195                running_sessions: Mutex::new(HashMap::new()),
196                injection_queues: Mutex::new(HashMap::new()),
197                agent_executor: Mutex::new(None),
198                user_store: Arc::new(user_store),
199                model_registry: RwLock::new(model_registry),
200                bridge: RwLock::new(BridgeState::default()),
201            }),
202        }
203    }
204
205    // --- Accessors ---
206
207    /// Get a read guard for the session manager.
208    pub async fn session_manager(&self) -> tokio::sync::RwLockReadGuard<'_, SessionManager> {
209        self.inner.session_manager.read().await
210    }
211
212    /// Get a write guard for the session manager.
213    pub async fn session_manager_mut(&self) -> tokio::sync::RwLockWriteGuard<'_, SessionManager> {
214        self.inner.session_manager.write().await
215    }
216
217    /// Get the current session ID (if a session is loaded).
218    pub async fn current_session_id(&self) -> Option<String> {
219        self.inner
220            .session_manager
221            .read()
222            .await
223            .current_session()
224            .map(|s| s.id.clone())
225    }
226
227    /// Get a read guard for the app config.
228    pub async fn config(&self) -> tokio::sync::RwLockReadGuard<'_, AppConfig> {
229        self.inner.config.read().await
230    }
231
232    /// Get a write guard for the app config.
233    pub async fn config_mut(&self) -> tokio::sync::RwLockWriteGuard<'_, AppConfig> {
234        self.inner.config.write().await
235    }
236
237    /// Get the working directory.
238    pub fn working_dir(&self) -> &str {
239        &self.inner.working_dir
240    }
241
242    // --- User store ---
243
244    /// Get a reference to the user store.
245    pub fn user_store(&self) -> &UserStore {
246        &self.inner.user_store
247    }
248
249    // --- Model registry ---
250
251    /// Get a read guard for the model registry.
252    pub async fn model_registry(&self) -> tokio::sync::RwLockReadGuard<'_, ModelRegistry> {
253        self.inner.model_registry.read().await
254    }
255
256    /// Get a write guard for the model registry.
257    pub async fn model_registry_mut(&self) -> tokio::sync::RwLockWriteGuard<'_, ModelRegistry> {
258        self.inner.model_registry.write().await
259    }
260
261    // --- WebSocket ---
262
263    /// Get a clone of the broadcast sender.
264    pub fn ws_sender(&self) -> broadcast::Sender<WsBroadcast> {
265        self.inner.ws_tx.clone()
266    }
267
268    /// Subscribe to WebSocket broadcasts.
269    pub fn ws_subscribe(&self) -> broadcast::Receiver<WsBroadcast> {
270        self.inner.ws_tx.subscribe()
271    }
272
273    /// Broadcast a message to all WebSocket subscribers.
274    pub fn broadcast(&self, msg: WsBroadcast) {
275        // Ignore send errors (no subscribers is fine).
276        let _ = self.inner.ws_tx.send(msg);
277    }
278
279    // --- Mode / settings ---
280
281    /// Get the current operation mode.
282    pub async fn mode(&self) -> OperationMode {
283        *self.inner.mode.read().await
284    }
285
286    /// Set the operation mode.
287    pub async fn set_mode(&self, mode: OperationMode) {
288        *self.inner.mode.write().await = mode;
289    }
290
291    // --- Autonomy level ---
292
293    /// Get the current autonomy level.
294    pub async fn autonomy_level(&self) -> String {
295        self.inner.autonomy_level.read().await.clone()
296    }
297
298    /// Set the autonomy level.
299    pub async fn set_autonomy_level(&self, level: String) {
300        *self.inner.autonomy_level.write().await = level;
301    }
302
303    // --- Interrupt ---
304
305    /// Request an interrupt.
306    ///
307    /// Also denies all pending approvals, ask-user, and plan-approval requests
308    /// by sending rejection through their oneshot channels so blocked tasks wake up.
309    pub async fn request_interrupt(&self) {
310        *self.inner.interrupt_requested.lock().await = true;
311
312        // Deny all pending approvals.
313        {
314            let mut approvals = self.inner.pending_approvals.lock().await;
315            for (_id, slot) in approvals.iter_mut() {
316                if let Some(tx) = slot.tx.take() {
317                    let _ = tx.send(ApprovalResult {
318                        approved: false,
319                        auto_approve: false,
320                    });
321                }
322            }
323            approvals.clear();
324        }
325
326        // Cancel all pending ask-user requests.
327        {
328            let mut ask_users = self.inner.pending_ask_users.lock().await;
329            for (_id, slot) in ask_users.iter_mut() {
330                if let Some(tx) = slot.tx.take() {
331                    let _ = tx.send(AskUserResult {
332                        answers: None,
333                        cancelled: true,
334                    });
335                }
336            }
337            ask_users.clear();
338        }
339
340        // Reject all pending plan approvals.
341        {
342            let mut plan_approvals = self.inner.pending_plan_approvals.lock().await;
343            for (_id, slot) in plan_approvals.iter_mut() {
344                if let Some(tx) = slot.tx.take() {
345                    let _ = tx.send(PlanApprovalResult {
346                        action: "reject".to_string(),
347                        feedback: "Interrupted".to_string(),
348                    });
349                }
350            }
351            plan_approvals.clear();
352        }
353    }
354
355    /// Clear the interrupt flag.
356    pub async fn clear_interrupt(&self) {
357        *self.inner.interrupt_requested.lock().await = false;
358    }
359
360    /// Check if interrupt has been requested.
361    pub async fn is_interrupt_requested(&self) -> bool {
362        *self.inner.interrupt_requested.lock().await
363    }
364
365    // --- Running sessions ---
366
367    /// Mark a session as running.
368    pub async fn set_session_running(&self, session_id: String) {
369        self.inner
370            .running_sessions
371            .lock()
372            .await
373            .insert(session_id, "running".to_string());
374    }
375
376    /// Mark a session as idle.
377    pub async fn set_session_idle(&self, session_id: &str) {
378        self.inner.running_sessions.lock().await.remove(session_id);
379    }
380
381    /// Check if a session is running.
382    pub async fn is_session_running(&self, session_id: &str) -> bool {
383        self.inner
384            .running_sessions
385            .lock()
386            .await
387            .contains_key(session_id)
388    }
389
390    // --- Git branch ---
391
392    /// Get the git branch for the working directory.
393    pub fn git_branch(&self) -> Option<String> {
394        let output = std::process::Command::new("git")
395            .args(["rev-parse", "--abbrev-ref", "HEAD"])
396            .current_dir(&self.inner.working_dir)
397            .output()
398            .ok()?;
399
400        if output.status.success() {
401            Some(String::from_utf8_lossy(&output.stdout).trim().to_string())
402        } else {
403            None
404        }
405    }
406}
407
408#[cfg(test)]
409mod tests;