Skip to main content

aprender_shell/
paged_model_stats.rs

1
2impl PagedMarkovModel {
3    /// Create a new paged model with given n-gram size and memory limit.
4    ///
5    /// # Arguments
6    /// * `n` - N-gram size (2-5)
7    /// * `memory_limit_mb` - Maximum memory usage in megabytes
8    #[must_use]
9    pub fn new(n: usize, memory_limit_mb: usize) -> Self {
10        let memory_limit = (memory_limit_mb * 1024 * 1024).max(MIN_MEMORY_LIMIT);
11        Self {
12            n: n.clamp(2, 5),
13            memory_limit,
14            metadata: PagedModelMetadata {
15                n,
16                total_commands: 0,
17                segment_count: 0,
18                command_freq: HashMap::new(),
19                segment_prefixes: Vec::new(),
20            },
21            bundle: None,
22            segments: HashMap::new(),
23            trie: Some(Trie::new()),
24            bundle_path: None,
25        }
26    }
27
28    /// Get memory limit in bytes.
29    #[must_use]
30    pub fn memory_limit(&self) -> usize {
31        self.memory_limit
32    }
33
34    /// Train on a list of commands.
35    pub fn train(&mut self, commands: &[String]) {
36        self.metadata.total_commands = commands.len();
37
38        for cmd in commands {
39            // Track command frequency
40            *self.metadata.command_freq.entry(cmd.clone()).or_insert(0) += 1;
41
42            // Add to trie
43            if let Some(ref mut trie) = self.trie {
44                trie.insert(cmd);
45            }
46
47            // Tokenize command
48            let tokens: Vec<&str> = cmd.split_whitespace().collect();
49            if tokens.is_empty() {
50                continue;
51            }
52
53            // Determine segment prefix (first token)
54            let prefix = tokens[0].to_string();
55
56            // Get or create segment
57            let segment = self
58                .segments
59                .entry(prefix.clone())
60                .or_insert_with(|| NgramSegment::new(prefix));
61
62            // Empty context predicts first token
63            segment.add(String::new(), tokens[0].to_string(), 1);
64
65            // Build context n-grams
66            for i in 0..tokens.len() {
67                let context_start = i.saturating_sub(self.n - 1);
68                let context: String = tokens[context_start..=i].join(" ");
69
70                if i + 1 < tokens.len() {
71                    segment.add(context, tokens[i + 1].to_string(), 1);
72                }
73            }
74        }
75
76        // Update metadata
77        self.metadata.segment_count = self.segments.len();
78        self.metadata.segment_prefixes = self.segments.keys().cloned().collect();
79    }
80
81    /// Save model to a paged bundle file.
82    pub fn save(&self, path: &Path) -> std::io::Result<()> {
83        let path_str = path.to_string_lossy().to_string();
84
85        // Add metadata as first model
86        let metadata_bytes = serde_json::to_vec(&self.metadata)
87            .map_err(|e| std::io::Error::other(format!("Failed to serialize metadata: {e}")))?;
88
89        let mut builder = BundleBuilder::new(&path_str)
90            .with_config(BundleConfig::new().with_compression(false))
91            .add_model("metadata", metadata_bytes);
92
93        // Add each segment as a separate model
94        for (prefix, segment) in &self.segments {
95            let segment_bytes = segment.to_bytes();
96            builder = builder.add_model(format!("segment_{prefix}"), segment_bytes);
97        }
98
99        // Build and save
100        builder
101            .build()
102            .map_err(|e| std::io::Error::other(format!("Failed to build bundle: {e}")))?;
103
104        Ok(())
105    }
106
107    /// Load model from a paged bundle file with memory limit.
108    pub fn load(path: &Path, memory_limit_mb: usize) -> std::io::Result<Self> {
109        let memory_limit = (memory_limit_mb * 1024 * 1024).max(MIN_MEMORY_LIMIT);
110
111        // Open as paged bundle
112        let paging_config = PagingConfig::new()
113            .with_max_memory(memory_limit)
114            .with_prefetch(true);
115
116        let mut bundle = PagedBundle::open(path, paging_config)
117            .map_err(|e| std::io::Error::other(format!("Failed to open bundle: {e}")))?;
118
119        // Load metadata (always in memory)
120        let metadata_bytes = bundle
121            .get_model("metadata")
122            .map_err(|e| std::io::Error::other(format!("Failed to read metadata: {e}")))?;
123
124        let metadata: PagedModelMetadata = serde_json::from_slice(metadata_bytes)
125            .map_err(|e| std::io::Error::other(format!("Failed to parse metadata: {e}")))?;
126
127        // Rebuild trie from command_freq
128        let mut trie = Trie::new();
129        for cmd in metadata.command_freq.keys() {
130            trie.insert(cmd);
131        }
132
133        Ok(Self {
134            n: metadata.n,
135            memory_limit,
136            metadata,
137            bundle: Some(bundle),
138            segments: HashMap::new(), // Loaded on demand
139            trie: Some(trie),
140            bundle_path: Some(path.to_path_buf()),
141        })
142    }
143
144    /// Load a specific segment on demand.
145    fn load_segment(&mut self, prefix: &str) -> std::io::Result<Option<NgramSegment>> {
146        if let Some(segment) = self.segments.get(prefix) {
147            return Ok(Some(segment.clone()));
148        }
149
150        if let Some(ref mut bundle) = self.bundle {
151            let model_name = format!("segment_{prefix}");
152            // Check if model exists by looking at model names
153            if bundle.model_names().iter().any(|n| *n == model_name) {
154                let bytes = bundle.get_model(&model_name).map_err(|e| {
155                    std::io::Error::other(format!("Failed to read segment '{prefix}': {e}"))
156                })?;
157                let segment = NgramSegment::from_bytes(bytes)?;
158                self.segments.insert(prefix.to_string(), segment.clone());
159                return Ok(Some(segment));
160            }
161        }
162
163        Ok(None)
164    }
165
166    /// Suggest completions for a prefix.
167    pub fn suggest(&mut self, prefix: &str, count: usize) -> Vec<(String, f32)> {
168        // Check for trailing space BEFORE trimming
169        let ends_with_space = prefix.is_empty() || prefix.ends_with(' ');
170        let prefix = prefix.trim();
171        let tokens: Vec<&str> = prefix.split_whitespace().collect();
172
173        let mut suggestions = Vec::new();
174
175        // Strategy 1: Trie prefix match for exact commands
176        if let Some(ref trie) = self.trie {
177            for cmd in trie.find_prefix(prefix, count * 4) {
178                let freq = self.metadata.command_freq.get(&cmd).copied().unwrap_or(1);
179                let score = freq as f32 / self.metadata.total_commands.max(1) as f32;
180                suggestions.push((cmd, score));
181            }
182        }
183
184        // Strategy 2: N-gram prediction (load segment on demand)
185        if !tokens.is_empty() && ends_with_space {
186            let segment_prefix = tokens[0];
187
188            // Load segment on demand
189            if let Ok(Some(segment)) = self.load_segment(segment_prefix) {
190                let context_start = tokens.len().saturating_sub(self.n - 1);
191                let context = tokens[context_start..].join(" ");
192
193                if let Some(next_tokens) = segment.ngrams.get(&context) {
194                    let total: u32 = next_tokens.values().sum();
195
196                    for (token, ngram_count) in next_tokens {
197                        let completion = format!("{} {}", prefix.trim(), token);
198                        let score = *ngram_count as f32 / total as f32;
199
200                        if !suggestions.iter().any(|(s, _)| s == &completion) {
201                            suggestions.push((completion, score * 0.8));
202                        }
203                    }
204                }
205            }
206        }
207
208        // Sort by score and take top count
209        suggestions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
210        suggestions.truncate(count);
211
212        suggestions
213    }
214
215    /// Get model statistics.
216    #[must_use]
217    pub fn stats(&self) -> PagedModelStats {
218        let loaded_segments = self.segments.len();
219        let total_segments = self.metadata.segment_count;
220        let loaded_bytes: usize = self.segments.values().map(|s| s.size_bytes).sum();
221
222        PagedModelStats {
223            n: self.n,
224            total_commands: self.metadata.total_commands,
225            vocab_size: self.metadata.command_freq.len(),
226            total_segments,
227            loaded_segments,
228            memory_limit: self.memory_limit,
229            loaded_bytes,
230            bundle_path: self.bundle_path.clone(),
231        }
232    }
233
234    /// Get paging statistics from the bundle.
235    pub fn paging_stats(&self) -> Option<PagingStats> {
236        self.bundle.as_ref().map(|b| b.stats().clone())
237    }
238
239    /// Hint that a segment will be needed soon (for prefetching).
240    pub fn prefetch_hint(&mut self, prefix: &str) {
241        if let Some(ref mut bundle) = self.bundle {
242            let _ = bundle.prefetch_hint(&format!("segment_{prefix}"));
243        }
244    }
245
246    /// Total commands trained on.
247    #[must_use]
248    pub fn total_commands(&self) -> usize {
249        self.metadata.total_commands
250    }
251
252    /// N-gram size.
253    #[must_use]
254    pub fn ngram_size(&self) -> usize {
255        self.n
256    }
257
258    /// Vocabulary size.
259    #[must_use]
260    pub fn vocab_size(&self) -> usize {
261        self.metadata.command_freq.len()
262    }
263
264    /// Top commands by frequency.
265    #[must_use]
266    pub fn top_commands(&self, count: usize) -> Vec<(String, u32)> {
267        let mut cmds: Vec<_> = self
268            .metadata
269            .command_freq
270            .iter()
271            .map(|(k, v)| (k.clone(), *v))
272            .collect();
273        cmds.sort_by(|a, b| b.1.cmp(&a.1));
274        cmds.truncate(count);
275        cmds
276    }
277}
278
279/// Statistics for a paged model.
280#[derive(Debug, Clone)]
281pub struct PagedModelStats {
282    /// N-gram size
283    pub n: usize,
284    /// Total commands trained on
285    pub total_commands: usize,
286    /// Vocabulary size
287    pub vocab_size: usize,
288    /// Total number of segments
289    pub total_segments: usize,
290    /// Number of loaded segments
291    pub loaded_segments: usize,
292    /// Memory limit in bytes
293    pub memory_limit: usize,
294    /// Currently loaded bytes
295    pub loaded_bytes: usize,
296    /// Path to bundle file (if loaded from file)
297    pub bundle_path: Option<std::path::PathBuf>,
298}
299
300impl std::fmt::Display for PagedModelStats {
301    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302        writeln!(f, "Paged Model Statistics:")?;
303        writeln!(f, "  N-gram size:      {}", self.n)?;
304        writeln!(f, "  Total commands:   {}", self.total_commands)?;
305        writeln!(f, "  Vocabulary size:  {}", self.vocab_size)?;
306        writeln!(
307            f,
308            "  Segments:         {}/{} loaded",
309            self.loaded_segments, self.total_segments
310        )?;
311        writeln!(
312            f,
313            "  Memory limit:     {:.1} MB",
314            self.memory_limit as f64 / 1024.0 / 1024.0
315        )?;
316        writeln!(
317            f,
318            "  Loaded bytes:     {:.1} KB",
319            self.loaded_bytes as f64 / 1024.0
320        )?;
321        if let Some(ref path) = self.bundle_path {
322            writeln!(f, "  Bundle path:      {}", path.display())?;
323        }
324        Ok(())
325    }
326}