kalosm_llama/
lib.rs

1//! # RLlama
2//!
3//! RLlama is a Rust implementation of the quantized [Llama 7B](https://llama.ai/news/announcing-llama-7b/) language model.
4//!
5//! Llama 7B is a very small but performant language model that can be easily run on your local machine.
6//!
7//! This library uses [Candle](https://github.com/huggingface/candle) to run Llama.
8//!
9//! ## Usage
10//!
11//! ```rust, no_run
12//! use kalosm_llama::prelude::*;
13//!
14//! #[tokio::main]
15//! async fn main() {
16//!     let mut model = Llama::new().await.unwrap();
17//!     let prompt = "The capital of France is ";
18//!     let mut stream = model(prompt);
19//!
20//!     print!("{prompt}");
21//!     while let Some(token) = stream.next().await {
22//!         print!("{token}");
23//!     }
24//! }
25//! ```
26
27#![warn(missing_docs)]
28
29#[cfg(feature = "mkl")]
30extern crate intel_mkl_src;
31
32#[cfg(feature = "accelerate")]
33extern crate accelerate_src;
34
35mod chat;
36mod chat_template;
37mod gguf_tokenizer;
38mod language_model;
39mod model;
40mod raw;
41mod session;
42mod source;
43mod structured;
44mod token_stream;
45
46pub use crate::chat::LlamaChatSession;
47use crate::model::LlamaModel;
48pub use crate::raw::cache::*;
49pub use crate::session::LlamaSession;
50use candle_core::Device;
51pub use kalosm_common::*;
52use kalosm_language_model::{TextCompletionBuilder, TextCompletionModelExt};
53use kalosm_model_types::ModelLoadingProgress;
54use kalosm_sample::{LiteralParser, StopOn};
55use model::LlamaModelError;
56use raw::LlamaConfig;
57pub use source::*;
58use std::mem::MaybeUninit;
59use std::ops::Deref;
60use std::sync::Arc;
61use tokenizers::Tokenizer;
62
63/// A prelude of commonly used items in kalosm-llama.
64pub mod prelude {
65    pub use crate::session::LlamaSession;
66    pub use crate::{Llama, LlamaBuilder, LlamaSource};
67    pub use kalosm_language_model::*;
68}
69
70enum Task {
71    UnstructuredGeneration(UnstructuredGenerationTask),
72    StructuredGeneration(StructuredGenerationTask),
73}
74
75struct StructuredGenerationTask {
76    runner: Box<dyn FnOnce(&mut LlamaModel) + Send>,
77}
78
79struct UnstructuredGenerationTask {
80    settings: InferenceSettings,
81    on_token: Box<dyn FnMut(String) -> Result<(), LlamaModelError> + Send + Sync>,
82    finished: tokio::sync::oneshot::Sender<Result<(), LlamaModelError>>,
83}
84
85/// A quantized Llama language model with support for streaming generation.
86#[derive(Clone)]
87pub struct Llama {
88    config: Arc<LlamaConfig>,
89    tokenizer: Arc<Tokenizer>,
90    task_sender: tokio::sync::mpsc::UnboundedSender<Task>,
91}
92
93impl Llama {
94    /// Create a default chat model.
95    pub async fn new_chat() -> Result<Self, LlamaSourceError> {
96        Llama::builder()
97            .with_source(LlamaSource::llama_3_1_8b_chat())
98            .build()
99            .await
100    }
101
102    /// Create a default phi-3 chat model.
103    pub async fn phi_3() -> Result<Self, LlamaSourceError> {
104        Llama::builder()
105            .with_source(LlamaSource::phi_3_5_mini_4k_instruct())
106            .build()
107            .await
108    }
109
110    /// Create a default text generation model.
111    pub async fn new() -> Result<Self, LlamaSourceError> {
112        Llama::builder()
113            .with_source(LlamaSource::llama_8b())
114            .build()
115            .await
116    }
117
118    /// Get the tokenizer for the model.
119    pub fn tokenizer(&self) -> &Arc<Tokenizer> {
120        &self.tokenizer
121    }
122
123    /// Create a new builder for a Llama model.
124    pub fn builder() -> LlamaBuilder {
125        LlamaBuilder::default()
126    }
127
128    #[allow(clippy::too_many_arguments)]
129    fn from_build(mut model: LlamaModel) -> Self {
130        let (task_sender, mut task_receiver) = tokio::sync::mpsc::unbounded_channel();
131        let config = model.model.config.clone();
132        let tokenizer = model.tokenizer.clone();
133
134        std::thread::spawn({
135            move || {
136                while let Some(task) = task_receiver.blocking_recv() {
137                    match task {
138                        Task::UnstructuredGeneration(UnstructuredGenerationTask {
139                            settings,
140                            on_token,
141                            finished,
142                        }) => {
143                            let result = model._infer(settings, on_token, &finished);
144                            if let Err(err) = &result {
145                                tracing::error!("Error running model: {err}");
146                            }
147                            _ = finished.send(result);
148                        }
149                        Task::StructuredGeneration(StructuredGenerationTask { runner }) => {
150                            runner(&mut model);
151                        }
152                    }
153                }
154            }
155        });
156        Self {
157            task_sender,
158            config,
159            tokenizer,
160        }
161    }
162
163    /// Get the default constraints for an assistant response. It parses any text until the end of the assistant's response.
164    pub fn default_assistant_constraints(&self) -> StopOn<String> {
165        let end_token = self.config.stop_token_string.clone();
166
167        StopOn::from(end_token)
168    }
169
170    /// Get the constraints that end the assistant's response.
171    pub fn end_assistant_marker_constraints(&self) -> LiteralParser {
172        let end_token = self.config.stop_token_string.clone();
173
174        LiteralParser::from(end_token)
175    }
176}
177
178impl Deref for Llama {
179    type Target = dyn Fn(&str) -> TextCompletionBuilder<Self>;
180
181    fn deref(&self) -> &Self::Target {
182        // https://github.com/dtolnay/case-studies/tree/master/callable-types
183
184        // Create an empty allocation for Self.
185        let uninit_callable = MaybeUninit::<Self>::uninit();
186        // Move a closure that captures just self into the uninitialized memory. Closures create an anonymous type that implement
187        // FnOnce. In this case, the layout of the type should just be Self because self is the only field in the closure type.
188        let uninit_closure = move |text: &str| {
189            TextCompletionModelExt::complete(unsafe { &*uninit_callable.as_ptr() }, text)
190        };
191
192        // Make sure the layout of the closure and Self is the same.
193        let size_of_closure = std::alloc::Layout::for_value(&uninit_closure);
194        assert_eq!(size_of_closure, std::alloc::Layout::new::<Self>());
195
196        // Then cast the lifetime of the closure to the lifetime of &self.
197        fn cast_lifetime<'a, T>(_a: &T, b: &'a T) -> &'a T {
198            b
199        }
200        let reference_to_closure = cast_lifetime(
201            {
202                // The real closure that we will never use.
203                &uninit_closure
204            },
205            #[allow(clippy::missing_transmute_annotations)]
206            // We transmute self into a reference to the closure. This is safe because we know that the closure has the same memory layout as Self so &Closure == &Self.
207            unsafe {
208                std::mem::transmute(self)
209            },
210        );
211
212        // Cast the closure to a trait object.
213        reference_to_closure as &_
214    }
215}
216
217/// A builder with configuration for a Llama model.
218#[derive(Default)]
219pub struct LlamaBuilder {
220    source: source::LlamaSource,
221    device: Option<Device>,
222    flash_attn: bool,
223}
224
225impl LlamaBuilder {
226    /// Set the source for the model.
227    pub fn with_source(mut self, source: source::LlamaSource) -> Self {
228        self.source = source;
229        self
230    }
231
232    /// Set whether to use Flash Attention.
233    pub fn with_flash_attn(mut self, use_flash_attn: bool) -> Self {
234        self.flash_attn = use_flash_attn;
235        self
236    }
237
238    /// Set the device to run the model with. (Defaults to an accelerator if available, otherwise the CPU)
239    pub fn with_device(mut self, device: Device) -> Self {
240        self.device = Some(device);
241        self
242    }
243
244    /// Get the device or the default device if not set.
245    pub(crate) fn get_device(&self) -> Result<Device, LlamaSourceError> {
246        match self.device.clone() {
247            Some(device) => Ok(device),
248            None => Ok(accelerated_device_if_available()?),
249        }
250    }
251
252    /// Build the model with a handler for progress as the download and loading progresses.
253    ///
254    /// ```rust, no_run
255    /// use kalosm::language::*;
256    /// # #[tokio::main]
257    /// # async fn main() -> Result<(), anyhow::Error> {
258    /// // Create a new llama model with a loading handler
259    /// let model = Llama::builder()
260    ///     .build_with_loading_handler(|progress| match progress {
261    ///         ModelLoadingProgress::Downloading { source, progress } => {
262    ///             let progress_percent = (progress.progress * 100) as u32;
263    ///             let elapsed = progress.start_time.elapsed().as_secs_f32();
264    ///             println!("Downloading file {source} {progress_percent}% ({elapsed}s)");
265    ///         }
266    ///         ModelLoadingProgress::Loading { progress } => {
267    ///             let progress = (progress * 100.0) as u32;
268    ///             println!("Loading model {progress}%");
269    ///         }
270    ///     })
271    ///     .await?;
272    /// # Ok(())
273    /// # }
274    /// ```
275    pub async fn build_with_loading_handler(
276        self,
277        handler: impl FnMut(ModelLoadingProgress) + Send + Sync + 'static,
278    ) -> Result<Llama, LlamaSourceError> {
279        let model = LlamaModel::from_builder(self, handler).await?;
280
281        Ok(Llama::from_build(model))
282    }
283
284    /// Build the model (this will download the model if it is not already downloaded)
285    pub async fn build(self) -> Result<Llama, LlamaSourceError> {
286        self.build_with_loading_handler(ModelLoadingProgress::multi_bar_loading_indicator())
287            .await
288    }
289}
290
291#[derive(Debug)]
292pub(crate) struct InferenceSettings {
293    prompt: String,
294
295    /// The token to stop on.
296    stop_on: Option<String>,
297
298    /// The sampler to use.
299    sampler: std::sync::Arc<std::sync::Mutex<dyn llm_samplers::prelude::Sampler>>,
300
301    /// The session to use.
302    session: LlamaSession,
303
304    /// The maximum number of tokens to generate.
305    max_tokens: u32,
306
307    /// The seed to use.
308    seed: Option<u64>,
309}
310
311impl InferenceSettings {
312    pub fn new(
313        prompt: impl Into<String>,
314        session: LlamaSession,
315        sampler: std::sync::Arc<std::sync::Mutex<dyn llm_samplers::prelude::Sampler>>,
316        max_tokens: u32,
317        stop_on: Option<String>,
318        seed: Option<u64>,
319    ) -> Self {
320        Self {
321            prompt: prompt.into(),
322            stop_on,
323            sampler,
324            session,
325            max_tokens,
326            seed,
327        }
328    }
329}