llama_cpp_bindings/
llguidance_sampler.rs1use std::ffi::c_void;
7use std::sync::Arc;
8
9use llguidance::Matcher;
10use toktrie::ApproximateTokEnv;
11
12use crate::GrammarError;
13use crate::model::LlamaModel;
14use crate::sampling::LlamaSampler;
15
16struct LlgContext {
18 matcher: Matcher,
19 tok_env: Arc<ApproximateTokEnv>,
20 grammar_kind: String,
21 grammar_data: String,
22}
23
24const unsafe extern "C" fn llg_name(
25 _smpl: *const llama_cpp_bindings_sys::llama_sampler,
26) -> *const std::os::raw::c_char {
27 c"llguidance".as_ptr()
28}
29
30unsafe extern "C" fn llg_accept(
31 smpl: *mut llama_cpp_bindings_sys::llama_sampler,
32 token: llama_cpp_bindings_sys::llama_token,
33) {
34 let ctx = unsafe { &mut *(*smpl).ctx.cast::<LlgContext>() };
35
36 if let Err(consume_error) = ctx.matcher.consume_token(token.cast_unsigned()) {
37 tracing::warn!(
38 token = token,
39 error = %consume_error,
40 "llguidance sampler failed to consume token"
41 );
42 }
43}
44
45unsafe extern "C" fn llg_apply(
46 smpl: *mut llama_cpp_bindings_sys::llama_sampler,
47 cur_p: *mut llama_cpp_bindings_sys::llama_token_data_array,
48) {
49 let ctx = unsafe { &mut *(*smpl).ctx.cast::<LlgContext>() };
50 let cur_p = unsafe { &mut *cur_p };
51
52 let mask = match ctx.matcher.compute_mask() {
53 Ok(mask) => mask,
54 Err(compute_error) => {
55 tracing::warn!(
56 error = %compute_error,
57 "llguidance sampler failed to compute mask, skipping constraint application"
58 );
59
60 return;
61 }
62 };
63
64 let data = unsafe { std::slice::from_raw_parts_mut(cur_p.data, cur_p.size) };
65 for item in data.iter_mut() {
66 if !mask.is_allowed(item.id.cast_unsigned()) {
67 item.logit = f32::NEG_INFINITY;
68 }
69 }
70}
71
72unsafe extern "C" fn llg_reset(smpl: *mut llama_cpp_bindings_sys::llama_sampler) {
73 let ctx = unsafe { &mut *(*smpl).ctx.cast::<LlgContext>() };
74
75 if let Err(reset_error) = ctx.matcher.reset() {
76 tracing::warn!(
77 error = %reset_error,
78 "llguidance sampler failed to reset"
79 );
80 }
81}
82
83unsafe extern "C" fn llg_clone(
84 smpl: *const llama_cpp_bindings_sys::llama_sampler,
85) -> *mut llama_cpp_bindings_sys::llama_sampler {
86 let ctx = unsafe { &*(*smpl).ctx.cast::<LlgContext>() };
87 let new_ctx = Box::new(LlgContext {
88 matcher: ctx.matcher.deep_clone(),
89 tok_env: Arc::clone(&ctx.tok_env),
90 grammar_kind: ctx.grammar_kind.clone(),
91 grammar_data: ctx.grammar_data.clone(),
92 });
93 unsafe {
94 llama_cpp_bindings_sys::llama_sampler_init(
95 &raw mut LLG_SAMPLER_I,
96 Box::into_raw(new_ctx).cast::<c_void>(),
97 )
98 }
99}
100
101unsafe extern "C" fn llg_free(smpl: *mut llama_cpp_bindings_sys::llama_sampler) {
102 let ctx_ptr = unsafe { (*smpl).ctx.cast::<LlgContext>() };
103 if !ctx_ptr.is_null() {
104 drop(unsafe { Box::from_raw(ctx_ptr) });
105 }
106}
107
108static mut LLG_SAMPLER_I: llama_cpp_bindings_sys::llama_sampler_i =
109 llama_cpp_bindings_sys::llama_sampler_i {
110 name: Some(llg_name),
111 accept: Some(llg_accept),
112 apply: Some(llg_apply),
113 reset: Some(llg_reset),
114 clone: Some(llg_clone),
115 free: Some(llg_free),
116 backend_init: None,
117 backend_accept: None,
118 backend_apply: None,
119 backend_set_input: None,
120 };
121
122pub fn create_llg_sampler(
128 model: &LlamaModel,
129 grammar_kind: &str,
130 grammar_data: &str,
131) -> Result<LlamaSampler, GrammarError> {
132 let tok_env = model.approximate_tok_env();
133 let tok_env_dyn: Arc<dyn toktrie::TokenizerEnv + Sync> = tok_env.clone();
134
135 let factory = llguidance::ParserFactory::new_simple(&tok_env_dyn)
136 .map_err(|factory_error| GrammarError::LlguidanceError(factory_error.to_string()))?;
137
138 let grammar = llguidance::api::TopLevelGrammar::from_tagged_str(grammar_kind, grammar_data)
139 .map_err(|parse_error| GrammarError::LlguidanceError(parse_error.to_string()))?;
140
141 let parser = factory
142 .create_parser(grammar)
143 .map_err(|parser_error| GrammarError::LlguidanceError(parser_error.to_string()))?;
144
145 let matcher = Matcher::new(Ok(parser));
146
147 let ctx = Box::new(LlgContext {
148 matcher,
149 tok_env,
150 grammar_kind: grammar_kind.to_string(),
151 grammar_data: grammar_data.to_string(),
152 });
153
154 let sampler = unsafe {
155 llama_cpp_bindings_sys::llama_sampler_init(
156 &raw mut LLG_SAMPLER_I,
157 Box::into_raw(ctx).cast::<c_void>(),
158 )
159 };
160
161 if sampler.is_null() {
162 Err(GrammarError::NullGrammar(
163 "llguidance sampler returned null".to_owned(),
164 ))
165 } else {
166 Ok(LlamaSampler { sampler })
167 }
168}