byteforge/
patching.rs

1use crate::{ByteForgeConfig, Result};
2use std::collections::VecDeque;
3use ahash::AHashMap;
4
5#[derive(Debug, Clone)]
6pub struct Patch {
7    pub bytes: Vec<u8>,
8    pub start_pos: usize,
9    pub end_pos: usize,
10    pub complexity_score: f32,
11    pub patch_type: PatchType,
12}
13
14#[derive(Debug, Clone, PartialEq)]
15pub enum PatchType {
16    Simple,      // Low complexity, common patterns
17    Complex,     // High complexity, rare patterns
18    Semantic,    // Word/sentence boundaries
19    Repetitive,  // Repeated patterns
20    Structural,  // Code/markup structure
21}
22
23pub struct MultiSignalPatcher {
24    config: ByteForgeConfig,
25    entropy_cache: AHashMap<u64, f32>,
26    pattern_cache: AHashMap<Vec<u8>, f32>,
27    ngram_counts: AHashMap<Vec<u8>, u32>,
28    rolling_hash: u64,
29    window_buffer: VecDeque<u8>,
30}
31
32impl MultiSignalPatcher {
33    pub fn new(config: ByteForgeConfig) -> Self {
34        Self {
35            config,
36            entropy_cache: AHashMap::new(),
37            pattern_cache: AHashMap::new(),
38            ngram_counts: AHashMap::new(),
39            rolling_hash: 0,
40            window_buffer: VecDeque::new(),
41        }
42    }
43
44    pub fn patch_bytes(&mut self, bytes: &[u8]) -> Result<Vec<Patch>> {
45        let mut patches = Vec::new();
46        let mut current_patch_start = 0;
47        let mut i = 0;
48
49        while i < bytes.len() {
50            let signals = self.calculate_signals(&bytes, i)?;
51            
52            let should_split = self.should_split_patch(&signals, i - current_patch_start);
53            
54            if should_split || i - current_patch_start >= self.config.patch_size_range.1 {
55                if i > current_patch_start {
56                    let patch = self.create_patch(&bytes, current_patch_start, i, &signals)?;
57                    patches.push(patch);
58                    current_patch_start = i;
59                }
60            }
61            
62            i += 1;
63        }
64
65        if current_patch_start < bytes.len() {
66            let signals = self.calculate_signals(&bytes, bytes.len() - 1)?;
67            let patch = self.create_patch(&bytes, current_patch_start, bytes.len(), &signals)?;
68            patches.push(patch);
69        }
70
71        Ok(patches)
72    }
73
74    fn calculate_signals(&mut self, bytes: &[u8], pos: usize) -> Result<PatchingSignals> {
75        let entropy = self.calculate_fast_entropy(bytes, pos)?;
76        let compression_ratio = self.calculate_compression_ratio(bytes, pos)?;
77        let semantic_boundary = self.detect_semantic_boundary(bytes, pos)?;
78        let repetition_score = self.calculate_repetition_score(bytes, pos)?;
79        let structural_score = self.calculate_structural_score(bytes, pos)?;
80
81        Ok(PatchingSignals {
82            entropy,
83            compression_ratio,
84            semantic_boundary,
85            repetition_score,
86            structural_score,
87        })
88    }
89
90    fn should_split_patch(&self, signals: &PatchingSignals, current_length: usize) -> bool {
91        if current_length < self.config.patch_size_range.0 {
92            return false;
93        }
94
95        let entropy_trigger = signals.entropy > self.config.entropy_threshold;
96        let compression_trigger = signals.compression_ratio > self.config.compression_threshold;
97        let semantic_trigger = signals.semantic_boundary > self.config.semantic_weight;
98        let repetition_trigger = signals.repetition_score > 0.8;
99        let structural_trigger = signals.structural_score > 0.7;
100
101        let signal_count = [entropy_trigger, compression_trigger, semantic_trigger, 
102                           repetition_trigger, structural_trigger]
103            .iter()
104            .map(|&x| x as u32)
105            .sum::<u32>();
106
107        signal_count >= 2 || (signal_count >= 1 && current_length >= self.config.patch_size_range.1 / 2)
108    }
109
110    fn calculate_fast_entropy(&mut self, bytes: &[u8], pos: usize) -> Result<f32> {
111        if pos < 3 {
112            return Ok(0.0);
113        }
114
115        let ngram = &bytes[pos.saturating_sub(3)..=pos];
116        let hash = self.hash_bytes(ngram);
117        
118        if let Some(&cached_entropy) = self.entropy_cache.get(&hash) {
119            return Ok(cached_entropy);
120        }
121
122        let entropy = self.compute_ngram_entropy(ngram)?;
123        self.entropy_cache.insert(hash, entropy);
124        
125        Ok(entropy)
126    }
127
128    fn compute_ngram_entropy(&self, ngram: &[u8]) -> Result<f32> {
129        let mut counts = [0u32; 256];
130        let mut total = 0u32;
131        
132        for &byte in ngram {
133            counts[byte as usize] += 1;
134            total += 1;
135        }
136
137        let mut entropy = 0.0f32;
138        for count in counts.iter().filter(|&&c| c > 0) {
139            let p = *count as f32 / total as f32;
140            entropy -= p * p.log2();
141        }
142
143        Ok(entropy)
144    }
145
146    fn calculate_compression_ratio(&self, bytes: &[u8], pos: usize) -> Result<f32> {
147        if pos < 8 {
148            return Ok(0.0);
149        }
150
151        let window = &bytes[pos.saturating_sub(8)..=pos];
152        let original_size = window.len();
153        let compressed_size = self.estimate_compression_size(window)?;
154        
155        Ok(1.0 - (compressed_size as f32 / original_size as f32))
156    }
157
158    fn estimate_compression_size(&self, window: &[u8]) -> Result<usize> {
159        let mut unique_bytes = std::collections::HashSet::new();
160        let mut repeat_count = 0;
161        
162        for i in 0..window.len() {
163            if !unique_bytes.insert(window[i]) {
164                repeat_count += 1;
165            }
166        }
167
168        Ok(window.len() - repeat_count / 2)
169    }
170
171    fn detect_semantic_boundary(&self, bytes: &[u8], pos: usize) -> Result<f32> {
172        if pos == 0 {
173            return Ok(0.0);
174        }
175
176        let current_byte = bytes[pos];
177        let prev_byte = bytes[pos - 1];
178
179        let is_word_boundary = self.is_word_boundary(prev_byte, current_byte);
180        let is_sentence_boundary = self.is_sentence_boundary(prev_byte, current_byte);
181        let is_line_boundary = current_byte == b'\n';
182
183        let score = if is_sentence_boundary {
184            0.9
185        } else if is_line_boundary {
186            0.8
187        } else if is_word_boundary {
188            0.6
189        } else {
190            0.0
191        };
192
193        Ok(score)
194    }
195
196    fn is_word_boundary(&self, prev: u8, current: u8) -> bool {
197        let prev_is_alphanum = prev.is_ascii_alphanumeric();
198        let current_is_alphanum = current.is_ascii_alphanumeric();
199        let is_space_transition = prev == b' ' || current == b' ';
200        
201        (prev_is_alphanum != current_is_alphanum) || is_space_transition
202    }
203
204    fn is_sentence_boundary(&self, prev: u8, current: u8) -> bool {
205        matches!(prev, b'.' | b'!' | b'?') && (current == b' ' || current == b'\n')
206    }
207
208    fn calculate_repetition_score(&self, bytes: &[u8], pos: usize) -> Result<f32> {
209        if pos < 4 {
210            return Ok(0.0);
211        }
212
213        let window_size = 8.min(pos);
214        let window = &bytes[pos - window_size..pos];
215        
216        let mut max_repeat_length = 0;
217        
218        for pattern_len in 1..=window_size / 2 {
219            let pattern = &window[window_size - pattern_len..];
220            let mut repeat_count = 0;
221            
222            for i in (0..window_size - pattern_len).step_by(pattern_len) {
223                if &window[i..i + pattern_len] == pattern {
224                    repeat_count += 1;
225                } else {
226                    break;
227                }
228            }
229            
230            if repeat_count > 1 {
231                max_repeat_length = max_repeat_length.max(pattern_len * repeat_count);
232            }
233        }
234
235        Ok(max_repeat_length as f32 / window_size as f32)
236    }
237
238    fn calculate_structural_score(&self, bytes: &[u8], pos: usize) -> Result<f32> {
239        if pos == 0 {
240            return Ok(0.0);
241        }
242
243        let current_byte = bytes[pos];
244        
245        let structural_chars = [b'{', b'}', b'[', b']', b'(', b')', b'<', b'>', b';', b':', b','];
246        let is_structural = structural_chars.contains(&current_byte);
247        
248        if is_structural {
249            Ok(0.8)
250        } else {
251            Ok(0.0)
252        }
253    }
254
255    fn create_patch(&self, bytes: &[u8], start: usize, end: usize, signals: &PatchingSignals) -> Result<Patch> {
256        let patch_bytes = bytes[start..end].to_vec();
257        let complexity_score = self.calculate_complexity_score(signals);
258        let patch_type = self.determine_patch_type(signals);
259
260        Ok(Patch {
261            bytes: patch_bytes,
262            start_pos: start,
263            end_pos: end,
264            complexity_score,
265            patch_type,
266        })
267    }
268
269    fn calculate_complexity_score(&self, signals: &PatchingSignals) -> f32 {
270        let weights = [0.3, 0.2, 0.2, 0.15, 0.15];
271        let values = [
272            signals.entropy,
273            signals.compression_ratio,
274            signals.semantic_boundary,
275            signals.repetition_score,
276            signals.structural_score,
277        ];
278
279        weights.iter().zip(values.iter()).map(|(w, v)| w * v).sum()
280    }
281
282    fn determine_patch_type(&self, signals: &PatchingSignals) -> PatchType {
283        if signals.repetition_score > 0.7 {
284            PatchType::Repetitive
285        } else if signals.structural_score > 0.6 {
286            PatchType::Structural
287        } else if signals.semantic_boundary > 0.5 {
288            PatchType::Semantic
289        } else if signals.entropy > 0.7 {
290            PatchType::Complex
291        } else {
292            PatchType::Simple
293        }
294    }
295
296    fn hash_bytes(&self, bytes: &[u8]) -> u64 {
297        use std::hash::{Hash, Hasher};
298        let mut hasher = ahash::AHasher::default();
299        bytes.hash(&mut hasher);
300        hasher.finish()
301    }
302}
303
304#[derive(Debug, Clone)]
305struct PatchingSignals {
306    entropy: f32,
307    compression_ratio: f32,
308    semantic_boundary: f32,
309    repetition_score: f32,
310    structural_score: f32,
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_basic_patching() {
319        let config = ByteForgeConfig::default();
320        let mut patcher = MultiSignalPatcher::new(config);
321        
322        let text = b"Hello world! This is a test.";
323        let patches = patcher.patch_bytes(text).unwrap();
324        
325        assert!(!patches.is_empty());
326        assert!(patches.iter().all(|p| !p.bytes.is_empty()));
327    }
328
329    #[test]
330    fn test_semantic_boundary_detection() {
331        let config = ByteForgeConfig::default();
332        let patcher = MultiSignalPatcher::new(config);
333        
334        let text = b"Hello world";
335        let score = patcher.detect_semantic_boundary(text, 5).unwrap(); // space position
336        assert!(score > 0.0);
337    }
338}