use std::fmt::Display;
use partial_sort::PartialSort;
use rand::{distributions::WeightedIndex, prelude::Distribution};
use thiserror::Error;
use crate::{
mulf, InferenceError, InferenceParameters, Model, OutputRequest, TokenId, TokenUtf8Buffer,
};
const SCRATCH_SIZE: usize = 512 * 1024 * 1024;
pub struct InferenceSession {
pub(crate) _session_ctx: ggml::Context,
pub(crate) memory_size: usize,
pub(crate) config: InferenceSessionConfig,
#[doc(hidden)]
pub memory_k: ggml::Tensor,
#[doc(hidden)]
pub memory_v: ggml::Tensor,
#[doc(hidden)]
pub n_past: usize,
#[doc(hidden)]
pub mem_per_token: usize,
pub(crate) tokens: Vec<TokenId>,
#[doc(hidden)]
pub last_logits: Vec<f32>,
#[doc(hidden)]
pub scratch: [ggml::Buffer; 2],
}
unsafe impl Send for InferenceSession {}
impl InferenceSession {
pub fn feed_prompt<E: std::error::Error + 'static>(
&mut self,
model: &dyn Model,
params: &InferenceParameters,
prompt: &str,
output_request: &mut OutputRequest,
mut callback: impl FnMut(&[u8]) -> Result<(), E>,
) -> Result<(), InferenceError> {
let beginning_of_sentence = self.n_past == 0;
let vocab = model.vocabulary();
let prompt_tokens: Vec<TokenId> = vocab
.tokenize(prompt, beginning_of_sentence)?
.iter()
.map(|(_, tok)| *tok)
.collect();
if self.n_past + prompt_tokens.len() >= model.n_context_tokens() {
return Err(InferenceError::ContextFull);
}
for batch in prompt_tokens.chunks(params.n_batch) {
model.evaluate(self, params, batch, output_request);
for &tk in batch {
let should_call_callback = Some(tk) != model.bot_token_id();
if should_call_callback {
if let Err(e) = callback(vocab.token(tk as usize)) {
return Err(InferenceError::UserCallback(Box::new(e)));
}
}
self.tokens.push(tk);
}
}
Ok(())
}
pub fn infer_next_token<'v>(
&mut self,
model: &'v dyn Model,
params: &InferenceParameters,
output_request: &mut OutputRequest,
rng: &mut impl rand::Rng,
) -> Result<&'v [u8], InferenceError> {
if self.n_past + 1 >= model.n_context_tokens() {
return Err(InferenceError::ContextFull);
}
let next_token = self.sample_top_p_top_k(params, rng);
self.tokens.push(next_token);
model.evaluate(self, params, &[next_token], output_request);
if next_token as TokenId == model.eot_token_id() {
Err(InferenceError::EndOfText)
} else {
Ok(model.vocabulary().token(next_token as usize))
}
}
pub fn infer<E: std::error::Error + 'static>(
&mut self,
model: &dyn Model,
rng: &mut impl rand::Rng,
request: &InferenceRequest,
output_request: &mut OutputRequest,
mut callback: impl FnMut(&str) -> Result<(), E>,
) -> Result<InferenceStats, InferenceError> {
let maximum_token_count = request.maximum_token_count.unwrap_or(usize::MAX);
if request.play_back_previous_tokens {
let mut token_utf8_buf = TokenUtf8Buffer::new();
for token_id in &self.tokens {
if let Some(tokens) =
token_utf8_buf.push(model.vocabulary().token(*token_id as usize))
{
if let Err(e) = callback(&tokens) {
return Err(InferenceError::UserCallback(Box::new(e)));
}
}
}
}
let mut stats = InferenceStats::default();
let start_at = std::time::SystemTime::now();
let parameters = request.parameters.unwrap_or(model.inference_parameters());
self.feed_prompt(
model,
parameters,
request.prompt,
output_request,
TokenUtf8Buffer::adapt_callback(&mut callback),
)?;
stats.feed_prompt_duration = start_at.elapsed().unwrap();
stats.prompt_tokens = self.n_past;
let mut tokens_processed = 0;
let mut token_utf8_buf = TokenUtf8Buffer::new();
while tokens_processed < maximum_token_count {
let token = match self.infer_next_token(model, parameters, &mut Default::default(), rng)
{
Ok(token) => token,
Err(InferenceError::EndOfText) => break,
Err(e) => return Err(e),
};
if let Some(tokens) = token_utf8_buf.push(token) {
if let Err(e) = callback(&tokens) {
return Err(InferenceError::UserCallback(Box::new(e)));
}
}
tokens_processed += 1;
}
stats.predict_duration = start_at.elapsed().unwrap();
stats.predict_tokens = self.n_past;
Ok(stats)
}
pub fn sample_top_p_top_k(
&self,
params: &InferenceParameters,
rng: &mut impl rand::Rng,
) -> TokenId {
let logits = &self.last_logits;
let n_logits = logits.len();
let mut logits_id = Vec::<(f32, TokenId)>::with_capacity(n_logits);
{
let scale = 1.0 / params.temperature;
for (i, &logit) in logits.iter().enumerate() {
let tid = i as TokenId;
let val = if let Some(logit_override) = params.bias_tokens.get(tid) {
logit_override
} else if self.tokens[self
.tokens
.len()
.saturating_sub(params.repetition_penalty_last_n)..]
.contains(&(i as TokenId))
{
if logits[i] < 0.0 {
logit * scale * params.repeat_penalty
} else {
logit * scale / params.repeat_penalty
}
} else {
logit * scale
};
logits_id.push((val, tid));
}
}
{
logits_id.partial_sort(params.top_k, |a, b| {
b.0.total_cmp(&a.0)
});
logits_id.truncate(params.top_k);
}
let maxl = logits_id
.iter()
.map(|x| x.0)
.max_by(f32::total_cmp)
.unwrap();
let mut probs: Vec<f32> = logits_id
.iter()
.copied()
.map(|(k, _)| (k - maxl).exp())
.collect();
let sum: f32 = probs.iter().copied().sum();
for p in probs.iter_mut() {
*p /= sum;
}
if params.top_p < 1.0 {
let mut cumsum = 0.0;
for i in 0..probs.len() {
cumsum += probs[i];
if cumsum >= params.top_p {
probs.truncate(i + 1);
logits_id.truncate(i + 1);
break;
}
}
cumsum = 1.0 / cumsum;
for p in probs.iter_mut() {
*p *= cumsum;
}
}
let dist = WeightedIndex::new(&probs).expect("WeightedIndex error");
let idx = dist.sample(rng);
logits_id[idx].1
}
pub unsafe fn get_snapshot(&mut self) -> InferenceSnapshotRef<'_> {
let memory_k = unsafe {
std::slice::from_raw_parts(self.memory_k.data() as *mut u8, self.memory_k.nbytes())
};
let memory_v = unsafe {
std::slice::from_raw_parts(self.memory_v.data() as *mut u8, self.memory_v.nbytes())
};
InferenceSnapshotRef {
npast: self.n_past,
config: self.config,
tokens: self.tokens.clone(),
logits: self.last_logits.clone(),
memory_k,
memory_v,
}
}
pub fn from_snapshot(
snapshot: InferenceSnapshot,
model: &dyn Model,
) -> Result<Self, SnapshotError> {
let mut session = model.start_session(snapshot.config);
if session.memory_k.nbytes() != snapshot.memory_k.len()
|| session.memory_v.nbytes() != snapshot.memory_v.len()
{
return Err(SnapshotError::MemorySizeMismatch {
self_size: session.memory_k.nbytes() + session.memory_v.nbytes(),
input_size: snapshot.memory_k.len() + snapshot.memory_v.len(),
});
}
unsafe {
session.memory_k.write_data(&snapshot.memory_k);
session.memory_v.write_data(&snapshot.memory_v);
}
session.n_past = snapshot.npast;
session.tokens = snapshot.tokens;
session.last_logits = snapshot.last_logits;
Ok(session)
}
}
impl InferenceSession {
pub fn new(
config: InferenceSessionConfig,
n_ctx: usize,
n_layer: usize,
n_embd: usize,
n_vocab: usize,
) -> InferenceSession {
let ctx_size = {
let mut ctx_size = 0;
ctx_size += mulf!(
n_ctx,
n_layer,
n_embd,
ggml::type_sizef(config.memory_k_type.into())
); ctx_size += mulf!(
n_ctx,
n_layer,
n_embd,
ggml::type_sizef(config.memory_v_type.into())
); ctx_size += (5 + 10 * n_layer) * 256; ctx_size
};
let session_ctx = ggml::Context::init(ctx_size, true);
let n_mem = n_layer * n_ctx;
let n_elements = n_embd * n_mem;
let memory_k = session_ctx.new_tensor_1d(config.memory_k_type.into(), n_elements);
let memory_v = session_ctx.new_tensor_1d(config.memory_v_type.into(), n_elements);
InferenceSession {
_session_ctx: session_ctx,
memory_size: ctx_size,
config,
memory_k,
memory_v,
n_past: 0,
mem_per_token: 0,
tokens: vec![],
last_logits: vec![0.0; n_vocab],
scratch: scratch_buffers(),
}
}
}
impl Clone for InferenceSession {
fn clone(&self) -> Self {
let context = ggml::Context::init(self.memory_size, true);
let memory_k = context.new_tensor_1d(self.memory_k.get_type(), self.memory_k.nelements());
let memory_v = context.new_tensor_1d(self.memory_v.get_type(), self.memory_v.nelements());
Self {
_session_ctx: context,
memory_size: self.memory_size,
config: self.config,
memory_k,
memory_v,
n_past: self.n_past,
mem_per_token: self.mem_per_token,
tokens: self.tokens.clone(),
last_logits: self.last_logits.clone(),
scratch: scratch_buffers(),
}
}
}
#[derive(Error, Debug)]
pub enum SnapshotError {
#[error("I/O error while reading or writing snapshot")]
IO(#[from] std::io::Error),
#[error("could not read snapshot due to size mismatch (self={self_size}, input={input_size})")]
MemorySizeMismatch {
self_size: usize,
input_size: usize,
},
}
#[derive(serde::Serialize, Clone, PartialEq)]
pub struct InferenceSnapshotRef<'a> {
pub npast: usize,
pub config: InferenceSessionConfig,
pub tokens: Vec<TokenId>,
pub logits: Vec<f32>,
#[serde(with = "serde_bytes")]
pub memory_k: &'a [u8],
#[serde(with = "serde_bytes")]
pub memory_v: &'a [u8],
}
impl InferenceSnapshotRef<'_> {
pub fn to_owned(&self) -> InferenceSnapshot {
InferenceSnapshot {
npast: self.npast,
config: self.config,
tokens: self.tokens.clone(),
last_logits: self.logits.clone(),
memory_k: self.memory_k.to_vec(),
memory_v: self.memory_v.to_vec(),
}
}
}
#[derive(serde::Deserialize, Clone, PartialEq)]
pub struct InferenceSnapshot {
pub npast: usize,
pub config: InferenceSessionConfig,
pub tokens: Vec<TokenId>,
pub last_logits: Vec<f32>,
#[serde(with = "serde_bytes")]
pub memory_k: Vec<u8>,
#[serde(with = "serde_bytes")]
pub memory_v: Vec<u8>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct InferenceSessionConfig {
pub memory_k_type: ModelKVMemoryType,
pub memory_v_type: ModelKVMemoryType,
}
impl Default for InferenceSessionConfig {
fn default() -> Self {
Self {
memory_k_type: ModelKVMemoryType::Float32,
memory_v_type: ModelKVMemoryType::Float32,
}
}
}
#[derive(Debug, PartialEq, Default, Clone, Copy)]
pub struct InferenceRequest<'a> {
pub prompt: &'a str,
pub parameters: Option<&'a InferenceParameters>,
pub play_back_previous_tokens: bool,
pub maximum_token_count: Option<usize>,
}
#[derive(Debug, Clone, Copy)]
pub struct InferenceStats {
pub feed_prompt_duration: std::time::Duration,
pub prompt_tokens: usize,
pub predict_duration: std::time::Duration,
pub predict_tokens: usize,
}
impl Default for InferenceStats {
fn default() -> Self {
Self {
feed_prompt_duration: std::time::Duration::from_secs(0),
prompt_tokens: 0,
predict_duration: std::time::Duration::from_secs(0),
predict_tokens: 0,
}
}
}
impl Display for InferenceStats {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"feed_prompt_duration: {}ms\nprompt_tokens: {}\npredict_duration: {}ms\npredict_tokens: {}\nper_token_duration: {:.3}ms",
self.feed_prompt_duration.as_millis(),
self.prompt_tokens,
self.predict_duration.as_millis(),
self.predict_tokens,
(self.predict_duration.as_millis() as f64) / (self.predict_tokens as f64),
)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum ModelKVMemoryType {
Float16,
Float32,
}
impl From<ModelKVMemoryType> for ggml::Type {
fn from(value: ModelKVMemoryType) -> Self {
match value {
ModelKVMemoryType::Float16 => ggml::Type::F16,
ModelKVMemoryType::Float32 => ggml::Type::F32,
}
}
}
fn scratch_buffers() -> [ggml::Buffer; 2] {
[
ggml::Buffer::new(SCRATCH_SIZE),
ggml::Buffer::new(SCRATCH_SIZE),
]
}