pub use ggml::util::*;
use std::path::{Path, PathBuf};
#[macro_export]
#[doc(hidden)]
macro_rules! mulf {
    ($term:expr, $($terms:expr),*) => {
        usize::try_from((($term as f64) $(* ($terms as f64))*) as u64).unwrap()
    };
}
use memmap2::{Mmap, MmapAsRawDesc, MmapOptions};
use thiserror::Error;
#[derive(Clone, PartialEq, Eq, Default)]
pub struct TokenUtf8Buffer(Vec<u8>);
impl TokenUtf8Buffer {
    pub const fn new() -> Self {
        Self(vec![])
    }
    pub fn push(&mut self, token: &[u8]) -> Option<String> {
        self.0.extend_from_slice(token);
        match std::str::from_utf8(&self.0) {
            Ok(s) => {
                let out = s.to_owned();
                self.0 = vec![];
                Some(out)
            }
            Err(..) => {
                for i in 1..self.0.len() {
                    let slice = &self.0[i..];
                    if slice.is_empty() {
                        break;
                    }
                    if let Ok(s) = std::str::from_utf8(slice) {
                        let out = s.to_owned();
                        self.0 = vec![];
                        return Some(out);
                    }
                }
                None
            }
        }
    }
    pub fn adapt_callback<'a, E: std::error::Error + 'static>(
        mut callback: impl FnMut(&str) -> Result<(), E> + 'a,
    ) -> impl FnMut(&[u8]) -> Result<(), E> + 'a {
        let mut buffer = Self::new();
        move |token| match buffer.push(token) {
            Some(tokens) => callback(&tokens),
            None => Ok(()),
        }
    }
}
#[derive(Error, Debug)]
pub enum FindAllModelFilesError {
    #[error("no parent path for {path:?}")]
    NoParentPath {
        path: PathBuf,
    },
    #[error("non-specific I/O error")]
    IO(#[from] std::io::Error),
}
pub fn find_all_model_files(main_path: &Path) -> Result<Vec<PathBuf>, FindAllModelFilesError> {
    let mut main_path_parent =
        main_path
            .parent()
            .ok_or_else(|| FindAllModelFilesError::NoParentPath {
                path: main_path.to_owned(),
            })?;
    if main_path_parent.to_str() == Some("") {
        main_path_parent = Path::new(".");
    }
    Ok(collect_related_paths(
        main_path,
        std::fs::read_dir(main_path_parent)?
            .filter_map(Result::ok)
            .map(|de| de.path()),
    ))
}
fn collect_related_paths(
    main_path: &Path,
    directory_paths: impl Iterator<Item = PathBuf>,
) -> Vec<PathBuf> {
    let main_filename = main_path.file_name().and_then(|p| p.to_str());
    let mut paths: Vec<PathBuf> = directory_paths
        .filter(|p| {
            p.file_name()
                .and_then(|p| p.to_str())
                .zip(main_filename)
                .map(|(part_filename, main_filename)| {
                    match part_filename.strip_prefix(main_filename) {
                        Some(suffix) => {
                            suffix.is_empty()
                                || (suffix
                                    .strip_prefix('.')
                                    .map(|s| s.parse::<usize>().is_ok())
                                    .unwrap_or(false))
                        }
                        None => false,
                    }
                })
                .unwrap_or(false)
        })
        .collect();
    paths.sort();
    paths
}
pub fn mmap_populate<T: MmapAsRawDesc>(file: T) -> Result<Mmap, std::io::Error> {
    unsafe { MmapOptions::new().populate().map(file) }
}
#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn test_collect_related_paths() {
        let main_path = PathBuf::from("/models/llama.bin");
        let directory_paths = [
            "/models/llama.bin",
            "/models/llama.bin.1",
            "/models/llama.bin.2",
            "/models/llama.bin.tmp",
        ]
        .map(PathBuf::from);
        let expected_paths = [
            "/models/llama.bin",
            "/models/llama.bin.1",
            "/models/llama.bin.2",
        ]
        .map(PathBuf::from);
        let output_paths = collect_related_paths(&main_path, directory_paths.into_iter());
        assert_eq!(expected_paths.as_slice(), output_paths);
    }
    #[test]
    fn test_valid_utf8() {
        let mut buffer = TokenUtf8Buffer::new();
        assert_eq!(buffer.push(b"hello").as_deref(), Some("hello"));
        assert_eq!(buffer.push(&[0xE2, 0x82, 0xAC]).as_deref(), Some("€"));
    }
    #[test]
    fn test_partial_utf8() {
        let mut buffer = TokenUtf8Buffer::new();
        assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None);
        assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€"));
    }
    #[test]
    fn test_invalid_prelude_for_valid_utf8() {
        let mut buffer = TokenUtf8Buffer::new();
        assert_eq!(buffer.push(&[0xD8]).as_deref(), None);
        assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None);
        assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€"));
    }
}