import argparse
import json
from pathlib import Path
def escape_rust_string(s: str) -> str:
return s.replace("\\", "\\\\").replace('"', '\\"')
def generate_rust_code(model: dict) -> str:
id_to_token = model["id_to_token"]
begin_transitions = model["begin_transitions"]
transitions = model["transitions"]
end_transitions = model["end_transitions"]
lines = []
lines.append("//! N-gram model for generating English-like words from entropy.")
lines.append("//!")
lines.append("//! This module is auto-generated by `models/generate_rust.py`")
lines.append("//! Do not edit manually.")
lines.append("")
lines.append("/// Token vocabulary - maps token ID to token string.")
lines.append("/// Tokens 0-255 are beginning tokens (^prefix).")
lines.append("/// Tokens 256-511 are end tokens (suffix$).")
lines.append("/// Tokens 512-1023 are middle tokens.")
lines.append(f"pub const TOKENS: [&str; {len(id_to_token)}] = [")
for token in id_to_token:
escaped = escape_rust_string(token)
lines.append(f' "{escaped}",')
lines.append("];")
lines.append("")
lines.append("/// Beginning token transitions.")
lines.append("/// Format: (`token_id`, `cumulative_probability` as u8)")
lines.append(f"pub const BEGIN_TRANSITIONS: [(u16, u8); {len(begin_transitions)}] = [")
for token_id, cumulative in begin_transitions:
if isinstance(cumulative, float) and cumulative <= 1.0:
cum_u8 = min(255, int(cumulative * 255))
else:
cum_u8 = min(255, int(cumulative))
lines.append(f" ({token_id}, {cum_u8}),")
lines.append("];")
lines.append("")
transition_index = []
transition_data = []
for token_id in range(len(id_to_token)):
key = str(token_id)
if key in transitions:
start = len(transition_data)
for next_id, cumulative in transitions[key]:
if isinstance(cumulative, float) and cumulative <= 1.0:
cum_u8 = min(255, int(cumulative * 255))
else:
cum_u8 = min(255, int(cumulative))
transition_data.append((next_id, cum_u8))
length = len(transition_data) - start
transition_index.append((start, length))
else:
transition_index.append((0, 0))
lines.append("/// Index into `TRANSITION_DATA`: (start, length) for each token.")
lines.append(f"pub const TRANSITION_INDEX: [(u32, u16); {len(transition_index)}] = [")
for start, length in transition_index:
lines.append(f" ({start}, {length}),")
lines.append("];")
lines.append("")
lines.append("/// Transition data: (`next_token_id`, `cumulative_probability` as u8)")
lines.append(f"pub static TRANSITION_DATA: [(u16, u8); {len(transition_data)}] = [")
for next_id, cum_u8 in transition_data:
lines.append(f" ({next_id}, {cum_u8}),")
lines.append("];")
lines.append("")
end_transition_index = []
end_transition_data = []
for token_id in range(len(id_to_token)):
key = str(token_id)
if key in end_transitions:
start = len(end_transition_data)
for next_id, cumulative in end_transitions[key]:
if isinstance(cumulative, float) and cumulative <= 1.0:
cum_u8 = min(255, int(cumulative * 255))
else:
cum_u8 = min(255, int(cumulative))
end_transition_data.append((next_id, cum_u8))
length = len(end_transition_data) - start
end_transition_index.append((start, length))
else:
end_transition_index.append((0, 0))
lines.append("/// Index into `END_TRANSITION_DATA`: (start, length) for each token.")
lines.append(f"pub const END_TRANSITION_INDEX: [(u32, u16); {len(end_transition_index)}] = [")
for start, length in end_transition_index:
lines.append(f" ({start}, {length}),")
lines.append("];")
lines.append("")
lines.append("/// End transition data: (`end_token_id`, `cumulative_probability` as u8)")
lines.append(f"pub static END_TRANSITION_DATA: [(u16, u8); {len(end_transition_data)}] = [")
for next_id, cum_u8 in end_transition_data:
lines.append(f" ({next_id}, {cum_u8}),")
lines.append("];")
lines.append("")
lines.append("""use crate::ByteReader;
/// Find token by binary searching cumulative probabilities.
fn find_token(transitions: &[(u16, u8)], value: u8) -> u16 {
for (token_id, cumulative) in transitions {
if *cumulative >= value {
return *token_id;
}
}
transitions.last().map_or(0, |(id, _)| *id)
}
/// Get the text for a token, stripping position markers.
fn token_text(token_id: u16) -> &'static str {
let token = TOKENS[token_id as usize];
let without_prefix = token.strip_prefix('^').unwrap_or(token);
without_prefix.strip_suffix('$').unwrap_or(without_prefix)
}
/// Bit reader that wraps a `ByteReader`, buffering bytes and reading bits.
struct BitReader<'a, R: ByteReader> {
reader: &'a mut R,
buffer: Vec<u8>,
bit_pos: usize,
exhausted: bool,
}
impl<'a, R: ByteReader> BitReader<'a, R> {
const fn new(reader: &'a mut R) -> Self {
Self {
reader,
buffer: Vec::new(),
bit_pos: 0,
exhausted: false,
}
}
/// Ensure we have at least `bits` available in the buffer.
fn ensure_bits(&mut self, bits: usize) -> bool {
if self.exhausted {
return self.bits_available() >= bits;
}
let bytes_needed = (self.bit_pos + bits).div_ceil(8);
while self.buffer.len() < bytes_needed {
let mut byte = [0u8; 1];
if self.reader.read(&mut byte) == 0 {
self.exhausted = true;
break;
}
self.buffer.push(byte[0]);
}
self.bits_available() >= bits
}
const fn bits_available(&self) -> usize {
(self.buffer.len() * 8).saturating_sub(self.bit_pos)
}
fn read_u8(&mut self) -> Option<u8> {
if !self.ensure_bits(8) {
return None;
}
let mut result: u8 = 0;
for _ in 0..8 {
let byte_idx = self.bit_pos / 8;
let bit_idx = self.bit_pos % 8;
let bit = (self.buffer[byte_idx] >> (7 - bit_idx)) & 1;
result = (result << 1) | bit;
self.bit_pos += 1;
}
Some(result)
}
fn has_more(&mut self) -> bool {
self.ensure_bits(8)
}
}
/// Generate an English-like word with a minimum target length.
///
/// The output always ends with an end token. If it cannot exactly match
/// the target length, it will stop at the shortest possible length
/// that is >= `target_len` when such an end token is available.
pub fn generate_word_with_target_len<R: ByteReader>(
reader: &mut R,
target_len: usize,
) -> String {
let mut bit_reader = BitReader::new(reader);
let mut result = String::new();
// Select beginning token
let Some(begin_value) = bit_reader.read_u8() else {
return String::new();
};
let first_token = find_token(&BEGIN_TRANSITIONS, begin_value);
result.push_str(token_text(first_token));
let mut current_token = first_token;
let mut current_len = result.len();
loop {
let (end_start, end_len) = END_TRANSITION_INDEX[current_token as usize];
if end_len > 0 {
let end_trans =
&END_TRANSITION_DATA[end_start as usize..(end_start as usize + end_len as usize)];
let mut can_reach_target = current_len >= target_len;
if !can_reach_target {
for (end_id, _) in end_trans {
if current_len + token_text(*end_id).len() >= target_len {
can_reach_target = true;
break;
}
}
}
if can_reach_target {
let value = bit_reader.read_u8().unwrap_or(0);
let mut end_token = find_token(end_trans, value);
if current_len + token_text(end_token).len() < target_len {
if let Some((end_id, _)) = end_trans
.iter()
.find(|(end_id, _)| current_len + token_text(*end_id).len() >= target_len)
{
end_token = *end_id;
} else if let Some((end_id, _)) = end_trans.last() {
end_token = *end_id;
}
}
result.push_str(token_text(end_token));
break;
}
}
let (start, len) = TRANSITION_INDEX[current_token as usize];
if len == 0 {
break;
}
let Some(value) = bit_reader.read_u8() else {
break;
};
let trans = &TRANSITION_DATA[start as usize..(start as usize + len as usize)];
let next_token = find_token(trans, value);
result.push_str(token_text(next_token));
current_token = next_token;
current_len = result.len();
}
result
}
/// Generate an English-like word from a `ByteReader`.
///
/// Reads bytes from the reader and generates tokens until the reader
/// is exhausted. The word consists of a beginning token, zero or more
/// middle tokens, and an end token.
///
/// # Panics
///
/// This function will not panic under normal usage. Internal assertions
/// are guaranteed by the function's control flow.
pub fn generate_word<R: ByteReader>(reader: &mut R) -> String {
let mut bit_reader = BitReader::new(reader);
let mut result = String::new();
// Select beginning token
let Some(begin_value) = bit_reader.read_u8() else {
return String::new();
};
let first_token = find_token(&BEGIN_TRANSITIONS, begin_value);
let mut current_token: Option<u16> = Some(first_token);
result.push_str(token_text(first_token));
// Select middle tokens while we have entropy
while bit_reader.has_more() {
let Some(current) = current_token else {
break;
};
let (start, len) = TRANSITION_INDEX[current as usize];
if len == 0 {
break;
}
let Some(value) = bit_reader.read_u8() else {
break;
};
let trans = &TRANSITION_DATA[start as usize..(start as usize + len as usize)];
let next_token = find_token(trans, value);
current_token = Some(next_token);
result.push_str(token_text(next_token));
}
// Select end token using remaining bits or default
if let Some(current) = current_token {
let (start, len) = END_TRANSITION_INDEX[current as usize];
if len > 0 {
let trans = &END_TRANSITION_DATA[start as usize..(start as usize + len as usize)];
let value = bit_reader.read_u8().unwrap_or(0);
let end_token = find_token(trans, value);
result.push_str(token_text(end_token));
}
}
result
}
""")
return "\n".join(lines)
def main():
parser = argparse.ArgumentParser(
description="Generate Rust code from n-gram model"
)
parser.add_argument(
"model",
type=Path,
help="Path to model JSON file",
)
parser.add_argument(
"-o",
"--output",
type=Path,
default=Path("../src/english_word.rs"),
help="Output Rust file (default: ../src/english_word.rs)",
)
args = parser.parse_args()
if not args.model.exists():
print(f"Error: Model file {args.model} does not exist")
return 1
print(f"Loading model from {args.model}...")
model = json.loads(args.model.read_text(encoding="utf-8"))
print("Generating Rust code...")
rust_code = generate_rust_code(model)
print(f"Writing to {args.output}...")
args.output.write_text(rust_code, encoding="utf-8")
print(f"Generated {len(rust_code)} bytes of Rust code")
print(f" Tokens: {len(model['id_to_token'])}")
print(f" Begin transitions: {len(model['begin_transitions'])}")
print(f" Middle transitions: {len(model['transitions'])}")
print(f" End transitions: {len(model['end_transitions'])}")
if __name__ == "__main__":
main()