1use std::fmt::Display;
2
3use partial_sort::PartialSort;
4use rand::{distributions::WeightedIndex, prelude::Distribution};
5use thiserror::Error;
6
7use crate::{
8 mulf, InferenceError, InferenceParameters, Model, OutputRequest, TokenId, TokenUtf8Buffer,
9};
10
11const SCRATCH_SIZE: usize = 512 * 1024 * 1024;
16
17pub struct InferenceSession {
28 pub(crate) _session_ctx: ggml::Context,
30
31 pub(crate) memory_size: usize,
33
34 pub(crate) config: InferenceSessionConfig,
36
37 #[doc(hidden)]
39 pub memory_k: ggml::Tensor,
40
41 #[doc(hidden)]
43 pub memory_v: ggml::Tensor,
44
45 #[doc(hidden)]
47 pub n_past: usize,
48
49 #[doc(hidden)]
52 pub mem_per_token: usize,
53
54 pub(crate) tokens: Vec<TokenId>,
56
57 #[doc(hidden)]
59 pub last_logits: Vec<f32>,
60
61 #[doc(hidden)]
66 pub scratch: [ggml::Buffer; 2],
67}
68unsafe impl Send for InferenceSession {}
69impl InferenceSession {
70 pub fn feed_prompt<E: std::error::Error + 'static>(
72 &mut self,
73 model: &dyn Model,
74 params: &InferenceParameters,
75 prompt: &str,
76 output_request: &mut OutputRequest,
77 mut callback: impl FnMut(&[u8]) -> Result<(), E>,
78 ) -> Result<(), InferenceError> {
79 let beginning_of_sentence = self.n_past == 0;
80
81 let vocab = model.vocabulary();
82 let prompt_tokens: Vec<TokenId> = vocab
83 .tokenize(prompt, beginning_of_sentence)?
84 .iter()
85 .map(|(_, tok)| *tok)
86 .collect();
87
88 if self.n_past + prompt_tokens.len() >= model.n_context_tokens() {
89 return Err(InferenceError::ContextFull);
90 }
91
92 for batch in prompt_tokens.chunks(params.n_batch) {
93 model.evaluate(self, params, batch, output_request);
94 for &tk in batch {
95 let should_call_callback = Some(tk) != model.bot_token_id();
96
97 if should_call_callback {
98 if let Err(e) = callback(vocab.token(tk as usize)) {
101 return Err(InferenceError::UserCallback(Box::new(e)));
102 }
103 }
104
105 self.tokens.push(tk);
107 }
108 }
109
110 Ok(())
111 }
112
113 pub fn infer_next_token<'v>(
115 &mut self,
116 model: &'v dyn Model,
117 params: &InferenceParameters,
118 output_request: &mut OutputRequest,
119 rng: &mut impl rand::Rng,
120 ) -> Result<&'v [u8], InferenceError> {
121 if self.n_past + 1 >= model.n_context_tokens() {
122 return Err(InferenceError::ContextFull);
123 }
124
125 let next_token = self.sample_top_p_top_k(params, rng);
127
128 self.tokens.push(next_token);
130
131 model.evaluate(self, params, &[next_token], output_request);
133
134 if next_token as TokenId == model.eot_token_id() {
136 Err(InferenceError::EndOfText)
137 } else {
138 Ok(model.vocabulary().token(next_token as usize))
139 }
140 }
141
142 pub fn infer<E: std::error::Error + 'static>(
150 &mut self,
151 model: &dyn Model,
152 rng: &mut impl rand::Rng,
153 request: &InferenceRequest,
154 output_request: &mut OutputRequest,
155 mut callback: impl FnMut(&str) -> Result<(), E>,
156 ) -> Result<InferenceStats, InferenceError> {
157 let maximum_token_count = request.maximum_token_count.unwrap_or(usize::MAX);
158 if request.play_back_previous_tokens {
159 let mut token_utf8_buf = TokenUtf8Buffer::new();
162 for token_id in &self.tokens {
163 if let Some(tokens) =
165 token_utf8_buf.push(model.vocabulary().token(*token_id as usize))
166 {
167 if let Err(e) = callback(&tokens) {
168 return Err(InferenceError::UserCallback(Box::new(e)));
169 }
170 }
171 }
172 }
173
174 let mut stats = InferenceStats::default();
175 let start_at = std::time::SystemTime::now();
176
177 let parameters = request.parameters.unwrap_or(model.inference_parameters());
178
179 self.feed_prompt(
182 model,
183 parameters,
184 request.prompt,
185 output_request,
186 TokenUtf8Buffer::adapt_callback(&mut callback),
187 )?;
188 stats.feed_prompt_duration = start_at.elapsed().unwrap();
189 stats.prompt_tokens = self.n_past;
190
191 let mut tokens_processed = 0;
196 let mut token_utf8_buf = TokenUtf8Buffer::new();
197 while tokens_processed < maximum_token_count {
198 let token = match self.infer_next_token(model, parameters, &mut Default::default(), rng)
199 {
200 Ok(token) => token,
201 Err(InferenceError::EndOfText) => break,
202 Err(e) => return Err(e),
203 };
204
205 if let Some(tokens) = token_utf8_buf.push(token) {
207 if let Err(e) = callback(&tokens) {
208 return Err(InferenceError::UserCallback(Box::new(e)));
209 }
210 }
211
212 tokens_processed += 1;
213 }
214 stats.predict_duration = start_at.elapsed().unwrap();
215 stats.predict_tokens = self.n_past;
216
217 Ok(stats)
218 }
219
220 pub fn sample_top_p_top_k(
222 &self,
223 params: &InferenceParameters,
224 rng: &mut impl rand::Rng,
225 ) -> TokenId {
226 let logits = &self.last_logits;
227 let n_logits = logits.len();
228 let mut logits_id = Vec::<(f32, TokenId)>::with_capacity(n_logits);
229
230 {
231 let scale = 1.0 / params.temperature;
232 for (i, &logit) in logits.iter().enumerate() {
233 let tid = i as TokenId;
234
235 let val = if let Some(logit_override) = params.bias_tokens.get(tid) {
236 logit_override
237 } else if self.tokens[self
238 .tokens
239 .len()
240 .saturating_sub(params.repetition_penalty_last_n)..]
241 .contains(&(i as TokenId))
242 {
243 if logits[i] < 0.0 {
248 logit * scale * params.repeat_penalty
249 } else {
250 logit * scale / params.repeat_penalty
251 }
252 } else {
253 logit * scale
254 };
255 logits_id.push((val, tid));
256 }
257 }
258
259 {
261 logits_id.partial_sort(params.top_k, |a, b| {
262 b.0.total_cmp(&a.0)
264 });
265 logits_id.truncate(params.top_k);
266 }
267
268 let maxl = logits_id
269 .iter()
270 .map(|x| x.0)
271 .max_by(f32::total_cmp)
272 .unwrap();
273
274 let mut probs: Vec<f32> = logits_id
276 .iter()
277 .copied()
278 .map(|(k, _)| (k - maxl).exp())
279 .collect();
280 let sum: f32 = probs.iter().copied().sum();
281
282 for p in probs.iter_mut() {
284 *p /= sum;
285 }
286
287 if params.top_p < 1.0 {
289 let mut cumsum = 0.0;
290 for i in 0..probs.len() {
291 cumsum += probs[i];
292 if cumsum >= params.top_p {
293 probs.truncate(i + 1);
294 logits_id.truncate(i + 1);
295 break;
296 }
297 }
298
299 cumsum = 1.0 / cumsum;
300 for p in probs.iter_mut() {
301 *p *= cumsum;
302 }
303 }
304
305 let dist = WeightedIndex::new(&probs).expect("WeightedIndex error");
306 let idx = dist.sample(rng);
307
308 logits_id[idx].1
309 }
310
311 pub unsafe fn get_snapshot(&mut self) -> InferenceSnapshotRef<'_> {
320 let memory_k = unsafe {
321 std::slice::from_raw_parts(self.memory_k.data() as *mut u8, self.memory_k.nbytes())
322 };
323 let memory_v = unsafe {
324 std::slice::from_raw_parts(self.memory_v.data() as *mut u8, self.memory_v.nbytes())
325 };
326
327 InferenceSnapshotRef {
328 npast: self.n_past,
329 config: self.config,
330 tokens: self.tokens.clone(),
331 logits: self.last_logits.clone(),
332 memory_k,
333 memory_v,
334 }
335 }
336
337 pub fn from_snapshot(
339 snapshot: InferenceSnapshot,
340 model: &dyn Model,
341 ) -> Result<Self, SnapshotError> {
342 let mut session = model.start_session(snapshot.config);
343
344 if session.memory_k.nbytes() != snapshot.memory_k.len()
345 || session.memory_v.nbytes() != snapshot.memory_v.len()
346 {
347 return Err(SnapshotError::MemorySizeMismatch {
348 self_size: session.memory_k.nbytes() + session.memory_v.nbytes(),
349 input_size: snapshot.memory_k.len() + snapshot.memory_v.len(),
350 });
351 }
352
353 unsafe {
357 session.memory_k.write_data(&snapshot.memory_k);
358 session.memory_v.write_data(&snapshot.memory_v);
359 }
360
361 session.n_past = snapshot.npast;
362 session.tokens = snapshot.tokens;
363 session.last_logits = snapshot.last_logits;
364
365 Ok(session)
366 }
367}
368impl InferenceSession {
369 pub fn new(
371 config: InferenceSessionConfig,
372 n_ctx: usize,
373 n_layer: usize,
374 n_embd: usize,
375 n_vocab: usize,
376 ) -> InferenceSession {
377 let ctx_size = {
378 let mut ctx_size = 0;
379 ctx_size += mulf!(
380 n_ctx,
381 n_layer,
382 n_embd,
383 ggml::type_sizef(config.memory_k_type.into())
384 ); ctx_size += mulf!(
386 n_ctx,
387 n_layer,
388 n_embd,
389 ggml::type_sizef(config.memory_v_type.into())
390 ); ctx_size += (5 + 10 * n_layer) * 256; ctx_size
393 };
394
395 let session_ctx = ggml::Context::init(ctx_size, true);
396
397 let n_mem = n_layer * n_ctx;
399 let n_elements = n_embd * n_mem;
400 let memory_k = session_ctx.new_tensor_1d(config.memory_k_type.into(), n_elements);
401 let memory_v = session_ctx.new_tensor_1d(config.memory_v_type.into(), n_elements);
402
403 InferenceSession {
404 _session_ctx: session_ctx,
405 memory_size: ctx_size,
406 config,
407 memory_k,
408 memory_v,
409 n_past: 0,
410 mem_per_token: 0,
411 tokens: vec![],
412 last_logits: vec![0.0; n_vocab],
413 scratch: scratch_buffers(),
414 }
415 }
416}
417impl Clone for InferenceSession {
418 fn clone(&self) -> Self {
419 let context = ggml::Context::init(self.memory_size, true);
420 let memory_k = context.new_tensor_1d(self.memory_k.get_type(), self.memory_k.nelements());
421 let memory_v = context.new_tensor_1d(self.memory_v.get_type(), self.memory_v.nelements());
422
423 Self {
424 _session_ctx: context,
425 memory_size: self.memory_size,
426 config: self.config,
427 memory_k,
428 memory_v,
429 n_past: self.n_past,
430 mem_per_token: self.mem_per_token,
431 tokens: self.tokens.clone(),
432 last_logits: self.last_logits.clone(),
433 scratch: scratch_buffers(),
434 }
435 }
436}
437
438#[derive(Error, Debug)]
439pub enum SnapshotError {
441 #[error("I/O error while reading or writing snapshot")]
443 IO(#[from] std::io::Error),
444 #[error("could not read snapshot due to size mismatch (self={self_size}, input={input_size})")]
446 MemorySizeMismatch {
447 self_size: usize,
449 input_size: usize,
451 },
452}
453
454#[derive(serde::Serialize, Clone, PartialEq)]
455pub struct InferenceSnapshotRef<'a> {
463 pub npast: usize,
465 pub config: InferenceSessionConfig,
467 pub tokens: Vec<TokenId>,
469 pub logits: Vec<f32>,
471 #[serde(with = "serde_bytes")]
473 pub memory_k: &'a [u8],
474 #[serde(with = "serde_bytes")]
476 pub memory_v: &'a [u8],
477}
478impl InferenceSnapshotRef<'_> {
479 pub fn to_owned(&self) -> InferenceSnapshot {
483 InferenceSnapshot {
484 npast: self.npast,
485 config: self.config,
486 tokens: self.tokens.clone(),
487 last_logits: self.logits.clone(),
488 memory_k: self.memory_k.to_vec(),
489 memory_v: self.memory_v.to_vec(),
490 }
491 }
492}
493
494#[derive(serde::Deserialize, Clone, PartialEq)]
497pub struct InferenceSnapshot {
499 pub npast: usize,
501 pub config: InferenceSessionConfig,
503 pub tokens: Vec<TokenId>,
505 pub last_logits: Vec<f32>,
507 #[serde(with = "serde_bytes")]
509 pub memory_k: Vec<u8>,
510 #[serde(with = "serde_bytes")]
512 pub memory_v: Vec<u8>,
513}
514
515#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
516pub struct InferenceSessionConfig {
521 pub memory_k_type: ModelKVMemoryType,
523 pub memory_v_type: ModelKVMemoryType,
525}
526impl Default for InferenceSessionConfig {
527 fn default() -> Self {
528 Self {
529 memory_k_type: ModelKVMemoryType::Float32,
530 memory_v_type: ModelKVMemoryType::Float32,
531 }
532 }
533}
534
535#[derive(Debug, PartialEq, Default, Clone, Copy)]
536pub struct InferenceRequest<'a> {
538 pub prompt: &'a str,
540 pub parameters: Option<&'a InferenceParameters>,
544 pub play_back_previous_tokens: bool,
550 pub maximum_token_count: Option<usize>,
552}
553
554#[derive(Debug, Clone, Copy)]
556pub struct InferenceStats {
557 pub feed_prompt_duration: std::time::Duration,
559 pub prompt_tokens: usize,
561 pub predict_duration: std::time::Duration,
563 pub predict_tokens: usize,
565}
566impl Default for InferenceStats {
567 fn default() -> Self {
568 Self {
569 feed_prompt_duration: std::time::Duration::from_secs(0),
570 prompt_tokens: 0,
571 predict_duration: std::time::Duration::from_secs(0),
572 predict_tokens: 0,
573 }
574 }
575}
576impl Display for InferenceStats {
577 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
578 write!(
579 f,
580 "feed_prompt_duration: {}ms\nprompt_tokens: {}\npredict_duration: {}ms\npredict_tokens: {}\nper_token_duration: {:.3}ms",
581 self.feed_prompt_duration.as_millis(),
582 self.prompt_tokens,
583 self.predict_duration.as_millis(),
584 self.predict_tokens,
585 (self.predict_duration.as_millis() as f64) / (self.predict_tokens as f64),
586 )
587 }
588}
589
590#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
592pub enum ModelKVMemoryType {
593 Float16,
595 Float32,
597}
598impl From<ModelKVMemoryType> for ggml::Type {
599 fn from(value: ModelKVMemoryType) -> Self {
600 match value {
601 ModelKVMemoryType::Float16 => ggml::Type::F16,
602 ModelKVMemoryType::Float32 => ggml::Type::F32,
603 }
604 }
605}
606
607fn scratch_buffers() -> [ggml::Buffer; 2] {
608 [
609 ggml::Buffer::new(SCRATCH_SIZE),
610 ggml::Buffer::new(SCRATCH_SIZE),
611 ]
612}