1use crate::{KnowledgeError, Result, TypeFact, TypeFactKind};
7use arrow::array::{ArrayRef, RecordBatch, StringArray};
8use arrow::datatypes::{DataType, Field, Schema};
9use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
10use parquet::arrow::ArrowWriter;
11use std::fs::File;
12use std::path::{Path, PathBuf};
13use std::sync::Arc;
14use tracing::{debug, info};
15
16pub struct TypeDatabase {
18 path: PathBuf,
20 schema: Arc<Schema>,
22}
23
24impl TypeDatabase {
25 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
27 let path = path.as_ref().to_path_buf();
28 let schema = Arc::new(Self::create_schema());
29 Ok(Self { path, schema })
30 }
31
32 pub fn temp() -> Result<Self> {
34 let path = std::env::temp_dir().join("depyler-types.parquet");
35 Self::new(path)
36 }
37
38 fn create_schema() -> Schema {
40 Schema::new(vec![
41 Field::new("module", DataType::Utf8, false),
42 Field::new("symbol", DataType::Utf8, false),
43 Field::new("kind", DataType::Utf8, false),
44 Field::new("signature", DataType::Utf8, false),
45 Field::new("return_type", DataType::Utf8, false),
46 ])
47 }
48
49 pub fn write(&self, facts: &[TypeFact]) -> Result<()> {
51 info!(path = %self.path.display(), count = facts.len(), "Writing type facts");
52
53 let batch = self.facts_to_batch(facts)?;
54
55 let file = File::create(&self.path)?;
56 let mut writer = ArrowWriter::try_new(file, self.schema.clone(), None)?;
57 writer.write(&batch)?;
58 writer.close()?;
59
60 debug!(path = %self.path.display(), "Write complete");
61 Ok(())
62 }
63
64 fn facts_to_batch(&self, facts: &[TypeFact]) -> Result<RecordBatch> {
66 let modules: Vec<&str> = facts.iter().map(|f| f.module.as_str()).collect();
67 let symbols: Vec<&str> = facts.iter().map(|f| f.symbol.as_str()).collect();
68 let kinds: Vec<String> = facts.iter().map(|f| f.kind.to_string()).collect();
69 let signatures: Vec<&str> = facts.iter().map(|f| f.signature.as_str()).collect();
70 let return_types: Vec<&str> = facts.iter().map(|f| f.return_type.as_str()).collect();
71
72 let columns: Vec<ArrayRef> = vec![
73 Arc::new(StringArray::from(modules)),
74 Arc::new(StringArray::from(symbols)),
75 Arc::new(StringArray::from(kinds.iter().map(|s| s.as_str()).collect::<Vec<_>>())),
76 Arc::new(StringArray::from(signatures)),
77 Arc::new(StringArray::from(return_types)),
78 ];
79
80 RecordBatch::try_new(self.schema.clone(), columns)
81 .map_err(|e| KnowledgeError::DatabaseError(e.to_string()))
82 }
83
84 pub fn read_all(&self) -> Result<Vec<TypeFact>> {
86 if !self.path.exists() {
87 return Ok(Vec::new());
88 }
89
90 let file = File::open(&self.path)?;
91 let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
92 let reader = builder.build()?;
93
94 let mut facts = Vec::new();
95 for batch in reader {
96 let batch = batch?;
97 let batch_facts = self.batch_to_facts(&batch)?;
98 facts.extend(batch_facts);
99 }
100
101 debug!(path = %self.path.display(), count = facts.len(), "Read type facts");
102 Ok(facts)
103 }
104
105 fn batch_to_facts(&self, batch: &RecordBatch) -> Result<Vec<TypeFact>> {
107 let modules = batch
108 .column(0)
109 .as_any()
110 .downcast_ref::<StringArray>()
111 .ok_or_else(|| KnowledgeError::DatabaseError("Invalid module column".to_string()))?;
112
113 let symbols = batch
114 .column(1)
115 .as_any()
116 .downcast_ref::<StringArray>()
117 .ok_or_else(|| KnowledgeError::DatabaseError("Invalid symbol column".to_string()))?;
118
119 let kinds = batch
120 .column(2)
121 .as_any()
122 .downcast_ref::<StringArray>()
123 .ok_or_else(|| KnowledgeError::DatabaseError("Invalid kind column".to_string()))?;
124
125 let signatures = batch
126 .column(3)
127 .as_any()
128 .downcast_ref::<StringArray>()
129 .ok_or_else(|| KnowledgeError::DatabaseError("Invalid signature column".to_string()))?;
130
131 let return_types = batch
132 .column(4)
133 .as_any()
134 .downcast_ref::<StringArray>()
135 .ok_or_else(|| KnowledgeError::DatabaseError("Invalid return_type column".to_string()))?;
136
137 let mut facts = Vec::with_capacity(batch.num_rows());
138 for i in 0..batch.num_rows() {
139 let kind_str = kinds.value(i);
140 let kind: TypeFactKind = kind_str.parse()?;
141
142 facts.push(TypeFact {
143 module: modules.value(i).to_string(),
144 symbol: symbols.value(i).to_string(),
145 kind,
146 signature: signatures.value(i).to_string(),
147 return_type: return_types.value(i).to_string(),
148 });
149 }
150
151 Ok(facts)
152 }
153
154 pub fn find_signature(&self, module: &str, symbol: &str) -> Option<String> {
156 self.read_all()
157 .ok()?
158 .into_iter()
159 .find(|f| f.module == module && f.symbol == symbol)
160 .map(|f| f.signature)
161 }
162
163 pub fn query_by_module(&self, prefix: &str) -> Result<Vec<TypeFact>> {
165 let all = self.read_all()?;
166 Ok(all
167 .into_iter()
168 .filter(|f| f.module.starts_with(prefix))
169 .collect())
170 }
171
172 pub fn path(&self) -> &Path {
174 &self.path
175 }
176
177 pub fn exists(&self) -> bool {
179 self.path.exists()
180 }
181
182 pub fn size_bytes(&self) -> Result<u64> {
184 let metadata = std::fs::metadata(&self.path)?;
185 Ok(metadata.len())
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use tempfile::TempDir;
193
194 #[test]
195 fn test_roundtrip() {
196 let temp = TempDir::new().unwrap();
197 let db_path = temp.path().join("test.parquet");
198 let db = TypeDatabase::new(&db_path).unwrap();
199
200 let facts = vec![
201 TypeFact::function("requests", "get", "(url: str) -> Response", "Response"),
202 TypeFact::class("requests.models", "Response"),
203 TypeFact::method("requests.models", "Response", "json", "(self) -> dict", "dict"),
204 ];
205
206 db.write(&facts).unwrap();
207 assert!(db.exists());
208
209 let loaded = db.read_all().unwrap();
210 assert_eq!(loaded.len(), 3);
211 assert_eq!(loaded[0].module, "requests");
212 assert_eq!(loaded[0].symbol, "get");
213 assert_eq!(loaded[1].kind, TypeFactKind::Class);
214 assert_eq!(loaded[2].symbol, "Response.json");
215 }
216
217 #[test]
218 fn test_find_signature() {
219 let temp = TempDir::new().unwrap();
220 let db_path = temp.path().join("test.parquet");
221 let db = TypeDatabase::new(&db_path).unwrap();
222
223 let facts = vec![
224 TypeFact::function("requests", "get", "(url: str, **kwargs) -> Response", "Response"),
225 ];
226
227 db.write(&facts).unwrap();
228
229 let sig = db.find_signature("requests", "get");
230 assert!(sig.is_some());
231 assert!(sig.unwrap().contains("url: str"));
232
233 let missing = db.find_signature("requests", "post");
234 assert!(missing.is_none());
235 }
236
237 #[test]
238 fn test_query_by_module() {
239 let temp = TempDir::new().unwrap();
240 let db_path = temp.path().join("test.parquet");
241 let db = TypeDatabase::new(&db_path).unwrap();
242
243 let facts = vec![
244 TypeFact::function("requests.api", "get", "(url: str) -> Response", "Response"),
245 TypeFact::function("requests.api", "post", "(url: str) -> Response", "Response"),
246 TypeFact::class("requests.models", "Response"),
247 ];
248
249 db.write(&facts).unwrap();
250
251 let api_facts = db.query_by_module("requests.api").unwrap();
252 assert_eq!(api_facts.len(), 2);
253
254 let all_requests = db.query_by_module("requests").unwrap();
255 assert_eq!(all_requests.len(), 3);
256 }
257}