Skip to main content

hotpatch_rs/
runtime.rs

1use core::ffi::c_void;
2use std::{
3    collections::{HashMap, HashSet},
4    path::{Path, PathBuf},
5    sync::{
6        atomic::{AtomicPtr, AtomicU64, Ordering},
7        Arc, RwLock,
8    },
9};
10
11use libloading::{Library, Symbol};
12
13use crate::{
14    abi::{c_ptr_to_str, PatchModuleV1, MODULE_EXPORT_SYMBOL_V1},
15    error::{HotpatchError, Result},
16    slot::{unresolved_symbol_stub, PatchSlot},
17    PatchFn,
18};
19
20#[derive(Debug, Clone)]
21pub struct ModuleLoadReport {
22    pub module_name: String,
23    pub module_version: u64,
24    pub symbols: Vec<String>,
25    pub generation: u64,
26    pub source_path: PathBuf,
27}
28
29struct LoadedModule {
30    name: String,
31    version: u64,
32    _library: Library,
33}
34
35pub struct HotpatchRuntime {
36    slots: RwLock<HashMap<String, Arc<PatchSlot>>>,
37    loaded_modules: RwLock<Vec<LoadedModule>>,
38    generation: AtomicU64,
39    host_context: AtomicPtr<c_void>,
40}
41
42impl Default for HotpatchRuntime {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl HotpatchRuntime {
49    pub fn new() -> Self {
50        Self {
51            slots: RwLock::new(HashMap::new()),
52            loaded_modules: RwLock::new(Vec::new()),
53            generation: AtomicU64::new(0),
54            host_context: AtomicPtr::new(core::ptr::null_mut()),
55        }
56    }
57
58    pub fn set_host_context(&self, ptr: *mut c_void) {
59        self.host_context.store(ptr, Ordering::Release);
60    }
61
62    pub fn generation(&self) -> u64 {
63        self.generation.load(Ordering::Acquire)
64    }
65
66    pub fn register_fallback(&self, symbol: impl Into<String>, fallback: PatchFn) -> Result<()> {
67        let symbol = symbol.into();
68        let mut slots = self
69            .slots
70            .write()
71            .expect("slot registry lock poisoned while registering fallback");
72        if slots.contains_key(&symbol) {
73            return Err(HotpatchError::FallbackAlreadyRegistered(symbol));
74        }
75        slots.insert(symbol, Arc::new(PatchSlot::new(fallback)));
76        Ok(())
77    }
78
79    pub fn list_symbols(&self) -> Vec<String> {
80        let slots = self
81            .slots
82            .read()
83            .expect("slot registry lock poisoned while listing symbols");
84        let mut keys: Vec<_> = slots.keys().cloned().collect();
85        keys.sort();
86        keys
87    }
88
89    pub unsafe fn call_raw(
90        &self,
91        symbol: &str,
92        input: *const c_void,
93        output: *mut c_void,
94    ) -> Result<i32> {
95        let slots = self
96            .slots
97            .read()
98            .expect("slot registry lock poisoned while invoking");
99        let slot = slots
100            .get(symbol)
101            .ok_or_else(|| HotpatchError::SymbolNotFound(symbol.to_string()))?;
102        let host_context = self.host_context.load(Ordering::Acquire);
103        Ok(slot.call(host_context, input, output))
104    }
105
106    pub unsafe fn call_typed<I, O>(&self, symbol: &str, input: &I, output: &mut O) -> Result<i32> {
107        self.call_raw(
108            symbol,
109            input as *const I as *const c_void,
110            output as *mut O as *mut c_void,
111        )
112    }
113
114    pub fn load_module<P: AsRef<Path>>(&self, path: P) -> Result<ModuleLoadReport> {
115        let path_ref = path.as_ref();
116        let library = unsafe {
117            Library::new(path_ref).map_err(|source| HotpatchError::LibraryLoad {
118                path: path_ref.display().to_string(),
119                source,
120            })?
121        };
122
123        let module = unsafe {
124            let symbol: Symbol<*const PatchModuleV1> =
125                library
126                    .get(MODULE_EXPORT_SYMBOL_V1)
127                    .map_err(|source| HotpatchError::MissingExport {
128                        symbol: "HOTPATCH_MODULE_V1",
129                        source,
130                    })?;
131
132            let module_ptr = *symbol;
133            if module_ptr.is_null() {
134                return Err(HotpatchError::InvalidModule(
135                    "module export pointer is null".into(),
136                ));
137            }
138            &*module_ptr
139        };
140
141        module.validate_header()?;
142
143        let module_name = unsafe { module.module_name()?.to_owned() };
144        let entries = unsafe { module.entries_slice() };
145
146        let mut seen = HashSet::with_capacity(entries.len());
147        let mut updates = Vec::with_capacity(entries.len());
148
149        for entry in entries {
150            if entry.name.is_null() {
151                return Err(HotpatchError::InvalidModule(
152                    "entry has null symbol name".into(),
153                ));
154            }
155            let name = unsafe { c_ptr_to_str(entry.name)?.to_owned() };
156            if !seen.insert(name.clone()) {
157                return Err(HotpatchError::InvalidModule(format!(
158                    "duplicate entry symbol in module: {name}"
159                )));
160            }
161            updates.push((name, entry.func));
162        }
163
164        if let Some(on_load) = module.on_load {
165            let status = unsafe { on_load(self.host_context.load(Ordering::Acquire)) };
166            if status != 0 {
167                return Err(HotpatchError::LifecycleFailed {
168                    module: module_name,
169                    hook: "on_load",
170                    status,
171                });
172            }
173        }
174
175        {
176            let mut slots = self
177                .slots
178                .write()
179                .expect("slot registry lock poisoned while loading module");
180            for (symbol, func) in &updates {
181                let slot = slots
182                    .entry(symbol.clone())
183                    .or_insert_with(|| Arc::new(PatchSlot::new(unresolved_symbol_stub)));
184                slot.swap(*func);
185            }
186        }
187
188        self.loaded_modules
189            .write()
190            .expect("module registry lock poisoned while loading module")
191            .push(LoadedModule {
192                name: module_name.clone(),
193                version: module.module_version,
194                _library: library,
195            });
196
197        let generation = self.generation.fetch_add(1, Ordering::AcqRel) + 1;
198
199        Ok(ModuleLoadReport {
200            module_name,
201            module_version: module.module_version,
202            symbols: updates.into_iter().map(|(name, _)| name).collect(),
203            generation,
204            source_path: path_ref.to_path_buf(),
205        })
206    }
207
208    pub fn loaded_modules(&self) -> Vec<(String, u64)> {
209        let modules = self
210            .loaded_modules
211            .read()
212            .expect("module registry lock poisoned while listing modules");
213        modules
214            .iter()
215            .map(|m| (m.name.clone(), m.version))
216            .collect()
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[repr(C)]
225    struct Input {
226        value: i32,
227    }
228
229    #[repr(C)]
230    struct Output {
231        value: i32,
232    }
233
234    unsafe extern "C" fn fallback(_: *mut c_void, input: *const c_void, output: *mut c_void) -> i32 {
235        let input = &*(input as *const Input);
236        let output = &mut *(output as *mut Output);
237        output.value = input.value + 1;
238        0
239    }
240
241    unsafe extern "C" fn patched(_: *mut c_void, input: *const c_void, output: *mut c_void) -> i32 {
242        let input = &*(input as *const Input);
243        let output = &mut *(output as *mut Output);
244        output.value = input.value + 10;
245        0
246    }
247
248    #[test]
249    fn fallback_registration_and_calls_work() {
250        let runtime = HotpatchRuntime::new();
251        runtime.register_fallback("math.add", fallback).unwrap();
252
253        let mut out = Output { value: 0 };
254        unsafe {
255            runtime
256                .call_typed("math.add", &Input { value: 5 }, &mut out)
257                .unwrap();
258        }
259        assert_eq!(out.value, 6);
260    }
261
262    #[test]
263    fn fallback_duplicate_is_rejected() {
264        let runtime = HotpatchRuntime::new();
265        runtime.register_fallback("math.add", fallback).unwrap();
266        let err = runtime.register_fallback("math.add", fallback).unwrap_err();
267        assert!(matches!(err, HotpatchError::FallbackAlreadyRegistered(_)));
268    }
269
270    #[test]
271    fn symbol_not_found_is_reported() {
272        let runtime = HotpatchRuntime::new();
273        let mut out = Output { value: 0 };
274        let err = unsafe { runtime.call_typed("does.not.exist", &Input { value: 1 }, &mut out) }
275            .unwrap_err();
276        assert!(matches!(err, HotpatchError::SymbolNotFound(_)));
277    }
278
279    #[test]
280    fn patch_slot_swap_changes_behavior() {
281        let runtime = HotpatchRuntime::new();
282        runtime.register_fallback("math.add", fallback).unwrap();
283        {
284            let slots = runtime.slots.read().unwrap();
285            slots.get("math.add").unwrap().swap(patched);
286        }
287
288        let mut out = Output { value: 0 };
289        unsafe {
290            runtime
291                .call_typed("math.add", &Input { value: 5 }, &mut out)
292                .unwrap();
293        }
294        assert_eq!(out.value, 15);
295    }
296}