mujoco_rs/
renderer.rs

1//! Module related to implementation of the [`MjRenderer`].
2use crate::wrappers::mj_visualization::MjvScene;
3use crate::wrappers::mj_rendering::MjrContext;
4use crate::prelude::*;
5
6use glfw::{Context, InitError, PWindow, WindowHint};
7use bitflags::bitflags;
8use png::Encoder;
9
10use std::io::{self, BufWriter, ErrorKind, Write};
11use std::fmt::Display;
12use std::error::Error;
13use std::path::Path;
14use std::fs::File;
15
16
17const RGB_NOT_FOUND_ERR_STR: &str = "RGB rendering is not enabled (renderer.with_rgb_rendering(true))";
18const DEPTH_NOT_FOUND_ERR_STR: &str = "depth rendering is not enabled (renderer.with_depth_rendering(true))";
19const INVALID_INPUT_SIZE: &str = "the input width and height don't match the renderer's configuration";
20
21
22/// A renderer for rendering 3D scenes.
23/// By default, RGB rendering is enabled and depth rendering is disabled.
24pub struct MjRenderer<'m> {
25    scene: MjvScene<'m>,
26    user_scene: MjvScene<'m>,
27    context: MjrContext,
28    model: &'m MjModel,
29
30    /* Glfw */
31    window: PWindow,
32
33    /* Configuration */
34    camera: MjvCamera,
35    option: MjvOption,
36    flags: RendererFlags,
37
38    /* Storage */
39    // Use Box to allow less space to be used
40    // when rgb or depth rendering is disabled
41    rgb: Option<Box<[u8]>>,
42    depth: Option<Box<[f32]>>,
43
44    width: usize,
45    height: usize,
46}
47
48impl<'m> MjRenderer<'m> {
49    /// Construct a new renderer.
50    /// `max_geom` represents the maximum number of geoms the [`MjvScene`] can hold, which
51    /// includes both the user-drawn geoms and the required from [`MjData`] state.
52    pub fn new(model: &'m MjModel, width: usize, height: usize, max_geom: usize) -> Result<Self, RendererError> {
53        let mut glfw = glfw::init_no_callbacks()
54            .map_err(|err| RendererError::GlfwInitError(err))?;
55
56        /* Create window for rendering */
57        glfw.window_hint(WindowHint::Visible(false));
58        let (mut _window, _) = match glfw.create_window(width as u32, height as u32, "", glfw::WindowMode::Windowed) {
59            Some(x) => Ok(x),
60            None => Err(RendererError::WindowCreationError)
61        }?;
62
63        _window.make_current();
64        glfw.set_swap_interval(glfw::SwapInterval::None);
65
66        /* Initialize the rendering context to render to the offscreen buffer. */
67        let mut context = MjrContext::new(model);
68        context.offscreen();
69
70        /* The 3D scene for visualization */
71        let scene = MjvScene::new(model, model.ffi().ngeom as usize + max_geom + EXTRA_SCENE_GEOM_SPACE);
72        let user_scene = MjvScene::new(model, max_geom);
73
74        let camera = MjvCamera::new_free(model);
75        let option = MjvOption::default();
76
77        let mut s = Self {
78            scene, user_scene, context, window: _window, model, camera, option,
79            flags: RendererFlags::empty(), rgb: None, depth: None,
80            width, height
81        };
82
83        s = s.with_rgb_rendering(true);
84        Ok(s)
85    }
86
87    /// Enables/disables RGB rendering. To be used on construction.
88    pub fn with_rgb_rendering(mut self, enable: bool) -> Self {
89        // 1. Define the type we want to allocate
90        self.flags.set(RendererFlags::RENDER_RGB, enable);
91        self.rgb = if enable { Some(vec![0; 3 * self.width * self.height].into_boxed_slice()) } else { None } ;
92        self
93    }
94
95    /// Enables/disables depth rendering. To be used on construction.
96    pub fn with_depth_rendering(mut self, enable: bool) -> Self {
97        self.flags.set(RendererFlags::RENDER_DEPTH, enable);
98        self.depth = if enable { Some(vec![0.0; self.width * self.height].into_boxed_slice()) } else { None } ;
99        self
100    }
101
102    /// Returns an immutable reference to the internal scene.
103    pub fn scene(&self) -> &MjvScene<'m>{
104        &self.scene
105    }
106
107    /// Returns an immutable reference to a user scene for drawing custom visual-only geoms.
108    pub fn user_scene(&self) -> &MjvScene<'m>{
109        &self.user_scene
110    }
111
112    /// Returns a mutable reference to a user scene for drawing custom visual-only geoms.
113    pub fn user_scene_mut(&mut self) -> &mut MjvScene<'m>{
114        &mut self.user_scene
115    }
116
117    /// Sets the font size. To be used on construction.
118    pub fn with_font_scale(mut self, font_scale: MjtFontScale) -> Self {
119        self.context.change_font(font_scale);
120        self
121    }
122
123    /// Update the visualization options and return a reference to self. To be used on construction.
124    pub fn with_opts(mut self, options: MjvOption) -> Self {
125        self.option = options;
126        self
127    }
128
129    /// Render images using the `camera`. To be used on construction.
130    pub fn with_camera(mut self, camera: MjvCamera) -> Self  {
131        self.camera = camera;
132        self
133    }
134
135    /// Update the scene with new data from data.
136    pub fn sync(&mut self, data: &mut MjData) {
137        let model_data_ptr = unsafe {  data.model().__raw() };
138        let bound_model_ptr = unsafe { self.model.__raw() };
139        assert_eq!(model_data_ptr, bound_model_ptr, "'data' must be created from the same model as the renderer.");
140
141        self.scene.update(data, &self.option, &MjvPerturb::default(), &mut self.camera);
142
143        /* Draw user scene geoms */
144        sync_geoms(&self.user_scene, &mut self.scene)
145            .expect("could not sync the user scene with the internal scene; this is a bug, please report it.");
146
147        self.render();
148    }
149
150    /// Returns a flattened RGB image of the scene.
151    pub fn rgb_flat(&self) -> Option<&[u8]> {
152        self.rgb.as_deref()
153    }
154
155    /// Returns an RGB image of the scene. This methods accepts two generic parameters <WIDTH, HEIGHT>
156    /// that define the shape of the output slice.
157    pub fn rgb<const WIDTH: usize, const HEIGHT: usize>(&self) -> io::Result<&[[[u8; 3]; WIDTH]; HEIGHT]> {
158        if let Some(flat) = self.rgb_flat() {
159            if flat.len() == WIDTH * HEIGHT * 3 {
160                let p_shaped = flat.as_ptr() as *const [[[u8; 3]; WIDTH]; HEIGHT];
161
162                // SAFETY: The alignment of the output is the same as the original.
163                // The lifetime also matches  'a in &'a self, which prevents data races.
164                // Length (number of elements) matches the output's.
165                Ok(unsafe { p_shaped.as_ref().unwrap() })
166            }
167            else {
168                Err(io::Error::new(io::ErrorKind::InvalidInput, INVALID_INPUT_SIZE))
169            }
170        }
171        else {
172            Err(io::Error::new(io::ErrorKind::NotFound, RGB_NOT_FOUND_ERR_STR))
173        }
174    }
175
176    /// Returns a flattened depth image of the scene.
177    pub fn depth_flat(&self) -> Option<&[f32]> {
178        self.depth.as_deref()
179    }
180
181    /// Returns a depth image of the scene. This methods accepts two generic parameters <WIDTH, HEIGHT>
182    /// that define the shape of the output slice.
183    pub fn depth<const WIDTH: usize, const HEIGHT: usize>(&self) -> io::Result<&[[f32; WIDTH]; HEIGHT]> {
184        if let Some(flat) = self.depth_flat() {
185            if flat.len() == WIDTH * HEIGHT {
186                let p_shaped = flat.as_ptr() as *const [[f32; WIDTH]; HEIGHT];
187
188                // SAFETY: The alignment of the output is the same as the original.
189                // The lifetime matches  'a in &'a self, which prevents data races.
190                // Length (number of elements) matches the output's.
191                Ok(unsafe { p_shaped.as_ref().unwrap() })
192            }
193            else {
194                Err(io::Error::new(io::ErrorKind::InvalidInput, INVALID_INPUT_SIZE))
195            }
196        }
197        else {
198            Err(io::Error::new(io::ErrorKind::NotFound, DEPTH_NOT_FOUND_ERR_STR))
199        }
200    }
201
202    /// Save an RGB image of the scene to a path.
203    /// # Errors
204    /// - [`ErrorKind::NotFound`] when RGB rendering is disabled,
205    /// - other errors related to write.
206    pub fn save_rgb<T: AsRef<Path>>(&self, path: T) -> io::Result<()> {
207        if let Some(rgb) = &self.rgb {
208            let file = File::create(path.as_ref())?;
209            let w = BufWriter::new(file);
210
211            let mut encoder = Encoder::new(w, self.width as u32, self.height as u32);
212            encoder.set_color(png::ColorType::Rgb);
213            encoder.set_depth(png::BitDepth::Eight);
214            encoder.set_compression(png::Compression::NoCompression);
215
216            let mut writer = encoder.write_header()?;
217            writer.write_image_data(rgb)?;
218            Ok(())
219        }
220        else {
221            Err(io::Error::new(ErrorKind::NotFound, RGB_NOT_FOUND_ERR_STR))
222        }
223    }
224
225    /// Save a depth image of the scene to a path. The image is 16-bit PNG, which
226    /// can be converted into depth (distance) data by dividing the grayscale values by 65535.0 an applying reverse denormalization.
227    /// If `normalize` is `true`, then the data is normalized with min-max normalization.
228    /// Use of [`MjRenderer::save_depth_raw`] is recommended if performance is critical, as
229    /// it skips PNG encoding and also saves the true depth values directly.
230    /// # Returns
231    /// An [`Ok`]`((min, max))` is returned, where min and max represent the normalization parameters.
232    /// # Errors
233    /// - [`ErrorKind::NotFound`] when depth rendering is disabled,
234    /// - other errors related to write.
235    pub fn save_depth<T: AsRef<Path>>(&self, path: T, normalize: bool) -> io::Result<(f32, f32)> {
236        if let Some(depth) = &self.depth {
237            let file = File::create(path.as_ref())?;
238            let w = BufWriter::new(file);
239
240            let mut encoder = Encoder::new(w, self.width as u32, self.height as u32);
241            encoder.set_color(png::ColorType::Grayscale);
242            encoder.set_depth(png::BitDepth::Sixteen);
243            encoder.set_compression(png::Compression::NoCompression);
244
245            let (norm, min, max) =
246            if normalize {
247                let max = depth.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
248                let min = depth.iter().cloned().fold(f32::INFINITY, f32::min);
249                (depth.iter().flat_map(|&x| (((x - min) / (max - min) * 65535.0).min(65535.0) as u16).to_be_bytes()).collect::<Box<_>>(), min, max)
250            }
251            else {
252                (depth.iter().flat_map(|&x| ((x * 65535.0).min(65535.0) as u16).to_be_bytes()).collect::<Box<_>>(), 0.0, 1.0)
253            };
254
255            let mut writer = encoder.write_header()?;
256            writer.write_image_data(&norm)?;
257            Ok((min, max))
258        }
259        else {
260            Err(io::Error::new(ErrorKind::NotFound, DEPTH_NOT_FOUND_ERR_STR))
261        }
262    }
263
264    /// Save the raw depth data to the `path`. The data is encoded
265    /// as a sequence of bytes, where groups of four represent a single f32 value.
266    /// The lower bytes of individual f32 appear first (low-endianness).
267    /// # Errors
268    /// - [`ErrorKind::NotFound`] when depth rendering is disabled,
269    /// - other errors related to write.
270    pub fn save_depth_raw<T: AsRef<Path>>(&self, path: T) -> io::Result<()> {
271        if let Some(depth) = &self.depth {
272            let file = File::create(path.as_ref())?;
273            let mut writer = BufWriter::new(file);
274
275            /* Fast conversion to a byte slice to prioritize performance */
276            let p = unsafe { std::slice::from_raw_parts(
277                depth.as_ptr() as *const u8,
278                std::mem::size_of::<f32>() * depth.len()
279            ) };
280
281            writer.write_all(p)?;
282            Ok(())
283        }
284        else {
285            Err(io::Error::new(ErrorKind::NotFound, DEPTH_NOT_FOUND_ERR_STR))
286        }
287    }
288
289    /// Draws the scene to internal arrays.
290    /// Use [`MjRenderer::rgb`] or [`MjRenderer::depth`] to obtain the rendered image.
291    fn render(&mut self) {
292        self.window.make_current();
293        let vp = MjrRectangle::new(0, 0, self.width as i32, self.height as i32);
294        self.scene.render(&vp, &self.context);
295
296        /* Fully flatten everything */
297        let flat_rgb = self.rgb.as_deref_mut();
298        let flat_depth = self.depth.as_deref_mut();
299
300        /* Read to whatever is enabled */
301        self.context.read_pixels(
302            flat_rgb,
303            flat_depth,
304            &vp
305        );
306
307        /* Make depth values be the actual distance in meters */
308        if let Some(depth) = self.depth.as_deref_mut() {
309            let map = &self.model.vis().map;
310            let near = map.znear;
311            let far = map.zfar;
312            for value in depth {
313                let z_ndc = 2.0 * *value - 1.0;
314                *value = 2.0 * near * far / (far + near - z_ndc * (far - near));
315            }
316        }
317    }
318}
319
320
321#[derive(Debug)]
322pub enum RendererError {
323    GlfwInitError(InitError),
324    WindowCreationError
325}
326
327impl Display for RendererError {
328    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329        match self {
330            Self::GlfwInitError(e) => write!(f, "glfw failed to initialize: {}", e),
331            Self::WindowCreationError => write!(f, "failed to create window"),
332        }
333    }
334}
335
336impl Error for RendererError {
337    fn source(&self) -> Option<&(dyn Error + 'static)> {
338        match self {
339            Self::GlfwInitError(e) => Some(e),
340            _ => None
341        }
342    }
343}
344
345bitflags! { 
346    /// Flags that enable features of the renderer.
347    struct RendererFlags: u8 {
348        const RENDER_RGB = 1 << 0;
349        const RENDER_DEPTH = 1 << 1;
350    }
351}
352
353
354
355/* 
356** Don't run any tests as OpenGL hates if anything
357** runs outside the main thread.
358*/
359
360// #[cfg(test)]
361// mod test {
362//     use super::*;
363
364//     const MODEL: &str = "
365//         <mujoco>
366//             <worldbody>
367//             </worldbody>
368//         </mujoco>
369//     ";
370
371//     // #[test]
372//     // fn test_update_normal() {
373//     //     // let model = MjModel::from_xml_string(MODEL).unwrap();
374//     //     // let model2 = MjModel::from_xml_string(MODEL).unwrap();
375//     //     // let mut data = model.make_data();
376//     //     // let mut renderer = Renderer::new(&model, 720, 1280).unwrap();
377        
378//     //     // /* Check if scene updates without errors. */
379//     //     // renderer.update_scene(&mut data);
380//     // }
381// }
382