use llama_crab_sys as sys;
use crate::model::LlamaModel;
#[allow(unused_imports)]
use crate::token::LlamaToken;
use super::LlamaSampler;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum GrammarError {
TriggerWordNullBytes,
GrammarNullBytes,
NullGrammar,
RootNotFound,
Ffi(i32),
}
impl std::fmt::Display for GrammarError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TriggerWordNullBytes => f.write_str("trigger word contains a NUL byte"),
Self::GrammarNullBytes => f.write_str("grammar string contains a NUL byte"),
Self::NullGrammar => f.write_str("llama_sampler_init_grammar returned null"),
Self::RootNotFound => f.write_str("grammar root rule not found"),
Self::Ffi(c) => write!(f, "grammar FFI failed (code {c})"),
}
}
}
impl std::error::Error for GrammarError {}
impl From<std::ffi::NulError> for GrammarError {
fn from(_: std::ffi::NulError) -> Self {
Self::TriggerWordNullBytes
}
}
impl LlamaSampler {
#[cfg(feature = "common")]
pub unsafe fn grammar(
model: &LlamaModel,
grammar: &str,
grammar_root: &str,
) -> Result<Self, GrammarError> {
let grammar = std::ffi::CString::new(grammar)
.map_err(|_| GrammarError::GrammarNullBytes)?;
let root = std::ffi::CString::new(grammar_root)
.map_err(|_| GrammarError::GrammarNullBytes)?;
let p = sys::llama_sampler_init_grammar(model.raw(), grammar.as_ptr(), root.as_ptr());
if p.is_null() {
return Err(GrammarError::NullGrammar);
}
Ok(Self::from_raw(p))
}
}