1use std::sync::{Arc, Mutex};
2use std::collections::HashMap;
3use tracing::info;
4use ndarray::prelude::*;
5use anyhow::{Context, Error, Result, Ok, bail, anyhow};
6use common::message::*;
7use common::geometry::Geometry;
8use common::connection::Connection;
9
10#[derive(Debug)]
12pub struct LibreDR {
13 connection: Connection,
14}
15
16impl LibreDR {
17 pub async fn new(connect: String, unix: bool, tls: bool) -> Result<Self> {
27 info!("LibreDR::new: Connecting to server {connect}");
28 let mut config: HashMap<String, String> = HashMap::new();
29 config.insert(String::from("connect"), connect.to_owned());
30 config.insert(String::from("unix"), String::from(if unix { "true" } else { "false" }));
31 config.insert(String::from("tls"), String::from(if tls { "true" } else { "false" }));
32 let connection = Connection::from_config(&config).await
33 .with_context(|| format!("Failed to connect to server {connect}"))?;
34 Ok(LibreDR {
35 connection,
36 })
37 }
38
39 async fn try_recv_msg_response_data(&mut self, data_cache: &DataCache) -> Result<Message> {
41 loop {
42 let msg_response = self.connection.recv_msg().await?;
43 info!("LibreDR::try_recv_msg_response_data: msg_response {msg_response}");
44 if let Message::RequestData(hash) = msg_response {
45 let data = {
46 let data_cache = data_cache.lock().expect("No task should panic");
47 let entry = data_cache.get(&hash).ok_or(format!("Client: ray_tracing_forward: unexpected hash {hash}"));
48 entry.map(|entry| {
49 entry.1.to_owned()
50 })
51 };
52 let msg_response = Message::ResponseData(Box::new(data.to_owned()));
53 self.connection.send_msg(&msg_response).await?;
54 data.map_err(Error::msg)?;
55 } else {
56 break Ok(msg_response);
57 }
58 }
59 }
60
61 #[allow(clippy::too_many_arguments)]
117 pub async fn ray_tracing_forward(&mut self,
118 geometry: &Geometry,
119 geometry_data_cache: &DataCache,
120 ray: ArrayD<f32>,
121 texture: Array3<f32>,
122 envmap: Array4<f32>,
123 sample_per_pixel: (usize, usize),
124 max_bounce: (usize, usize, usize, usize),
125 switches: (u8, u8, u8, u8),
126 clip_near: (f32, f32, f32),
127 camera_space: bool,
128 requires_grad: bool,
129 srand: i32,
130 low_discrepancy: u32) -> Result<ArrayD<f32>> {
131 let input_shape = ray.shape().to_owned();
132 assert!(input_shape.len() > 1, "ray_tracing_forward: ray should be at least 1D");
133 let ray_channels_input = if camera_space { 33 } else { 18 };
134 assert_eq!(input_shape[0], ray_channels_input,
135 "ray_tracing_forward: ray channel: {}, expected: {}", input_shape[0], ray_channels_input);
136 let ray = ray.to_shape([ray_channels_input, input_shape[1..].iter().product()])?.into_owned();
137 let mut data_cache_content = hashbrown::HashMap::new();
138 {
139 let geometry_data_cache = geometry_data_cache.lock().expect("No task should panic");
140 data_cache_content.extend(geometry_data_cache.iter().map(|(k, v)| {
141 (k.to_owned(), v.to_owned())
142 }));
143 }
144 let ray_data = Data::RayData(ray);
145 let ray_data_hash = ray_data.hash();
146 data_cache_content.insert(ray_data_hash.to_owned(), (0, ray_data));
147 let material_data = Data::MaterialData(texture, envmap);
148 let material_data_hash = material_data.hash();
149 data_cache_content.insert(material_data_hash.to_owned(), (0, material_data));
150 let data_cache = Arc::new(Mutex::new(data_cache_content));
151 let request = RequestRayTracingForward {
152 geometry: geometry.to_owned(),
153 ray: ray_data_hash,
154 material: material_data_hash,
155 sample_per_pixel,
156 max_bounce,
157 switches,
158 clip_near,
159 camera_space,
160 requires_grad,
161 srand,
162 low_discrepancy
163 };
164 let msg_request = Message::RequestTask(RequestTask::RequestRayTracingForward(Box::new(request)));
165 info!("LibreDR::ray_tracing_forward: msg_request {msg_request}");
166 self.connection.send_msg(&msg_request).await?;
167 let msg_response = self.try_recv_msg_response_data(&data_cache).await?;
168 let Message::ResponseTask(response_task) = msg_response else {
169 bail!("Unexpected response {msg_response}");
170 };
171 let response_task = response_task.map_err(Error::msg)?;
172 let ResponseTask::ResponseRayTracingForward(response) = response_task else {
173 bail!("Unexpected response {response_task}");
174 };
175 let mut output_shape = input_shape.to_owned();
176 output_shape[0] = if camera_space { 3 } else { 9 };
177 let response = response.render.to_shape(output_shape)?.into_owned();
178 Ok(response)
179 }
180
181 #[allow(clippy::too_many_arguments)]
203 pub async fn ray_tracing_backward(&mut self, d_ray: ArrayD<f32>) ->
204 Result<(Array3<f32>, Array4<f32>, Option<ArrayD<f32>>)> {
205 let input_shape = d_ray.shape().to_owned();
206 assert!(input_shape.len() > 1, "ray_tracing_backward: d_ray should be at least 1D");
207 assert_eq!(input_shape[0], 3,
208 "ray_tracing_backward: d_ray channel: {}, expected: {}", input_shape[0], 3);
209 let d_ray = d_ray.to_shape([3, input_shape[1..].iter().product()])?.into_owned();
210 let request = RequestRayTracingBackward { d_ray, };
211 let msg_request = Message::RequestTask(RequestTask::RequestRayTracingBackward(Box::new(request)));
212 info!("LibreDR::ray_tracing_backward: msg_request {msg_request}");
213 self.connection.send_msg(&msg_request).await?;
214 let msg_response = self.connection.recv_msg().await?;
215 info!("LibreDR::ray_tracing_backward: msg_response {msg_response}");
216 let Message::ResponseTask(response_task) = msg_response else {
217 bail!("Unexpected response {msg_response}");
218 };
219 let response_task = response_task.map_err(Error::msg)?;
220 let ResponseTask::ResponseRayTracingBackward(response) = response_task else {
221 bail!("Unexpected response {response_task}");
222 };
223 let d_ray_texture = response.d_ray_texture.map(|d_ray_texture| {
224 let mut output_shape = input_shape.to_owned();
225 output_shape[0] = 14;
226 Ok(d_ray_texture.to_shape(output_shape)?.into_owned())
227 }).transpose()?;
228 Ok((
229 response.d_texture.ok_or(anyhow!("LibreDR::ray_tracing_backward: None d_texture returned from server"))?,
230 response.d_envmap.ok_or(anyhow!("LibreDR::ray_tracing_backward: None d_envmap returned from server"))?,
231 d_ray_texture))
232 }
233
234 pub async fn close(&mut self) -> Result<()> {
236 self.connection.send_msg(&Message::Close()).await
237 }
238}