Skip to main content

omni_dev/data/
amendments.rs

1//! Amendment data structures and validation.
2
3use std::fs;
4use std::path::Path;
5
6use anyhow::{Context, Result};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9
10/// Amendment file structure.
11#[derive(Debug, Serialize, Deserialize, JsonSchema)]
12#[schemars(deny_unknown_fields)]
13pub struct AmendmentFile {
14    /// List of commit amendments to apply.
15    pub amendments: Vec<Amendment>,
16}
17
18/// Individual commit amendment.
19///
20/// `summary` is force-included in `required` via the
21/// `#[schemars(extend(...))]` attribute even though `#[serde(default)]`
22/// would normally exclude it: OpenAI's strict-subset rule requires every
23/// property in `properties` to appear in `required`, while we still want
24/// graceful YAML loading for files written before the field existed.
25#[derive(Debug, Serialize, Deserialize, JsonSchema)]
26#[schemars(deny_unknown_fields)]
27#[schemars(extend("required" = ["commit", "message", "summary"]))]
28pub struct Amendment {
29    /// Full 40-character SHA-1 commit hash.
30    pub commit: String,
31    /// New commit message.
32    pub message: String,
33    /// Brief summary of what this commit changes (for cross-commit coherence).
34    /// Empty string when the model has nothing to add.
35    #[serde(default)]
36    pub summary: String,
37}
38
39impl AmendmentFile {
40    /// Loads amendments from a YAML file.
41    pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
42        let content = fs::read_to_string(&path).with_context(|| {
43            format!("Failed to read amendment file: {}", path.as_ref().display())
44        })?;
45
46        let amendment_file: Self =
47            crate::data::from_yaml(&content).context("Failed to parse YAML amendment file")?;
48
49        amendment_file.validate()?;
50
51        Ok(amendment_file)
52    }
53
54    /// Validates amendment file structure and content.
55    pub fn validate(&self) -> Result<()> {
56        // Empty amendments are allowed - they indicate no changes are needed
57        for (i, amendment) in self.amendments.iter().enumerate() {
58            amendment
59                .validate()
60                .with_context(|| format!("Invalid amendment at index {i}"))?;
61        }
62
63        Ok(())
64    }
65
66    /// Saves amendments to a YAML file with proper multiline formatting.
67    pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
68        let yaml_content =
69            serde_yaml::to_string(self).context("Failed to serialize amendments to YAML")?;
70
71        // Post-process YAML to use literal block scalars for multiline messages
72        let formatted_yaml = self.format_multiline_yaml(&yaml_content);
73
74        fs::write(&path, formatted_yaml).with_context(|| {
75            format!(
76                "Failed to write amendment file: {}",
77                path.as_ref().display()
78            )
79        })?;
80
81        Ok(())
82    }
83
84    /// Formats YAML to use literal block scalars for multiline messages.
85    fn format_multiline_yaml(&self, yaml: &str) -> String {
86        let mut result = String::new();
87        let lines: Vec<&str> = yaml.lines().collect();
88        let mut i = 0;
89
90        while i < lines.len() {
91            let line = lines[i];
92
93            // Check if this is a message field with a quoted multiline string
94            if line.trim_start().starts_with("message:") && line.contains('"') {
95                let indent = line.len() - line.trim_start().len();
96                let indent_str = " ".repeat(indent);
97
98                // Extract the quoted content
99                if let Some(start_quote) = line.find('"') {
100                    if let Some(end_quote) = line.rfind('"') {
101                        if start_quote != end_quote {
102                            let quoted_content = &line[start_quote + 1..end_quote];
103
104                            // Check if it contains newlines (multiline content)
105                            if quoted_content.contains("\\n") {
106                                // Convert to literal block scalar format
107                                result.push_str(&format!("{indent_str}message: |\n"));
108
109                                // Process the content, converting \n to actual newlines
110                                let unescaped = quoted_content.replace("\\n", "\n");
111                                for (line_idx, content_line) in unescaped.lines().enumerate() {
112                                    if line_idx == 0 && content_line.trim().is_empty() {
113                                        // Skip leading empty line
114                                        continue;
115                                    }
116                                    result.push_str(&format!("{indent_str}  {content_line}\n"));
117                                }
118                                i += 1;
119                                continue;
120                            }
121                        }
122                    }
123                }
124            }
125
126            // Default: just copy the line as-is
127            result.push_str(line);
128            result.push('\n');
129            i += 1;
130        }
131
132        result
133    }
134}
135
136impl Amendment {
137    /// Creates a new amendment.
138    pub fn new(commit: String, message: String) -> Self {
139        Self {
140            commit,
141            message,
142            summary: String::new(),
143        }
144    }
145
146    /// Validates amendment structure.
147    pub fn validate(&self) -> Result<()> {
148        // Validate commit hash format
149        if self.commit.len() != crate::git::FULL_HASH_LEN {
150            anyhow::bail!(
151                "Commit hash must be exactly {} characters long, got: {}",
152                crate::git::FULL_HASH_LEN,
153                self.commit.len()
154            );
155        }
156
157        if !self.commit.chars().all(|c| c.is_ascii_hexdigit()) {
158            anyhow::bail!("Commit hash must contain only hexadecimal characters");
159        }
160
161        if !self
162            .commit
163            .chars()
164            .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit())
165        {
166            anyhow::bail!("Commit hash must be lowercase");
167        }
168
169        // Validate message content
170        if self.message.trim().is_empty() {
171            anyhow::bail!("Commit message cannot be empty");
172        }
173
174        Ok(())
175    }
176}
177
178#[cfg(test)]
179#[allow(clippy::unwrap_used, clippy::expect_used)]
180mod tests {
181    use super::*;
182    use tempfile::TempDir;
183
184    // ── Amendment::validate ──────────────────────────────────────────
185
186    #[test]
187    fn valid_amendment() {
188        let amendment = Amendment::new("a".repeat(40), "feat: add feature".to_string());
189        assert!(amendment.validate().is_ok());
190    }
191
192    #[test]
193    fn short_hash_rejected() {
194        let amendment = Amendment::new("abc1234".to_string(), "feat: add feature".to_string());
195        let err = amendment.validate().unwrap_err();
196        assert!(err.to_string().contains("exactly"));
197    }
198
199    #[test]
200    fn uppercase_hash_rejected() {
201        let amendment = Amendment::new("A".repeat(40), "feat: add feature".to_string());
202        let err = amendment.validate().unwrap_err();
203        assert!(err.to_string().contains("lowercase"));
204    }
205
206    #[test]
207    fn non_hex_hash_rejected() {
208        let amendment = Amendment::new("g".repeat(40), "feat: add feature".to_string());
209        let err = amendment.validate().unwrap_err();
210        assert!(err.to_string().contains("hexadecimal"));
211    }
212
213    #[test]
214    fn empty_message_rejected() {
215        let amendment = Amendment::new("a".repeat(40), "   ".to_string());
216        let err = amendment.validate().unwrap_err();
217        assert!(err.to_string().contains("empty"));
218    }
219
220    #[test]
221    fn valid_hex_digits() {
222        // All valid hex chars: 0-9, a-f
223        let hash = "0123456789abcdef0123456789abcdef01234567";
224        let amendment = Amendment::new(hash.to_string(), "fix: something".to_string());
225        assert!(amendment.validate().is_ok());
226    }
227
228    // ── AmendmentFile::validate ──────────────────────────────────────
229
230    #[test]
231    fn validate_empty_amendments_ok() {
232        let file = AmendmentFile { amendments: vec![] };
233        assert!(file.validate().is_ok());
234    }
235
236    #[test]
237    fn validate_propagates_amendment_errors() {
238        let file = AmendmentFile {
239            amendments: vec![Amendment::new("short".to_string(), "msg".to_string())],
240        };
241        let err = file.validate().unwrap_err();
242        assert!(err.to_string().contains("index 0"));
243    }
244
245    // ── AmendmentFile round-trip ─────────────────────────────────────
246
247    #[test]
248    fn save_and_load_roundtrip() -> Result<()> {
249        let dir = {
250            std::fs::create_dir_all("tmp")?;
251            TempDir::new_in("tmp")?
252        };
253        let path = dir.path().join("amendments.yaml");
254
255        let original = AmendmentFile {
256            amendments: vec![
257                Amendment {
258                    commit: "a".repeat(40),
259                    message: "feat(cli): add new command".to_string(),
260                    summary: "Adds the twiddle command".to_string(),
261                },
262                Amendment {
263                    commit: "b".repeat(40),
264                    message: "fix(git): resolve rebase issue\n\nDetailed body here.".to_string(),
265                    summary: String::new(),
266                },
267            ],
268        };
269
270        original.save_to_file(&path)?;
271        let loaded = AmendmentFile::load_from_file(&path)?;
272
273        assert_eq!(loaded.amendments.len(), 2);
274        assert_eq!(loaded.amendments[0].commit, "a".repeat(40));
275        assert_eq!(loaded.amendments[0].message, "feat(cli): add new command");
276        assert_eq!(loaded.amendments[1].commit, "b".repeat(40));
277        assert!(loaded.amendments[1]
278            .message
279            .contains("resolve rebase issue"));
280        Ok(())
281    }
282
283    #[test]
284    fn load_invalid_yaml_fails() -> Result<()> {
285        let dir = {
286            std::fs::create_dir_all("tmp")?;
287            TempDir::new_in("tmp")?
288        };
289        let path = dir.path().join("bad.yaml");
290        fs::write(&path, "not: valid: yaml: [{{")?;
291        assert!(AmendmentFile::load_from_file(&path).is_err());
292        Ok(())
293    }
294
295    #[test]
296    fn load_nonexistent_file_fails() {
297        assert!(AmendmentFile::load_from_file("/nonexistent/path.yaml").is_err());
298    }
299
300    // ── property tests ────────────────────────────────────────────
301
302    mod prop {
303        use super::*;
304        use proptest::prelude::*;
305
306        proptest! {
307            #[test]
308            fn valid_hex_hash_nonempty_msg_validates(
309                hash in "[0-9a-f]{40}",
310                msg in "[a-zA-Z0-9].{0,200}",
311            ) {
312                let amendment = Amendment::new(hash, msg);
313                prop_assert!(amendment.validate().is_ok());
314            }
315
316            #[test]
317            fn wrong_length_hash_rejects(
318                len in (1_usize..80).prop_filter("not 40", |l| *l != 40),
319            ) {
320                let hash: String = "a".repeat(len);
321                let amendment = Amendment::new(hash, "valid message".to_string());
322                prop_assert!(amendment.validate().is_err());
323            }
324
325            #[test]
326            fn non_hex_char_in_hash_rejects(
327                pos in 0_usize..40,
328                bad_idx in 0_usize..20,
329            ) {
330                let bad_chars = "ghijklmnopqrstuvwxyz";
331                let bad_char = bad_chars.as_bytes()[bad_idx % bad_chars.len()] as char;
332                let mut chars: Vec<char> = "a".repeat(40).chars().collect();
333                chars[pos] = bad_char;
334                let hash: String = chars.into_iter().collect();
335                let amendment = Amendment::new(hash, "valid message".to_string());
336                prop_assert!(amendment.validate().is_err());
337            }
338
339            #[test]
340            fn whitespace_only_message_rejects(
341                hash in "[0-9a-f]{40}",
342                ws in "[ \t\n]{1,20}",
343            ) {
344                let amendment = Amendment::new(hash, ws);
345                prop_assert!(amendment.validate().is_err());
346            }
347
348            #[test]
349            fn roundtrip_save_load(
350                count in 1_usize..5,
351            ) {
352                let tmp_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("tmp");
353                let dir = { std::fs::create_dir_all(&tmp_root).ok(); tempfile::TempDir::new_in(&tmp_root).unwrap() };
354                let path = dir.path().join("amendments.yaml");
355                let amendments: Vec<Amendment> = (0..count)
356                    .map(|i| {
357                        let hash = format!("{i:0>40x}");
358                        Amendment::new(hash, format!("feat: message {i}"))
359                    })
360                    .collect();
361                let original = AmendmentFile { amendments };
362                original.save_to_file(&path).unwrap();
363                let loaded = AmendmentFile::load_from_file(&path).unwrap();
364                prop_assert_eq!(loaded.amendments.len(), original.amendments.len());
365                for (orig, load) in original.amendments.iter().zip(loaded.amendments.iter()) {
366                    prop_assert_eq!(&orig.commit, &load.commit);
367                    // Messages may differ slightly due to YAML block scalar formatting
368                    prop_assert!(load.message.contains(orig.message.lines().next().unwrap()));
369                }
370            }
371        }
372    }
373}