use anyhow::{anyhow, Result};
use byteorder::{BigEndian, LittleEndian, ReadBytesExt};
#[cfg(feature = "debug")]
use log::debug;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{borrow::Borrow, collections::BTreeMap, fmt::Display};
pub const FILE_MAGIC_GGML: i32 = 0x67676d6c;
pub const FILE_MAGIC_GGMF: i32 = 0x67676d66;
pub const FILE_MAGIC_GGJT: i32 = 0x67676a74;
pub const FILE_MAGIC_GGLA: i32 = 0x67676C61;
pub const FILE_MAGIC_GGUF_LE: i32 = 0x46554747;
pub const FILE_MAGIC_GGUF_BE: i32 = 0x47475546;
const GGUF_VERSION_V1: i32 = 0x00000001;
const GGUF_VERSION_V2: i32 = 0x00000002;
const GGUF_VERSION_V3: i32 = 0x00000003;
const THOUSAND: u64 = 1000;
const MILLION: u64 = 1_000_000;
const BILLION: u64 = 1_000_000_000;
fn human_number(value: u64) -> String {
    match value {
        _ if value > BILLION => format!("{:.0}B", value as f64 / BILLION as f64),
        _ if value > MILLION => format!("{:.0}M", value as f64 / MILLION as f64),
        _ if value > THOUSAND => format!("{:.0}K", value as f64 / THOUSAND as f64),
        _ => format!("{}", value),
    }
}
fn file_type(ft: u64) -> String {
    match ft {
        0 => "F32",
        1 => "F16",
        2 => "Q4_0",
        3 => "Q4_1",
        4 => "Q4_1_SOME_F16",
        5 => "Q4_2",
        6 => "Q4_3",
        7 => "Q8_0",
        8 => "Q5_0",
        9 => "Q5_1",
        10 => "Q2_K",
        11 => "Q3_K_S",
        12 => "Q3_K_M",
        13 => "Q3_K_L",
        14 => "Q4_K_S",
        15 => "Q4_K_M",
        16 => "Q5_K_S",
        17 => "Q5_K_M",
        18 => "Q6_K",
        _ => "unknown",
    }
    .to_string()
}
#[derive(Default, Debug, Clone)]
pub enum ByteOrder {
    #[default]
    LE,
    BE,
}
#[derive(Debug, Clone)]
pub enum Version {
    V1(V1),
    V2(V2),
    V3(V3),
}
#[derive(Debug, Deserialize, Default, Clone)]
pub struct V1 {
    num_tensor: u32,
    num_kv: u32,
}
#[derive(Debug, Deserialize, Default, Clone)]
pub struct V2 {
    num_tensor: u64,
    num_kv: u64,
}
#[derive(Debug, Deserialize, Default, Clone)]
pub struct V3 {
    num_tensor: u64,
    num_kv: u64,
}
pub struct GGUFContainer {
    bo: ByteOrder,
    version: Version,
    reader: Box<dyn std::io::Read + 'static>,
}
impl GGUFContainer {
    pub fn new(bo: ByteOrder, reader: Box<dyn std::io::Read>) -> Self {
        Self {
            bo,
            version: Version::V1(V1::default()),
            reader,
        }
    }
    pub fn get_version(&self) -> String {
        match &self.version {
            Version::V1(_) => String::from("v1"),
            Version::V2(_) => String::from("v2"),
            Version::V3(_) => String::from("v3"),
        }
    }
    pub fn decode(&mut self) -> Result<GGUFModel> {
        let version = match self.bo {
            ByteOrder::LE => self.reader.read_i32::<LittleEndian>().unwrap(),
            ByteOrder::BE => self.reader.read_i32::<BigEndian>().unwrap(),
        };
        #[cfg(feature = "debug")]
        {
            debug!("version {}", version);
        }
        match version {
            GGUF_VERSION_V1 => {
                let mut buffer: [u32; 2] = [0; 2];
                match self.bo {
                    ByteOrder::LE => {
                        self.reader
                            .read_u32_into::<LittleEndian>(&mut buffer)
                            .unwrap();
                    }
                    ByteOrder::BE => {
                        self.reader.read_u32_into::<BigEndian>(&mut buffer).unwrap();
                    }
                };
                self.version = Version::V1(V1 {
                    num_tensor: buffer[0],
                    num_kv: buffer[1],
                });
            }
            GGUF_VERSION_V2 | GGUF_VERSION_V3 => {
                let mut buffer: [u64; 2] = [0; 2];
                match self.bo {
                    ByteOrder::LE => {
                        self.reader
                            .read_u64_into::<LittleEndian>(&mut buffer)
                            .unwrap();
                    }
                    ByteOrder::BE => {
                        self.reader.read_u64_into::<BigEndian>(&mut buffer).unwrap();
                    }
                };
                if version == GGUF_VERSION_V2 {
                    self.version = Version::V2(V2 {
                        num_tensor: buffer[0],
                        num_kv: buffer[1],
                    });
                } else {
                    self.version = Version::V3(V3 {
                        num_tensor: buffer[0],
                        num_kv: buffer[1],
                    });
                }
            }
            invalid_version => {
                return Err(anyhow!(
                    "invalid version {}, only support version: 1 | 2 | 3",
                    invalid_version
                ))
            }
        };
        let mut model = GGUFModel {
            kv: BTreeMap::new(),
            tensors: Vec::new(),
            parameters: 0,
            bo: self.bo.clone(),
            version: self.version.clone(),
        };
        model.decode(&mut self.reader)?;
        Ok(model)
    }
}
#[derive(Debug, Clone)]
pub struct Tensor {
    pub name: String,
    pub kind: u32,
    pub offset: u64,
    pub size: u64,
    pub shape: Vec<u64>,
}
pub struct GGUFModel {
    kv: BTreeMap<String, Value>,
    tensors: Vec<Tensor>,
    parameters: u64,
    bo: ByteOrder,
    version: Version,
}
#[derive(Debug)]
pub enum MetadataValueType {
    Uint8 = 0,
    Int8 = 1,
    Uint16 = 2,
    Int16 = 3,
    Uint32 = 4,
    Int32 = 5,
    Float32 = 6,
    Bool = 7,
    String = 8,
    Array = 9,
    Uint64 = 10,
    Int64 = 11,
    Float64 = 12,
}
impl TryFrom<u32> for MetadataValueType {
    type Error = anyhow::Error;
    fn try_from(value: u32) -> Result<Self, Self::Error> {
        Ok(match value {
            0 => MetadataValueType::Uint8,
            1 => MetadataValueType::Int8,
            2 => MetadataValueType::Uint16,
            3 => MetadataValueType::Int16,
            4 => MetadataValueType::Uint32,
            5 => MetadataValueType::Int32,
            6 => MetadataValueType::Float32,
            7 => MetadataValueType::Bool,
            8 => MetadataValueType::String,
            9 => MetadataValueType::Array,
            10 => MetadataValueType::Uint64,
            11 => MetadataValueType::Int64,
            12 => MetadataValueType::Float64,
            _ => return Err(anyhow!("unsupport metadata value type")),
        })
    }
}
#[derive(Debug, Serialize)]
pub enum GGMLType {
    F32 = 0,
    F16 = 1,
    Q4_0 = 2,
    Q4_1 = 3,
    Q5_0 = 6,
    Q5_1 = 7,
    Q8_0 = 8,
    Q8_1 = 9,
    Q2K = 10,
    Q3K = 11,
    Q4K = 12,
    Q5K = 13,
    Q6K = 14,
    Q8K = 15,
    I8 = 16,
    I16 = 17,
    I32 = 18,
    Count = 19,
}
impl Display for GGMLType {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            GGMLType::F32 => write!(f, "F32"),
            GGMLType::F16 => write!(f, "F16"),
            GGMLType::Q4_0 => write!(f, "Q4_0"),
            GGMLType::Q4_1 => write!(f, "Q4_1"),
            GGMLType::Q5_0 => write!(f, "Q5_0"),
            GGMLType::Q5_1 => write!(f, "Q5_1"),
            GGMLType::Q8_0 => write!(f, "Q8_0"),
            GGMLType::Q8_1 => write!(f, "Q8_1"),
            GGMLType::Q2K => write!(f, "Q2K"),
            GGMLType::Q3K => write!(f, "Q3K"),
            GGMLType::Q4K => write!(f, "Q4K"),
            GGMLType::Q5K => write!(f, "Q5K"),
            GGMLType::Q6K => write!(f, "Q6K"),
            GGMLType::Q8K => write!(f, "Q8K"),
            GGMLType::I8 => write!(f, "I8"),
            GGMLType::I16 => write!(f, "I16"),
            GGMLType::I32 => write!(f, "I32"),
            GGMLType::Count => write!(f, "Count"),
        }
    }
}
impl TryFrom<u32> for GGMLType {
    type Error = anyhow::Error;
    fn try_from(value: u32) -> std::prelude::v1::Result<Self, Self::Error> {
        Ok(match value {
            0 => GGMLType::F32,
            1 => GGMLType::F16,
            2 => GGMLType::Q4_0,
            3 => GGMLType::Q4_1,
            6 => GGMLType::Q5_0,
            7 => GGMLType::Q5_1,
            8 => GGMLType::Q8_0,
            9 => GGMLType::Q8_1,
            10 => GGMLType::Q2K,
            11 => GGMLType::Q3K,
            12 => GGMLType::Q4K,
            13 => GGMLType::Q5K,
            14 => GGMLType::Q6K,
            15 => GGMLType::Q8K,
            16 => GGMLType::I8,
            17 => GGMLType::I16,
            18 => GGMLType::I32,
            19 => GGMLType::Count,
            _ => return Err(anyhow!("invalid GGML type")),
        })
    }
}
impl GGUFModel {
    pub(crate) fn decode(&mut self, mut reader: impl std::io::Read) -> Result<()> {
        for _i in 0..self.num_kv() {
            let key = self.read_string(&mut reader);
            let value_type: MetadataValueType = self.read_u32(&mut reader).try_into()?;
            let value = match value_type {
                MetadataValueType::Uint8 => Value::from(self.read_u8(&mut reader)),
                MetadataValueType::Int8 => Value::from(self.read_i8(&mut reader)),
                MetadataValueType::Uint16 => Value::from(self.read_u16(&mut reader)),
                MetadataValueType::Int16 => Value::from(self.read_i16(&mut reader)),
                MetadataValueType::Uint32 => Value::from(self.read_u32(&mut reader)),
                MetadataValueType::Int32 => Value::from(self.read_i32(&mut reader)),
                MetadataValueType::Float32 => Value::from(self.read_f32(&mut reader)),
                MetadataValueType::Bool => Value::from(self.read_bool(&mut reader)),
                MetadataValueType::String => Value::from(self.read_string(&mut reader)),
                MetadataValueType::Array => Value::from(self.read_array(&mut reader, 3)?),
                MetadataValueType::Uint64 => Value::from(self.read_u64(&mut reader)),
                MetadataValueType::Int64 => Value::from(self.read_i64(&mut reader)),
                MetadataValueType::Float64 => Value::from(self.read_f64(&mut reader)),
            };
            #[cfg(feature = "debug")]
            {
                debug!(
                    "kv [{}] vtype {:?} key={}, value={}",
                    _i, value_type, key, value
                );
            }
            self.kv.insert(key, value);
        }
        for _ in 0..self.num_tensor() {
            let name = self.read_string(&mut reader);
            let dims = self.read_u32(&mut reader);
            let mut shape = [1; 4];
            for i in 0..dims {
                shape[i as usize] = self.read_u64(&mut reader);
            }
            let kind = self.read_u32(&mut reader);
            let offset = self.read_u64(&mut reader);
            let block_size = match kind {
                _ if kind < 2 => 1,
                _ if kind < 10 => 32,
                _ => 256,
            };
            let ggml_type_kind: GGMLType = kind.try_into()?;
            let type_size = match ggml_type_kind {
                GGMLType::F32 => 4,
                GGMLType::F16 => 2,
                GGMLType::Q4_0 => 2 + block_size / 2,
                GGMLType::Q4_1 => 2 + 2 + block_size / 2,
                GGMLType::Q5_0 => 2 + 4 + block_size / 2,
                GGMLType::Q5_1 => 2 + 2 + 4 + block_size / 2,
                GGMLType::Q8_0 => 2 + block_size,
                GGMLType::Q8_1 => 4 + 4 + block_size,
                GGMLType::Q2K => block_size / 16 + block_size / 4 + 2 + 2,
                GGMLType::Q3K => block_size / 8 + block_size / 4 + 12 + 2,
                GGMLType::Q4K => 2 + 2 + 12 + block_size / 2,
                GGMLType::Q5K => 2 + 2 + 12 + block_size / 8 + block_size / 2,
                GGMLType::Q6K => block_size / 2 + block_size / 4 + block_size / 16 + 2,
                GGMLType::Q8K => todo!(),
                GGMLType::I8 => todo!(),
                GGMLType::I16 => todo!(),
                GGMLType::I32 => todo!(),
                GGMLType::Count => todo!(),
            };
            let parameters = shape[0] * shape[1] * shape[2] * shape[3];
            let size = parameters * type_size / block_size;
            self.tensors.push(Tensor {
                name,
                kind,
                offset,
                size,
                shape: shape.to_vec(),
            });
            self.parameters += parameters;
        }
        Ok(())
    }
    fn read_u8(&self, mut reader: impl std::io::Read) -> u8 {
        reader.read_u8().unwrap()
    }
    fn read_u32(&self, mut reader: impl std::io::Read) -> u32 {
        match self.bo {
            ByteOrder::LE => reader.read_u32::<LittleEndian>().unwrap(),
            ByteOrder::BE => reader.read_u32::<BigEndian>().unwrap(),
        }
    }
    fn read_f32(&self, mut reader: impl std::io::Read) -> f32 {
        match self.bo {
            ByteOrder::LE => reader.read_f32::<LittleEndian>().unwrap(),
            ByteOrder::BE => reader.read_f32::<BigEndian>().unwrap(),
        }
    }
    fn read_f64(&self, mut reader: impl std::io::Read) -> f64 {
        match self.bo {
            ByteOrder::LE => reader.read_f64::<LittleEndian>().unwrap(),
            ByteOrder::BE => reader.read_f64::<BigEndian>().unwrap(),
        }
    }
    fn read_u64(&self, mut reader: impl std::io::Read) -> u64 {
        match self.bo {
            ByteOrder::LE => reader.read_u64::<LittleEndian>().unwrap(),
            ByteOrder::BE => reader.read_u64::<BigEndian>().unwrap(),
        }
    }
    fn read_i8(&self, mut reader: impl std::io::Read) -> i8 {
        reader.read_i8().unwrap()
    }
    fn read_u16(&self, mut reader: impl std::io::Read) -> u16 {
        match self.bo {
            ByteOrder::LE => reader.read_u16::<LittleEndian>().unwrap(),
            ByteOrder::BE => reader.read_u16::<BigEndian>().unwrap(),
        }
    }
    fn read_i16(&self, mut reader: impl std::io::Read) -> i16 {
        match self.bo {
            ByteOrder::LE => reader.read_i16::<LittleEndian>().unwrap(),
            ByteOrder::BE => reader.read_i16::<BigEndian>().unwrap(),
        }
    }
    fn read_i32(&self, mut reader: impl std::io::Read) -> i32 {
        match self.bo {
            ByteOrder::LE => reader.read_i32::<LittleEndian>().unwrap(),
            ByteOrder::BE => reader.read_i32::<BigEndian>().unwrap(),
        }
    }
    fn read_i64(&self, mut reader: impl std::io::Read) -> i64 {
        match self.bo {
            ByteOrder::LE => reader.read_i64::<LittleEndian>().unwrap(),
            ByteOrder::BE => reader.read_i64::<BigEndian>().unwrap(),
        }
    }
    fn read_bool(&self, mut reader: impl std::io::Read) -> bool {
        reader.read_u8().unwrap() != 0
    }
    fn read_string(&self, mut reader: impl std::io::Read) -> String {
        let name_len = self.read_version_size(&mut reader);
        let mut buffer = vec![0; name_len as usize];
        reader.read_exact(&mut buffer).unwrap();
        String::from_utf8_lossy(&buffer).to_string()
    }
    fn read_array(&self, mut reader: impl std::io::Read, read_count: usize) -> Result<Vec<Value>> {
        let mut data = Vec::new();
        let item_type: MetadataValueType = self.read_u32(&mut reader).try_into()?;
        let array_len = self.read_version_size(&mut reader);
        for _ in 0..array_len {
            let value = match item_type {
                MetadataValueType::Uint8 => Value::from(self.read_u8(&mut reader)),
                MetadataValueType::Int8 => Value::from(self.read_i8(&mut reader)),
                MetadataValueType::Uint16 => Value::from(self.read_u16(&mut reader)),
                MetadataValueType::Int16 => Value::from(self.read_i16(&mut reader)),
                MetadataValueType::Uint32 => Value::from(self.read_u32(&mut reader)),
                MetadataValueType::Int32 => Value::from(self.read_i32(&mut reader)),
                MetadataValueType::Float32 => Value::from(self.read_f32(&mut reader)),
                MetadataValueType::Bool => Value::from(self.read_bool(&mut reader)),
                MetadataValueType::String => Value::from(self.read_string(&mut reader)),
                MetadataValueType::Uint64 => Value::from(self.read_u64(&mut reader)),
                MetadataValueType::Int64 => Value::from(self.read_i64(&mut reader)),
                MetadataValueType::Float64 => Value::from(self.read_f64(&mut reader)),
                _ => return Err(anyhow!("Unsupport item value type: Array")),
            };
            if read_count > 0 && data.len() < read_count {
                data.push(value);
            }
        }
        Ok(data)
    }
    fn read_version_size(&self, mut reader: impl std::io::Read) -> u64 {
        match self.version.borrow() {
            Version::V1(_) => self.read_u32(&mut reader) as u64,
            Version::V2(_) => self.read_u64(&mut reader),
            Version::V3(_) => self.read_u64(&mut reader),
        }
    }
    pub fn get_version(&self) -> String {
        match &self.version {
            Version::V1(_) => String::from("v1"),
            Version::V2(_) => String::from("v2"),
            Version::V3(_) => String::from("v3"),
        }
    }
    pub fn num_kv(&self) -> u64 {
        match &self.version {
            Version::V1(v1) => v1.num_kv as u64,
            Version::V2(v2) => v2.num_kv,
            Version::V3(v3) => v3.num_kv,
        }
    }
    pub fn num_tensor(&self) -> u64 {
        match &self.version {
            Version::V1(v1) => v1.num_tensor as u64,
            Version::V2(v2) => v2.num_tensor,
            Version::V3(v3) => v3.num_tensor,
        }
    }
    pub fn model_family(&self) -> String {
        let arch = self
            .kv
            .get("general.architecture")
            .cloned()
            .unwrap_or(Value::from("unknown"));
        match arch {
            Value::String(arch) => arch,
            _ => String::from("unknown"),
        }
    }
    pub fn model_parameters(&self) -> String {
        if self.parameters > 0 {
            human_number(self.parameters)
        } else {
            String::from("unknown")
        }
    }
    pub fn file_type(&self) -> String {
        if let Some(ft) = self.kv.get("general.file_type") {
            file_type(ft.as_u64().unwrap())
        } else {
            String::from("unknown")
        }
    }
    pub fn metadata(&self) -> &BTreeMap<String, Value> {
        &self.kv
    }
    pub fn tensors(&self) -> &Vec<Tensor> {
        &self.tensors
    }
}
pub fn get_gguf_container(file: &str) -> Result<GGUFContainer> {
    if !std::path::Path::new(file).exists() {
        return Err(anyhow!("file not found"));
    }
    let mut reader = std::fs::File::open(file).unwrap();
    let byte_le = reader.read_i32::<LittleEndian>().unwrap();
    match byte_le {
        FILE_MAGIC_GGML => Err(anyhow!("unsupport ggml format")),
        FILE_MAGIC_GGMF => Err(anyhow!("unsupport ggmf format")),
        FILE_MAGIC_GGJT => Err(anyhow!("unsupport ggjt format")),
        FILE_MAGIC_GGLA => Err(anyhow!("unsupport ggla format")),
        FILE_MAGIC_GGUF_LE => Ok(GGUFContainer::new(ByteOrder::LE, Box::new(reader))),
        FILE_MAGIC_GGUF_BE => Ok(GGUFContainer::new(ByteOrder::BE, Box::new(reader))),
        _ => Err(anyhow!("invalid file magic")),
    }
}
#[cfg(test)]
mod tests {
    use serde_json::json;
    #[test]
    fn test_read_le_v3_gguf() {
        let mut container = super::get_gguf_container("tests/test-le-v3.gguf").unwrap();
        let model = container.decode().unwrap();
        assert_eq!(model.get_version(), "v3");
        assert_eq!(model.model_family(), "llama");
        assert_eq!(model.file_type(), "unknown");
        assert_eq!(model.model_parameters(), "192");
        assert_eq!(
            serde_json::to_value(model.kv).unwrap(),
            json!({
                "general.architecture": "llama", "llama.block_count": 12, "general.alignment": 64, "answer": 42, "answer_in_float": 42.0
            })
        );
    }
}