1pub use ggml::util::*;
3use std::path::{Path, PathBuf};
4
5#[macro_export]
13#[doc(hidden)]
14macro_rules! mulf {
15 ($term:expr, $($terms:expr),*) => {
16 usize::try_from((($term as f64) $(* ($terms as f64))*) as u64).unwrap()
17 };
18}
19
20use memmap2::{Mmap, MmapAsRawDesc, MmapOptions};
21use thiserror::Error;
22
23#[derive(Clone, PartialEq, Eq, Default)]
28pub struct TokenUtf8Buffer(Vec<u8>);
29impl TokenUtf8Buffer {
30 pub const fn new() -> Self {
32 Self(vec![])
33 }
34
35 pub fn push(&mut self, token: &[u8]) -> Option<String> {
38 self.0.extend_from_slice(token);
39 match std::str::from_utf8(&self.0) {
40 Ok(s) => {
41 let out = s.to_owned();
42 self.0 = vec![];
43 Some(out)
44 }
45 Err(..) => {
46 for i in 1..self.0.len() {
47 let slice = &self.0[i..];
48 if slice.is_empty() {
49 break;
50 }
51
52 if let Ok(s) = std::str::from_utf8(slice) {
53 let out = s.to_owned();
54 self.0 = vec![];
55 return Some(out);
56 }
57 }
58 None
59 }
60 }
61 }
62
63 pub fn adapt_callback<'a, E: std::error::Error + 'static>(
65 mut callback: impl FnMut(&str) -> Result<(), E> + 'a,
66 ) -> impl FnMut(&[u8]) -> Result<(), E> + 'a {
67 let mut buffer = Self::new();
68 move |token| match buffer.push(token) {
69 Some(tokens) => callback(&tokens),
70 None => Ok(()),
71 }
72 }
73}
74
75#[derive(Error, Debug)]
76pub enum FindAllModelFilesError {
78 #[error("no parent path for {path:?}")]
79 NoParentPath {
81 path: PathBuf,
83 },
84 #[error("non-specific I/O error")]
85 IO(#[from] std::io::Error),
87}
88
89pub fn find_all_model_files(main_path: &Path) -> Result<Vec<PathBuf>, FindAllModelFilesError> {
91 let mut main_path_parent =
92 main_path
93 .parent()
94 .ok_or_else(|| FindAllModelFilesError::NoParentPath {
95 path: main_path.to_owned(),
96 })?;
97 if main_path_parent.to_str() == Some("") {
98 main_path_parent = Path::new(".");
99 }
100 Ok(collect_related_paths(
101 main_path,
102 std::fs::read_dir(main_path_parent)?
103 .filter_map(Result::ok)
104 .map(|de| de.path()),
105 ))
106}
107
108fn collect_related_paths(
109 main_path: &Path,
110 directory_paths: impl Iterator<Item = PathBuf>,
111) -> Vec<PathBuf> {
112 let main_filename = main_path.file_name().and_then(|p| p.to_str());
113
114 let mut paths: Vec<PathBuf> = directory_paths
115 .filter(|p| {
116 p.file_name()
117 .and_then(|p| p.to_str())
118 .zip(main_filename)
119 .map(|(part_filename, main_filename)| {
120 match part_filename.strip_prefix(main_filename) {
121 Some(suffix) => {
122 suffix.is_empty()
123 || (suffix
124 .strip_prefix('.')
125 .map(|s| s.parse::<usize>().is_ok())
126 .unwrap_or(false))
127 }
128 None => false,
129 }
130 })
131 .unwrap_or(false)
132 })
133 .collect();
134 paths.sort();
135 paths
136}
137
138pub fn mmap_populate<T: MmapAsRawDesc>(file: T) -> Result<Mmap, std::io::Error> {
140 unsafe { MmapOptions::new().populate().map(file) }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 #[test]
148 fn test_collect_related_paths() {
149 let main_path = PathBuf::from("/models/llama.bin");
150 let directory_paths = [
151 "/models/llama.bin",
152 "/models/llama.bin.1",
153 "/models/llama.bin.2",
154 "/models/llama.bin.tmp",
155 ]
156 .map(PathBuf::from);
157 let expected_paths = [
158 "/models/llama.bin",
159 "/models/llama.bin.1",
160 "/models/llama.bin.2",
161 ]
162 .map(PathBuf::from);
163
164 let output_paths = collect_related_paths(&main_path, directory_paths.into_iter());
165 assert_eq!(expected_paths.as_slice(), output_paths);
166 }
167
168 #[test]
169 fn test_valid_utf8() {
170 let mut buffer = TokenUtf8Buffer::new();
171 assert_eq!(buffer.push(b"hello").as_deref(), Some("hello"));
172 assert_eq!(buffer.push(&[0xE2, 0x82, 0xAC]).as_deref(), Some("€"));
173 }
174
175 #[test]
176 fn test_partial_utf8() {
177 let mut buffer = TokenUtf8Buffer::new();
178 assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None);
179 assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€"));
180 }
181
182 #[test]
183 fn test_invalid_prelude_for_valid_utf8() {
184 let mut buffer = TokenUtf8Buffer::new();
185 assert_eq!(buffer.push(&[0xD8]).as_deref(), None);
186 assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None);
187 assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€"));
188 }
189}