impl PagedMarkovModel {
#[must_use]
pub fn new(n: usize, memory_limit_mb: usize) -> Self {
let memory_limit = (memory_limit_mb * 1024 * 1024).max(MIN_MEMORY_LIMIT);
Self {
n: n.clamp(2, 5),
memory_limit,
metadata: PagedModelMetadata {
n,
total_commands: 0,
segment_count: 0,
command_freq: HashMap::new(),
segment_prefixes: Vec::new(),
},
bundle: None,
segments: HashMap::new(),
trie: Some(Trie::new()),
bundle_path: None,
}
}
#[must_use]
pub fn memory_limit(&self) -> usize {
self.memory_limit
}
pub fn train(&mut self, commands: &[String]) {
self.metadata.total_commands = commands.len();
for cmd in commands {
*self.metadata.command_freq.entry(cmd.clone()).or_insert(0) += 1;
if let Some(ref mut trie) = self.trie {
trie.insert(cmd);
}
let tokens: Vec<&str> = cmd.split_whitespace().collect();
if tokens.is_empty() {
continue;
}
let prefix = tokens[0].to_string();
let segment = self
.segments
.entry(prefix.clone())
.or_insert_with(|| NgramSegment::new(prefix));
segment.add(String::new(), tokens[0].to_string(), 1);
for i in 0..tokens.len() {
let context_start = i.saturating_sub(self.n - 1);
let context: String = tokens[context_start..=i].join(" ");
if i + 1 < tokens.len() {
segment.add(context, tokens[i + 1].to_string(), 1);
}
}
}
self.metadata.segment_count = self.segments.len();
self.metadata.segment_prefixes = self.segments.keys().cloned().collect();
}
pub fn save(&self, path: &Path) -> std::io::Result<()> {
let path_str = path.to_string_lossy().to_string();
let metadata_bytes = serde_json::to_vec(&self.metadata)
.map_err(|e| std::io::Error::other(format!("Failed to serialize metadata: {e}")))?;
let mut builder = BundleBuilder::new(&path_str)
.with_config(BundleConfig::new().with_compression(false))
.add_model("metadata", metadata_bytes);
for (prefix, segment) in &self.segments {
let segment_bytes = segment.to_bytes();
builder = builder.add_model(format!("segment_{prefix}"), segment_bytes);
}
builder
.build()
.map_err(|e| std::io::Error::other(format!("Failed to build bundle: {e}")))?;
Ok(())
}
pub fn load(path: &Path, memory_limit_mb: usize) -> std::io::Result<Self> {
let memory_limit = (memory_limit_mb * 1024 * 1024).max(MIN_MEMORY_LIMIT);
let paging_config = PagingConfig::new()
.with_max_memory(memory_limit)
.with_prefetch(true);
let mut bundle = PagedBundle::open(path, paging_config)
.map_err(|e| std::io::Error::other(format!("Failed to open bundle: {e}")))?;
let metadata_bytes = bundle
.get_model("metadata")
.map_err(|e| std::io::Error::other(format!("Failed to read metadata: {e}")))?;
let metadata: PagedModelMetadata = serde_json::from_slice(metadata_bytes)
.map_err(|e| std::io::Error::other(format!("Failed to parse metadata: {e}")))?;
let mut trie = Trie::new();
for cmd in metadata.command_freq.keys() {
trie.insert(cmd);
}
Ok(Self {
n: metadata.n,
memory_limit,
metadata,
bundle: Some(bundle),
segments: HashMap::new(), trie: Some(trie),
bundle_path: Some(path.to_path_buf()),
})
}
fn load_segment(&mut self, prefix: &str) -> std::io::Result<Option<NgramSegment>> {
if let Some(segment) = self.segments.get(prefix) {
return Ok(Some(segment.clone()));
}
if let Some(ref mut bundle) = self.bundle {
let model_name = format!("segment_{prefix}");
if bundle.model_names().iter().any(|n| *n == model_name) {
let bytes = bundle.get_model(&model_name).map_err(|e| {
std::io::Error::other(format!("Failed to read segment '{prefix}': {e}"))
})?;
let segment = NgramSegment::from_bytes(bytes)?;
self.segments.insert(prefix.to_string(), segment.clone());
return Ok(Some(segment));
}
}
Ok(None)
}
pub fn suggest(&mut self, prefix: &str, count: usize) -> Vec<(String, f32)> {
let ends_with_space = prefix.is_empty() || prefix.ends_with(' ');
let prefix = prefix.trim();
let tokens: Vec<&str> = prefix.split_whitespace().collect();
let mut suggestions = Vec::new();
if let Some(ref trie) = self.trie {
for cmd in trie.find_prefix(prefix, count * 4) {
let freq = self.metadata.command_freq.get(&cmd).copied().unwrap_or(1);
let score = freq as f32 / self.metadata.total_commands.max(1) as f32;
suggestions.push((cmd, score));
}
}
if !tokens.is_empty() && ends_with_space {
let segment_prefix = tokens[0];
if let Ok(Some(segment)) = self.load_segment(segment_prefix) {
let context_start = tokens.len().saturating_sub(self.n - 1);
let context = tokens[context_start..].join(" ");
if let Some(next_tokens) = segment.ngrams.get(&context) {
let total: u32 = next_tokens.values().sum();
for (token, ngram_count) in next_tokens {
let completion = format!("{} {}", prefix.trim(), token);
let score = *ngram_count as f32 / total as f32;
if !suggestions.iter().any(|(s, _)| s == &completion) {
suggestions.push((completion, score * 0.8));
}
}
}
}
}
suggestions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
suggestions.truncate(count);
suggestions
}
#[must_use]
pub fn stats(&self) -> PagedModelStats {
let loaded_segments = self.segments.len();
let total_segments = self.metadata.segment_count;
let loaded_bytes: usize = self.segments.values().map(|s| s.size_bytes).sum();
PagedModelStats {
n: self.n,
total_commands: self.metadata.total_commands,
vocab_size: self.metadata.command_freq.len(),
total_segments,
loaded_segments,
memory_limit: self.memory_limit,
loaded_bytes,
bundle_path: self.bundle_path.clone(),
}
}
pub fn paging_stats(&self) -> Option<PagingStats> {
self.bundle.as_ref().map(|b| b.stats().clone())
}
pub fn prefetch_hint(&mut self, prefix: &str) {
if let Some(ref mut bundle) = self.bundle {
let _ = bundle.prefetch_hint(&format!("segment_{prefix}"));
}
}
#[must_use]
pub fn total_commands(&self) -> usize {
self.metadata.total_commands
}
#[must_use]
pub fn ngram_size(&self) -> usize {
self.n
}
#[must_use]
pub fn vocab_size(&self) -> usize {
self.metadata.command_freq.len()
}
#[must_use]
pub fn top_commands(&self, count: usize) -> Vec<(String, u32)> {
let mut cmds: Vec<_> = self
.metadata
.command_freq
.iter()
.map(|(k, v)| (k.clone(), *v))
.collect();
cmds.sort_by(|a, b| b.1.cmp(&a.1));
cmds.truncate(count);
cmds
}
}
#[derive(Debug, Clone)]
pub struct PagedModelStats {
pub n: usize,
pub total_commands: usize,
pub vocab_size: usize,
pub total_segments: usize,
pub loaded_segments: usize,
pub memory_limit: usize,
pub loaded_bytes: usize,
pub bundle_path: Option<std::path::PathBuf>,
}
impl std::fmt::Display for PagedModelStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Paged Model Statistics:")?;
writeln!(f, " N-gram size: {}", self.n)?;
writeln!(f, " Total commands: {}", self.total_commands)?;
writeln!(f, " Vocabulary size: {}", self.vocab_size)?;
writeln!(
f,
" Segments: {}/{} loaded",
self.loaded_segments, self.total_segments
)?;
writeln!(
f,
" Memory limit: {:.1} MB",
self.memory_limit as f64 / 1024.0 / 1024.0
)?;
writeln!(
f,
" Loaded bytes: {:.1} KB",
self.loaded_bytes as f64 / 1024.0
)?;
if let Some(ref path) = self.bundle_path {
writeln!(f, " Bundle path: {}", path.display())?;
}
Ok(())
}
}