Skip to main content

forge_agent/
observe.rs

1//! Observation phase - Graph-based context gathering.
2//!
3//! This module implements the observation phase of the agent loop, gathering
4//! relevant context from the code graph to inform intelligent operations.
5
6use crate::Result;
7use forge_core::{
8    types::SymbolId,
9    Forge,
10};
11use std::collections::HashMap;
12use std::sync::Arc;
13
14/// Observer for gathering context from a code graph.
15///
16/// The Observer uses Forge SDK to query symbols and references.
17#[derive(Clone)]
18pub struct Observer {
19    /// The Forge SDK instance for graph queries
20    forge: Arc<Forge>,
21    /// Cache for observation results (query -> observation)
22    cache: Arc<tokio::sync::RwLock<HashMap<String, Observation>>>,
23}
24
25impl Observer {
26    /// Creates a new Observer with given Forge instance.
27    pub fn new(forge: Forge) -> Self {
28        Self {
29            forge: Arc::new(forge),
30            cache: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
31        }
32    }
33
34    /// Gathers observation data for a natural language query.
35    pub async fn gather(&self, query: &str) -> Result<Observation> {
36        // Check cache first
37        {
38            let cache = self.cache.read().await;
39            if let Some(cached) = cache.get(query) {
40                return Ok(cached.clone());
41            }
42        }
43
44        // For now, gather symbols using the graph API
45        let symbols = self.gather_symbols(query).await?;
46
47        let observation = crate::Observation {
48            query: query.to_string(),
49            symbols,
50        };
51
52        // Cache the result
53        {
54            let mut cache = self.cache.write().await;
55            cache.insert(query.to_string(), observation.clone());
56        }
57
58        Ok(observation)
59    }
60
61    /// Gathers symbols by parsing the query.
62    async fn gather_symbols(&self, query: &str) -> Result<Vec<ObservedSymbol>> {
63        let _graph = self.forge.as_ref();
64        let mut symbols = Vec::new();
65
66        // Parse query: if it contains "find" and "named", extract the name
67        let query_lower = query.to_lowercase();
68
69        if query_lower.contains("find") && query_lower.contains("named") {
70            // Extract name after "named"
71            if let Some(pos) = query_lower.find("named") {
72                let remaining = &query[pos + 6..];
73                let name = remaining.trim().trim_end_matches('?').trim().to_string();
74                if !name.is_empty() {
75                    // For now, we can't find by name without knowing the file
76                    // Return a placeholder symbol
77                    symbols.push(ObservedSymbol {
78                        id: SymbolId(0),
79                        name: name.clone(),
80                        kind: forge_core::types::SymbolKind::Function,
81                        location: forge_core::types::Location {
82                            file_path: std::path::PathBuf::from("<unknown>"),
83                            byte_start: 0,
84                            byte_end: 0,
85                            line_number: 0,
86                        },
87                    });
88                }
89            }
90        }
91
92        Ok(symbols)
93    }
94
95    /// Clears the observation cache.
96    pub async fn clear_cache(&self) {
97        let mut cache = self.cache.write().await;
98        cache.clear();
99    }
100}
101
102/// Result of the observation phase.
103///
104/// Contains relevant context gathered from the code graph.
105#[derive(Clone, Debug)]
106pub struct Observation {
107    /// The original Query
108    pub query: String,
109    /// Relevant symbols found
110    pub symbols: Vec<ObservedSymbol>,
111}
112
113/// A symbol observed during the observation phase.
114#[derive(Clone, Debug)]
115pub struct ObservedSymbol {
116    /// Unique symbol identifier
117    pub id: SymbolId,
118    /// Symbol name
119    pub name: String,
120    /// Kind of symbol
121    pub kind: forge_core::types::SymbolKind,
122    /// Source location
123    pub location: forge_core::types::Location,
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use tempfile::TempDir;
130    use forge_core::Forge;
131
132    async fn create_test_observer() -> (Observer, TempDir) {
133        let temp_dir = TempDir::new().unwrap();
134        let forge = Forge::open(temp_dir.path()).await.unwrap();
135        let observer = Observer::new(forge);
136        (observer, temp_dir)
137    }
138
139    #[tokio::test]
140    async fn test_observer_creation() {
141        let temp_dir = TempDir::new().unwrap();
142        let forge = Forge::open(temp_dir.path()).await.unwrap();
143        let observer = Observer::new(forge);
144
145        // Observer should have empty cache
146        let cache = observer.cache.read().await;
147        assert!(cache.is_empty());
148    }
149
150    #[tokio::test]
151    async fn test_observation_caching() {
152        let (observer, _temp_dir) = create_test_observer().await;
153
154        // First call should not find anything in empty DB
155        let result1 = observer.gather("test query").await;
156        assert!(result1.is_ok());
157
158        // Second call should hit cache
159        let result2 = observer.gather("test query").await;
160        assert!(result2.is_ok());
161
162        // Results should be identical
163        assert_eq!(result1.unwrap().query, result2.unwrap().query);
164    }
165
166    #[tokio::test]
167    async fn test_clear_cache() {
168        let (observer, _temp_dir) = create_test_observer().await;
169
170        // Add something to cache
171        let _ = observer.gather("test query").await;
172
173        // Clear cache
174        observer.clear_cache().await;
175
176        // Cache should be empty
177        let cache = observer.cache.read().await;
178        assert!(cache.is_empty());
179    }
180}