python_check_updates/parsers/
conda.rs

1use super::{Dependency, DependencyParser};
2use crate::version::VersionSpec;
3use anyhow::{Context, Result};
4use serde_yaml::Value;
5use std::fs;
6use std::path::PathBuf;
7
8/// Parser for conda environment.yml files
9pub struct CondaParser;
10
11impl CondaParser {
12    pub fn new() -> Self {
13        Self
14    }
15
16    /// Parse a single dependency string from conda format
17    /// Examples:
18    /// - "numpy" -> (numpy, Any)
19    /// - "numpy=1.24.0" -> (numpy, ==1.24.0)
20    /// - "numpy>=1.24.0" -> (numpy, >=1.24.0)
21    /// - "python=3.9.*" -> (python, ==3.9.*)
22    fn parse_conda_dependency(dep_str: &str) -> Option<(String, VersionSpec)> {
23        let dep_str = dep_str.trim();
24
25        // Skip empty strings or comments
26        if dep_str.is_empty() || dep_str.starts_with('#') {
27            return None;
28        }
29
30        // Conda uses = for exact version, >= for minimum, etc.
31        // Examples: numpy=1.24.0, numpy>=1.24, numpy, python=3.9.*
32
33        // Check for version operators (order matters - check >= before =)
34        if let Some(idx) = dep_str.find(">=") {
35            let name = dep_str[..idx].trim().to_lowercase();
36            let version_str = dep_str[idx + 2..].trim();
37            return match VersionSpec::parse(&format!(">={}", version_str)) {
38                Ok(spec) => Some((name, spec)),
39                Err(_) => Some((name, VersionSpec::Any)),
40            };
41        }
42
43        if let Some(idx) = dep_str.find("<=") {
44            let name = dep_str[..idx].trim().to_lowercase();
45            let version_str = dep_str[idx + 2..].trim();
46            return match VersionSpec::parse(&format!("<={}", version_str)) {
47                Ok(spec) => Some((name, spec)),
48                Err(_) => Some((name, VersionSpec::Any)),
49            };
50        }
51
52        if let Some(idx) = dep_str.find("!=") {
53            let name = dep_str[..idx].trim().to_lowercase();
54            let version_str = dep_str[idx + 2..].trim();
55            return match VersionSpec::parse(&format!("!={}", version_str)) {
56                Ok(spec) => Some((name, spec)),
57                Err(_) => Some((name, VersionSpec::Any)),
58            };
59        }
60
61        if let Some(idx) = dep_str.find('>') {
62            let name = dep_str[..idx].trim().to_lowercase();
63            let version_str = dep_str[idx + 1..].trim();
64            return match VersionSpec::parse(&format!(">{}", version_str)) {
65                Ok(spec) => Some((name, spec)),
66                Err(_) => Some((name, VersionSpec::Any)),
67            };
68        }
69
70        if let Some(idx) = dep_str.find('<') {
71            let name = dep_str[..idx].trim().to_lowercase();
72            let version_str = dep_str[idx + 1..].trim();
73            return match VersionSpec::parse(&format!("<{}", version_str)) {
74                Ok(spec) => Some((name, spec)),
75                Err(_) => Some((name, VersionSpec::Any)),
76            };
77        }
78
79        if let Some(idx) = dep_str.find('=') {
80            let name = dep_str[..idx].trim().to_lowercase();
81            let version_str = dep_str[idx + 1..].trim();
82
83            // Conda uses = for pinning, convert to ==
84            return match VersionSpec::parse(&format!("=={}", version_str)) {
85                Ok(spec) => Some((name, spec)),
86                Err(_) => Some((name, VersionSpec::Any)),
87            };
88        }
89
90        // No version specified - just package name
91        let name = dep_str.to_lowercase();
92        Some((name, VersionSpec::Any))
93    }
94
95    /// Parse a pip dependency string (these follow pip format, not conda format)
96    /// Examples:
97    /// - "numpy" -> (numpy, Any)
98    /// - "numpy==1.24.0" -> (numpy, ==1.24.0)
99    /// - "numpy>=1.24.0,<2.0.0" -> (numpy, >=1.24.0,<2.0.0)
100    fn parse_pip_dependency(dep_str: &str) -> Option<(String, VersionSpec)> {
101        let dep_str = dep_str.trim();
102
103        // Skip empty strings or comments
104        if dep_str.is_empty() || dep_str.starts_with('#') {
105            return None;
106        }
107
108        // Pip format uses various operators: ==, >=, <=, ~=, !=, <, >
109        // and can have multiple constraints separated by commas
110
111        // Find where the version spec starts (first operator character)
112        let operators = ["==", ">=", "<=", "~=", "!=", "<", ">", "^", "~"];
113        let mut split_pos = None;
114
115        for op in &operators {
116            if let Some(pos) = dep_str.find(op) {
117                if split_pos.is_none() || pos < split_pos.unwrap() {
118                    split_pos = Some(pos);
119                }
120            }
121        }
122
123        if let Some(pos) = split_pos {
124            let name = dep_str[..pos].trim().to_lowercase();
125            let version_str = dep_str[pos..].trim();
126
127            return match VersionSpec::parse(version_str) {
128                Ok(spec) => Some((name, spec)),
129                Err(_) => Some((name, VersionSpec::Any)),
130            };
131        }
132
133        // No version specified - just package name
134        let name = dep_str.to_lowercase();
135        Some((name, VersionSpec::Any))
136    }
137}
138
139impl DependencyParser for CondaParser {
140    fn parse(&self, path: &PathBuf) -> Result<Vec<Dependency>> {
141        let content = fs::read_to_string(path)
142            .context(format!("Failed to read file: {}", path.display()))?;
143
144        let yaml: Value = serde_yaml::from_str(&content)
145            .context(format!("Failed to parse YAML: {}", path.display()))?;
146
147        let mut dependencies = Vec::new();
148
149        // Get the dependencies list
150        if let Some(deps) = yaml.get("dependencies").and_then(|v| v.as_sequence()) {
151            for (idx, dep) in deps.iter().enumerate() {
152                // Line number is approximate - YAML line numbers are tricky
153                // We'll use the array index + 1 (assuming dependencies: starts at line 1)
154                let line_number = idx + 2; // +2 because: 1 for "dependencies:" line, 1 for 0-based index
155
156                // Dependencies can be either strings or objects (for pip section)
157                if let Some(dep_str) = dep.as_str() {
158                    // Regular conda dependency as a string
159                    if let Some((name, version_spec)) = Self::parse_conda_dependency(dep_str) {
160                        dependencies.push(Dependency {
161                            name,
162                            version_spec,
163                            source_file: path.clone(),
164                            line_number,
165                            original_line: format!("  - {}", dep_str),
166                        });
167                    }
168                } else if let Some(pip_section) = dep.as_mapping() {
169                    // This might be a pip section: { pip: [...] }
170                    if let Some(pip_deps) = pip_section.get("pip").and_then(|v| v.as_sequence()) {
171                        for (pip_idx, pip_dep) in pip_deps.iter().enumerate() {
172                            if let Some(pip_dep_str) = pip_dep.as_str() {
173                                if let Some((name, version_spec)) = Self::parse_pip_dependency(pip_dep_str) {
174                                    dependencies.push(Dependency {
175                                        name,
176                                        version_spec,
177                                        source_file: path.clone(),
178                                        line_number: line_number + pip_idx + 1, // Approximate line number
179                                        original_line: format!("    - {}", pip_dep_str),
180                                    });
181                                }
182                            }
183                        }
184                    }
185                }
186            }
187        }
188
189        Ok(dependencies)
190    }
191
192    fn can_parse(&self, path: &PathBuf) -> bool {
193        path.file_name()
194            .and_then(|n| n.to_str())
195            .map(|n| n == "environment.yml" || n == "environment.yaml")
196            .unwrap_or(false)
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use std::io::Write;
204    use tempfile::NamedTempFile;
205
206    #[test]
207    fn test_can_parse() {
208        let parser = CondaParser::new();
209        assert!(parser.can_parse(&PathBuf::from("environment.yml")));
210        assert!(parser.can_parse(&PathBuf::from("environment.yaml")));
211        assert!(!parser.can_parse(&PathBuf::from("requirements.txt")));
212        assert!(!parser.can_parse(&PathBuf::from("pyproject.toml")));
213    }
214
215    #[test]
216    fn test_parse_conda_dependency() {
217        // Test simple package name
218        let (name, spec) = CondaParser::parse_conda_dependency("numpy").unwrap();
219        assert_eq!(name, "numpy");
220        assert!(matches!(spec, VersionSpec::Any));
221
222        // Test conda-style pinned version (=)
223        let (name, spec) = CondaParser::parse_conda_dependency("numpy=1.24.0").unwrap();
224        assert_eq!(name, "numpy");
225        assert!(matches!(spec, VersionSpec::Pinned(_)));
226
227        // Test minimum version
228        let (name, spec) = CondaParser::parse_conda_dependency("numpy>=1.24.0").unwrap();
229        assert_eq!(name, "numpy");
230        assert!(matches!(spec, VersionSpec::Minimum(_)));
231
232        // Test wildcard version
233        let (name, spec) = CondaParser::parse_conda_dependency("python=3.9.*").unwrap();
234        assert_eq!(name, "python");
235        assert!(matches!(spec, VersionSpec::Wildcard { .. }));
236    }
237
238    #[test]
239    fn test_parse_pip_dependency() {
240        // Test simple package name
241        let (name, spec) = CondaParser::parse_pip_dependency("requests").unwrap();
242        assert_eq!(name, "requests");
243        assert!(matches!(spec, VersionSpec::Any));
244
245        // Test pip-style pinned version (==)
246        let (name, spec) = CondaParser::parse_pip_dependency("requests==2.28.0").unwrap();
247        assert_eq!(name, "requests");
248        assert!(matches!(spec, VersionSpec::Pinned(_)));
249
250        // Test range
251        let (name, spec) = CondaParser::parse_pip_dependency("numpy>=1.24.0,<2.0.0").unwrap();
252        assert_eq!(name, "numpy");
253        assert!(matches!(spec, VersionSpec::Range { .. }));
254
255        // Test compatible release
256        let (name, spec) = CondaParser::parse_pip_dependency("flask~=2.0.0").unwrap();
257        assert_eq!(name, "flask");
258        assert!(matches!(spec, VersionSpec::Compatible(_)));
259    }
260
261    #[test]
262    fn test_parse_environment_yml() {
263        let yaml_content = r#"
264name: myenv
265channels:
266  - conda-forge
267  - defaults
268dependencies:
269  - python=3.9.*
270  - numpy=1.24.0
271  - pandas>=1.5.0
272  - scikit-learn
273  - pip:
274    - requests==2.28.0
275    - flask>=2.0.0,<3.0.0
276    - django
277"#;
278
279        let mut temp_file = NamedTempFile::new().unwrap();
280        write!(temp_file, "{}", yaml_content).unwrap();
281        let path = temp_file.path().to_path_buf();
282
283        let parser = CondaParser::new();
284        let dependencies = parser.parse(&path).unwrap();
285
286        // Should find 7 dependencies total (4 conda + 3 pip)
287        assert_eq!(dependencies.len(), 7);
288
289        // Check conda dependencies
290        let python_dep = dependencies.iter().find(|d| d.name == "python").unwrap();
291        assert!(matches!(python_dep.version_spec, VersionSpec::Wildcard { .. }));
292
293        let numpy_dep = dependencies.iter().find(|d| d.name == "numpy").unwrap();
294        assert!(matches!(numpy_dep.version_spec, VersionSpec::Pinned(_)));
295
296        let pandas_dep = dependencies.iter().find(|d| d.name == "pandas").unwrap();
297        assert!(matches!(pandas_dep.version_spec, VersionSpec::Minimum(_)));
298
299        let sklearn_dep = dependencies.iter().find(|d| d.name == "scikit-learn").unwrap();
300        assert!(matches!(sklearn_dep.version_spec, VersionSpec::Any));
301
302        // Check pip dependencies
303        let requests_dep = dependencies.iter().find(|d| d.name == "requests").unwrap();
304        assert!(matches!(requests_dep.version_spec, VersionSpec::Pinned(_)));
305
306        let flask_dep = dependencies.iter().find(|d| d.name == "flask").unwrap();
307        assert!(matches!(flask_dep.version_spec, VersionSpec::Range { .. }));
308
309        let django_dep = dependencies.iter().find(|d| d.name == "django").unwrap();
310        assert!(matches!(django_dep.version_spec, VersionSpec::Any));
311    }
312
313    #[test]
314    fn test_parse_environment_yaml() {
315        let yaml_content = r#"
316dependencies:
317  - numpy=1.24.0
318"#;
319
320        let mut temp_file = NamedTempFile::new().unwrap();
321        write!(temp_file, "{}", yaml_content).unwrap();
322
323        // Rename to .yaml extension
324        let temp_path = temp_file.path().to_path_buf();
325        let yaml_path = temp_path.parent().unwrap().join("environment.yaml");
326        std::fs::write(&yaml_path, yaml_content).unwrap();
327
328        let parser = CondaParser::new();
329        assert!(parser.can_parse(&yaml_path));
330
331        let dependencies = parser.parse(&yaml_path).unwrap();
332        assert_eq!(dependencies.len(), 1);
333        assert_eq!(dependencies[0].name, "numpy");
334
335        // Clean up
336        std::fs::remove_file(&yaml_path).ok();
337    }
338
339    #[test]
340    fn test_empty_dependencies() {
341        let yaml_content = r#"
342name: myenv
343dependencies: []
344"#;
345
346        let mut temp_file = NamedTempFile::new().unwrap();
347        write!(temp_file, "{}", yaml_content).unwrap();
348        let path = temp_file.path().to_path_buf();
349
350        let parser = CondaParser::new();
351        let dependencies = parser.parse(&path).unwrap();
352
353        assert_eq!(dependencies.len(), 0);
354    }
355}