1use crate::{KnowledgeError, Result, TypeDatabase, TypeFact, TypeFactKind};
7use std::collections::HashMap;
8use std::path::Path;
9
10pub struct TypeQuery {
12 db: TypeDatabase,
14 cache: HashMap<String, TypeFact>,
16 cache_populated: bool,
18}
19
20impl TypeQuery {
21 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
23 let db = TypeDatabase::new(path)?;
24 Ok(Self {
25 db,
26 cache: HashMap::new(),
27 cache_populated: false,
28 })
29 }
30
31 pub fn warm_cache(&mut self) -> Result<()> {
33 if self.cache_populated {
34 return Ok(());
35 }
36
37 let facts = self.db.read_all()?;
38 for fact in facts {
39 let key = format!("{}.{}", fact.module, fact.symbol);
40 self.cache.insert(key, fact);
41 }
42
43 self.cache_populated = true;
44 Ok(())
45 }
46
47 pub fn find_signature(&mut self, module: &str, symbol: &str) -> Result<String> {
58 self.warm_cache()?;
59
60 let key = format!("{module}.{symbol}");
62 if let Some(fact) = self.cache.get(&key) {
63 return Ok(fact.signature.clone());
64 }
65
66 let submodule_patterns = ["api", "core", "_api", "main", "base"];
68 for submod in &submodule_patterns {
69 let alt_key = format!("{module}.{submod}.{symbol}");
70 if let Some(fact) = self.cache.get(&alt_key) {
71 return Ok(fact.signature.clone());
72 }
73 }
74
75 Err(KnowledgeError::SymbolNotFound {
76 module: module.to_string(),
77 symbol: symbol.to_string(),
78 })
79 }
80
81 pub fn find_return_type(&mut self, module: &str, symbol: &str) -> Result<String> {
86 self.warm_cache()?;
87
88 let key = format!("{module}.{symbol}");
90 if let Some(fact) = self.cache.get(&key) {
91 return Ok(fact.return_type.clone());
92 }
93
94 let submodule_patterns = ["api", "core", "_api", "main", "base"];
97 for submod in &submodule_patterns {
98 let alt_key = format!("{module}.{submod}.{symbol}");
99 if let Some(fact) = self.cache.get(&alt_key) {
100 return Ok(fact.return_type.clone());
101 }
102 }
103
104 Err(KnowledgeError::SymbolNotFound {
105 module: module.to_string(),
106 symbol: symbol.to_string(),
107 })
108 }
109
110 pub fn find_fact(&mut self, module: &str, symbol: &str) -> Result<TypeFact> {
114 self.warm_cache()?;
115
116 let key = format!("{module}.{symbol}");
118 if let Some(fact) = self.cache.get(&key) {
119 return Ok(fact.clone());
120 }
121
122 let submodule_patterns = ["api", "core", "_api", "main", "base"];
124 for submod in &submodule_patterns {
125 let alt_key = format!("{module}.{submod}.{symbol}");
126 if let Some(fact) = self.cache.get(&alt_key) {
127 return Ok(fact.clone());
128 }
129 }
130
131 Err(KnowledgeError::SymbolNotFound {
132 module: module.to_string(),
133 symbol: symbol.to_string(),
134 })
135 }
136
137 pub fn find_functions(&mut self, module: &str) -> Result<Vec<TypeFact>> {
139 self.warm_cache()?;
140
141 Ok(self
142 .cache
143 .values()
144 .filter(|f| f.module == module && f.kind == TypeFactKind::Function)
145 .cloned()
146 .collect())
147 }
148
149 pub fn find_classes(&mut self, module: &str) -> Result<Vec<TypeFact>> {
151 self.warm_cache()?;
152
153 Ok(self
154 .cache
155 .values()
156 .filter(|f| f.module == module && f.kind == TypeFactKind::Class)
157 .cloned()
158 .collect())
159 }
160
161 pub fn find_methods(&mut self, module: &str, class_name: &str) -> Result<Vec<TypeFact>> {
163 self.warm_cache()?;
164
165 let prefix = format!("{class_name}.");
166 Ok(self
167 .cache
168 .values()
169 .filter(|f| f.module == module && f.kind == TypeFactKind::Method && f.symbol.starts_with(&prefix))
170 .cloned()
171 .collect())
172 }
173
174 pub fn has_symbol(&mut self, module: &str, symbol: &str) -> bool {
178 if self.warm_cache().is_err() {
179 return false;
180 }
181
182 let key = format!("{module}.{symbol}");
184 if self.cache.contains_key(&key) {
185 return true;
186 }
187
188 let submodule_patterns = ["api", "core", "_api", "main", "base"];
190 for submod in &submodule_patterns {
191 let alt_key = format!("{module}.{submod}.{symbol}");
192 if self.cache.contains_key(&alt_key) {
193 return true;
194 }
195 }
196
197 false
198 }
199
200 pub fn count(&mut self) -> usize {
202 if self.warm_cache().is_ok() {
203 self.cache.len()
204 } else {
205 0
206 }
207 }
208
209 pub fn search(&mut self, pattern: &str) -> Result<Vec<TypeFact>> {
211 self.warm_cache()?;
212
213 let pattern_lower = pattern.to_lowercase();
214 Ok(self
215 .cache
216 .values()
217 .filter(|f| {
218 f.symbol.to_lowercase().contains(&pattern_lower)
219 || f.module.to_lowercase().contains(&pattern_lower)
220 })
221 .cloned()
222 .collect())
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use tempfile::TempDir;
230
231 fn setup_test_db() -> (TempDir, TypeQuery) {
232 let temp = TempDir::new().unwrap();
233 let db_path = temp.path().join("test.parquet");
234 let db = TypeDatabase::new(&db_path).unwrap();
235
236 let facts = vec![
237 TypeFact::function("requests", "get", "(url: str, **kwargs) -> Response", "Response"),
238 TypeFact::function("requests", "post", "(url: str, data: dict) -> Response", "Response"),
239 TypeFact::class("requests.models", "Response"),
240 TypeFact::method("requests.models", "Response", "json", "(self) -> dict", "dict"),
241 TypeFact::method("requests.models", "Response", "text", "(self) -> str", "str"),
242 ];
243
244 db.write(&facts).unwrap();
245
246 let query = TypeQuery::new(&db_path).unwrap();
247 (temp, query)
248 }
249
250 #[test]
251 fn test_find_signature() {
252 let (_temp, mut query) = setup_test_db();
253
254 let sig = query.find_signature("requests", "get").unwrap();
255 assert!(sig.contains("url: str"));
256 assert!(sig.contains("**kwargs"));
257 }
258
259 #[test]
260 fn test_find_return_type() {
261 let (_temp, mut query) = setup_test_db();
262
263 let ret = query.find_return_type("requests", "get").unwrap();
264 assert_eq!(ret, "Response");
265 }
266
267 #[test]
268 fn test_find_methods() {
269 let (_temp, mut query) = setup_test_db();
270
271 let methods = query.find_methods("requests.models", "Response").unwrap();
272 assert_eq!(methods.len(), 2);
273
274 let method_names: Vec<_> = methods.iter().map(|m| m.symbol.as_str()).collect();
275 assert!(method_names.contains(&"Response.json"));
276 assert!(method_names.contains(&"Response.text"));
277 }
278
279 #[test]
280 fn test_has_symbol() {
281 let (_temp, mut query) = setup_test_db();
282
283 assert!(query.has_symbol("requests", "get"));
284 assert!(!query.has_symbol("requests", "put")); }
286
287 #[test]
288 fn test_search() {
289 let (_temp, mut query) = setup_test_db();
290
291 let results = query.search("json").unwrap();
292 assert_eq!(results.len(), 1);
293 assert_eq!(results[0].symbol, "Response.json");
294 }
295
296 #[test]
297 fn test_symbol_not_found() {
298 let (_temp, mut query) = setup_test_db();
299
300 let result = query.find_signature("unknown", "function");
301 assert!(result.is_err());
302 assert!(matches!(result, Err(KnowledgeError::SymbolNotFound { .. })));
303 }
304}