use aprender::bundle::{BundleBuilder, BundleConfig, PagedBundle, PagingConfig, PagingStats};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use crate::trie::Trie;
const MIN_MEMORY_LIMIT: usize = 1024 * 1024;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NgramSegment {
pub prefix: String,
pub ngrams: HashMap<String, HashMap<String, u32>>,
pub size_bytes: usize,
}
impl NgramSegment {
#[must_use]
pub fn new(prefix: String) -> Self {
Self {
prefix,
ngrams: HashMap::new(),
size_bytes: 0,
}
}
pub fn add(&mut self, context: String, next_token: String, count: u32) {
let entry = self.ngrams.entry(context).or_default();
*entry.entry(next_token).or_insert(0) += count;
self.update_size();
}
fn update_size(&mut self) {
self.size_bytes = self
.ngrams
.iter()
.map(|(k, v)| k.len() + v.keys().map(|k2| k2.len() + 4).sum::<usize>())
.sum();
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::new();
let prefix_bytes = self.prefix.as_bytes();
bytes.extend(&(prefix_bytes.len() as u32).to_le_bytes());
bytes.extend(prefix_bytes);
bytes.extend(&(self.ngrams.len() as u32).to_le_bytes());
for (context, next_tokens) in &self.ngrams {
let ctx_bytes = context.as_bytes();
bytes.extend(&(ctx_bytes.len() as u32).to_le_bytes());
bytes.extend(ctx_bytes);
bytes.extend(&(next_tokens.len() as u32).to_le_bytes());
for (token, count) in next_tokens {
let tok_bytes = token.as_bytes();
bytes.extend(&(tok_bytes.len() as u32).to_le_bytes());
bytes.extend(tok_bytes);
bytes.extend(&count.to_le_bytes());
}
}
bytes
}
pub fn from_bytes(bytes: &[u8]) -> std::io::Result<Self> {
let mut pos = 0;
let read_u32 = |data: &[u8], offset: usize| -> std::io::Result<u32> {
let slice = data
.get(offset..offset + 4)
.ok_or_else(|| std::io::Error::other("Truncated segment data"))?;
let arr: [u8; 4] = slice
.try_into()
.map_err(|_| std::io::Error::other("Invalid byte slice"))?;
Ok(u32::from_le_bytes(arr))
};
let prefix_len = read_u32(bytes, pos)? as usize;
pos += 4;
if bytes.len() < pos + prefix_len {
return Err(std::io::Error::other("Truncated prefix"));
}
let prefix = String::from_utf8_lossy(&bytes[pos..pos + prefix_len]).to_string();
pos += prefix_len;
let ngram_count = read_u32(bytes, pos)? as usize;
pos += 4;
let mut ngrams = HashMap::with_capacity(ngram_count);
for _ in 0..ngram_count {
let ctx_len = read_u32(bytes, pos)? as usize;
pos += 4;
if bytes.len() < pos + ctx_len {
return Err(std::io::Error::other("Truncated context"));
}
let context = String::from_utf8_lossy(&bytes[pos..pos + ctx_len]).to_string();
pos += ctx_len;
let token_count = read_u32(bytes, pos)? as usize;
pos += 4;
let mut next_tokens = HashMap::with_capacity(token_count);
for _ in 0..token_count {
let tok_len = read_u32(bytes, pos)? as usize;
pos += 4;
if bytes.len() < pos + tok_len {
return Err(std::io::Error::other("Truncated token"));
}
let token = String::from_utf8_lossy(&bytes[pos..pos + tok_len]).to_string();
pos += tok_len;
let count = read_u32(bytes, pos)?;
pos += 4;
next_tokens.insert(token, count);
}
ngrams.insert(context, next_tokens);
}
let mut segment = Self {
prefix,
ngrams,
size_bytes: 0,
};
segment.update_size();
Ok(segment)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PagedModelMetadata {
pub n: usize,
pub total_commands: usize,
pub segment_count: usize,
pub command_freq: HashMap<String, u32>,
pub segment_prefixes: Vec<String>,
}
pub struct PagedMarkovModel {
n: usize,
memory_limit: usize,
metadata: PagedModelMetadata,
bundle: Option<PagedBundle>,
segments: HashMap<String, NgramSegment>,
trie: Option<Trie>,
bundle_path: Option<std::path::PathBuf>,
}
include!("paged_model_stats.rs");
include!("paged_model_ngram_segment.rs");