Skip to main content

pcu/parsers/
conda.rs

1use super::{Dependency, DependencyParser};
2use check_updates_core::VersionSpec;
3use anyhow::{Context, Result};
4use serde_yaml::Value;
5use std::fs;
6use std::path::Path;
7
8/// Parser for conda environment.yml files
9pub struct CondaParser;
10
11impl Default for CondaParser {
12    fn default() -> Self {
13        Self::new()
14    }
15}
16
17impl CondaParser {
18    pub fn new() -> Self {
19        Self
20    }
21
22    /// Parse a single dependency string from conda format
23    /// Examples:
24    /// - "numpy" -> (numpy, Any)
25    /// - "numpy=1.24.0" -> (numpy, ==1.24.0)
26    /// - "numpy>=1.24.0" -> (numpy, >=1.24.0)
27    /// - "python=3.9.*" -> (python, ==3.9.*)
28    fn parse_conda_dependency(dep_str: &str) -> Option<(String, VersionSpec)> {
29        let dep_str = dep_str.trim();
30
31        // Skip empty strings or comments
32        if dep_str.is_empty() || dep_str.starts_with('#') {
33            return None;
34        }
35
36        // Conda uses = for exact version, >= for minimum, etc.
37        // Examples: numpy=1.24.0, numpy>=1.24, numpy, python=3.9.*
38
39        // Check for version operators (order matters - check >= before =)
40        if let Some(idx) = dep_str.find(">=") {
41            let name = dep_str[..idx].trim().to_lowercase();
42            let version_str = dep_str[idx + 2..].trim();
43            return match VersionSpec::parse(&format!(">={version_str}")) {
44                Ok(spec) => Some((name, spec)),
45                Err(_) => Some((name, VersionSpec::Any)),
46            };
47        }
48
49        if let Some(idx) = dep_str.find("<=") {
50            let name = dep_str[..idx].trim().to_lowercase();
51            let version_str = dep_str[idx + 2..].trim();
52            return match VersionSpec::parse(&format!("<={version_str}")) {
53                Ok(spec) => Some((name, spec)),
54                Err(_) => Some((name, VersionSpec::Any)),
55            };
56        }
57
58        if let Some(idx) = dep_str.find("!=") {
59            let name = dep_str[..idx].trim().to_lowercase();
60            let version_str = dep_str[idx + 2..].trim();
61            return match VersionSpec::parse(&format!("!={version_str}")) {
62                Ok(spec) => Some((name, spec)),
63                Err(_) => Some((name, VersionSpec::Any)),
64            };
65        }
66
67        if let Some(idx) = dep_str.find('>') {
68            let name = dep_str[..idx].trim().to_lowercase();
69            let version_str = dep_str[idx + 1..].trim();
70            return match VersionSpec::parse(&format!(">{version_str}")) {
71                Ok(spec) => Some((name, spec)),
72                Err(_) => Some((name, VersionSpec::Any)),
73            };
74        }
75
76        if let Some(idx) = dep_str.find('<') {
77            let name = dep_str[..idx].trim().to_lowercase();
78            let version_str = dep_str[idx + 1..].trim();
79            return match VersionSpec::parse(&format!("<{version_str}")) {
80                Ok(spec) => Some((name, spec)),
81                Err(_) => Some((name, VersionSpec::Any)),
82            };
83        }
84
85        if let Some(idx) = dep_str.find('=') {
86            let name = dep_str[..idx].trim().to_lowercase();
87            let version_str = dep_str[idx + 1..].trim();
88
89            // Conda uses = for pinning, convert to ==
90            return match VersionSpec::parse(&format!("=={version_str}")) {
91                Ok(spec) => Some((name, spec)),
92                Err(_) => Some((name, VersionSpec::Any)),
93            };
94        }
95
96        // No version specified - just package name
97        let name = dep_str.to_lowercase();
98        Some((name, VersionSpec::Any))
99    }
100
101    /// Parse a pip dependency string (these follow pip format, not conda format)
102    /// Examples:
103    /// - "numpy" -> (numpy, Any)
104    /// - "numpy==1.24.0" -> (numpy, ==1.24.0)
105    /// - "numpy>=1.24.0,<2.0.0" -> (numpy, >=1.24.0,<2.0.0)
106    fn parse_pip_dependency(dep_str: &str) -> Option<(String, VersionSpec)> {
107        let dep_str = dep_str.trim();
108
109        // Skip empty strings or comments
110        if dep_str.is_empty() || dep_str.starts_with('#') {
111            return None;
112        }
113
114        // Pip format uses various operators: ==, >=, <=, ~=, !=, <, >
115        // and can have multiple constraints separated by commas
116
117        // Find where the version spec starts (first operator character)
118        let operators = ["==", ">=", "<=", "~=", "!=", "<", ">", "^", "~"];
119        let mut split_pos = None;
120
121        for op in &operators {
122            if let Some(pos) = dep_str.find(op)
123                && split_pos.is_none_or(|sp| pos < sp) {
124                    split_pos = Some(pos);
125                }
126        }
127
128        if let Some(pos) = split_pos {
129            let name = dep_str[..pos].trim().to_lowercase();
130            let version_str = dep_str[pos..].trim();
131
132            return match VersionSpec::parse(version_str) {
133                Ok(spec) => Some((name, spec)),
134                Err(_) => Some((name, VersionSpec::Any)),
135            };
136        }
137
138        // No version specified - just package name
139        let name = dep_str.to_lowercase();
140        Some((name, VersionSpec::Any))
141    }
142}
143
144impl DependencyParser for CondaParser {
145    fn parse(&self, path: &Path) -> Result<Vec<Dependency>> {
146        let content = fs::read_to_string(path)
147            .context(format!("Failed to read file: {}", path.display()))?;
148
149        let yaml: Value = serde_yaml::from_str(&content)
150            .context(format!("Failed to parse YAML: {}", path.display()))?;
151
152        let mut dependencies = Vec::new();
153
154        // Get the dependencies list
155        if let Some(deps) = yaml.get("dependencies").and_then(|v| v.as_sequence()) {
156            for (idx, dep) in deps.iter().enumerate() {
157                // Line number is approximate - YAML line numbers are tricky
158                // We'll use the array index + 1 (assuming dependencies: starts at line 1)
159                let line_number = idx + 2; // +2 because: 1 for "dependencies:" line, 1 for 0-based index
160
161                // Dependencies can be either strings or objects (for pip section)
162                if let Some(dep_str) = dep.as_str() {
163                    // Regular conda dependency as a string
164                    if let Some((name, version_spec)) = Self::parse_conda_dependency(dep_str) {
165                        dependencies.push(Dependency {
166                            name,
167                            version_spec,
168                            source_file: path.to_path_buf(),
169                            line_number,
170                            original_line: format!("  - {dep_str}"),
171                        });
172                    }
173                } else if let Some(pip_section) = dep.as_mapping() {
174                    // This might be a pip section: { pip: [...] }
175                    if let Some(pip_deps) = pip_section.get("pip").and_then(|v| v.as_sequence()) {
176                        for (pip_idx, pip_dep) in pip_deps.iter().enumerate() {
177                            if let Some(pip_dep_str) = pip_dep.as_str()
178                                && let Some((name, version_spec)) = Self::parse_pip_dependency(pip_dep_str) {
179                                    dependencies.push(Dependency {
180                                        name,
181                                        version_spec,
182                                        source_file: path.to_path_buf(),
183                                        line_number: line_number + pip_idx + 1, // Approximate line number
184                                        original_line: format!("    - {pip_dep_str}"),
185                                    });
186                                }
187                        }
188                    }
189                }
190            }
191        }
192
193        Ok(dependencies)
194    }
195
196    fn can_parse(&self, path: &Path) -> bool {
197        path.file_name()
198            .and_then(|n| n.to_str())
199            .map(|n| n == "environment.yml" || n == "environment.yaml")
200            .unwrap_or(false)
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use std::io::Write;
208    use std::path::PathBuf;
209    use tempfile::NamedTempFile;
210
211    #[test]
212    fn test_can_parse() {
213        let parser = CondaParser::new();
214        assert!(parser.can_parse(&PathBuf::from("environment.yml")));
215        assert!(parser.can_parse(&PathBuf::from("environment.yaml")));
216        assert!(!parser.can_parse(&PathBuf::from("requirements.txt")));
217        assert!(!parser.can_parse(&PathBuf::from("pyproject.toml")));
218    }
219
220    #[test]
221    fn test_parse_conda_dependency() {
222        // Test simple package name
223        let (name, spec) = CondaParser::parse_conda_dependency("numpy").unwrap();
224        assert_eq!(name, "numpy");
225        assert!(matches!(spec, VersionSpec::Any));
226
227        // Test conda-style pinned version (=)
228        let (name, spec) = CondaParser::parse_conda_dependency("numpy=1.24.0").unwrap();
229        assert_eq!(name, "numpy");
230        assert!(matches!(spec, VersionSpec::Pinned(_)));
231
232        // Test minimum version
233        let (name, spec) = CondaParser::parse_conda_dependency("numpy>=1.24.0").unwrap();
234        assert_eq!(name, "numpy");
235        assert!(matches!(spec, VersionSpec::Minimum(_)));
236
237        // Test wildcard version
238        let (name, spec) = CondaParser::parse_conda_dependency("python=3.9.*").unwrap();
239        assert_eq!(name, "python");
240        assert!(matches!(spec, VersionSpec::Wildcard { .. }));
241    }
242
243    #[test]
244    fn test_parse_pip_dependency() {
245        // Test simple package name
246        let (name, spec) = CondaParser::parse_pip_dependency("requests").unwrap();
247        assert_eq!(name, "requests");
248        assert!(matches!(spec, VersionSpec::Any));
249
250        // Test pip-style pinned version (==)
251        let (name, spec) = CondaParser::parse_pip_dependency("requests==2.28.0").unwrap();
252        assert_eq!(name, "requests");
253        assert!(matches!(spec, VersionSpec::Pinned(_)));
254
255        // Test range
256        let (name, spec) = CondaParser::parse_pip_dependency("numpy>=1.24.0,<2.0.0").unwrap();
257        assert_eq!(name, "numpy");
258        assert!(matches!(spec, VersionSpec::Range { .. }));
259
260        // Test compatible release
261        let (name, spec) = CondaParser::parse_pip_dependency("flask~=2.0.0").unwrap();
262        assert_eq!(name, "flask");
263        assert!(matches!(spec, VersionSpec::Compatible(_)));
264    }
265
266    #[test]
267    fn test_parse_environment_yml() {
268        let yaml_content = r#"
269name: myenv
270channels:
271  - conda-forge
272  - defaults
273dependencies:
274  - python=3.9.*
275  - numpy=1.24.0
276  - pandas>=1.5.0
277  - scikit-learn
278  - pip:
279    - requests==2.28.0
280    - flask>=2.0.0,<3.0.0
281    - django
282"#;
283
284        let mut temp_file = NamedTempFile::new().unwrap();
285        write!(temp_file, "{}", yaml_content).unwrap();
286        let path = temp_file.path().to_path_buf();
287
288        let parser = CondaParser::new();
289        let dependencies = parser.parse(&path).unwrap();
290
291        // Should find 7 dependencies total (4 conda + 3 pip)
292        assert_eq!(dependencies.len(), 7);
293
294        // Check conda dependencies
295        let python_dep = dependencies.iter().find(|d| d.name == "python").unwrap();
296        assert!(matches!(python_dep.version_spec, VersionSpec::Wildcard { .. }));
297
298        let numpy_dep = dependencies.iter().find(|d| d.name == "numpy").unwrap();
299        assert!(matches!(numpy_dep.version_spec, VersionSpec::Pinned(_)));
300
301        let pandas_dep = dependencies.iter().find(|d| d.name == "pandas").unwrap();
302        assert!(matches!(pandas_dep.version_spec, VersionSpec::Minimum(_)));
303
304        let sklearn_dep = dependencies.iter().find(|d| d.name == "scikit-learn").unwrap();
305        assert!(matches!(sklearn_dep.version_spec, VersionSpec::Any));
306
307        // Check pip dependencies
308        let requests_dep = dependencies.iter().find(|d| d.name == "requests").unwrap();
309        assert!(matches!(requests_dep.version_spec, VersionSpec::Pinned(_)));
310
311        let flask_dep = dependencies.iter().find(|d| d.name == "flask").unwrap();
312        assert!(matches!(flask_dep.version_spec, VersionSpec::Range { .. }));
313
314        let django_dep = dependencies.iter().find(|d| d.name == "django").unwrap();
315        assert!(matches!(django_dep.version_spec, VersionSpec::Any));
316    }
317
318    #[test]
319    fn test_parse_environment_yaml() {
320        let yaml_content = r#"
321dependencies:
322  - numpy=1.24.0
323"#;
324
325        let mut temp_file = NamedTempFile::new().unwrap();
326        write!(temp_file, "{}", yaml_content).unwrap();
327
328        // Rename to .yaml extension
329        let temp_path = temp_file.path().to_path_buf();
330        let yaml_path = temp_path.parent().unwrap().join("environment.yaml");
331        std::fs::write(&yaml_path, yaml_content).unwrap();
332
333        let parser = CondaParser::new();
334        assert!(parser.can_parse(&yaml_path));
335
336        let dependencies = parser.parse(&yaml_path).unwrap();
337        assert_eq!(dependencies.len(), 1);
338        assert_eq!(dependencies[0].name, "numpy");
339
340        // Clean up
341        std::fs::remove_file(&yaml_path).ok();
342    }
343
344    #[test]
345    fn test_empty_dependencies() {
346        let yaml_content = r#"
347name: myenv
348dependencies: []
349"#;
350
351        let mut temp_file = NamedTempFile::new().unwrap();
352        write!(temp_file, "{}", yaml_content).unwrap();
353        let path = temp_file.path().to_path_buf();
354
355        let parser = CondaParser::new();
356        let dependencies = parser.parse(&path).unwrap();
357
358        assert_eq!(dependencies.len(), 0);
359    }
360}