rumdl_lib/types/
heading_level.rs

1use serde::{Deserialize, Deserializer, Serialize, Serializer};
2use std::fmt;
3
4/// Markdown heading level (1-6)
5///
6/// Markdown supports exactly 6 levels of headings, from # (level 1) through ###### (level 6).
7/// This type enforces that constraint at both compile time (after construction) and runtime
8/// (during config deserialization).
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
10pub struct HeadingLevel(u8);
11
12impl HeadingLevel {
13    /// Create a new heading level, validating it's in the range 1-6.
14    ///
15    /// # Errors
16    /// Returns `HeadingLevelError` if the level is not between 1 and 6 inclusive.
17    pub fn new(level: u8) -> Result<Self, HeadingLevelError> {
18        if (1..=6).contains(&level) {
19            Ok(Self(level))
20        } else {
21            Err(HeadingLevelError(level))
22        }
23    }
24
25    /// Get the underlying heading level value (1-6).
26    pub fn get(self) -> u8 {
27        self.0
28    }
29
30    /// Convert to usize for compatibility with existing code.
31    pub fn as_usize(self) -> usize {
32        self.0 as usize
33    }
34}
35
36/// Error type for invalid heading levels.
37#[derive(Debug, Clone, Copy)]
38pub struct HeadingLevelError(u8);
39
40impl fmt::Display for HeadingLevelError {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        write!(
43            f,
44            "Heading level must be between 1 and 6, got {}. \
45             Markdown supports only 6 heading levels (# through ######).",
46            self.0
47        )
48    }
49}
50
51impl std::error::Error for HeadingLevelError {}
52
53impl<'de> Deserialize<'de> for HeadingLevel {
54    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
55    where
56        D: Deserializer<'de>,
57    {
58        let level = u8::deserialize(deserializer)?;
59        HeadingLevel::new(level).map_err(serde::de::Error::custom)
60    }
61}
62
63impl Serialize for HeadingLevel {
64    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
65    where
66        S: Serializer,
67    {
68        self.0.serialize(serializer)
69    }
70}
71
72impl Default for HeadingLevel {
73    fn default() -> Self {
74        Self(1) // Safe: 1 is always valid
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81
82    #[test]
83    fn test_valid_heading_levels() {
84        for level in 1..=6 {
85            let h = HeadingLevel::new(level).unwrap();
86            assert_eq!(h.get(), level);
87            assert_eq!(h.as_usize(), level as usize);
88        }
89    }
90
91    #[test]
92    fn test_invalid_heading_levels() {
93        for level in [0, 7, 8, 10, 255] {
94            assert!(HeadingLevel::new(level).is_err());
95        }
96    }
97
98    #[test]
99    fn test_default() {
100        assert_eq!(HeadingLevel::default().get(), 1);
101    }
102
103    #[test]
104    fn test_roundtrip() {
105        // Test that HeadingLevel can be serialized and deserialized within a struct
106        #[derive(serde::Serialize, serde::Deserialize)]
107        struct TestConfig {
108            level: HeadingLevel,
109        }
110
111        let config = TestConfig {
112            level: HeadingLevel::new(3).unwrap(),
113        };
114        let serialized = toml::to_string(&config).unwrap();
115        let deserialized: TestConfig = toml::from_str(&serialized).unwrap();
116        assert_eq!(deserialized.level.get(), 3);
117    }
118
119    #[test]
120    fn test_validation_error() {
121        #[derive(Debug, serde::Deserialize)]
122        struct TestConfig {
123            level: HeadingLevel,
124        }
125
126        let toml_str = "level = 10";
127        let result: Result<TestConfig, _> = toml::from_str(toml_str);
128        assert!(result.is_err());
129        let err = result.unwrap_err().to_string();
130        assert!(err.contains("must be between 1 and 6") || err.contains("got 10"));
131
132        // Also test that valid config deserializes correctly
133        let valid_toml = "level = 3";
134        let config: TestConfig = toml::from_str(valid_toml).unwrap();
135        assert_eq!(config.level.get(), 3);
136    }
137}