Skip to main content

aprender_shell/
paged_model.rs

1//! Memory-Paged Markov Model for Large Shell Histories
2//!
3//! Uses aprender's bundle module for efficient memory management when
4//! dealing with large shell histories that exceed available RAM.
5
6use aprender::bundle::{BundleBuilder, BundleConfig, PagedBundle, PagingConfig, PagingStats};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10
11use crate::trie::Trie;
12
13/// Minimum memory limit (1MB)
14const MIN_MEMORY_LIMIT: usize = 1024 * 1024;
15
16/// N-gram segment for paged storage.
17///
18/// Each segment contains n-grams for a specific context prefix,
19/// allowing on-demand loading.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct NgramSegment {
22    /// Context prefix this segment covers (e.g., "git", "cargo")
23    pub prefix: String,
24    /// N-gram data: context -> (next_token -> count)
25    pub ngrams: HashMap<String, HashMap<String, u32>>,
26    /// Size estimate in bytes
27    pub size_bytes: usize,
28}
29
30impl NgramSegment {
31    /// Create a new empty segment.
32    #[must_use]
33    pub fn new(prefix: String) -> Self {
34        Self {
35            prefix,
36            ngrams: HashMap::new(),
37            size_bytes: 0,
38        }
39    }
40
41    /// Add an n-gram to this segment.
42    pub fn add(&mut self, context: String, next_token: String, count: u32) {
43        let entry = self.ngrams.entry(context).or_default();
44        *entry.entry(next_token).or_insert(0) += count;
45        self.update_size();
46    }
47
48    /// Update size estimate.
49    fn update_size(&mut self) {
50        self.size_bytes = self
51            .ngrams
52            .iter()
53            .map(|(k, v)| k.len() + v.keys().map(|k2| k2.len() + 4).sum::<usize>())
54            .sum();
55    }
56
57    /// Serialize to bytes.
58    pub fn to_bytes(&self) -> Vec<u8> {
59        // Simple binary format: prefix_len(4) + prefix + ngram_count(4) + ngrams
60        let mut bytes = Vec::new();
61
62        // Prefix
63        let prefix_bytes = self.prefix.as_bytes();
64        bytes.extend(&(prefix_bytes.len() as u32).to_le_bytes());
65        bytes.extend(prefix_bytes);
66
67        // N-gram count
68        bytes.extend(&(self.ngrams.len() as u32).to_le_bytes());
69
70        for (context, next_tokens) in &self.ngrams {
71            // Context
72            let ctx_bytes = context.as_bytes();
73            bytes.extend(&(ctx_bytes.len() as u32).to_le_bytes());
74            bytes.extend(ctx_bytes);
75
76            // Next tokens count
77            bytes.extend(&(next_tokens.len() as u32).to_le_bytes());
78
79            for (token, count) in next_tokens {
80                // Token
81                let tok_bytes = token.as_bytes();
82                bytes.extend(&(tok_bytes.len() as u32).to_le_bytes());
83                bytes.extend(tok_bytes);
84                // Count
85                bytes.extend(&count.to_le_bytes());
86            }
87        }
88
89        bytes
90    }
91
92    /// Deserialize from bytes.
93    pub fn from_bytes(bytes: &[u8]) -> std::io::Result<Self> {
94        let mut pos = 0;
95
96        // Helper to read 4 bytes as u32
97        let read_u32 = |data: &[u8], offset: usize| -> std::io::Result<u32> {
98            let slice = data
99                .get(offset..offset + 4)
100                .ok_or_else(|| std::io::Error::other("Truncated segment data"))?;
101            let arr: [u8; 4] = slice
102                .try_into()
103                .map_err(|_| std::io::Error::other("Invalid byte slice"))?;
104            Ok(u32::from_le_bytes(arr))
105        };
106
107        // Read prefix
108        let prefix_len = read_u32(bytes, pos)? as usize;
109        pos += 4;
110
111        if bytes.len() < pos + prefix_len {
112            return Err(std::io::Error::other("Truncated prefix"));
113        }
114        let prefix = String::from_utf8_lossy(&bytes[pos..pos + prefix_len]).to_string();
115        pos += prefix_len;
116
117        // Read n-gram count
118        let ngram_count = read_u32(bytes, pos)? as usize;
119        pos += 4;
120
121        let mut ngrams = HashMap::with_capacity(ngram_count);
122
123        for _ in 0..ngram_count {
124            // Read context
125            let ctx_len = read_u32(bytes, pos)? as usize;
126            pos += 4;
127
128            if bytes.len() < pos + ctx_len {
129                return Err(std::io::Error::other("Truncated context"));
130            }
131            let context = String::from_utf8_lossy(&bytes[pos..pos + ctx_len]).to_string();
132            pos += ctx_len;
133
134            // Read next tokens count
135            let token_count = read_u32(bytes, pos)? as usize;
136            pos += 4;
137
138            let mut next_tokens = HashMap::with_capacity(token_count);
139
140            for _ in 0..token_count {
141                // Read token
142                let tok_len = read_u32(bytes, pos)? as usize;
143                pos += 4;
144
145                if bytes.len() < pos + tok_len {
146                    return Err(std::io::Error::other("Truncated token"));
147                }
148                let token = String::from_utf8_lossy(&bytes[pos..pos + tok_len]).to_string();
149                pos += tok_len;
150
151                // Read count
152                let count = read_u32(bytes, pos)?;
153                pos += 4;
154
155                next_tokens.insert(token, count);
156            }
157
158            ngrams.insert(context, next_tokens);
159        }
160
161        let mut segment = Self {
162            prefix,
163            ngrams,
164            size_bytes: 0,
165        };
166        segment.update_size();
167        Ok(segment)
168    }
169}
170
171/// Model metadata stored in the bundle manifest.
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct PagedModelMetadata {
174    /// N-gram size
175    pub n: usize,
176    /// Total commands trained on
177    pub total_commands: usize,
178    /// Number of segments
179    pub segment_count: usize,
180    /// Command frequency map (kept in memory - relatively small)
181    pub command_freq: HashMap<String, u32>,
182    /// Segment prefixes for index lookup
183    pub segment_prefixes: Vec<String>,
184}
185
186/// Memory-paged Markov model for shell command prediction.
187///
188/// Uses aprender's bundle module to store n-gram data on disk and
189/// load segments on-demand, enabling handling of large shell histories
190/// without exhausting RAM.
191pub struct PagedMarkovModel {
192    /// N-gram size
193    n: usize,
194    /// Memory limit in bytes
195    memory_limit: usize,
196    /// Metadata
197    metadata: PagedModelMetadata,
198    /// Paged bundle (when loaded from file)
199    bundle: Option<PagedBundle>,
200    /// In-memory segments (for training/small models)
201    segments: HashMap<String, NgramSegment>,
202    /// Prefix trie for fast lookup
203    trie: Option<Trie>,
204    /// Path to bundle file (if loaded)
205    bundle_path: Option<std::path::PathBuf>,
206}
207
208include!("paged_model_stats.rs");
209include!("paged_model_ngram_segment.rs");