Skip to main content

llama_cpp_bindings/
llguidance_sampler.rs

1use std::ffi::c_void;
2use std::sync::Arc;
3
4use llguidance::Matcher;
5use toktrie::ApproximateTokEnv;
6
7use crate::GrammarError;
8use crate::model::LlamaModel;
9use crate::sampling::LlamaSampler;
10
11struct LlgContext {
12    matcher: Matcher,
13    tok_env: Arc<ApproximateTokEnv>,
14    grammar_kind: String,
15    grammar_data: String,
16}
17
18const unsafe extern "C" fn llg_name(
19    _smpl: *const llama_cpp_bindings_sys::llama_sampler,
20) -> *const std::os::raw::c_char {
21    c"llguidance".as_ptr()
22}
23
24unsafe extern "C" fn llg_accept(
25    smpl: *mut llama_cpp_bindings_sys::llama_sampler,
26    token: llama_cpp_bindings_sys::llama_token,
27) {
28    let ctx = unsafe { &mut *(*smpl).ctx.cast::<LlgContext>() };
29
30    if let Err(consume_error) = ctx.matcher.consume_token(token.cast_unsigned()) {
31        log::warn!(
32            "llguidance sampler failed to consume token: token={token}, error={consume_error}",
33        );
34    }
35}
36
37unsafe extern "C" fn llg_apply(
38    smpl: *mut llama_cpp_bindings_sys::llama_sampler,
39    cur_p: *mut llama_cpp_bindings_sys::llama_token_data_array,
40) {
41    let ctx = unsafe { &mut *(*smpl).ctx.cast::<LlgContext>() };
42    let cur_p = unsafe { &mut *cur_p };
43
44    let mask = match ctx.matcher.compute_mask() {
45        Ok(mask) => mask,
46        Err(compute_error) => {
47            log::warn!(
48                "llguidance sampler failed to compute mask, skipping constraint application: error={compute_error}",
49            );
50
51            return;
52        }
53    };
54
55    let data = unsafe { std::slice::from_raw_parts_mut(cur_p.data, cur_p.size) };
56    for item in data.iter_mut() {
57        if !mask.is_allowed(item.id.cast_unsigned()) {
58            item.logit = f32::NEG_INFINITY;
59        }
60    }
61}
62
63unsafe extern "C" fn llg_reset(smpl: *mut llama_cpp_bindings_sys::llama_sampler) {
64    let ctx = unsafe { &mut *(*smpl).ctx.cast::<LlgContext>() };
65
66    if let Err(reset_error) = ctx.matcher.reset() {
67        log::warn!("llguidance sampler failed to reset: error={reset_error}");
68    }
69}
70
71unsafe extern "C" fn llg_clone(
72    smpl: *const llama_cpp_bindings_sys::llama_sampler,
73) -> *mut llama_cpp_bindings_sys::llama_sampler {
74    let ctx = unsafe { &*(*smpl).ctx.cast::<LlgContext>() };
75    let new_ctx = Box::new(LlgContext {
76        matcher: ctx.matcher.deep_clone(),
77        tok_env: Arc::clone(&ctx.tok_env),
78        grammar_kind: ctx.grammar_kind.clone(),
79        grammar_data: ctx.grammar_data.clone(),
80    });
81    unsafe {
82        llama_cpp_bindings_sys::llama_sampler_init(
83            &raw mut LLG_SAMPLER_I,
84            Box::into_raw(new_ctx).cast::<c_void>(),
85        )
86    }
87}
88
89unsafe extern "C" fn llg_free(smpl: *mut llama_cpp_bindings_sys::llama_sampler) {
90    let ctx_ptr = unsafe { (*smpl).ctx.cast::<LlgContext>() };
91    if !ctx_ptr.is_null() {
92        drop(unsafe { Box::from_raw(ctx_ptr) });
93    }
94}
95
96static mut LLG_SAMPLER_I: llama_cpp_bindings_sys::llama_sampler_i =
97    llama_cpp_bindings_sys::llama_sampler_i {
98        name: Some(llg_name),
99        accept: Some(llg_accept),
100        apply: Some(llg_apply),
101        reset: Some(llg_reset),
102        clone: Some(llg_clone),
103        free: Some(llg_free),
104        backend_init: None,
105        backend_accept: None,
106        backend_apply: None,
107        backend_set_input: None,
108    };
109
110/// # Errors
111///
112/// Returns `GrammarError` if the parser factory, grammar, or parser cannot be created.
113pub fn create_llg_sampler(
114    model: &LlamaModel,
115    grammar_kind: &str,
116    grammar_data: &str,
117) -> Result<LlamaSampler, GrammarError> {
118    let tok_env = model.approximate_tok_env();
119    let tok_env_dyn: Arc<dyn toktrie::TokenizerEnv + Sync> = tok_env.clone();
120
121    let factory = llguidance::ParserFactory::new_simple(&tok_env_dyn)
122        .map_err(|factory_error| GrammarError::LlguidanceError(factory_error.to_string()))?;
123
124    let grammar = llguidance::api::TopLevelGrammar::from_tagged_str(grammar_kind, grammar_data)
125        .map_err(|parse_error| GrammarError::LlguidanceError(parse_error.to_string()))?;
126
127    let parser = factory
128        .create_parser(grammar)
129        .map_err(|parser_error| GrammarError::LlguidanceError(parser_error.to_string()))?;
130
131    let matcher = Matcher::new(Ok(parser));
132
133    let ctx = Box::new(LlgContext {
134        matcher,
135        tok_env,
136        grammar_kind: grammar_kind.to_string(),
137        grammar_data: grammar_data.to_string(),
138    });
139
140    let sampler = unsafe {
141        llama_cpp_bindings_sys::llama_sampler_init(
142            &raw mut LLG_SAMPLER_I,
143            Box::into_raw(ctx).cast::<c_void>(),
144        )
145    };
146
147    if sampler.is_null() {
148        Err(GrammarError::LlguidanceSamplerUnavailable)
149    } else {
150        Ok(LlamaSampler { sampler })
151    }
152}