Skip to main content

mujoco_rs/wrappers/
mj_plugin.rs

1//! MuJoCo plugin library loading.
2use std::ffi::{CString, c_char, c_int};
3use std::path::Path;
4
5use crate::mujoco_c::{mj_loadAllPluginLibraries, mj_loadPluginLibrary};
6use crate::error::MjPluginError;
7
8/// Callback invoked by [`load_all_plugin_libraries`] for each loaded library.
9///
10/// Parameters: `filename`, `first` plugin index, `count` of plugins registered.
11pub type MjPluginLibraryLoadCallback =
12    Option<unsafe extern "C" fn(*const c_char, c_int, c_int)>;
13
14/// Loads a single MuJoCo plugin shared library.
15///
16/// # Errors
17/// Returns [`MjPluginError`] if `path` is not valid UTF-8 or contains a null byte.
18///
19/// # Examples
20///
21/// Load the PID actuator plugin before using PID-based actuators in a model:
22///
23/// ```no_run
24/// use mujoco_rs::prelude::*;
25///
26/// load_plugin_library("path/to/mujoco/bin/mujoco_plugin/libactuator.so")
27///     .expect("failed to load actuator plugin");
28///
29/// let model = MjModel::from_xml("model.xml").expect("could not load the model");
30/// ```
31pub fn load_plugin_library<P: AsRef<Path>>(path: P) -> Result<(), MjPluginError> {
32    let s = path.as_ref().to_str().ok_or(MjPluginError::InvalidUtf8Path)?;
33    let c = CString::new(s).map_err(|_| MjPluginError::NullBytePath)?;
34    // SAFETY: `c` is a valid null-terminated string; MuJoCo does not retain the pointer.
35    unsafe { mj_loadPluginLibrary(c.as_ptr()) };
36    Ok(())
37}
38
39/// Loads all MuJoCo plugin shared libraries found in `directory`.
40///
41/// Pass `None` for `callback` to omit per-library notification.
42///
43/// # Errors
44/// Returns [`MjPluginError`] if `directory` is not valid UTF-8 or contains a null byte.
45///
46/// # Examples
47///
48/// Load all MuJoCo plugins from the plugin directory (e.g. to enable PID actuators,
49/// cable elasticity simulation, SDF collision shapes, or custom sensors):
50///
51/// ```no_run
52/// use mujoco_rs::prelude::*;
53///
54/// // Load all MuJoCo plugins from the plugin directory.
55/// // Adjust the path to match your MuJoCo installation.
56/// load_all_plugin_libraries("path/to/mujoco/bin/mujoco_plugin", None)
57///     .expect("failed to load plugin libraries");
58///
59/// let model = MjModel::from_xml("model.xml").expect("could not load the model");
60/// ```
61pub fn load_all_plugin_libraries<P: AsRef<Path>>(
62    directory: P,
63    callback: MjPluginLibraryLoadCallback,
64) -> Result<(), MjPluginError> {
65    let s = directory.as_ref().to_str().ok_or(MjPluginError::InvalidUtf8Path)?;
66    let c = CString::new(s).map_err(|_| MjPluginError::NullBytePath)?;
67    // SAFETY: `c` is a valid null-terminated string. `callback`, if non-null, matches the
68    // expected signature and is valid for the duration of the call.
69    unsafe { mj_loadAllPluginLibraries(c.as_ptr(), callback) };
70    Ok(())
71}
72
73#[cfg(test)]
74mod tests {
75    use super::{load_all_plugin_libraries, load_plugin_library};
76
77    use crate::error::MjPluginError;
78
79    /// Verifies that [`load_all_plugin_libraries`] works.
80    #[test]
81    fn load_all_plugin_libraries_loads_from_directory() {
82        let lib_dir = match std::env::var("MUJOCO_DYNAMIC_LINK_DIR") {
83            Ok(d) => d,
84            Err(_) => return,
85        };
86        let plugin_dir = std::path::Path::new(&lib_dir)
87            .parent()
88            .expect("MUJOCO_DYNAMIC_LINK_DIR should have a parent directory")
89            .join("bin/mujoco_plugin");
90
91        load_all_plugin_libraries(&plugin_dir, None).expect("plugin dir should load");
92    }
93
94    #[test]
95    fn load_plugin_library_null_byte_error() {
96        let result = load_plugin_library("path\0with\0nulls");
97        assert!(matches!(result, Err(MjPluginError::NullBytePath)));
98    }
99
100    #[test]
101    fn load_all_plugin_libraries_null_byte_error() {
102        let result = load_all_plugin_libraries("dir\0null", None);
103        assert!(matches!(result, Err(MjPluginError::NullBytePath)));
104    }
105
106    #[cfg(unix)]
107    #[test]
108    fn load_plugin_library_invalid_utf8_error() {
109        use std::os::unix::ffi::OsStrExt;
110        use std::ffi::OsStr;
111        let path = OsStr::from_bytes(&[0xFF, 0xFE]);
112        assert!(matches!(
113            load_plugin_library(path),
114            Err(MjPluginError::InvalidUtf8Path)
115        ));
116    }
117
118    #[cfg(unix)]
119    #[test]
120    fn load_all_plugin_libraries_invalid_utf8_error() {
121        use std::os::unix::ffi::OsStrExt;
122        use std::ffi::OsStr;
123        let path = OsStr::from_bytes(&[0xFF, 0xFE]);
124        assert!(matches!(
125            load_all_plugin_libraries(path, None),
126            Err(MjPluginError::InvalidUtf8Path)
127        ));
128    }
129}
130