use lazy_static::lazy_static;
use llguidance::{
api::TopLevelGrammar,
earley::SlicedBiasComputer,
toktrie::{ApproximateTokEnv, InferenceCapabilities, TokEnv, TokenizerEnv},
Matcher, ParserFactory, TokenParser,
};
use serde_json::{json, Value};
use std::sync::Arc;
lazy_static! {
static ref PARSER_FACTORY_PHI: ParserFactory = {
let env = llg_test_utils::get_tok_env();
let mut fact = ParserFactory::new(
env,
InferenceCapabilities {
ff_tokens: false,
backtrack: false,
conditional_ff_tokens: false,
fork: false,
},
&SlicedBiasComputer::general_slices(),
)
.unwrap();
fact.set_stderr_log_level(2);
fact.set_buffer_log_level(0);
fact
};
}
lazy_static! {
static ref PARSER_FACTORY: ParserFactory = {
let env =
toktrie_hf_downloader::tok_env_from_name("unsloth/Meta-Llama-3.1-8B-Instruct").unwrap();
let mut fact = ParserFactory::new(
&env,
InferenceCapabilities {
ff_tokens: false,
backtrack: false,
conditional_ff_tokens: false,
fork: false,
},
&SlicedBiasComputer::general_slices(),
)
.unwrap();
fact.set_stderr_log_level(2);
fact.set_buffer_log_level(0);
fact
};
}
fn make_parser(lark: &str) -> TokenParser {
let grm = TopLevelGrammar::from_lark(lark.to_string());
let mut parser = PARSER_FACTORY.create_parser(grm).unwrap();
parser.start_without_prompt();
parser
}
fn consume(parser: &mut TokenParser, tok: u32) {
let n = parser.consume_token(tok).unwrap();
assert!(n == 0);
}
#[test]
fn test_ff_tokens() {
let lark = r#"
start: <[1111]> <[311]> ( <[366]> | "s" ) <[311]> <[1111]>
"#;
let grm = TopLevelGrammar::from_lark(lark.to_string());
let mut parser = PARSER_FACTORY_PHI.create_parser(grm).unwrap();
parser.start_without_prompt();
let t = parser.compute_ff_tokens();
assert_eq!(t, vec![1111, 311]);
let n = parser.validate_tokens_raw(&t).unwrap();
assert_eq!(n, 2);
consume(&mut parser, 1111);
consume(&mut parser, 311);
let n = parser.validate_tokens_raw(&[366, 311, 1111]).unwrap();
assert_eq!(n, 3);
let n = parser.validate_tokens_raw(&[29879, 311, 1111]).unwrap();
assert_eq!(n, 3);
consume(&mut parser, 29879);
let t = parser.compute_ff_tokens();
assert_eq!(t, vec![311, 1111]);
let n = parser.validate_tokens_raw(&t).unwrap();
assert_eq!(n, 2);
}
fn get_tok_env() -> &'static TokEnv {
PARSER_FACTORY.tok_env()
}
fn json_fwd_test(schema: Value, obj: Value) {
let mut p = make_parser(&format!(
"start: %json {}",
serde_json::to_string(&schema).unwrap()
));
let trie = get_tok_env().tok_trie();
let tokens = get_tok_env().tokenize(serde_json::to_string(&obj).unwrap().as_str());
println!("\n\ntokens: {}\n", trie.tokens_dbg(&tokens));
for tok in tokens.iter() {
let m = p.compute_mask().unwrap();
assert!(m.is_allowed(*tok));
consume(&mut p, *tok);
}
}
#[test]
fn test_ff_json1() {
json_fwd_test(
json!({
"type": "object",
"properties": {
"someLongPropertyName": {
"type": "string"
}
},
"additionalProperties": false
}),
json!({
"someLongPropertyName": "123"
}),
);
}
#[test]
fn test_ff_json2() {
json_fwd_test(
json!({
"additionalProperties": false,
"properties": {
"path": {
"pattern": "^/contributions",
"type": "string"
}
},
"required": ["path"],
"type": "object"
}
),
json!({"path": "/contributions/foo"}),
);
}
#[test]
fn test_ff_json3() {
json_fwd_test(
json!({
"additionalProperties": false,
"properties": {
"location": { "type": "string" },
"retries": { "type": "number" },
"retrieveDate": { "type": "string" },
"retryInterval": { "type": "number" }
},
"required": [ "location", "retrieveDate" ],
"type": "object"
}),
json!({
"location": "https://example.com/firmware.bin",
"retrieveDate": "2022-01-01T12:00:00Z",
"retryInterval": 300
}),
);
}
#[test]
fn test_ff_json4() {
let schema = json!({
"anyOf":[{
"type": "object",
"properties": {
"foo": { "type": "number" }
},
"additionalProperties": { "type": "string" },
}, {
"type": "object",
"properties": {
"bar": { "type": "number" }
},
"additionalProperties": false,
}]
});
json_fwd_test(schema.clone(), json!({ "foo": 123, "baz": "hello" }));
json_fwd_test(schema.clone(), json!({ "bar": 123 }));
}
#[test]
fn test_ff_early() {
let lark = r#"
start: lst
lst: "," lst | ""
"#;
let mut parser = make_parser(lark);
let tokens = get_tok_env().tokenize(",,,,,,,");
for tok in tokens.iter() {
parser.consume_token(*tok).unwrap();
}
}
#[test]
fn test_err_state() {
let lark = r#"
start: /[a-z]*/
"#;
let tokens = get_tok_env().tokenize("fobarbazqu123");
let mut t2 = vec![];
for _ in 0..100 {
t2.push(tokens[0]);
t2.push(tokens[1]);
t2.push(tokens[2]);
}
t2.extend_from_slice(&tokens);
let mut matcher = Matcher::new(Ok(make_parser(lark)));
for tok in t2.iter() {
if let Err(e) = matcher.consume_token(*tok) {
let e = e.to_string();
println!("Error: {e}");
assert!(e.contains("<state>"));
assert!(e.contains("Tokens:"));
return;
}
}
unreachable!();
}
#[test]
fn test_trigger_lexer_error() {
let lark = r#"
start: /[a-z]*/
"#;
let tokens = get_tok_env().tokenize("fobarbazqu");
let mut matcher = Matcher::new(Ok(make_parser(lark)));
for tok in tokens.iter() {
matcher.consume_token(*tok).unwrap();
}
if let Err(e) = matcher.test_trigger_lexer_error() {
let e = e.to_string();
println!("Error: {e}");
assert!(e.contains("<state>"));
assert!(e.contains("synthetic error"));
} else {
unreachable!();
}
if let Err(e) = matcher.consume_token(123) {
let e = e.to_string();
println!("Error: {e}");
assert!(e.contains("<state>"));
assert!(e.contains("synthetic error"));
} else {
unreachable!();
}
}
#[test]
fn test_lexer_inv_crash() {
let tokenv = get_tok_env();
let t1 = tokenv.tokenize("#");
let t2 = tokenv.tokenize("?");
let tokens = tokenv.tokenize("a#");
assert!(t1.len() == 1);
assert!(t2.len() == 1);
let grm = format!("start: /[a-z]+/ ( <[{}]> | <[{}]> )", t1[0], t2[0]);
let parser = make_parser(&grm);
let mut matcher = Matcher::new(Ok(parser));
for t in tokens {
matcher.consume_token(t).unwrap();
}
}
#[test]
fn test_stop_when_try_consume_fails() {
let lark = r#"
start: "blah"* "stop"
"#;
let parser = make_parser(lark);
let tokens = get_tok_env().tokenize("blahblahblahblahstopblah");
let mut matcher = Matcher::new(Ok(parser));
let consumed = matcher.try_consume_tokens(&tokens).unwrap();
assert!(consumed < tokens.len());
assert!(matcher.is_stopped());
matcher.reset().unwrap();
assert!(!matcher.is_stopped());
matcher.try_consume_tokens(&tokens[..consumed]).unwrap();
assert!(matcher.is_stopped());
}
#[test]
fn test_try_consume_after_stop() {
let lark = r#"
start: "blah"* "stop"
"#;
let parser = make_parser(lark);
let tokens = get_tok_env().tokenize("blahblahblahblahstopblah");
let mut matcher = Matcher::new(Ok(parser));
for tok in tokens.iter() {
let is_stopped = matcher.is_stopped();
matcher.try_consume_tokens(&[*tok]).unwrap();
if is_stopped {
assert!(!matcher.is_error());
return;
}
}
unreachable!();
}
#[test]
fn test_try_consume_eos_consistency() {
let lark = r#"start: "a""#;
let parser = make_parser(lark);
let tokens = get_tok_env().tokenize("a");
let eos = get_tok_env().eos_token();
let tokens_with_eos = [tokens.as_slice(), &[eos]].concat();
let mut matcher = Matcher::new(Ok(parser));
let n_consumed_all = matcher.try_consume_tokens(&tokens_with_eos).unwrap();
matcher.reset().unwrap();
let n_consumed_no_eos = matcher.try_consume_tokens(&tokens).unwrap();
assert!(n_consumed_no_eos <= n_consumed_all);
let eos_consumed = matcher.try_consume_tokens(&[eos]).unwrap();
assert!(eos_consumed <= 1);
assert_eq!(n_consumed_no_eos + eos_consumed, n_consumed_all);
}
#[test]
fn test_multi_eos_mask_when_stopped() {
let base = ApproximateTokEnv::single_byte();
let base_trie = base.tok_trie();
let primary_eos = base_trie.eos_token();
let extra_eos = primary_eos - 1;
let multi_trie = base_trie.clone().with_eos_tokens(&[primary_eos, extra_eos]);
let tok_env: TokEnv = Arc::new(ApproximateTokEnv::new(multi_trie));
let factory = ParserFactory::new(
&tok_env,
InferenceCapabilities::default(),
&SlicedBiasComputer::general_slices(),
)
.unwrap();
let grm = TopLevelGrammar::from_lark(r#"start: "a""#.to_string());
let mut parser = factory.create_parser(grm).unwrap();
parser.start_without_prompt();
let mut matcher = Matcher::new(Ok(parser));
let mask = matcher.compute_mask().unwrap();
assert!(mask.is_allowed(b'a' as u32));
matcher.consume_token(b'a' as u32).unwrap();
let mask = matcher.compute_mask_or_eos().unwrap();
assert!(
mask.is_allowed(primary_eos),
"primary EOS should be in stopped mask"
);
assert!(
mask.is_allowed(extra_eos),
"extra EOS should be in stopped mask"
);
assert!(matcher.is_stopped());
}
#[test]
fn test_multi_eos_validate_tokens() {
let base = ApproximateTokEnv::single_byte();
let base_trie = base.tok_trie();
let primary_eos = base_trie.eos_token();
let extra_eos = primary_eos - 1;
let multi_trie = base_trie.clone().with_eos_tokens(&[primary_eos, extra_eos]);
let tok_env: TokEnv = Arc::new(ApproximateTokEnv::new(multi_trie));
let factory = ParserFactory::new(
&tok_env,
InferenceCapabilities::default(),
&SlicedBiasComputer::general_slices(),
)
.unwrap();
let grm = TopLevelGrammar::from_lark(r#"start: /[a-z]*/"#.to_string());
let mut parser = factory.create_parser(grm).unwrap();
parser.start_without_prompt();
let mut matcher = Matcher::new(Ok(parser));
matcher.consume_token(b'a' as u32).unwrap();
assert!(
!matcher.is_stopped(),
"parser should still be running while lexeme bytes are pending"
);
assert_eq!(
matcher.validate_tokens(&[primary_eos]).unwrap(),
1,
"primary EOS must be accepted by validate_tokens"
);
assert_eq!(
matcher.validate_tokens(&[extra_eos]).unwrap(),
1,
"extra EOS must also be accepted by validate_tokens"
);
}