use std::ffi::{CStr, c_char, c_int, c_void};
use std::sync::{Arc, Mutex, OnceLock};
use crate::bindings::xmlRegisterInputCallbacks;
type MatchFn = Box<dyn Fn(&str) -> bool + Send + Sync + 'static>;
type OpenFn = Box<dyn Fn(&str) -> Option<Vec<u8>> + Send + Sync + 'static>;
struct Callback {
match_url: MatchFn,
open: OpenFn,
}
fn callbacks() -> &'static Mutex<Vec<Arc<Callback>>> {
static CALLBACKS: OnceLock<Mutex<Vec<Arc<Callback>>>> = OnceLock::new();
CALLBACKS.get_or_init(|| Mutex::new(Vec::new()))
}
fn snapshot() -> Vec<Arc<Callback>> {
callbacks().lock().unwrap().clone()
}
pub fn register_input_callback<M, O>(match_url: M, open: O)
where
M: Fn(&str) -> bool + Send + Sync + 'static,
O: Fn(&str) -> Option<Vec<u8>> + Send + Sync + 'static,
{
callbacks().lock().unwrap().push(Arc::new(Callback {
match_url: Box::new(match_url),
open: Box::new(open),
}));
static REGISTERED: OnceLock<()> = OnceLock::new();
REGISTERED.get_or_init(|| {
crate::init_parser();
unsafe {
xmlRegisterInputCallbacks(
Some(trampoline_match),
Some(trampoline_open),
Some(trampoline_read),
Some(trampoline_close),
);
}
});
}
struct OpenState {
bytes: Vec<u8>,
position: usize,
}
unsafe extern "C" fn trampoline_match(filename: *const c_char) -> c_int {
if filename.is_null() {
return 0;
}
let url = match unsafe { CStr::from_ptr(filename) }.to_str() {
Ok(s) => s,
Err(_) => return 0,
};
for cb in snapshot().iter().rev() {
if (cb.match_url)(url) {
return 1;
}
}
0
}
unsafe extern "C" fn trampoline_open(filename: *const c_char) -> *mut c_void {
if filename.is_null() {
return std::ptr::null_mut();
}
let url = match unsafe { CStr::from_ptr(filename) }.to_str() {
Ok(s) => s,
Err(_) => return std::ptr::null_mut(),
};
for cb in snapshot().iter().rev() {
if !(cb.match_url)(url) {
continue;
}
if let Some(bytes) = (cb.open)(url) {
return Box::into_raw(Box::new(OpenState { bytes, position: 0 })) as *mut c_void;
}
}
std::ptr::null_mut()
}
unsafe extern "C" fn trampoline_read(
context: *mut c_void,
buffer: *mut c_char,
len: c_int,
) -> c_int {
if context.is_null() || buffer.is_null() || len <= 0 {
return -1;
}
let state = unsafe { &mut *(context as *mut OpenState) };
let remaining = state.bytes.len().saturating_sub(state.position);
let n = remaining.min(len as usize);
if n == 0 {
return 0;
}
unsafe {
std::ptr::copy_nonoverlapping(
state.bytes.as_ptr().add(state.position),
buffer as *mut u8,
n,
);
}
state.position += n;
n as c_int
}
unsafe extern "C" fn trampoline_close(context: *mut c_void) -> c_int {
if context.is_null() {
return -1;
}
let _state = unsafe { Box::from_raw(context as *mut OpenState) };
0
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bindings::{xmlFreeDoc, xmlReadFile};
use std::ffi::CString;
use std::sync::atomic::{AtomicUsize, Ordering};
static SAMPLE_XML: &[u8] = br#"<?xml version="1.0"?>
<root attr="ok"><child/></root>"#;
fn read_file_via_libxml2(url: &str) -> bool {
let c = CString::new(url).unwrap();
unsafe {
let doc = xmlReadFile(c.as_ptr(), std::ptr::null(), 0);
if doc.is_null() {
return false;
}
xmlFreeDoc(doc);
true
}
}
#[test]
fn input_callback_scenarios() {
register_input_callback(
|url| url.starts_with("embed:///"),
|url| (url == "embed:///sample.xml").then(|| SAMPLE_XML.to_vec()),
);
assert!(read_file_via_libxml2("embed:///sample.xml"));
assert!(!read_file_via_libxml2("embed:///unknown.xml"));
assert!(!read_file_via_libxml2("/nonexistent/definitely/missing.xml"));
register_input_callback(
|url| url == "reentrant:///outer",
|_url| {
let _ = read_file_via_libxml2("embed:///sample.xml");
Some(SAMPLE_XML.to_vec())
},
);
assert!(read_file_via_libxml2("reentrant:///outer"));
static FIRST_OPENED: AtomicUsize = AtomicUsize::new(0);
static SECOND_OPENED: AtomicUsize = AtomicUsize::new(0);
register_input_callback(
|url| url == "ordered:///x",
|_| {
FIRST_OPENED.fetch_add(1, Ordering::SeqCst);
Some(b"<a>first</a>".to_vec())
},
);
register_input_callback(
|url| url == "ordered:///x",
|_| {
SECOND_OPENED.fetch_add(1, Ordering::SeqCst);
Some(SAMPLE_XML.to_vec())
},
);
assert!(read_file_via_libxml2("ordered:///x"));
assert_eq!(
SECOND_OPENED.load(Ordering::SeqCst),
1,
"newest registration should run",
);
assert_eq!(
FIRST_OPENED.load(Ordering::SeqCst),
0,
"older registration should not be consulted",
);
}
}