aprender_shell/
paged_model_stats.rs1
2impl PagedMarkovModel {
3 #[must_use]
9 pub fn new(n: usize, memory_limit_mb: usize) -> Self {
10 let memory_limit = (memory_limit_mb * 1024 * 1024).max(MIN_MEMORY_LIMIT);
11 Self {
12 n: n.clamp(2, 5),
13 memory_limit,
14 metadata: PagedModelMetadata {
15 n,
16 total_commands: 0,
17 segment_count: 0,
18 command_freq: HashMap::new(),
19 segment_prefixes: Vec::new(),
20 },
21 bundle: None,
22 segments: HashMap::new(),
23 trie: Some(Trie::new()),
24 bundle_path: None,
25 }
26 }
27
28 #[must_use]
30 pub fn memory_limit(&self) -> usize {
31 self.memory_limit
32 }
33
34 pub fn train(&mut self, commands: &[String]) {
36 self.metadata.total_commands = commands.len();
37
38 for cmd in commands {
39 *self.metadata.command_freq.entry(cmd.clone()).or_insert(0) += 1;
41
42 if let Some(ref mut trie) = self.trie {
44 trie.insert(cmd);
45 }
46
47 let tokens: Vec<&str> = cmd.split_whitespace().collect();
49 if tokens.is_empty() {
50 continue;
51 }
52
53 let prefix = tokens[0].to_string();
55
56 let segment = self
58 .segments
59 .entry(prefix.clone())
60 .or_insert_with(|| NgramSegment::new(prefix));
61
62 segment.add(String::new(), tokens[0].to_string(), 1);
64
65 for i in 0..tokens.len() {
67 let context_start = i.saturating_sub(self.n - 1);
68 let context: String = tokens[context_start..=i].join(" ");
69
70 if i + 1 < tokens.len() {
71 segment.add(context, tokens[i + 1].to_string(), 1);
72 }
73 }
74 }
75
76 self.metadata.segment_count = self.segments.len();
78 self.metadata.segment_prefixes = self.segments.keys().cloned().collect();
79 }
80
81 pub fn save(&self, path: &Path) -> std::io::Result<()> {
83 let path_str = path.to_string_lossy().to_string();
84
85 let metadata_bytes = serde_json::to_vec(&self.metadata)
87 .map_err(|e| std::io::Error::other(format!("Failed to serialize metadata: {e}")))?;
88
89 let mut builder = BundleBuilder::new(&path_str)
90 .with_config(BundleConfig::new().with_compression(false))
91 .add_model("metadata", metadata_bytes);
92
93 for (prefix, segment) in &self.segments {
95 let segment_bytes = segment.to_bytes();
96 builder = builder.add_model(format!("segment_{prefix}"), segment_bytes);
97 }
98
99 builder
101 .build()
102 .map_err(|e| std::io::Error::other(format!("Failed to build bundle: {e}")))?;
103
104 Ok(())
105 }
106
107 pub fn load(path: &Path, memory_limit_mb: usize) -> std::io::Result<Self> {
109 let memory_limit = (memory_limit_mb * 1024 * 1024).max(MIN_MEMORY_LIMIT);
110
111 let paging_config = PagingConfig::new()
113 .with_max_memory(memory_limit)
114 .with_prefetch(true);
115
116 let mut bundle = PagedBundle::open(path, paging_config)
117 .map_err(|e| std::io::Error::other(format!("Failed to open bundle: {e}")))?;
118
119 let metadata_bytes = bundle
121 .get_model("metadata")
122 .map_err(|e| std::io::Error::other(format!("Failed to read metadata: {e}")))?;
123
124 let metadata: PagedModelMetadata = serde_json::from_slice(metadata_bytes)
125 .map_err(|e| std::io::Error::other(format!("Failed to parse metadata: {e}")))?;
126
127 let mut trie = Trie::new();
129 for cmd in metadata.command_freq.keys() {
130 trie.insert(cmd);
131 }
132
133 Ok(Self {
134 n: metadata.n,
135 memory_limit,
136 metadata,
137 bundle: Some(bundle),
138 segments: HashMap::new(), trie: Some(trie),
140 bundle_path: Some(path.to_path_buf()),
141 })
142 }
143
144 fn load_segment(&mut self, prefix: &str) -> std::io::Result<Option<NgramSegment>> {
146 if let Some(segment) = self.segments.get(prefix) {
147 return Ok(Some(segment.clone()));
148 }
149
150 if let Some(ref mut bundle) = self.bundle {
151 let model_name = format!("segment_{prefix}");
152 if bundle.model_names().iter().any(|n| *n == model_name) {
154 let bytes = bundle.get_model(&model_name).map_err(|e| {
155 std::io::Error::other(format!("Failed to read segment '{prefix}': {e}"))
156 })?;
157 let segment = NgramSegment::from_bytes(bytes)?;
158 self.segments.insert(prefix.to_string(), segment.clone());
159 return Ok(Some(segment));
160 }
161 }
162
163 Ok(None)
164 }
165
166 pub fn suggest(&mut self, prefix: &str, count: usize) -> Vec<(String, f32)> {
168 let ends_with_space = prefix.is_empty() || prefix.ends_with(' ');
170 let prefix = prefix.trim();
171 let tokens: Vec<&str> = prefix.split_whitespace().collect();
172
173 let mut suggestions = Vec::new();
174
175 if let Some(ref trie) = self.trie {
177 for cmd in trie.find_prefix(prefix, count * 4) {
178 let freq = self.metadata.command_freq.get(&cmd).copied().unwrap_or(1);
179 let score = freq as f32 / self.metadata.total_commands.max(1) as f32;
180 suggestions.push((cmd, score));
181 }
182 }
183
184 if !tokens.is_empty() && ends_with_space {
186 let segment_prefix = tokens[0];
187
188 if let Ok(Some(segment)) = self.load_segment(segment_prefix) {
190 let context_start = tokens.len().saturating_sub(self.n - 1);
191 let context = tokens[context_start..].join(" ");
192
193 if let Some(next_tokens) = segment.ngrams.get(&context) {
194 let total: u32 = next_tokens.values().sum();
195
196 for (token, ngram_count) in next_tokens {
197 let completion = format!("{} {}", prefix.trim(), token);
198 let score = *ngram_count as f32 / total as f32;
199
200 if !suggestions.iter().any(|(s, _)| s == &completion) {
201 suggestions.push((completion, score * 0.8));
202 }
203 }
204 }
205 }
206 }
207
208 suggestions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
210 suggestions.truncate(count);
211
212 suggestions
213 }
214
215 #[must_use]
217 pub fn stats(&self) -> PagedModelStats {
218 let loaded_segments = self.segments.len();
219 let total_segments = self.metadata.segment_count;
220 let loaded_bytes: usize = self.segments.values().map(|s| s.size_bytes).sum();
221
222 PagedModelStats {
223 n: self.n,
224 total_commands: self.metadata.total_commands,
225 vocab_size: self.metadata.command_freq.len(),
226 total_segments,
227 loaded_segments,
228 memory_limit: self.memory_limit,
229 loaded_bytes,
230 bundle_path: self.bundle_path.clone(),
231 }
232 }
233
234 pub fn paging_stats(&self) -> Option<PagingStats> {
236 self.bundle.as_ref().map(|b| b.stats().clone())
237 }
238
239 pub fn prefetch_hint(&mut self, prefix: &str) {
241 if let Some(ref mut bundle) = self.bundle {
242 let _ = bundle.prefetch_hint(&format!("segment_{prefix}"));
243 }
244 }
245
246 #[must_use]
248 pub fn total_commands(&self) -> usize {
249 self.metadata.total_commands
250 }
251
252 #[must_use]
254 pub fn ngram_size(&self) -> usize {
255 self.n
256 }
257
258 #[must_use]
260 pub fn vocab_size(&self) -> usize {
261 self.metadata.command_freq.len()
262 }
263
264 #[must_use]
266 pub fn top_commands(&self, count: usize) -> Vec<(String, u32)> {
267 let mut cmds: Vec<_> = self
268 .metadata
269 .command_freq
270 .iter()
271 .map(|(k, v)| (k.clone(), *v))
272 .collect();
273 cmds.sort_by(|a, b| b.1.cmp(&a.1));
274 cmds.truncate(count);
275 cmds
276 }
277}
278
279#[derive(Debug, Clone)]
281pub struct PagedModelStats {
282 pub n: usize,
284 pub total_commands: usize,
286 pub vocab_size: usize,
288 pub total_segments: usize,
290 pub loaded_segments: usize,
292 pub memory_limit: usize,
294 pub loaded_bytes: usize,
296 pub bundle_path: Option<std::path::PathBuf>,
298}
299
300impl std::fmt::Display for PagedModelStats {
301 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302 writeln!(f, "Paged Model Statistics:")?;
303 writeln!(f, " N-gram size: {}", self.n)?;
304 writeln!(f, " Total commands: {}", self.total_commands)?;
305 writeln!(f, " Vocabulary size: {}", self.vocab_size)?;
306 writeln!(
307 f,
308 " Segments: {}/{} loaded",
309 self.loaded_segments, self.total_segments
310 )?;
311 writeln!(
312 f,
313 " Memory limit: {:.1} MB",
314 self.memory_limit as f64 / 1024.0 / 1024.0
315 )?;
316 writeln!(
317 f,
318 " Loaded bytes: {:.1} KB",
319 self.loaded_bytes as f64 / 1024.0
320 )?;
321 if let Some(ref path) = self.bundle_path {
322 writeln!(f, " Bundle path: {}", path.display())?;
323 }
324 Ok(())
325 }
326}