Skip to main content

usage/spec/
helpers.rs

1use indexmap::IndexMap;
2use kdl::{KdlEntry, KdlEntryFormat, KdlNode, KdlValue};
3use miette::SourceSpan;
4use std::fmt::Debug;
5use std::ops::RangeBounds;
6
7use crate::error::UsageErr;
8use crate::spec::context::ParsingContext;
9
10/// Compute the number of `#` characters needed for a raw multiline string.
11/// We need `n` such that the value does not contain `"""` followed by `n` `#` characters.
12fn raw_multiline_hash_count(value: &str) -> usize {
13    let mut max_count = 0;
14    for line in value.lines() {
15        for (idx, _) in line.match_indices("\"\"\"") {
16            let after = &line[idx + 3..];
17            let count = after.chars().take_while(|&c| c == '#').count();
18            max_count = max_count.max(count);
19        }
20    }
21    max_count + 1
22}
23
24/// Create a KdlEntry for a string value, using KDL raw multiline string syntax (`#"""..."""#`)
25/// when the value contains newlines. The number of `#` characters is automatically determined
26/// to ensure the value can be embedded safely.
27pub(crate) fn string_entry(key: Option<&str>, value: &str) -> KdlEntry {
28    let mut entry = match key {
29        Some(k) => KdlEntry::new_prop(k, KdlValue::String(value.to_string())),
30        None => KdlEntry::new(KdlValue::String(value.to_string())),
31    };
32    if value.contains('\n') {
33        let n = raw_multiline_hash_count(value);
34        let hashes = "#".repeat(n);
35        let repr = format!("{hashes}\"\"\"\n{value}\n\"\"\"{hashes}");
36        entry.set_format(KdlEntryFormat {
37            value_repr: repr,
38            leading: " ".into(),
39            trailing: "".into(),
40            after_ty: "".into(),
41            before_ty_name: "".into(),
42            after_ty_name: "".into(),
43            after_key: "".into(),
44            after_eq: "".into(),
45            autoformat_keep: true,
46        });
47    }
48    entry
49}
50
51#[derive(Debug)]
52pub struct NodeHelper<'a> {
53    pub(crate) node: &'a KdlNode,
54    pub(crate) ctx: &'a ParsingContext,
55}
56
57impl<'a> NodeHelper<'a> {
58    pub(crate) fn new(ctx: &'a ParsingContext, node: &'a KdlNode) -> Self {
59        Self { node, ctx }
60    }
61
62    pub(crate) fn name(&self) -> &str {
63        self.node.name().value()
64    }
65    pub(crate) fn span(&self) -> SourceSpan {
66        (self.node.span().offset(), self.node.span().len()).into()
67    }
68    pub(crate) fn ensure_arg_len<R>(&self, range: R) -> Result<&Self, UsageErr>
69    where
70        R: RangeBounds<usize> + Debug,
71    {
72        let count = self.args().count();
73        if !range.contains(&count) {
74            let ctx = self.ctx;
75            let span = self.span();
76            bail_parse!(ctx, span, "expected {range:?} arguments, got {count}",)
77        }
78        Ok(self)
79    }
80    pub(crate) fn get(&self, key: &str) -> Option<ParseEntry<'_>> {
81        self.node.entry(key).map(|e| ParseEntry::new(self.ctx, e))
82    }
83    pub(crate) fn arg(&self, i: usize) -> Result<ParseEntry<'_>, UsageErr> {
84        if let Some(entry) = self.args().nth(i) {
85            return Ok(entry);
86        }
87        bail_parse!(self.ctx, self.span(), "missing argument")
88    }
89    pub(crate) fn args(&self) -> impl Iterator<Item = ParseEntry<'_>> + '_ {
90        self.node
91            .entries()
92            .iter()
93            .filter(|e| e.name().is_none())
94            .map(|e| ParseEntry::new(self.ctx, e))
95    }
96    pub(crate) fn props(&self) -> IndexMap<&str, ParseEntry<'_>> {
97        self.node
98            .entries()
99            .iter()
100            .filter_map(|e| {
101                e.name()
102                    .map(|key| (key.value(), ParseEntry::new(self.ctx, e)))
103            })
104            .collect()
105    }
106    pub(crate) fn children(&self) -> Vec<Self> {
107        self.node
108            .children()
109            .map(|c| {
110                c.nodes()
111                    .iter()
112                    .map(|n| NodeHelper::new(self.ctx, n))
113                    .collect()
114            })
115            .unwrap_or_default()
116    }
117}
118
119#[derive(Debug)]
120pub(crate) struct ParseEntry<'a> {
121    pub(crate) ctx: &'a ParsingContext,
122    pub(crate) entry: &'a KdlEntry,
123    pub(crate) value: &'a KdlValue,
124}
125
126impl<'a> ParseEntry<'a> {
127    fn new(ctx: &'a ParsingContext, entry: &'a KdlEntry) -> Self {
128        Self {
129            ctx,
130            entry,
131            value: entry.value(),
132        }
133    }
134
135    fn span(&self) -> SourceSpan {
136        (self.entry.span().offset(), self.entry.span().len()).into()
137    }
138}
139
140impl ParseEntry<'_> {
141    pub fn ensure_usize(&self) -> Result<usize, UsageErr> {
142        match self.value.as_integer() {
143            Some(i) => Ok(i as usize),
144            None => bail_parse!(self.ctx, self.span(), "expected usize"),
145        }
146    }
147    #[allow(dead_code)]
148    pub fn ensure_f64(&self) -> Result<f64, UsageErr> {
149        match self.value.as_float() {
150            Some(f) => Ok(f),
151            None => bail_parse!(self.ctx, self.span(), "expected float"),
152        }
153    }
154    pub fn ensure_bool(&self) -> Result<bool, UsageErr> {
155        match self.value.as_bool() {
156            Some(b) => Ok(b),
157            None => bail_parse!(self.ctx, self.span(), "expected bool"),
158        }
159    }
160    pub fn ensure_string(&self) -> Result<String, UsageErr> {
161        match self.value.as_string() {
162            Some(s) => Ok(s.to_string()),
163            None => bail_parse!(self.ctx, self.span(), "expected string"),
164        }
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use kdl::KdlDocument;
172    use std::path::Path;
173
174    fn parse_node(input: &str) -> (ParsingContext, KdlDocument) {
175        let ctx = ParsingContext::new(Path::new("test.kdl"), input);
176        let doc: KdlDocument = input.parse().unwrap();
177        (ctx, doc)
178    }
179
180    #[test]
181    fn test_node_helper_name() {
182        let (ctx, doc) = parse_node("test_node \"arg1\"");
183        let node = doc.nodes().first().unwrap();
184        let helper = NodeHelper::new(&ctx, node);
185        assert_eq!(helper.name(), "test_node");
186    }
187
188    #[test]
189    fn test_node_helper_arg() {
190        let (ctx, doc) = parse_node("node \"first\" \"second\"");
191        let node = doc.nodes().first().unwrap();
192        let helper = NodeHelper::new(&ctx, node);
193
194        assert_eq!(helper.arg(0).unwrap().ensure_string().unwrap(), "first");
195        assert_eq!(helper.arg(1).unwrap().ensure_string().unwrap(), "second");
196    }
197
198    #[test]
199    fn test_node_helper_args_count() {
200        let (ctx, doc) = parse_node("node \"a\" \"b\" \"c\"");
201        let node = doc.nodes().first().unwrap();
202        let helper = NodeHelper::new(&ctx, node);
203
204        assert_eq!(helper.args().count(), 3);
205    }
206
207    #[test]
208    fn test_node_helper_props() {
209        let (ctx, doc) = parse_node("node key1=\"value1\" key2=\"value2\"");
210        let node = doc.nodes().first().unwrap();
211        let helper = NodeHelper::new(&ctx, node);
212
213        let props = helper.props();
214        assert_eq!(props.len(), 2);
215        assert_eq!(props["key1"].ensure_string().unwrap(), "value1");
216        assert_eq!(props["key2"].ensure_string().unwrap(), "value2");
217    }
218
219    #[test]
220    fn test_node_helper_get() {
221        let (ctx, doc) = parse_node("node name=\"test\"");
222        let node = doc.nodes().first().unwrap();
223        let helper = NodeHelper::new(&ctx, node);
224
225        assert!(helper.get("name").is_some());
226        assert!(helper.get("nonexistent").is_none());
227    }
228
229    #[test]
230    fn test_node_helper_children() {
231        let (ctx, doc) = parse_node("parent { child1; child2 }");
232        let node = doc.nodes().first().unwrap();
233        let helper = NodeHelper::new(&ctx, node);
234
235        let children = helper.children();
236        assert_eq!(children.len(), 2);
237        assert_eq!(children[0].name(), "child1");
238        assert_eq!(children[1].name(), "child2");
239    }
240
241    #[test]
242    fn test_node_helper_ensure_arg_len_valid() {
243        let (ctx, doc) = parse_node("node \"a\" \"b\"");
244        let node = doc.nodes().first().unwrap();
245        let helper = NodeHelper::new(&ctx, node);
246
247        assert!(helper.ensure_arg_len(2..=2).is_ok());
248        assert!(helper.ensure_arg_len(1..=3).is_ok());
249        assert!(helper.ensure_arg_len(0..).is_ok());
250    }
251
252    #[test]
253    fn test_node_helper_ensure_arg_len_invalid() {
254        let (ctx, doc) = parse_node("node \"a\"");
255        let node = doc.nodes().first().unwrap();
256        let helper = NodeHelper::new(&ctx, node);
257
258        assert!(helper.ensure_arg_len(2..=2).is_err());
259    }
260
261    #[test]
262    fn test_parse_entry_ensure_usize() {
263        let (ctx, doc) = parse_node("node 42");
264        let node = doc.nodes().first().unwrap();
265        let helper = NodeHelper::new(&ctx, node);
266
267        assert_eq!(helper.arg(0).unwrap().ensure_usize().unwrap(), 42);
268    }
269
270    #[test]
271    fn test_parse_entry_ensure_bool() {
272        let (ctx, doc) = parse_node("node #true");
273        let node = doc.nodes().first().unwrap();
274        let helper = NodeHelper::new(&ctx, node);
275
276        assert!(helper.arg(0).unwrap().ensure_bool().unwrap());
277    }
278
279    #[test]
280    fn test_parse_entry_ensure_string() {
281        let (ctx, doc) = parse_node("node \"hello\"");
282        let node = doc.nodes().first().unwrap();
283        let helper = NodeHelper::new(&ctx, node);
284
285        assert_eq!(helper.arg(0).unwrap().ensure_string().unwrap(), "hello");
286    }
287
288    #[test]
289    fn test_parse_entry_type_mismatch() {
290        let (ctx, doc) = parse_node("node \"not_a_number\"");
291        let node = doc.nodes().first().unwrap();
292        let helper = NodeHelper::new(&ctx, node);
293
294        assert!(helper.arg(0).unwrap().ensure_usize().is_err());
295    }
296}