llm_base/
util.rs

1//! Utilities for interacting with LLMs and loading them.
2pub use ggml::util::*;
3use std::path::{Path, PathBuf};
4
5/// NOTE: The original code relies in promotion rules and automatic cast between
6/// int to float. What we do instead is use this macro to convert every term of
7/// the multiplication to f64, which should have enough precision bits to hold
8/// the final value, then cast to usize. I have observed a discrepancy between
9/// the ctx_size found using this code, and the one in llama.cpp. The number for
10/// rust ends up being slightly lower, but no "out of memory" errors are
11/// reported by ggml.
12#[macro_export]
13#[doc(hidden)]
14macro_rules! mulf {
15    ($term:expr, $($terms:expr),*) => {
16        usize::try_from((($term as f64) $(* ($terms as f64))*) as u64).unwrap()
17    };
18}
19
20use memmap2::{Mmap, MmapAsRawDesc, MmapOptions};
21use thiserror::Error;
22
23/// Used to buffer incoming tokens until they produce a valid string of UTF-8 text.
24///
25/// Tokens are *not* valid UTF-8 by themselves. However, the LLM will produce valid UTF-8
26/// from multiple tokens. This helps alleviate that issue.
27#[derive(Clone, PartialEq, Eq, Default)]
28pub struct TokenUtf8Buffer(Vec<u8>);
29impl TokenUtf8Buffer {
30    /// Create a new buffer.
31    pub const fn new() -> Self {
32        Self(vec![])
33    }
34
35    /// Add a token to the buffer. If the buffer contains a valid string of UTF-8 text,
36    /// it is returned and the buffer is cleared for next use.
37    pub fn push(&mut self, token: &[u8]) -> Option<String> {
38        self.0.extend_from_slice(token);
39        match std::str::from_utf8(&self.0) {
40            Ok(s) => {
41                let out = s.to_owned();
42                self.0 = vec![];
43                Some(out)
44            }
45            Err(..) => {
46                for i in 1..self.0.len() {
47                    let slice = &self.0[i..];
48                    if slice.is_empty() {
49                        break;
50                    }
51
52                    if let Ok(s) = std::str::from_utf8(slice) {
53                        let out = s.to_owned();
54                        self.0 = vec![];
55                        return Some(out);
56                    }
57                }
58                None
59            }
60        }
61    }
62
63    /// Adapt a `&str` callback so that it can be used in a `&[u8]` context.
64    pub fn adapt_callback<'a, E: std::error::Error + 'static>(
65        mut callback: impl FnMut(&str) -> Result<(), E> + 'a,
66    ) -> impl FnMut(&[u8]) -> Result<(), E> + 'a {
67        let mut buffer = Self::new();
68        move |token| match buffer.push(token) {
69            Some(tokens) => callback(&tokens),
70            None => Ok(()),
71        }
72    }
73}
74
75#[derive(Error, Debug)]
76/// Errors encountered during the loading process.
77pub enum FindAllModelFilesError {
78    #[error("no parent path for {path:?}")]
79    /// There is no parent path for a given path.
80    NoParentPath {
81        /// The path without a parent.
82        path: PathBuf,
83    },
84    #[error("non-specific I/O error")]
85    /// A non-specific IO error.
86    IO(#[from] std::io::Error),
87}
88
89/// Find all the files related to a model.
90pub fn find_all_model_files(main_path: &Path) -> Result<Vec<PathBuf>, FindAllModelFilesError> {
91    let mut main_path_parent =
92        main_path
93            .parent()
94            .ok_or_else(|| FindAllModelFilesError::NoParentPath {
95                path: main_path.to_owned(),
96            })?;
97    if main_path_parent.to_str() == Some("") {
98        main_path_parent = Path::new(".");
99    }
100    Ok(collect_related_paths(
101        main_path,
102        std::fs::read_dir(main_path_parent)?
103            .filter_map(Result::ok)
104            .map(|de| de.path()),
105    ))
106}
107
108fn collect_related_paths(
109    main_path: &Path,
110    directory_paths: impl Iterator<Item = PathBuf>,
111) -> Vec<PathBuf> {
112    let main_filename = main_path.file_name().and_then(|p| p.to_str());
113
114    let mut paths: Vec<PathBuf> = directory_paths
115        .filter(|p| {
116            p.file_name()
117                .and_then(|p| p.to_str())
118                .zip(main_filename)
119                .map(|(part_filename, main_filename)| {
120                    match part_filename.strip_prefix(main_filename) {
121                        Some(suffix) => {
122                            suffix.is_empty()
123                                || (suffix
124                                    .strip_prefix('.')
125                                    .map(|s| s.parse::<usize>().is_ok())
126                                    .unwrap_or(false))
127                        }
128                        None => false,
129                    }
130                })
131                .unwrap_or(false)
132        })
133        .collect();
134    paths.sort();
135    paths
136}
137
138/// mmap with MAP_POPULATE
139pub fn mmap_populate<T: MmapAsRawDesc>(file: T) -> Result<Mmap, std::io::Error> {
140    unsafe { MmapOptions::new().populate().map(file) }
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146
147    #[test]
148    fn test_collect_related_paths() {
149        let main_path = PathBuf::from("/models/llama.bin");
150        let directory_paths = [
151            "/models/llama.bin",
152            "/models/llama.bin.1",
153            "/models/llama.bin.2",
154            "/models/llama.bin.tmp",
155        ]
156        .map(PathBuf::from);
157        let expected_paths = [
158            "/models/llama.bin",
159            "/models/llama.bin.1",
160            "/models/llama.bin.2",
161        ]
162        .map(PathBuf::from);
163
164        let output_paths = collect_related_paths(&main_path, directory_paths.into_iter());
165        assert_eq!(expected_paths.as_slice(), output_paths);
166    }
167
168    #[test]
169    fn test_valid_utf8() {
170        let mut buffer = TokenUtf8Buffer::new();
171        assert_eq!(buffer.push(b"hello").as_deref(), Some("hello"));
172        assert_eq!(buffer.push(&[0xE2, 0x82, 0xAC]).as_deref(), Some("€"));
173    }
174
175    #[test]
176    fn test_partial_utf8() {
177        let mut buffer = TokenUtf8Buffer::new();
178        assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None);
179        assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€"));
180    }
181
182    #[test]
183    fn test_invalid_prelude_for_valid_utf8() {
184        let mut buffer = TokenUtf8Buffer::new();
185        assert_eq!(buffer.push(&[0xD8]).as_deref(), None);
186        assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None);
187        assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€"));
188    }
189}