use crate::context::LlamaContext;
use crate::error::{LlamaError, Result};
impl LlamaContext<'_> {
pub fn embeddings(&self) -> Result<&[f32]> {
let ptr = unsafe { llama_crab_sys::llama_get_embeddings(self.raw()) };
if ptr.is_null() {
return Err(LlamaError::Embedding(
"embeddings not enabled (LlamaContextParams::with_embeddings(true))".into(),
));
}
let n = self.model().n_embd() as usize;
Ok(unsafe { std::slice::from_raw_parts(ptr, n) })
}
pub fn embeddings_seq(&self, seq_id: i32) -> Result<&[f32]> {
let ptr = unsafe { llama_crab_sys::llama_get_embeddings_seq(self.raw(), seq_id) };
if ptr.is_null() {
return Err(LlamaError::Embedding(format!(
"no embedding for seq {seq_id}"
)));
}
let n = self.model().n_embd() as usize;
Ok(unsafe { std::slice::from_raw_parts(ptr, n) })
}
pub fn embeddings_ith(&self, i: i32) -> Result<&[f32]> {
let ptr = unsafe { llama_crab_sys::llama_get_embeddings_ith(self.raw(), i) };
if ptr.is_null() {
return Err(LlamaError::Embedding(format!("no embedding at index {i}")));
}
let n = self.model().n_embd() as usize;
Ok(unsafe { std::slice::from_raw_parts(ptr, n) })
}
pub fn normalize(v: &mut [f32]) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
}