use crate::context::LlamaContext;
use crate::error::Result;
use crate::token::LlamaToken;
impl LlamaContext<'_> {
pub fn logits_ith(&self, i: i32) -> Result<&[f32]> {
let ptr = unsafe { llama_crab_sys::llama_get_logits_ith(self.raw_handle(), i) };
if ptr.is_null() {
return Err(crate::error::LlamaError::Batch(format!(
"no logits at index {i}"
)));
}
let n = self.model().n_vocab() as usize;
Ok(unsafe { std::slice::from_raw_parts(ptr, n) })
}
#[must_use]
pub fn sampled_token_ith(&self, i: i32) -> LlamaToken {
let raw = unsafe { llama_crab_sys::llama_get_sampled_token_ith(self.raw_handle(), i) };
LlamaToken(raw)
}
pub fn sampled_probs_ith(&self, i: i32) -> Result<&[f32]> {
let ptr = unsafe { llama_crab_sys::llama_get_sampled_probs_ith(self.raw_handle(), i) };
if ptr.is_null() {
return Err(crate::error::LlamaError::Batch(format!(
"no sampled probs at index {i}"
)));
}
let n = self.model().n_vocab() as usize;
Ok(unsafe { std::slice::from_raw_parts(ptr, n) })
}
}