use serde_json::Value;
use crate::error::{LlamaError, Result};
use super::Llama;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Completion {
pub text: String,
pub n_tokens: usize,
pub stop_reason: StopReason,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum StopReason {
Length,
Eos,
Stop,
}
pub fn create_completion(llama: &mut Llama, prompt: &str, max_tokens: usize) -> Result<Completion> {
let _ = llama.context().seq_rm(0, -1, -1);
let tokens = llama.model().tokenize(prompt, true, true)?;
let mut batch = crate::batch::LlamaBatch::new(tokens.len(), 1);
for (i, &t) in tokens.iter().enumerate() {
batch
.add(t, i as i32, &[0], i + 1 == tokens.len())
.map_err(LlamaError::from)?;
}
llama.context().decode(&batch)?;
let mut sampler = crate::sampling::LlamaSampler::greedy()
.ok_or_else(|| LlamaError::Batch("sampler_init_greedy returned null".into()))?;
let ctx_ptr = llama.context().raw_handle();
let eos = llama.model().token_eos();
let eot = llama.model().token_eot();
let mut generated = String::new();
let mut last_pos = tokens.len() as i32;
let mut n_generated = 0_usize;
let mut stop_reason = StopReason::Length;
for _ in 0..max_tokens {
let idx = if n_generated == 0 {
(tokens.len() as i32) - 1
} else {
0
};
let next = unsafe { sampler.sample(ctx_ptr, idx) };
sampler.accept(next);
if next == eos || next == eot {
stop_reason = StopReason::Eos;
break;
}
let piece = llama.model().detokenize(&[next], false)?;
generated.push_str(&piece);
n_generated += 1;
let mut single = crate::batch::LlamaBatch::new(1, 1);
single
.add(next, last_pos, &[0], true)
.map_err(LlamaError::from)?;
llama.context().decode(&single)?;
last_pos += 1;
}
Ok(Completion {
text: generated,
n_tokens: n_generated,
stop_reason,
})
}
pub fn json_schema_grammar(schema: &Value) -> Result<String> {
crate::json_schema::schema_to_grammar(schema, "root")
.map_err(|e| LlamaError::JsonSchemaToGrammar(e.to_string()))
}