pub use ggml::util::*;
use std::path::{Path, PathBuf};
#[macro_export]
#[doc(hidden)]
macro_rules! mulf {
($term:expr, $($terms:expr),*) => {
usize::try_from((($term as f64) $(* ($terms as f64))*) as u64).unwrap()
};
}
use memmap2::{Mmap, MmapAsRawDesc, MmapOptions};
use thiserror::Error;
#[derive(Clone, PartialEq, Eq, Default)]
pub struct TokenUtf8Buffer(Vec<u8>);
impl TokenUtf8Buffer {
pub const fn new() -> Self {
Self(vec![])
}
pub fn push(&mut self, token: &[u8]) -> Option<String> {
self.0.extend_from_slice(token);
match std::str::from_utf8(&self.0) {
Ok(s) => {
let out = s.to_owned();
self.0 = vec![];
Some(out)
}
Err(..) => {
for i in 1..self.0.len() {
let slice = &self.0[i..];
if slice.is_empty() {
break;
}
if let Ok(s) = std::str::from_utf8(slice) {
let out = s.to_owned();
self.0 = vec![];
return Some(out);
}
}
None
}
}
}
pub fn adapt_callback<'a, E: std::error::Error + 'static>(
mut callback: impl FnMut(&str) -> Result<(), E> + 'a,
) -> impl FnMut(&[u8]) -> Result<(), E> + 'a {
let mut buffer = Self::new();
move |token| match buffer.push(token) {
Some(tokens) => callback(&tokens),
None => Ok(()),
}
}
}
#[derive(Error, Debug)]
pub enum FindAllModelFilesError {
#[error("no parent path for {path:?}")]
NoParentPath {
path: PathBuf,
},
#[error("non-specific I/O error")]
IO(#[from] std::io::Error),
}
pub fn find_all_model_files(main_path: &Path) -> Result<Vec<PathBuf>, FindAllModelFilesError> {
let mut main_path_parent =
main_path
.parent()
.ok_or_else(|| FindAllModelFilesError::NoParentPath {
path: main_path.to_owned(),
})?;
if main_path_parent.to_str() == Some("") {
main_path_parent = Path::new(".");
}
Ok(collect_related_paths(
main_path,
std::fs::read_dir(main_path_parent)?
.filter_map(Result::ok)
.map(|de| de.path()),
))
}
fn collect_related_paths(
main_path: &Path,
directory_paths: impl Iterator<Item = PathBuf>,
) -> Vec<PathBuf> {
let main_filename = main_path.file_name().and_then(|p| p.to_str());
let mut paths: Vec<PathBuf> = directory_paths
.filter(|p| {
p.file_name()
.and_then(|p| p.to_str())
.zip(main_filename)
.map(|(part_filename, main_filename)| {
match part_filename.strip_prefix(main_filename) {
Some(suffix) => {
suffix.is_empty()
|| (suffix
.strip_prefix('.')
.map(|s| s.parse::<usize>().is_ok())
.unwrap_or(false))
}
None => false,
}
})
.unwrap_or(false)
})
.collect();
paths.sort();
paths
}
pub fn mmap_populate<T: MmapAsRawDesc>(file: T) -> Result<Mmap, std::io::Error> {
unsafe { MmapOptions::new().populate().map(file) }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_collect_related_paths() {
let main_path = PathBuf::from("/models/llama.bin");
let directory_paths = [
"/models/llama.bin",
"/models/llama.bin.1",
"/models/llama.bin.2",
"/models/llama.bin.tmp",
]
.map(PathBuf::from);
let expected_paths = [
"/models/llama.bin",
"/models/llama.bin.1",
"/models/llama.bin.2",
]
.map(PathBuf::from);
let output_paths = collect_related_paths(&main_path, directory_paths.into_iter());
assert_eq!(expected_paths.as_slice(), output_paths);
}
#[test]
fn test_valid_utf8() {
let mut buffer = TokenUtf8Buffer::new();
assert_eq!(buffer.push(b"hello").as_deref(), Some("hello"));
assert_eq!(buffer.push(&[0xE2, 0x82, 0xAC]).as_deref(), Some("€"));
}
#[test]
fn test_partial_utf8() {
let mut buffer = TokenUtf8Buffer::new();
assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None);
assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€"));
}
#[test]
fn test_invalid_prelude_for_valid_utf8() {
let mut buffer = TokenUtf8Buffer::new();
assert_eq!(buffer.push(&[0xD8]).as_deref(), None);
assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None);
assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€"));
}
}