1use aprender::bundle::{BundleBuilder, BundleConfig, PagedBundle, PagingConfig, PagingStats};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10
11use crate::trie::Trie;
12
13const MIN_MEMORY_LIMIT: usize = 1024 * 1024;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct NgramSegment {
22 pub prefix: String,
24 pub ngrams: HashMap<String, HashMap<String, u32>>,
26 pub size_bytes: usize,
28}
29
30impl NgramSegment {
31 #[must_use]
33 pub fn new(prefix: String) -> Self {
34 Self {
35 prefix,
36 ngrams: HashMap::new(),
37 size_bytes: 0,
38 }
39 }
40
41 pub fn add(&mut self, context: String, next_token: String, count: u32) {
43 let entry = self.ngrams.entry(context).or_default();
44 *entry.entry(next_token).or_insert(0) += count;
45 self.update_size();
46 }
47
48 fn update_size(&mut self) {
50 self.size_bytes = self
51 .ngrams
52 .iter()
53 .map(|(k, v)| k.len() + v.keys().map(|k2| k2.len() + 4).sum::<usize>())
54 .sum();
55 }
56
57 pub fn to_bytes(&self) -> Vec<u8> {
59 let mut bytes = Vec::new();
61
62 let prefix_bytes = self.prefix.as_bytes();
64 bytes.extend(&(prefix_bytes.len() as u32).to_le_bytes());
65 bytes.extend(prefix_bytes);
66
67 bytes.extend(&(self.ngrams.len() as u32).to_le_bytes());
69
70 for (context, next_tokens) in &self.ngrams {
71 let ctx_bytes = context.as_bytes();
73 bytes.extend(&(ctx_bytes.len() as u32).to_le_bytes());
74 bytes.extend(ctx_bytes);
75
76 bytes.extend(&(next_tokens.len() as u32).to_le_bytes());
78
79 for (token, count) in next_tokens {
80 let tok_bytes = token.as_bytes();
82 bytes.extend(&(tok_bytes.len() as u32).to_le_bytes());
83 bytes.extend(tok_bytes);
84 bytes.extend(&count.to_le_bytes());
86 }
87 }
88
89 bytes
90 }
91
92 pub fn from_bytes(bytes: &[u8]) -> std::io::Result<Self> {
94 let mut pos = 0;
95
96 let read_u32 = |data: &[u8], offset: usize| -> std::io::Result<u32> {
98 let slice = data
99 .get(offset..offset + 4)
100 .ok_or_else(|| std::io::Error::other("Truncated segment data"))?;
101 let arr: [u8; 4] = slice
102 .try_into()
103 .map_err(|_| std::io::Error::other("Invalid byte slice"))?;
104 Ok(u32::from_le_bytes(arr))
105 };
106
107 let prefix_len = read_u32(bytes, pos)? as usize;
109 pos += 4;
110
111 if bytes.len() < pos + prefix_len {
112 return Err(std::io::Error::other("Truncated prefix"));
113 }
114 let prefix = String::from_utf8_lossy(&bytes[pos..pos + prefix_len]).to_string();
115 pos += prefix_len;
116
117 let ngram_count = read_u32(bytes, pos)? as usize;
119 pos += 4;
120
121 let mut ngrams = HashMap::with_capacity(ngram_count);
122
123 for _ in 0..ngram_count {
124 let ctx_len = read_u32(bytes, pos)? as usize;
126 pos += 4;
127
128 if bytes.len() < pos + ctx_len {
129 return Err(std::io::Error::other("Truncated context"));
130 }
131 let context = String::from_utf8_lossy(&bytes[pos..pos + ctx_len]).to_string();
132 pos += ctx_len;
133
134 let token_count = read_u32(bytes, pos)? as usize;
136 pos += 4;
137
138 let mut next_tokens = HashMap::with_capacity(token_count);
139
140 for _ in 0..token_count {
141 let tok_len = read_u32(bytes, pos)? as usize;
143 pos += 4;
144
145 if bytes.len() < pos + tok_len {
146 return Err(std::io::Error::other("Truncated token"));
147 }
148 let token = String::from_utf8_lossy(&bytes[pos..pos + tok_len]).to_string();
149 pos += tok_len;
150
151 let count = read_u32(bytes, pos)?;
153 pos += 4;
154
155 next_tokens.insert(token, count);
156 }
157
158 ngrams.insert(context, next_tokens);
159 }
160
161 let mut segment = Self {
162 prefix,
163 ngrams,
164 size_bytes: 0,
165 };
166 segment.update_size();
167 Ok(segment)
168 }
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct PagedModelMetadata {
174 pub n: usize,
176 pub total_commands: usize,
178 pub segment_count: usize,
180 pub command_freq: HashMap<String, u32>,
182 pub segment_prefixes: Vec<String>,
184}
185
186pub struct PagedMarkovModel {
192 n: usize,
194 memory_limit: usize,
196 metadata: PagedModelMetadata,
198 bundle: Option<PagedBundle>,
200 segments: HashMap<String, NgramSegment>,
202 trie: Option<Trie>,
204 bundle_path: Option<std::path::PathBuf>,
206}
207
208impl PagedMarkovModel {
209 #[must_use]
215 pub fn new(n: usize, memory_limit_mb: usize) -> Self {
216 let memory_limit = (memory_limit_mb * 1024 * 1024).max(MIN_MEMORY_LIMIT);
217 Self {
218 n: n.clamp(2, 5),
219 memory_limit,
220 metadata: PagedModelMetadata {
221 n,
222 total_commands: 0,
223 segment_count: 0,
224 command_freq: HashMap::new(),
225 segment_prefixes: Vec::new(),
226 },
227 bundle: None,
228 segments: HashMap::new(),
229 trie: Some(Trie::new()),
230 bundle_path: None,
231 }
232 }
233
234 #[must_use]
236 #[allow(dead_code)]
237 pub fn memory_limit(&self) -> usize {
238 self.memory_limit
239 }
240
241 pub fn train(&mut self, commands: &[String]) {
243 self.metadata.total_commands = commands.len();
244
245 for cmd in commands {
246 *self.metadata.command_freq.entry(cmd.clone()).or_insert(0) += 1;
248
249 if let Some(ref mut trie) = self.trie {
251 trie.insert(cmd);
252 }
253
254 let tokens: Vec<&str> = cmd.split_whitespace().collect();
256 if tokens.is_empty() {
257 continue;
258 }
259
260 let prefix = tokens[0].to_string();
262
263 let segment = self
265 .segments
266 .entry(prefix.clone())
267 .or_insert_with(|| NgramSegment::new(prefix));
268
269 segment.add(String::new(), tokens[0].to_string(), 1);
271
272 for i in 0..tokens.len() {
274 let context_start = i.saturating_sub(self.n - 1);
275 let context: String = tokens[context_start..=i].join(" ");
276
277 if i + 1 < tokens.len() {
278 segment.add(context, tokens[i + 1].to_string(), 1);
279 }
280 }
281 }
282
283 self.metadata.segment_count = self.segments.len();
285 self.metadata.segment_prefixes = self.segments.keys().cloned().collect();
286 }
287
288 pub fn save(&self, path: &Path) -> std::io::Result<()> {
290 let path_str = path.to_string_lossy().to_string();
291
292 let metadata_bytes = serde_json::to_vec(&self.metadata)
294 .map_err(|e| std::io::Error::other(format!("Failed to serialize metadata: {e}")))?;
295
296 let mut builder = BundleBuilder::new(&path_str)
297 .with_config(BundleConfig::new().with_compression(false))
298 .add_model("metadata", metadata_bytes);
299
300 for (prefix, segment) in &self.segments {
302 let segment_bytes = segment.to_bytes();
303 builder = builder.add_model(format!("segment_{prefix}"), segment_bytes);
304 }
305
306 builder
308 .build()
309 .map_err(|e| std::io::Error::other(format!("Failed to build bundle: {e}")))?;
310
311 Ok(())
312 }
313
314 pub fn load(path: &Path, memory_limit_mb: usize) -> std::io::Result<Self> {
316 let memory_limit = (memory_limit_mb * 1024 * 1024).max(MIN_MEMORY_LIMIT);
317
318 let paging_config = PagingConfig::new()
320 .with_max_memory(memory_limit)
321 .with_prefetch(true);
322
323 let mut bundle = PagedBundle::open(path, paging_config)
324 .map_err(|e| std::io::Error::other(format!("Failed to open bundle: {e}")))?;
325
326 let metadata_bytes = bundle
328 .get_model("metadata")
329 .map_err(|e| std::io::Error::other(format!("Failed to read metadata: {e}")))?;
330
331 let metadata: PagedModelMetadata = serde_json::from_slice(metadata_bytes)
332 .map_err(|e| std::io::Error::other(format!("Failed to parse metadata: {e}")))?;
333
334 let mut trie = Trie::new();
336 for cmd in metadata.command_freq.keys() {
337 trie.insert(cmd);
338 }
339
340 Ok(Self {
341 n: metadata.n,
342 memory_limit,
343 metadata,
344 bundle: Some(bundle),
345 segments: HashMap::new(), trie: Some(trie),
347 bundle_path: Some(path.to_path_buf()),
348 })
349 }
350
351 fn load_segment(&mut self, prefix: &str) -> std::io::Result<Option<NgramSegment>> {
353 if let Some(segment) = self.segments.get(prefix) {
354 return Ok(Some(segment.clone()));
355 }
356
357 if let Some(ref mut bundle) = self.bundle {
358 let model_name = format!("segment_{prefix}");
359 if bundle.model_names().iter().any(|n| *n == model_name) {
361 let bytes = bundle.get_model(&model_name).map_err(|e| {
362 std::io::Error::other(format!("Failed to read segment '{prefix}': {e}"))
363 })?;
364 let segment = NgramSegment::from_bytes(bytes)?;
365 self.segments.insert(prefix.to_string(), segment.clone());
366 return Ok(Some(segment));
367 }
368 }
369
370 Ok(None)
371 }
372
373 pub fn suggest(&mut self, prefix: &str, count: usize) -> Vec<(String, f32)> {
375 let ends_with_space = prefix.is_empty() || prefix.ends_with(' ');
377 let prefix = prefix.trim();
378 let tokens: Vec<&str> = prefix.split_whitespace().collect();
379
380 let mut suggestions = Vec::new();
381
382 if let Some(ref trie) = self.trie {
384 for cmd in trie.find_prefix(prefix, count * 4) {
385 let freq = self.metadata.command_freq.get(&cmd).copied().unwrap_or(1);
386 let score = freq as f32 / self.metadata.total_commands.max(1) as f32;
387 suggestions.push((cmd, score));
388 }
389 }
390
391 if !tokens.is_empty() && ends_with_space {
393 let segment_prefix = tokens[0];
394
395 if let Ok(Some(segment)) = self.load_segment(segment_prefix) {
397 let context_start = tokens.len().saturating_sub(self.n - 1);
398 let context = tokens[context_start..].join(" ");
399
400 if let Some(next_tokens) = segment.ngrams.get(&context) {
401 let total: u32 = next_tokens.values().sum();
402
403 for (token, ngram_count) in next_tokens {
404 let completion = format!("{} {}", prefix.trim(), token);
405 let score = *ngram_count as f32 / total as f32;
406
407 if !suggestions.iter().any(|(s, _)| s == &completion) {
408 suggestions.push((completion, score * 0.8));
409 }
410 }
411 }
412 }
413 }
414
415 suggestions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
417 suggestions.truncate(count);
418
419 suggestions
420 }
421
422 #[must_use]
424 pub fn stats(&self) -> PagedModelStats {
425 let loaded_segments = self.segments.len();
426 let total_segments = self.metadata.segment_count;
427 let loaded_bytes: usize = self.segments.values().map(|s| s.size_bytes).sum();
428
429 PagedModelStats {
430 n: self.n,
431 total_commands: self.metadata.total_commands,
432 vocab_size: self.metadata.command_freq.len(),
433 total_segments,
434 loaded_segments,
435 memory_limit: self.memory_limit,
436 loaded_bytes,
437 bundle_path: self.bundle_path.clone(),
438 }
439 }
440
441 pub fn paging_stats(&self) -> Option<PagingStats> {
443 self.bundle.as_ref().map(|b| b.stats().clone())
444 }
445
446 #[allow(dead_code)]
448 pub fn prefetch_hint(&mut self, prefix: &str) {
449 if let Some(ref mut bundle) = self.bundle {
450 let _ = bundle.prefetch_hint(&format!("segment_{prefix}"));
451 }
452 }
453
454 #[must_use]
456 #[allow(dead_code)]
457 pub fn total_commands(&self) -> usize {
458 self.metadata.total_commands
459 }
460
461 #[must_use]
463 #[allow(dead_code)]
464 pub fn ngram_size(&self) -> usize {
465 self.n
466 }
467
468 #[must_use]
470 #[allow(dead_code)]
471 pub fn vocab_size(&self) -> usize {
472 self.metadata.command_freq.len()
473 }
474
475 #[must_use]
477 pub fn top_commands(&self, count: usize) -> Vec<(String, u32)> {
478 let mut cmds: Vec<_> = self
479 .metadata
480 .command_freq
481 .iter()
482 .map(|(k, v)| (k.clone(), *v))
483 .collect();
484 cmds.sort_by(|a, b| b.1.cmp(&a.1));
485 cmds.truncate(count);
486 cmds
487 }
488}
489
490#[derive(Debug, Clone)]
492pub struct PagedModelStats {
493 pub n: usize,
495 pub total_commands: usize,
497 pub vocab_size: usize,
499 pub total_segments: usize,
501 pub loaded_segments: usize,
503 pub memory_limit: usize,
505 pub loaded_bytes: usize,
507 pub bundle_path: Option<std::path::PathBuf>,
509}
510
511impl std::fmt::Display for PagedModelStats {
512 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
513 writeln!(f, "Paged Model Statistics:")?;
514 writeln!(f, " N-gram size: {}", self.n)?;
515 writeln!(f, " Total commands: {}", self.total_commands)?;
516 writeln!(f, " Vocabulary size: {}", self.vocab_size)?;
517 writeln!(
518 f,
519 " Segments: {}/{} loaded",
520 self.loaded_segments, self.total_segments
521 )?;
522 writeln!(
523 f,
524 " Memory limit: {:.1} MB",
525 self.memory_limit as f64 / 1024.0 / 1024.0
526 )?;
527 writeln!(
528 f,
529 " Loaded bytes: {:.1} KB",
530 self.loaded_bytes as f64 / 1024.0
531 )?;
532 if let Some(ref path) = self.bundle_path {
533 writeln!(f, " Bundle path: {}", path.display())?;
534 }
535 Ok(())
536 }
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542 use tempfile::tempdir;
543
544 #[test]
545 fn test_ngram_segment_new() {
546 let segment = NgramSegment::new("git".to_string());
547 assert_eq!(segment.prefix, "git");
548 assert!(segment.ngrams.is_empty());
549 assert_eq!(segment.size_bytes, 0);
550 }
551
552 #[test]
553 fn test_ngram_segment_add() {
554 let mut segment = NgramSegment::new("git".to_string());
555 segment.add("git".to_string(), "commit".to_string(), 1);
556 segment.add("git".to_string(), "commit".to_string(), 1);
557 segment.add("git".to_string(), "push".to_string(), 1);
558
559 assert_eq!(segment.ngrams.len(), 1);
560 let git_nexts = segment.ngrams.get("git").unwrap();
561 assert_eq!(git_nexts.get("commit"), Some(&2));
562 assert_eq!(git_nexts.get("push"), Some(&1));
563 }
564
565 #[test]
566 fn test_ngram_segment_serialization() {
567 let mut segment = NgramSegment::new("cargo".to_string());
568 segment.add("cargo".to_string(), "build".to_string(), 5);
569 segment.add("cargo".to_string(), "test".to_string(), 3);
570 segment.add("cargo build".to_string(), "--release".to_string(), 2);
571
572 let bytes = segment.to_bytes();
573 let restored = NgramSegment::from_bytes(&bytes).unwrap();
574
575 assert_eq!(restored.prefix, "cargo");
576 assert_eq!(restored.ngrams.len(), 2);
577 assert_eq!(restored.ngrams.get("cargo").unwrap().get("build"), Some(&5));
578 assert_eq!(restored.ngrams.get("cargo").unwrap().get("test"), Some(&3));
579 assert_eq!(
580 restored.ngrams.get("cargo build").unwrap().get("--release"),
581 Some(&2)
582 );
583 }
584
585 #[test]
586 fn test_paged_model_new() {
587 let model = PagedMarkovModel::new(3, 10);
588 assert_eq!(model.ngram_size(), 3);
589 assert!(model.memory_limit() >= MIN_MEMORY_LIMIT);
590 }
591
592 #[test]
593 fn test_paged_model_train() {
594 let commands = vec![
595 "git status".to_string(),
596 "git commit -m test".to_string(),
597 "git push".to_string(),
598 "cargo build".to_string(),
599 "cargo test".to_string(),
600 ];
601
602 let mut model = PagedMarkovModel::new(3, 10);
603 model.train(&commands);
604
605 assert_eq!(model.total_commands(), 5);
606 assert_eq!(model.vocab_size(), 5);
607
608 assert!(model.segments.contains_key("git"));
610 assert!(model.segments.contains_key("cargo"));
611 }
612
613 #[test]
614 fn test_paged_model_suggest() {
615 let commands = vec![
616 "git status".to_string(),
617 "git status".to_string(),
618 "git commit -m fix".to_string(),
619 "git push".to_string(),
620 ];
621
622 let mut model = PagedMarkovModel::new(3, 10);
623 model.train(&commands);
624
625 let suggestions = model.suggest("git ", 3);
626 assert!(!suggestions.is_empty());
627
628 let has_status = suggestions.iter().any(|(s, _)| s.contains("status"));
630 assert!(has_status);
631 }
632
633 #[test]
634 fn test_paged_model_save_load() {
635 let dir = tempdir().unwrap();
636 let path = dir.path().join("test.apbundle");
637
638 let commands = vec![
640 "git status".to_string(),
641 "git commit".to_string(),
642 "cargo build".to_string(),
643 ];
644
645 let mut model = PagedMarkovModel::new(3, 10);
646 model.train(&commands);
647 model.save(&path).unwrap();
648
649 let mut loaded = PagedMarkovModel::load(&path, 10).unwrap();
651
652 assert_eq!(loaded.total_commands(), 3);
653 assert_eq!(loaded.vocab_size(), 3);
654 assert_eq!(loaded.ngram_size(), 3);
655
656 let suggestions = loaded.suggest("git ", 3);
658 assert!(!suggestions.is_empty());
659 }
660
661 #[test]
662 fn test_paged_model_stats() {
663 let commands = vec![
664 "git status".to_string(),
665 "cargo build".to_string(),
666 "docker run".to_string(),
667 ];
668
669 let mut model = PagedMarkovModel::new(3, 10);
670 model.train(&commands);
671
672 let stats = model.stats();
673 assert_eq!(stats.n, 3);
674 assert_eq!(stats.total_commands, 3);
675 assert_eq!(stats.vocab_size, 3);
676 assert_eq!(stats.total_segments, 3); }
678
679 #[test]
680 fn test_paged_model_on_demand_loading() {
681 let dir = tempdir().unwrap();
682 let path = dir.path().join("ondemand.apbundle");
683
684 let commands = vec![
686 "git status".to_string(),
687 "git commit".to_string(),
688 "cargo build".to_string(),
689 "cargo test".to_string(),
690 "docker run".to_string(),
691 "kubectl get pods".to_string(),
692 ];
693
694 let mut model = PagedMarkovModel::new(3, 10);
695 model.train(&commands);
696 model.save(&path).unwrap();
697
698 let mut loaded = PagedMarkovModel::load(&path, 1).unwrap();
700
701 assert_eq!(loaded.stats().loaded_segments, 0);
703
704 let _ = loaded.suggest("git ", 3);
706 assert!(loaded.segments.contains_key("git"));
707
708 let _ = loaded.suggest("cargo ", 3);
710 assert!(loaded.segments.contains_key("cargo"));
711 }
712
713 #[test]
714 fn test_paged_model_prefetch_hint() {
715 let dir = tempdir().unwrap();
716 let path = dir.path().join("prefetch.apbundle");
717
718 let commands = vec!["git status".to_string(), "cargo build".to_string()];
719
720 let mut model = PagedMarkovModel::new(3, 10);
721 model.train(&commands);
722 model.save(&path).unwrap();
723
724 let mut loaded = PagedMarkovModel::load(&path, 10).unwrap();
725
726 loaded.prefetch_hint("git");
728
729 let suggestions = loaded.suggest("git ", 3);
731 assert!(!suggestions.is_empty());
732 }
733
734 #[test]
735 fn test_paged_model_top_commands() {
736 let commands = vec![
737 "git status".to_string(),
738 "git status".to_string(),
739 "git status".to_string(),
740 "cargo build".to_string(),
741 "cargo build".to_string(),
742 "docker run".to_string(),
743 ];
744
745 let mut model = PagedMarkovModel::new(3, 10);
746 model.train(&commands);
747
748 let top = model.top_commands(2);
749 assert_eq!(top.len(), 2);
750 assert_eq!(top[0].0, "git status");
751 assert_eq!(top[0].1, 3);
752 assert_eq!(top[1].0, "cargo build");
753 assert_eq!(top[1].1, 2);
754 }
755
756 #[test]
757 fn test_paged_model_empty_commands() {
758 let mut model = PagedMarkovModel::new(3, 10);
759 model.train(&[]);
760
761 assert_eq!(model.total_commands(), 0);
762 assert_eq!(model.vocab_size(), 0);
763
764 let suggestions = model.suggest("git ", 3);
765 assert!(suggestions.is_empty());
766 }
767
768 #[test]
769 fn test_ngram_segment_empty_bytes() {
770 let result = NgramSegment::from_bytes(&[]);
771 assert!(result.is_err());
772 }
773
774 #[test]
775 fn test_paged_model_stats_display() {
776 let mut model = PagedMarkovModel::new(3, 10);
777 model.train(&["git status".to_string()]);
778
779 let stats = model.stats();
780 let display = format!("{stats}");
781
782 assert!(display.contains("N-gram size:"));
783 assert!(display.contains("Total commands:"));
784 assert!(display.contains("Memory limit:"));
785 }
786}