Skip to main content

aft/
context.rs

1use std::cell::{Ref, RefCell, RefMut};
2use std::path::{Component, Path, PathBuf};
3use std::sync::mpsc;
4
5use fastembed::TextEmbedding;
6use notify::RecommendedWatcher;
7
8use crate::backup::BackupStore;
9use crate::callgraph::CallGraph;
10use crate::checkpoint::CheckpointStore;
11use crate::config::Config;
12use crate::language::LanguageProvider;
13use crate::lsp::manager::LspManager;
14use crate::search_index::SearchIndex;
15use crate::semantic_index::SemanticIndex;
16
17#[derive(Debug, Clone)]
18pub enum SemanticIndexStatus {
19    Disabled,
20    Building,
21    Ready,
22    Failed(String),
23}
24
25pub enum SemanticIndexEvent {
26    Ready(SemanticIndex),
27    Failed(String),
28}
29
30/// Normalize a path by resolving `.` and `..` components lexically,
31/// without touching the filesystem. This prevents path traversal
32/// attacks when `fs::canonicalize` fails (e.g. for non-existent paths).
33fn normalize_path(path: &Path) -> PathBuf {
34    let mut result = PathBuf::new();
35    for component in path.components() {
36        match component {
37            Component::ParentDir => {
38                // Pop the last component unless we're at root or have no components
39                if !result.pop() {
40                    result.push(component);
41                }
42            }
43            Component::CurDir => {} // Skip `.`
44            _ => result.push(component),
45        }
46    }
47    result
48}
49
50fn resolve_with_existing_ancestors(path: &Path) -> PathBuf {
51    let mut existing = path.to_path_buf();
52    let mut tail_segments = Vec::new();
53
54    while !existing.exists() {
55        if let Some(name) = existing.file_name() {
56            tail_segments.push(name.to_owned());
57        } else {
58            break;
59        }
60
61        existing = match existing.parent() {
62            Some(parent) => parent.to_path_buf(),
63            None => break,
64        };
65    }
66
67    let mut resolved = std::fs::canonicalize(&existing).unwrap_or(existing);
68    for segment in tail_segments.into_iter().rev() {
69        resolved.push(segment);
70    }
71
72    resolved
73}
74
75/// Shared application context threaded through all command handlers.
76///
77/// Holds the language provider, backup/checkpoint stores, configuration,
78/// and call graph engine. Constructed once at startup and passed by
79/// reference to `dispatch`.
80///
81/// Stores use `RefCell` for interior mutability — the binary is single-threaded
82/// (one request at a time on the stdin read loop) so runtime borrow checking
83/// is safe and never contended.
84pub struct AppContext {
85    provider: Box<dyn LanguageProvider>,
86    backup: RefCell<BackupStore>,
87    checkpoint: RefCell<CheckpointStore>,
88    config: RefCell<Config>,
89    callgraph: RefCell<Option<CallGraph>>,
90    search_index: RefCell<Option<SearchIndex>>,
91    search_index_rx:
92        RefCell<Option<crossbeam_channel::Receiver<(SearchIndex, crate::parser::SymbolCache)>>>,
93    semantic_index: RefCell<Option<SemanticIndex>>,
94    semantic_index_rx: RefCell<Option<crossbeam_channel::Receiver<SemanticIndexEvent>>>,
95    semantic_index_status: RefCell<SemanticIndexStatus>,
96    semantic_embedding_model: RefCell<Option<TextEmbedding>>,
97    watcher: RefCell<Option<RecommendedWatcher>>,
98    watcher_rx: RefCell<Option<mpsc::Receiver<notify::Result<notify::Event>>>>,
99    lsp_manager: RefCell<LspManager>,
100}
101
102impl AppContext {
103    pub fn new(provider: Box<dyn LanguageProvider>, config: Config) -> Self {
104        AppContext {
105            provider,
106            backup: RefCell::new(BackupStore::new()),
107            checkpoint: RefCell::new(CheckpointStore::new()),
108            config: RefCell::new(config),
109            callgraph: RefCell::new(None),
110            search_index: RefCell::new(None),
111            search_index_rx: RefCell::new(None),
112            semantic_index: RefCell::new(None),
113            semantic_index_rx: RefCell::new(None),
114            semantic_index_status: RefCell::new(SemanticIndexStatus::Disabled),
115            semantic_embedding_model: RefCell::new(None),
116            watcher: RefCell::new(None),
117            watcher_rx: RefCell::new(None),
118            lsp_manager: RefCell::new(LspManager::new()),
119        }
120    }
121
122    /// Access the language provider.
123    pub fn provider(&self) -> &dyn LanguageProvider {
124        self.provider.as_ref()
125    }
126
127    /// Access the backup store.
128    pub fn backup(&self) -> &RefCell<BackupStore> {
129        &self.backup
130    }
131
132    /// Access the checkpoint store.
133    pub fn checkpoint(&self) -> &RefCell<CheckpointStore> {
134        &self.checkpoint
135    }
136
137    /// Access the configuration (shared borrow).
138    pub fn config(&self) -> Ref<'_, Config> {
139        self.config.borrow()
140    }
141
142    /// Access the configuration (mutable borrow).
143    pub fn config_mut(&self) -> RefMut<'_, Config> {
144        self.config.borrow_mut()
145    }
146
147    /// Access the call graph engine.
148    pub fn callgraph(&self) -> &RefCell<Option<CallGraph>> {
149        &self.callgraph
150    }
151
152    /// Access the search index.
153    pub fn search_index(&self) -> &RefCell<Option<SearchIndex>> {
154        &self.search_index
155    }
156
157    /// Access the search-index build receiver (returns index + pre-warmed symbol cache).
158    pub fn search_index_rx(
159        &self,
160    ) -> &RefCell<Option<crossbeam_channel::Receiver<(SearchIndex, crate::parser::SymbolCache)>>>
161    {
162        &self.search_index_rx
163    }
164
165    /// Access the semantic search index.
166    pub fn semantic_index(&self) -> &RefCell<Option<SemanticIndex>> {
167        &self.semantic_index
168    }
169
170    /// Access the semantic-index build receiver.
171    pub fn semantic_index_rx(
172        &self,
173    ) -> &RefCell<Option<crossbeam_channel::Receiver<SemanticIndexEvent>>> {
174        &self.semantic_index_rx
175    }
176
177    pub fn semantic_index_status(&self) -> &RefCell<SemanticIndexStatus> {
178        &self.semantic_index_status
179    }
180
181    /// Access the cached semantic embedding model.
182    pub fn semantic_embedding_model(&self) -> &RefCell<Option<TextEmbedding>> {
183        &self.semantic_embedding_model
184    }
185
186    /// Access the file watcher handle (kept alive to continue watching).
187    pub fn watcher(&self) -> &RefCell<Option<RecommendedWatcher>> {
188        &self.watcher
189    }
190
191    /// Access the watcher event receiver.
192    pub fn watcher_rx(&self) -> &RefCell<Option<mpsc::Receiver<notify::Result<notify::Event>>>> {
193        &self.watcher_rx
194    }
195
196    /// Access the LSP manager.
197    pub fn lsp(&self) -> RefMut<'_, LspManager> {
198        self.lsp_manager.borrow_mut()
199    }
200
201    /// Notify LSP servers that a file was written.
202    /// Call this after write_format_validate in command handlers.
203    pub fn lsp_notify_file_changed(&self, file_path: &Path, content: &str) {
204        if let Ok(mut lsp) = self.lsp_manager.try_borrow_mut() {
205            if let Err(e) = lsp.notify_file_changed(file_path, content) {
206                log::warn!("sync error for {}: {}", file_path.display(), e);
207            }
208        }
209    }
210
211    /// Notify LSP and optionally wait for diagnostics.
212    ///
213    /// Call this after `write_format_validate` when the request has `"diagnostics": true`.
214    /// Sends didChange to the server, waits briefly for publishDiagnostics, and returns
215    /// any diagnostics for the file. If no server is running, returns empty immediately.
216    pub fn lsp_notify_and_collect_diagnostics(
217        &self,
218        file_path: &Path,
219        content: &str,
220        timeout: std::time::Duration,
221    ) -> Vec<crate::lsp::diagnostics::StoredDiagnostic> {
222        let Ok(mut lsp) = self.lsp_manager.try_borrow_mut() else {
223            return Vec::new();
224        };
225
226        // Clear any queued notifications before this write so the wait loop only
227        // observes diagnostics triggered by the current change.
228        lsp.drain_events();
229
230        // Send didChange/didOpen
231        if let Err(e) = lsp.notify_file_changed(file_path, content) {
232            log::warn!("sync error for {}: {}", file_path.display(), e);
233            return Vec::new();
234        }
235
236        // Wait for diagnostics to arrive
237        lsp.wait_for_diagnostics(file_path, timeout)
238    }
239
240    /// Post-write LSP hook: notify server and optionally collect diagnostics.
241    ///
242    /// This is the single call site for all command handlers after `write_format_validate`.
243    /// When `diagnostics` is true, it notifies the server, waits until matching
244    /// diagnostics arrive or the timeout expires, and returns diagnostics for the file.
245    /// When false, it just notifies (fire-and-forget).
246    pub fn lsp_post_write(
247        &self,
248        file_path: &Path,
249        content: &str,
250        params: &serde_json::Value,
251    ) -> Vec<crate::lsp::diagnostics::StoredDiagnostic> {
252        let wants_diagnostics = params
253            .get("diagnostics")
254            .and_then(|v| v.as_bool())
255            .unwrap_or(false);
256
257        if !wants_diagnostics {
258            self.lsp_notify_file_changed(file_path, content);
259            return Vec::new();
260        }
261
262        let wait_ms = params
263            .get("wait_ms")
264            .and_then(|v| v.as_u64())
265            .unwrap_or(1500)
266            .min(10_000); // Cap at 10 seconds to prevent hangs from adversarial input
267
268        self.lsp_notify_and_collect_diagnostics(
269            file_path,
270            content,
271            std::time::Duration::from_millis(wait_ms),
272        )
273    }
274
275    /// Validate that a file path falls within the configured project root.
276    ///
277    /// When `project_root` is configured (normal plugin usage), this resolves the
278    /// path and checks it starts with the root. Returns the canonicalized path on
279    /// success, or an error response on violation.
280    ///
281    /// When no `project_root` is configured (direct CLI usage), all paths pass
282    /// through unrestricted for backward compatibility.
283    pub fn validate_path(
284        &self,
285        req_id: &str,
286        path: &Path,
287    ) -> Result<std::path::PathBuf, crate::protocol::Response> {
288        let config = self.config();
289        // When restrict_to_project_root is false (default), allow all paths
290        if !config.restrict_to_project_root {
291            return Ok(path.to_path_buf());
292        }
293        let root = match &config.project_root {
294            Some(r) => r.clone(),
295            None => return Ok(path.to_path_buf()), // No root configured, allow all
296        };
297        drop(config);
298
299        // Resolve the path (follow symlinks, normalize ..)
300        let resolved = std::fs::canonicalize(path)
301            .unwrap_or_else(|_| resolve_with_existing_ancestors(&normalize_path(path)));
302
303        let resolved_root = std::fs::canonicalize(&root).unwrap_or(root);
304
305        if !resolved.starts_with(&resolved_root) {
306            return Err(crate::protocol::Response::error(
307                req_id,
308                "path_outside_root",
309                format!(
310                    "path '{}' is outside the project root '{}'",
311                    path.display(),
312                    resolved_root.display()
313                ),
314            ));
315        }
316
317        Ok(resolved)
318    }
319
320    /// Count active LSP server instances.
321    pub fn lsp_server_count(&self) -> usize {
322        self.lsp_manager
323            .try_borrow()
324            .map(|lsp| lsp.server_count())
325            .unwrap_or(0)
326    }
327
328    /// Symbol cache statistics from the language provider.
329    pub fn symbol_cache_stats(&self) -> serde_json::Value {
330        if let Some(tsp) = self
331            .provider
332            .as_any()
333            .downcast_ref::<crate::parser::TreeSitterProvider>()
334        {
335            let (local, warm) = tsp.symbol_cache_stats();
336            serde_json::json!({
337                "local_entries": local,
338                "warm_entries": warm,
339            })
340        } else {
341            serde_json::json!({
342                "local_entries": 0,
343                "warm_entries": 0,
344            })
345        }
346    }
347}