Skip to main content

standard_version/
pyproject.rs

1//! pyproject.toml version file engine.
2//!
3//! Implements [`VersionFile`] for Python's `pyproject.toml` manifest, detecting
4//! and rewriting the `version` field inside the `[project]` section while
5//! preserving formatting.
6
7use crate::toml_helpers;
8use crate::version_file::{VersionFile, VersionFileError};
9
10/// TOML section header for `pyproject.toml`.
11const SECTION: &str = "[project]";
12
13/// Version file engine for `pyproject.toml`.
14#[derive(Debug, Clone, Copy)]
15pub struct PyprojectVersionFile;
16
17impl VersionFile for PyprojectVersionFile {
18    fn name(&self) -> &str {
19        "pyproject.toml"
20    }
21
22    fn filenames(&self) -> &[&str] {
23        &["pyproject.toml"]
24    }
25
26    fn detect(&self, content: &str) -> bool {
27        toml_helpers::detect_version_in_section(content, SECTION)
28    }
29
30    fn read_version(&self, content: &str) -> Option<String> {
31        toml_helpers::read_version_in_section(content, SECTION)
32    }
33
34    fn write_version(&self, content: &str, new_version: &str) -> Result<String, VersionFileError> {
35        toml_helpers::write_version_in_section(content, SECTION, new_version)
36    }
37}
38
39// ---------------------------------------------------------------------------
40// Tests
41// ---------------------------------------------------------------------------
42
43#[cfg(test)]
44mod tests {
45    use super::*;
46
47    const BASIC_PYPROJECT: &str = r#"[project]
48name = "my-package"
49version = "0.1.0"
50description = "A test package"
51"#;
52
53    const MULTI_SECTION_PYPROJECT: &str = r#"[project]
54name = "my-package"
55version = "0.1.0"
56description = "A test package"
57
58[tool.poetry]
59version = "0.1.0"
60"#;
61
62    // --- detect ---
63
64    #[test]
65    fn detect_with_project_version() {
66        assert!(PyprojectVersionFile.detect(BASIC_PYPROJECT));
67    }
68
69    #[test]
70    fn detect_without_project_section() {
71        let content = "[tool.poetry]\nversion = \"1.0.0\"\n";
72        assert!(!PyprojectVersionFile.detect(content));
73    }
74
75    #[test]
76    fn detect_project_without_version() {
77        let content = "[project]\nname = \"x\"\n\n[tool.poetry]\nversion = \"1.0.0\"\n";
78        assert!(!PyprojectVersionFile.detect(content));
79    }
80
81    // --- read_version ---
82
83    #[test]
84    fn read_version_basic() {
85        assert_eq!(
86            PyprojectVersionFile.read_version(BASIC_PYPROJECT),
87            Some("0.1.0".to_string()),
88        );
89    }
90
91    #[test]
92    fn read_version_no_project() {
93        let content = "[tool.poetry]\nversion = \"1.0.0\"\n";
94        assert_eq!(PyprojectVersionFile.read_version(content), None);
95    }
96
97    // --- write_version ---
98
99    #[test]
100    fn write_version_basic() {
101        let result = PyprojectVersionFile
102            .write_version(BASIC_PYPROJECT, "1.0.0")
103            .unwrap();
104        assert!(result.contains("version = \"1.0.0\""));
105        assert!(result.contains("name = \"my-package\""));
106        assert!(result.contains("description = \"A test package\""));
107    }
108
109    #[test]
110    fn write_version_only_in_project_section() {
111        let result = PyprojectVersionFile
112            .write_version(MULTI_SECTION_PYPROJECT, "2.0.0")
113            .unwrap();
114        assert!(result.contains("version = \"2.0.0\""));
115        // [tool.poetry] version untouched — count occurrences.
116        let count = result.matches("version = \"0.1.0\"").count();
117        assert_eq!(count, 1, "tool.poetry version should remain 0.1.0");
118    }
119
120    #[test]
121    fn write_version_no_field_returns_error() {
122        let content = "[project]\nname = \"x\"\n";
123        let err = PyprojectVersionFile.write_version(content, "1.0.0");
124        assert!(err.is_err());
125    }
126
127    #[test]
128    fn write_version_preserves_no_trailing_newline() {
129        let content = "[project]\nname = \"x\"\nversion = \"0.1.0\"";
130        let result = PyprojectVersionFile
131            .write_version(content, "0.2.0")
132            .unwrap();
133        assert!(!result.ends_with('\n'));
134        assert!(result.contains("version = \"0.2.0\""));
135    }
136
137    #[test]
138    fn integration_with_tempdir() {
139        use std::fs;
140
141        let dir = tempfile::tempdir().unwrap();
142        let pyproject = dir.path().join("pyproject.toml");
143        fs::write(
144            &pyproject,
145            r#"[project]
146name = "example"
147version = "0.1.0"
148requires-python = ">=3.8"
149
150[tool.setuptools]
151packages = ["example"]
152"#,
153        )
154        .unwrap();
155
156        let content = fs::read_to_string(&pyproject).unwrap();
157        assert!(PyprojectVersionFile.detect(&content));
158        assert_eq!(
159            PyprojectVersionFile.read_version(&content),
160            Some("0.1.0".to_string()),
161        );
162
163        let updated = PyprojectVersionFile
164            .write_version(&content, "2.0.0")
165            .unwrap();
166        fs::write(&pyproject, &updated).unwrap();
167
168        let on_disk = fs::read_to_string(&pyproject).unwrap();
169        assert!(on_disk.contains("version = \"2.0.0\""));
170        assert!(on_disk.contains("name = \"example\""));
171        assert!(on_disk.contains("requires-python = \">=3.8\""));
172    }
173}