use nom::IResult;
use std::collections::HashMap;
use std::io::Result;
#[derive(PartialEq, Eq, Debug)]
pub enum DType {
Plain {
ty: String,
shape: Vec<u64>
},
Record(Vec<Field>)
}
#[derive(PartialEq, Eq, Debug)]
pub struct Field {
pub name: String,
pub dtype: DType
}
impl DType {
pub fn descr(&self) -> String {
use DType::*;
match *self {
Record(ref fields) =>
fields.iter()
.map(|&Field { ref name, ref dtype }|
match *dtype {
Plain { ref ty, ref shape } =>
if shape.len() == 0 {
format!("('{}', '{}'), ", name, ty)
} else {
let shape_str = shape.iter().fold(String::new(), |o,n| o + &format!("{},", n));
format!("('{}', '{}', ({})), ", name, ty, shape_str)
},
Record(_) => unimplemented!("nested record dtypes")
}
)
.fold("[".to_string(), |o, n| o + &n) + "]",
Plain { ref ty, .. } => format!("'{}'", ty),
}
}
pub fn from_descr(descr: Value) -> Result<Self> {
use DType::*;
match descr {
Value::String(string) => Ok(Plain { ty: string, shape: vec![] }),
Value::List(ref list) => Ok(Record(convert_list_to_record_fields(list)?)),
_ => invalid_data("must be string or list")
}
}
}
fn convert_list_to_record_fields(values: &[Value]) -> Result<Vec<Field>> {
first_error(values.iter()
.map(|value| match *value {
Value::List(ref tuple) => convert_tuple_to_record_field(tuple),
_ => invalid_data("list must contain list or tuple")
}))
}
fn convert_tuple_to_record_field(tuple: &[Value]) -> Result<Field> {
use self::Value::String;
match tuple.len() {
2 | 3 => match (&tuple[0], &tuple[1]) {
(&String(ref name), &String(ref dtype)) =>
Ok(Field { name: name.clone(), dtype: DType::Plain {
ty: dtype.clone(),
shape: if tuple.len() == 2 {
vec![]
} else {
convert_value_to_shape(&tuple[2])?
}
} }),
_ => invalid_data("list entry must contain strings for id and dtype")
},
_ => invalid_data("list entry must contain 2 or 3 items")
}
}
fn convert_value_to_shape(field: &Value) -> Result<Vec<u64>> {
if let Value::List(ref lengths) = *field {
first_error(lengths.iter().map(convert_value_to_positive_integer))
} else {
invalid_data("shape must be list or tuple")
}
}
fn convert_value_to_positive_integer(number: &Value) -> Result<u64> {
if let Value::Integer(number) = *number {
if number > 0 {
Ok(number as u64)
} else {
invalid_data("number must be positive")
}
} else {
invalid_data("must be a number")
}
}
fn first_error<I, T>(results: I) -> Result<Vec<T>>
where I: IntoIterator<Item=Result<T>>
{
let mut vector = vec![];
for result in results {
vector.push(result?);
}
Ok(vector)
}
fn invalid_data<T>(message: &str) -> Result<T> {
use std::io::{Error, ErrorKind};
Err(Error::new(ErrorKind::InvalidData, message.to_string()))
}
#[derive(PartialEq, Eq, Debug, Clone)]
pub enum Value {
String(String),
Integer(i64),
Bool(bool),
List(Vec<Value>),
Map(HashMap<String,Value>),
}
pub fn parse_header(bs: &[u8]) -> IResult<&[u8], Value> {
parser::header(bs)
}
mod parser {
use super::Value;
use nom::*;
named!(pub header<Value>,
do_parse!(
tag!(&[0x93u8]) >>
tag!(b"NUMPY") >>
tag!(&[0x01u8, 0x00]) >>
hdr: length_value!(le_u16, item) >>
(hdr)
)
);
named!(pub integer<Value>,
map!(
map_res!(
map_res!(
ws!(digit),
::std::str::from_utf8
),
::std::str::FromStr::from_str
),
Value::Integer
)
);
named!(pub boolean<Value>,
ws!(alt!(
tag!("True") => { |_| Value::Bool(true) } |
tag!("False") => { |_| Value::Bool(false) }
))
);
named!(pub string<Value>,
map!(
map!(
map_res!(
ws!(alt!(
delimited!(tag!("\""),
is_not_s!("\""),
tag!("\"")) |
delimited!(tag!("\'"),
is_not_s!("\'"),
tag!("\'"))
)),
::std::str::from_utf8
),
|s: &str| s.to_string()
),
Value::String
)
);
named!(pub item<Value>, alt!(integer | boolean | string | list | map));
named!(pub list<Value>,
map!(
ws!(alt!(
delimited!(tag!("["),
terminated!(separated_list!(tag!(","), item), alt!(tag!(",") | tag!(""))),
tag!("]")) |
delimited!(tag!("("),
terminated!(separated_list!(tag!(","), item), alt!(tag!(",") | tag!(""))),
tag!(")"))
)),
Value::List
)
);
named!(pub map<Value>,
map!(
ws!(
delimited!(tag!("{"),
terminated!(separated_list!(tag!(","),
separated_pair!(map_opt!(string, |it| match it { Value::String(s) => Some(s), _ => None }), tag!(":"), item)
), alt!(tag!(",") | tag!(""))),
tag!("}"))
),
|v: Vec<_>| Value::Map(v.into_iter().collect())
)
);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn description_of_record_array_as_python_list_of_tuples() {
let dtype = DType::Record(vec![
Field {
name: "float".to_string(),
dtype: DType::Plain { ty: ">f4".to_string(), shape: vec![] }
},
Field {
name: "byte".to_string(),
dtype: DType::Plain { ty: "<u1".to_string(), shape: vec![] }
}
]);
let expected = "[('float', '>f4'), ('byte', '<u1'), ]";
assert_eq!(dtype.descr(), expected);
}
#[test]
fn description_of_unstructured_primitive_array() {
let dtype = DType::Plain { ty: ">f8".to_string(), shape: vec![] };
assert_eq!(dtype.descr(), "'>f8'");
}
#[test]
fn converts_simple_description_to_record_dtype() {
let dtype = ">f8".to_string();
assert_eq!(
DType::from_descr(Value::String(dtype.clone())).unwrap(),
DType::Plain { ty: dtype, shape: vec![] }
);
}
#[test]
fn converts_record_description_to_record_dtype() {
let descr = parse("[('a', '<u2'), ('b', '<f4')]");
let expected_dtype = DType::Record(vec![
Field {
name: "a".to_string(),
dtype: DType::Plain { ty: "<u2".to_string(), shape: vec![] }
},
Field {
name: "b".to_string(),
dtype: DType::Plain { ty: "<f4".to_string(), shape: vec![] }
}
]);
assert_eq!(DType::from_descr(descr).unwrap(), expected_dtype);
}
#[test]
fn record_description_with_onedimenional_field_shape_declaration() {
let descr = parse("[('a', '>f8', (1,))]");
let expected_dtype = DType::Record(vec![
Field {
name: "a".to_string(),
dtype: DType::Plain { ty: ">f8".to_string(), shape: vec![1] }
}
]);
assert_eq!(DType::from_descr(descr).unwrap(), expected_dtype);
}
#[test]
fn errors_on_value_variants_that_cannot_be_converted() {
let no_dtype = Value::Bool(false);
assert!(DType::from_descr(no_dtype).is_err());
}
#[test]
fn errors_when_record_list_does_not_contain_lists() {
let faulty_list = parse("['a', 123]");
assert!(DType::from_descr(faulty_list).is_err());
}
#[test]
fn errors_when_record_list_entry_contains_too_few_items() {
let faulty_list = parse("[('a')]");
assert!(DType::from_descr(faulty_list).is_err());
}
#[test]
fn errors_when_record_list_entry_contains_too_many_items() {
let faulty_list = parse("[('a', 1, 2, 3)]");
assert!(DType::from_descr(faulty_list).is_err());
}
#[test]
fn errors_when_record_list_entry_contains_non_strings_for_id_or_dtype() {
let faulty_list = parse("[(1, 2)]");
assert!(DType::from_descr(faulty_list).is_err());
}
#[test]
fn errors_when_shape_is_not_a_list() {
let no_shape = parse("1");
assert!(convert_value_to_shape(&no_shape).is_err());
}
#[test]
fn errors_when_shape_number_is_not_a_number() {
let no_number = parse("[]");
assert!(convert_value_to_positive_integer(&no_number).is_err());
}
#[test]
fn errors_when_shape_number_is_not_positive() {
assert!(convert_value_to_positive_integer(&parse("0")).is_err());
}
fn parse(source: &str) -> Value {
parser::item(source.as_bytes())
.to_result()
.expect("could not parse Python expression")
}
}