Skip to main content

llama_cpp_bindings/
llguidance_sampler.rs

1//! Pure Rust llguidance sampler for constrained decoding.
2//!
3//! Implements a custom `llama_sampler` using the `llguidance` and `toktrie` Rust crates
4//! to enforce grammar constraints (JSON schema, regex, Lark, etc.) during token sampling.
5
6use std::ffi::c_void;
7use std::sync::Arc;
8
9use llguidance::Matcher;
10use toktrie::ApproximateTokEnv;
11
12use crate::GrammarError;
13use crate::model::LlamaModel;
14use crate::sampling::LlamaSampler;
15
16/// Internal state for the llguidance sampler.
17struct LlgContext {
18    matcher: Matcher,
19    tok_env: Arc<ApproximateTokEnv>,
20    grammar_kind: String,
21    grammar_data: String,
22}
23
24const unsafe extern "C" fn llg_name(
25    _smpl: *const llama_cpp_bindings_sys::llama_sampler,
26) -> *const std::os::raw::c_char {
27    c"llguidance".as_ptr()
28}
29
30unsafe extern "C" fn llg_accept(
31    smpl: *mut llama_cpp_bindings_sys::llama_sampler,
32    token: llama_cpp_bindings_sys::llama_token,
33) {
34    let ctx = unsafe { &mut *(*smpl).ctx.cast::<LlgContext>() };
35
36    if let Err(consume_error) = ctx.matcher.consume_token(token.cast_unsigned()) {
37        tracing::warn!(
38            token = token,
39            error = %consume_error,
40            "llguidance sampler failed to consume token"
41        );
42    }
43}
44
45unsafe extern "C" fn llg_apply(
46    smpl: *mut llama_cpp_bindings_sys::llama_sampler,
47    cur_p: *mut llama_cpp_bindings_sys::llama_token_data_array,
48) {
49    let ctx = unsafe { &mut *(*smpl).ctx.cast::<LlgContext>() };
50    let cur_p = unsafe { &mut *cur_p };
51
52    let mask = match ctx.matcher.compute_mask() {
53        Ok(mask) => mask,
54        Err(compute_error) => {
55            tracing::warn!(
56                error = %compute_error,
57                "llguidance sampler failed to compute mask, skipping constraint application"
58            );
59
60            return;
61        }
62    };
63
64    let data = unsafe { std::slice::from_raw_parts_mut(cur_p.data, cur_p.size) };
65    for item in data.iter_mut() {
66        if !mask.is_allowed(item.id.cast_unsigned()) {
67            item.logit = f32::NEG_INFINITY;
68        }
69    }
70}
71
72unsafe extern "C" fn llg_reset(smpl: *mut llama_cpp_bindings_sys::llama_sampler) {
73    let ctx = unsafe { &mut *(*smpl).ctx.cast::<LlgContext>() };
74
75    if let Err(reset_error) = ctx.matcher.reset() {
76        tracing::warn!(
77            error = %reset_error,
78            "llguidance sampler failed to reset"
79        );
80    }
81}
82
83unsafe extern "C" fn llg_clone(
84    smpl: *const llama_cpp_bindings_sys::llama_sampler,
85) -> *mut llama_cpp_bindings_sys::llama_sampler {
86    let ctx = unsafe { &*(*smpl).ctx.cast::<LlgContext>() };
87    let new_ctx = Box::new(LlgContext {
88        matcher: ctx.matcher.deep_clone(),
89        tok_env: Arc::clone(&ctx.tok_env),
90        grammar_kind: ctx.grammar_kind.clone(),
91        grammar_data: ctx.grammar_data.clone(),
92    });
93    unsafe {
94        llama_cpp_bindings_sys::llama_sampler_init(
95            &raw mut LLG_SAMPLER_I,
96            Box::into_raw(new_ctx).cast::<c_void>(),
97        )
98    }
99}
100
101unsafe extern "C" fn llg_free(smpl: *mut llama_cpp_bindings_sys::llama_sampler) {
102    let ctx_ptr = unsafe { (*smpl).ctx.cast::<LlgContext>() };
103    if !ctx_ptr.is_null() {
104        drop(unsafe { Box::from_raw(ctx_ptr) });
105    }
106}
107
108static mut LLG_SAMPLER_I: llama_cpp_bindings_sys::llama_sampler_i =
109    llama_cpp_bindings_sys::llama_sampler_i {
110        name: Some(llg_name),
111        accept: Some(llg_accept),
112        apply: Some(llg_apply),
113        reset: Some(llg_reset),
114        clone: Some(llg_clone),
115        free: Some(llg_free),
116        backend_init: None,
117        backend_accept: None,
118        backend_apply: None,
119        backend_set_input: None,
120    };
121
122/// Create an llguidance-based constrained decoding sampler.
123///
124/// # Errors
125///
126/// Returns `GrammarError` if the parser factory, grammar, or parser cannot be created.
127pub fn create_llg_sampler(
128    model: &LlamaModel,
129    grammar_kind: &str,
130    grammar_data: &str,
131) -> Result<LlamaSampler, GrammarError> {
132    let tok_env = model.approximate_tok_env();
133    let tok_env_dyn: Arc<dyn toktrie::TokenizerEnv + Sync> = tok_env.clone();
134
135    let factory = llguidance::ParserFactory::new_simple(&tok_env_dyn)
136        .map_err(|factory_error| GrammarError::LlguidanceError(factory_error.to_string()))?;
137
138    let grammar = llguidance::api::TopLevelGrammar::from_tagged_str(grammar_kind, grammar_data)
139        .map_err(|parse_error| GrammarError::LlguidanceError(parse_error.to_string()))?;
140
141    let parser = factory
142        .create_parser(grammar)
143        .map_err(|parser_error| GrammarError::LlguidanceError(parser_error.to_string()))?;
144
145    let matcher = Matcher::new(Ok(parser));
146
147    let ctx = Box::new(LlgContext {
148        matcher,
149        tok_env,
150        grammar_kind: grammar_kind.to_string(),
151        grammar_data: grammar_data.to_string(),
152    });
153
154    let sampler = unsafe {
155        llama_cpp_bindings_sys::llama_sampler_init(
156            &raw mut LLG_SAMPLER_I,
157            Box::into_raw(ctx).cast::<c_void>(),
158        )
159    };
160
161    if sampler.is_null() {
162        Err(GrammarError::NullGrammar(
163            "llguidance sampler returned null".to_owned(),
164        ))
165    } else {
166        Ok(LlamaSampler { sampler })
167    }
168}