1use core::ffi::c_void;
2use std::{
3 collections::{HashMap, HashSet},
4 path::{Path, PathBuf},
5 sync::{
6 atomic::{AtomicPtr, AtomicU64, Ordering},
7 Arc, RwLock,
8 },
9};
10
11use libloading::{Library, Symbol};
12
13use crate::{
14 abi::{c_ptr_to_str, PatchModuleV1, MODULE_EXPORT_SYMBOL_V1},
15 error::{HotpatchError, Result},
16 slot::{unresolved_symbol_stub, PatchSlot},
17 PatchFn,
18};
19
20#[derive(Debug, Clone)]
21pub struct ModuleLoadReport {
22 pub module_name: String,
23 pub module_version: u64,
24 pub symbols: Vec<String>,
25 pub generation: u64,
26 pub source_path: PathBuf,
27}
28
29struct LoadedModule {
30 name: String,
31 version: u64,
32 _library: Library,
33}
34
35pub struct HotpatchRuntime {
36 slots: RwLock<HashMap<String, Arc<PatchSlot>>>,
37 loaded_modules: RwLock<Vec<LoadedModule>>,
38 generation: AtomicU64,
39 host_context: AtomicPtr<c_void>,
40}
41
42impl Default for HotpatchRuntime {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl HotpatchRuntime {
49 pub fn new() -> Self {
50 Self {
51 slots: RwLock::new(HashMap::new()),
52 loaded_modules: RwLock::new(Vec::new()),
53 generation: AtomicU64::new(0),
54 host_context: AtomicPtr::new(core::ptr::null_mut()),
55 }
56 }
57
58 pub fn set_host_context(&self, ptr: *mut c_void) {
59 self.host_context.store(ptr, Ordering::Release);
60 }
61
62 pub fn generation(&self) -> u64 {
63 self.generation.load(Ordering::Acquire)
64 }
65
66 pub fn register_fallback(&self, symbol: impl Into<String>, fallback: PatchFn) -> Result<()> {
67 let symbol = symbol.into();
68 let mut slots = self
69 .slots
70 .write()
71 .expect("slot registry lock poisoned while registering fallback");
72 if slots.contains_key(&symbol) {
73 return Err(HotpatchError::FallbackAlreadyRegistered(symbol));
74 }
75 slots.insert(symbol, Arc::new(PatchSlot::new(fallback)));
76 Ok(())
77 }
78
79 pub fn list_symbols(&self) -> Vec<String> {
80 let slots = self
81 .slots
82 .read()
83 .expect("slot registry lock poisoned while listing symbols");
84 let mut keys: Vec<_> = slots.keys().cloned().collect();
85 keys.sort();
86 keys
87 }
88
89 pub unsafe fn call_raw(
90 &self,
91 symbol: &str,
92 input: *const c_void,
93 output: *mut c_void,
94 ) -> Result<i32> {
95 let slots = self
96 .slots
97 .read()
98 .expect("slot registry lock poisoned while invoking");
99 let slot = slots
100 .get(symbol)
101 .ok_or_else(|| HotpatchError::SymbolNotFound(symbol.to_string()))?;
102 let host_context = self.host_context.load(Ordering::Acquire);
103 Ok(slot.call(host_context, input, output))
104 }
105
106 pub unsafe fn call_typed<I, O>(&self, symbol: &str, input: &I, output: &mut O) -> Result<i32> {
107 self.call_raw(
108 symbol,
109 input as *const I as *const c_void,
110 output as *mut O as *mut c_void,
111 )
112 }
113
114 pub fn load_module<P: AsRef<Path>>(&self, path: P) -> Result<ModuleLoadReport> {
115 let path_ref = path.as_ref();
116 let library = unsafe {
117 Library::new(path_ref).map_err(|source| HotpatchError::LibraryLoad {
118 path: path_ref.display().to_string(),
119 source,
120 })?
121 };
122
123 let module = unsafe {
124 let symbol: Symbol<*const PatchModuleV1> =
125 library
126 .get(MODULE_EXPORT_SYMBOL_V1)
127 .map_err(|source| HotpatchError::MissingExport {
128 symbol: "HOTPATCH_MODULE_V1",
129 source,
130 })?;
131
132 let module_ptr = *symbol;
133 if module_ptr.is_null() {
134 return Err(HotpatchError::InvalidModule(
135 "module export pointer is null".into(),
136 ));
137 }
138 &*module_ptr
139 };
140
141 module.validate_header()?;
142
143 let module_name = unsafe { module.module_name()?.to_owned() };
144 let entries = unsafe { module.entries_slice() };
145
146 let mut seen = HashSet::with_capacity(entries.len());
147 let mut updates = Vec::with_capacity(entries.len());
148
149 for entry in entries {
150 if entry.name.is_null() {
151 return Err(HotpatchError::InvalidModule(
152 "entry has null symbol name".into(),
153 ));
154 }
155 let name = unsafe { c_ptr_to_str(entry.name)?.to_owned() };
156 if !seen.insert(name.clone()) {
157 return Err(HotpatchError::InvalidModule(format!(
158 "duplicate entry symbol in module: {name}"
159 )));
160 }
161 updates.push((name, entry.func));
162 }
163
164 if let Some(on_load) = module.on_load {
165 let status = unsafe { on_load(self.host_context.load(Ordering::Acquire)) };
166 if status != 0 {
167 return Err(HotpatchError::LifecycleFailed {
168 module: module_name,
169 hook: "on_load",
170 status,
171 });
172 }
173 }
174
175 {
176 let mut slots = self
177 .slots
178 .write()
179 .expect("slot registry lock poisoned while loading module");
180 for (symbol, func) in &updates {
181 let slot = slots
182 .entry(symbol.clone())
183 .or_insert_with(|| Arc::new(PatchSlot::new(unresolved_symbol_stub)));
184 slot.swap(*func);
185 }
186 }
187
188 self.loaded_modules
189 .write()
190 .expect("module registry lock poisoned while loading module")
191 .push(LoadedModule {
192 name: module_name.clone(),
193 version: module.module_version,
194 _library: library,
195 });
196
197 let generation = self.generation.fetch_add(1, Ordering::AcqRel) + 1;
198
199 Ok(ModuleLoadReport {
200 module_name,
201 module_version: module.module_version,
202 symbols: updates.into_iter().map(|(name, _)| name).collect(),
203 generation,
204 source_path: path_ref.to_path_buf(),
205 })
206 }
207
208 pub fn loaded_modules(&self) -> Vec<(String, u64)> {
209 let modules = self
210 .loaded_modules
211 .read()
212 .expect("module registry lock poisoned while listing modules");
213 modules
214 .iter()
215 .map(|m| (m.name.clone(), m.version))
216 .collect()
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[repr(C)]
225 struct Input {
226 value: i32,
227 }
228
229 #[repr(C)]
230 struct Output {
231 value: i32,
232 }
233
234 unsafe extern "C" fn fallback(_: *mut c_void, input: *const c_void, output: *mut c_void) -> i32 {
235 let input = &*(input as *const Input);
236 let output = &mut *(output as *mut Output);
237 output.value = input.value + 1;
238 0
239 }
240
241 unsafe extern "C" fn patched(_: *mut c_void, input: *const c_void, output: *mut c_void) -> i32 {
242 let input = &*(input as *const Input);
243 let output = &mut *(output as *mut Output);
244 output.value = input.value + 10;
245 0
246 }
247
248 #[test]
249 fn fallback_registration_and_calls_work() {
250 let runtime = HotpatchRuntime::new();
251 runtime.register_fallback("math.add", fallback).unwrap();
252
253 let mut out = Output { value: 0 };
254 unsafe {
255 runtime
256 .call_typed("math.add", &Input { value: 5 }, &mut out)
257 .unwrap();
258 }
259 assert_eq!(out.value, 6);
260 }
261
262 #[test]
263 fn fallback_duplicate_is_rejected() {
264 let runtime = HotpatchRuntime::new();
265 runtime.register_fallback("math.add", fallback).unwrap();
266 let err = runtime.register_fallback("math.add", fallback).unwrap_err();
267 assert!(matches!(err, HotpatchError::FallbackAlreadyRegistered(_)));
268 }
269
270 #[test]
271 fn symbol_not_found_is_reported() {
272 let runtime = HotpatchRuntime::new();
273 let mut out = Output { value: 0 };
274 let err = unsafe { runtime.call_typed("does.not.exist", &Input { value: 1 }, &mut out) }
275 .unwrap_err();
276 assert!(matches!(err, HotpatchError::SymbolNotFound(_)));
277 }
278
279 #[test]
280 fn patch_slot_swap_changes_behavior() {
281 let runtime = HotpatchRuntime::new();
282 runtime.register_fallback("math.add", fallback).unwrap();
283 {
284 let slots = runtime.slots.read().unwrap();
285 slots.get("math.add").unwrap().swap(patched);
286 }
287
288 let mut out = Output { value: 0 };
289 unsafe {
290 runtime
291 .call_typed("math.add", &Input { value: 5 }, &mut out)
292 .unwrap();
293 }
294 assert_eq!(out.value, 15);
295 }
296}