mujoco_rs/wrappers/
mj_plugin.rs1use std::ffi::{CString, c_char, c_int};
3use std::path::Path;
4
5use crate::mujoco_c::{mj_loadAllPluginLibraries, mj_loadPluginLibrary};
6use crate::error::MjPluginError;
7
8pub type MjPluginLibraryLoadCallback =
12 Option<unsafe extern "C" fn(*const c_char, c_int, c_int)>;
13
14pub fn load_plugin_library<P: AsRef<Path>>(path: P) -> Result<(), MjPluginError> {
32 let s = path.as_ref().to_str().ok_or(MjPluginError::InvalidUtf8Path)?;
33 let c = CString::new(s).map_err(|_| MjPluginError::NullBytePath)?;
34 unsafe { mj_loadPluginLibrary(c.as_ptr()) };
36 Ok(())
37}
38
39pub fn load_all_plugin_libraries<P: AsRef<Path>>(
62 directory: P,
63 callback: MjPluginLibraryLoadCallback,
64) -> Result<(), MjPluginError> {
65 let s = directory.as_ref().to_str().ok_or(MjPluginError::InvalidUtf8Path)?;
66 let c = CString::new(s).map_err(|_| MjPluginError::NullBytePath)?;
67 unsafe { mj_loadAllPluginLibraries(c.as_ptr(), callback) };
70 Ok(())
71}
72
73#[cfg(test)]
74mod tests {
75 use super::{load_all_plugin_libraries, load_plugin_library};
76
77 use crate::error::MjPluginError;
78
79 #[test]
81 fn load_all_plugin_libraries_loads_from_directory() {
82 let lib_dir = match std::env::var("MUJOCO_DYNAMIC_LINK_DIR") {
83 Ok(d) => d,
84 Err(_) => return,
85 };
86 let plugin_dir = std::path::Path::new(&lib_dir)
87 .parent()
88 .expect("MUJOCO_DYNAMIC_LINK_DIR should have a parent directory")
89 .join("bin/mujoco_plugin");
90
91 load_all_plugin_libraries(&plugin_dir, None).expect("plugin dir should load");
92 }
93
94 #[test]
95 fn load_plugin_library_null_byte_error() {
96 let result = load_plugin_library("path\0with\0nulls");
97 assert!(matches!(result, Err(MjPluginError::NullBytePath)));
98 }
99
100 #[test]
101 fn load_all_plugin_libraries_null_byte_error() {
102 let result = load_all_plugin_libraries("dir\0null", None);
103 assert!(matches!(result, Err(MjPluginError::NullBytePath)));
104 }
105
106 #[cfg(unix)]
107 #[test]
108 fn load_plugin_library_invalid_utf8_error() {
109 use std::os::unix::ffi::OsStrExt;
110 use std::ffi::OsStr;
111 let path = OsStr::from_bytes(&[0xFF, 0xFE]);
112 assert!(matches!(
113 load_plugin_library(path),
114 Err(MjPluginError::InvalidUtf8Path)
115 ));
116 }
117
118 #[cfg(unix)]
119 #[test]
120 fn load_all_plugin_libraries_invalid_utf8_error() {
121 use std::os::unix::ffi::OsStrExt;
122 use std::ffi::OsStr;
123 let path = OsStr::from_bytes(&[0xFF, 0xFE]);
124 assert!(matches!(
125 load_all_plugin_libraries(path, None),
126 Err(MjPluginError::InvalidUtf8Path)
127 ));
128 }
129}
130