Skip to main content

entrenar/yaml_mode/manifest/
shorthand.rs

1//! Human-readable value shorthand for YAML configs.
2//!
3//! Supports SI suffix notation for integer fields:
4//! - `32K` → 32,768 (kibi, powers of 2)
5//! - `1M` → 1,048,576 (mebi)
6//! - `10B` → 10,000,000,000 (giga, base-10 for token counts)
7//!
8//! Also supports:
9//! - Plain integers: `1024`
10//! - Underscore notation: `32_768` (YAML native)
11//! - Scientific notation strings: `"1e6"` → 1,000,000
12
13use serde::{Deserialize, Deserializer};
14
15/// Parse a human-readable size string into a usize.
16///
17/// Supports:
18/// - Plain numbers: "1024", "32768"
19/// - SI suffixes (binary): "32K" (32*1024), "1M" (1*1024²), "1G" (1*1024³)
20/// - SI suffixes (decimal): "10B" (10*10⁹), "1T" (1*10¹²)
21/// - Scientific notation: "1e6", "3.2e4"
22///
23/// Note: K/M use binary (powers of 2) since they're used for model dimensions.
24/// B/T use decimal since they're used for token/parameter counts where "10B" means 10 billion.
25pub fn parse_human_usize(s: &str) -> Result<usize, String> {
26    let s = s.trim();
27    if s.is_empty() {
28        return Err("empty string".into());
29    }
30
31    // Try scientific notation first (e.g., "1e6", "3.2e4")
32    if s.contains('e') || s.contains('E') {
33        return s
34            .parse::<f64>()
35            .map(|v| v as usize)
36            .map_err(|e| format!("invalid scientific notation '{s}': {e}"));
37    }
38
39    // Check for SI suffix
40    let (num_str, multiplier) = match s.as_bytes().last() {
41        Some(b'K' | b'k') => (&s[..s.len() - 1], 1024_usize),
42        Some(b'M' | b'm') => (&s[..s.len() - 1], 1024 * 1024),
43        Some(b'G' | b'g') => (&s[..s.len() - 1], 1024 * 1024 * 1024),
44        Some(b'B' | b'b') => (&s[..s.len() - 1], 1_000_000_000_usize),
45        Some(b'T' | b't') => (&s[..s.len() - 1], 1_000_000_000_000_usize),
46        _ => (s, 1),
47    };
48
49    // Parse the numeric part (allow float for "1.5K" etc.)
50    if num_str.contains('.') {
51        let v: f64 = num_str.parse().map_err(|e| format!("invalid number '{num_str}': {e}"))?;
52        Ok((v * multiplier as f64) as usize)
53    } else {
54        let v: usize = num_str.parse().map_err(|e| format!("invalid number '{num_str}': {e}"))?;
55        v.checked_mul(multiplier).ok_or_else(|| format!("overflow: {v} * {multiplier}"))
56    }
57}
58
59/// Deserialize an `Option<usize>` that accepts both numbers and human-readable strings.
60///
61/// # Examples (YAML)
62/// ```yaml
63/// vocab_size: 32K       # → Some(32768)
64/// vocab_size: 32768     # → Some(32768)
65/// vocab_size: "1e5"     # → Some(100000)
66/// ```
67pub fn deserialize_human_usize_opt<'de, D>(deserializer: D) -> Result<Option<usize>, D::Error>
68where
69    D: Deserializer<'de>,
70{
71    #[derive(Deserialize)]
72    #[serde(untagged)]
73    enum NumOrStr {
74        Num(usize),
75        Float(f64),
76        Str(String),
77    }
78
79    let opt: Option<NumOrStr> = Option::deserialize(deserializer)?;
80    match opt {
81        None => Ok(None),
82        Some(NumOrStr::Num(n)) => Ok(Some(n)),
83        Some(NumOrStr::Float(f)) => Ok(Some(f as usize)),
84        Some(NumOrStr::Str(s)) => parse_human_usize(&s).map(Some).map_err(serde::de::Error::custom),
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91
92    #[test]
93    fn test_plain_numbers() {
94        assert_eq!(parse_human_usize("1024").expect("valid"), 1024);
95        assert_eq!(parse_human_usize("0").expect("valid"), 0);
96        assert_eq!(parse_human_usize("32768").expect("valid"), 32768);
97    }
98
99    #[test]
100    fn test_si_suffix_binary() {
101        assert_eq!(parse_human_usize("32K").expect("valid"), 32 * 1024);
102        assert_eq!(parse_human_usize("1M").expect("valid"), 1024 * 1024);
103        assert_eq!(parse_human_usize("1G").expect("valid"), 1024 * 1024 * 1024);
104    }
105
106    #[test]
107    fn test_si_suffix_lowercase() {
108        assert_eq!(parse_human_usize("32k").expect("valid"), 32 * 1024);
109        assert_eq!(parse_human_usize("1m").expect("valid"), 1024 * 1024);
110    }
111
112    #[test]
113    fn test_si_suffix_decimal() {
114        assert_eq!(parse_human_usize("10B").expect("valid"), 10_000_000_000);
115        assert_eq!(parse_human_usize("1T").expect("valid"), 1_000_000_000_000);
116    }
117
118    #[test]
119    fn test_scientific_notation() {
120        assert_eq!(parse_human_usize("1e6").expect("valid"), 1_000_000);
121        assert_eq!(parse_human_usize("3.2e4").expect("valid"), 32000);
122        assert_eq!(parse_human_usize("1E5").expect("valid"), 100_000);
123    }
124
125    #[test]
126    fn test_fractional_suffix() {
127        assert_eq!(parse_human_usize("1.5K").expect("valid"), 1536); // 1.5 * 1024
128        assert_eq!(parse_human_usize("0.5M").expect("valid"), 524_288); // 0.5 * 1M
129    }
130
131    #[test]
132    fn test_empty_string_errors() {
133        assert!(parse_human_usize("").is_err());
134    }
135
136    #[test]
137    fn test_invalid_string_errors() {
138        assert!(parse_human_usize("abc").is_err());
139        assert!(parse_human_usize("K").is_err());
140    }
141
142    #[test]
143    fn test_whitespace_trimmed() {
144        assert_eq!(parse_human_usize("  32K  ").expect("valid"), 32 * 1024);
145    }
146
147    #[test]
148    fn test_serde_deserialize_number() {
149        #[derive(Deserialize)]
150        struct Config {
151            #[serde(
152                default,
153                skip_serializing_if = "Option::is_none",
154                deserialize_with = "deserialize_human_usize_opt"
155            )]
156            vocab_size: Option<usize>,
157        }
158
159        let yaml = "vocab_size: 32768";
160        let config: Config = serde_yaml::from_str(yaml).expect("should parse");
161        assert_eq!(config.vocab_size, Some(32768));
162    }
163
164    #[test]
165    fn test_serde_deserialize_string_suffix() {
166        #[derive(Deserialize)]
167        struct Config {
168            #[serde(
169                default,
170                skip_serializing_if = "Option::is_none",
171                deserialize_with = "deserialize_human_usize_opt"
172            )]
173            vocab_size: Option<usize>,
174        }
175
176        let yaml = "vocab_size: \"32K\"";
177        let config: Config = serde_yaml::from_str(yaml).expect("should parse");
178        assert_eq!(config.vocab_size, Some(32 * 1024));
179    }
180
181    #[test]
182    fn test_serde_deserialize_none() {
183        #[derive(Deserialize)]
184        struct Config {
185            #[serde(
186                default,
187                skip_serializing_if = "Option::is_none",
188                deserialize_with = "deserialize_human_usize_opt"
189            )]
190            vocab_size: Option<usize>,
191        }
192
193        let yaml = "other: 123";
194        let config: Config = serde_yaml::from_str(yaml).expect("should parse");
195        assert_eq!(config.vocab_size, None);
196    }
197
198    #[test]
199    fn test_serde_deserialize_scientific() {
200        #[derive(Deserialize)]
201        struct Config {
202            #[serde(
203                default,
204                skip_serializing_if = "Option::is_none",
205                deserialize_with = "deserialize_human_usize_opt"
206            )]
207            count: Option<usize>,
208        }
209
210        let yaml = "count: \"1e6\"";
211        let config: Config = serde_yaml::from_str(yaml).expect("should parse");
212        assert_eq!(config.count, Some(1_000_000));
213    }
214}