Skip to main content

bn/
util.rs

1//! Utility functions for bean ID parsing and status conversion.
2
3use crate::bean::Status;
4use anyhow::{Context, Result};
5use std::path::Path;
6use std::str::FromStr;
7
8/// Validate a bean ID to prevent path traversal attacks.
9///
10/// Valid IDs match the pattern: ^[a-zA-Z0-9._-]+$
11/// This prevents directory escape attacks like "../../../etc/passwd".
12///
13/// # Examples
14/// - "1" ✓ (valid)
15/// - "3.2.1" ✓ (valid)
16/// - "my-task" ✓ (valid)
17/// - "task_v1.0" ✓ (valid)
18/// - "../etc/passwd" ✗ (invalid)
19/// - "task/../escape" ✗ (invalid)
20pub fn validate_bean_id(id: &str) -> Result<()> {
21    if id.is_empty() {
22        return Err(anyhow::anyhow!("Bean ID cannot be empty"));
23    }
24
25    if id.len() > 255 {
26        return Err(anyhow::anyhow!("Bean ID too long (max 255 characters)"));
27    }
28
29    // Check that ID only contains safe characters: alphanumeric, dots, underscores, hyphens
30    if !id
31        .chars()
32        .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '_' || c == '-')
33    {
34        return Err(anyhow::anyhow!(
35            "Invalid bean ID '{}': must contain only alphanumeric characters, dots, underscores, and hyphens",
36            id
37        ));
38    }
39
40    // Ensure no path traversal sequences
41    if id.contains("..") {
42        return Err(anyhow::anyhow!(
43            "Invalid bean ID '{}': cannot contain '..' (path traversal protection)",
44            id
45        ));
46    }
47
48    Ok(())
49}
50
51/// A segment of a dot-separated ID, either numeric or alphanumeric.
52/// Numeric segments sort before alpha segments when compared.
53#[derive(Debug, Clone, PartialEq, Eq)]
54enum IdSegment {
55    Num(u64),
56    Alpha(String),
57}
58
59impl PartialOrd for IdSegment {
60    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
61        Some(self.cmp(other))
62    }
63}
64
65impl Ord for IdSegment {
66    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
67        match (self, other) {
68            (IdSegment::Num(a), IdSegment::Num(b)) => a.cmp(b),
69            (IdSegment::Alpha(a), IdSegment::Alpha(b)) => a.cmp(b),
70            // Numeric segments sort before alpha segments
71            (IdSegment::Num(_), IdSegment::Alpha(_)) => std::cmp::Ordering::Less,
72            (IdSegment::Alpha(_), IdSegment::Num(_)) => std::cmp::Ordering::Greater,
73        }
74    }
75}
76
77/// Compare two bean IDs using natural ordering.
78/// Parses IDs as dot-separated segments and compares them.
79/// Numeric segments are compared numerically, alpha segments lexicographically.
80/// Numeric segments sort before alpha segments.
81///
82/// # Examples
83/// - "1" < "2" (numeric comparison)
84/// - "1" < "10" (numeric comparison, not string comparison)
85/// - "3.1" < "3.2" (multi-level comparison)
86/// - "abc" < "def" (alpha comparison)
87/// - "1" < "abc" (numeric sorts before alpha)
88pub fn natural_cmp(a: &str, b: &str) -> std::cmp::Ordering {
89    let sa = parse_id_segments(a);
90    let sb = parse_id_segments(b);
91    sa.cmp(&sb)
92}
93
94/// Parse a dot-separated ID into segments.
95///
96/// Each segment is parsed as numeric (u64) if possible, otherwise kept as a string.
97/// Used for natural ID comparison.
98///
99/// # Examples
100/// - "1" → [Num(1)]
101/// - "3.1" → [Num(3), Num(1)]
102/// - "my-task" → [Alpha("my-task")]
103/// - "1.abc.2" → [Num(1), Alpha("abc"), Num(2)]
104fn parse_id_segments(id: &str) -> Vec<IdSegment> {
105    id.split('.')
106        .map(|seg| match seg.parse::<u64>() {
107            Ok(n) => IdSegment::Num(n),
108            Err(_) => IdSegment::Alpha(seg.to_string()),
109        })
110        .collect()
111}
112
113/// Convert a status string to a Status enum, or None if invalid.
114///
115/// Valid inputs: "open", "in_progress", "closed"
116pub fn parse_status(s: &str) -> Option<Status> {
117    match s {
118        "open" => Some(Status::Open),
119        "in_progress" => Some(Status::InProgress),
120        "closed" => Some(Status::Closed),
121        _ => None,
122    }
123}
124
125/// Implement FromStr for Status to support standard parsing.
126impl FromStr for Status {
127    type Err = String;
128
129    fn from_str(s: &str) -> Result<Self, Self::Err> {
130        parse_status(s).ok_or_else(|| format!("Invalid status: {}", s))
131    }
132}
133
134/// Convert a bean title into a URL-safe kebab-case slug for use in filenames.
135///
136/// Algorithm:
137/// 1. Trim whitespace
138/// 2. Lowercase all characters
139/// 3. Replace spaces with hyphens
140/// 4. Remove non-alphanumeric characters except hyphens
141/// 5. Collapse consecutive hyphens into single hyphen
142/// 6. Remove leading/trailing hyphens
143/// 7. Truncate to 50 characters
144/// 8. Return "unnamed" if empty
145///
146/// # Examples
147/// - "My Task" → "my-task"
148/// - "Build API v2.0" → "build-api-v20"
149/// - "Foo   Bar" → "foo-bar"
150/// - "Implement `bn show` to render Markdown" → "implement-bn-show-to-render-markdown"
151/// - "Update Bean parser to read .md + YAML frontmatter" → "update-bean-parser-to-read-md-yaml-frontmatter"
152/// - "My-Task!!!" → "my-task"
153/// - "   Spaces   " → "spaces"
154/// - "" (empty) → "unnamed"
155/// - "a" (single char) → "a"
156pub fn title_to_slug(title: &str) -> String {
157    // Step 1: Trim whitespace
158    let trimmed = title.trim();
159
160    // Step 2: Lowercase all characters
161    let lowercased = trimmed.to_lowercase();
162
163    // Step 3 & 4: Replace spaces with hyphens and remove non-alphanumeric (except hyphens)
164    let mut slug = String::new();
165    for c in lowercased.chars() {
166        if c.is_ascii_alphanumeric() {
167            slug.push(c);
168        } else if c.is_whitespace() || c == '-' {
169            slug.push('-');
170        }
171        // Skip all other characters (special chars, punctuation, etc.)
172    }
173
174    // Step 5: Collapse consecutive hyphens into single hyphen
175    let slug = slug.chars().fold(String::new(), |mut acc, c| {
176        if c == '-' && acc.ends_with('-') {
177            acc
178        } else {
179            acc.push(c);
180            acc
181        }
182    });
183
184    // Step 6: Remove leading/trailing hyphens
185    let slug = slug.trim_matches('-').to_string();
186
187    // Step 7: Truncate to 50 characters and re-trim hyphens
188    let slug = if slug.len() > 50 {
189        slug.chars()
190            .take(50)
191            .collect::<String>()
192            .trim_end_matches('-')
193            .to_string()
194    } else {
195        slug
196    };
197
198    // Step 8: Return "unnamed" if empty
199    if slug.is_empty() {
200        "unnamed".to_string()
201    } else {
202        slug
203    }
204}
205
206/// Write contents to a file atomically using write-to-temp + rename.
207///
208/// Writes to a temporary file in the same directory as `path`, then renames
209/// it to the target. `rename()` is atomic on POSIX when source and destination
210/// are on the same filesystem (guaranteed here since we use the same directory).
211/// The temp file is cleaned up on error.
212pub fn atomic_write(path: &Path, contents: &str) -> Result<()> {
213    let tmp_path = path.with_extension(format!("tmp.{}", std::process::id()));
214
215    // Write to temp file; clean up on failure
216    if let Err(e) = std::fs::write(&tmp_path, contents) {
217        let _ = std::fs::remove_file(&tmp_path);
218        return Err(e)
219            .with_context(|| format!("Failed to write temp file: {}", tmp_path.display()));
220    }
221
222    // Atomic rename; clean up temp on failure
223    if let Err(e) = std::fs::rename(&tmp_path, path) {
224        let _ = std::fs::remove_file(&tmp_path);
225        return Err(e).with_context(|| {
226            format!(
227                "Failed to rename {} -> {}",
228                tmp_path.display(),
229                path.display()
230            )
231        });
232    }
233
234    Ok(())
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    // ---------- title_to_slug tests ----------
242
243    #[test]
244    fn title_to_slug_simple_case() {
245        assert_eq!(title_to_slug("My Task"), "my-task");
246    }
247
248    #[test]
249    fn title_to_slug_with_numbers_and_dots() {
250        assert_eq!(title_to_slug("Build API v2.0"), "build-api-v20");
251    }
252
253    #[test]
254    fn title_to_slug_multiple_spaces() {
255        assert_eq!(title_to_slug("Foo   Bar"), "foo-bar");
256    }
257
258    #[test]
259    fn title_to_slug_with_backticks() {
260        assert_eq!(
261            title_to_slug("Implement `bn show` to render Markdown"),
262            "implement-bn-show-to-render-markdown"
263        );
264    }
265
266    #[test]
267    fn title_to_slug_with_special_chars() {
268        assert_eq!(
269            title_to_slug("Update Bean parser to read .md + YAML frontmatter"),
270            "update-bean-parser-to-read-md-yaml-frontmatter"
271        );
272    }
273
274    #[test]
275    fn title_to_slug_with_exclamation() {
276        assert_eq!(title_to_slug("My-Task!!!"), "my-task");
277    }
278
279    #[test]
280    fn title_to_slug_leading_trailing_spaces() {
281        assert_eq!(title_to_slug("   Spaces   "), "spaces");
282    }
283
284    #[test]
285    fn title_to_slug_empty_string() {
286        assert_eq!(title_to_slug(""), "unnamed");
287    }
288
289    #[test]
290    fn title_to_slug_single_character() {
291        assert_eq!(title_to_slug("a"), "a");
292        assert_eq!(title_to_slug("Z"), "z");
293    }
294
295    #[test]
296    fn title_to_slug_only_spaces() {
297        assert_eq!(title_to_slug("   "), "unnamed");
298    }
299
300    #[test]
301    fn title_to_slug_only_special_chars() {
302        assert_eq!(title_to_slug("!!!@@@###"), "unnamed");
303    }
304
305    #[test]
306    fn title_to_slug_truncate_50_chars() {
307        let long_title = "a".repeat(60);
308        let result = title_to_slug(&long_title);
309        assert_eq!(result, "a".repeat(50));
310        assert_eq!(result.len(), 50);
311    }
312
313    #[test]
314    fn title_to_slug_truncate_with_hyphens() {
315        let title = "word ".repeat(20); // Creates long string with hyphens after truncation
316        let result = title_to_slug(&title);
317        assert!(result.len() <= 50);
318    }
319
320    #[test]
321    fn title_to_slug_mixed_case() {
322        assert_eq!(
323            title_to_slug("ThIs Is A MiXeD CaSe TiTle"),
324            "this-is-a-mixed-case-title"
325        );
326    }
327
328    #[test]
329    fn title_to_slug_numbers_preserved() {
330        assert_eq!(
331            title_to_slug("Task 123 Version 4.5.6"),
332            "task-123-version-456"
333        );
334    }
335
336    #[test]
337    fn title_to_slug_consecutive_hyphens() {
338        assert_eq!(title_to_slug("foo---bar"), "foo-bar");
339        assert_eq!(title_to_slug("foo - - bar"), "foo-bar");
340    }
341
342    #[test]
343    fn title_to_slug_unicode_removed() {
344        // Unicode characters are not ASCII alphanumeric, so they get removed
345        assert_eq!(title_to_slug("café"), "caf");
346        assert_eq!(title_to_slug("naïve"), "nave");
347    }
348
349    #[test]
350    fn title_to_slug_all_whitespace_types() {
351        assert_eq!(title_to_slug("foo\tbar\nbaz"), "foo-bar-baz");
352    }
353
354    #[test]
355    fn title_to_slug_exactly_50_chars() {
356        let title = "a".repeat(50);
357        assert_eq!(title_to_slug(&title), title);
358    }
359
360    // ---------- natural_cmp tests ----------
361
362    #[test]
363    fn natural_cmp_single_digit() {
364        assert_eq!(natural_cmp("1", "2"), std::cmp::Ordering::Less);
365        assert_eq!(natural_cmp("2", "1"), std::cmp::Ordering::Greater);
366        assert_eq!(natural_cmp("1", "1"), std::cmp::Ordering::Equal);
367    }
368
369    #[test]
370    fn natural_cmp_multi_digit() {
371        assert_eq!(natural_cmp("1", "10"), std::cmp::Ordering::Less);
372        assert_eq!(natural_cmp("10", "1"), std::cmp::Ordering::Greater);
373        assert_eq!(natural_cmp("10", "10"), std::cmp::Ordering::Equal);
374    }
375
376    #[test]
377    fn natural_cmp_multi_level() {
378        assert_eq!(natural_cmp("3.1", "3.2"), std::cmp::Ordering::Less);
379        assert_eq!(natural_cmp("3.2", "3.1"), std::cmp::Ordering::Greater);
380        assert_eq!(natural_cmp("3.1", "3.1"), std::cmp::Ordering::Equal);
381    }
382
383    #[test]
384    fn natural_cmp_three_level() {
385        assert_eq!(natural_cmp("3.2.1", "3.2.2"), std::cmp::Ordering::Less);
386        assert_eq!(natural_cmp("3.2.2", "3.2.1"), std::cmp::Ordering::Greater);
387        assert_eq!(natural_cmp("3.2.1", "3.2.1"), std::cmp::Ordering::Equal);
388    }
389
390    #[test]
391    fn natural_cmp_different_prefix() {
392        assert_eq!(natural_cmp("2.1", "3.1"), std::cmp::Ordering::Less);
393        assert_eq!(natural_cmp("10.5", "9.99"), std::cmp::Ordering::Greater);
394    }
395
396    // ---------- parse_id_segments tests ----------
397
398    #[test]
399    fn parse_id_segments_single() {
400        assert_eq!(parse_id_segments("1"), vec![IdSegment::Num(1)]);
401        assert_eq!(parse_id_segments("42"), vec![IdSegment::Num(42)]);
402    }
403
404    #[test]
405    fn parse_id_segments_multi_level() {
406        assert_eq!(
407            parse_id_segments("1.2"),
408            vec![IdSegment::Num(1), IdSegment::Num(2)]
409        );
410        assert_eq!(
411            parse_id_segments("3.2.1"),
412            vec![IdSegment::Num(3), IdSegment::Num(2), IdSegment::Num(1)]
413        );
414    }
415
416    #[test]
417    fn parse_id_segments_leading_zeros() {
418        // Leading zeros are parsed as decimal, not octal
419        assert_eq!(parse_id_segments("01"), vec![IdSegment::Num(1)]);
420        assert_eq!(
421            parse_id_segments("03.02"),
422            vec![IdSegment::Num(3), IdSegment::Num(2)]
423        );
424    }
425
426    #[test]
427    fn parse_id_segments_alpha() {
428        assert_eq!(
429            parse_id_segments("abc"),
430            vec![IdSegment::Alpha("abc".to_string())]
431        );
432        assert_eq!(
433            parse_id_segments("1.abc.2"),
434            vec![
435                IdSegment::Num(1),
436                IdSegment::Alpha("abc".to_string()),
437                IdSegment::Num(2)
438            ]
439        );
440    }
441
442    #[test]
443    fn natural_cmp_alpha_ids() {
444        // Alpha IDs should not all compare equal
445        assert_eq!(natural_cmp("abc", "def"), std::cmp::Ordering::Less);
446        assert_eq!(natural_cmp("def", "abc"), std::cmp::Ordering::Greater);
447        assert_eq!(natural_cmp("abc", "abc"), std::cmp::Ordering::Equal);
448    }
449
450    #[test]
451    fn natural_cmp_numeric_before_alpha() {
452        assert_eq!(natural_cmp("1", "abc"), std::cmp::Ordering::Less);
453        assert_eq!(natural_cmp("abc", "1"), std::cmp::Ordering::Greater);
454    }
455
456    #[test]
457    fn natural_cmp_mixed_segments() {
458        // "1.abc.2" vs "1.abc.3" — third segment differs
459        assert_eq!(natural_cmp("1.abc.2", "1.abc.3"), std::cmp::Ordering::Less);
460        // "1.abc" vs "1.def" — second segment differs
461        assert_eq!(natural_cmp("1.abc", "1.def"), std::cmp::Ordering::Less);
462    }
463
464    // ---------- parse_status tests ----------
465
466    #[test]
467    fn parse_status_valid_open() {
468        assert_eq!(parse_status("open"), Some(Status::Open));
469    }
470
471    #[test]
472    fn parse_status_valid_in_progress() {
473        assert_eq!(parse_status("in_progress"), Some(Status::InProgress));
474    }
475
476    #[test]
477    fn parse_status_valid_closed() {
478        assert_eq!(parse_status("closed"), Some(Status::Closed));
479    }
480
481    #[test]
482    fn parse_status_invalid() {
483        assert_eq!(parse_status("invalid"), None);
484        assert_eq!(parse_status(""), None);
485        assert_eq!(parse_status("OPEN"), None);
486        assert_eq!(parse_status("Closed"), None);
487    }
488
489    #[test]
490    fn parse_status_whitespace() {
491        assert_eq!(parse_status("open "), None);
492        assert_eq!(parse_status(" open"), None);
493    }
494
495    // ---------- Status::FromStr tests ----------
496
497    #[test]
498    fn status_from_str_open() {
499        assert_eq!("open".parse::<Status>(), Ok(Status::Open));
500    }
501
502    #[test]
503    fn status_from_str_in_progress() {
504        assert_eq!("in_progress".parse::<Status>(), Ok(Status::InProgress));
505    }
506
507    #[test]
508    fn status_from_str_closed() {
509        assert_eq!("closed".parse::<Status>(), Ok(Status::Closed));
510    }
511
512    #[test]
513    fn status_from_str_invalid() {
514        assert!("invalid".parse::<Status>().is_err());
515        assert!("".parse::<Status>().is_err());
516    }
517
518    // ---------- validate_bean_id tests ----------
519
520    #[test]
521    fn validate_bean_id_simple_numeric() {
522        assert!(validate_bean_id("1").is_ok());
523        assert!(validate_bean_id("42").is_ok());
524        assert!(validate_bean_id("999").is_ok());
525    }
526
527    #[test]
528    fn validate_bean_id_dotted() {
529        assert!(validate_bean_id("3.1").is_ok());
530        assert!(validate_bean_id("3.2.1").is_ok());
531        assert!(validate_bean_id("1.2.3.4.5").is_ok());
532    }
533
534    #[test]
535    fn validate_bean_id_with_underscores() {
536        assert!(validate_bean_id("task_1").is_ok());
537        assert!(validate_bean_id("my_task_v1").is_ok());
538    }
539
540    #[test]
541    fn validate_bean_id_with_hyphens() {
542        assert!(validate_bean_id("my-task").is_ok());
543        assert!(validate_bean_id("task-v1-0").is_ok());
544    }
545
546    #[test]
547    fn validate_bean_id_alphanumeric() {
548        assert!(validate_bean_id("abc123def").is_ok());
549        assert!(validate_bean_id("Task1").is_ok());
550    }
551
552    #[test]
553    fn validate_bean_id_empty_fails() {
554        assert!(validate_bean_id("").is_err());
555    }
556
557    #[test]
558    fn validate_bean_id_path_traversal_fails() {
559        assert!(validate_bean_id("../etc/passwd").is_err());
560        assert!(validate_bean_id("..").is_err());
561        assert!(validate_bean_id("foo/../bar").is_err());
562        assert!(validate_bean_id("task..escape").is_err());
563    }
564
565    #[test]
566    fn validate_bean_id_absolute_path_fails() {
567        assert!(validate_bean_id("/etc/passwd").is_err());
568    }
569
570    #[test]
571    fn validate_bean_id_spaces_fail() {
572        assert!(validate_bean_id("my task").is_err());
573        assert!(validate_bean_id(" 1").is_err());
574        assert!(validate_bean_id("1 ").is_err());
575    }
576
577    #[test]
578    fn validate_bean_id_special_chars_fail() {
579        assert!(validate_bean_id("task@home").is_err());
580        assert!(validate_bean_id("task#1").is_err());
581        assert!(validate_bean_id("task$money").is_err());
582        assert!(validate_bean_id("task%complete").is_err());
583        assert!(validate_bean_id("task&friend").is_err());
584        assert!(validate_bean_id("task*star").is_err());
585        assert!(validate_bean_id("task(paren").is_err());
586        assert!(validate_bean_id("task)close").is_err());
587        assert!(validate_bean_id("task+plus").is_err());
588        assert!(validate_bean_id("task=equals").is_err());
589        assert!(validate_bean_id("task[bracket").is_err());
590        assert!(validate_bean_id("task]close").is_err());
591        assert!(validate_bean_id("task{brace").is_err());
592        assert!(validate_bean_id("task}close").is_err());
593        assert!(validate_bean_id("task|pipe").is_err());
594        assert!(validate_bean_id("task;semicolon").is_err());
595        assert!(validate_bean_id("task:colon").is_err());
596        assert!(validate_bean_id("task\"quote").is_err());
597        assert!(validate_bean_id("task'apostrophe").is_err());
598        assert!(validate_bean_id("task<less").is_err());
599        assert!(validate_bean_id("task>greater").is_err());
600        assert!(validate_bean_id("task,comma").is_err());
601        assert!(validate_bean_id("task?question").is_err());
602    }
603
604    #[test]
605    fn validate_bean_id_too_long() {
606        let long_id = "a".repeat(256);
607        assert!(validate_bean_id(&long_id).is_err());
608
609        let max_id = "a".repeat(255);
610        assert!(validate_bean_id(&max_id).is_ok());
611    }
612
613    // ---------- atomic_write tests ----------
614
615    #[test]
616    fn test_atomic_write_creates_file_with_correct_contents() {
617        let dir = tempfile::tempdir().unwrap();
618        let path = dir.path().join("test.yaml");
619
620        atomic_write(&path, "hello: world\n").unwrap();
621
622        let contents = std::fs::read_to_string(&path).unwrap();
623        assert_eq!(contents, "hello: world\n");
624    }
625
626    #[test]
627    fn test_atomic_write_overwrites_existing_file() {
628        let dir = tempfile::tempdir().unwrap();
629        let path = dir.path().join("test.yaml");
630
631        std::fs::write(&path, "old content").unwrap();
632        atomic_write(&path, "new content").unwrap();
633
634        let contents = std::fs::read_to_string(&path).unwrap();
635        assert_eq!(contents, "new content");
636    }
637
638    #[test]
639    fn test_atomic_write_no_temp_file_left_behind() {
640        let dir = tempfile::tempdir().unwrap();
641        let path = dir.path().join("test.yaml");
642
643        atomic_write(&path, "data").unwrap();
644
645        let entries: Vec<_> = std::fs::read_dir(dir.path())
646            .unwrap()
647            .filter_map(|e| e.ok())
648            .collect();
649        assert_eq!(entries.len(), 1, "only the target file should exist");
650        assert_eq!(entries[0].file_name().to_str().unwrap(), "test.yaml");
651    }
652}