use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareStage, MiddlewareVerdict};
use async_trait::async_trait;
use std::path::PathBuf;
type MiddlewareExecuteFn = unsafe extern "C" fn(
ctx_json: *const u8,
ctx_len: u32,
out_buf: *mut u8,
out_buf_len: u32,
out_len: *mut u32,
) -> i32;
const MAX_OUTPUT_SIZE: usize = 1024 * 1024;
pub struct DylibMiddleware {
display_name: String,
#[allow(dead_code)]
lib: libloading::Library, execute_fn: MiddlewareExecuteFn,
active_stages: Vec<MiddlewareStage>,
}
unsafe impl Send for DylibMiddleware {}
unsafe impl Sync for DylibMiddleware {}
impl std::fmt::Debug for DylibMiddleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DylibMiddleware")
.field("name", &self.display_name)
.finish()
}
}
impl DylibMiddleware {
pub unsafe fn load(path: &PathBuf, stages: Vec<MiddlewareStage>) -> Result<Self, String> {
let lib = unsafe {
libloading::Library::new(path)
.map_err(|e| format!("Failed to load dylib {}: {e}", path.display()))?
};
let execute_fn: MiddlewareExecuteFn = unsafe {
let sym = lib
.get::<MiddlewareExecuteFn>(b"nsed_middleware_execute")
.map_err(|e| {
format!(
"Symbol 'nsed_middleware_execute' not found in {}: {e}",
path.display()
)
})?;
*sym
};
let name = path
.file_stem()
.and_then(|n| n.to_str())
.unwrap_or("dylib")
.to_string();
Ok(Self {
display_name: name,
lib,
execute_fn,
active_stages: stages,
})
}
}
#[async_trait]
impl AgentMiddleware for DylibMiddleware {
async fn execute(&self, ctx: &MiddlewareContext) -> MiddlewareVerdict {
let input = match serde_json::to_vec(ctx) {
Ok(v) => v,
Err(e) => {
return MiddlewareVerdict::block(
"dylib_middleware",
format!("Context serialization error: {e}"),
);
}
};
let execute_fn = self.execute_fn;
let name = self.display_name.clone();
let result = tokio::task::spawn_blocking(move || {
let mut out_buf = vec![0u8; MAX_OUTPUT_SIZE];
let mut out_len: u32 = 0;
let status = unsafe {
(execute_fn)(
input.as_ptr(),
input.len() as u32,
out_buf.as_mut_ptr(),
out_buf.len() as u32,
&mut out_len,
)
};
match status {
0 => {
let safe_len = (out_len as usize).min(out_buf.len());
let out_slice = &out_buf[..safe_len];
if out_slice.is_empty() {
return MiddlewareVerdict::pass();
}
serde_json::from_slice(out_slice).unwrap_or_else(|e| {
tracing::warn!(
middleware = name.as_str(),
error = %e,
"Dylib output was not valid JSON verdict — fail closed (block)"
);
MiddlewareVerdict::block(
"middleware_error",
format!("Dylib '{}' returned invalid JSON", name),
)
})
}
1 => MiddlewareVerdict::block("dylib_middleware", "Blocked by dylib middleware"),
code => MiddlewareVerdict::block(
"dylib_middleware",
format!("Dylib returned error code {code}"),
),
}
})
.await;
match result {
Ok(verdict) => verdict,
Err(e) => {
tracing::error!(
middleware = self.display_name.as_str(),
error = %e,
"Dylib middleware task panicked"
);
MiddlewareVerdict::block("dylib_middleware", "Dylib middleware panicked")
}
}
}
fn stages(&self) -> Vec<MiddlewareStage> {
self.active_stages.clone()
}
fn name(&self) -> &str {
&self.display_name
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::middleware::Verdict;
#[test]
fn dylib_load_nonexistent_fails() {
let result = unsafe { DylibMiddleware::load(&PathBuf::from("/nonexistent.so"), vec![]) };
assert!(result.is_err());
assert!(result.unwrap_err().contains("Failed to load"));
}
#[test]
fn dylib_load_missing_symbol_fails() {
#[cfg(target_os = "linux")]
let path = PathBuf::from("libc.so.6");
#[cfg(target_os = "macos")]
let path = PathBuf::from("libSystem.B.dylib");
#[cfg(target_os = "windows")]
let path = PathBuf::from("kernel32.dll");
let result = unsafe { DylibMiddleware::load(&path, vec![]) };
assert!(result.is_err());
assert!(result.unwrap_err().contains("nsed_middleware_execute"));
}
#[test]
fn verdict_from_status_codes() {
let v = MiddlewareVerdict::block("dylib_middleware", "Blocked by dylib middleware");
assert_eq!(v.verdict, Verdict::Block);
}
}