entrenar/yaml_mode/manifest/
shorthand.rs1use serde::{Deserialize, Deserializer};
14
15pub fn parse_human_usize(s: &str) -> Result<usize, String> {
26 let s = s.trim();
27 if s.is_empty() {
28 return Err("empty string".into());
29 }
30
31 if s.contains('e') || s.contains('E') {
33 return s
34 .parse::<f64>()
35 .map(|v| v as usize)
36 .map_err(|e| format!("invalid scientific notation '{s}': {e}"));
37 }
38
39 let (num_str, multiplier) = match s.as_bytes().last() {
41 Some(b'K' | b'k') => (&s[..s.len() - 1], 1024_usize),
42 Some(b'M' | b'm') => (&s[..s.len() - 1], 1024 * 1024),
43 Some(b'G' | b'g') => (&s[..s.len() - 1], 1024 * 1024 * 1024),
44 Some(b'B' | b'b') => (&s[..s.len() - 1], 1_000_000_000_usize),
45 Some(b'T' | b't') => (&s[..s.len() - 1], 1_000_000_000_000_usize),
46 _ => (s, 1),
47 };
48
49 if num_str.contains('.') {
51 let v: f64 = num_str.parse().map_err(|e| format!("invalid number '{num_str}': {e}"))?;
52 Ok((v * multiplier as f64) as usize)
53 } else {
54 let v: usize = num_str.parse().map_err(|e| format!("invalid number '{num_str}': {e}"))?;
55 v.checked_mul(multiplier).ok_or_else(|| format!("overflow: {v} * {multiplier}"))
56 }
57}
58
59pub fn deserialize_human_usize_opt<'de, D>(deserializer: D) -> Result<Option<usize>, D::Error>
68where
69 D: Deserializer<'de>,
70{
71 #[derive(Deserialize)]
72 #[serde(untagged)]
73 enum NumOrStr {
74 Num(usize),
75 Float(f64),
76 Str(String),
77 }
78
79 let opt: Option<NumOrStr> = Option::deserialize(deserializer)?;
80 match opt {
81 None => Ok(None),
82 Some(NumOrStr::Num(n)) => Ok(Some(n)),
83 Some(NumOrStr::Float(f)) => Ok(Some(f as usize)),
84 Some(NumOrStr::Str(s)) => parse_human_usize(&s).map(Some).map_err(serde::de::Error::custom),
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91
92 #[test]
93 fn test_plain_numbers() {
94 assert_eq!(parse_human_usize("1024").expect("valid"), 1024);
95 assert_eq!(parse_human_usize("0").expect("valid"), 0);
96 assert_eq!(parse_human_usize("32768").expect("valid"), 32768);
97 }
98
99 #[test]
100 fn test_si_suffix_binary() {
101 assert_eq!(parse_human_usize("32K").expect("valid"), 32 * 1024);
102 assert_eq!(parse_human_usize("1M").expect("valid"), 1024 * 1024);
103 assert_eq!(parse_human_usize("1G").expect("valid"), 1024 * 1024 * 1024);
104 }
105
106 #[test]
107 fn test_si_suffix_lowercase() {
108 assert_eq!(parse_human_usize("32k").expect("valid"), 32 * 1024);
109 assert_eq!(parse_human_usize("1m").expect("valid"), 1024 * 1024);
110 }
111
112 #[test]
113 fn test_si_suffix_decimal() {
114 assert_eq!(parse_human_usize("10B").expect("valid"), 10_000_000_000);
115 assert_eq!(parse_human_usize("1T").expect("valid"), 1_000_000_000_000);
116 }
117
118 #[test]
119 fn test_scientific_notation() {
120 assert_eq!(parse_human_usize("1e6").expect("valid"), 1_000_000);
121 assert_eq!(parse_human_usize("3.2e4").expect("valid"), 32000);
122 assert_eq!(parse_human_usize("1E5").expect("valid"), 100_000);
123 }
124
125 #[test]
126 fn test_fractional_suffix() {
127 assert_eq!(parse_human_usize("1.5K").expect("valid"), 1536); assert_eq!(parse_human_usize("0.5M").expect("valid"), 524_288); }
130
131 #[test]
132 fn test_empty_string_errors() {
133 assert!(parse_human_usize("").is_err());
134 }
135
136 #[test]
137 fn test_invalid_string_errors() {
138 assert!(parse_human_usize("abc").is_err());
139 assert!(parse_human_usize("K").is_err());
140 }
141
142 #[test]
143 fn test_whitespace_trimmed() {
144 assert_eq!(parse_human_usize(" 32K ").expect("valid"), 32 * 1024);
145 }
146
147 #[test]
148 fn test_serde_deserialize_number() {
149 #[derive(Deserialize)]
150 struct Config {
151 #[serde(
152 default,
153 skip_serializing_if = "Option::is_none",
154 deserialize_with = "deserialize_human_usize_opt"
155 )]
156 vocab_size: Option<usize>,
157 }
158
159 let yaml = "vocab_size: 32768";
160 let config: Config = serde_yaml::from_str(yaml).expect("should parse");
161 assert_eq!(config.vocab_size, Some(32768));
162 }
163
164 #[test]
165 fn test_serde_deserialize_string_suffix() {
166 #[derive(Deserialize)]
167 struct Config {
168 #[serde(
169 default,
170 skip_serializing_if = "Option::is_none",
171 deserialize_with = "deserialize_human_usize_opt"
172 )]
173 vocab_size: Option<usize>,
174 }
175
176 let yaml = "vocab_size: \"32K\"";
177 let config: Config = serde_yaml::from_str(yaml).expect("should parse");
178 assert_eq!(config.vocab_size, Some(32 * 1024));
179 }
180
181 #[test]
182 fn test_serde_deserialize_none() {
183 #[derive(Deserialize)]
184 struct Config {
185 #[serde(
186 default,
187 skip_serializing_if = "Option::is_none",
188 deserialize_with = "deserialize_human_usize_opt"
189 )]
190 vocab_size: Option<usize>,
191 }
192
193 let yaml = "other: 123";
194 let config: Config = serde_yaml::from_str(yaml).expect("should parse");
195 assert_eq!(config.vocab_size, None);
196 }
197
198 #[test]
199 fn test_serde_deserialize_scientific() {
200 #[derive(Deserialize)]
201 struct Config {
202 #[serde(
203 default,
204 skip_serializing_if = "Option::is_none",
205 deserialize_with = "deserialize_human_usize_opt"
206 )]
207 count: Option<usize>,
208 }
209
210 let yaml = "count: \"1e6\"";
211 let config: Config = serde_yaml::from_str(yaml).expect("should parse");
212 assert_eq!(config.count, Some(1_000_000));
213 }
214}