cbtop/adversarial/
resources.rs1use std::time::{Duration, Instant};
4
5use super::{AdversarialError, AdversarialResult, ResourceUsage};
6
7#[derive(Debug, Clone)]
9pub struct BitFlipInjector {
10 pub seed: u64,
12 pub flip_count: usize,
14}
15
16impl Default for BitFlipInjector {
17 fn default() -> Self {
18 Self {
19 seed: 42,
20 flip_count: 1,
21 }
22 }
23}
24
25impl BitFlipInjector {
26 pub fn new(seed: u64, flip_count: usize) -> Self {
28 Self { seed, flip_count }
29 }
30
31 pub fn inject(&self, data: &[u8]) -> Vec<u8> {
33 let mut result = data.to_vec();
34 if result.is_empty() {
35 return result;
36 }
37
38 let mut rng_state = self.seed;
40 for _ in 0..self.flip_count {
41 rng_state = rng_state
43 .wrapping_mul(6364136223846793005)
44 .wrapping_add(1442695040888963407);
45
46 let byte_idx = (rng_state as usize) % result.len();
47 let bit_idx = ((rng_state >> 32) as usize) % 8;
48
49 result[byte_idx] ^= 1 << bit_idx;
50 }
51
52 result
53 }
54
55 pub fn inject_floats(&self, data: &[f32]) -> Vec<f32> {
57 let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
59
60 let corrupted = self.inject(&bytes);
61
62 corrupted
63 .chunks_exact(4)
64 .map(|chunk| {
65 let arr: [u8; 4] = chunk.try_into().unwrap_or([0; 4]);
66 f32::from_le_bytes(arr)
67 })
68 .collect()
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct ResourceLimiter {
75 pub max_stack_depth: usize,
77 pub max_memory_bytes: usize,
79 pub timeout: Duration,
81 current_depth: usize,
83 current_memory: usize,
85 start_time: Option<Instant>,
87}
88
89impl Default for ResourceLimiter {
90 fn default() -> Self {
91 Self {
92 max_stack_depth: 1000,
93 max_memory_bytes: 1024 * 1024 * 1024, timeout: Duration::from_secs(60),
95 current_depth: 0,
96 current_memory: 0,
97 start_time: None,
98 }
99 }
100}
101
102impl ResourceLimiter {
103 pub fn new() -> Self {
105 Self::default()
106 }
107
108 pub fn with_max_depth(mut self, max_depth: usize) -> Self {
110 self.max_stack_depth = max_depth;
111 self
112 }
113
114 pub fn with_max_memory(mut self, max_bytes: usize) -> Self {
116 self.max_memory_bytes = max_bytes;
117 self
118 }
119
120 pub fn with_timeout(mut self, timeout: Duration) -> Self {
122 self.timeout = timeout;
123 self
124 }
125
126 pub fn start_operation(&mut self) {
128 self.start_time = Some(Instant::now());
129 }
130
131 pub fn check_timeout(&self, operation: &str) -> AdversarialResult<()> {
133 if let Some(start) = self.start_time {
134 let elapsed = start.elapsed();
135 if elapsed > self.timeout {
136 return Err(AdversarialError::Timeout {
137 operation: operation.to_string(),
138 elapsed,
139 limit: self.timeout,
140 });
141 }
142 }
143 Ok(())
144 }
145
146 pub fn enter_recursion(&mut self) -> AdversarialResult<()> {
148 self.current_depth += 1;
149 if self.current_depth > self.max_stack_depth {
150 return Err(AdversarialError::StackOverflow {
151 depth: self.current_depth,
152 max_depth: self.max_stack_depth,
153 });
154 }
155 Ok(())
156 }
157
158 pub fn exit_recursion(&mut self) {
160 if self.current_depth > 0 {
161 self.current_depth -= 1;
162 }
163 }
164
165 pub fn request_memory(&mut self, bytes: usize) -> AdversarialResult<()> {
167 let new_total = self.current_memory.saturating_add(bytes);
168 if new_total > self.max_memory_bytes {
169 return Err(AdversarialError::ResourceExhausted {
170 resource: format!(
171 "memory: requested {bytes} bytes, would exceed limit of {} bytes",
172 self.max_memory_bytes
173 ),
174 });
175 }
176 self.current_memory = new_total;
177 Ok(())
178 }
179
180 pub fn release_memory(&mut self, bytes: usize) {
182 self.current_memory = self.current_memory.saturating_sub(bytes);
183 }
184
185 pub fn usage(&self) -> ResourceUsage {
187 ResourceUsage {
188 stack_depth: self.current_depth,
189 memory_bytes: self.current_memory,
190 elapsed: self.start_time.map(|s| s.elapsed()),
191 }
192 }
193
194 pub fn reset(&mut self) {
196 self.current_depth = 0;
197 self.current_memory = 0;
198 self.start_time = None;
199 }
200}
201
202#[derive(Debug, Clone)]
204pub struct CancellationToken {
205 cancelled: std::sync::Arc<std::sync::atomic::AtomicBool>,
206}
207
208impl Default for CancellationToken {
209 fn default() -> Self {
210 Self::new()
211 }
212}
213
214impl CancellationToken {
215 pub fn new() -> Self {
217 Self {
218 cancelled: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)),
219 }
220 }
221
222 pub fn cancel(&self) {
224 self.cancelled
225 .store(true, std::sync::atomic::Ordering::SeqCst);
226 }
227
228 pub fn is_cancelled(&self) -> bool {
230 self.cancelled.load(std::sync::atomic::Ordering::SeqCst)
231 }
232
233 pub fn check(&self, operation: &str) -> AdversarialResult<()> {
235 if self.is_cancelled() {
236 return Err(AdversarialError::Cancelled {
237 operation: operation.to_string(),
238 });
239 }
240 Ok(())
241 }
242
243 pub fn clone_token(&self) -> Self {
245 Self {
246 cancelled: std::sync::Arc::clone(&self.cancelled),
247 }
248 }
249}
250
251#[derive(Debug, Clone)]
253pub struct RecoveryHandler<S: Clone> {
254 checkpoint: Option<S>,
256}
257
258impl<S: Clone> Default for RecoveryHandler<S> {
259 fn default() -> Self {
260 Self::new()
261 }
262}
263
264impl<S: Clone> RecoveryHandler<S> {
265 pub fn new() -> Self {
267 Self { checkpoint: None }
268 }
269
270 pub fn checkpoint(&mut self, state: S) {
272 self.checkpoint = Some(state);
273 }
274
275 pub fn recover(&self) -> AdversarialResult<S> {
277 self.checkpoint
278 .clone()
279 .ok_or_else(|| AdversarialError::RecoveryFailed {
280 original_error: "unknown".to_string(),
281 recovery_error: "no checkpoint available".to_string(),
282 })
283 }
284
285 pub fn try_with_recovery<F, T, E>(&self, operation: F) -> AdversarialResult<T>
287 where
288 F: FnOnce() -> Result<T, E>,
289 E: std::fmt::Display,
290 {
291 match operation() {
292 Ok(result) => Ok(result),
293 Err(e) => Err(AdversarialError::RecoveryFailed {
294 original_error: e.to_string(),
295 recovery_error: if self.checkpoint.is_some() {
296 "operation failed, checkpoint available".to_string()
297 } else {
298 "operation failed, no checkpoint".to_string()
299 },
300 }),
301 }
302 }
303
304 pub fn has_checkpoint(&self) -> bool {
306 self.checkpoint.is_some()
307 }
308}