aprender_shell/
paged_model.rs1use aprender::bundle::{BundleBuilder, BundleConfig, PagedBundle, PagingConfig, PagingStats};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10
11use crate::trie::Trie;
12
13const MIN_MEMORY_LIMIT: usize = 1024 * 1024;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct NgramSegment {
22 pub prefix: String,
24 pub ngrams: HashMap<String, HashMap<String, u32>>,
26 pub size_bytes: usize,
28}
29
30impl NgramSegment {
31 #[must_use]
33 pub fn new(prefix: String) -> Self {
34 Self {
35 prefix,
36 ngrams: HashMap::new(),
37 size_bytes: 0,
38 }
39 }
40
41 pub fn add(&mut self, context: String, next_token: String, count: u32) {
43 let entry = self.ngrams.entry(context).or_default();
44 *entry.entry(next_token).or_insert(0) += count;
45 self.update_size();
46 }
47
48 fn update_size(&mut self) {
50 self.size_bytes = self
51 .ngrams
52 .iter()
53 .map(|(k, v)| k.len() + v.keys().map(|k2| k2.len() + 4).sum::<usize>())
54 .sum();
55 }
56
57 pub fn to_bytes(&self) -> Vec<u8> {
59 let mut bytes = Vec::new();
61
62 let prefix_bytes = self.prefix.as_bytes();
64 bytes.extend(&(prefix_bytes.len() as u32).to_le_bytes());
65 bytes.extend(prefix_bytes);
66
67 bytes.extend(&(self.ngrams.len() as u32).to_le_bytes());
69
70 for (context, next_tokens) in &self.ngrams {
71 let ctx_bytes = context.as_bytes();
73 bytes.extend(&(ctx_bytes.len() as u32).to_le_bytes());
74 bytes.extend(ctx_bytes);
75
76 bytes.extend(&(next_tokens.len() as u32).to_le_bytes());
78
79 for (token, count) in next_tokens {
80 let tok_bytes = token.as_bytes();
82 bytes.extend(&(tok_bytes.len() as u32).to_le_bytes());
83 bytes.extend(tok_bytes);
84 bytes.extend(&count.to_le_bytes());
86 }
87 }
88
89 bytes
90 }
91
92 pub fn from_bytes(bytes: &[u8]) -> std::io::Result<Self> {
94 let mut pos = 0;
95
96 let read_u32 = |data: &[u8], offset: usize| -> std::io::Result<u32> {
98 let slice = data
99 .get(offset..offset + 4)
100 .ok_or_else(|| std::io::Error::other("Truncated segment data"))?;
101 let arr: [u8; 4] = slice
102 .try_into()
103 .map_err(|_| std::io::Error::other("Invalid byte slice"))?;
104 Ok(u32::from_le_bytes(arr))
105 };
106
107 let prefix_len = read_u32(bytes, pos)? as usize;
109 pos += 4;
110
111 if bytes.len() < pos + prefix_len {
112 return Err(std::io::Error::other("Truncated prefix"));
113 }
114 let prefix = String::from_utf8_lossy(&bytes[pos..pos + prefix_len]).to_string();
115 pos += prefix_len;
116
117 let ngram_count = read_u32(bytes, pos)? as usize;
119 pos += 4;
120
121 let mut ngrams = HashMap::with_capacity(ngram_count);
122
123 for _ in 0..ngram_count {
124 let ctx_len = read_u32(bytes, pos)? as usize;
126 pos += 4;
127
128 if bytes.len() < pos + ctx_len {
129 return Err(std::io::Error::other("Truncated context"));
130 }
131 let context = String::from_utf8_lossy(&bytes[pos..pos + ctx_len]).to_string();
132 pos += ctx_len;
133
134 let token_count = read_u32(bytes, pos)? as usize;
136 pos += 4;
137
138 let mut next_tokens = HashMap::with_capacity(token_count);
139
140 for _ in 0..token_count {
141 let tok_len = read_u32(bytes, pos)? as usize;
143 pos += 4;
144
145 if bytes.len() < pos + tok_len {
146 return Err(std::io::Error::other("Truncated token"));
147 }
148 let token = String::from_utf8_lossy(&bytes[pos..pos + tok_len]).to_string();
149 pos += tok_len;
150
151 let count = read_u32(bytes, pos)?;
153 pos += 4;
154
155 next_tokens.insert(token, count);
156 }
157
158 ngrams.insert(context, next_tokens);
159 }
160
161 let mut segment = Self {
162 prefix,
163 ngrams,
164 size_bytes: 0,
165 };
166 segment.update_size();
167 Ok(segment)
168 }
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct PagedModelMetadata {
174 pub n: usize,
176 pub total_commands: usize,
178 pub segment_count: usize,
180 pub command_freq: HashMap<String, u32>,
182 pub segment_prefixes: Vec<String>,
184}
185
186pub struct PagedMarkovModel {
192 n: usize,
194 memory_limit: usize,
196 metadata: PagedModelMetadata,
198 bundle: Option<PagedBundle>,
200 segments: HashMap<String, NgramSegment>,
202 trie: Option<Trie>,
204 bundle_path: Option<std::path::PathBuf>,
206}
207
208include!("paged_model_stats.rs");
209include!("paged_model_ngram_segment.rs");