1use crate::ast::*;
4use nom::{
5 branch::alt,
6 bytes::complete::{tag, take_while, take_while1},
7 character::complete::{char, digit1, multispace0, multispace1},
8 combinator::{all_consuming, map, map_res, opt, recognize, value},
9 multi::{many0, separated_list0},
10 sequence::{delimited, pair, preceded, terminated},
11 IResult,
12};
13use std::path::Path;
14
15pub fn parse_file(path: impl AsRef<Path>) -> Result<Protocol, String> {
17 let content =
18 std::fs::read_to_string(path.as_ref()).map_err(|e| format!("failed to read file: {}", e))?;
19 parse_protocol(&content)
20}
21
22pub fn parse_protocol(input: &str) -> Result<Protocol, String> {
24 let input = remove_comments(input);
26
27 let result = all_consuming(protocol_parser)(&input);
28 match result {
29 Ok((_, protocol)) => Ok(protocol),
30 Err(e) => Err(format!("parse error: {:?}", e)),
31 }
32}
33
34fn remove_comments(input: &str) -> String {
36 let mut result = String::with_capacity(input.len());
37 let mut chars = input.chars().peekable();
38 let mut at_line_start = true;
39
40 while let Some(c) = chars.next() {
41 if c == '/' {
42 match chars.peek() {
43 Some('*') => {
44 chars.next();
46 while let Some(c) = chars.next() {
47 if c == '*' && chars.peek() == Some(&'/') {
48 chars.next();
49 result.push(' '); break;
51 }
52 }
53 at_line_start = false;
54 }
55 Some('/') => {
56 chars.next();
58 while let Some(&c) = chars.peek() {
59 if c == '\n' {
60 break;
61 }
62 chars.next();
63 }
64 }
65 _ => {
66 result.push(c);
67 at_line_start = false;
68 }
69 }
70 } else if c == '#' || (c == '%' && at_line_start) {
71 while let Some(&c) = chars.peek() {
73 if c == '\n' {
74 result.push('\n');
75 at_line_start = true;
76 break;
77 }
78 chars.next();
79 }
80 } else if c == '\n' {
81 result.push(c);
82 at_line_start = true;
83 } else if c.is_whitespace() {
84 result.push(c);
85 } else {
87 result.push(c);
88 at_line_start = false;
89 }
90 }
91
92 result
93}
94
95fn resolve_well_known_constant(name: &str) -> Option<u32> {
97 match name {
98 "VIR_UUID_BUFLEN" => Some(16),
99 "VIR_UUID_STRING_BUFLEN" => Some(37),
100 _ => None,
101 }
102}
103
104fn ws<'a, F, O, E>(inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O, E>
107where
108 F: FnMut(&'a str) -> IResult<&'a str, O, E>,
109 E: nom::error::ParseError<&'a str>,
110{
111 delimited(multispace0, inner, multispace0)
112}
113
114fn identifier(input: &str) -> IResult<&str, &str> {
115 recognize(pair(
116 take_while1(|c: char| c.is_ascii_alphabetic() || c == '_'),
117 take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'),
118 ))(input)
119}
120
121fn integer(input: &str) -> IResult<&str, i64> {
122 alt((
123 map_res(
125 preceded(
126 alt((tag("0x"), tag("0X"))),
127 take_while1(|c: char| c.is_ascii_hexdigit()),
128 ),
129 |s: &str| i64::from_str_radix(s, 16),
130 ),
131 map_res(recognize(pair(opt(char('-')), digit1)), |s: &str| {
133 s.parse::<i64>()
134 }),
135 ))(input)
136}
137
138fn const_value(input: &str) -> IResult<&str, ConstValue> {
139 alt((
140 map(integer, ConstValue::Int),
141 map(identifier, |s| ConstValue::Ident(s.to_string())),
142 ))(input)
143}
144
145fn protocol_parser(input: &str) -> IResult<&str, Protocol> {
148 let (input, items) = many0(ws(definition))(input)?;
149
150 let mut protocol = Protocol::new("remote");
151
152 for item in items {
153 match item {
154 Definition::Const(c) => protocol.constants.push(c),
155 Definition::Type(t) => protocol.types.push(t),
156 }
157 }
158
159 extract_procedures(&mut protocol);
161
162 Ok((input, protocol))
163}
164
165fn extract_procedures(protocol: &mut Protocol) {
171 let procedure_enum = protocol
173 .types
174 .iter()
175 .find_map(|t| {
176 if let TypeDef::Enum(e) = t {
177 if e.name == "remote_procedure" {
178 return Some(e.clone());
179 }
180 }
181 None
182 });
183
184 let procedure_enum = match procedure_enum {
185 Some(e) => e,
186 None => return,
187 };
188
189 let struct_names: std::collections::HashSet<String> = protocol
191 .types
192 .iter()
193 .filter_map(|t| {
194 if let TypeDef::Struct(s) = t {
195 Some(s.name.clone())
196 } else {
197 None
198 }
199 })
200 .collect();
201
202 for variant in &procedure_enum.variants {
204 let number = match &variant.value {
205 Some(ConstValue::Int(n)) => *n as u32,
206 _ => continue,
207 };
208
209 let base_name = variant
211 .name
212 .strip_prefix("REMOTE_PROC_")
213 .unwrap_or(&variant.name)
214 .to_lowercase();
215
216 let args_name = format!("remote_{}_args", base_name);
217 let ret_name = format!("remote_{}_ret", base_name);
218
219 let args = if struct_names.contains(&args_name) {
220 Some(args_name)
221 } else {
222 None
223 };
224
225 let ret = if struct_names.contains(&ret_name) {
226 Some(ret_name)
227 } else {
228 None
229 };
230
231 protocol.procedures.push(Procedure {
232 name: variant.name.clone(),
233 number,
234 args,
235 ret,
236 priority: Priority::default(),
237 });
238 }
239}
240
241enum Definition {
242 Const(Constant),
243 Type(TypeDef),
244}
245
246fn definition(input: &str) -> IResult<&str, Definition> {
247 alt((
248 map(const_def, Definition::Const),
249 map(type_def, Definition::Type),
250 ))(input)
251}
252
253fn const_def(input: &str) -> IResult<&str, Constant> {
255 let (input, _) = tag("const")(input)?;
256 let (input, _) = multispace1(input)?;
257 let (input, name) = identifier(input)?;
258 let (input, _) = ws(char('='))(input)?;
259 let (input, value) = const_value(input)?;
260 let (input, _) = ws(char(';'))(input)?;
261
262 Ok((
263 input,
264 Constant {
265 name: name.to_string(),
266 value,
267 },
268 ))
269}
270
271fn type_def(input: &str) -> IResult<&str, TypeDef> {
273 alt((
274 map(struct_def, TypeDef::Struct),
275 map(enum_def, TypeDef::Enum),
276 map(union_def, TypeDef::Union),
277 map(typedef_def, TypeDef::Typedef),
278 ))(input)
279}
280
281fn struct_def(input: &str) -> IResult<&str, StructDef> {
283 let (input, _) = tag("struct")(input)?;
284 let (input, _) = multispace1(input)?;
285 let (input, name) = identifier(input)?;
286 let (input, _) = ws(char('{'))(input)?;
287 let (input, fields) = many0(ws(field_def))(input)?;
288 let (input, _) = ws(char('}'))(input)?;
289 let (input, _) = ws(char(';'))(input)?;
290
291 Ok((
292 input,
293 StructDef {
294 name: name.to_string(),
295 fields,
296 },
297 ))
298}
299
300fn field_def(input: &str) -> IResult<&str, Field> {
302 let (input, ty) = type_spec(input)?;
303 let (input, _) = multispace1(input)?;
304 let (input, name) = identifier(input)?;
305 let (input, ty) = array_suffix(input, ty)?;
307 let (input, _) = ws(char(';'))(input)?;
308
309 Ok((
310 input,
311 Field {
312 name: name.to_string(),
313 ty,
314 },
315 ))
316}
317
318fn array_suffix(input: &str, base_ty: Type) -> IResult<&str, Type> {
320 let input_trimmed = input.trim_start();
321
322 if input_trimmed.starts_with('[') {
323 let (input, _) = multispace0(input)?;
325 let (input, _) = char('[')(input)?;
326 let (input, len) = ws(const_value)(input)?;
327 let (input, _) = char(']')(input)?;
328
329 let size = match &len {
330 ConstValue::Int(n) => *n as u32,
331 ConstValue::Ident(name) => resolve_well_known_constant(name).unwrap_or(0),
332 };
333
334 match &base_ty {
336 Type::Opaque { .. } => Ok((
337 input,
338 Type::Opaque {
339 len: LengthSpec::Fixed(size),
340 },
341 )),
342 _ => Ok((
343 input,
344 Type::Array {
345 elem: Box::new(base_ty),
346 len: LengthSpec::Fixed(size),
347 },
348 )),
349 }
350 } else if input_trimmed.starts_with('<') {
351 match &base_ty {
355 Type::String { .. } | Type::Opaque { .. } => {
356 let (input, _) = multispace0(input)?;
358 let (input, _) = char('<')(input)?;
359 let (input, len) = ws(opt(const_value))(input)?;
360 let (input, _) = char('>')(input)?;
361
362 let max = len.and_then(|v| match v {
363 ConstValue::Int(n) => Some(n as u32),
364 ConstValue::Ident(_) => None,
365 });
366
367 match base_ty {
369 Type::String { .. } => Ok((input, Type::String { max_len: max })),
370 Type::Opaque { .. } => Ok((
371 input,
372 Type::Opaque {
373 len: LengthSpec::Variable { max },
374 },
375 )),
376 _ => unreachable!(),
377 }
378 }
379 _ => {
380 let (input, _) = multispace0(input)?;
382 let (input, _) = char('<')(input)?;
383 let (input, len) = ws(opt(const_value))(input)?;
384 let (input, _) = char('>')(input)?;
385
386 let max = len.and_then(|v| match v {
387 ConstValue::Int(n) => Some(n as u32),
388 ConstValue::Ident(_) => None,
389 });
390
391 Ok((
392 input,
393 Type::Array {
394 elem: Box::new(base_ty),
395 len: LengthSpec::Variable { max },
396 },
397 ))
398 }
399 }
400 } else {
401 Ok((input, base_ty))
403 }
404}
405
406fn type_spec(input: &str) -> IResult<&str, Type> {
408 alt((
409 value(Type::Void, tag("void")),
410 value(
412 Type::UHyper,
413 pair(tag("unsigned"), preceded(multispace1, tag("hyper"))),
414 ),
415 value(
416 Type::UInt,
417 pair(tag("unsigned"), preceded(multispace1, tag("int"))),
418 ),
419 map(
421 pair(tag("unsigned"), preceded(multispace1, tag("char"))),
422 |_| Type::Named("u8".to_string()),
423 ),
424 map(
426 pair(tag("unsigned"), preceded(multispace1, tag("short"))),
427 |_| Type::Named("u16".to_string()),
428 ),
429 value(Type::Named("i8".to_string()), tag("char")),
431 value(Type::Named("i16".to_string()), tag("short")),
433 value(Type::Hyper, tag("hyper")),
434 value(Type::Int, tag("int")),
435 value(Type::Float, tag("float")),
436 value(Type::Double, tag("double")),
437 value(Type::Bool, tag("bool")),
438 string_type,
439 opaque_type,
440 optional_type,
441 map(identifier, |s| Type::Named(s.to_string())),
442 ))(input)
443}
444
445fn optional_type(input: &str) -> IResult<&str, Type> {
447 let (input, ty) = alt((
448 value(
449 Type::UHyper,
450 pair(tag("unsigned"), preceded(multispace1, tag("hyper"))),
451 ),
452 value(
453 Type::UInt,
454 pair(tag("unsigned"), preceded(multispace1, tag("int"))),
455 ),
456 value(Type::Hyper, tag("hyper")),
457 value(Type::Int, tag("int")),
458 value(Type::Float, tag("float")),
459 value(Type::Double, tag("double")),
460 value(Type::Bool, tag("bool")),
461 map(identifier, |s| Type::Named(s.to_string())),
462 ))(input)?;
463
464 let (input, _) = ws(char('*'))(input)?;
465
466 Ok((input, Type::Optional(Box::new(ty))))
467}
468
469fn string_type(input: &str) -> IResult<&str, Type> {
471 let (input, _) = tag("string")(input)?;
472 let (input, max_len) = opt(delimited(char('<'), ws(opt(integer)), char('>')))(input)?;
473
474 let max_len = max_len.flatten().map(|n| n as u32);
475
476 Ok((input, Type::String { max_len }))
477}
478
479fn opaque_type(input: &str) -> IResult<&str, Type> {
481 let (input, _) = tag("opaque")(input)?;
482
483 Ok((
484 input,
485 Type::Opaque {
486 len: LengthSpec::Variable { max: None },
487 },
488 ))
489}
490
491fn enum_def(input: &str) -> IResult<&str, EnumDef> {
493 let (input, _) = tag("enum")(input)?;
494 let (input, _) = multispace1(input)?;
495 let (input, name) = identifier(input)?;
496 let (input, _) = ws(char('{'))(input)?;
497 let (input, variants) = separated_list0(ws(char(',')), ws(enum_variant))(input)?;
498 let (input, _) = opt(ws(char(',')))(input)?; let (input, _) = ws(char('}'))(input)?;
500 let (input, _) = ws(char(';'))(input)?;
501
502 Ok((
503 input,
504 EnumDef {
505 name: name.to_string(),
506 variants,
507 },
508 ))
509}
510
511fn enum_variant(input: &str) -> IResult<&str, EnumVariant> {
513 let (input, name) = identifier(input)?;
514 let (input, value) = opt(preceded(ws(char('=')), const_value))(input)?;
515
516 Ok((
517 input,
518 EnumVariant {
519 name: name.to_string(),
520 value,
521 },
522 ))
523}
524
525fn union_def(input: &str) -> IResult<&str, UnionDef> {
527 let (input, _) = tag("union")(input)?;
528 let (input, _) = multispace1(input)?;
529 let (input, name) = identifier(input)?;
530 let (input, _) = ws(tag("switch"))(input)?;
531 let (input, _) = ws(char('('))(input)?;
532 let (input, disc_ty) = type_spec(input)?;
533 let (input, _) = multispace1(input)?;
534 let (input, disc_name) = identifier(input)?;
535 let (input, _) = ws(char(')'))(input)?;
536 let (input, _) = ws(char('{'))(input)?;
537 let (input, cases) = many0(ws(union_case))(input)?;
538 let (input, default) = opt(union_default)(input)?;
539 let (input, _) = ws(char('}'))(input)?;
540 let (input, _) = ws(char(';'))(input)?;
541
542 Ok((
543 input,
544 UnionDef {
545 name: name.to_string(),
546 discriminant: Field {
547 name: disc_name.to_string(),
548 ty: disc_ty,
549 },
550 cases,
551 default,
552 },
553 ))
554}
555
556fn union_case(input: &str) -> IResult<&str, UnionCase> {
558 let (input, _) = tag("case")(input)?;
559 let (input, _) = multispace1(input)?;
560 let (input, value) = const_value(input)?;
561 let (input, _) = ws(char(':'))(input)?;
562
563 let (input, field) = alt((
565 map(field_def, Some),
566 map(terminated(tag("void"), ws(char(';'))), |_| None),
567 ))(input)?;
568
569 Ok((
570 input,
571 UnionCase {
572 values: vec![value],
573 field,
574 },
575 ))
576}
577
578fn union_default(input: &str) -> IResult<&str, Box<Type>> {
580 let (input, _) = ws(tag("default"))(input)?;
581 let (input, _) = ws(char(':'))(input)?;
582 let (input, field) = field_def(input)?;
583
584 Ok((input, Box::new(field.ty)))
585}
586
587fn typedef_def(input: &str) -> IResult<&str, TypedefDef> {
589 let (input, _) = tag("typedef")(input)?;
590 let (input, _) = multispace1(input)?;
591 let (input, target) = type_spec(input)?;
592 let (input, _) = multispace0(input)?;
593
594 let (input, is_pointer) = opt(char('*'))(input)?;
596 let (input, _) = multispace0(input)?;
597 let (input, name) = identifier(input)?;
598
599 let (input, target) = array_suffix(input, target)?;
601 let (input, _) = ws(char(';'))(input)?;
602
603 let target = if is_pointer.is_some() {
604 Type::Optional(Box::new(target))
605 } else {
606 target
607 };
608
609 Ok((
610 input,
611 TypedefDef {
612 name: name.to_string(),
613 target,
614 },
615 ))
616}
617
618#[cfg(test)]
619mod tests {
620 use super::*;
621
622 #[test]
623 fn test_remove_comments() {
624 let input = r#"
625 /* block comment */
626 const FOO = 1; // line comment
627 # preprocessor
628 const BAR = 2;
629 "#;
630 let result = remove_comments(input);
631 assert!(!result.contains("block comment"));
632 assert!(!result.contains("line comment"));
633 assert!(!result.contains("preprocessor"));
634 }
635
636 #[test]
637 fn test_parse_const() {
638 let input = "const FOO = 42;";
639 let (_, c) = const_def(input).unwrap();
640 assert_eq!(c.name, "FOO");
641 assert!(matches!(c.value, ConstValue::Int(42)));
642 }
643
644 #[test]
645 fn test_parse_struct() {
646 let input = r#"
647 struct Point {
648 int x;
649 int y;
650 };
651 "#;
652 let result = parse_protocol(input).unwrap();
653 assert_eq!(result.types.len(), 1);
654
655 if let TypeDef::Struct(s) = &result.types[0] {
656 assert_eq!(s.name, "Point");
657 assert_eq!(s.fields.len(), 2);
658 } else {
659 panic!("expected struct");
660 }
661 }
662
663 #[test]
664 fn test_parse_enum() {
665 let input = r#"
666 enum Color {
667 RED = 0,
668 GREEN = 1,
669 BLUE = 2
670 };
671 "#;
672 let result = parse_protocol(input).unwrap();
673 assert_eq!(result.types.len(), 1);
674
675 if let TypeDef::Enum(e) = &result.types[0] {
676 assert_eq!(e.name, "Color");
677 assert_eq!(e.variants.len(), 3);
678 } else {
679 panic!("expected enum");
680 }
681 }
682
683 #[test]
684 fn test_parse_typedef() {
685 let input = "typedef string remote_string<>;";
686 let result = parse_protocol(input).unwrap();
687 assert_eq!(result.types.len(), 1);
688
689 if let TypeDef::Typedef(t) = &result.types[0] {
690 assert_eq!(t.name, "remote_string");
691 } else {
692 panic!("expected typedef");
693 }
694 }
695}