airs_types/tensor/data_type/
der.rs

1use std::fmt::Formatter;
2
3use serde::de::{Error, Visitor};
4
5use super::*;
6
7struct DataTypeVisitor;
8
9impl<'de> Deserialize<'de> for DataType {
10    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: Deserializer<'de> {
11        deserializer.deserialize_any(DataTypeVisitor)
12    }
13}
14
15impl<'de> Visitor<'de> for DataTypeVisitor {
16    type Value = DataType;
17
18    fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
19        formatter.write_str("except string")
20    }
21    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E> where E: Error {
22        if eq_insensitive(v, &["f32", "fp32", "float32"]) {
23            Ok(DataType::Float32)
24        } else if eq_insensitive(v, &["u8", "uint8", "unsigned8"]) {
25            Ok(DataType::Unsigned8)
26        } else {
27            Err(Error::custom(format!("unknown data type: {}", v)))
28        }
29    }
30}
31
32#[inline(always)]
33fn eq_insensitive(input: &str, set: &[&str]) -> bool {
34    for s in set {
35        if s.eq_ignore_ascii_case(input) {
36            return true;
37        }
38    }
39    false
40}