Skip to main content

batuta/
pipeline_analysis.rs

1//! Pipeline Library Analysis
2//!
3//! Provides analysis of Python ML library usage for conversion guidance.
4//! Extracts NumPy, scikit-learn, and PyTorch operations to map to Rust equivalents.
5
6use anyhow::Result;
7use std::path::Path;
8
9#[cfg(feature = "native")]
10use tracing::info;
11
12#[cfg(feature = "native")]
13use walkdir::WalkDir;
14
15use crate::numpy_converter::{NumPyConverter, NumPyOp};
16use crate::pytorch_converter::{PyTorchConverter, PyTorchOperation};
17use crate::sklearn_converter::{SklearnAlgorithm, SklearnConverter};
18
19/// Analyzer for ML library usage in Python code
20pub struct LibraryAnalyzer {
21    numpy_converter: NumPyConverter,
22    sklearn_converter: SklearnConverter,
23    pytorch_converter: PyTorchConverter,
24}
25
26impl Default for LibraryAnalyzer {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32impl LibraryAnalyzer {
33    /// Create a new library analyzer
34    pub fn new() -> Self {
35        Self {
36            numpy_converter: NumPyConverter::new(),
37            sklearn_converter: SklearnConverter::new(),
38            pytorch_converter: PyTorchConverter::new(),
39        }
40    }
41
42    /// Analyze Python source for NumPy usage and provide conversion guidance
43    #[cfg(feature = "native")]
44    pub fn analyze_numpy_usage(&self, input_path: &Path) -> Result<Vec<String>> {
45        contract_pre_analyze!(input_path);
46        let converter = &self.numpy_converter;
47        analyze_library(input_path, &["import numpy", "from numpy"], "NumPy", |path, content| {
48            let operations = [
49                ("np.add", NumPyOp::Add),
50                ("np.subtract", NumPyOp::Subtract),
51                ("np.multiply", NumPyOp::Multiply),
52                ("np.dot", NumPyOp::Dot),
53                ("np.sum", NumPyOp::Sum),
54                ("np.array", NumPyOp::Array),
55            ];
56            operations
57                .iter()
58                .filter_map(|(pattern, op)| {
59                    if content.contains(pattern) {
60                        converter.convert(op).map(|r| {
61                            format!("{}: {} → {}", path.display(), pattern, r.code_template)
62                        })
63                    } else {
64                        None
65                    }
66                })
67                .collect()
68        })
69    }
70
71    /// Stub for WASM build
72    #[cfg(not(feature = "native"))]
73    pub fn analyze_numpy_usage(&self, _input_path: &Path) -> Result<Vec<String>> {
74        Ok(Vec::new())
75    }
76
77    /// Analyze Python source for sklearn usage and provide conversion guidance
78    #[cfg(feature = "native")]
79    pub fn analyze_sklearn_usage(&self, input_path: &Path) -> Result<Vec<String>> {
80        contract_pre_analyze!(input_path);
81        let converter = &self.sklearn_converter;
82        analyze_library(
83            input_path,
84            &["import sklearn", "from sklearn"],
85            "sklearn",
86            |path, content| {
87                let algorithms = [
88                    ("LinearRegression", SklearnAlgorithm::LinearRegression),
89                    ("LogisticRegression", SklearnAlgorithm::LogisticRegression),
90                    ("KMeans", SklearnAlgorithm::KMeans),
91                    ("DecisionTreeClassifier", SklearnAlgorithm::DecisionTreeClassifier),
92                    ("RandomForestClassifier", SklearnAlgorithm::RandomForestClassifier),
93                    ("StandardScaler", SklearnAlgorithm::StandardScaler),
94                    ("train_test_split", SklearnAlgorithm::TrainTestSplit),
95                ];
96                algorithms
97                    .iter()
98                    .filter(|(pattern, _)| content.contains(*pattern))
99                    .filter_map(|(pattern, alg)| {
100                        converter.convert(alg).map(|r| {
101                            format!(
102                                "{}: {} ({}) → {}",
103                                path.display(),
104                                pattern,
105                                alg.sklearn_module(),
106                                r.code_template
107                            )
108                        })
109                    })
110                    .collect()
111            },
112        )
113    }
114
115    /// Stub for WASM build
116    #[cfg(not(feature = "native"))]
117    pub fn analyze_sklearn_usage(&self, _input_path: &Path) -> Result<Vec<String>> {
118        Ok(Vec::new())
119    }
120
121    /// Analyze Python source for PyTorch usage and provide conversion guidance
122    #[cfg(feature = "native")]
123    pub fn analyze_pytorch_usage(&self, input_path: &Path) -> Result<Vec<String>> {
124        contract_pre_analyze!(input_path);
125        let converter = &self.pytorch_converter;
126        analyze_library(
127            input_path,
128            &["import torch", "from torch", "from transformers"],
129            "PyTorch",
130            |path, content| {
131                let operations = [
132                    ("torch.load", PyTorchOperation::LoadModel),
133                    ("from_pretrained", PyTorchOperation::LoadModel),
134                    ("AutoTokenizer", PyTorchOperation::LoadTokenizer),
135                    (".forward(", PyTorchOperation::Forward),
136                    (".generate(", PyTorchOperation::Generate),
137                    ("nn.Linear", PyTorchOperation::Linear),
138                    ("MultiheadAttention", PyTorchOperation::Attention),
139                    ("tokenizer.encode", PyTorchOperation::Encode),
140                    ("tokenizer.decode", PyTorchOperation::Decode),
141                ];
142                operations
143                    .iter()
144                    .filter(|(pattern, _)| content.contains(*pattern))
145                    .filter_map(|(pattern, op)| {
146                        converter.convert(op).map(|r| {
147                            format!(
148                                "{}: {} ({}) → {}",
149                                path.display(),
150                                pattern,
151                                op.pytorch_module(),
152                                r.code_template
153                            )
154                        })
155                    })
156                    .collect()
157            },
158        )
159    }
160
161    /// Stub for WASM build
162    #[cfg(not(feature = "native"))]
163    pub fn analyze_pytorch_usage(&self, _input_path: &Path) -> Result<Vec<String>> {
164        Ok(Vec::new())
165    }
166}
167
168/// Shared helper: walk Python files matching import patterns and apply conversion logic
169#[cfg(feature = "native")]
170fn analyze_library<F>(
171    input_path: &Path,
172    import_patterns: &[&str],
173    lib_name: &str,
174    match_content: F,
175) -> Result<Vec<String>>
176where
177    F: Fn(&Path, &str) -> Vec<String>,
178{
179    let mut recommendations = Vec::new();
180    for entry in WalkDir::new(input_path).follow_links(true).into_iter().filter_map(|e| e.ok()) {
181        let Some(ext) = entry.path().extension() else {
182            continue;
183        };
184        if ext != "py" {
185            continue;
186        }
187        let Ok(content) = std::fs::read_to_string(entry.path()) else {
188            continue;
189        };
190        if !import_patterns.iter().any(|p| content.contains(p)) {
191            continue;
192        }
193        info!("  Found {} usage in: {}", lib_name, entry.path().display());
194        recommendations.extend(match_content(entry.path(), &content));
195    }
196    Ok(recommendations)
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use std::path::PathBuf;
203
204    fn setup_dir(name: &str) -> PathBuf {
205        let dir = std::env::temp_dir().join(name);
206        let _ = std::fs::remove_dir_all(&dir);
207        std::fs::create_dir_all(&dir).expect("mkdir failed");
208        dir
209    }
210
211    fn cleanup(dir: &Path) {
212        let _ = std::fs::remove_dir_all(dir);
213    }
214
215    #[test]
216    fn test_library_analyzer_creation() {
217        let _analyzer = LibraryAnalyzer::new();
218    }
219
220    #[test]
221    fn test_library_analyzer_default() {
222        let _analyzer = LibraryAnalyzer::default();
223    }
224
225    // ===== Nonexistent paths =====
226
227    #[cfg(feature = "native")]
228    #[test]
229    fn test_analyze_numpy_nonexistent_path() {
230        let analyzer = LibraryAnalyzer::new();
231        let result = analyzer.analyze_numpy_usage(Path::new("/nonexistent/path"));
232        assert!(result.is_ok());
233        assert!(result.expect("operation failed").is_empty());
234    }
235
236    #[cfg(feature = "native")]
237    #[test]
238    fn test_analyze_sklearn_nonexistent_path() {
239        let analyzer = LibraryAnalyzer::new();
240        let result = analyzer.analyze_sklearn_usage(Path::new("/nonexistent/path"));
241        assert!(result.is_ok());
242        assert!(result.expect("operation failed").is_empty());
243    }
244
245    #[cfg(feature = "native")]
246    #[test]
247    fn test_analyze_pytorch_nonexistent_path() {
248        let analyzer = LibraryAnalyzer::new();
249        let result = analyzer.analyze_pytorch_usage(Path::new("/nonexistent/path"));
250        assert!(result.is_ok());
251        assert!(result.expect("operation failed").is_empty());
252    }
253
254    // ===== NumPy with real files =====
255
256    #[cfg(feature = "native")]
257    #[test]
258    fn test_analyze_numpy_with_matching_file() {
259        let dir = setup_dir("test_pa_numpy");
260        std::fs::write(
261            dir.join("model.py"),
262            "import numpy as np\nx = np.array([1,2,3])\ny = np.dot(x, x)\nz = np.sum(y)\n",
263        )
264        .expect("unexpected failure");
265        let analyzer = LibraryAnalyzer::new();
266        let results = analyzer.analyze_numpy_usage(&dir).expect("unexpected failure");
267        assert!(!results.is_empty());
268        assert!(results.iter().any(|r| r.contains("np.array")));
269        assert!(results.iter().any(|r| r.contains("np.dot")));
270        assert!(results.iter().any(|r| r.contains("np.sum")));
271        cleanup(&dir);
272    }
273
274    #[cfg(feature = "native")]
275    #[test]
276    fn test_analyze_numpy_no_import() {
277        let dir = setup_dir("test_pa_numpy_noimport");
278        std::fs::write(dir.join("script.py"), "x = [1, 2, 3]\nprint(sum(x))\n")
279            .expect("fs write failed");
280        let analyzer = LibraryAnalyzer::new();
281        let results = analyzer.analyze_numpy_usage(&dir).expect("unexpected failure");
282        assert!(results.is_empty());
283        cleanup(&dir);
284    }
285
286    #[cfg(feature = "native")]
287    #[test]
288    fn test_analyze_numpy_non_python_files_ignored() {
289        let dir = setup_dir("test_pa_numpy_nonpy");
290        std::fs::write(dir.join("data.txt"), "import numpy as np\nnp.array([1])\n")
291            .expect("fs write failed");
292        let analyzer = LibraryAnalyzer::new();
293        let results = analyzer.analyze_numpy_usage(&dir).expect("unexpected failure");
294        assert!(results.is_empty());
295        cleanup(&dir);
296    }
297
298    #[cfg(feature = "native")]
299    #[test]
300    fn test_analyze_numpy_add_subtract_multiply() {
301        let dir = setup_dir("test_pa_numpy_ops");
302        std::fs::write(
303            dir.join("ops.py"),
304            "import numpy as np\na = np.add(x, y)\nb = np.subtract(x, y)\nc = np.multiply(x, y)\n",
305        )
306        .expect("unexpected failure");
307        let analyzer = LibraryAnalyzer::new();
308        let results = analyzer.analyze_numpy_usage(&dir).expect("unexpected failure");
309        assert!(results.iter().any(|r| r.contains("np.add")));
310        assert!(results.iter().any(|r| r.contains("np.subtract")));
311        assert!(results.iter().any(|r| r.contains("np.multiply")));
312        cleanup(&dir);
313    }
314
315    // ===== sklearn with real files =====
316
317    #[cfg(feature = "native")]
318    #[test]
319    fn test_analyze_sklearn_with_matching_file() {
320        let dir = setup_dir("test_pa_sklearn");
321        std::fs::write(
322            dir.join("train.py"),
323            "from sklearn.linear_model import LinearRegression\nfrom sklearn.cluster import KMeans\nmodel = LinearRegression()\nkm = KMeans(n_clusters=3)\n",
324        )
325        .expect("unexpected failure");
326        let analyzer = LibraryAnalyzer::new();
327        let results = analyzer.analyze_sklearn_usage(&dir).expect("unexpected failure");
328        assert!(!results.is_empty());
329        assert!(results.iter().any(|r| r.contains("LinearRegression")));
330        assert!(results.iter().any(|r| r.contains("KMeans")));
331        cleanup(&dir);
332    }
333
334    #[cfg(feature = "native")]
335    #[test]
336    fn test_analyze_sklearn_no_import() {
337        let dir = setup_dir("test_pa_sklearn_noimport");
338        std::fs::write(dir.join("script.py"), "print('hello')\n").expect("fs write failed");
339        let analyzer = LibraryAnalyzer::new();
340        let results = analyzer.analyze_sklearn_usage(&dir).expect("unexpected failure");
341        assert!(results.is_empty());
342        cleanup(&dir);
343    }
344
345    #[cfg(feature = "native")]
346    #[test]
347    fn test_analyze_sklearn_more_algorithms() {
348        // Only algorithms registered in SklearnConverter::new() produce output
349        let dir = setup_dir("test_pa_sklearn_more");
350        std::fs::write(
351            dir.join("ml.py"),
352            "from sklearn.tree import DecisionTreeClassifier\nfrom sklearn.preprocessing import StandardScaler\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.linear_model import LogisticRegression\n",
353        )
354        .expect("unexpected failure");
355        let analyzer = LibraryAnalyzer::new();
356        let results = analyzer.analyze_sklearn_usage(&dir).expect("unexpected failure");
357        assert!(results.iter().any(|r| r.contains("DecisionTreeClassifier")));
358        assert!(results.iter().any(|r| r.contains("StandardScaler")));
359        assert!(results.iter().any(|r| r.contains("train_test_split")));
360        assert!(results.iter().any(|r| r.contains("LogisticRegression")));
361        cleanup(&dir);
362    }
363
364    // ===== PyTorch with real files =====
365
366    #[cfg(feature = "native")]
367    #[test]
368    fn test_analyze_pytorch_with_matching_file() {
369        let dir = setup_dir("test_pa_pytorch");
370        std::fs::write(
371            dir.join("infer.py"),
372            "import torch\nmodel = torch.load('model.pt')\nout = model.forward(x)\n",
373        )
374        .expect("unexpected failure");
375        let analyzer = LibraryAnalyzer::new();
376        let results = analyzer.analyze_pytorch_usage(&dir).expect("unexpected failure");
377        assert!(!results.is_empty());
378        assert!(results.iter().any(|r| r.contains("torch.load")));
379        assert!(results.iter().any(|r| r.contains(".forward(")));
380        cleanup(&dir);
381    }
382
383    #[cfg(feature = "native")]
384    #[test]
385    fn test_analyze_pytorch_no_import() {
386        let dir = setup_dir("test_pa_pytorch_noimport");
387        std::fs::write(dir.join("app.py"), "print('hello')\n").expect("fs write failed");
388        let analyzer = LibraryAnalyzer::new();
389        let results = analyzer.analyze_pytorch_usage(&dir).expect("unexpected failure");
390        assert!(results.is_empty());
391        cleanup(&dir);
392    }
393
394    #[cfg(feature = "native")]
395    #[test]
396    fn test_analyze_pytorch_transformers() {
397        let dir = setup_dir("test_pa_pytorch_hf");
398        std::fs::write(
399            dir.join("hf.py"),
400            "from transformers import AutoTokenizer\ntokenizer = AutoTokenizer.from_pretrained('bert')\nids = tokenizer.encode('hello')\ntext = tokenizer.decode(ids)\n",
401        )
402        .expect("unexpected failure");
403        let analyzer = LibraryAnalyzer::new();
404        let results = analyzer.analyze_pytorch_usage(&dir).expect("unexpected failure");
405        assert!(results.iter().any(|r| r.contains("AutoTokenizer")));
406        assert!(results.iter().any(|r| r.contains("from_pretrained")));
407        assert!(results.iter().any(|r| r.contains("tokenizer.encode")));
408        assert!(results.iter().any(|r| r.contains("tokenizer.decode")));
409        cleanup(&dir);
410    }
411
412    #[cfg(feature = "native")]
413    #[test]
414    fn test_analyze_pytorch_nn_modules() {
415        let dir = setup_dir("test_pa_pytorch_nn");
416        std::fs::write(
417            dir.join("model.py"),
418            "import torch\nimport torch.nn as nn\nlayer = nn.Linear(10, 5)\nattn = nn.MultiheadAttention(512, 8)\nout = model.generate(ids)\n",
419        )
420        .expect("unexpected failure");
421        let analyzer = LibraryAnalyzer::new();
422        let results = analyzer.analyze_pytorch_usage(&dir).expect("unexpected failure");
423        assert!(results.iter().any(|r| r.contains("nn.Linear")));
424        assert!(results.iter().any(|r| r.contains("MultiheadAttention")));
425        assert!(results.iter().any(|r| r.contains(".generate(")));
426        cleanup(&dir);
427    }
428
429    // ===== Subdirectory traversal =====
430
431    #[cfg(feature = "native")]
432    #[test]
433    fn test_analyze_numpy_recursive() {
434        let dir = setup_dir("test_pa_numpy_recurse");
435        let sub = dir.join("pkg").join("sub");
436        std::fs::create_dir_all(&sub).expect("mkdir failed");
437        std::fs::write(sub.join("deep.py"), "from numpy import array\nx = np.array([1])\n")
438            .expect("unexpected failure");
439        let analyzer = LibraryAnalyzer::new();
440        let results = analyzer.analyze_numpy_usage(&dir).expect("unexpected failure");
441        assert!(results.iter().any(|r| r.contains("np.array")));
442        cleanup(&dir);
443    }
444
445    // ===== Empty directory =====
446
447    #[cfg(feature = "native")]
448    #[test]
449    fn test_analyze_all_empty_dir() {
450        let dir = setup_dir("test_pa_all_empty");
451        let analyzer = LibraryAnalyzer::new();
452        assert!(analyzer.analyze_numpy_usage(&dir).expect("unexpected failure").is_empty());
453        assert!(analyzer.analyze_sklearn_usage(&dir).expect("unexpected failure").is_empty());
454        assert!(analyzer.analyze_pytorch_usage(&dir).expect("unexpected failure").is_empty());
455        cleanup(&dir);
456    }
457}