Skip to main content

avdl_serde_code_generator/
to_rust.rs

1#![allow(dead_code)]
2use crate::Rule;
3use pest::iterators::{Pair, Pairs};
4use regex::Regex;
5use std::{
6  cell::RefCell,
7  error::Error,
8  fmt::Write as FmtWrite,
9  io::{self, Write},
10  str::FromStr,
11};
12
13const DERIVES: &'static str = "Serialize, Deserialize, Debug";
14const DERIVES_WITH_HASH: &'static str = "Serialize, Deserialize, Debug, Hash, PartialEq, Eq";
15const DERIVES_WITH_HASH_ENUM: &'static str =
16  "Serialize_repr, Deserialize_repr, Debug, Hash, PartialEq, Eq";
17
18thread_local! {
19    pub static ADD_DERIVE: RefCell<bool> = RefCell::new(false);
20    pub static DERIVE_HASH_FOR_IDENT: RefCell<Vec<String>> = RefCell::new(vec![]);
21
22}
23
24pub fn set_derive_hash_for_ident(v: Vec<String>) {
25  DERIVE_HASH_FOR_IDENT.with(|t| {
26    *t.borrow_mut() = v;
27  });
28}
29
30fn should_derive_hash(v: &str) -> bool {
31  DERIVE_HASH_FOR_IDENT.with(|t| t.borrow().iter().any(|s| s.as_str() == v))
32}
33
34pub fn set_add_derive(v: bool) {
35  ADD_DERIVE.with(|add_derive| {
36    *add_derive.borrow_mut() = v;
37  })
38}
39
40fn convert_path_str_to_rust_mod(path: &str, as_name: &str) -> String {
41  let path = String::from(path);
42  let has_parts = path.contains("/");
43  let parts: Vec<&str> = path.split("/").collect();
44  let is_as_name_same = as_name == "" || parts.last().unwrap() == &as_name;
45
46  let mut module = if has_parts {
47    if parts.get(0) == Some(&"..") {
48      let mut s = String::from("super::");
49      s.push_str(
50        &parts
51          .into_iter()
52          .map(|part| if part == ".." { "super" } else { part })
53          .collect::<Vec<&str>>()
54          .join("::"),
55      );
56      s
57    } else {
58      let important_parts = parts
59        .into_iter()
60        .skip_while(|s| s != &"protocol")
61        .collect::<Vec<&str>>();
62      format!("crate::{}", important_parts.join("::"))
63    }
64  } else {
65    format!("super::{}::*", path.replace(".avdl", ""))
66  };
67
68  if !is_as_name_same {
69    module.push_str(" as ");
70    module.push_str(as_name);
71  }
72  module.push_str(";");
73
74  module
75}
76
77fn convert_to_import<W>(w: &mut W, p: Pair<Rule>) -> Result<(), Box<dyn Error>>
78where
79  W: Write,
80{
81  assert_eq!(p.as_rule(), Rule::import);
82  let parts = p.into_inner().take(3).collect::<Vec<Pair<Rule>>>();
83  let (kind, path, as_name) = (&parts[0], &parts[1], parts.get(2));
84  match kind.as_str() {
85    "idl" => {
86      write!(
87        w,
88        "use {}",
89        convert_path_str_to_rust_mod(path.as_str(), as_name.map(|p| p.as_str()).unwrap_or(""))
90      )?;
91    }
92    _ => panic!("Unhandled import kind (not idl)"),
93  }
94
95  Ok(())
96}
97
98// Hack to handle cases where the struct is actually infinite size.
99// We make the type a Box<T> when it's used.
100fn should_box(idl_type: &str) -> bool {
101  match idl_type {
102    "MessageUnboxed" | "UIMessage" => true,
103    _ => false,
104  }
105}
106
107fn convert_idl_type_to_rust_type(idl_type: &str) -> String {
108  if should_box(idl_type) {
109    return format!("Box<{}>", idl_type);
110  }
111  match idl_type {
112    // "bytes" => String::from("Vec<u8>"),
113    "bytes" => String::from("String"),
114    "boolean" => String::from("bool"),
115    "uint" => String::from("u32"),
116    "int" => String::from("i32"),
117    "int64" => String::from("i64"),
118    "long" => String::from("i64"),
119    "double" => String::from("f64"),
120    "uint64" => String::from("u64"),
121    "uint32" => String::from("u32"),
122    "array" => String::from("Vec"),
123    "string" => String::from("String"),
124    "Hash" => String::from("String"),
125    "void" => String::from("()"),
126    "map" => String::from("std::collections::HashMap"),
127    _ => idl_type.into(),
128  }
129}
130
131fn convert_dot_to_sep(a: &str) -> String {
132  a.split(".").collect::<Vec<&str>>().join("::")
133}
134
135enum AVDLType {
136  Simple(String),
137  Maybe(String),
138  Union(),
139}
140
141impl ToString for AVDLType {
142  fn to_string(&self) -> String {
143    match self {
144      AVDLType::Simple(s) => s.clone(),
145      AVDLType::Maybe(s) => format!("Option<{}>", s),
146      AVDLType::Union() => panic!("Not implemented"),
147    }
148  }
149}
150
151struct AVDLSimpleType {
152  ty: Option<String>,
153  generics: Vec<AVDLType>,
154}
155
156impl ToString for AVDLSimpleType {
157  fn to_string(&self) -> String {
158    let mut s = self.ty.as_ref().unwrap().clone();
159    if !self.generics.is_empty() {
160      write!(
161        &mut s,
162        "<{}>",
163        self
164          .generics
165          .iter()
166          .map(|g| g.to_string())
167          .collect::<Vec<String>>()
168          .join(", ")
169      )
170      .unwrap();
171    }
172    s
173  }
174}
175
176impl<'a> From<Pair<'a, Rule>> for AVDLSimpleType {
177  fn from(pair: Pair<'a, Rule>) -> Self {
178    assert_eq!(pair.as_rule(), Rule::simple_ty);
179    let parts = pair.into_inner();
180    let mut ty = AVDLSimpleType {
181      ty: None,
182      generics: vec![],
183    };
184
185    let transform = |p: &Pair<Rule>| convert_dot_to_sep(&convert_idl_type_to_rust_type(p.as_str()));
186
187    for pair in parts {
188      match pair.as_rule() {
189        Rule::ns_ident => ty.ty = Some(transform(&pair)),
190        Rule::ty => ty.generics.push(pair.into()),
191        _ => unreachable!(),
192      }
193    }
194
195    if ty.ty.as_ref().map(|s| s.as_str()) == Some("std::collections::HashMap")
196      && ty.generics.len() == 1
197    {
198      ty.generics
199        .insert(0, AVDLType::Simple(String::from("String")));
200    }
201
202    ty
203  }
204}
205
206impl<'a> From<Pair<'a, Rule>> for AVDLType {
207  fn from(pair: Pair<'a, Rule>) -> Self {
208    assert!(
209      pair.as_rule() == Rule::ty || pair.as_rule() == Rule::maybe_ty,
210      "Unexpected rule: {:?}",
211      pair.as_rule()
212    );
213
214    let inner = pair.into_inner().next().unwrap();
215    match inner.as_rule() {
216      Rule::simple_ty => {
217        let ty: AVDLSimpleType = inner.into();
218        AVDLType::Simple(ty.to_string())
219      }
220      Rule::maybe_ty => {
221        let inner_ty: AVDLType = inner.into();
222        AVDLType::Maybe(inner_ty.to_string())
223      }
224      _ => panic!("Unhandled case: {:?}", inner.as_rule()),
225    }
226  }
227}
228
229struct AVDLIdent(String);
230impl<'a> From<Pair<'a, Rule>> for AVDLIdent {
231  fn from(pair: Pair<'a, Rule>) -> Self {
232    assert_eq!(pair.as_rule(), Rule::ident);
233    let s = match pair.as_str() {
234      "box" => "box_",
235      "match" => "match_",
236      "ref" => "ref_",
237      "type" => "ty",
238      "self" => "self_",
239      "where" => "where_",
240      _ => pair.as_str(),
241    };
242    AVDLIdent(s.into())
243  }
244}
245
246fn convert_typedef<W>(w: &mut W, p: Pair<Rule>) -> Result<(), Box<dyn Error>>
247where
248  W: Write,
249{
250  assert_eq!(p.as_rule(), Rule::typedef);
251  let mut parts = p.into_inner();
252  let mut type_name: Option<String> = None;
253  let mut type_target: Option<String> = None;
254  while let Some(pair) = parts.next() {
255    match pair.as_rule() {
256      Rule::ty => {
257        let ty: AVDLType = pair.into();
258        type_target = Some(ty.to_string());
259      }
260      Rule::lint => {
261        write!(w, "// LINT: {}\n", pair.as_str())?;
262      }
263      Rule::record => type_name = Some(pair.into_inner().next().unwrap().as_str().into()),
264      _ => unreachable!(),
265    }
266  }
267
268  write!(
269    w,
270    "pub type {} = {};",
271    type_name.unwrap(),
272    type_target.unwrap()
273  )?;
274
275  Ok(())
276}
277
278struct EnumCaseTy {
279  enum_name: String,
280  comment: Option<String>,
281  preline_comment: Option<String>,
282}
283
284impl<'a> From<Pair<'a, Rule>> for EnumCaseTy {
285  fn from(pair: Pair<'a, Rule>) -> Self {
286    let mut parts = pair.into_inner();
287    let mut enum_name: Option<String> = None;
288    let mut comment: Option<String> = None;
289    while let Some(pair) = parts.next() {
290      match pair.as_rule() {
291        Rule::ident => enum_name = Some(quiet_voice(pair.into())),
292        Rule::comment => comment = Some(format!(" {}", pair.as_str())),
293        _ => unreachable!(),
294      }
295    }
296
297    EnumCaseTy {
298      enum_name: enum_name.unwrap(),
299      comment,
300      preline_comment: None,
301    }
302  }
303}
304
305impl WriteTo for EnumCaseTy {
306  fn write_to<W: Write>(&self, w: &mut W) -> Result<(), io::Error> {
307    if let Some(preline_comment) = self.preline_comment.as_ref() {
308      write!(w, "  {}", preline_comment)?;
309    }
310    write!(
311      w,
312      "  {},{}",
313      self.enum_name,
314      self.comment.as_ref().map(|s| s.as_str()).unwrap_or("\n")
315    )?;
316    Ok(())
317  }
318}
319
320struct EnumTy {
321  ident: String,
322  cases: Vec<EnumCaseTy>,
323}
324
325impl<'a> From<Pair<'a, Rule>> for EnumTy {
326  fn from(pair: Pair<'a, Rule>) -> Self {
327    let mut parts = pair.into_inner();
328    let mut ident: Option<String> = None;
329    let mut comment: Option<String> = None;
330    let mut cases: Vec<EnumCaseTy> = vec![];
331    while let Some(pair) = parts.next() {
332      match pair.as_rule() {
333        Rule::ident => ident = Some(AVDLIdent::from(pair).0),
334        Rule::enum_case => {
335          let mut case: EnumCaseTy = pair.into();
336          case.preline_comment = comment.take();
337          cases.push(case);
338        }
339        Rule::comment => comment = Some(pair.as_str().into()),
340        _ => unreachable!(),
341      }
342    }
343
344    EnumTy {
345      ident: ident.expect("Couldn't find name for variant"),
346      cases,
347    }
348  }
349}
350
351impl WriteTo for EnumTy {
352  fn write_to<W: Write>(&self, w: &mut W) -> Result<(), io::Error> {
353    write!(
354      w,
355      "#[derive({})]\n#[repr(u8)]\npub enum {} {{\n",
356      DERIVES_WITH_HASH_ENUM, self.ident
357    )?;
358    for case in self.cases.iter() {
359      case.write_to(w)?;
360    }
361    write!(w, "}}")?;
362    Ok(())
363  }
364}
365
366fn convert_enum<W>(w: &mut W, p: Pair<Rule>) -> Result<(), Box<dyn Error>>
367where
368  W: Write,
369{
370  assert_eq!(p.as_rule(), Rule::enum_ty);
371  // This is actually already a rust enum!
372  let ty: EnumTy = p.into();
373  ty.write_to(w)?;
374
375  // write!(w, "{}", p.as_str())?;
376  Ok(())
377}
378
379// Turns FOO -> Foo
380fn quiet_voice(s: AVDLIdent) -> String {
381  s.0
382    .chars()
383    .enumerate()
384    .map(|(i, c)| {
385      if i == 0 {
386        c.to_ascii_uppercase()
387      } else {
388        c.to_ascii_lowercase()
389      }
390    })
391    .collect()
392}
393
394struct VariantCaseTy {
395  enum_name: String,
396  enum_inner_ty: String,
397  comment: Option<String>,
398}
399
400impl<'a> From<Pair<'a, Rule>> for VariantCaseTy {
401  fn from(pair: Pair<'a, Rule>) -> Self {
402    let mut parts = pair.into_inner();
403    let mut enum_name: Option<String> = None;
404    let mut enum_inner_ty: Option<AVDLType> = None;
405    let mut comment: Option<String> = None;
406    while let Some(pair) = parts.next() {
407      match pair.as_rule() {
408        Rule::ident => enum_name = Some(quiet_voice(pair.into())),
409        Rule::ty => enum_inner_ty = Some(pair.into()),
410        Rule::comment => comment = Some(format!(" {}", pair.as_str())),
411        _ => unreachable!(),
412      }
413    }
414
415    VariantCaseTy {
416      enum_name: enum_name.unwrap(),
417      enum_inner_ty: enum_inner_ty.unwrap().to_string(),
418      comment,
419    }
420  }
421}
422
423impl WriteTo for VariantCaseTy {
424  fn write_to<W: Write>(&self, w: &mut W) -> Result<(), io::Error> {
425    if self.enum_name == "Default" {
426      writeln!(w, "  {} {{}},", self.enum_name)?;
427    } else {
428      let enum_lowercase = match self.enum_name.as_str() {
429        "Move" => String::from("r#move"),
430        s => s.to_ascii_lowercase(),
431      };
432      writeln!(
433        w,
434        "  {} {{{}: {}}},",
435        self.enum_name, enum_lowercase, self.enum_inner_ty
436      )?;
437    }
438    Ok(())
439  }
440}
441
442struct VariantTy {
443  ident: String,
444  cases: Vec<VariantCaseTy>,
445}
446
447impl<'a> From<Pair<'a, Rule>> for VariantTy {
448  fn from(pair: Pair<'a, Rule>) -> Self {
449    let mut parts = pair.into_inner();
450    let mut ident: Option<String> = None;
451    let mut cases: Vec<VariantCaseTy> = vec![];
452    while let Some(pair) = parts.next() {
453      match pair.as_rule() {
454        Rule::ident => ident = Some(pair.as_str().into()),
455        Rule::variant_case => cases.push(pair.into()),
456        Rule::variant_param => {}
457        _ => unreachable!(),
458      }
459    }
460
461    VariantTy {
462      ident: ident.expect("Couldn't find name for variant"),
463      cases,
464    }
465  }
466}
467
468impl WriteTo for VariantTy {
469  fn write_to<W: Write>(&self, w: &mut W) -> Result<(), io::Error> {
470    let added_derives = write_derives(w, false).unwrap();
471    if added_derives {
472      write!(w, "#[serde(untagged)]\n")?;
473    }
474    write!(w, "pub enum {} {{\n", self.ident)?;
475    for case in self.cases.iter() {
476      case.write_to(w)?;
477    }
478    write!(w, "}}")?;
479    Ok(())
480  }
481}
482
483fn convert_variant<W>(w: &mut W, p: Pair<Rule>) -> Result<(), Box<dyn Error>>
484where
485  W: Write,
486{
487  assert_eq!(p.as_rule(), Rule::variant_ty);
488  let ty: VariantTy = p.into();
489  ty.write_to(w)?;
490  Ok(())
491}
492
493fn convert_fixed<W>(w: &mut W, p: Pair<Rule>) -> Result<(), Box<dyn Error>>
494where
495  W: Write,
496{
497  assert_eq!(p.as_rule(), Rule::fixed_ty);
498  let mut parts = p.into_inner();
499  let mut ty: Option<AVDLType> = None;
500  let mut _byte_size: usize = 0;
501  while let Some(pair) = parts.next() {
502    match pair.as_rule() {
503      Rule::ty => ty = Some(pair.into()),
504      Rule::byte_size => {
505        _byte_size = usize::from_str(pair.as_str()).expect("Couldn't parse byte_size")
506      }
507      _ => unreachable!(),
508    }
509  }
510
511  write!(
512    w,
513    // This is the correct one, but doesn't serialize/deserialize easily
514    // "pub type {} = [u8;{}];",
515    "pub type {} = Vec<u8>;",
516    ty.unwrap().to_string(),
517    // byte_size
518  )?;
519  Ok(())
520}
521
522struct JSONKey {
523  rename_to: String,
524}
525impl<'a> From<&str> for JSONKey {
526  fn from(s: &str) -> Self {
527    let re = Regex::new(r#"@jsonkey\("([^"]+)"\)"#).unwrap();
528    let mut captures = re.captures_iter(s);
529    JSONKey {
530      rename_to: captures.next().expect("Regex didn't match")[1].into(),
531    }
532  }
533}
534
535struct AVDLRecordProp {
536  ty: AVDLType,
537  field: String,
538  attributes: Vec<String>,
539}
540
541impl AVDLRecordProp {
542  fn can_hash(&self) -> bool {
543    match self.ty.to_string().as_str() {
544      "f64" => false,
545      _ => true,
546    }
547  }
548}
549
550pub trait WriteTo {
551  fn write_to<W: Write>(&self, w: &mut W) -> Result<(), io::Error>;
552}
553
554impl WriteTo for AVDLRecordProp {
555  fn write_to<W: Write>(&self, w: &mut W) -> Result<(), io::Error> {
556    let mut json_key: Option<JSONKey> = None;
557    let add_derive = ADD_DERIVE.with(|add_derive| *add_derive.borrow());
558    let mut is_optional = add_derive;
559    let mut add_default = false;
560    let mut is_bytes = false;
561    for attr in self.attributes.iter() {
562      if attr.contains("@jsonkey") {
563        json_key = Some(attr.as_str().into())
564      } else if attr.contains("@optional(true)") {
565        is_optional = true;
566      } else {
567        write!(w, "  // {}\n", attr)?;
568      }
569    }
570
571    if json_key.is_none() && add_derive {
572      match self.field.as_str() {
573        "ty" => write!(w, "  #[serde(rename = \"type\")]\n").unwrap(),
574        "box_" => write!(w, "  #[serde(rename = \"box\")]\n").unwrap(),
575        "match_" => write!(w, "  #[serde(rename = \"match\")]\n").unwrap(),
576        "ref_" => write!(w, "  #[serde(rename = \"ref\")]\n").unwrap(),
577        "self_" => write!(w, "  #[serde(rename = \"self\")]\n").unwrap(),
578        "where_" => write!(w, "  #[serde(rename = \"where\")]\n").unwrap(),
579        _ => {}
580      }
581    }
582
583    if let Some(json_key) = json_key {
584      if add_derive {
585        write!(w, "  #[serde(rename = \"{}\")]\n", json_key.rename_to).unwrap();
586      }
587    }
588
589    let ty_str = self.ty.to_string();
590    if ty_str == "String" || ty_str.contains("Vec") {
591      add_default = true;
592      is_optional = true;
593    }
594
595    if ty_str.contains("Vec<u8>") {
596      is_bytes = true;
597    }
598
599    if ty_str == "bool" {
600      add_default = true;
601    }
602
603    if ty_str.contains("Option<") {
604      is_optional = false;
605    }
606
607    if add_default && add_derive {
608      write!(w, "  #[serde(default)]\n").unwrap();
609    }
610
611    if is_bytes && add_derive {
612      // write!(w, r#"  #[serde(with = "Base64Standard")]\n"#).unwrap();
613    }
614
615    if is_optional {
616      write!(
617        w,
618        "  pub {}: Option<{}>,\n",
619        self.field,
620        self.ty.to_string()
621      )?;
622    } else {
623      write!(w, "  pub {}: {},\n", self.field, self.ty.to_string())?;
624    }
625    Ok(())
626  }
627}
628
629impl<'a> From<Pair<'a, Rule>> for AVDLRecordProp {
630  fn from(pair: Pair<'a, Rule>) -> Self {
631    assert_eq!(pair.as_rule(), Rule::record_prop);
632    let mut ty: Option<AVDLType> = None;
633    let mut field: Option<String> = None;
634    let mut attributes = vec![];
635    let mut parts = pair.into_inner();
636    while let Some(pair) = parts.next() {
637      match pair.as_rule() {
638        Rule::lint | Rule::generic_annotation => attributes.push(pair.as_str().into()),
639        Rule::ty => ty = Some(pair.into()),
640        Rule::ident => {
641          let ident: AVDLIdent = pair.into();
642          field = Some(ident.0);
643        }
644        _ => panic!("Unhandled case: {:?}", pair),
645      }
646    }
647
648    AVDLRecordProp {
649      ty: ty.expect("Couldn't find types"),
650      field: field.expect("Couldn't find field"),
651      attributes,
652    }
653  }
654}
655
656fn write_derives<W>(w: &mut W, can_hash: bool) -> Result<bool, Box<dyn Error>>
657where
658  W: Write,
659{
660  let did_add = ADD_DERIVE.with(|add_derive| {
661    if *add_derive.borrow() {
662      write!(
663        w,
664        "#[derive({})]\n",
665        if can_hash { DERIVES_WITH_HASH } else { DERIVES }
666      )
667      .unwrap();
668      true
669    } else {
670      false
671    }
672  });
673
674  Ok(did_add)
675}
676
677fn convert_record<W>(w: &mut W, p: Pair<Rule>) -> Result<(), Box<dyn Error>>
678where
679  W: Write,
680{
681  assert_eq!(p.as_rule(), Rule::record);
682  let mut parts = p.into_inner();
683  let mut type_name: Option<AVDLType> = None;
684  let mut record_props: Vec<AVDLRecordProp> = vec![];
685
686  while let Some(pair) = parts.next() {
687    match pair.as_rule() {
688      Rule::ty => type_name = Some(pair.into()),
689      Rule::comment => write!(w, "{}", pair.as_str())?,
690      Rule::record_prop => {
691        record_props.push(pair.into());
692      }
693      _ => panic!("Unhandled case: {:?}", pair),
694    }
695  }
696
697  // let can_hash = false;
698  let ident = type_name.expect("No Record name").to_string();
699  let can_hash = should_derive_hash(&ident);
700  println!("Can Hash {} {}", ident, can_hash);
701
702  write_derives(w, can_hash)?;
703  write!(w, "pub struct {} {{\n", ident)?;
704  for prop in record_props.into_iter() {
705    prop.write_to(w)?;
706  }
707  write!(w, "}}")?;
708
709  Ok(())
710}
711
712fn convert_interface_fn<W>(_w: &mut W, _p: Pair<Rule>) -> Result<(), Box<dyn Error>>
713where
714  W: Write,
715{
716  // Not implemented. Skipping for now
717  Ok(())
718}
719
720pub fn build_rust_code_from_avdl<W>(mut input: Pairs<Rule>, w: &mut W) -> Result<(), Box<dyn Error>>
721where
722  W: Write,
723{
724  for node in input.next().expect("Nothing to parse").into_inner() {
725    match node.as_rule() {
726      Rule::namespace_annotation => {
727        if let Some(n) = node.into_inner().next() {
728          match n.as_rule() {
729            Rule::namespace_name => write!(w, "// Namespace: {:?}\n", n.as_str())?,
730            _ => unreachable!(),
731          }
732        }
733      }
734      Rule::protocol => {
735        let mut inner = node.into_inner();
736        while let Some(n) = inner.next() {
737          match n.as_rule() {
738            Rule::protocol_name => {
739              let protocol_name = n.as_str();
740              write!(w, "// Protocol: {:?}\n", protocol_name)?;
741              write!(w, "#![allow(dead_code)]\n")?;
742              write!(w, "#![allow(non_snake_case)]\n")?;
743              write!(w, "#![allow(non_camel_case_types)]\n")?;
744              write!(w, "#![allow(unused_imports)]\n")?;
745
746              write!(w, "use serde::{{Serialize, Deserialize}};\n")?;
747              write!(w, "use serde_repr::{{Deserialize_repr, Serialize_repr}};")?;
748              write!(w, "use super::*;\n")?;
749              if protocol_name.to_ascii_lowercase() != "common" {
750                // write!(w, "use super::common::*;\n")?
751                // write!(w, "use super::*;\n")?
752              }
753            }
754            Rule::protocol_body => {
755              let mut inner = n.into_inner();
756              while let Some(protocol_body_node) = inner.next() {
757                let separator = match protocol_body_node.as_rule() {
758                  Rule::comment => "",
759                  Rule::generic_annotation | Rule::import => "\n",
760                  _ => "\n\n",
761                };
762
763                match protocol_body_node.as_rule() {
764                  Rule::comment => write!(w, "{}", protocol_body_node.as_str())?,
765                  Rule::import => convert_to_import(w, protocol_body_node)?,
766                  Rule::typedef => convert_typedef(w, protocol_body_node)?,
767                  Rule::generic_annotation => write!(w, "// {}", protocol_body_node.as_str())?,
768                  Rule::record => convert_record(w, protocol_body_node)?,
769                  Rule::enum_ty => convert_enum(w, protocol_body_node)?,
770                  Rule::variant_ty => convert_variant(w, protocol_body_node)?,
771                  Rule::fixed_ty => convert_fixed(w, protocol_body_node)?,
772                  Rule::interface_fn => convert_interface_fn(w, protocol_body_node)?,
773                  _ => {}
774                }
775
776                write!(w, "{}", separator)?;
777              }
778            }
779            _ => unreachable!(),
780          }
781          write!(w, "\n")?;
782        }
783      }
784      Rule::generic_annotation => {}
785      _ => unreachable!(),
786    }
787  }
788  Ok(())
789}
790
791#[cfg(test)]
792mod tests {
793  use super::*;
794  use crate::AVDLParser;
795  use pest::Parser;
796
797  fn test_conversion<F>(
798    r: Rule,
799    conversion_fn: F,
800    input: &str,
801    expected: &str,
802  ) -> Result<(), Box<dyn Error>>
803  where
804    F: Fn(&mut Vec<u8>, Pair<Rule>) -> Result<(), Box<dyn Error>>,
805  {
806    let mut output = vec![];
807    conversion_fn(&mut output, AVDLParser::parse(r, input)?.next().unwrap())?;
808    assert_eq!(String::from_utf8(output).unwrap(), expected);
809    Ok(())
810  }
811
812  #[test]
813  fn test_type_conversion() {
814    {
815      let input = AVDLParser::parse(Rule::ty, "map<TeamInviteID,AnnotatedTeamInvite>");
816      let ty: AVDLType = input.unwrap().next().unwrap().into();
817      assert_eq!(
818        ty.to_string(),
819        "std::collections::HashMap<TeamInviteID, AnnotatedTeamInvite>"
820      );
821    }
822    {
823      let input = AVDLParser::parse(Rule::ty, "map<int>");
824      let ty: AVDLType = input.unwrap().next().unwrap().into();
825      assert_eq!(ty.to_string(), "std::collections::HashMap<String, i32>");
826    }
827  }
828
829  #[test]
830  fn test_convert_import() -> Result<(), Box<dyn Error>> {
831    test_conversion(
832      Rule::import,
833      convert_to_import,
834      r#"import idl "chat_ui.avdl";"#,
835      "use super::chat_ui::*;",
836    )
837    .unwrap();
838    test_conversion(
839      Rule::import,
840      convert_to_import,
841      r#"import idl "github.com/keybase/client/go/protocol/gregor1" as gregor1;"#,
842      "use crate::protocol::gregor1;",
843    )
844    .unwrap();
845    test_conversion(
846      Rule::import,
847      convert_to_import,
848      r#"import idl "github.com/keybase/client/go/protocol/gregor1" as otherGregor;"#,
849      "use crate::protocol::gregor1 as otherGregor;",
850    )
851    .unwrap();
852
853    Ok(())
854  }
855
856  #[test]
857  fn test_interface_fn() {
858    // test_conversion(
859    //   Rule::typedef,
860    //   convert_typedef,
861    //   r#"GetInboxAndUnboxLocalRes getInboxAndUnboxLocal(union { null, GetInboxLocalQuery} query, union { null, Pagination } pagination, keybase1.TLFIdentifyBehavior identifyBehavior);"#,
862    //   r#"fn getInboxAndUnboxLocal(&self, query: Option<GetInboxLocalQuery> query, pagination: Option<Pagination>, identifyBehavior: keybase1.TLFIdentifyBehavior) -> GetInboxAndUnboxLocalRes;"#,
863    // )
864    // .unwrap();
865    test_conversion(
866      Rule::interface_fn,
867      convert_interface_fn,
868      r#"GetInboxAndUnboxLocalRes getInboxAndUnboxLocal(union { null, GetInboxLocalQuery} query, union { null, Pagination } pagination, keybase1.TLFIdentifyBehavior identifyBehavior);"#,
869      r#""#,
870    )
871    .unwrap();
872    test_conversion(
873      Rule::interface_fn,
874      convert_interface_fn,
875      r#"  UnreadlineRes getUnreadline(int sessionID, ConversationID convID,
876MessageID readMsgID, keybase1.TLFIdentifyBehavior identifyBehavior);"#,
877      r#""#,
878    )
879    .unwrap();
880    test_conversion(
881      Rule::interface_fn,
882      convert_interface_fn,
883      r#"    OutboxID generateOutboxID();"#,
884      r#""#,
885    )
886    .unwrap();
887
888    test_conversion(
889      Rule::interface_fn,
890      convert_interface_fn,
891      r#"  void start(int sessionID, string username, IdentifyReason reason, boolean forceDisplay=false);"#,
892      r#""#,
893    )
894    .unwrap();
895
896    test_conversion(
897      Rule::interface_fn,
898      convert_interface_fn,
899      r#"  PostRemoteRes postRemote(
900    ConversationID conversationID,
901    MessageBoxed messageBoxed,
902    array<gregor1.UID> atMentions,
903    ChannelMention channelMention,
904    union { null, TopicNameState } topicNameState,
905    // Add any atMentions to the conversation automatically with the given
906    // status
907    union { null, ConversationMemberStatus } joinMentionsAs
908  );"#,
909      r#""#,
910    )
911    .unwrap();
912
913    test_conversion(
914      Rule::interface_fn,
915      convert_interface_fn,
916      r#"  void chatAttachmentDownloadProgress(int sessionID, long bytesComplete, long bytesTotal) oneway;"#,
917      r#""#,
918    )
919    .unwrap();
920  }
921
922  #[test]
923  fn test_typedef() {
924    test_conversion(
925      Rule::typedef,
926      convert_typedef,
927      r#"@typedef("bytes")  record ThreadID {}"#,
928      // TODO fix this when we support bytes
929      // "pub type ThreadID = Vec<u8>;",
930      "pub type ThreadID = String;",
931    )
932    .unwrap();
933  }
934
935  #[test]
936  fn test_enum() {
937    test_conversion(
938      Rule::enum_ty,
939      convert_enum,
940      r#"enum RetentionPolicyType {
941  NONE_0,
942  RETAIN_1, // Keep messages forever
943  EXPIRE_2, // Delete after a while
944  INHERIT_3, // Use the team's policy
945  EPHEMERAL_4 // Force all messages to be exploding.
946}"#,
947      &format!(
948        "#[derive({})]\n#[repr(u8)]\npub enum RetentionPolicyType {{
949  None_0,
950  Retain_1, // Keep messages forever
951  Expire_2, // Delete after a while
952  Inherit_3, // Use the team's policy
953  Ephemeral_4, // Force all messages to be exploding.
954}}",
955        DERIVES_WITH_HASH_ENUM
956      ),
957    )
958    .unwrap();
959  }
960
961  #[test]
962  fn test_record() {
963    set_add_derive(true);
964    test_conversion(
965      Rule::record,
966      convert_record,
967      r#"record InboxVersInfo {
968  gregor1.UID uid;
969  gregor1.UID type;
970  @optional(true)
971  InboxVers vers;
972}"#,
973      &format!(
974        r#"#[derive({})]
975pub struct InboxVersInfo {{
976  pub uid: Option<gregor1::UID>,
977  #[serde(rename = "type")]
978  pub ty: Option<gregor1::UID>,
979  pub vers: Option<InboxVers>,
980}}"#,
981        format!("{}", DERIVES)
982      ),
983    )
984    .unwrap();
985
986    test_conversion(
987      Rule::record,
988      convert_record,
989      r#"record InboxVersInfo {
990    @mpackkey("b") @jsonkey("b")
991    union { null, gregor1.UID } botUID;
992    @mpackkey("c") @jsonkey("c")
993    InboxVers vers;
994}"#,
995      &format!(
996        r#"#[derive({})]
997pub struct InboxVersInfo {{
998  // @mpackkey("b")
999  #[serde(rename = "b")]
1000  pub botUID: Option<gregor1::UID>,
1001  // @mpackkey("c")
1002  #[serde(rename = "c")]
1003  pub vers: Option<InboxVers>,
1004}}"#,
1005        format!("{}", DERIVES)
1006      ),
1007    )
1008    .unwrap();
1009
1010    set_derive_hash_for_ident(vec![String::from("ConvSummary")]);
1011    test_conversion(
1012      Rule::record,
1013      convert_record,
1014      r#"record ConvSummary {
1015  @optional(true)
1016  array<string> supersedes;
1017}"#,
1018      &format!(
1019        "#[derive({})]
1020pub struct ConvSummary {{
1021  #[serde(default)]
1022  pub supersedes: Option<Vec<String>>,
1023}}",
1024        format!("{}, Hash, PartialEq, Eq", DERIVES)
1025      ),
1026    )
1027    .unwrap();
1028    set_derive_hash_for_ident(vec![]);
1029
1030    test_conversion(
1031      Rule::record,
1032      convert_record,
1033      r#"record ConvSummary {
1034  array<string> supersedes;
1035}"#,
1036      &format!(
1037        "#[derive({})]
1038pub struct ConvSummary {{
1039  #[serde(default)]
1040  pub supersedes: Option<Vec<String>>,
1041}}",
1042        format!("{}", DERIVES)
1043      ),
1044    )
1045    .unwrap();
1046
1047    test_conversion(
1048      Rule::record,
1049      convert_record,
1050      r#"record ConvSummary {
1051  @optional(true)
1052  array<f32> supersedes;
1053}"#,
1054      &format!(
1055        "#[derive({})]
1056pub struct ConvSummary {{
1057  #[serde(default)]
1058  pub supersedes: Option<Vec<f32>>,
1059}}",
1060        format!("{}", DERIVES)
1061      ),
1062    )
1063    .unwrap();
1064  }
1065
1066  #[test]
1067  fn test_fixed() {
1068    test_conversion(
1069      Rule::fixed_ty,
1070      convert_fixed,
1071      r#"fixed Bytes32(32);"#,
1072      // "pub type Bytes32 = [u8;32];",
1073      "pub type Bytes32 = Vec<u8>;",
1074    )
1075    .unwrap();
1076  }
1077
1078  #[test]
1079  fn test_variant() {
1080    set_add_derive(true);
1081    test_conversion(
1082      Rule::variant_ty,
1083      convert_variant,
1084      r#"variant AssetMetadata switch (AssetMetadataType assetType) {
1085  case IMAGE: AssetMetadataImage;
1086  case VIDEO: AssetMetadataVideo;
1087  case AUDIO: AssetMetadataAudio;
1088}"#,
1089      &format!(
1090        "#[derive({})]\n#[serde(untagged)]\npub enum AssetMetadata {{
1091  Image {{image: AssetMetadataImage}},
1092  Video {{video: AssetMetadataVideo}},
1093  Audio {{audio: AssetMetadataAudio}},
1094}}",
1095        DERIVES
1096      ),
1097    )
1098    .unwrap();
1099    test_conversion(
1100      Rule::variant_ty,
1101      convert_variant,
1102      r#"variant AssetMetadata switch (AssetMetadataType assetType) {
1103    case IMAGE: AssetMetadataImage;
1104    case MOVE: Thing;
1105    default: void; // Note, if badged, we should urge an upgrade here.
1106}"#,
1107      &format!(
1108        "#[derive({})]\n#[serde(untagged)]\npub enum AssetMetadata {{
1109  Image {{image: AssetMetadataImage}},
1110  Move {{r#move: Thing}},
1111  Default {{}},
1112}}",
1113        DERIVES
1114      ),
1115    )
1116    .unwrap();
1117  }
1118
1119  #[test]
1120  fn test_import_bug() {
1121    assert_eq!(
1122      "super::super::keybase1;",
1123      convert_path_str_to_rust_mod("../keybase1", "keybase1")
1124    );
1125  }
1126}