llama_cpp_bindings/
llguidance_sampler.rs1use std::ffi::c_void;
7use std::sync::Arc;
8
9use llguidance::Matcher;
10use toktrie::ApproximateTokEnv;
11
12use crate::GrammarError;
13use crate::model::LlamaModel;
14use crate::sampling::LlamaSampler;
15
16struct LlgContext {
18 matcher: Matcher,
19 tok_env: Arc<ApproximateTokEnv>,
20 grammar_kind: String,
21 grammar_data: String,
22}
23
24const unsafe extern "C" fn llg_name(
25 _smpl: *const llama_cpp_bindings_sys::llama_sampler,
26) -> *const std::os::raw::c_char {
27 c"llguidance".as_ptr()
28}
29
30unsafe extern "C" fn llg_accept(
31 smpl: *mut llama_cpp_bindings_sys::llama_sampler,
32 token: llama_cpp_bindings_sys::llama_token,
33) {
34 let ctx = unsafe { &mut *(*smpl).ctx.cast::<LlgContext>() };
35
36 if let Err(consume_error) = ctx.matcher.consume_token(token.cast_unsigned()) {
37 log::warn!(
38 "llguidance sampler failed to consume token: token={token}, error={consume_error}",
39 );
40 }
41}
42
43unsafe extern "C" fn llg_apply(
44 smpl: *mut llama_cpp_bindings_sys::llama_sampler,
45 cur_p: *mut llama_cpp_bindings_sys::llama_token_data_array,
46) {
47 let ctx = unsafe { &mut *(*smpl).ctx.cast::<LlgContext>() };
48 let cur_p = unsafe { &mut *cur_p };
49
50 let mask = match ctx.matcher.compute_mask() {
51 Ok(mask) => mask,
52 Err(compute_error) => {
53 log::warn!(
54 "llguidance sampler failed to compute mask, skipping constraint application: error={compute_error}",
55 );
56
57 return;
58 }
59 };
60
61 let data = unsafe { std::slice::from_raw_parts_mut(cur_p.data, cur_p.size) };
62 for item in data.iter_mut() {
63 if !mask.is_allowed(item.id.cast_unsigned()) {
64 item.logit = f32::NEG_INFINITY;
65 }
66 }
67}
68
69unsafe extern "C" fn llg_reset(smpl: *mut llama_cpp_bindings_sys::llama_sampler) {
70 let ctx = unsafe { &mut *(*smpl).ctx.cast::<LlgContext>() };
71
72 if let Err(reset_error) = ctx.matcher.reset() {
73 log::warn!("llguidance sampler failed to reset: error={reset_error}");
74 }
75}
76
77unsafe extern "C" fn llg_clone(
78 smpl: *const llama_cpp_bindings_sys::llama_sampler,
79) -> *mut llama_cpp_bindings_sys::llama_sampler {
80 let ctx = unsafe { &*(*smpl).ctx.cast::<LlgContext>() };
81 let new_ctx = Box::new(LlgContext {
82 matcher: ctx.matcher.deep_clone(),
83 tok_env: Arc::clone(&ctx.tok_env),
84 grammar_kind: ctx.grammar_kind.clone(),
85 grammar_data: ctx.grammar_data.clone(),
86 });
87 unsafe {
88 llama_cpp_bindings_sys::llama_sampler_init(
89 &raw mut LLG_SAMPLER_I,
90 Box::into_raw(new_ctx).cast::<c_void>(),
91 )
92 }
93}
94
95unsafe extern "C" fn llg_free(smpl: *mut llama_cpp_bindings_sys::llama_sampler) {
96 let ctx_ptr = unsafe { (*smpl).ctx.cast::<LlgContext>() };
97 if !ctx_ptr.is_null() {
98 drop(unsafe { Box::from_raw(ctx_ptr) });
99 }
100}
101
102static mut LLG_SAMPLER_I: llama_cpp_bindings_sys::llama_sampler_i =
103 llama_cpp_bindings_sys::llama_sampler_i {
104 name: Some(llg_name),
105 accept: Some(llg_accept),
106 apply: Some(llg_apply),
107 reset: Some(llg_reset),
108 clone: Some(llg_clone),
109 free: Some(llg_free),
110 backend_init: None,
111 backend_accept: None,
112 backend_apply: None,
113 backend_set_input: None,
114 };
115
116pub fn create_llg_sampler(
122 model: &LlamaModel,
123 grammar_kind: &str,
124 grammar_data: &str,
125) -> Result<LlamaSampler, GrammarError> {
126 let tok_env = model.approximate_tok_env();
127 let tok_env_dyn: Arc<dyn toktrie::TokenizerEnv + Sync> = tok_env.clone();
128
129 let factory = llguidance::ParserFactory::new_simple(&tok_env_dyn)
130 .map_err(|factory_error| GrammarError::LlguidanceError(factory_error.to_string()))?;
131
132 let grammar = llguidance::api::TopLevelGrammar::from_tagged_str(grammar_kind, grammar_data)
133 .map_err(|parse_error| GrammarError::LlguidanceError(parse_error.to_string()))?;
134
135 let parser = factory
136 .create_parser(grammar)
137 .map_err(|parser_error| GrammarError::LlguidanceError(parser_error.to_string()))?;
138
139 let matcher = Matcher::new(Ok(parser));
140
141 let ctx = Box::new(LlgContext {
142 matcher,
143 tok_env,
144 grammar_kind: grammar_kind.to_string(),
145 grammar_data: grammar_data.to_string(),
146 });
147
148 let sampler = unsafe {
149 llama_cpp_bindings_sys::llama_sampler_init(
150 &raw mut LLG_SAMPLER_I,
151 Box::into_raw(ctx).cast::<c_void>(),
152 )
153 };
154
155 if sampler.is_null() {
156 Err(GrammarError::LlguidanceSamplerUnavailable)
157 } else {
158 Ok(LlamaSampler { sampler })
159 }
160}