libredr/
client.rs

1use std::sync::{Arc, Mutex};
2use std::collections::HashMap;
3use tracing::info;
4use ndarray::prelude::*;
5use anyhow::{Context, Error, Result, Ok, bail, anyhow};
6use common::message::*;
7use common::geometry::Geometry;
8use common::connection::Connection;
9
10/// Rust interface for LibreDR client
11#[derive(Debug)]
12pub struct LibreDR {
13  connection: Connection,
14}
15
16impl LibreDR {
17  /// Construct `LibreDR` by connecting to LibreDR server\
18  /// Return `Error` if connection failed\
19  /// # Examples
20  /// ```
21  /// async {
22  ///   let client_tcp = LibreDR::new(String::from("127.0.0.1:9001"), false, false).await?;
23  ///   let client_unix = LibreDR::new(String::from("/var/run/libredr_client.sock"), true, false).await?;
24  /// }
25  /// ```
26  pub async fn new(connect: String, unix: bool, tls: bool) -> Result<Self> {
27    info!("LibreDR::new: Connecting to server {connect}");
28    let mut config: HashMap<String, String> = HashMap::new();
29    config.insert(String::from("connect"), connect.to_owned());
30    config.insert(String::from("unix"), String::from(if unix { "true" } else { "false" }));
31    config.insert(String::from("tls"), String::from(if tls { "true" } else { "false" }));
32    let connection = Connection::from_config(&config).await
33      .with_context(|| format!("Failed to connect to server {connect}"))?;
34    Ok(LibreDR {
35      connection,
36    })
37  }
38
39  /// Receive messages, response `RequestData` task, until receive a different task
40  async fn try_recv_msg_response_data(&mut self, data_cache: &DataCache) -> Result<Message> {
41    loop {
42      let msg_response = self.connection.recv_msg().await?;
43      info!("LibreDR::try_recv_msg_response_data: msg_response {msg_response}");
44      if let Message::RequestData(hash) = msg_response {
45        let data = {
46          let data_cache = data_cache.lock().expect("No task should panic");
47          let entry = data_cache.get(&hash).ok_or(format!("Client: ray_tracing_forward: unexpected hash {hash}"));
48          entry.map(|entry| {
49            entry.1.to_owned()
50          })
51        };
52        let msg_response = Message::ResponseData(Box::new(data.to_owned()));
53        self.connection.send_msg(&msg_response).await?;
54        data.map_err(Error::msg)?;
55      } else {
56        break Ok(msg_response);
57      }
58    }
59  }
60
61  /// Create a [`RequestRayTracingForward`] task and wait for response
62  ///
63  /// # Arguments
64  /// * `ray` - ray parameters
65  ///   * if `camera_space` is `false` 18 * `image_shape`
66  ///     * including ray position 9 * `image_shape`
67  ///     * including ray direction 9 * `image_shape`
68  ///   * if `camera_space` is `true`, add another (1 + 14) channels
69  ///     * including ray depth 1 * `image_shape` (if depth <= 0, treat as hit miss)
70  ///     * including ray material 14 * `image_shape`
71  /// * `texture` - (3 + 3 + 3 + 1 + 3 + 1) * `texture_resolution` * `texture_resolution` (must be square image)
72  ///   * including normal + diffuse + specular + roughness + intensity + window
73  /// * `envmap` - 3 * 6 * `envmap_resolution` * `envmap_resolution`
74  ///   * (must be box unwrapped 6 square images)
75  /// * `sample_per_pixel` - `sample_per_pixel_forward`, `sample_per_pixel_backward`
76  /// * `max_bounce` - `max_bounce_forward`, `max_bounce_backward`, `max_bounce_low_discrepancy`, `skip_bounce`
77  /// * `switches` - tuple of 4 switches to determine hit miss and reflection behavior
78  ///   * render::MISS_* - determine how to deal with ray hit miss
79  ///     * [`common::render::MISS_NONE`]
80  ///     * [`common::render::MISS_ENVMAP`]
81  ///   * render::REFLECTION_NORMAL_* - determine how to get surface normal
82  ///     * [`common::render::REFLECTION_NORMAL_FACE`]
83  ///     * [`common::render::REFLECTION_NORMAL_VERTEX`]
84  ///     * [`common::render::REFLECTION_NORMAL_TEXTURE`]
85  ///   * render::REFLECTION_DIFFUSE_* - determine diffuse reflection model
86  ///     * [`common::render::REFLECTION_DIFFUSE_NONE`]
87  ///     * [`common::render::REFLECTION_DIFFUSE_LAMBERTIAN`]
88  ///   * render::REFLECTION_SPECULAR_* - determine specular reflection model
89  ///     * [`common::render::REFLECTION_SPECULAR_NONE`]
90  ///     * [`common::render::REFLECTION_SPECULAR_PHONG`]
91  ///     * [`common::render::REFLECTION_SPECULAR_BLINN_PHONG`]
92  ///     * [`common::render::REFLECTION_SPECULAR_TORRANCE_SPARROW_PHONG`]
93  ///     * [`common::render::REFLECTION_SPECULAR_TORRANCE_SPARROW_BLINN_PHONG`]
94  ///     * [`common::render::REFLECTION_SPECULAR_TORRANCE_SPARROW_BECKMANN`]
95  /// * `clip_near` - clip near distance for camera
96  ///   * `clip_near` can be a single float number (same for all bounces),
97  ///   * or tuple of 3 float numbers (first bounce, second bounce, and other bounces)
98  /// * `camera_space` - if `true`, the first bounce uses the depth and material given by the ray
99  /// * `requires_grad` - if `true`, worker will save intermediate data, the next task must be `ray_tracing_backward`
100  /// * `srand` - random seed
101  ///   * if srand >= 0, the same random seed is used for every pixel
102  ///   * if srand < 0, use different seed for each pixel
103  /// * `low_discrepancy` - (optional) start id of Halton low discrepancy sequence.
104  ///   * The default value is the same as `sample_per_pixel_forward`.
105  ///   * if combine multiple rendered images to reduce noise, this value can be set to: \
106  ///       1 * `sample_per_pixel_forward`, 2 * `sample_per_pixel_forward`, 3 * `sample_per_pixel_forward`, ...
107  ///
108  /// # Return
109  /// Return shape will be,
110  /// * if `camera_space` is `true`
111  ///   * render image 3 * `image_shape`
112  /// * if `camera_space` is `false`, add another
113  ///   * ray texture coordinate 2 * `image_shape`
114  ///   * ray depth (Euclidean distance) 1 * `image_shape`
115  ///   * ray normal 3 * `image_shape`
116  #[allow(clippy::too_many_arguments)]
117  pub async fn ray_tracing_forward(&mut self,
118      geometry: &Geometry,
119      geometry_data_cache: &DataCache,
120      ray: ArrayD<f32>,
121      texture: Array3<f32>,
122      envmap: Array4<f32>,
123      sample_per_pixel: (usize, usize),
124      max_bounce: (usize, usize, usize, usize),
125      switches: (u8, u8, u8, u8),
126      clip_near: (f32, f32, f32),
127      camera_space: bool,
128      requires_grad: bool,
129      srand: i32,
130      low_discrepancy: u32) -> Result<ArrayD<f32>> {
131    let input_shape = ray.shape().to_owned();
132    assert!(input_shape.len() > 1, "ray_tracing_forward: ray should be at least 1D");
133    let ray_channels_input = if camera_space { 33 } else { 18 };
134    assert_eq!(input_shape[0], ray_channels_input,
135      "ray_tracing_forward: ray channel: {}, expected: {}", input_shape[0], ray_channels_input);
136    let ray = ray.to_shape([ray_channels_input, input_shape[1..].iter().product()])?.into_owned();
137    let mut data_cache_content = hashbrown::HashMap::new();
138    {
139      let geometry_data_cache = geometry_data_cache.lock().expect("No task should panic");
140      data_cache_content.extend(geometry_data_cache.iter().map(|(k, v)| {
141        (k.to_owned(), v.to_owned())
142      }));
143    }
144    let ray_data = Data::RayData(ray);
145    let ray_data_hash = ray_data.hash();
146    data_cache_content.insert(ray_data_hash.to_owned(), (0, ray_data));
147    let material_data = Data::MaterialData(texture, envmap);
148    let material_data_hash = material_data.hash();
149    data_cache_content.insert(material_data_hash.to_owned(), (0, material_data));
150    let data_cache = Arc::new(Mutex::new(data_cache_content));
151    let request = RequestRayTracingForward {
152      geometry: geometry.to_owned(),
153      ray: ray_data_hash,
154      material: material_data_hash,
155      sample_per_pixel,
156      max_bounce,
157      switches,
158      clip_near,
159      camera_space,
160      requires_grad,
161      srand,
162      low_discrepancy
163    };
164    let msg_request = Message::RequestTask(RequestTask::RequestRayTracingForward(Box::new(request)));
165    info!("LibreDR::ray_tracing_forward: msg_request {msg_request}");
166    self.connection.send_msg(&msg_request).await?;
167    let msg_response = self.try_recv_msg_response_data(&data_cache).await?;
168    let Message::ResponseTask(response_task) = msg_response else {
169      bail!("Unexpected response {msg_response}");
170    };
171    let response_task = response_task.map_err(Error::msg)?;
172    let ResponseTask::ResponseRayTracingForward(response) = response_task else {
173      bail!("Unexpected response {response_task}");
174    };
175    let mut output_shape = input_shape.to_owned();
176    output_shape[0] = if camera_space { 3 } else { 9 };
177    let response = response.render.to_shape(output_shape)?.into_owned();
178    Ok(response)
179  }
180
181  /// Create a [`RequestRayTracingBackward`] task and wait for response.
182  ///
183  /// Must be called consecutive to a [`RequestRayTracingForward`] task with `requires_grad` set to `true`. \
184  /// To create multiple [`RequestRayTracingForward`] tasks and backward together, multiple client connections are
185  /// required.
186  ///
187  /// # Arguments
188  /// * `d_ray` - gradient of image 3 * `image_shape` (must ensure same `image_shape` as [`RequestRayTracingForward`])
189  ///
190  /// # Return
191  /// Return shape will be,
192  /// * if `camera_space` is `false` for [`RequestRayTracingForward`] task
193  ///   * 1st return value (3 + 3 + 3 + 1 + 3 + 1) * `texture_resolution` * `texture_resolution`
194  ///     * (same `texture_resolution` as [`RequestRayTracingForward`])
195  ///     * including d_normal + d_diffuse + d_specular + d_roughness + d_intensity + d_window
196  ///   * 2nd return value 3 * 6 * `envmap_resolution` * `envmap_resolution`
197  ///     * (same `envmap_resolution` as [`RequestRayTracingForward`])
198  ///     * including d_envmap
199  /// * if `camera_space` is `true` for [`RequestRayTracingForward`] task, add another
200  ///   * 3rd return value 14 * `image_shape` (same shape as [`RequestRayTracingForward`])
201  ///     * including d_ray_texture
202  #[allow(clippy::too_many_arguments)]
203  pub async fn ray_tracing_backward(&mut self, d_ray: ArrayD<f32>) ->
204      Result<(Array3<f32>, Array4<f32>, Option<ArrayD<f32>>)> {
205    let input_shape = d_ray.shape().to_owned();
206    assert!(input_shape.len() > 1, "ray_tracing_backward: d_ray should be at least 1D");
207    assert_eq!(input_shape[0], 3,
208      "ray_tracing_backward: d_ray channel: {}, expected: {}", input_shape[0], 3);
209    let d_ray = d_ray.to_shape([3, input_shape[1..].iter().product()])?.into_owned();
210    let request = RequestRayTracingBackward { d_ray, };
211    let msg_request = Message::RequestTask(RequestTask::RequestRayTracingBackward(Box::new(request)));
212    info!("LibreDR::ray_tracing_backward: msg_request {msg_request}");
213    self.connection.send_msg(&msg_request).await?;
214    let msg_response = self.connection.recv_msg().await?;
215    info!("LibreDR::ray_tracing_backward: msg_response {msg_response}");
216    let Message::ResponseTask(response_task) = msg_response else {
217      bail!("Unexpected response {msg_response}");
218    };
219    let response_task = response_task.map_err(Error::msg)?;
220    let ResponseTask::ResponseRayTracingBackward(response) = response_task else {
221      bail!("Unexpected response {response_task}");
222    };
223    let d_ray_texture = response.d_ray_texture.map(|d_ray_texture| {
224      let mut output_shape = input_shape.to_owned();
225      output_shape[0] = 14;
226      Ok(d_ray_texture.to_shape(output_shape)?.into_owned())
227    }).transpose()?;
228    Ok((
229      response.d_texture.ok_or(anyhow!("LibreDR::ray_tracing_backward: None d_texture returned from server"))?,
230      response.d_envmap.ok_or(anyhow!("LibreDR::ray_tracing_backward: None d_envmap returned from server"))?,
231      d_ray_texture))
232  }
233
234  /// Send [`Message::Close`] to server to close cleanly
235  pub async fn close(&mut self) -> Result<()> {
236    self.connection.send_msg(&Message::Close()).await
237  }
238}