use std::{
ffi::{c_char, c_void, CStr},
fmt::Display,
sync::Arc,
};
use anyhow::{bail, ensure, Result};
use toktrie::{InferenceCapabilities, TokEnv, TokRxInfo, TokTrie, TokenizerEnv};
use crate::{
api::{ParserLimits, RegexNode, TopLevelGrammar},
lark_to_llguidance, CommitResult, Constraint, JsonCompileOptions, Logger, ParserFactory,
StopController, TokenParser,
};
struct CTokenizerInner {
trie: TokTrie,
tokenize_fn: LlgTokenizeFn,
tokenize_user_data: *const c_void,
tokenize_assumes_string: bool,
}
unsafe impl Send for CTokenizerInner {}
unsafe impl Sync for CTokenizerInner {}
impl CTokenizerInner {
fn raw_tokenize(&self, s: &[u8]) -> Vec<toktrie::TokenId> {
if let Some(tokenize_fn) = self.tokenize_fn {
let mut res_toks = vec![0; s.len() / 4 + 5];
let n_toks = tokenize_fn(
self.tokenize_user_data,
s.as_ptr(),
s.len(),
res_toks.as_mut_ptr(),
res_toks.len(),
);
if n_toks > res_toks.len() {
res_toks.resize(n_toks, 0);
tokenize_fn(
self.tokenize_user_data,
s.as_ptr(),
s.len(),
res_toks.as_mut_ptr(),
res_toks.len(),
);
}
res_toks.truncate(n_toks);
res_toks
} else {
self.trie.greedy_tokenize(s)
}
}
}
impl TokenizerEnv for CTokenizerInner {
fn tok_trie(&self) -> &TokTrie {
&self.trie
}
fn tokenize_bytes(&self, s: &[u8]) -> Vec<toktrie::TokenId> {
if self.tokenize_assumes_string {
self.trie
.tokenize_with_greedy_fallback(s, |s| self.raw_tokenize(s.as_bytes()))
} else {
self.raw_tokenize(s)
}
}
fn tokenize_is_canonical(&self) -> bool {
self.tokenize_fn.is_some()
}
}
#[derive(Clone)]
pub struct LlgTokenizer {
pub token_env: TokEnv,
}
impl LlgTokenizer {
fn from_init(init: &LlgTokenizerInit) -> Result<Self> {
ensure!(
init.tokenize_fn.is_some() || init.use_approximate_greedy_tokenize_fn,
"Either tokenize_fn or use_approximate_greedy_tokenize_fn must be set"
);
let tokens = if init.tokenizer_json.is_null() {
ensure!(
!init.token_lens.is_null() && !init.token_bytes.is_null(),
"token_lens and token_bytes must be set"
);
let token_lens =
unsafe { std::slice::from_raw_parts(init.token_lens, init.vocab_size as usize) };
let total_len = token_lens.iter().sum::<u32>();
let token_bytes =
unsafe { std::slice::from_raw_parts(init.token_bytes, total_len as usize) };
let mut tokens = vec![];
let mut ptr = 0;
for len in token_lens {
let token = &token_bytes[ptr..ptr + *len as usize];
tokens.push(token.to_vec());
ptr += *len as usize;
}
tokens
} else {
let tokenizer_json = unsafe { c_str_to_str(init.tokenizer_json, "tokenizer_json") }?;
let tokenizer_json = serde_json::from_str(tokenizer_json)
.map_err(|e| anyhow::anyhow!("Invalid JSON in tokenizer_json: {e}"))?;
let mut token_bytes =
crate::tokenizer_json::token_bytes_from_tokenizer_json(&tokenizer_json)?;
let sz = init.vocab_size as usize;
if token_bytes.len() < sz {
token_bytes.resize(sz, vec![]);
}
token_bytes
};
let trie = TokTrie::from(&TokRxInfo::new(tokens.len() as u32, init.tok_eos), &tokens);
Ok(LlgTokenizer {
token_env: Arc::new(CTokenizerInner {
trie,
tokenize_assumes_string: init.tokenize_assumes_string && init.tokenize_fn.is_some(),
tokenize_fn: init.tokenize_fn,
tokenize_user_data: init.tokenize_user_data,
}),
})
}
fn to_env(&self) -> TokEnv {
self.token_env.clone()
}
}
pub type LlgToken = u32;
pub type LlgTokenizeFn = Option<
extern "C" fn(
user_data: *const c_void,
bytes: *const u8,
bytes_len: usize,
output_tokens: *mut u32,
output_tokens_len: usize,
) -> usize,
>;
pub type LlgCallback = Option<extern "C" fn(user_data: *const c_void)>;
#[repr(C)]
pub struct LlgTokenizerInit {
pub vocab_size: u32,
pub tok_eos: LlgToken,
pub token_lens: *const u32,
pub token_bytes: *const u8,
pub tokenizer_json: *const c_char,
pub tokenize_assumes_string: bool,
pub tokenize_fn: LlgTokenizeFn,
pub use_approximate_greedy_tokenize_fn: bool,
pub tokenize_user_data: *const c_void,
}
#[derive(Clone)]
#[repr(C)]
pub struct LlgConstraintInit {
pub tokenizer: *const LlgTokenizer,
pub log_buffer_level: u32,
pub log_stderr_level: u32,
pub ff_tokens_ok: bool,
pub backtrack_ok: bool,
pub limits: ParserLimits,
}
impl LlgConstraintInit {
pub fn logger(&self) -> Logger {
Logger::new(self.log_buffer_level, self.log_stderr_level)
}
pub fn inference_capabilities(&self) -> InferenceCapabilities {
InferenceCapabilities {
ff_tokens: self.ff_tokens_ok,
backtrack: self.backtrack_ok,
conditional_ff_tokens: false,
fork: false,
}
}
pub fn tok_env(&self) -> Result<TokEnv> {
if self.tokenizer.is_null() {
bail!("Tokenizer is null");
}
Ok(unsafe { (&*self.tokenizer).to_env() })
}
pub fn build_parser(
&self,
grammar: TopLevelGrammar,
extra_lexemes: Vec<String>,
) -> Result<TokenParser> {
TokenParser::from_llguidance_json(
self.tok_env()?,
grammar,
self.logger(),
self.inference_capabilities(),
self.limits.clone(),
extra_lexemes,
)
}
pub fn build_parser_from_factory(
&self,
factory: &ParserFactory,
grammar: TopLevelGrammar,
) -> Result<TokenParser> {
let mut parser = self.build_parser(grammar, factory.extra_lexemes())?;
factory.post_process_parser(&mut parser);
Ok(parser)
}
pub fn build_constraint(&self, grammar: TopLevelGrammar) -> Result<Constraint> {
let parser = self.build_parser(grammar, vec![])?;
Ok(Constraint::new(parser))
}
}
#[derive(Clone)]
#[repr(C)]
pub struct LlgConstraintStep {
pub constraint: *mut LlgConstraint,
pub mask_dest: *mut u32,
pub mask_byte_len: usize,
}
unsafe impl Send for LlgConstraintStep {}
pub struct LlgConstraint {
local_error: Option<String>,
last_logs: String,
pub(crate) constraint: Option<Constraint>,
last_commit_result: CommitResult,
}
pub struct LlgStopController {
stop_controller: StopController,
last_result: String,
}
impl Clone for LlgConstraint {
fn clone(&self) -> Self {
LlgConstraint {
local_error: self.local_error.clone(),
last_logs: self.last_logs.clone(),
constraint: self.constraint.clone(),
last_commit_result: self.last_commit_result.clone(),
}
}
}
impl Default for LlgConstraint {
fn default() -> Self {
LlgConstraint {
local_error: None,
last_logs: "\x00".to_string(),
constraint: None,
last_commit_result: CommitResult::default(),
}
}
}
#[repr(C)]
pub struct LlgMaskResult {
pub sample_mask: *const u32,
pub temperature: f32,
pub is_stop: bool,
}
#[repr(C)]
pub struct LlgCommitResult {
pub tokens: *const u32,
pub n_tokens: u32,
pub is_stop: bool,
}
impl LlgCommitResult {
pub fn from_commit_result(r: &CommitResult) -> Self {
let len = r.ff_tokens.len() as u32;
LlgCommitResult {
tokens: if len == 0 {
std::ptr::null()
} else {
r.ff_tokens.as_ptr()
},
n_tokens: len,
is_stop: r.stop,
}
}
}
unsafe fn c_str_to_str<'a>(c_str: *const c_char, info: &str) -> Result<&'a str> {
CStr::from_ptr(c_str)
.to_str()
.map_err(|_| anyhow::anyhow!("Invalid UTF-8 in {}", info))
}
fn new_constraint_regex(init: &LlgConstraintInit, regex: *const c_char) -> Result<Constraint> {
let regex = unsafe { c_str_to_str(regex, "regex") }?;
let grammar = TopLevelGrammar::from_regex(RegexNode::Regex(regex.to_string()));
init.build_constraint(grammar)
}
fn new_constraint_lark(init: &LlgConstraintInit, lark: *const c_char) -> Result<Constraint> {
let lark = unsafe { c_str_to_str(lark, "lark") }?;
let grammar = lark_to_llguidance(lark)?;
init.build_constraint(grammar)
}
fn new_constraint_json(init: &LlgConstraintInit, json_schema: *const c_char) -> Result<Constraint> {
let json_schema = unsafe { c_str_to_str(json_schema, "json_schema") }?;
let json_schema = serde_json::from_str(json_schema)
.map_err(|e| anyhow::anyhow!("Invalid JSON in json_schema: {e}"))?;
let opts = JsonCompileOptions::default();
let grammar = opts
.json_to_llg(json_schema)
.map_err(|e| anyhow::anyhow!("Error compiling JSON schema to LLG: {e}"))?;
init.build_constraint(grammar)
}
fn new_constraint(init: &LlgConstraintInit, grammar_json: *const c_char) -> Result<Constraint> {
let grammar_json = unsafe { c_str_to_str(grammar_json, "grammar_json") }?;
let grammar: TopLevelGrammar = serde_json::from_str(grammar_json)
.map_err(|e| anyhow::anyhow!("Invalid JSON in grammar_json: {e}"))?;
init.build_constraint(grammar)
}
fn new_constraint_any(
init: &LlgConstraintInit,
constraint_type: *const c_char,
data: *const c_char,
) -> Result<Constraint> {
let tp = unsafe { c_str_to_str(constraint_type, "constraint_type") }?;
match tp {
"regex" => new_constraint_regex(init, data),
"json" | "json_schema" => new_constraint_json(init, data),
"lark" => new_constraint_lark(init, data),
"llguidance" | "guidance" => new_constraint(init, data),
_ => bail!("unknown constraint type: {tp}"),
}
}
impl LlgConstraint {
fn get_error(&self) -> *const c_char {
match &self.local_error {
Some(e) => e.as_ptr() as *const c_char,
None => std::ptr::null(),
}
}
fn get_error_code(&self) -> i32 {
if self.local_error.is_some() {
-1
} else {
0
}
}
pub(crate) fn set_error(&mut self, e: &str) {
self.constraint = None;
self.local_error = Some(format!("{e}\0"));
}
}
#[no_mangle]
pub extern "C" fn llg_constraint_init_set_defaults(
init: &mut LlgConstraintInit,
tokenizer: *const LlgTokenizer,
) {
*init = LlgConstraintInit {
tokenizer,
log_buffer_level: 0,
log_stderr_level: 1,
ff_tokens_ok: false,
backtrack_ok: false,
limits: ParserLimits::default(),
};
}
pub fn constraint_to_llg(c: Result<Constraint>) -> *mut LlgConstraint {
let mut res = LlgConstraint::default();
match c {
Ok(constraint) => res.constraint = Some(constraint),
Err(e) => res.set_error(&e.to_string()),
};
Box::into_raw(Box::new(res))
}
#[no_mangle]
pub extern "C" fn llg_new_constraint(
init: &LlgConstraintInit,
grammar_json: *const c_char,
) -> *mut LlgConstraint {
constraint_to_llg(new_constraint(init, grammar_json))
}
#[no_mangle]
pub extern "C" fn llg_new_constraint_regex(
init: &LlgConstraintInit,
regex: *const c_char,
) -> *mut LlgConstraint {
constraint_to_llg(new_constraint_regex(init, regex))
}
#[no_mangle]
pub extern "C" fn llg_new_constraint_json(
init: &LlgConstraintInit,
json_schema: *const c_char,
) -> *mut LlgConstraint {
constraint_to_llg(new_constraint_json(init, json_schema))
}
#[no_mangle]
pub extern "C" fn llg_new_constraint_lark(
init: &LlgConstraintInit,
lark: *const c_char,
) -> *mut LlgConstraint {
constraint_to_llg(new_constraint_lark(init, lark))
}
#[no_mangle]
pub extern "C" fn llg_new_constraint_any(
init: &LlgConstraintInit,
constraint_type: *const c_char,
data: *const c_char,
) -> *mut LlgConstraint {
constraint_to_llg(new_constraint_any(init, constraint_type, data))
}
#[no_mangle]
pub extern "C" fn llg_get_error(cc: &LlgConstraint) -> *const c_char {
cc.get_error()
}
#[no_mangle]
pub extern "C" fn llg_get_temperature(cc: &LlgConstraint) -> f32 {
cc.constraint.as_ref().map_or(0.0, |c| c.temperature)
}
#[no_mangle]
pub extern "C" fn llg_is_stopped(cc: &LlgConstraint) -> bool {
cc.constraint
.as_ref()
.map_or(true, |c| c.step_result().is_stop())
}
#[no_mangle]
pub extern "C" fn llg_compute_mask(cc: &mut LlgConstraint, res_p: &mut LlgMaskResult) -> i32 {
if let Some(constraint) = &mut cc.constraint {
match constraint.compute_mask() {
Ok(r) => {
let r = LlgMaskResult {
sample_mask: r
.sample_mask
.as_ref()
.map_or(std::ptr::null(), |m| m.as_ptr()),
is_stop: r.is_stop(),
temperature: constraint.temperature,
};
*res_p = r;
}
Err(e) => cc.set_error(&e.to_string()),
}
}
cc.get_error_code()
}
#[no_mangle]
pub extern "C" fn llg_commit_token(
cc: &mut LlgConstraint,
token: LlgToken,
res_p: &mut LlgCommitResult,
) -> i32 {
if let Some(constraint) = &mut cc.constraint {
let trie = constraint.parser.token_env.tok_trie();
let token = if token < trie.vocab_size() as LlgToken {
Some(token)
} else {
None
};
match constraint.commit_token(token) {
Ok(r) => {
cc.last_commit_result = r;
let res = LlgCommitResult::from_commit_result(&cc.last_commit_result);
*res_p = res;
}
Err(e) => cc.set_error(&e.to_string()),
}
}
cc.get_error_code()
}
#[no_mangle]
pub extern "C" fn llg_par_compute_mask(
steps: *const LlgConstraintStep,
n_steps: usize,
user_data: *const c_void,
done_cb: LlgCallback,
) {
if steps.is_null() {
panic!("llg_par_compute_mask: steps is null");
}
#[cfg(feature = "rayon")]
{
let steps = unsafe { std::slice::from_raw_parts(steps, n_steps).to_vec() };
crate::ffi_par::par_compute_mask(steps, user_data, done_cb);
}
#[cfg(not(feature = "rayon"))]
{
let _ = (steps, n_steps, user_data, done_cb);
panic!("llg_par_compute_mask: rayon feature is not enabled");
}
}
#[no_mangle]
pub extern "C" fn llg_clone_constraint(cc: &LlgConstraint) -> *mut LlgConstraint {
Box::into_raw(Box::new(cc.clone()))
}
#[no_mangle]
pub extern "C" fn llg_new_tokenizer(
tok_init: &LlgTokenizerInit,
error_string: *mut c_char,
error_string_len: usize,
) -> *mut LlgTokenizer {
match LlgTokenizer::from_init(tok_init) {
Ok(tok) => Box::into_raw(Box::new(tok)),
Err(e) => {
save_error_string(e, error_string, error_string_len);
std::ptr::null_mut()
}
}
}
#[no_mangle]
pub extern "C" fn llg_clone_tokenizer(tok: &LlgTokenizer) -> *mut LlgTokenizer {
Box::into_raw(Box::new(LlgTokenizer {
token_env: tok.token_env.clone(),
}))
}
#[no_mangle]
pub extern "C" fn llg_tokenize_bytes(
tok: &LlgTokenizer,
bytes: *const u8,
bytes_len: usize,
output_tokens: *mut u32,
output_tokens_len: usize,
) -> usize {
let tokens = tok
.token_env
.tokenize_bytes(unsafe { std::slice::from_raw_parts(bytes, bytes_len) });
let n_toks = tokens.len();
let to_copy = std::cmp::min(n_toks, output_tokens_len);
unsafe {
std::ptr::copy_nonoverlapping(tokens.as_ptr(), output_tokens, to_copy);
}
n_toks
}
#[no_mangle]
pub extern "C" fn llg_tokenize_bytes_marker(
tok: &LlgTokenizer,
bytes: *const u8,
bytes_len: usize,
output_tokens: *mut u32,
output_tokens_len: usize,
) -> usize {
let tokens = tok
.token_env
.tokenize_bytes_marker(unsafe { std::slice::from_raw_parts(bytes, bytes_len) })
.0;
let n_toks = tokens.len();
let to_copy = std::cmp::min(n_toks, output_tokens_len);
unsafe {
std::ptr::copy_nonoverlapping(tokens.as_ptr(), output_tokens, to_copy);
}
n_toks
}
#[no_mangle]
pub extern "C" fn llg_stringify_tokens(
tok: &LlgTokenizer,
tokens: *const u32,
n_tokens: usize,
output: *mut c_char,
output_len: usize,
) -> usize {
let trie = tok.token_env.tok_trie();
let tokens = unsafe { std::slice::from_raw_parts(tokens, n_tokens) };
let s = trie.tokens_dbg(tokens);
let s = s.as_bytes();
let len = std::cmp::min(s.len(), output_len - 1);
unsafe {
std::ptr::copy_nonoverlapping(s.as_ptr(), output as *mut u8, len);
*output.add(len) = 0;
}
s.len() + 1
}
#[no_mangle]
pub extern "C" fn llg_free_tokenizer(tok: *mut LlgTokenizer) {
unsafe {
drop(Box::from_raw(tok));
}
}
#[no_mangle]
pub extern "C" fn llg_free_constraint(cc: *mut LlgConstraint) {
unsafe {
drop(Box::from_raw(cc));
}
}
#[no_mangle]
pub extern "C" fn llg_flush_logs(cc: &mut LlgConstraint) -> *const c_char {
if let Some(constraint) = &mut cc.constraint {
let s = constraint.flush_logs();
if s.contains('\0') {
cc.last_logs = s.replace('\0', "\\0");
} else {
cc.last_logs = s;
}
cc.last_logs.push('\0');
}
cc.last_logs.as_ptr() as *const c_char
}
fn build_stop_controller(
tokenizer: &LlgTokenizer,
stop_tokens: &[u32],
stop_rx: *const c_char,
) -> Result<StopController> {
let stop_rx = if stop_rx.is_null() {
None
} else {
Some(unsafe { c_str_to_str(stop_rx, "stop_rx") }?.to_string())
};
StopController::new(
tokenizer.token_env.clone(),
stop_tokens.to_vec(),
stop_rx,
vec![],
)
}
fn save_error_string(e: impl Display, error_string: *mut c_char, error_string_len: usize) {
if error_string_len > 0 {
let e = e.to_string();
let e = e.as_bytes();
let len = std::cmp::min(e.len(), error_string_len - 1);
unsafe {
std::ptr::copy_nonoverlapping(e.as_ptr(), error_string as *mut u8, len);
*error_string.add(len) = 0;
}
}
}
#[no_mangle]
pub extern "C" fn llg_new_stop_controller(
tokenizer: &LlgTokenizer,
stop_tokens: *const u32,
stop_tokens_len: usize,
stop_rx: *const c_char,
error_string: *mut c_char,
error_string_len: usize,
) -> *mut LlgStopController {
let stop_tokens = unsafe { std::slice::from_raw_parts(stop_tokens, stop_tokens_len) };
match build_stop_controller(tokenizer, stop_tokens, stop_rx) {
Ok(stop_controller) => Box::into_raw(Box::new(LlgStopController {
stop_controller,
last_result: String::new(),
})),
Err(e) => {
save_error_string(e, error_string, error_string_len);
std::ptr::null_mut()
}
}
}
#[no_mangle]
pub extern "C" fn llg_stop_commit_token(
stop_ctrl: &mut LlgStopController,
token: u32,
output_len_p: &mut usize,
is_stopped_p: &mut bool,
) -> *const c_char {
let r = stop_ctrl.stop_controller.commit_token(token);
*output_len_p = r.len();
*is_stopped_p = stop_ctrl.stop_controller.is_stopped();
stop_ctrl.last_result = format!("{r}\0");
stop_ctrl.last_result.as_ptr() as *const c_char
}
#[no_mangle]
pub extern "C" fn llg_free_stop_controller(stop_ctrl: *mut LlgStopController) {
unsafe {
drop(Box::from_raw(stop_ctrl));
}
}