1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
/*!
# KBNF

This crate provides a constrained decoding engine
which ensures that a language model's output adheres strictly to the format defined by KBNF (Koishi's BNF), an enhanced variant of EBNF.
KBNF includes features that enhance usability, notably embeddable regular expressions and more flexible exceptions.
Here is a quick example of how this crate works:

```rust
fn greedy_decode(logits: &[f32])->u32 {
    logits.iter().enumerate().max_by(|a,b|a.1.partial_cmp(b.1).unwrap()).unwrap().0 as u32
}

use ahash::AHashMap;
use kbnf::{Engine, EngineLike, Grammar, Token, Vocabulary};
let grammar_str = r#"
start ::= "你好"except!('\n\n')'\n\n';
"#;
let mut token_strings: AHashMap<u32, String> = AHashMap::default();
token_strings.extend(
    [
        (1, "你好".to_string()),
        (2, "hello".to_string()),
        (3, "250".to_string()),
        (4, "\n".to_string()),
        (5, "\n\n".to_string()),
    ]
);
let mut tokens = token_strings
    .iter()
    .map(|(k, v)| (*k, Token(v.as_bytes().to_vec().into_boxed_slice())))
    .collect::<AHashMap<u32, _>>();
tokens.insert(3,Token(Box::new([250])));
let vocab = Vocabulary::new(tokens, token_strings).unwrap();
let mut engine = Engine::new(grammar_str, vocab).unwrap();
let mut token = 1; // the prompt token
let mut logits = [0.0, 0.0, 0.0, 1.0, 0.0, 0.0]; // logits obtained from the language model
assert_eq!(
    engine.update_logits(token, &mut logits).unwrap(),
    kbnf::AcceptTokenResult::Ongoing
);
assert_eq!(&format!("{:?}", logits), "[-inf, 0.0, 0.0, 1.0, 0.0, 0.0]");
token = greedy_decode(&logits);
logits = [0.0, 0.0, 0.0, 0.0, 1.0, 0.0]; // new logits obtained from the language model
assert_eq!(
    engine.update_logits(token, &mut logits).unwrap(),
    kbnf::AcceptTokenResult::Ongoing
);
assert_eq!(&format!("{:?}", logits), "[-inf, 0.0, 0.0, 0.0, 1.0, 0.0]");
token = greedy_decode(&logits);
logits = [0.0, 1.0, 0.0, 0.0, 0.0, 0.0]; // new logits obtained from the language model
assert_eq!(
    engine.update_logits(token, &mut logits).unwrap(),
    kbnf::AcceptTokenResult::Ongoing
);
assert_eq!(
    &format!("{:?}", logits),
    "[-inf, 1.0, 0.0, 0.0, 0.0, -inf]"
);
token = greedy_decode(&logits);
logits = [0.0, 0.0, 0.0, 0.0, 0.0, 1.0]; // new logits obtained from the language model
assert_eq!(
    engine.update_logits(token, &mut logits).unwrap(),
    kbnf::AcceptTokenResult::Ongoing
);
assert_eq!(&format!("{:?}", logits), "[-inf, 0.0, 0.0, 0.0, 0.0, 1.0]");
token = greedy_decode(&logits);
logits = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; // new logits obtained from the language model
assert_eq!(
    engine.update_logits(token, &mut logits).unwrap(),
    kbnf::AcceptTokenResult::Finished
);
assert_eq!(&format!("{:?}", logits), "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]");
// Currently, if the engine finishes, it will not update the logits.
```

# Overview

The primary type in this crate are [EngineLike] and [Engine]. [EngineLike] defines the behavior of an engine,
while [Engine] is a concrete implementation of [EngineLike]. The most important method in [Engine] are as follows:
- [Engine::new]: This method creates a new engine from a [KBNF grammar](#kbnf-grammar) string, a [Vocabulary] and default configuration.
[Engine::with_config] allows you to specify a custom configuration.
- [Engine::update_logits]: This method tries to accept a new token and then updates the logits accordingly.
- [Engine::reset]: This method resets the engine to its initial state. Notably, the cache is preserved.

This crate-level documentation is organized as follows:

- [Examples](#examples): This section contains some examples of how to use the crate.
- [KBNF Grammar](#kbnf-grammar): This section enumerates the syntax of KBNF grammar.
- [Performance](#performance): This section discusses how to optimize the performance of the engine.

# Examples

## Get initially allowed token IDs

```rust
use ahash::AHashMap;
use kbnf::{Engine, EngineLike, Grammar, Token, Vocabulary};
let grammar_str = r#"
start ::= except!('\n\n')'\n\n';
"#;
let mut token_strings: AHashMap<u32, String> = AHashMap::default();
token_strings.extend(
    [
        (1, "a".to_string()),
        (2, "hello".to_string()),
        (4, "\n".to_string()),
        (5, "\n\n".to_string()),
    ]
);
let tokens = token_strings
    .iter()
    .map(|(k, v)| (*k, Token(v.as_bytes().to_vec().into_boxed_slice())))
    .collect::<AHashMap<u32, _>>();
let vocab = Vocabulary::new(tokens, token_strings).unwrap();
let mut engine = Engine::new(grammar_str, vocab).unwrap();
let mut logits = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; // The logits of the language model
engine.compute_allowed_token_ids();
assert_eq!(
    engine
        .allowed_token_ids_from_last_computation()
        .ones()
        .collect::<Vec<_>>(),
    vec![1, 2, 4, 5]
);
engine.mask_logits(&mut logits).unwrap(); // mask the logits
assert_eq!(&format!("{:?}", logits), "[-inf, 0.0, 0.0, -inf, 0.0, 0.0]");
```

## Update engine's state with some prompts

```rust
use ahash::AHashMap;
use kbnf::{Engine, EngineLike, Grammar, Token, Vocabulary};
let grammar_str = r#"
start ::= except!('\n\n')'\n\n';
"#;
let mut token_strings: AHashMap<u32, String> = AHashMap::default();
token_strings.extend(
    [
        (1, "a".to_string()),
        (2, "hello".to_string()),
        (4, "\n".to_string()),
        (5, "\n\n".to_string()),
    ],
);
let tokens = token_strings
    .iter()
    .map(|(k, v)| (*k, Token(v.as_bytes().to_vec().into_boxed_slice())))
    .collect::<AHashMap<u32, _>>();
let vocab = Vocabulary::new(tokens, token_strings).unwrap();
let mut engine = Engine::new(grammar_str, vocab).unwrap();
let mut logits = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; // The logits of the language model
engine.try_accept_new_token(2).unwrap();
engine.try_accept_new_token(2).unwrap();
engine.compute_allowed_token_ids();
assert_eq!(
    engine
        .allowed_token_ids_from_last_computation()
        .ones()
        .collect::<Vec<_>>(),
    vec![1, 2, 4, 5]
); // get the IDs
engine.mask_logits(&mut logits).unwrap(); // mask the logits
assert_eq!(&format!("{:?}", logits), "[-inf, 0.0, 0.0, -inf, 0.0, 0.0]");
```

## Reuse an engine for multiple generations

```rust
use ahash::AHashMap;
use kbnf::{Engine, EngineLike, Grammar, Token, Vocabulary};
let grammar_str = r#"
start ::= except!('\n\n')'\n\n';
"#;
let mut token_strings: AHashMap<u32, String> = AHashMap::default();
token_strings.extend(
    [
        (1, "a".to_string()),
        (2, "hello".to_string()),
        (4, "\n".to_string()),
        (5, "\n\n".to_string()),
    ],
);
let tokens = token_strings
    .iter()
    .map(|(k, v)| (*k, Token(v.as_bytes().to_vec().into_boxed_slice())))
    .collect::<AHashMap<u32, _>>();
let vocab = Vocabulary::new(tokens, token_strings).unwrap();
let mut engine = Engine::new(grammar_str, vocab).unwrap();
let mut logits = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; // The logits of the language model
engine.try_accept_new_token(2).unwrap();
engine.try_accept_new_token(5).unwrap();
engine.compute_allowed_token_ids();
assert_eq!(
    engine
        .allowed_token_ids_from_last_computation()
        .ones()
        .collect::<Vec<usize>>(),
    Vec::<usize>::new()
);
engine.reset();
assert_eq!(
    engine.update_logits(2, &mut logits).unwrap(),
    kbnf::AcceptTokenResult::Ongoing
);
assert_eq!(&format!("{:?}", logits), "[-inf, 0.0, 0.0, -inf, 0.0, 0.0]");
```

# KBNF Grammar

KBNF is roughly a superset of [EBNF](https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_form). The syntax of KBNF is as follows:

## An informal, quick introduction to terms

- **Terminal**: Terminal is a fancy name for plain, old strings.
- **Nonterminal**: Nonterminal means a symbol that expands into sequences of other symbols.

## Nonterminal definition

Any KBNF grammar is made of nonterminal definitions. **By default, the engine starts from the definition of the nonterminal `start`**.

```ebnf
(*In KBNF,this is a comment.*)
start ::= "A"; (* Defines a nonterminal start that corresponds to a terminal "A". *)
(*The engine will constrain output to be exactly "A".*)
```

A nonterminal can be defined multiple times.

```ebnf
start ::= "A";
start ::= "B";
(*This means nonterminal start can either expand to "A" or "B".
Hence, the engine will constrain the output to be either "A" or "B".*)
```

A nonterminal identifier can contain any number of underscores, ASCII numerical and alphabetic characters.
It cannot start with a numerical character however.

## Terminal

A terminal is a sequence of UTF-8 characters enclosed in double quotes or single quotes.

Currently, these escaped characters are supported:

| Escape sequence | Escaped value            |
|-----------------|--------------------------|
| `\t`            | U+0009 (HT)              |
| `\n`            | U+000A (LF)              |
| `\r`            | U+000D (CR)              |
| `\"`            | U+0022 (QUOTATION MARK)  |
| `\'`            | U+0027 (APOSTROPHE)      |
| `\\`            | U+005C (REVERSE SOLIDUS) |

More escaped characters will be added in the future.

## Concatenation

Two or more symbols in a sequence are concatenated.

```ebnf
start ::= "A" "B"; (* Equivalent to start ::= "AB". *)
```

```ebnf
start ::= "A" start;
(*
The expansion: start -> "A" start -> "A" "A" start -> "A" "A" "A" start -> ...
Hence, the engine will constrain the output to be an infinite sequence of "A"s.
*)
```

## Alternation

Concatenated symbols separated by `|` are alternatives to each other.

```ebnf
start ::= "A" | "B";
(*
The engine will constrain the output to be either "A" or "B".
This is equivalent to:
start ::= "A";
start ::= "B";
*)
```

```ebnf
start ::= "A" start | "B" start;
(*
The engine will constrain the output to be an infinite sequence
that only contains "A" and "B".
*)
```

## Grouping

Symbols enclosed in parentheses are grouped.

```ebnf
start ::= ("A"|"B") "C";
(*
The engine will constrain the output to be either "AC" or "BC".
This is equivalent to:
start ::= "A" "C";
start ::= "B" "C";
*)
```

## Option

Symbols enclosed in square brackets are optional.

```ebnf
start ::= "A" ["B"];
(*
The engine will constrain the output to be either "A" or "AB".
This is equivalent to:
start ::= "A";
start ::= "A" "B";
*)
```

A symbol followed by a `?` is optional.

```ebnf
start ::= "A"? "B";
(*
The engine will constrain the output to be either "B" or "AB".
*)
```

```ebnf
start ::= ("{"start"}")?;
(*
The engine will constrain the output to be a sequence of balanced curly brackets.
*)
```

**NOTE THAT KBNF does not allow the grammar to finish with an empty string.**
Otherwise, the engine will finish immediately, which does not make sense.

## Repetition

Symbols enclosed in curly brackets can be repeated zero or more times.

```ebnf
start ::= "A"{"A"};
```

**NOTE THAT KBNF ends eagerly, so the engine will constrain the output to be exactly one "A".**

```ebnf
start ::= {"A"|"C"} "B";
(*The engine will constrain the output to a sequence
of "A"s and "C"s followed by exactly one "B".*)
```

A symbol followed by a `*` can be repeated zero or more times.

```ebnf
start ::= "A"* "B"; (*The engine will constrain the output to
a sequence of "A"s followed by exactly one "B".*)
```

A symbol followed by a `+` can be repeated one or more times.
```kbnf_syntax
start ::= ("A"|"B")+ "C";
(*The engine will constrain the output to
a nonempty sequence of "A"s and "B"s followed by exactly one "C".*)
```

## Regular expression

A UTF-8 string enclosed in `#""` is a regular expression. The escaped characters supported is the same as [Terminal](##terminal).

```ebnf
start ::= #".+A";
(*
The engine will constrain the output to be
a sequence of any characters followed by exactly one A.
This is equivalent to:
start ::= #".+" "A";
*)
```

The Rust regex crate is used to support regular expressions,
which means [the syntax supported](https://docs.rs/regex/latest/regex/index.html#syntax) might differ from other regex engines.
Notably, the regex crate does not support arbitrary lookarounds. In exchange, linear time matching is guaranteed.
**WARNING: the regular expression is compiled into a DFA which, by its nature, has worst case exponential time and space complexity.**
If you are dealing with untrusted regular expressions,
you should set a memory limit in [Config::regex_config] to prevent DoS attacks.

## Exceptions/except!

Although exception is the formal term, I personally find it confusing, so I will refer to it as "except!".
The `except!` keyword is used to exclude certain strings from the output.

```ebnf
start ::= except!('\n\n')'\n\n';
(*
The engine will constrain the output to be a sequence of characters
that does not contain "\n\n" followed by exactly one "\n\n".
*)
```

**NOTE THAT THE DEFINITION ABOVE DOES ALLOW `\n\n\n`!**
The first `\n` comes from the exception(since `\n != \n\n`), and the second `\n\n` comes from the terminal.
If you want a string that strictly ends with `\n\n`, you should use the following definition:

```ebnf
start ::= #".*\n\n";
```

You can use a nonterminal that directly contains alternations of terminals in `except!`.

```ebnf
start ::= except!(C)C;
C ::= "A"|"B";
(*The engine will constrain the output to be
a sequence of characters that ends with "A" or "B". *)
```

You can also specify the maximum repetition of `except!`.

```ebnf
start ::= except!('\n\n',50)'\n\n';
(*The engine will constrain the output
to be a sequence of bytes of maximum length 50
that does not contain "\n\n" followed by exactly one "\n\n".*)
```

# Performance

## Reducing ambuguity

Grammar structure is the most influential factor in the performance of the engine **asymptotically**.

Practically speaking, if your engine runs abymally slow for long inputs, you should check the grammar
for [ambiguity](https://en.wikipedia.org/wiki/Ambiguous_grammar). Unfortunately, determining ambiguity is undecidable.
There does exist some heuristics to detect ambiguity like
[Shift-Reduce Conflict](https://www.gnu.org/software/bison/manual/html_node/Shift_002fReduce.html) and
[Reduce-Reduce Conflict](https://www.gnu.org/software/bison/manual/html_node/Reduce_002fReduce.html#:~:text=A%20reduce/reduce%20conflict%20occurs,zero%20or%20more%20word%20groupings).
They may be implemented in this crate in the future. Some locally disambiguation methods may be implemented in the future as well.

## Reuse an engine for multiple generations with cache enabled

Caches are preserved between [Engine::reset] calls.
Hence, if your grammar and vocabulary are fixed, you should reuse the engine for multiple generations,
so when the engine hits the same state, it can directly fetch the allowed token IDs from the cache without recomputation.

## Prefer regular expressions over context-free grammars

Regular expressions are compiled into a DFA, which has lower overhead than Earley recognizer.

## Prefer left recursion over right recursion

While Leo optimization ensures both left and right recursion have linear time complexity,
it still introduces a constant factor overhead.
*/
#![warn(missing_docs)]
#![warn(rustdoc::broken_intra_doc_links)]
pub mod config;
pub mod engine;
pub mod engine_base;
pub mod engine_like;
mod ffi_bindings;
pub mod grammar;
pub mod utils;
pub mod vocabulary;
mod zero;
pub use config::Config;
pub use engine::Engine;
pub use engine_like::AcceptTokenResult;
pub use engine_like::EngineLike;
pub use grammar::Grammar;
#[cfg(feature = "python")]
use pyo3::prelude::*;
#[cfg(feature = "mimalloc")]
use mimalloc::MiMalloc;
pub use vocabulary::Token;
pub use vocabulary::Vocabulary;


#[cfg(feature = "mimalloc")]
#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;

#[cfg(feature = "python")]
#[pymodule]
#[pyo3(name = "kbnf")]
fn kbnf(m: &Bound<'_, PyModule>) -> PyResult<()> {
    pyo3_log::init();
    m.add_class::<Config>()?;
    m.add_class::<config::CompressionConfig>()?;
    m.add_class::<config::Fsa>()?;
    m.add_class::<config::RegexConfig>()?;
    m.add_class::<engine::EngineConfig>()?;
    m.add_class::<Engine>()?;
    m.add_class::<AcceptTokenResult>()?;
    m.add_class::<engine_like::AcceptTokenError>()?;
    m.add_class::<engine_like::MaskLogitsError>()?;
    m.add_class::<engine_like::UpdateLogitsError>()?;
    m.add_class::<Vocabulary>()?;
    m.add_class::<Token>()?;
    Ok(())
}