rumdl_lib/types/
heading_level.rs1use serde::{Deserialize, Deserializer, Serialize, Serializer};
2use std::fmt;
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
10pub struct HeadingLevel(u8);
11
12impl HeadingLevel {
13 pub fn new(level: u8) -> Result<Self, HeadingLevelError> {
18 if (1..=6).contains(&level) {
19 Ok(Self(level))
20 } else {
21 Err(HeadingLevelError(level))
22 }
23 }
24
25 pub fn get(self) -> u8 {
27 self.0
28 }
29
30 pub fn as_usize(self) -> usize {
32 self.0 as usize
33 }
34}
35
36#[derive(Debug, Clone, Copy)]
38pub struct HeadingLevelError(u8);
39
40impl fmt::Display for HeadingLevelError {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 write!(
43 f,
44 "Heading level must be between 1 and 6, got {}. \
45 Markdown supports only 6 heading levels (# through ######).",
46 self.0
47 )
48 }
49}
50
51impl std::error::Error for HeadingLevelError {}
52
53impl<'de> Deserialize<'de> for HeadingLevel {
54 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
55 where
56 D: Deserializer<'de>,
57 {
58 let level = u8::deserialize(deserializer)?;
59 HeadingLevel::new(level).map_err(serde::de::Error::custom)
60 }
61}
62
63impl Serialize for HeadingLevel {
64 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
65 where
66 S: Serializer,
67 {
68 self.0.serialize(serializer)
69 }
70}
71
72impl Default for HeadingLevel {
73 fn default() -> Self {
74 Self(1) }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81
82 #[test]
83 fn test_valid_heading_levels() {
84 for level in 1..=6 {
85 let h = HeadingLevel::new(level).unwrap();
86 assert_eq!(h.get(), level);
87 assert_eq!(h.as_usize(), level as usize);
88 }
89 }
90
91 #[test]
92 fn test_invalid_heading_levels() {
93 for level in [0, 7, 8, 10, 255] {
94 assert!(HeadingLevel::new(level).is_err());
95 }
96 }
97
98 #[test]
99 fn test_default() {
100 assert_eq!(HeadingLevel::default().get(), 1);
101 }
102
103 #[test]
104 fn test_roundtrip() {
105 #[derive(serde::Serialize, serde::Deserialize)]
107 struct TestConfig {
108 level: HeadingLevel,
109 }
110
111 let config = TestConfig {
112 level: HeadingLevel::new(3).unwrap(),
113 };
114 let serialized = toml::to_string(&config).unwrap();
115 let deserialized: TestConfig = toml::from_str(&serialized).unwrap();
116 assert_eq!(deserialized.level.get(), 3);
117 }
118
119 #[test]
120 fn test_validation_error() {
121 #[derive(Debug, serde::Deserialize)]
122 struct TestConfig {
123 level: HeadingLevel,
124 }
125
126 let toml_str = "level = 10";
127 let result: Result<TestConfig, _> = toml::from_str(toml_str);
128 assert!(result.is_err());
129 let err = result.unwrap_err().to_string();
130 assert!(err.contains("must be between 1 and 6") || err.contains("got 10"));
131
132 let valid_toml = "level = 3";
134 let config: TestConfig = toml::from_str(valid_toml).unwrap();
135 assert_eq!(config.level.get(), 3);
136 }
137}