1use anyhow::Result;
7use std::path::Path;
8
9#[cfg(feature = "native")]
10use tracing::info;
11
12#[cfg(feature = "native")]
13use walkdir::WalkDir;
14
15use crate::numpy_converter::{NumPyConverter, NumPyOp};
16use crate::pytorch_converter::{PyTorchConverter, PyTorchOperation};
17use crate::sklearn_converter::{SklearnAlgorithm, SklearnConverter};
18
19pub struct LibraryAnalyzer {
21 numpy_converter: NumPyConverter,
22 sklearn_converter: SklearnConverter,
23 pytorch_converter: PyTorchConverter,
24}
25
26impl Default for LibraryAnalyzer {
27 fn default() -> Self {
28 Self::new()
29 }
30}
31
32impl LibraryAnalyzer {
33 pub fn new() -> Self {
35 Self {
36 numpy_converter: NumPyConverter::new(),
37 sklearn_converter: SklearnConverter::new(),
38 pytorch_converter: PyTorchConverter::new(),
39 }
40 }
41
42 #[cfg(feature = "native")]
44 pub fn analyze_numpy_usage(&self, input_path: &Path) -> Result<Vec<String>> {
45 contract_pre_analyze!(input_path);
46 let converter = &self.numpy_converter;
47 analyze_library(input_path, &["import numpy", "from numpy"], "NumPy", |path, content| {
48 let operations = [
49 ("np.add", NumPyOp::Add),
50 ("np.subtract", NumPyOp::Subtract),
51 ("np.multiply", NumPyOp::Multiply),
52 ("np.dot", NumPyOp::Dot),
53 ("np.sum", NumPyOp::Sum),
54 ("np.array", NumPyOp::Array),
55 ];
56 operations
57 .iter()
58 .filter_map(|(pattern, op)| {
59 if content.contains(pattern) {
60 converter.convert(op).map(|r| {
61 format!("{}: {} → {}", path.display(), pattern, r.code_template)
62 })
63 } else {
64 None
65 }
66 })
67 .collect()
68 })
69 }
70
71 #[cfg(not(feature = "native"))]
73 pub fn analyze_numpy_usage(&self, _input_path: &Path) -> Result<Vec<String>> {
74 Ok(Vec::new())
75 }
76
77 #[cfg(feature = "native")]
79 pub fn analyze_sklearn_usage(&self, input_path: &Path) -> Result<Vec<String>> {
80 contract_pre_analyze!(input_path);
81 let converter = &self.sklearn_converter;
82 analyze_library(
83 input_path,
84 &["import sklearn", "from sklearn"],
85 "sklearn",
86 |path, content| {
87 let algorithms = [
88 ("LinearRegression", SklearnAlgorithm::LinearRegression),
89 ("LogisticRegression", SklearnAlgorithm::LogisticRegression),
90 ("KMeans", SklearnAlgorithm::KMeans),
91 ("DecisionTreeClassifier", SklearnAlgorithm::DecisionTreeClassifier),
92 ("RandomForestClassifier", SklearnAlgorithm::RandomForestClassifier),
93 ("StandardScaler", SklearnAlgorithm::StandardScaler),
94 ("train_test_split", SklearnAlgorithm::TrainTestSplit),
95 ];
96 algorithms
97 .iter()
98 .filter(|(pattern, _)| content.contains(*pattern))
99 .filter_map(|(pattern, alg)| {
100 converter.convert(alg).map(|r| {
101 format!(
102 "{}: {} ({}) → {}",
103 path.display(),
104 pattern,
105 alg.sklearn_module(),
106 r.code_template
107 )
108 })
109 })
110 .collect()
111 },
112 )
113 }
114
115 #[cfg(not(feature = "native"))]
117 pub fn analyze_sklearn_usage(&self, _input_path: &Path) -> Result<Vec<String>> {
118 Ok(Vec::new())
119 }
120
121 #[cfg(feature = "native")]
123 pub fn analyze_pytorch_usage(&self, input_path: &Path) -> Result<Vec<String>> {
124 contract_pre_analyze!(input_path);
125 let converter = &self.pytorch_converter;
126 analyze_library(
127 input_path,
128 &["import torch", "from torch", "from transformers"],
129 "PyTorch",
130 |path, content| {
131 let operations = [
132 ("torch.load", PyTorchOperation::LoadModel),
133 ("from_pretrained", PyTorchOperation::LoadModel),
134 ("AutoTokenizer", PyTorchOperation::LoadTokenizer),
135 (".forward(", PyTorchOperation::Forward),
136 (".generate(", PyTorchOperation::Generate),
137 ("nn.Linear", PyTorchOperation::Linear),
138 ("MultiheadAttention", PyTorchOperation::Attention),
139 ("tokenizer.encode", PyTorchOperation::Encode),
140 ("tokenizer.decode", PyTorchOperation::Decode),
141 ];
142 operations
143 .iter()
144 .filter(|(pattern, _)| content.contains(*pattern))
145 .filter_map(|(pattern, op)| {
146 converter.convert(op).map(|r| {
147 format!(
148 "{}: {} ({}) → {}",
149 path.display(),
150 pattern,
151 op.pytorch_module(),
152 r.code_template
153 )
154 })
155 })
156 .collect()
157 },
158 )
159 }
160
161 #[cfg(not(feature = "native"))]
163 pub fn analyze_pytorch_usage(&self, _input_path: &Path) -> Result<Vec<String>> {
164 Ok(Vec::new())
165 }
166}
167
168#[cfg(feature = "native")]
170fn analyze_library<F>(
171 input_path: &Path,
172 import_patterns: &[&str],
173 lib_name: &str,
174 match_content: F,
175) -> Result<Vec<String>>
176where
177 F: Fn(&Path, &str) -> Vec<String>,
178{
179 let mut recommendations = Vec::new();
180 for entry in WalkDir::new(input_path).follow_links(true).into_iter().filter_map(|e| e.ok()) {
181 let Some(ext) = entry.path().extension() else {
182 continue;
183 };
184 if ext != "py" {
185 continue;
186 }
187 let Ok(content) = std::fs::read_to_string(entry.path()) else {
188 continue;
189 };
190 if !import_patterns.iter().any(|p| content.contains(p)) {
191 continue;
192 }
193 info!(" Found {} usage in: {}", lib_name, entry.path().display());
194 recommendations.extend(match_content(entry.path(), &content));
195 }
196 Ok(recommendations)
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use std::path::PathBuf;
203
204 fn setup_dir(name: &str) -> PathBuf {
205 let dir = std::env::temp_dir().join(name);
206 let _ = std::fs::remove_dir_all(&dir);
207 std::fs::create_dir_all(&dir).expect("mkdir failed");
208 dir
209 }
210
211 fn cleanup(dir: &Path) {
212 let _ = std::fs::remove_dir_all(dir);
213 }
214
215 #[test]
216 fn test_library_analyzer_creation() {
217 let _analyzer = LibraryAnalyzer::new();
218 }
219
220 #[test]
221 fn test_library_analyzer_default() {
222 let _analyzer = LibraryAnalyzer::default();
223 }
224
225 #[cfg(feature = "native")]
228 #[test]
229 fn test_analyze_numpy_nonexistent_path() {
230 let analyzer = LibraryAnalyzer::new();
231 let result = analyzer.analyze_numpy_usage(Path::new("/nonexistent/path"));
232 assert!(result.is_ok());
233 assert!(result.expect("operation failed").is_empty());
234 }
235
236 #[cfg(feature = "native")]
237 #[test]
238 fn test_analyze_sklearn_nonexistent_path() {
239 let analyzer = LibraryAnalyzer::new();
240 let result = analyzer.analyze_sklearn_usage(Path::new("/nonexistent/path"));
241 assert!(result.is_ok());
242 assert!(result.expect("operation failed").is_empty());
243 }
244
245 #[cfg(feature = "native")]
246 #[test]
247 fn test_analyze_pytorch_nonexistent_path() {
248 let analyzer = LibraryAnalyzer::new();
249 let result = analyzer.analyze_pytorch_usage(Path::new("/nonexistent/path"));
250 assert!(result.is_ok());
251 assert!(result.expect("operation failed").is_empty());
252 }
253
254 #[cfg(feature = "native")]
257 #[test]
258 fn test_analyze_numpy_with_matching_file() {
259 let dir = setup_dir("test_pa_numpy");
260 std::fs::write(
261 dir.join("model.py"),
262 "import numpy as np\nx = np.array([1,2,3])\ny = np.dot(x, x)\nz = np.sum(y)\n",
263 )
264 .expect("unexpected failure");
265 let analyzer = LibraryAnalyzer::new();
266 let results = analyzer.analyze_numpy_usage(&dir).expect("unexpected failure");
267 assert!(!results.is_empty());
268 assert!(results.iter().any(|r| r.contains("np.array")));
269 assert!(results.iter().any(|r| r.contains("np.dot")));
270 assert!(results.iter().any(|r| r.contains("np.sum")));
271 cleanup(&dir);
272 }
273
274 #[cfg(feature = "native")]
275 #[test]
276 fn test_analyze_numpy_no_import() {
277 let dir = setup_dir("test_pa_numpy_noimport");
278 std::fs::write(dir.join("script.py"), "x = [1, 2, 3]\nprint(sum(x))\n")
279 .expect("fs write failed");
280 let analyzer = LibraryAnalyzer::new();
281 let results = analyzer.analyze_numpy_usage(&dir).expect("unexpected failure");
282 assert!(results.is_empty());
283 cleanup(&dir);
284 }
285
286 #[cfg(feature = "native")]
287 #[test]
288 fn test_analyze_numpy_non_python_files_ignored() {
289 let dir = setup_dir("test_pa_numpy_nonpy");
290 std::fs::write(dir.join("data.txt"), "import numpy as np\nnp.array([1])\n")
291 .expect("fs write failed");
292 let analyzer = LibraryAnalyzer::new();
293 let results = analyzer.analyze_numpy_usage(&dir).expect("unexpected failure");
294 assert!(results.is_empty());
295 cleanup(&dir);
296 }
297
298 #[cfg(feature = "native")]
299 #[test]
300 fn test_analyze_numpy_add_subtract_multiply() {
301 let dir = setup_dir("test_pa_numpy_ops");
302 std::fs::write(
303 dir.join("ops.py"),
304 "import numpy as np\na = np.add(x, y)\nb = np.subtract(x, y)\nc = np.multiply(x, y)\n",
305 )
306 .expect("unexpected failure");
307 let analyzer = LibraryAnalyzer::new();
308 let results = analyzer.analyze_numpy_usage(&dir).expect("unexpected failure");
309 assert!(results.iter().any(|r| r.contains("np.add")));
310 assert!(results.iter().any(|r| r.contains("np.subtract")));
311 assert!(results.iter().any(|r| r.contains("np.multiply")));
312 cleanup(&dir);
313 }
314
315 #[cfg(feature = "native")]
318 #[test]
319 fn test_analyze_sklearn_with_matching_file() {
320 let dir = setup_dir("test_pa_sklearn");
321 std::fs::write(
322 dir.join("train.py"),
323 "from sklearn.linear_model import LinearRegression\nfrom sklearn.cluster import KMeans\nmodel = LinearRegression()\nkm = KMeans(n_clusters=3)\n",
324 )
325 .expect("unexpected failure");
326 let analyzer = LibraryAnalyzer::new();
327 let results = analyzer.analyze_sklearn_usage(&dir).expect("unexpected failure");
328 assert!(!results.is_empty());
329 assert!(results.iter().any(|r| r.contains("LinearRegression")));
330 assert!(results.iter().any(|r| r.contains("KMeans")));
331 cleanup(&dir);
332 }
333
334 #[cfg(feature = "native")]
335 #[test]
336 fn test_analyze_sklearn_no_import() {
337 let dir = setup_dir("test_pa_sklearn_noimport");
338 std::fs::write(dir.join("script.py"), "print('hello')\n").expect("fs write failed");
339 let analyzer = LibraryAnalyzer::new();
340 let results = analyzer.analyze_sklearn_usage(&dir).expect("unexpected failure");
341 assert!(results.is_empty());
342 cleanup(&dir);
343 }
344
345 #[cfg(feature = "native")]
346 #[test]
347 fn test_analyze_sklearn_more_algorithms() {
348 let dir = setup_dir("test_pa_sklearn_more");
350 std::fs::write(
351 dir.join("ml.py"),
352 "from sklearn.tree import DecisionTreeClassifier\nfrom sklearn.preprocessing import StandardScaler\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.linear_model import LogisticRegression\n",
353 )
354 .expect("unexpected failure");
355 let analyzer = LibraryAnalyzer::new();
356 let results = analyzer.analyze_sklearn_usage(&dir).expect("unexpected failure");
357 assert!(results.iter().any(|r| r.contains("DecisionTreeClassifier")));
358 assert!(results.iter().any(|r| r.contains("StandardScaler")));
359 assert!(results.iter().any(|r| r.contains("train_test_split")));
360 assert!(results.iter().any(|r| r.contains("LogisticRegression")));
361 cleanup(&dir);
362 }
363
364 #[cfg(feature = "native")]
367 #[test]
368 fn test_analyze_pytorch_with_matching_file() {
369 let dir = setup_dir("test_pa_pytorch");
370 std::fs::write(
371 dir.join("infer.py"),
372 "import torch\nmodel = torch.load('model.pt')\nout = model.forward(x)\n",
373 )
374 .expect("unexpected failure");
375 let analyzer = LibraryAnalyzer::new();
376 let results = analyzer.analyze_pytorch_usage(&dir).expect("unexpected failure");
377 assert!(!results.is_empty());
378 assert!(results.iter().any(|r| r.contains("torch.load")));
379 assert!(results.iter().any(|r| r.contains(".forward(")));
380 cleanup(&dir);
381 }
382
383 #[cfg(feature = "native")]
384 #[test]
385 fn test_analyze_pytorch_no_import() {
386 let dir = setup_dir("test_pa_pytorch_noimport");
387 std::fs::write(dir.join("app.py"), "print('hello')\n").expect("fs write failed");
388 let analyzer = LibraryAnalyzer::new();
389 let results = analyzer.analyze_pytorch_usage(&dir).expect("unexpected failure");
390 assert!(results.is_empty());
391 cleanup(&dir);
392 }
393
394 #[cfg(feature = "native")]
395 #[test]
396 fn test_analyze_pytorch_transformers() {
397 let dir = setup_dir("test_pa_pytorch_hf");
398 std::fs::write(
399 dir.join("hf.py"),
400 "from transformers import AutoTokenizer\ntokenizer = AutoTokenizer.from_pretrained('bert')\nids = tokenizer.encode('hello')\ntext = tokenizer.decode(ids)\n",
401 )
402 .expect("unexpected failure");
403 let analyzer = LibraryAnalyzer::new();
404 let results = analyzer.analyze_pytorch_usage(&dir).expect("unexpected failure");
405 assert!(results.iter().any(|r| r.contains("AutoTokenizer")));
406 assert!(results.iter().any(|r| r.contains("from_pretrained")));
407 assert!(results.iter().any(|r| r.contains("tokenizer.encode")));
408 assert!(results.iter().any(|r| r.contains("tokenizer.decode")));
409 cleanup(&dir);
410 }
411
412 #[cfg(feature = "native")]
413 #[test]
414 fn test_analyze_pytorch_nn_modules() {
415 let dir = setup_dir("test_pa_pytorch_nn");
416 std::fs::write(
417 dir.join("model.py"),
418 "import torch\nimport torch.nn as nn\nlayer = nn.Linear(10, 5)\nattn = nn.MultiheadAttention(512, 8)\nout = model.generate(ids)\n",
419 )
420 .expect("unexpected failure");
421 let analyzer = LibraryAnalyzer::new();
422 let results = analyzer.analyze_pytorch_usage(&dir).expect("unexpected failure");
423 assert!(results.iter().any(|r| r.contains("nn.Linear")));
424 assert!(results.iter().any(|r| r.contains("MultiheadAttention")));
425 assert!(results.iter().any(|r| r.contains(".generate(")));
426 cleanup(&dir);
427 }
428
429 #[cfg(feature = "native")]
432 #[test]
433 fn test_analyze_numpy_recursive() {
434 let dir = setup_dir("test_pa_numpy_recurse");
435 let sub = dir.join("pkg").join("sub");
436 std::fs::create_dir_all(&sub).expect("mkdir failed");
437 std::fs::write(sub.join("deep.py"), "from numpy import array\nx = np.array([1])\n")
438 .expect("unexpected failure");
439 let analyzer = LibraryAnalyzer::new();
440 let results = analyzer.analyze_numpy_usage(&dir).expect("unexpected failure");
441 assert!(results.iter().any(|r| r.contains("np.array")));
442 cleanup(&dir);
443 }
444
445 #[cfg(feature = "native")]
448 #[test]
449 fn test_analyze_all_empty_dir() {
450 let dir = setup_dir("test_pa_all_empty");
451 let analyzer = LibraryAnalyzer::new();
452 assert!(analyzer.analyze_numpy_usage(&dir).expect("unexpected failure").is_empty());
453 assert!(analyzer.analyze_sklearn_usage(&dir).expect("unexpected failure").is_empty());
454 assert!(analyzer.analyze_pytorch_usage(&dir).expect("unexpected failure").is_empty());
455 cleanup(&dir);
456 }
457}