memlink_runtime/
safety.rs1use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
4use std::sync::Arc;
5use std::thread;
6use std::time::Duration;
7
8use crate::error::{Error, Result};
9
10pub const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(30);
11pub const DEFAULT_MAX_STACK_DEPTH: usize = 100;
12pub const DEFAULT_MEMORY_LIMIT: usize = 64 * 1024 * 1024;
13
14#[derive(Debug, Clone)]
15pub struct SafetyConfig {
16 pub call_timeout: Duration,
17 pub max_stack_depth: usize,
18 pub memory_limit: usize,
19 pub enforce_timeout: bool,
20 pub check_stack_depth: bool,
21 pub enforce_memory_limit: bool,
22}
23
24impl Default for SafetyConfig {
25 fn default() -> Self {
26 SafetyConfig {
27 call_timeout: DEFAULT_CALL_TIMEOUT,
28 max_stack_depth: DEFAULT_MAX_STACK_DEPTH,
29 memory_limit: DEFAULT_MEMORY_LIMIT,
30 enforce_timeout: true,
31 check_stack_depth: false,
32 enforce_memory_limit: true,
33 }
34 }
35}
36
37impl SafetyConfig {
38 pub fn with_call_timeout(mut self, timeout: Duration) -> Self {
39 self.call_timeout = timeout;
40 self
41 }
42
43 pub fn with_max_stack_depth(mut self, depth: usize) -> Self {
44 self.max_stack_depth = depth;
45 self
46 }
47
48 pub fn with_memory_limit(mut self, limit: usize) -> Self {
49 self.memory_limit = limit;
50 self
51 }
52
53 pub fn with_timeout_enforcement(mut self, enabled: bool) -> Self {
54 self.enforce_timeout = enabled;
55 self
56 }
57
58 pub fn with_stack_depth_check(mut self, enabled: bool) -> Self {
59 self.check_stack_depth = enabled;
60 self
61 }
62
63 pub fn with_memory_limit_enforcement(mut self, enabled: bool) -> Self {
64 self.enforce_memory_limit = enabled;
65 self
66 }
67}
68
69#[derive(Debug, Default)]
70pub struct StackDepth {
71 depth: AtomicUsize,
72}
73
74impl StackDepth {
75 pub fn new() -> Self {
76 StackDepth {
77 depth: AtomicUsize::new(0),
78 }
79 }
80
81 pub fn enter(&self, max_depth: usize) -> Result<()> {
82 let new_depth = self.depth.fetch_add(1, Ordering::Relaxed) + 1;
83 if new_depth > max_depth {
84 self.depth.fetch_sub(1, Ordering::Relaxed);
85 Err(Error::ModulePanicked(format!(
86 "Stack depth {} exceeds maximum {}",
87 new_depth, max_depth
88 )))
89 } else {
90 Ok(())
91 }
92 }
93
94 pub fn exit(&self) {
95 self.depth.fetch_sub(1, Ordering::Relaxed);
96 }
97
98 pub fn current(&self) -> usize {
99 self.depth.load(Ordering::Relaxed)
100 }
101
102 pub fn reset(&self) {
103 self.depth.store(0, Ordering::Relaxed);
104 }
105}
106
107#[derive(Debug)]
108pub struct MemoryTracker {
109 used: AtomicUsize,
110 limit: usize,
111}
112
113impl MemoryTracker {
114 pub fn new(limit: usize) -> Self {
115 MemoryTracker {
116 used: AtomicUsize::new(0),
117 limit,
118 }
119 }
120
121 pub fn allocate(&self, size: usize) -> Result<()> {
122 let mut current = self.used.load(Ordering::Relaxed);
123
124 loop {
125 let new_used = current + size;
126 if new_used > self.limit {
127 return Err(Error::InvalidModuleFormat(format!(
128 "Memory allocation of {} bytes would exceed limit of {} bytes",
129 size, self.limit
130 )));
131 }
132
133 match self.used.compare_exchange(
134 current,
135 new_used,
136 Ordering::SeqCst,
137 Ordering::Relaxed,
138 ) {
139 Ok(_) => return Ok(()),
140 Err(x) => current = x,
141 }
142 }
143 }
144
145 pub fn free(&self, size: usize) {
146 self.used.fetch_sub(size, Ordering::Relaxed);
147 }
148
149 pub fn used(&self) -> usize {
150 self.used.load(Ordering::Relaxed)
151 }
152
153 pub fn limit(&self) -> usize {
154 self.limit
155 }
156
157 pub fn remaining(&self) -> usize {
158 self.limit - self.used.load(Ordering::Relaxed)
159 }
160
161 pub fn usage_ratio(&self) -> f32 {
162 self.used.load(Ordering::Relaxed) as f32 / self.limit as f32
163 }
164}
165
166pub fn with_timeout<F, R>(timeout: Duration, f: F) -> Result<R>
167where
168 F: FnOnce() -> R + Send + 'static,
169 R: Send + 'static,
170{
171 let timeout_flag = Arc::new(AtomicBool::new(false));
172 let timeout_flag_clone = Arc::clone(&timeout_flag);
173
174 let handle = thread::spawn(move || {
175 let result = f();
176 timeout_flag_clone.store(true, Ordering::Relaxed);
177 result
178 });
179
180 match handle.join_timeout(timeout) {
181 Ok(result) => Ok(result),
182 Err(_) => Err(Error::ModuleCallFailed(-3)),
183 }
184}
185
186trait JoinHandleTimeout<T> {
187 fn join_timeout(self, timeout: Duration) -> thread::Result<T>;
188}
189
190impl<T> JoinHandleTimeout<T> for thread::JoinHandle<T> {
191 fn join_timeout(self, timeout: Duration) -> thread::Result<T> {
192 let start = std::time::Instant::now();
193 while start.elapsed() < timeout {
194 if self.is_finished() {
195 return self.join();
196 }
197 thread::sleep(Duration::from_millis(1));
198 }
199 Err(Box::new("Timeout expired"))
200 }
201}
202
203pub fn validate_call_safety(
204 config: &SafetyConfig,
205 stack_depth: &StackDepth,
206 memory_tracker: Option<&MemoryTracker>,
207) -> Result<()> {
208 if config.check_stack_depth {
209 stack_depth.enter(config.max_stack_depth)?;
210 }
211
212 if let Some(tracker) = memory_tracker {
213 if config.enforce_memory_limit && tracker.used() > tracker.limit() {
214 return Err(Error::InvalidModuleFormat(
215 "Memory limit exceeded".to_string()
216 ));
217 }
218 }
219
220 Ok(())
221}