Skip to main content

memlink_runtime/
safety.rs

1//! Safety hardening features for the runtime.
2
3use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
4use std::sync::Arc;
5use std::thread;
6use std::time::Duration;
7
8use crate::error::{Error, Result};
9
10pub const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(30);
11pub const DEFAULT_MAX_STACK_DEPTH: usize = 100;
12pub const DEFAULT_MEMORY_LIMIT: usize = 64 * 1024 * 1024;
13
14#[derive(Debug, Clone)]
15pub struct SafetyConfig {
16    pub call_timeout: Duration,
17    pub max_stack_depth: usize,
18    pub memory_limit: usize,
19    pub enforce_timeout: bool,
20    pub check_stack_depth: bool,
21    pub enforce_memory_limit: bool,
22}
23
24impl Default for SafetyConfig {
25    fn default() -> Self {
26        SafetyConfig {
27            call_timeout: DEFAULT_CALL_TIMEOUT,
28            max_stack_depth: DEFAULT_MAX_STACK_DEPTH,
29            memory_limit: DEFAULT_MEMORY_LIMIT,
30            enforce_timeout: true,
31            check_stack_depth: false,
32            enforce_memory_limit: true,
33        }
34    }
35}
36
37impl SafetyConfig {
38    pub fn with_call_timeout(mut self, timeout: Duration) -> Self {
39        self.call_timeout = timeout;
40        self
41    }
42
43    pub fn with_max_stack_depth(mut self, depth: usize) -> Self {
44        self.max_stack_depth = depth;
45        self
46    }
47
48    pub fn with_memory_limit(mut self, limit: usize) -> Self {
49        self.memory_limit = limit;
50        self
51    }
52
53    pub fn with_timeout_enforcement(mut self, enabled: bool) -> Self {
54        self.enforce_timeout = enabled;
55        self
56    }
57
58    pub fn with_stack_depth_check(mut self, enabled: bool) -> Self {
59        self.check_stack_depth = enabled;
60        self
61    }
62
63    pub fn with_memory_limit_enforcement(mut self, enabled: bool) -> Self {
64        self.enforce_memory_limit = enabled;
65        self
66    }
67}
68
69#[derive(Debug, Default)]
70pub struct StackDepth {
71    depth: AtomicUsize,
72}
73
74impl StackDepth {
75    pub fn new() -> Self {
76        StackDepth {
77            depth: AtomicUsize::new(0),
78        }
79    }
80
81    pub fn enter(&self, max_depth: usize) -> Result<()> {
82        let new_depth = self.depth.fetch_add(1, Ordering::Relaxed) + 1;
83        if new_depth > max_depth {
84            self.depth.fetch_sub(1, Ordering::Relaxed);
85            Err(Error::ModulePanicked(format!(
86                "Stack depth {} exceeds maximum {}",
87                new_depth, max_depth
88            )))
89        } else {
90            Ok(())
91        }
92    }
93
94    pub fn exit(&self) {
95        self.depth.fetch_sub(1, Ordering::Relaxed);
96    }
97
98    pub fn current(&self) -> usize {
99        self.depth.load(Ordering::Relaxed)
100    }
101
102    pub fn reset(&self) {
103        self.depth.store(0, Ordering::Relaxed);
104    }
105}
106
107#[derive(Debug)]
108pub struct MemoryTracker {
109    used: AtomicUsize,
110    limit: usize,
111}
112
113impl MemoryTracker {
114    pub fn new(limit: usize) -> Self {
115        MemoryTracker {
116            used: AtomicUsize::new(0),
117            limit,
118        }
119    }
120
121    pub fn allocate(&self, size: usize) -> Result<()> {
122        let mut current = self.used.load(Ordering::Relaxed);
123
124        loop {
125            let new_used = current + size;
126            if new_used > self.limit {
127                return Err(Error::InvalidModuleFormat(format!(
128                    "Memory allocation of {} bytes would exceed limit of {} bytes",
129                    size, self.limit
130                )));
131            }
132
133            match self.used.compare_exchange(
134                current,
135                new_used,
136                Ordering::SeqCst,
137                Ordering::Relaxed,
138            ) {
139                Ok(_) => return Ok(()),
140                Err(x) => current = x,
141            }
142        }
143    }
144
145    pub fn free(&self, size: usize) {
146        self.used.fetch_sub(size, Ordering::Relaxed);
147    }
148
149    pub fn used(&self) -> usize {
150        self.used.load(Ordering::Relaxed)
151    }
152
153    pub fn limit(&self) -> usize {
154        self.limit
155    }
156
157    pub fn remaining(&self) -> usize {
158        self.limit - self.used.load(Ordering::Relaxed)
159    }
160
161    pub fn usage_ratio(&self) -> f32 {
162        self.used.load(Ordering::Relaxed) as f32 / self.limit as f32
163    }
164}
165
166pub fn with_timeout<F, R>(timeout: Duration, f: F) -> Result<R>
167where
168    F: FnOnce() -> R + Send + 'static,
169    R: Send + 'static,
170{
171    let timeout_flag = Arc::new(AtomicBool::new(false));
172    let timeout_flag_clone = Arc::clone(&timeout_flag);
173
174    let handle = thread::spawn(move || {
175        let result = f();
176        timeout_flag_clone.store(true, Ordering::Relaxed);
177        result
178    });
179
180    match handle.join_timeout(timeout) {
181        Ok(result) => Ok(result),
182        Err(_) => Err(Error::ModuleCallFailed(-3)),
183    }
184}
185
186trait JoinHandleTimeout<T> {
187    fn join_timeout(self, timeout: Duration) -> thread::Result<T>;
188}
189
190impl<T> JoinHandleTimeout<T> for thread::JoinHandle<T> {
191    fn join_timeout(self, timeout: Duration) -> thread::Result<T> {
192        let start = std::time::Instant::now();
193        while start.elapsed() < timeout {
194            if self.is_finished() {
195                return self.join();
196            }
197            thread::sleep(Duration::from_millis(1));
198        }
199        Err(Box::new("Timeout expired"))
200    }
201}
202
203pub fn validate_call_safety(
204    config: &SafetyConfig,
205    stack_depth: &StackDepth,
206    memory_tracker: Option<&MemoryTracker>,
207) -> Result<()> {
208    if config.check_stack_depth {
209        stack_depth.enter(config.max_stack_depth)?;
210    }
211
212    if let Some(tracker) = memory_tracker {
213        if config.enforce_memory_limit && tracker.used() > tracker.limit() {
214            return Err(Error::InvalidModuleFormat(
215                "Memory limit exceeded".to_string()
216            ));
217        }
218    }
219
220    Ok(())
221}