use crate::model::{ParsingContext, TensorPlusPath};
use crate::pb::tensor_proto::DataType;
use crate::pb::*;
use prost::Message;
use std::convert::{TryFrom, TryInto};
use std::fs;
use std::path::{Path, PathBuf};
use tract_hir::internal::*;
impl TryFrom<DataType> for DatumType {
    type Error = TractError;
    fn try_from(t: DataType) -> TractResult<DatumType> {
        match t {
            DataType::Bool => Ok(DatumType::Bool),
            DataType::Uint8 => Ok(DatumType::U8),
            DataType::Uint16 => Ok(DatumType::U16),
            DataType::Uint32 => Ok(DatumType::U32),
            DataType::Uint64 => Ok(DatumType::U64),
            DataType::Int8 => Ok(DatumType::I8),
            DataType::Int16 => Ok(DatumType::I16),
            DataType::Int32 => Ok(DatumType::I32),
            DataType::Int64 => Ok(DatumType::I64),
            DataType::Float16 => Ok(DatumType::F16),
            DataType::Float => Ok(DatumType::F32),
            DataType::Double => Ok(DatumType::F64),
            DataType::String => Ok(DatumType::String),
            _ => bail!("Unknown DatumType {:?}", t),
        }
    }
}
pub fn translate_inference_fact(
    ctx: &ParsingContext,
    t: &type_proto::Tensor,
) -> TractResult<InferenceFact> {
    let mut fact = InferenceFact::default();
    fact = fact.with_datum_type(DataType::from_i32(t.elem_type).unwrap().try_into()?);
    if let Some(shape) = &t.shape {
        let shape: TVec<DimFact> = shape
            .dim
            .iter()
            .map(|d| match &d.value {
                Some(tensor_shape_proto::dimension::Value::DimValue(v)) if *v >= 0 => {
                    DimFact::from(v.to_dim())
                }
                Some(tensor_shape_proto::dimension::Value::DimParam(v)) => {
                    let sym = ctx.symbol_table.sym(v);
                    DimFact::from(sym.to_dim())
                }
                _ => DimFact::default(),
            })
            .collect();
        fact = fact.with_shape(ShapeFactoid::closed(shape));
    }
    Ok(fact)
}
#[cfg(target_family="wasm")]
fn extend_bytes_from_path(buf: &mut Vec<u8>, p: impl AsRef<Path>) -> TractResult<()> {
    use std::io::BufRead;
    let file = fs::File::open(p)?;
    let file_size = file.metadata()?.len() as usize;
    if buf.capacity() < file_size + buf.len() {
        buf.reserve(file_size);
    }
    let mut reader = std::io::BufReader::new(file);
    while reader.fill_buf()?.len() > 0 {
        buf.extend_from_slice(reader.buffer());
        reader.consume(reader.buffer().len());
    }
    Ok(())
}
#[cfg(all(any(windows, unix), not(target_os = "emscripten")))]
fn extend_bytes_from_path(buf: &mut Vec<u8>, p: impl AsRef<Path>) -> TractResult<()> {
    let file = fs::File::open(p)?;
    let mmap = unsafe { memmap2::Mmap::map(&file)? };
    buf.extend_from_slice(&mmap);
    Ok(())
}
fn get_external_resources(t: &TensorProto, path: &str) -> TractResult<Vec<u8>> {
    let mut tensor_data: Vec<u8> = Vec::new();
    trace!("number of external file needed for this tensor: {}", t.external_data.len());
    for external_data in t.external_data.iter()
    {
        let p = PathBuf::from(format!("{}/{}", path, external_data.value));
        trace!("external file detected: {:?}", p);
        extend_bytes_from_path(&mut tensor_data, p)?;
        trace!("external file loaded");
    }
    Ok(tensor_data)
}
fn create_tensor(shape: Vec<usize>, dt: DatumType, data: &[u8]) -> TractResult<Tensor> {
    unsafe {
        match dt {
            DatumType::U8 => Tensor::from_raw::<u8>(&shape, data),
            DatumType::U16 => Tensor::from_raw::<u16>(&shape, data),
            DatumType::U32 => Tensor::from_raw::<u32>(&shape, data),
            DatumType::U64 => Tensor::from_raw::<u64>(&shape, data),
            DatumType::I8 => Tensor::from_raw::<i8>(&shape, data),
            DatumType::I16 => Tensor::from_raw::<i16>(&shape, data),
            DatumType::I32 => Tensor::from_raw::<i32>(&shape, data),
            DatumType::I64 => Tensor::from_raw::<i64>(&shape, data),
            DatumType::F16 => Tensor::from_raw::<f16>(&shape, data),
            DatumType::F32 => Tensor::from_raw::<f32>(&shape, data),
            DatumType::F64 => Tensor::from_raw::<f64>(&shape, data),
            DatumType::Bool => Ok(Tensor::from_raw::<u8>(&shape, data)?
                .into_array::<u8>()?
                .mapv(|x| x != 0)
                .into()),
            _ => unimplemented!("FIXME, raw tensor loading"),
        }
    }
}
fn common_tryfrom(t: &TensorProto, path: Option<&str>) -> TractResult<Tensor> {
    let dt = DataType::from_i32(t.data_type).unwrap().try_into()?;
    let shape: Vec<usize> = t.dims.iter().map(|&i| i as usize).collect();
    let is_external = t.data_location.is_some() && t.data_location == Some(1);
    if t.raw_data.len() > 0 {
        create_tensor(shape, dt, &t.raw_data)
    } else if is_external {
        if let Some(model_path) = path {
            let external_data = get_external_resources(t, model_path)?;
            create_tensor(shape, dt, &external_data)
        } else {
            bail!("no model path was specified in the parsing context, yet external data was detected. aborting");
        }
    } else {
        use tract_ndarray::Array;
        let it = match dt {
            DatumType::Bool => {
                Array::from_shape_vec(&*shape, t.int32_data.iter().map(|&x| x != 0).collect())?
                    .into()
            }
            DatumType::U8 => {
                Array::from_shape_vec(&*shape, t.int32_data.iter().map(|&x| x as u8).collect())?
                    .into()
            }
            DatumType::U16 => {
                Array::from_shape_vec(&*shape, t.int32_data.iter().map(|&x| x as u16).collect())?
                    .into()
            }
            DatumType::U32 => Array::from_shape_vec(&*shape, t.int32_data.to_vec())?.into(),
            DatumType::U64 => Array::from_shape_vec(&*shape, t.int64_data.to_vec())?.into(),
            DatumType::I8 => {
                Array::from_shape_vec(&*shape, t.int32_data.iter().map(|&x| x as i8).collect())?
                    .into()
            }
            DatumType::I16 => {
                Array::from_shape_vec(&*shape, t.int32_data.iter().map(|&x| x as i16).collect())?
                    .into()
            }
            DatumType::I32 => Array::from_shape_vec(&*shape, t.int32_data.to_vec())?.into(),
            DatumType::I64 => Array::from_shape_vec(&*shape, t.int64_data.to_vec())?.into(),
            DatumType::F32 => Array::from_shape_vec(&*shape, t.float_data.to_vec())?.into(),
            DatumType::F64 => Array::from_shape_vec(&*shape, t.double_data.to_vec())?.into(),
            DatumType::String => {
                let strings = t
                    .string_data
                    .iter()
                    .cloned()
                    .map(String::from_utf8)
                    .collect::<Result<Vec<String>, _>>()
                    .context("Invalid UTF8 buffer")?;
                Array::from_shape_vec(&*shape, strings)?.into()
            }
            _ => unimplemented!("FIXME, struct tensor loading"),
        };
        Ok(it)
    }
}
impl TryFrom<TensorPlusPath<'_>> for Tensor {
    type Error = TractError;
    fn try_from(st: TensorPlusPath) -> TractResult<Tensor> {
        common_tryfrom(st.tensor, Some(st.model_path))
    }
}
impl<'a> TryFrom<&'a TensorProto> for Tensor {
    type Error = TractError;
    fn try_from(t: &TensorProto) -> TractResult<Tensor> {
        common_tryfrom(t, None)
    }
}
impl TryFrom<TensorProto> for Tensor {
    type Error = TractError;
    fn try_from(t: TensorProto) -> TractResult<Tensor> {
        (&t).try_into()
    }
}
pub fn proto_from_reader<R: ::std::io::Read>(mut r: R) -> TractResult<TensorProto> {
    let mut v = vec![];
    r.read_to_end(&mut v)?;
    let b = bytes::Bytes::from(v);
    TensorProto::decode(b).context("Can not parse protobuf input")
}
pub fn from_reader<R: ::std::io::Read>(r: R) -> TractResult<Tensor> {
    proto_from_reader(r)?.try_into()
}