airs_types/tensor/data_type/
der.rs1use std::fmt::Formatter;
2
3use serde::de::{Error, Visitor};
4
5use super::*;
6
7struct DataTypeVisitor;
8
9impl<'de> Deserialize<'de> for DataType {
10 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: Deserializer<'de> {
11 deserializer.deserialize_any(DataTypeVisitor)
12 }
13}
14
15impl<'de> Visitor<'de> for DataTypeVisitor {
16 type Value = DataType;
17
18 fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
19 formatter.write_str("except string")
20 }
21 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E> where E: Error {
22 if eq_insensitive(v, &["f32", "fp32", "float32"]) {
23 Ok(DataType::Float32)
24 } else if eq_insensitive(v, &["u8", "uint8", "unsigned8"]) {
25 Ok(DataType::Unsigned8)
26 } else {
27 Err(Error::custom(format!("unknown data type: {}", v)))
28 }
29 }
30}
31
32#[inline(always)]
33fn eq_insensitive(input: &str, set: &[&str]) -> bool {
34 for s in set {
35 if s.eq_ignore_ascii_case(input) {
36 return true;
37 }
38 }
39 false
40}