use super::*;
#[derive(Debug, PartialEq)]
pub struct UnionSwitch<'a> {
pub var_name: String,
pub var_type: BasicType<'a>,
}
#[derive(Debug, PartialEq)]
pub struct Union<'a> {
pub name: String,
pub cases: Vec<UnionCase<'a>>,
pub default: Option<UnionCase<'a>>,
pub void_cases: Vec<String>,
pub switch: UnionSwitch<'a>,
}
impl<'a> Union<'a> {
pub fn new(vs: Vec<Node<'a>>) -> Self {
let name = vs[0].ident_str().to_string();
let mut cases = Vec::new();
let mut void_cases = Vec::new();
let mut default = None;
let switch = UnionSwitch {
var_name: vs[2].ident_str().to_string(),
var_type: BasicType::from(vs[1].ident_str().to_string()),
};
let mut case_values = Vec::new();
for v in vs.into_iter().skip(3) {
let mut is_default_case = false;
let stmt = match v {
Node::UnionCase(nodes) => CaseStmt::parse(case_values, nodes),
Node::UnionDefault(nodes) => {
is_default_case = true;
case_values.push("default".to_string());
CaseStmt::parse(case_values, nodes)
}
v => panic!("unexpected token type for union {:?}", v),
};
match stmt {
CaseStmt::Defined(c) if is_default_case => default = Some(c),
CaseStmt::Defined(c) => cases.push(c),
CaseStmt::Fallthrough(values) => {
case_values = values;
continue;
}
CaseStmt::Void(values) => void_cases.extend_from_slice(&values),
}
case_values = Vec::new()
}
Union {
name,
cases,
default,
void_cases,
switch,
}
}
pub fn name(&self) -> &str {
&self.name
}
}
impl<'a> CompoundType for Union<'a> {
fn inner_types(&self) -> Vec<&ArrayType<BasicType>> {
self.cases
.iter()
.chain(self.default.iter())
.map(|f| &f.field_value)
.collect()
}
fn contains_opaque(&self) -> bool {
self.cases
.iter()
.chain(self.default.iter())
.any(|f| f.contains_opaque())
}
}
#[derive(Debug, PartialEq)]
pub struct UnionCase<'a> {
pub case_values: Vec<String>,
pub field_name: String,
pub field_value: ArrayType<BasicType<'a>>,
}
impl<'a> UnionCase<'a> {
pub fn new(case_values: Vec<String>, field: Vec<Node<'a>>) -> Self {
match field.as_slice() {
[Node::Type(t), Node::Type(BasicType::Ident(l))] => Self {
case_values,
field_name: l.to_string(),
field_value: ArrayType::None(t.to_owned()),
},
_ => panic!("invalid number of union field tokens"),
}
}
pub fn contains_opaque(&self) -> bool {
match self.field_value.unwrap_array() {
BasicType::Opaque => true,
_ => false,
}
}
}
enum CaseStmt<'a> {
Fallthrough(Vec<String>),
Defined(UnionCase<'a>),
Void(Vec<String>),
}
impl<'a> CaseStmt<'a> {
fn parse(mut case_values: Vec<String>, mut nodes: Vec<Node<'a>>) -> Self {
match nodes.remove(0) {
Node::Type(t) => case_values.push(t.as_str().to_string()),
Node::UnionVoid => {
return Self::Void(case_values);
}
Node::UnionDataField(nodes) => {
return Self::Defined(UnionCase::new(case_values, nodes));
}
_ => unreachable!(),
};
if nodes.len() == 0 {
return Self::Fallthrough(case_values);
}
match nodes.remove(0) {
Node::UnionDataField(nodes) => Self::Defined(UnionCase::new(case_values, nodes)),
Node::UnionVoid => return Self::Void(case_values),
_ => unreachable!(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{walk, Rule, XDRParser};
use pest::Parser;
use std::borrow::Cow;
macro_rules! parse {
($input: expr) => {{
let ast = XDRParser::parse(Rule::item, $input)
.unwrap()
.next()
.unwrap();
let root = walk(ast).unwrap();
let union = root.into_inner().remove(0);
match union {
Node::Union(u) => u,
_ => panic!("not a union in ast root"),
}
}};
}
#[test]
fn test_union() {
let got = parse!(
r#"
union createhow4 switch (createmode4 mode) {
case GUARDED4:
fattr4 createattrs;
case EXCLUSIVE4:
verifier4 createverf;
};"#
);
assert_eq!(got.name, "createhow4");
assert_eq!(got.default, None);
assert_eq!(got.void_cases.len(), 0);
assert_eq!(got.cases.len(), 2);
assert_eq!(&got.cases[0].case_values, &["GUARDED4"]);
assert_eq!(got.cases[0].field_name, "createattrs");
assert_eq!(
got.cases[0].field_value,
ArrayType::None(BasicType::Ident(Cow::from("fattr4")))
);
assert_eq!(&got.cases[1].case_values, &["EXCLUSIVE4"]);
assert_eq!(got.cases[1].field_name, "createverf");
assert_eq!(
got.cases[1].field_value,
ArrayType::None(BasicType::Ident(Cow::from("verifier4")))
);
assert_eq!(got.switch.var_name, "mode");
assert_eq!(
got.switch.var_type,
BasicType::Ident(Cow::from("createmode4"))
);
}
#[test]
fn test_union_fallthrough() {
let got = parse!(
r#"
union createhow4 switch (createmode4 mode) {
case UNCHECKED4:
case GUARDED4:
fattr4 createattrs;
case EXCLUSIVE4:
verifier4 createverf;
};"#
);
assert_eq!(got.name, "createhow4");
assert_eq!(got.default, None);
assert_eq!(got.void_cases.len(), 0);
assert_eq!(got.cases.len(), 2);
assert_eq!(&got.cases[0].case_values, &["UNCHECKED4", "GUARDED4"]);
assert_eq!(got.cases[0].field_name, "createattrs");
assert_eq!(
got.cases[0].field_value,
ArrayType::None(BasicType::Ident(Cow::from("fattr4")))
);
assert_eq!(&got.cases[1].case_values, &["EXCLUSIVE4"]);
assert_eq!(got.cases[1].field_name, "createverf");
assert_eq!(
got.cases[1].field_value,
ArrayType::None(BasicType::Ident(Cow::from("verifier4")))
);
assert_eq!(got.switch.var_name, "mode");
assert_eq!(
got.switch.var_type,
BasicType::Ident(Cow::from("createmode4"))
);
}
#[test]
fn test_union_void_default() {
let got = parse!(
r#"
union LOCKU4res switch (nfsstat4 status) {
case NFS4_OK:
stateid4 lock_stateid;
default:
void;
};"#
);
assert_eq!(got.name, "LOCKU4res");
assert_eq!(got.default, None);
assert_eq!(got.cases.len(), 1);
assert_eq!(&got.cases[0].case_values, &["NFS4_OK"]);
assert_eq!(got.cases[0].field_name, "lock_stateid");
assert_eq!(
got.cases[0].field_value,
ArrayType::None(BasicType::Ident(Cow::from("stateid4")))
);
assert_eq!(got.void_cases.len(), 1);
assert_eq!(&got.void_cases, &["default"]);
assert_eq!(got.switch.var_name, "status");
assert_eq!(got.switch.var_type, BasicType::Ident(Cow::from("nfsstat4")));
}
#[test]
fn test_union_default() {
let got = parse!(
r#"
union LOCKU4res switch (nfsstat4 status) {
case NFS4_OK:
stateid4 lock_stateid;
default:
type_name field_name;
};"#
);
assert_eq!(got.name, "LOCKU4res");
assert_eq!(got.cases.len(), 1);
assert_eq!(&got.cases[0].case_values, &["NFS4_OK"]);
assert_eq!(got.cases[0].field_name, "lock_stateid");
assert_eq!(
got.cases[0].field_value,
ArrayType::None(BasicType::Ident(Cow::from("stateid4")))
);
assert_eq!(got.void_cases.len(), 0);
let default = &got.default.unwrap();
assert_eq!(default.case_values, &["default"]);
assert_eq!(default.field_name, "field_name");
assert_eq!(
default.field_value,
ArrayType::None(BasicType::Ident(Cow::from("type_name")))
);
assert_eq!(got.switch.var_name, "status");
assert_eq!(got.switch.var_type, BasicType::Ident(Cow::from("nfsstat4")));
}
#[test]
fn test_union_case_void() {
let got = parse!(
r#"
union LOCKU4res switch (nfsstat4 status) {
case NFS4_OK:
stateid4 lock_stateid;
case something:
void;
default:
type_name field_name;
};"#
);
assert_eq!(got.name, "LOCKU4res");
assert_eq!(got.cases.len(), 1);
assert_eq!(&got.cases[0].case_values, &["NFS4_OK"]);
assert_eq!(got.cases[0].field_name, "lock_stateid");
assert_eq!(
got.cases[0].field_value,
ArrayType::None(BasicType::Ident(Cow::from("stateid4")))
);
assert_eq!(got.void_cases, &["something"]);
let default = &got.default.unwrap();
assert_eq!(default.case_values, &["default"]);
assert_eq!(default.field_name, "field_name");
assert_eq!(
default.field_value,
ArrayType::None(BasicType::Ident(Cow::from("type_name")))
);
assert_eq!(got.switch.var_name, "status");
assert_eq!(got.switch.var_type, BasicType::Ident(Cow::from("nfsstat4")));
}
#[test]
fn test_union_case_void_fallthrough() {
let got = parse!(
r#"
union LOCKU4res switch (nfsstat4 status) {
case NFS4_OK:
stateid4 lock_stateid;
case another:
case something:
void;
default:
type_name field_name;
};"#
);
assert_eq!(got.name, "LOCKU4res");
assert_eq!(got.cases.len(), 1);
assert_eq!(&got.cases[0].case_values, &["NFS4_OK"]);
assert_eq!(got.cases[0].field_name, "lock_stateid");
assert_eq!(
got.cases[0].field_value,
ArrayType::None(BasicType::Ident(Cow::from("stateid4")))
);
assert_eq!(got.void_cases, &["another", "something",]);
let default = &got.default.unwrap();
assert_eq!(default.case_values, &["default"]);
assert_eq!(default.field_name, "field_name");
assert_eq!(
default.field_value,
ArrayType::None(BasicType::Ident(Cow::from("type_name")))
);
assert_eq!(got.switch.var_name, "status");
assert_eq!(got.switch.var_type, BasicType::Ident(Cow::from("nfsstat4")));
}
}