use crate::batch::LlamaBatch;
use crate::error::Result;
use crate::sampling::LlamaSampler;
use crate::token::LlamaToken;
use crate::Llama;
impl Llama {
pub fn complete_infill(&mut self, prefix: &str, suffix: &str) -> Result<String> {
let fim = self
.model()
.fim_tokens()
.ok_or_else(|| crate::error::LlamaError::Batch("model does not support FIM".into()))?;
let _ = self.context_mut().seq_rm(0, -1, -1);
let prompt = fim.build_prompt(prefix, suffix)?;
let tokens = self.model().tokenize(&prompt, true, false)?;
if tokens.is_empty() {
return Ok(String::new());
}
let mut batch = LlamaBatch::new(tokens.len(), 1);
for (i, &t) in tokens.iter().enumerate() {
let logits = i + 1 == tokens.len();
batch
.add(t, i as i32, &[0], logits)
.map_err(crate::error::LlamaError::from)?;
}
self.context_mut().decode(&batch)?;
let mut sampler = LlamaSampler::greedy()
.ok_or_else(|| crate::error::LlamaError::Batch("greedy sampler init failed".into()))?;
let ctx_ptr = self.context().raw_handle();
let eos = self.model().token_eos();
let eot = fim.eot.unwrap_or(eos);
let mut out = String::new();
let mut n_generated = 0_usize;
for _ in 0..256 {
let idx = if n_generated == 0 {
(tokens.len() as i32) - 1
} else {
0
};
let tok: LlamaToken = unsafe { sampler.sample(ctx_ptr, idx) };
sampler.accept(tok);
if tok == eos || tok == eot {
break;
}
if let Ok(piece) = self.model().detokenize(&[tok], false) {
out.push_str(&piece);
}
n_generated += 1;
let mut single = LlamaBatch::new(1, 1);
single
.add(
tok,
tokens.len() as i32 + n_generated as i32 - 1,
&[0],
true,
)
.map_err(crate::error::LlamaError::from)?;
self.context_mut().decode(&single)?;
}
Ok(out.trim().to_string())
}
}
impl Llama {
pub(crate) fn context_mut(&mut self) -> &mut crate::context::LlamaContext<'static> {
&mut self.context
}
}