1use std::{
2 ffi::{c_char, c_void, CStr},
3 fmt::Display,
4 sync::Arc,
5};
6
7use anyhow::{bail, ensure, Result};
8use toktrie::{InferenceCapabilities, TokEnv, TokRxInfo, TokTrie, TokenizerEnv};
9
10use crate::{
11 api::{ParserLimits, RegexNode, TopLevelGrammar},
12 lark_to_llguidance, CommitResult, Constraint, JsonCompileOptions, Logger, ParserFactory,
13 StopController, TokenParser,
14};
15
16struct CTokenizerInner {
17 trie: TokTrie,
18 tokenize_fn: LlgTokenizeFn,
19 tokenize_user_data: *const c_void,
20 tokenize_assumes_string: bool,
21}
22unsafe impl Send for CTokenizerInner {}
23unsafe impl Sync for CTokenizerInner {}
24
25impl CTokenizerInner {
26 fn raw_tokenize(&self, s: &[u8]) -> Vec<toktrie::TokenId> {
27 if let Some(tokenize_fn) = self.tokenize_fn {
28 let mut res_toks = vec![0; s.len() / 4 + 5];
29 let n_toks = tokenize_fn(
30 self.tokenize_user_data,
31 s.as_ptr(),
32 s.len(),
33 res_toks.as_mut_ptr(),
34 res_toks.len(),
35 );
36
37 if n_toks > res_toks.len() {
38 res_toks.resize(n_toks, 0);
39 tokenize_fn(
40 self.tokenize_user_data,
41 s.as_ptr(),
42 s.len(),
43 res_toks.as_mut_ptr(),
44 res_toks.len(),
45 );
46 }
47
48 res_toks.truncate(n_toks);
49 res_toks
50 } else {
51 self.trie.greedy_tokenize(s)
52 }
53 }
54}
55
56impl TokenizerEnv for CTokenizerInner {
57 fn tok_trie(&self) -> &TokTrie {
58 &self.trie
59 }
60
61 fn tokenize_bytes(&self, s: &[u8]) -> Vec<toktrie::TokenId> {
62 if self.tokenize_assumes_string {
63 self.trie
64 .tokenize_with_greedy_fallback(s, |s| self.raw_tokenize(s.as_bytes()))
65 } else {
66 self.raw_tokenize(s)
67 }
68 }
69
70 fn tokenize_is_canonical(&self) -> bool {
71 self.tokenize_fn.is_some()
72 }
73}
74
75#[derive(Clone)]
76pub struct LlgTokenizer {
77 pub token_env: TokEnv,
78}
79
80impl LlgTokenizer {
81 fn from_init(init: &LlgTokenizerInit) -> Result<Self> {
82 ensure!(
83 init.tokenize_fn.is_some() || init.use_approximate_greedy_tokenize_fn,
84 "Either tokenize_fn or use_approximate_greedy_tokenize_fn must be set"
85 );
86 let tokens = if init.tokenizer_json.is_null() {
87 ensure!(
88 !init.token_lens.is_null() && !init.token_bytes.is_null(),
89 "token_lens and token_bytes must be set"
90 );
91 let token_lens =
92 unsafe { std::slice::from_raw_parts(init.token_lens, init.vocab_size as usize) };
93 let total_len = token_lens.iter().sum::<u32>();
94 let token_bytes =
95 unsafe { std::slice::from_raw_parts(init.token_bytes, total_len as usize) };
96
97 let mut tokens = vec![];
98 let mut ptr = 0;
99 for len in token_lens {
100 let token = &token_bytes[ptr..ptr + *len as usize];
101 tokens.push(token.to_vec());
102 ptr += *len as usize;
103 }
104 tokens
105 } else {
106 let tokenizer_json = unsafe { c_str_to_str(init.tokenizer_json, "tokenizer_json") }?;
107 let tokenizer_json = serde_json::from_str(tokenizer_json)
108 .map_err(|e| anyhow::anyhow!("Invalid JSON in tokenizer_json: {e}"))?;
109 let mut token_bytes =
110 crate::tokenizer_json::token_bytes_from_tokenizer_json(&tokenizer_json)?;
111
112 let sz = init.vocab_size as usize;
113 if token_bytes.len() < sz {
114 token_bytes.resize(sz, vec![]);
115 }
116
117 token_bytes
118 };
119
120 let trie = TokTrie::from(&TokRxInfo::new(tokens.len() as u32, init.tok_eos), &tokens);
121
122 Ok(LlgTokenizer {
123 token_env: Arc::new(CTokenizerInner {
124 trie,
125 tokenize_assumes_string: init.tokenize_assumes_string && init.tokenize_fn.is_some(),
126 tokenize_fn: init.tokenize_fn,
127 tokenize_user_data: init.tokenize_user_data,
128 }),
129 })
130 }
131
132 fn to_env(&self) -> TokEnv {
133 self.token_env.clone()
134 }
135}
136
137pub type LlgToken = u32;
138
139pub type LlgTokenizeFn = Option<
144 extern "C" fn(
145 user_data: *const c_void,
146 bytes: *const u8,
147 bytes_len: usize,
148 output_tokens: *mut u32,
149 output_tokens_len: usize,
150 ) -> usize,
151>;
152
153pub type LlgCallback = Option<extern "C" fn(user_data: *const c_void)>;
155
156#[repr(C)]
157pub struct LlgTokenizerInit {
158 pub vocab_size: u32,
160
161 pub tok_eos: LlgToken,
164
165 pub token_lens: *const u32,
167
168 pub token_bytes: *const u8,
171
172 pub tokenizer_json: *const c_char,
175
176 pub tokenize_assumes_string: bool,
180
181 pub tokenize_fn: LlgTokenizeFn,
187
188 pub use_approximate_greedy_tokenize_fn: bool,
191
192 pub tokenize_user_data: *const c_void,
194}
195
196#[derive(Clone)]
197#[repr(C)]
198pub struct LlgConstraintInit {
199 pub tokenizer: *const LlgTokenizer,
201 pub log_buffer_level: u32,
204 pub log_stderr_level: u32,
206 pub ff_tokens_ok: bool,
209 pub backtrack_ok: bool,
212 pub limits: ParserLimits,
215}
216
217impl LlgConstraintInit {
218 pub fn logger(&self) -> Logger {
219 Logger::new(self.log_buffer_level, self.log_stderr_level)
220 }
221
222 pub fn inference_capabilities(&self) -> InferenceCapabilities {
223 InferenceCapabilities {
224 ff_tokens: self.ff_tokens_ok,
225 backtrack: self.backtrack_ok,
226 conditional_ff_tokens: false,
227 fork: false,
228 }
229 }
230
231 pub fn tok_env(&self) -> Result<TokEnv> {
232 if self.tokenizer.is_null() {
233 bail!("Tokenizer is null");
234 }
235 Ok(unsafe { (&*self.tokenizer).to_env() })
236 }
237
238 pub fn build_parser(
239 &self,
240 grammar: TopLevelGrammar,
241 extra_lexemes: Vec<String>,
242 ) -> Result<TokenParser> {
243 TokenParser::from_llguidance_json(
244 self.tok_env()?,
245 grammar,
246 self.logger(),
247 self.inference_capabilities(),
248 self.limits.clone(),
249 extra_lexemes,
250 )
251 }
252
253 pub fn build_parser_from_factory(
254 &self,
255 factory: &ParserFactory,
256 grammar: TopLevelGrammar,
257 ) -> Result<TokenParser> {
258 let mut parser = self.build_parser(grammar, factory.extra_lexemes())?;
259 factory.post_process_parser(&mut parser);
260 Ok(parser)
261 }
262
263 pub fn build_constraint(&self, grammar: TopLevelGrammar) -> Result<Constraint> {
264 let parser = self.build_parser(grammar, vec![])?;
265 Ok(Constraint::new(parser))
266 }
267}
268
269#[derive(Clone)]
270#[repr(C)]
271pub struct LlgConstraintStep {
272 pub constraint: *mut LlgConstraint,
274 pub mask_dest: *mut u32,
276 pub mask_byte_len: usize,
278}
279
280unsafe impl Send for LlgConstraintStep {}
281
282pub struct LlgConstraint {
283 local_error: Option<String>,
284 last_logs: String,
285 pub(crate) constraint: Option<Constraint>,
286 last_commit_result: CommitResult,
287}
288
289pub struct LlgStopController {
290 stop_controller: StopController,
291 last_result: String,
292}
293
294impl Clone for LlgConstraint {
295 fn clone(&self) -> Self {
296 LlgConstraint {
297 local_error: self.local_error.clone(),
298 last_logs: self.last_logs.clone(),
299 constraint: self.constraint.clone(),
300 last_commit_result: self.last_commit_result.clone(),
301 }
302 }
303}
304
305impl Default for LlgConstraint {
306 fn default() -> Self {
307 LlgConstraint {
308 local_error: None,
309 last_logs: "\x00".to_string(),
310 constraint: None,
311 last_commit_result: CommitResult::default(),
312 }
313 }
314}
315
316#[repr(C)]
317pub struct LlgMaskResult {
318 pub sample_mask: *const u32,
321 pub temperature: f32,
323 pub is_stop: bool,
325}
326
327#[repr(C)]
329pub struct LlgCommitResult {
330 pub tokens: *const u32,
333 pub n_tokens: u32,
335 pub is_stop: bool,
337}
338
339impl LlgCommitResult {
340 pub fn from_commit_result(r: &CommitResult) -> Self {
341 let len = r.ff_tokens.len() as u32;
342 LlgCommitResult {
343 tokens: if len == 0 {
344 std::ptr::null()
345 } else {
346 r.ff_tokens.as_ptr()
347 },
348 n_tokens: len,
349 is_stop: r.stop,
350 }
351 }
352}
353
354unsafe fn c_str_to_str<'a>(c_str: *const c_char, info: &str) -> Result<&'a str> {
355 CStr::from_ptr(c_str)
356 .to_str()
357 .map_err(|_| anyhow::anyhow!("Invalid UTF-8 in {}", info))
358}
359
360fn new_constraint_regex(init: &LlgConstraintInit, regex: *const c_char) -> Result<Constraint> {
361 let regex = unsafe { c_str_to_str(regex, "regex") }?;
362 let grammar = TopLevelGrammar::from_regex(RegexNode::Regex(regex.to_string()));
363 init.build_constraint(grammar)
364}
365
366fn new_constraint_lark(init: &LlgConstraintInit, lark: *const c_char) -> Result<Constraint> {
367 let lark = unsafe { c_str_to_str(lark, "lark") }?;
368 let grammar = lark_to_llguidance(lark)?;
369 init.build_constraint(grammar)
370}
371
372fn new_constraint_json(init: &LlgConstraintInit, json_schema: *const c_char) -> Result<Constraint> {
373 let json_schema = unsafe { c_str_to_str(json_schema, "json_schema") }?;
374 let json_schema = serde_json::from_str(json_schema)
375 .map_err(|e| anyhow::anyhow!("Invalid JSON in json_schema: {e}"))?;
376 let opts = JsonCompileOptions::default();
377 let grammar = opts
378 .json_to_llg(json_schema)
379 .map_err(|e| anyhow::anyhow!("Error compiling JSON schema to LLG: {e}"))?;
380 init.build_constraint(grammar)
381}
382
383fn new_constraint(init: &LlgConstraintInit, grammar_json: *const c_char) -> Result<Constraint> {
384 let grammar_json = unsafe { c_str_to_str(grammar_json, "grammar_json") }?;
385 let grammar: TopLevelGrammar = serde_json::from_str(grammar_json)
386 .map_err(|e| anyhow::anyhow!("Invalid JSON in grammar_json: {e}"))?;
387 init.build_constraint(grammar)
388}
389
390fn new_constraint_any(
391 init: &LlgConstraintInit,
392 constraint_type: *const c_char,
393 data: *const c_char,
394) -> Result<Constraint> {
395 let tp = unsafe { c_str_to_str(constraint_type, "constraint_type") }?;
396 match tp {
397 "regex" => new_constraint_regex(init, data),
398 "json" | "json_schema" => new_constraint_json(init, data),
399 "lark" => new_constraint_lark(init, data),
400 "llguidance" | "guidance" => new_constraint(init, data),
401 _ => bail!("unknown constraint type: {tp}"),
402 }
403}
404
405impl LlgConstraint {
406 fn get_error(&self) -> *const c_char {
407 match &self.local_error {
408 Some(e) => e.as_ptr() as *const c_char,
409 None => std::ptr::null(),
410 }
411 }
412
413 fn get_error_code(&self) -> i32 {
414 if self.local_error.is_some() {
415 -1
416 } else {
417 0
418 }
419 }
420
421 pub(crate) fn set_error(&mut self, e: &str) {
422 self.constraint = None;
423 self.local_error = Some(format!("{e}\0"));
424 }
425}
426
427#[no_mangle]
432pub extern "C" fn llg_constraint_init_set_defaults(
433 init: &mut LlgConstraintInit,
434 tokenizer: *const LlgTokenizer,
435) {
436 *init = LlgConstraintInit {
437 tokenizer,
438 log_buffer_level: 0,
439 log_stderr_level: 1,
440 ff_tokens_ok: false,
441 backtrack_ok: false,
442 limits: ParserLimits::default(),
443 };
444}
445
446pub fn constraint_to_llg(c: Result<Constraint>) -> *mut LlgConstraint {
447 let mut res = LlgConstraint::default();
448
449 match c {
450 Ok(constraint) => res.constraint = Some(constraint),
451 Err(e) => res.set_error(&e.to_string()),
452 };
453
454 Box::into_raw(Box::new(res))
455}
456
457#[no_mangle]
460pub extern "C" fn llg_new_constraint(
461 init: &LlgConstraintInit,
462 grammar_json: *const c_char,
463) -> *mut LlgConstraint {
464 constraint_to_llg(new_constraint(init, grammar_json))
465}
466
467#[no_mangle]
470pub extern "C" fn llg_new_constraint_regex(
471 init: &LlgConstraintInit,
472 regex: *const c_char,
473) -> *mut LlgConstraint {
474 constraint_to_llg(new_constraint_regex(init, regex))
475}
476
477#[no_mangle]
480pub extern "C" fn llg_new_constraint_json(
481 init: &LlgConstraintInit,
482 json_schema: *const c_char,
483) -> *mut LlgConstraint {
484 constraint_to_llg(new_constraint_json(init, json_schema))
485}
486
487#[no_mangle]
490pub extern "C" fn llg_new_constraint_lark(
491 init: &LlgConstraintInit,
492 lark: *const c_char,
493) -> *mut LlgConstraint {
494 constraint_to_llg(new_constraint_lark(init, lark))
495}
496
497#[no_mangle]
501pub extern "C" fn llg_new_constraint_any(
502 init: &LlgConstraintInit,
503 constraint_type: *const c_char,
504 data: *const c_char,
505) -> *mut LlgConstraint {
506 constraint_to_llg(new_constraint_any(init, constraint_type, data))
507}
508
509#[no_mangle]
513pub extern "C" fn llg_get_error(cc: &LlgConstraint) -> *const c_char {
514 cc.get_error()
515}
516
517#[no_mangle]
520pub extern "C" fn llg_get_temperature(cc: &LlgConstraint) -> f32 {
521 cc.constraint.as_ref().map_or(0.0, |c| c.temperature)
522}
523
524#[no_mangle]
526pub extern "C" fn llg_is_stopped(cc: &LlgConstraint) -> bool {
527 cc.constraint
528 .as_ref()
529 .map_or(true, |c| c.step_result().is_stop())
530}
531
532#[no_mangle]
537pub extern "C" fn llg_compute_mask(cc: &mut LlgConstraint, res_p: &mut LlgMaskResult) -> i32 {
538 if let Some(constraint) = &mut cc.constraint {
539 match constraint.compute_mask() {
540 Ok(r) => {
541 let r = LlgMaskResult {
542 sample_mask: r
543 .sample_mask
544 .as_ref()
545 .map_or(std::ptr::null(), |m| m.as_ptr()),
546 is_stop: r.is_stop(),
547 temperature: constraint.temperature,
548 };
549 *res_p = r;
550 }
551 Err(e) => cc.set_error(&e.to_string()),
552 }
553 }
554 cc.get_error_code()
555}
556
557#[no_mangle]
562pub extern "C" fn llg_commit_token(
563 cc: &mut LlgConstraint,
564 token: LlgToken,
565 res_p: &mut LlgCommitResult,
566) -> i32 {
567 if let Some(constraint) = &mut cc.constraint {
568 let trie = constraint.parser.token_env.tok_trie();
569 let token = if token < trie.vocab_size() as LlgToken {
570 Some(token)
571 } else {
572 None
573 };
574 match constraint.commit_token(token) {
575 Ok(r) => {
576 cc.last_commit_result = r;
578 let res = LlgCommitResult::from_commit_result(&cc.last_commit_result);
579 *res_p = res;
580 }
581 Err(e) => cc.set_error(&e.to_string()),
582 }
583 }
584 cc.get_error_code()
585}
586
587#[no_mangle]
589pub extern "C" fn llg_par_compute_mask(
590 steps: *const LlgConstraintStep,
591 n_steps: usize,
592 user_data: *const c_void,
593 done_cb: LlgCallback,
594) {
595 if steps.is_null() {
596 panic!("llg_par_compute_mask: steps is null");
597 }
598
599 #[cfg(feature = "rayon")]
600 {
601 let steps = unsafe { std::slice::from_raw_parts(steps, n_steps).to_vec() };
602 crate::ffi_par::par_compute_mask(steps, user_data, done_cb);
603 }
604
605 #[cfg(not(feature = "rayon"))]
606 {
607 let _ = (steps, n_steps, user_data, done_cb);
608 panic!("llg_par_compute_mask: rayon feature is not enabled");
609 }
610}
611
612#[no_mangle]
614pub extern "C" fn llg_clone_constraint(cc: &LlgConstraint) -> *mut LlgConstraint {
615 Box::into_raw(Box::new(cc.clone()))
616}
617
618#[no_mangle]
620pub extern "C" fn llg_new_tokenizer(
621 tok_init: &LlgTokenizerInit,
622 error_string: *mut c_char,
623 error_string_len: usize,
624) -> *mut LlgTokenizer {
625 match LlgTokenizer::from_init(tok_init) {
626 Ok(tok) => Box::into_raw(Box::new(tok)),
627 Err(e) => {
628 save_error_string(e, error_string, error_string_len);
629 std::ptr::null_mut()
630 }
631 }
632}
633
634#[no_mangle]
637pub extern "C" fn llg_clone_tokenizer(tok: &LlgTokenizer) -> *mut LlgTokenizer {
638 Box::into_raw(Box::new(LlgTokenizer {
639 token_env: tok.token_env.clone(),
640 }))
641}
642
643#[no_mangle]
647pub extern "C" fn llg_tokenize_bytes(
648 tok: &LlgTokenizer,
649 bytes: *const u8,
650 bytes_len: usize,
651 output_tokens: *mut u32,
652 output_tokens_len: usize,
653) -> usize {
654 let tokens = tok
655 .token_env
656 .tokenize_bytes(unsafe { std::slice::from_raw_parts(bytes, bytes_len) });
657 let n_toks = tokens.len();
658 let to_copy = std::cmp::min(n_toks, output_tokens_len);
659 unsafe {
660 std::ptr::copy_nonoverlapping(tokens.as_ptr(), output_tokens, to_copy);
661 }
662 n_toks
663}
664
665#[no_mangle]
670pub extern "C" fn llg_tokenize_bytes_marker(
671 tok: &LlgTokenizer,
672 bytes: *const u8,
673 bytes_len: usize,
674 output_tokens: *mut u32,
675 output_tokens_len: usize,
676) -> usize {
677 let tokens = tok
678 .token_env
679 .tokenize_bytes_marker(unsafe { std::slice::from_raw_parts(bytes, bytes_len) })
680 .0;
681 let n_toks = tokens.len();
682 let to_copy = std::cmp::min(n_toks, output_tokens_len);
683 unsafe {
684 std::ptr::copy_nonoverlapping(tokens.as_ptr(), output_tokens, to_copy);
685 }
686 n_toks
687}
688
689#[no_mangle]
693pub extern "C" fn llg_stringify_tokens(
694 tok: &LlgTokenizer,
695 tokens: *const u32,
696 n_tokens: usize,
697 output: *mut c_char,
698 output_len: usize,
699) -> usize {
700 let trie = tok.token_env.tok_trie();
701 let tokens = unsafe { std::slice::from_raw_parts(tokens, n_tokens) };
702 let s = trie.tokens_dbg(tokens);
703 let s = s.as_bytes();
704 let len = std::cmp::min(s.len(), output_len - 1);
705 unsafe {
706 std::ptr::copy_nonoverlapping(s.as_ptr(), output as *mut u8, len);
707 *output.add(len) = 0;
708 }
709 s.len() + 1
710}
711
712#[no_mangle]
714pub extern "C" fn llg_free_tokenizer(tok: *mut LlgTokenizer) {
715 unsafe {
716 drop(Box::from_raw(tok));
717 }
718}
719
720#[no_mangle]
722pub extern "C" fn llg_free_constraint(cc: *mut LlgConstraint) {
723 unsafe {
724 drop(Box::from_raw(cc));
725 }
726}
727
728#[no_mangle]
733pub extern "C" fn llg_flush_logs(cc: &mut LlgConstraint) -> *const c_char {
734 if let Some(constraint) = &mut cc.constraint {
735 let s = constraint.flush_logs();
736 if s.contains('\0') {
737 cc.last_logs = s.replace('\0', "\\0");
738 } else {
739 cc.last_logs = s;
740 }
741 cc.last_logs.push('\0');
742 }
743 cc.last_logs.as_ptr() as *const c_char
744}
745
746fn build_stop_controller(
747 tokenizer: &LlgTokenizer,
748 stop_tokens: &[u32],
749 stop_rx: *const c_char,
750) -> Result<StopController> {
751 let stop_rx = if stop_rx.is_null() {
752 None
753 } else {
754 Some(unsafe { c_str_to_str(stop_rx, "stop_rx") }?.to_string())
755 };
756 StopController::new(
757 tokenizer.token_env.clone(),
758 stop_tokens.to_vec(),
759 stop_rx,
760 vec![],
761 )
762}
763
764fn save_error_string(e: impl Display, error_string: *mut c_char, error_string_len: usize) {
765 if error_string_len > 0 {
766 let e = e.to_string();
767 let e = e.as_bytes();
768 let len = std::cmp::min(e.len(), error_string_len - 1);
769 unsafe {
770 std::ptr::copy_nonoverlapping(e.as_ptr(), error_string as *mut u8, len);
771 *error_string.add(len) = 0;
772 }
773 }
774}
775
776#[no_mangle]
778pub extern "C" fn llg_new_stop_controller(
779 tokenizer: &LlgTokenizer,
780 stop_tokens: *const u32,
781 stop_tokens_len: usize,
782 stop_rx: *const c_char,
783 error_string: *mut c_char,
784 error_string_len: usize,
785) -> *mut LlgStopController {
786 let stop_tokens = unsafe { std::slice::from_raw_parts(stop_tokens, stop_tokens_len) };
787 match build_stop_controller(tokenizer, stop_tokens, stop_rx) {
788 Ok(stop_controller) => Box::into_raw(Box::new(LlgStopController {
789 stop_controller,
790 last_result: String::new(),
791 })),
792 Err(e) => {
793 save_error_string(e, error_string, error_string_len);
794 std::ptr::null_mut()
795 }
796 }
797}
798
799#[no_mangle]
804pub extern "C" fn llg_stop_commit_token(
805 stop_ctrl: &mut LlgStopController,
806 token: u32,
807 output_len_p: &mut usize,
808 is_stopped_p: &mut bool,
809) -> *const c_char {
810 let r = stop_ctrl.stop_controller.commit_token(token);
811 *output_len_p = r.len();
812 *is_stopped_p = stop_ctrl.stop_controller.is_stopped();
813 stop_ctrl.last_result = format!("{r}\0");
814 stop_ctrl.last_result.as_ptr() as *const c_char
815}
816
817#[no_mangle]
819pub extern "C" fn llg_free_stop_controller(stop_ctrl: *mut LlgStopController) {
820 unsafe {
821 drop(Box::from_raw(stop_ctrl));
822 }
823}