1#![warn(missing_docs)]
28
29#[cfg(feature = "mkl")]
30extern crate intel_mkl_src;
31
32#[cfg(feature = "accelerate")]
33extern crate accelerate_src;
34
35mod chat;
36mod chat_template;
37mod gguf_tokenizer;
38mod language_model;
39mod model;
40mod raw;
41mod session;
42mod source;
43mod structured;
44mod token_stream;
45
46pub use crate::chat::LlamaChatSession;
47use crate::model::LlamaModel;
48pub use crate::raw::cache::*;
49pub use crate::session::LlamaSession;
50use candle_core::Device;
51pub use kalosm_common::*;
52use kalosm_language_model::{TextCompletionBuilder, TextCompletionModelExt};
53use kalosm_model_types::ModelLoadingProgress;
54use kalosm_sample::{LiteralParser, StopOn};
55use model::LlamaModelError;
56use raw::LlamaConfig;
57pub use source::*;
58use std::mem::MaybeUninit;
59use std::ops::Deref;
60use std::sync::Arc;
61use tokenizers::Tokenizer;
62
63pub mod prelude {
65 pub use crate::session::LlamaSession;
66 pub use crate::{Llama, LlamaBuilder, LlamaSource};
67 pub use kalosm_language_model::*;
68}
69
70enum Task {
71 UnstructuredGeneration(UnstructuredGenerationTask),
72 StructuredGeneration(StructuredGenerationTask),
73}
74
75struct StructuredGenerationTask {
76 runner: Box<dyn FnOnce(&mut LlamaModel) + Send>,
77}
78
79struct UnstructuredGenerationTask {
80 settings: InferenceSettings,
81 on_token: Box<dyn FnMut(String) -> Result<(), LlamaModelError> + Send + Sync>,
82 finished: tokio::sync::oneshot::Sender<Result<(), LlamaModelError>>,
83}
84
85#[derive(Clone)]
87pub struct Llama {
88 config: Arc<LlamaConfig>,
89 tokenizer: Arc<Tokenizer>,
90 task_sender: tokio::sync::mpsc::UnboundedSender<Task>,
91}
92
93impl Llama {
94 pub async fn new_chat() -> Result<Self, LlamaSourceError> {
96 Llama::builder()
97 .with_source(LlamaSource::llama_3_1_8b_chat())
98 .build()
99 .await
100 }
101
102 pub async fn phi_3() -> Result<Self, LlamaSourceError> {
104 Llama::builder()
105 .with_source(LlamaSource::phi_3_5_mini_4k_instruct())
106 .build()
107 .await
108 }
109
110 pub async fn new() -> Result<Self, LlamaSourceError> {
112 Llama::builder()
113 .with_source(LlamaSource::llama_8b())
114 .build()
115 .await
116 }
117
118 pub fn tokenizer(&self) -> &Arc<Tokenizer> {
120 &self.tokenizer
121 }
122
123 pub fn builder() -> LlamaBuilder {
125 LlamaBuilder::default()
126 }
127
128 #[allow(clippy::too_many_arguments)]
129 fn from_build(mut model: LlamaModel) -> Self {
130 let (task_sender, mut task_receiver) = tokio::sync::mpsc::unbounded_channel();
131 let config = model.model.config.clone();
132 let tokenizer = model.tokenizer.clone();
133
134 std::thread::spawn({
135 move || {
136 while let Some(task) = task_receiver.blocking_recv() {
137 match task {
138 Task::UnstructuredGeneration(UnstructuredGenerationTask {
139 settings,
140 on_token,
141 finished,
142 }) => {
143 let result = model._infer(settings, on_token, &finished);
144 if let Err(err) = &result {
145 tracing::error!("Error running model: {err}");
146 }
147 _ = finished.send(result);
148 }
149 Task::StructuredGeneration(StructuredGenerationTask { runner }) => {
150 runner(&mut model);
151 }
152 }
153 }
154 }
155 });
156 Self {
157 task_sender,
158 config,
159 tokenizer,
160 }
161 }
162
163 pub fn default_assistant_constraints(&self) -> StopOn<String> {
165 let end_token = self.config.stop_token_string.clone();
166
167 StopOn::from(end_token)
168 }
169
170 pub fn end_assistant_marker_constraints(&self) -> LiteralParser {
172 let end_token = self.config.stop_token_string.clone();
173
174 LiteralParser::from(end_token)
175 }
176}
177
178impl Deref for Llama {
179 type Target = dyn Fn(&str) -> TextCompletionBuilder<Self>;
180
181 fn deref(&self) -> &Self::Target {
182 let uninit_callable = MaybeUninit::<Self>::uninit();
186 let uninit_closure = move |text: &str| {
189 TextCompletionModelExt::complete(unsafe { &*uninit_callable.as_ptr() }, text)
190 };
191
192 let size_of_closure = std::alloc::Layout::for_value(&uninit_closure);
194 assert_eq!(size_of_closure, std::alloc::Layout::new::<Self>());
195
196 fn cast_lifetime<'a, T>(_a: &T, b: &'a T) -> &'a T {
198 b
199 }
200 let reference_to_closure = cast_lifetime(
201 {
202 &uninit_closure
204 },
205 #[allow(clippy::missing_transmute_annotations)]
206 unsafe {
208 std::mem::transmute(self)
209 },
210 );
211
212 reference_to_closure as &_
214 }
215}
216
217#[derive(Default)]
219pub struct LlamaBuilder {
220 source: source::LlamaSource,
221 device: Option<Device>,
222 flash_attn: bool,
223}
224
225impl LlamaBuilder {
226 pub fn with_source(mut self, source: source::LlamaSource) -> Self {
228 self.source = source;
229 self
230 }
231
232 pub fn with_flash_attn(mut self, use_flash_attn: bool) -> Self {
234 self.flash_attn = use_flash_attn;
235 self
236 }
237
238 pub fn with_device(mut self, device: Device) -> Self {
240 self.device = Some(device);
241 self
242 }
243
244 pub(crate) fn get_device(&self) -> Result<Device, LlamaSourceError> {
246 match self.device.clone() {
247 Some(device) => Ok(device),
248 None => Ok(accelerated_device_if_available()?),
249 }
250 }
251
252 pub async fn build_with_loading_handler(
276 self,
277 handler: impl FnMut(ModelLoadingProgress) + Send + Sync + 'static,
278 ) -> Result<Llama, LlamaSourceError> {
279 let model = LlamaModel::from_builder(self, handler).await?;
280
281 Ok(Llama::from_build(model))
282 }
283
284 pub async fn build(self) -> Result<Llama, LlamaSourceError> {
286 self.build_with_loading_handler(ModelLoadingProgress::multi_bar_loading_indicator())
287 .await
288 }
289}
290
291#[derive(Debug)]
292pub(crate) struct InferenceSettings {
293 prompt: String,
294
295 stop_on: Option<String>,
297
298 sampler: std::sync::Arc<std::sync::Mutex<dyn llm_samplers::prelude::Sampler>>,
300
301 session: LlamaSession,
303
304 max_tokens: u32,
306
307 seed: Option<u64>,
309}
310
311impl InferenceSettings {
312 pub fn new(
313 prompt: impl Into<String>,
314 session: LlamaSession,
315 sampler: std::sync::Arc<std::sync::Mutex<dyn llm_samplers::prelude::Sampler>>,
316 max_tokens: u32,
317 stop_on: Option<String>,
318 seed: Option<u64>,
319 ) -> Self {
320 Self {
321 prompt: prompt.into(),
322 stop_on,
323 sampler,
324 session,
325 max_tokens,
326 seed,
327 }
328 }
329}