1use std::num::{NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8};
4use std::str::FromStr;
5
6use serde::{de::Error, Deserialize};
7
8macro_rules! uint {
9 ($primitive:ident, $nz:ident, $lim:ident, $lim_nz:ident, $bounded:ident) => {
10 #[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, Hash)]
12 #[repr(transparent)]
13 #[serde(transparent)]
14 pub struct $lim<const MAX: $primitive>($primitive);
15
16 impl<const MAX: $primitive> $lim<MAX> {
17 pub const MIN: Self = Self(<$primitive>::MIN);
19
20 pub const MAX: Self = Self(MAX);
22
23 fn new(value: $primitive) -> Result<Self, String> {
24 if value > MAX {
25 Err(format!("value is greater than {}", MAX))
26 } else {
27 Ok(Self(value))
28 }
29 }
30 }
31
32 impl<const MAX: $primitive> FromStr for $lim<MAX> {
33 type Err = String;
34
35 fn from_str(src: &str) -> Result<Self, Self::Err> {
36 Self::new(src.parse::<$primitive>().map_err(|e| e.to_string())?)
37 }
38 }
39
40 impl<const MAX: $primitive> TryFrom<$primitive> for $lim<MAX> {
41 type Error = String;
42
43 fn try_from(value: $primitive) -> Result<Self, Self::Error> {
44 Self::new(value)
45 }
46 }
47
48 impl<'de, const MAX: $primitive> Deserialize<'de> for $lim<MAX> {
49 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
50 where
51 D: serde::Deserializer<'de>,
52 {
53 Self::new(Deserialize::deserialize(deserializer)?).map_err(D::Error::custom)
54 }
55 }
56
57 impl<const MAX: $primitive> From<$lim<MAX>> for $primitive {
58 fn from(value: $lim<MAX>) -> Self {
59 value.0
60 }
61 }
62
63 #[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, Hash)]
65 #[repr(transparent)]
66 #[serde(transparent)]
67 pub struct $lim_nz<const MAX: $primitive>($nz);
68
69 impl<const MAX: $primitive> $lim_nz<MAX> {
70 pub const MIN: Self = Self($nz::MIN);
73
74 pub const MAX: Self = Self(unsafe { $nz::new_unchecked(MAX) });
77
78 fn new(value: $primitive) -> Result<Self, String> {
79 if value > MAX {
80 Err(format!("value is greater than {}", MAX))
81 } else if let Some(value) = $nz::new(value) {
82 Ok(Self(value))
83 } else {
84 Err("value is zero".into())
85 }
86 }
87 }
88
89 impl<const MAX: $primitive> FromStr for $lim_nz<MAX> {
90 type Err = String;
91
92 fn from_str(src: &str) -> Result<Self, Self::Err> {
93 Self::new(src.parse::<$primitive>().map_err(|e| e.to_string())?)
94 }
95 }
96
97 impl<const MAX: $primitive> TryFrom<$primitive> for $lim_nz<MAX> {
98 type Error = String;
99
100 fn try_from(value: $primitive) -> Result<Self, Self::Error> {
101 Self::new(value)
102 }
103 }
104
105 impl<'de, const MAX: $primitive> Deserialize<'de> for $lim_nz<MAX> {
106 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
107 where
108 D: serde::Deserializer<'de>,
109 {
110 Self::new(Deserialize::deserialize(deserializer)?).map_err(D::Error::custom)
111 }
112 }
113
114 impl<const MAX: $primitive> From<$lim_nz<MAX>> for $nz {
115 fn from(value: $lim_nz<MAX>) -> Self {
116 value.0
117 }
118 }
119
120 impl<const MAX: $primitive> From<$lim_nz<MAX>> for $primitive {
121 fn from(value: $lim_nz<MAX>) -> Self {
122 value.0.into()
123 }
124 }
125
126 #[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, Hash)]
130 #[repr(transparent)]
131 #[serde(transparent)]
132 pub struct $bounded<const MIN: $primitive, const MAX: $primitive>($nz);
133
134 impl<const MIN: $primitive, const MAX: $primitive> $bounded<MIN, MAX> {
135 pub const MIN: Self = Self(unsafe { $nz::new_unchecked(MIN) });
137
138 pub const MAX: Self = Self(unsafe { $nz::new_unchecked(MAX) });
140
141 fn new(value: $primitive) -> Result<Self, String> {
142 if value < MIN {
143 Err(format!("value is less than {}", MIN))
144 } else if value > MAX {
145 Err(format!("value is greater than {}", MAX))
146 } else if let Some(value) = $nz::new(value) {
147 Ok(Self(value))
148 } else {
149 Err("value is zero".into())
150 }
151 }
152 }
153
154 impl<const MIN: $primitive, const MAX: $primitive> TryFrom<$primitive>
155 for $bounded<MIN, MAX>
156 {
157 type Error = String;
158
159 fn try_from(value: $primitive) -> Result<Self, Self::Error> {
160 Self::new(value)
161 }
162 }
163
164 impl<const MIN: $primitive, const MAX: $primitive> FromStr for $bounded<MIN, MAX> {
165 type Err = String;
166
167 fn from_str(src: &str) -> Result<Self, Self::Err> {
168 Self::new(src.parse::<$primitive>().map_err(|e| e.to_string())?)
169 }
170 }
171
172 impl<'de, const MIN: $primitive, const MAX: $primitive> Deserialize<'de>
173 for $bounded<MIN, MAX>
174 {
175 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
176 where
177 D: serde::Deserializer<'de>,
178 {
179 Self::new(Deserialize::deserialize(deserializer)?).map_err(D::Error::custom)
180 }
181 }
182
183 impl<const MIN: $primitive, const MAX: $primitive> From<$bounded<MIN, MAX>> for $nz {
184 fn from(value: $bounded<MIN, MAX>) -> Self {
185 value.0
186 }
187 }
188
189 impl<const MIN: $primitive, const MAX: $primitive> From<$bounded<MIN, MAX>> for $primitive {
190 fn from(value: $bounded<MIN, MAX>) -> Self {
191 value.0.into()
192 }
193 }
194 };
195}
196
197uint!(u8, NonZeroU8, LimitedU8, LimitedNonZeroU8, BoundedU8);
198uint!(u16, NonZeroU16, LimitedU16, LimitedNonZeroU16, BoundedU16);
199uint!(u32, NonZeroU32, LimitedU32, LimitedNonZeroU32, BoundedU32);
200uint!(u64, NonZeroU64, LimitedU64, LimitedNonZeroU64, BoundedU64);
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn u8_min_max() {
208 assert_eq!(Ok(LimitedU8::<10>::MIN), 0.try_into());
209 assert_eq!(Ok(LimitedU8::<10>::MAX), 10.try_into());
210 assert_eq!(Ok(LimitedNonZeroU8::<10>::MIN), 1.try_into());
211 assert_eq!(Ok(LimitedNonZeroU8::<10>::MAX), 10.try_into());
212 assert_eq!(Ok(BoundedU8::<7, 10>::MIN), 7.try_into());
213 assert_eq!(Ok(BoundedU8::<7, 10>::MAX), 10.try_into());
214 }
215
216 #[test]
217 fn u8_from_str() {
218 {
219 type LU8 = LimitedU8<10>;
220 assert_eq!(Ok(LU8::MIN), "0".parse());
221 assert_eq!(Ok(LU8::MAX), "10".parse());
222 assert_eq!(Err("value is greater than 10".into()), "11".parse::<LU8>());
223 }
224 {
225 type LU8 = LimitedNonZeroU8<10>;
226 assert_eq!(Ok(LU8::MIN), "1".parse());
227 assert_eq!(Ok(LU8::MAX), "10".parse());
228 assert_eq!(Err("value is greater than 10".into()), "11".parse::<LU8>());
229 }
230 {
231 type BU8 = BoundedU8<7, 10>;
232 assert_eq!(Err("value is less than 7".into()), "6".parse::<BU8>());
233 assert_eq!(Ok(BU8::MIN), "7".parse());
234 assert_eq!(Ok(BU8::MAX), "10".parse());
235 assert_eq!(Err("value is greater than 10".into()), "11".parse::<BU8>());
236 }
237 }
238
239 #[test]
240 fn deserialize_u8_from_str() {
241 {
242 #[derive(Deserialize, Debug)]
243 struct Foo {
244 bar: LimitedU8<10>,
245 }
246
247 match serde_json::from_str::<Foo>(r#"{"bar": 0}"#) {
248 Ok(foo) => assert_eq!(foo.bar, LimitedU8::<10>::MIN),
249 Err(e) => panic!("failed to deserialize: {e}"),
250 }
251 match serde_json::from_str::<Foo>(r#"{"bar": "0"}"#) {
252 Ok(_) => panic!("deserialization should fail"),
253 Err(e) => assert!(e.to_string().contains("invalid type: string")),
254 }
255 match serde_html_form::from_str::<Foo>(r#"bar=0"#) {
256 Ok(foo) => assert_eq!(foo.bar, LimitedU8::<10>::MIN),
257 Err(e) => panic!("failed to deserialize: {e}"),
258 }
259 match serde_html_form::from_str::<Foo>(r#"bar=10"#) {
260 Ok(foo) => assert_eq!(foo.bar, LimitedU8::<10>::MAX),
261 Err(e) => panic!("failed to deserialize: {e}"),
262 }
263 match serde_html_form::from_str::<Foo>(r#"bar=11"#) {
264 Ok(_) => panic!("deserialization should fail"),
265 Err(e) => assert_eq!(e.to_string(), "value is greater than 10"),
266 }
267 }
268
269 {
270 #[derive(Deserialize, Debug)]
271 struct Foo {
272 bar: LimitedNonZeroU8<10>,
273 }
274
275 match serde_json::from_str::<Foo>(r#"{"bar": 0}"#) {
276 Ok(_) => panic!("deserialization should fail"),
277 Err(e) => assert_eq!(e.to_string(), "value is zero at line 1 column 10"),
278 }
279 match serde_json::from_str::<Foo>(r#"{"bar": "0"}"#) {
280 Ok(_) => panic!("deserialization should fail"),
281 Err(e) => assert!(e.to_string().contains("invalid type: string")),
282 }
283 match serde_html_form::from_str::<Foo>(r#"bar=0"#) {
284 Ok(_) => panic!("deserialization should fail"),
285 Err(e) => assert_eq!(e.to_string(), "value is zero"),
286 }
287 match serde_html_form::from_str::<Foo>(r#"bar=10"#) {
288 Ok(foo) => assert_eq!(foo.bar, LimitedNonZeroU8::<10>::MAX),
289 Err(e) => panic!("failed to deserialize: {e}"),
290 }
291 match serde_html_form::from_str::<Foo>(r#"bar=11"#) {
292 Ok(_) => panic!("deserialization should fail"),
293 Err(e) => assert_eq!(e.to_string(), "value is greater than 10"),
294 }
295 }
296
297 {
298 #[derive(Deserialize, Debug)]
299 struct Foo {
300 bar: BoundedU8<1, 10>,
301 }
302
303 match serde_json::from_str::<Foo>(r#"{"bar": 0}"#) {
304 Ok(_) => panic!("deserialization should fail"),
305 Err(e) => assert_eq!(e.to_string(), "value is less than 1 at line 1 column 10"),
306 }
307 match serde_json::from_str::<Foo>(r#"{"bar": "0"}"#) {
308 Ok(_) => panic!("deserialization should fail"),
309 Err(e) => assert!(e.to_string().contains("invalid type: string")),
310 }
311 match serde_html_form::from_str::<Foo>(r#"bar=0"#) {
312 Ok(_) => panic!("deserialization should fail"),
313 Err(e) => assert_eq!(e.to_string(), "value is less than 1"),
314 }
315 match serde_html_form::from_str::<Foo>(r#"bar=10"#) {
316 Ok(foo) => assert_eq!(foo.bar, BoundedU8::<1, 10>::MAX),
317 Err(e) => panic!("failed to deserialize: {e}"),
318 }
319 match serde_html_form::from_str::<Foo>(r#"bar=11"#) {
320 Ok(_) => panic!("deserialization should fail"),
321 Err(e) => assert_eq!(e.to_string(), "value is greater than 10"),
322 }
323 }
324 }
325}