use std::{
    collections::HashMap,
    fmt::{Display, Formatter},
    fs::File,
    io::{BufRead, BufReader, Read, Seek, SeekFrom},
    path::{Path, PathBuf},
};
use crate::{
    util::{self, FindAllModelFilesError},
    Hyperparameters, KnownModel, ModelParameters, TokenId, Vocabulary,
};
pub use ggml::ContainerType;
use ggml::{
    format::{LoadError as FormatLoadError, PartialHyperparameters, TensorLoadInfo},
    Context,
};
use memmap2::Mmap;
use thiserror::Error;
#[derive(Debug, PartialEq, Clone, Copy, Eq, Default)]
pub enum FileType {
    F32,
    #[default]
    MostlyF16,
    MostlyQ4_0,
    MostlyQ4_1,
    MostlyQ4_1SomeF16,
    MostlyQ4_2,
    MostlyQ8_0,
    MostlyQ5_0,
    MostlyQ5_1,
}
impl From<FileType> for i32 {
    fn from(value: FileType) -> Self {
        match value {
            FileType::F32 => 0,
            FileType::MostlyF16 => 1,
            FileType::MostlyQ4_0 => 2,
            FileType::MostlyQ4_1 => 3,
            FileType::MostlyQ4_1SomeF16 => 4,
            FileType::MostlyQ4_2 => 5,
            FileType::MostlyQ8_0 => 7,
            FileType::MostlyQ5_0 => 8,
            FileType::MostlyQ5_1 => 9,
        }
    }
}
impl TryFrom<i32> for FileType {
    type Error = ();
    fn try_from(value: i32) -> Result<Self, Self::Error> {
        match value {
            0 => Ok(FileType::F32),
            1 => Ok(FileType::MostlyF16),
            2 => Ok(FileType::MostlyQ4_0),
            3 => Ok(FileType::MostlyQ4_1),
            4 => Ok(FileType::MostlyQ4_1SomeF16),
            5 => Ok(FileType::MostlyQ4_2),
            7 => Ok(FileType::MostlyQ8_0),
            8 => Ok(FileType::MostlyQ5_0),
            9 => Ok(FileType::MostlyQ5_1),
            _ => Err(()),
        }
    }
}
impl Display for FileType {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self {
            FileType::F32 => write!(f, "f32"),
            FileType::MostlyF16 => write!(f, "f16"),
            FileType::MostlyQ4_0 => write!(f, "q4_0"),
            FileType::MostlyQ4_1 => write!(f, "q4_1"),
            FileType::MostlyQ4_1SomeF16 => write!(f, "q4_1_with_f16"),
            FileType::MostlyQ4_2 => write!(f, "q4_2"),
            FileType::MostlyQ8_0 => write!(f, "q8_0"),
            FileType::MostlyQ5_0 => write!(f, "q5_0"),
            FileType::MostlyQ5_1 => write!(f, "q5_1"),
        }
    }
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum LoadProgress {
    HyperparametersLoaded,
    ContextSize {
        bytes: usize,
    },
    TensorLoaded {
        current_tensor: usize,
        tensor_count: usize,
    },
    Loaded {
        file_size: u64,
        tensor_count: usize,
    },
}
#[derive(Error, Debug)]
pub enum LoadError {
    #[error("could not open file {path:?}")]
    OpenFileFailed {
        source: std::io::Error,
        path: PathBuf,
    },
    #[error("no parent path for {path:?}")]
    NoParentPath {
        path: PathBuf,
    },
    #[error("unable to read exactly {bytes} bytes")]
    ReadExactFailed {
        source: std::io::Error,
        bytes: usize,
    },
    #[error("non-specific I/O error")]
    Io(#[from] std::io::Error),
    #[error("could not convert bytes to a UTF-8 string")]
    InvalidUtf8(#[from] std::string::FromUtf8Error),
    #[error("invalid integer conversion")]
    InvalidIntegerConversion(#[from] std::num::TryFromIntError),
    #[error("unsupported f16_: {0}")]
    UnsupportedFileType(i32),
    #[error("invalid magic number {magic:#x} for {path:?}")]
    InvalidMagic {
        path: PathBuf,
        magic: u32,
    },
    #[error("invalid file format version {version}")]
    InvalidFormatVersion {
        container_type: ContainerType,
        version: u32,
    },
    #[error("invalid value {ftype} for `f16` in hyperparameters")]
    HyperparametersF16Invalid {
        ftype: i32,
    },
    #[error("unknown tensor `{tensor_name}` in {path:?}")]
    UnknownTensor {
        tensor_name: String,
        path: PathBuf,
    },
    #[error("the tensor `{tensor_name}` has the wrong size in {path:?}")]
    TensorWrongSize {
        tensor_name: String,
        path: PathBuf,
    },
    #[error("invalid ftype {ftype} for tensor `{tensor_name}` in {path:?}")]
    UnsupportedElementType {
        tensor_name: String,
        ftype: u32,
        path: PathBuf,
    },
    #[error("invariant broken: {invariant} in {path:?}")]
    InvariantBroken {
        path: Option<PathBuf>,
        invariant: String,
    },
    #[error("could not create model from {path:?}")]
    ModelNotCreated {
        path: PathBuf,
    },
    #[error("multipart models are not supported")]
    MultipartNotSupported {
        paths: Vec<PathBuf>,
    },
}
impl From<FindAllModelFilesError> for LoadError {
    fn from(value: FindAllModelFilesError) -> Self {
        match value {
            FindAllModelFilesError::NoParentPath { path } => LoadError::NoParentPath { path },
            FindAllModelFilesError::IO(err) => LoadError::Io(err),
        }
    }
}
impl LoadError {
    #[doc(hidden)]
    pub fn from_format_error(value: FormatLoadError<LoadError>, path: PathBuf) -> Self {
        match value {
            FormatLoadError::InvalidMagic(magic) => LoadError::InvalidMagic { path, magic },
            FormatLoadError::InvalidFormatVersion(container_type, version) => {
                LoadError::InvalidFormatVersion {
                    container_type,
                    version,
                }
            }
            FormatLoadError::Io(err) => LoadError::Io(err),
            FormatLoadError::InvalidUtf8(err) => LoadError::InvalidUtf8(err),
            FormatLoadError::InvalidIntegerConversion(err) => {
                LoadError::InvalidIntegerConversion(err)
            }
            FormatLoadError::ImplementationError(err) => err,
            FormatLoadError::UnsupportedElementType { tensor_name, ftype } => {
                LoadError::UnsupportedElementType {
                    path,
                    tensor_name,
                    ftype,
                }
            }
            FormatLoadError::InvariantBroken(invariant) => LoadError::InvariantBroken {
                path: Some(path),
                invariant,
            },
        }
    }
}
pub trait TensorLoader<E: std::error::Error> {
    fn load(&mut self, name: &str) -> Result<ggml::Tensor, E>;
    fn load_manual(&mut self, name: &str, ne: &[usize]) -> Result<ggml::Tensor, E>;
    fn finish(self) -> (Context, HashMap<String, ggml::Tensor>, Option<Mmap>);
}
pub fn load<M: KnownModel>(
    path: &Path,
    params: ModelParameters,
    load_progress_callback: impl FnMut(LoadProgress),
) -> Result<M, LoadError> {
    let paths = util::find_all_model_files(path)?;
    if paths.len() != 1 {
        return Err(LoadError::MultipartNotSupported { paths });
    }
    let file = File::open(path).map_err(|e| LoadError::OpenFileFailed {
        source: e,
        path: path.to_owned(),
    })?;
    let mut reader = BufReader::new(&file);
    let mut loader = Loader::new(load_progress_callback);
    ggml::format::load(&mut reader, &mut loader)
        .map_err(|err| LoadError::from_format_error(err, path.to_owned()))?;
    let Loader {
        hyperparameters,
        vocabulary,
        tensors,
        mut load_progress_callback,
        container_type,
        ..
    } = loader;
    let use_mmap = params.prefer_mmap && container_type.support_mmap();
    let ctx_size = tensors
        .values()
        .map(|ti| {
            ggml::Tensor::C_TYPE_SIZE
                + ggml::OBJECT_SIZE
                + if use_mmap { 0 } else { ti.calc_size() }
        })
        .sum::<usize>();
    (load_progress_callback)(LoadProgress::ContextSize { bytes: ctx_size });
    let context = Context::init(ctx_size, !use_mmap);
    let (mmap, file_size) = {
        let file = File::open(path)?;
        let mmap = if use_mmap {
            Some(unsafe { Mmap::map(&file)? })
        } else {
            None
        };
        (mmap, file.metadata()?.len())
    };
    struct MmapCompatibleLoader<'a> {
        path: PathBuf,
        file: File,
        tensors: HashMap<String, TensorLoadInfo>,
        context: Context,
        mmap: Option<Mmap>,
        load_progress_callback: &'a mut dyn FnMut(LoadProgress),
        loaded_tensors: HashMap<String, ggml::Tensor>,
    }
    impl TensorLoader<LoadError> for MmapCompatibleLoader<'_> {
        fn load(&mut self, name: &str) -> Result<ggml::Tensor, LoadError> {
            let tensor_dims = self
                .tensors
                .get(name)
                .map(|tensor| tensor.dims().to_vec())
                .ok_or(LoadError::UnknownTensor {
                    tensor_name: String::from(name),
                    path: Default::default(),
                })?;
            self.load_manual(name, &tensor_dims)
        }
        fn load_manual(&mut self, name: &str, ne: &[usize]) -> Result<ggml::Tensor, LoadError> {
            let info = self
                .tensors
                .get(name)
                .ok_or_else(|| LoadError::UnknownTensor {
                    path: self.path.clone(),
                    tensor_name: name.to_owned(),
                })?;
            let dims = ne.len();
            if dims != info.n_dims {
                return Err(LoadError::InvariantBroken {
                    path: Some(self.path.clone()),
                    invariant: format!(
                        "the tensor {name} should have {} dimensions, not {dims}",
                        info.n_dims
                    ),
                });
            }
            let ctx = &self.context;
            let mut tensor = match dims {
                1 => ctx.new_tensor_1d(info.element_type, ne[0]),
                2 => ctx.new_tensor_2d(info.element_type, ne[0], ne[1]),
                3 => ctx.new_tensor_3d(info.element_type, ne[0], ne[1], ne[2]),
                _ => {
                    return Err(LoadError::InvariantBroken {
                        path: Some(self.path.clone()),
                        invariant: format!(
                            "the tensor {name} had an unsupported dimension count: {ne:?}"
                        ),
                    })
                }
            };
            match self.mmap.as_ref() {
                Some(mmap) => unsafe {
                    let ptr = mmap.as_ptr().offset(info.start_offset as isize);
                    tensor.set_data(ptr as *mut std::ffi::c_void);
                },
                None => {
                    let buf: &mut [u8] = unsafe {
                        std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes())
                    };
                    self.file.seek(SeekFrom::Start(info.start_offset))?;
                    self.file.read_exact(buf)?;
                }
            }
            self.loaded_tensors.insert(name.to_owned(), tensor.share());
            (self.load_progress_callback)(LoadProgress::TensorLoaded {
                current_tensor: self.loaded_tensors.len(),
                tensor_count: self.tensors.len(),
            });
            Ok(tensor)
        }
        fn finish(self) -> (Context, HashMap<String, ggml::Tensor>, Option<Mmap>) {
            (self.context, self.loaded_tensors, self.mmap)
        }
    }
    let tensors_len = tensors.len();
    let tl = MmapCompatibleLoader {
        path: path.to_owned(),
        file,
        tensors,
        context,
        mmap,
        load_progress_callback: &mut load_progress_callback,
        loaded_tensors: Default::default(),
    };
    let model = KnownModel::new(hyperparameters, params, vocabulary, tl)?;
    (load_progress_callback)(LoadProgress::Loaded {
        file_size,
        tensor_count: tensors_len,
    });
    Ok(model)
}
pub struct Loader<Hp: Hyperparameters, F: FnMut(LoadProgress)> {
    load_progress_callback: F,
    pub container_type: ContainerType,
    pub hyperparameters: Hp,
    pub vocabulary: Vocabulary,
    pub tensors: HashMap<String, TensorLoadInfo>,
}
impl<Hp: Hyperparameters, F: FnMut(LoadProgress)> Loader<Hp, F> {
    pub fn new(load_progress_callback: F) -> Self {
        Self {
            load_progress_callback,
            container_type: ContainerType::Ggjt,
            hyperparameters: Hp::default(),
            vocabulary: Vocabulary::default(),
            tensors: HashMap::default(),
        }
    }
}
impl<Hp: Hyperparameters, F: FnMut(LoadProgress)> ggml::format::LoadHandler<LoadError>
    for Loader<Hp, F>
{
    fn container_type(&mut self, container_type: ContainerType) -> Result<(), LoadError> {
        self.container_type = container_type;
        Ok(())
    }
    fn vocabulary_token(&mut self, i: usize, token: Vec<u8>, score: f32) -> Result<(), LoadError> {
        let id = match TokenId::try_from(i) {
            Ok(id) => id,
            Err(err) => return Err(LoadError::InvalidIntegerConversion(err)),
        };
        self.vocabulary.push_token(id, token, score);
        Ok(())
    }
    fn read_hyperparameters(
        &mut self,
        reader: &mut dyn BufRead,
    ) -> Result<PartialHyperparameters, LoadError> {
        let hyperparameters = Hp::read_ggml(reader)?;
        let partial = PartialHyperparameters {
            n_vocab: hyperparameters.n_vocabulary(),
        };
        self.hyperparameters = hyperparameters;
        (self.load_progress_callback)(LoadProgress::HyperparametersLoaded);
        Ok(partial)
    }
    fn tensor_buffer(&mut self, info: TensorLoadInfo) -> Result<(), LoadError> {
        self.tensors.insert(info.name.clone(), info);
        Ok(())
    }
}
pub fn load_progress_callback_stdout(progress: LoadProgress) {
    match progress {
        LoadProgress::HyperparametersLoaded => println!("Loaded hyperparameters"),
        LoadProgress::ContextSize { bytes } => println!(
            "ggml ctx size = {:.2} MB\n",
            bytes as f64 / (1024.0 * 1024.0)
        ),
        LoadProgress::TensorLoaded {
            current_tensor,
            tensor_count,
            ..
        } => {
            let current_tensor = current_tensor + 1;
            if current_tensor % 8 == 0 {
                println!("Loaded tensor {current_tensor}/{tensor_count}");
            }
        }
        LoadProgress::Loaded {
            file_size: byte_size,
            tensor_count,
        } => {
            println!("Loading of model complete");
            println!(
                "Model size = {:.2} MB / num tensors = {}",
                byte_size as f64 / 1024.0 / 1024.0,
                tensor_count
            );
        }
    };
}