1use crate::{ByteForgeConfig, Result};
2use std::collections::VecDeque;
3use ahash::AHashMap;
4
5#[derive(Debug, Clone)]
6pub struct Patch {
7 pub bytes: Vec<u8>,
8 pub start_pos: usize,
9 pub end_pos: usize,
10 pub complexity_score: f32,
11 pub patch_type: PatchType,
12}
13
14#[derive(Debug, Clone, PartialEq)]
15pub enum PatchType {
16 Simple, Complex, Semantic, Repetitive, Structural, }
22
23pub struct MultiSignalPatcher {
24 config: ByteForgeConfig,
25 entropy_cache: AHashMap<u64, f32>,
26 pattern_cache: AHashMap<Vec<u8>, f32>,
27 ngram_counts: AHashMap<Vec<u8>, u32>,
28 rolling_hash: u64,
29 window_buffer: VecDeque<u8>,
30}
31
32impl MultiSignalPatcher {
33 pub fn new(config: ByteForgeConfig) -> Self {
34 Self {
35 config,
36 entropy_cache: AHashMap::new(),
37 pattern_cache: AHashMap::new(),
38 ngram_counts: AHashMap::new(),
39 rolling_hash: 0,
40 window_buffer: VecDeque::new(),
41 }
42 }
43
44 pub fn patch_bytes(&mut self, bytes: &[u8]) -> Result<Vec<Patch>> {
45 let mut patches = Vec::new();
46 let mut current_patch_start = 0;
47 let mut i = 0;
48
49 while i < bytes.len() {
50 let signals = self.calculate_signals(&bytes, i)?;
51
52 let should_split = self.should_split_patch(&signals, i - current_patch_start);
53
54 if should_split || i - current_patch_start >= self.config.patch_size_range.1 {
55 if i > current_patch_start {
56 let patch = self.create_patch(&bytes, current_patch_start, i, &signals)?;
57 patches.push(patch);
58 current_patch_start = i;
59 }
60 }
61
62 i += 1;
63 }
64
65 if current_patch_start < bytes.len() {
66 let signals = self.calculate_signals(&bytes, bytes.len() - 1)?;
67 let patch = self.create_patch(&bytes, current_patch_start, bytes.len(), &signals)?;
68 patches.push(patch);
69 }
70
71 Ok(patches)
72 }
73
74 fn calculate_signals(&mut self, bytes: &[u8], pos: usize) -> Result<PatchingSignals> {
75 let entropy = self.calculate_fast_entropy(bytes, pos)?;
76 let compression_ratio = self.calculate_compression_ratio(bytes, pos)?;
77 let semantic_boundary = self.detect_semantic_boundary(bytes, pos)?;
78 let repetition_score = self.calculate_repetition_score(bytes, pos)?;
79 let structural_score = self.calculate_structural_score(bytes, pos)?;
80
81 Ok(PatchingSignals {
82 entropy,
83 compression_ratio,
84 semantic_boundary,
85 repetition_score,
86 structural_score,
87 })
88 }
89
90 fn should_split_patch(&self, signals: &PatchingSignals, current_length: usize) -> bool {
91 if current_length < self.config.patch_size_range.0 {
92 return false;
93 }
94
95 let entropy_trigger = signals.entropy > self.config.entropy_threshold;
96 let compression_trigger = signals.compression_ratio > self.config.compression_threshold;
97 let semantic_trigger = signals.semantic_boundary > self.config.semantic_weight;
98 let repetition_trigger = signals.repetition_score > 0.8;
99 let structural_trigger = signals.structural_score > 0.7;
100
101 let signal_count = [entropy_trigger, compression_trigger, semantic_trigger,
102 repetition_trigger, structural_trigger]
103 .iter()
104 .map(|&x| x as u32)
105 .sum::<u32>();
106
107 signal_count >= 2 || (signal_count >= 1 && current_length >= self.config.patch_size_range.1 / 2)
108 }
109
110 fn calculate_fast_entropy(&mut self, bytes: &[u8], pos: usize) -> Result<f32> {
111 if pos < 3 {
112 return Ok(0.0);
113 }
114
115 let ngram = &bytes[pos.saturating_sub(3)..=pos];
116 let hash = self.hash_bytes(ngram);
117
118 if let Some(&cached_entropy) = self.entropy_cache.get(&hash) {
119 return Ok(cached_entropy);
120 }
121
122 let entropy = self.compute_ngram_entropy(ngram)?;
123 self.entropy_cache.insert(hash, entropy);
124
125 Ok(entropy)
126 }
127
128 fn compute_ngram_entropy(&self, ngram: &[u8]) -> Result<f32> {
129 let mut counts = [0u32; 256];
130 let mut total = 0u32;
131
132 for &byte in ngram {
133 counts[byte as usize] += 1;
134 total += 1;
135 }
136
137 let mut entropy = 0.0f32;
138 for count in counts.iter().filter(|&&c| c > 0) {
139 let p = *count as f32 / total as f32;
140 entropy -= p * p.log2();
141 }
142
143 Ok(entropy)
144 }
145
146 fn calculate_compression_ratio(&self, bytes: &[u8], pos: usize) -> Result<f32> {
147 if pos < 8 {
148 return Ok(0.0);
149 }
150
151 let window = &bytes[pos.saturating_sub(8)..=pos];
152 let original_size = window.len();
153 let compressed_size = self.estimate_compression_size(window)?;
154
155 Ok(1.0 - (compressed_size as f32 / original_size as f32))
156 }
157
158 fn estimate_compression_size(&self, window: &[u8]) -> Result<usize> {
159 let mut unique_bytes = std::collections::HashSet::new();
160 let mut repeat_count = 0;
161
162 for i in 0..window.len() {
163 if !unique_bytes.insert(window[i]) {
164 repeat_count += 1;
165 }
166 }
167
168 Ok(window.len() - repeat_count / 2)
169 }
170
171 fn detect_semantic_boundary(&self, bytes: &[u8], pos: usize) -> Result<f32> {
172 if pos == 0 {
173 return Ok(0.0);
174 }
175
176 let current_byte = bytes[pos];
177 let prev_byte = bytes[pos - 1];
178
179 let is_word_boundary = self.is_word_boundary(prev_byte, current_byte);
180 let is_sentence_boundary = self.is_sentence_boundary(prev_byte, current_byte);
181 let is_line_boundary = current_byte == b'\n';
182
183 let score = if is_sentence_boundary {
184 0.9
185 } else if is_line_boundary {
186 0.8
187 } else if is_word_boundary {
188 0.6
189 } else {
190 0.0
191 };
192
193 Ok(score)
194 }
195
196 fn is_word_boundary(&self, prev: u8, current: u8) -> bool {
197 let prev_is_alphanum = prev.is_ascii_alphanumeric();
198 let current_is_alphanum = current.is_ascii_alphanumeric();
199 let is_space_transition = prev == b' ' || current == b' ';
200
201 (prev_is_alphanum != current_is_alphanum) || is_space_transition
202 }
203
204 fn is_sentence_boundary(&self, prev: u8, current: u8) -> bool {
205 matches!(prev, b'.' | b'!' | b'?') && (current == b' ' || current == b'\n')
206 }
207
208 fn calculate_repetition_score(&self, bytes: &[u8], pos: usize) -> Result<f32> {
209 if pos < 4 {
210 return Ok(0.0);
211 }
212
213 let window_size = 8.min(pos);
214 let window = &bytes[pos - window_size..pos];
215
216 let mut max_repeat_length = 0;
217
218 for pattern_len in 1..=window_size / 2 {
219 let pattern = &window[window_size - pattern_len..];
220 let mut repeat_count = 0;
221
222 for i in (0..window_size - pattern_len).step_by(pattern_len) {
223 if &window[i..i + pattern_len] == pattern {
224 repeat_count += 1;
225 } else {
226 break;
227 }
228 }
229
230 if repeat_count > 1 {
231 max_repeat_length = max_repeat_length.max(pattern_len * repeat_count);
232 }
233 }
234
235 Ok(max_repeat_length as f32 / window_size as f32)
236 }
237
238 fn calculate_structural_score(&self, bytes: &[u8], pos: usize) -> Result<f32> {
239 if pos == 0 {
240 return Ok(0.0);
241 }
242
243 let current_byte = bytes[pos];
244
245 let structural_chars = [b'{', b'}', b'[', b']', b'(', b')', b'<', b'>', b';', b':', b','];
246 let is_structural = structural_chars.contains(¤t_byte);
247
248 if is_structural {
249 Ok(0.8)
250 } else {
251 Ok(0.0)
252 }
253 }
254
255 fn create_patch(&self, bytes: &[u8], start: usize, end: usize, signals: &PatchingSignals) -> Result<Patch> {
256 let patch_bytes = bytes[start..end].to_vec();
257 let complexity_score = self.calculate_complexity_score(signals);
258 let patch_type = self.determine_patch_type(signals);
259
260 Ok(Patch {
261 bytes: patch_bytes,
262 start_pos: start,
263 end_pos: end,
264 complexity_score,
265 patch_type,
266 })
267 }
268
269 fn calculate_complexity_score(&self, signals: &PatchingSignals) -> f32 {
270 let weights = [0.3, 0.2, 0.2, 0.15, 0.15];
271 let values = [
272 signals.entropy,
273 signals.compression_ratio,
274 signals.semantic_boundary,
275 signals.repetition_score,
276 signals.structural_score,
277 ];
278
279 weights.iter().zip(values.iter()).map(|(w, v)| w * v).sum()
280 }
281
282 fn determine_patch_type(&self, signals: &PatchingSignals) -> PatchType {
283 if signals.repetition_score > 0.7 {
284 PatchType::Repetitive
285 } else if signals.structural_score > 0.6 {
286 PatchType::Structural
287 } else if signals.semantic_boundary > 0.5 {
288 PatchType::Semantic
289 } else if signals.entropy > 0.7 {
290 PatchType::Complex
291 } else {
292 PatchType::Simple
293 }
294 }
295
296 fn hash_bytes(&self, bytes: &[u8]) -> u64 {
297 use std::hash::{Hash, Hasher};
298 let mut hasher = ahash::AHasher::default();
299 bytes.hash(&mut hasher);
300 hasher.finish()
301 }
302}
303
304#[derive(Debug, Clone)]
305struct PatchingSignals {
306 entropy: f32,
307 compression_ratio: f32,
308 semantic_boundary: f32,
309 repetition_score: f32,
310 structural_score: f32,
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn test_basic_patching() {
319 let config = ByteForgeConfig::default();
320 let mut patcher = MultiSignalPatcher::new(config);
321
322 let text = b"Hello world! This is a test.";
323 let patches = patcher.patch_bytes(text).unwrap();
324
325 assert!(!patches.is_empty());
326 assert!(patches.iter().all(|p| !p.bytes.is_empty()));
327 }
328
329 #[test]
330 fn test_semantic_boundary_detection() {
331 let config = ByteForgeConfig::default();
332 let patcher = MultiSignalPatcher::new(config);
333
334 let text = b"Hello world";
335 let score = patcher.detect_semantic_boundary(text, 5).unwrap(); assert!(score > 0.0);
337 }
338}