Skip to main content

forge_agent/
observe.rs

1//! Observation phase - Graph-based context gathering.
2//!
3//! This module implements the observation phase of the agent loop, gathering
4//! relevant context from the code graph to inform intelligent operations.
5
6use crate::{AgentError, Result};
7use forge_core::{
8    types::{Symbol, SymbolId},
9    Forge,
10};
11use std::collections::HashMap;
12use std::sync::Arc;
13
14/// Observer for gathering context from a code graph.
15///
16/// The Observer uses Forge SDK to query symbols and references.
17#[derive(Clone)]
18pub struct Observer {
19    /// The Forge SDK instance for graph queries
20    forge: Arc<Forge>,
21    /// Cache for observation results (query -> observation)
22    cache: Arc<tokio::sync::RwLock<HashMap<String, Observation>>>,
23}
24
25impl Observer {
26    /// Creates a new Observer with given Forge instance.
27    pub fn new(forge: Forge) -> Self {
28        Self {
29            forge: Arc::new(forge),
30            cache: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
31        }
32    }
33
34    /// Gathers observation data for a natural language query.
35    pub async fn gather(&self, query: &str) -> Result<Observation> {
36        // Check cache first
37        {
38            let cache = self.cache.read().await;
39            if let Some(cached) = cache.get(query) {
40                return Ok(cached.clone());
41            }
42        }
43
44        // For now, gather symbols using the graph API
45        let symbols = self.gather_symbols(query).await?;
46
47        let observation = crate::Observation {
48            query: query.to_string(),
49            symbols,
50        };
51
52        // Cache the result
53        {
54            let mut cache = self.cache.write().await;
55            cache.insert(query.to_string(), observation.clone());
56        }
57
58        Ok(observation)
59    }
60
61    /// Gathers symbols by parsing the query.
62    async fn gather_symbols(&self, query: &str) -> Result<Vec<ObservedSymbol>> {
63        let _graph = self.forge.as_ref();
64        let mut symbols = Vec::new();
65
66        // Parse query: if it contains "find" and "named", extract the name
67        let query_lower = query.to_lowercase();
68
69        if query_lower.contains("find") && query_lower.contains("named") {
70            // Extract name after "named"
71            if let Some(pos) = query_lower.find("named") {
72                let remaining = &query[pos + 6..];
73                let name = remaining.trim().trim_end_matches('?').trim().to_string();
74                if !name.is_empty() {
75                    // For now, we can't find by name without knowing the file
76                    // Return a placeholder symbol
77                    symbols.push(ObservedSymbol {
78                        id: SymbolId(0),
79                        name: name.clone(),
80                        kind: forge_core::types::SymbolKind::Function,
81                        location: forge_core::types::Location {
82                            file_path: std::path::PathBuf::from("<unknown>"),
83                            byte_start: 0,
84                            byte_end: 0,
85                            line_number: 0,
86                        },
87                    });
88                }
89            }
90        }
91
92        Ok(symbols)
93    }
94
95    /// Clears the observation cache.
96    pub async fn clear_cache(&self) {
97        let mut cache = self.cache.write().await;
98        cache.clear();
99    }
100}
101
102/// Parsed query representation.
103#[derive(Debug, Clone)]
104pub struct ParsedQuery {
105    /// Original query string
106    original: String,
107}
108
109/// Result of the observation phase.
110///
111/// Contains relevant context gathered from the code graph.
112#[derive(Clone, Debug)]
113pub struct Observation {
114    /// The original Query
115    pub query: String,
116    /// Relevant symbols found
117    pub symbols: Vec<ObservedSymbol>,
118}
119
120/// A symbol observed during the observation phase.
121#[derive(Clone, Debug)]
122pub struct ObservedSymbol {
123    /// Unique symbol identifier
124    pub id: SymbolId,
125    /// Symbol name
126    pub name: String,
127    /// Kind of symbol
128    pub kind: forge_core::types::SymbolKind,
129    /// Source location
130    pub location: forge_core::types::Location,
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use tempfile::TempDir;
137
138    async fn create_test_observer() -> (Observer, TempDir) {
139        let temp_dir = TempDir::new().unwrap();
140        let forge = Forge::open(temp_dir.path()).await.unwrap();
141        let observer = Observer::new(forge);
142        (observer, temp_dir)
143    }
144
145    #[tokio::test]
146    async fn test_observer_creation() {
147        let temp_dir = TempDir::new().unwrap();
148        let forge = Forge::open(temp_dir.path()).await.unwrap();
149        let observer = Observer::new(forge);
150
151        // Observer should have empty cache
152        let cache = observer.cache.read().await;
153        assert!(cache.is_empty());
154    }
155
156    #[tokio::test]
157    async fn test_observation_caching() {
158        let (observer, _temp_dir) = create_test_observer().await;
159
160        // First call should not find anything in empty DB
161        let result1 = observer.gather("test query").await;
162        assert!(result1.is_ok());
163
164        // Second call should hit cache
165        let result2 = observer.gather("test query").await;
166        assert!(result2.is_ok());
167
168        // Results should be identical
169        assert_eq!(result1.unwrap().query, result2.unwrap().query);
170    }
171
172    #[tokio::test]
173    async fn test_clear_cache() {
174        let (observer, _temp_dir) = create_test_observer().await;
175
176        // Add something to cache
177        let _ = observer.gather("test query").await;
178
179        // Clear cache
180        observer.clear_cache().await;
181
182        // Cache should be empty
183        let cache = observer.cache.read().await;
184        assert!(cache.is_empty());
185    }
186}