aiscript_directive/
lib.rs

1use std::collections::HashMap;
2
3use aiscript_lexer::{Scanner, TokenType};
4
5use serde_json::Value;
6
7pub use validator::Validator;
8pub mod route;
9pub mod validator;
10
11pub trait FromDirective {
12    fn from_directive(directive: Directive) -> Result<Self, String>
13    where
14        Self: Sized;
15}
16
17#[derive(Debug, Clone, Eq, PartialEq)]
18pub struct Directive {
19    pub name: String,
20    pub params: DirectiveParams,
21    pub line: u32,
22}
23
24#[derive(Debug, Clone, Eq, PartialEq)]
25pub enum DirectiveParams {
26    KeyValue(HashMap<String, Value>),
27    Array(Vec<Value>),
28    Directives(Vec<Directive>),
29}
30
31impl Directive {
32    pub fn get_arg_value(&self, name: &'static str) -> Option<&Value> {
33        if let DirectiveParams::KeyValue(kv) = &self.params {
34            kv.get(name)
35        } else {
36            None
37        }
38    }
39}
40
41pub struct DirectiveParser<'a, 'b: 'a> {
42    scanner: &'a mut Scanner<'b>,
43}
44
45impl<'a, 'b> DirectiveParser<'a, 'b> {
46    pub fn new(scanner: &'a mut Scanner<'b>) -> Self {
47        if scanner.check(TokenType::Eof) {
48            scanner.advance();
49        }
50        Self { scanner }
51    }
52
53    #[must_use]
54    pub fn parse_validators(&mut self) -> Vec<Box<dyn Validator>> {
55        self.parse_directives()
56            .into_iter()
57            .filter_map(|directive| match FromDirective::from_directive(directive) {
58                Ok(validator) => Some(validator),
59                Err(err) => {
60                    self.scanner.error(&err);
61                    None
62                }
63            })
64            .collect()
65    }
66
67    #[must_use]
68    pub fn parse_directives(&mut self) -> Vec<Directive> {
69        let mut directives = Vec::new();
70        while self.scanner.check(TokenType::At) {
71            if let Some(directive) = self.parse_directive() {
72                directives.push(directive);
73            }
74        }
75        directives
76    }
77
78    #[must_use]
79    pub fn parse_directive(&mut self) -> Option<Directive> {
80        self.scanner
81            .consume(TokenType::At, "Expected '@' at start of directive");
82
83        if self.scanner.is_at_end() {
84            self.scanner.error_at_current("Unexpected end");
85            return None;
86        }
87
88        let name_token = self.scanner.current;
89        self.scanner.advance();
90        let name = name_token.lexeme.to_owned();
91
92        let params = if self.scanner.match_token(TokenType::OpenParen) {
93            let params = self.parse_parameters()?;
94            self.scanner
95                .consume(TokenType::CloseParen, "Expect ')' after parameters.");
96            params
97        } else {
98            DirectiveParams::KeyValue(HashMap::new())
99        };
100
101        Some(Directive {
102            name,
103            params,
104            line: name_token.line,
105        })
106    }
107
108    fn parse_parameters(&mut self) -> Option<DirectiveParams> {
109        // Handle empty parentheses case first
110        if self.scanner.check(TokenType::CloseParen) {
111            return Some(DirectiveParams::KeyValue(HashMap::new()));
112        }
113
114        if self.scanner.check(TokenType::OpenBracket) {
115            // Parse array
116            let array = self.parse_array()?;
117            Some(DirectiveParams::Array(array))
118        } else if self.scanner.check(TokenType::At) {
119            // Parse one or more directives separated by commas
120            // self.scanner.advance(); // consume '@'
121            let mut directives = Vec::new();
122            loop {
123                if let Some(directive) = self.parse_directive() {
124                    directives.push(directive);
125                }
126                if !self.scanner.check(TokenType::Comma) {
127                    break;
128                }
129                self.scanner.advance(); // consume ','
130            }
131            Some(DirectiveParams::Directives(directives))
132        } else if self.scanner.check(TokenType::Identifier) {
133            // Parse key-value parameters
134            let mut params = HashMap::new();
135            while !self.scanner.check(TokenType::CloseParen) {
136                self.scanner
137                    .consume(TokenType::Identifier, "Expect parameter key.");
138                let key = self.scanner.previous.lexeme.to_owned();
139                self.scanner
140                    .consume(TokenType::Equal, "Expect '=' after parameter key.");
141                let value = self.parse_value()?;
142                params.insert(key, value);
143                if !self.scanner.check(TokenType::Comma) {
144                    break;
145                }
146                self.scanner.advance(); // consume ','
147            }
148            Some(DirectiveParams::KeyValue(params))
149        } else {
150            self.scanner.error("Expected parameters.");
151            None
152        }
153    }
154
155    fn parse_array(&mut self) -> Option<Vec<Value>> {
156        self.scanner
157            .consume(TokenType::OpenBracket, "Expect '[' before array.");
158        let mut values = Vec::new();
159
160        while !self.scanner.check(TokenType::CloseBracket) {
161            values.push(self.parse_value()?);
162
163            if self.scanner.check(TokenType::Comma) {
164                self.scanner.advance(); // consume comma
165            }
166        }
167
168        self.scanner
169            .consume(TokenType::CloseBracket, "Expect '] at the end of array.");
170        Some(values)
171    }
172
173    fn parse_value(&mut self) -> Option<Value> {
174        let token = self.scanner.current;
175        self.scanner.advance();
176        match token.kind {
177            TokenType::String => Some(Value::String(token.lexeme.to_owned())),
178            TokenType::Number => {
179                let num_str = token.lexeme;
180                // First try parsing as i64 (integer)
181                if let Ok(int_val) = num_str.parse::<i64>() {
182                    Some(Value::Number(serde_json::Number::from(int_val)))
183                } else {
184                    // If not an integer, try as f64 (float)
185                    match num_str.parse::<f64>() {
186                        Ok(float_val) => match serde_json::Number::from_f64(float_val) {
187                            Some(num) => Some(Value::Number(num)),
188                            None => {
189                                self.scanner.error("Invalid float value");
190                                None
191                            }
192                        },
193                        Err(err) => {
194                            self.scanner.error(&format!("Invalid number: {err}"));
195                            None
196                        }
197                    }
198                }
199            }
200            TokenType::True => Some(Value::Bool(true)),
201            TokenType::False => Some(Value::Bool(false)),
202            TokenType::OpenBracket => {
203                let values = self.parse_array()?;
204                Some(Value::Array(values))
205            }
206            _ => {
207                self.scanner
208                    .error(&format!("Unexpected token {:?}", token.kind));
209                None
210            }
211        }
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use aiscript_lexer::Scanner;
219    use serde_json::json;
220
221    fn parse_single_directive(input: &str) -> Option<Directive> {
222        let mut scanner = Scanner::new(input);
223        let mut parser = DirectiveParser::new(&mut scanner);
224        parser.parse_directive()
225    }
226
227    #[test]
228    fn test_simple_directive() {
229        let directive = parse_single_directive("@validate").unwrap();
230        assert_eq!(directive.name, "validate");
231        assert!(matches!(directive.params, DirectiveParams::KeyValue(ref map) if map.is_empty()));
232    }
233
234    #[test]
235    fn test_directive_with_array() {
236        let directive = parse_single_directive("@values([1, 2, 3])").unwrap();
237        assert_eq!(directive.name, "values");
238        match directive.params {
239            DirectiveParams::Array(values) => {
240                assert_eq!(values.len(), 3);
241                assert_eq!(values[0], json!(1));
242                assert_eq!(values[1], json!(2));
243                assert_eq!(values[2], json!(3));
244            }
245            _ => panic!("Expected Array parameters"),
246        }
247    }
248
249    #[test]
250    fn test_directive_with_mixed_array() {
251        let directive = parse_single_directive(r#"@values([1, "test", true])"#).unwrap();
252        assert_eq!(directive.name, "values");
253        match directive.params {
254            DirectiveParams::Array(values) => {
255                assert_eq!(values.len(), 3);
256                assert_eq!(values[0], json!(1));
257                assert_eq!(values[1], json!("test"));
258                assert_eq!(values[2], json!(true));
259            }
260            _ => panic!("Expected Array parameters"),
261        }
262    }
263
264    #[test]
265    fn test_directive_with_key_value() {
266        let directive = parse_single_directive(r#"@validate(min=1, max=10, name="test")"#).unwrap();
267        assert_eq!(directive.name, "validate");
268        match directive.params {
269            DirectiveParams::KeyValue(params) => {
270                assert_eq!(params.len(), 3);
271                assert_eq!(params.get("min").unwrap(), &json!(1));
272                assert_eq!(params.get("max").unwrap(), &json!(10));
273                assert_eq!(params.get("name").unwrap(), &json!("test"));
274            }
275            _ => panic!("Expected KeyValue parameters"),
276        }
277    }
278
279    #[test]
280    fn test_directive_with_nested_directives() {
281        let directive =
282            parse_single_directive("@combine(@length(min=5), @pattern(regex=\"[a-z]+\"))").unwrap();
283        assert_eq!(directive.name, "combine");
284        match directive.params {
285            DirectiveParams::Directives(directives) => {
286                assert_eq!(directives.len(), 2);
287
288                let first = &directives[0];
289                assert_eq!(first.name, "length");
290                match &first.params {
291                    DirectiveParams::KeyValue(params) => {
292                        assert_eq!(params.get("min").unwrap(), &json!(5));
293                    }
294                    _ => panic!("Expected KeyValue parameters for length directive"),
295                }
296
297                let second = &directives[1];
298                assert_eq!(second.name, "pattern");
299                match &second.params {
300                    DirectiveParams::KeyValue(params) => {
301                        assert_eq!(params.get("regex").unwrap(), &json!("[a-z]+"));
302                    }
303                    _ => panic!("Expected KeyValue parameters for pattern directive"),
304                }
305            }
306            _ => panic!("Expected Directives parameters"),
307        }
308    }
309
310    #[test]
311    fn test_directive_with_empty_array() {
312        let directive = parse_single_directive("@values([])").unwrap();
313        assert_eq!(directive.name, "values");
314        match directive.params {
315            DirectiveParams::Array(values) => {
316                assert_eq!(values.len(), 0);
317            }
318            _ => panic!("Expected Array parameters"),
319        }
320    }
321
322    #[test]
323    fn test_directive_with_empty_key_value() {
324        let directive = parse_single_directive("@validate()").unwrap();
325        assert_eq!(directive.name, "validate");
326        match directive.params {
327            DirectiveParams::KeyValue(params) => {
328                assert!(params.is_empty());
329            }
330            _ => panic!("Expected KeyValue parameters"),
331        }
332    }
333
334    #[test]
335    fn test_invalid_directives() {
336        // assert!(parse_single_directive("validate").is_none()); // Missing @
337        assert!(parse_single_directive("@").is_none()); // Missing name
338        assert!(parse_single_directive("@validate(").is_none()); // Unclosed parenthesis
339        assert!(parse_single_directive("@validate(min=)").is_none()); // Missing value
340        assert!(parse_single_directive("@validate(=5)").is_none()); // Missing key
341    }
342
343    #[test]
344    fn test_complex_nested_directives() {
345        let directive = parse_single_directive(
346            r#"@group(
347                @validate(min=1, max=10),
348                @format([1, 2, 3]),
349                @nested(@check(value=true))
350            )"#,
351        )
352        .unwrap();
353
354        assert_eq!(directive.name, "group");
355        match directive.params {
356            DirectiveParams::Directives(directives) => {
357                assert_eq!(directives.len(), 3);
358
359                // First nested directive
360                let validate = &directives[0];
361                assert_eq!(validate.name, "validate");
362                match &validate.params {
363                    DirectiveParams::KeyValue(params) => {
364                        assert_eq!(params.get("min").unwrap(), &json!(1));
365                        assert_eq!(params.get("max").unwrap(), &json!(10));
366                    }
367                    _ => panic!("Expected KeyValue parameters for validate"),
368                }
369
370                // Second nested directive
371                let format = &directives[1];
372                assert_eq!(format.name, "format");
373                match &format.params {
374                    DirectiveParams::Array(values) => {
375                        assert_eq!(values.len(), 3);
376                        assert_eq!(values[0], json!(1));
377                        assert_eq!(values[1], json!(2));
378                        assert_eq!(values[2], json!(3));
379                    }
380                    _ => panic!("Expected Array parameters for format"),
381                }
382
383                // Third nested directive with its own nested directive
384                let nested = &directives[2];
385                assert_eq!(nested.name, "nested");
386                match &nested.params {
387                    DirectiveParams::Directives(inner) => {
388                        assert_eq!(inner.len(), 1);
389                        let check = &inner[0];
390                        assert_eq!(check.name, "check");
391                        match &check.params {
392                            DirectiveParams::KeyValue(params) => {
393                                assert_eq!(params.get("value").unwrap(), &json!(true));
394                            }
395                            _ => panic!("Expected KeyValue parameters for check"),
396                        }
397                    }
398                    _ => panic!("Expected Directives parameters for nested"),
399                }
400            }
401            _ => panic!("Expected Directives parameters"),
402        }
403    }
404}