1use ghostflow_core::Tensor;
11use crate::linear::Linear;
12use crate::Module;
13
14#[derive(Debug, Clone)]
16pub struct NeRFConfig {
17 pub num_samples_coarse: usize,
19 pub num_samples_fine: usize,
21 pub num_freq_bands: usize,
23 pub hidden_size: usize,
25 pub num_layers: usize,
27 pub skip_layer: usize,
29 pub use_view_dirs: bool,
31 pub near: f32,
33 pub far: f32,
35}
36
37impl Default for NeRFConfig {
38 fn default() -> Self {
39 NeRFConfig {
40 num_samples_coarse: 64,
41 num_samples_fine: 128,
42 num_freq_bands: 10,
43 hidden_size: 256,
44 num_layers: 8,
45 skip_layer: 4,
46 use_view_dirs: true,
47 near: 2.0,
48 far: 6.0,
49 }
50 }
51}
52
53impl NeRFConfig {
54 pub fn tiny() -> Self {
56 NeRFConfig {
57 num_samples_coarse: 32,
58 num_samples_fine: 64,
59 num_freq_bands: 6,
60 hidden_size: 128,
61 num_layers: 4,
62 skip_layer: 2,
63 use_view_dirs: true,
64 near: 2.0,
65 far: 6.0,
66 }
67 }
68
69 pub fn large() -> Self {
71 NeRFConfig {
72 num_samples_coarse: 128,
73 num_samples_fine: 256,
74 num_freq_bands: 15,
75 hidden_size: 512,
76 num_layers: 10,
77 skip_layer: 5,
78 use_view_dirs: true,
79 near: 2.0,
80 far: 6.0,
81 }
82 }
83}
84
85pub struct PositionalEncoder {
87 num_freq_bands: usize,
88 include_input: bool,
89}
90
91impl PositionalEncoder {
92 pub fn new(num_freq_bands: usize, include_input: bool) -> Self {
94 PositionalEncoder {
95 num_freq_bands,
96 include_input,
97 }
98 }
99
100 pub fn output_dim(&self, input_dim: usize) -> usize {
102 let encoded_dim = input_dim * self.num_freq_bands * 2; if self.include_input {
104 encoded_dim + input_dim
105 } else {
106 encoded_dim
107 }
108 }
109
110 pub fn encode(&self, x: &Tensor) -> Result<Tensor, String> {
112 let x_data = x.data_f32();
113 let dims = x.dims();
114 let input_dim = dims[dims.len() - 1];
115 let batch_size = x_data.len() / input_dim;
116
117 let output_dim = self.output_dim(input_dim);
118 let mut result = Vec::with_capacity(batch_size * output_dim);
119
120 for i in 0..batch_size {
121 let start = i * input_dim;
122 let end = start + input_dim;
123 let input_slice = &x_data[start..end];
124
125 if self.include_input {
127 result.extend_from_slice(input_slice);
128 }
129
130 for freq in 0..self.num_freq_bands {
132 let freq_scale = 2.0_f32.powi(freq as i32) * std::f32::consts::PI;
133 for &val in input_slice.iter() {
134 let scaled = val * freq_scale;
135 result.push(scaled.sin());
136 result.push(scaled.cos());
137 }
138 }
139 }
140
141 let mut output_dims = dims.to_vec();
142 output_dims[dims.len() - 1] = output_dim;
143
144 Tensor::from_slice(&result, &output_dims)
145 .map_err(|e| format!("Failed to create encoded tensor: {:?}", e))
146 }
147}
148
149pub struct NeRFMLP {
151 config: NeRFConfig,
152 pos_encoder: PositionalEncoder,
153 dir_encoder: Option<PositionalEncoder>,
154 layers: Vec<Linear>,
155 density_layer: Linear,
156 rgb_layers: Vec<Linear>,
157}
158
159impl NeRFMLP {
160 pub fn new(config: NeRFConfig) -> Self {
162 let pos_encoder = PositionalEncoder::new(config.num_freq_bands, true);
163 let pos_dim = pos_encoder.output_dim(3); let dir_encoder = if config.use_view_dirs {
166 Some(PositionalEncoder::new(config.num_freq_bands / 2, true))
167 } else {
168 None
169 };
170
171 let mut layers = Vec::new();
173 let mut in_dim = pos_dim;
174
175 for i in 0..config.num_layers {
176 let out_dim = config.hidden_size;
177 layers.push(Linear::new(in_dim, out_dim));
178
179 if i == config.skip_layer {
181 in_dim = config.hidden_size + pos_dim;
182 } else {
183 in_dim = config.hidden_size;
184 }
185 }
186
187 let density_layer = Linear::new(config.hidden_size, 1);
189
190 let mut rgb_layers = Vec::new();
192 if config.use_view_dirs {
193 let dir_dim = dir_encoder.as_ref().unwrap().output_dim(3);
194 rgb_layers.push(Linear::new(config.hidden_size + dir_dim, config.hidden_size / 2));
195 rgb_layers.push(Linear::new(config.hidden_size / 2, 3));
196 } else {
197 rgb_layers.push(Linear::new(config.hidden_size, 3));
198 }
199
200 NeRFMLP {
201 config,
202 pos_encoder,
203 dir_encoder,
204 layers,
205 density_layer,
206 rgb_layers,
207 }
208 }
209
210 pub fn forward(&self, positions: &Tensor, directions: Option<&Tensor>) -> Result<(Tensor, Tensor), String> {
212 let mut x = self.pos_encoder.encode(positions)?;
214 let encoded_pos = x.clone();
215
216 for (i, layer) in self.layers.iter().enumerate() {
218 x = layer.forward(&x);
219 x = x.relu();
220
221 if i == self.config.skip_layer {
223 x = self.concatenate(&x, &encoded_pos)?;
224 }
225 }
226
227 let density = self.density_layer.forward(&x);
229 let density = density.relu();
230
231 let rgb = if self.config.use_view_dirs && directions.is_some() {
233 let dirs = directions.unwrap();
234 let encoded_dirs = self.dir_encoder.as_ref().unwrap().encode(dirs)?;
235 let mut rgb_input = self.concatenate(&x, &encoded_dirs)?;
236
237 for (i, layer) in self.rgb_layers.iter().enumerate() {
238 rgb_input = layer.forward(&rgb_input);
239 if i < self.rgb_layers.len() - 1 {
240 rgb_input = rgb_input.relu();
241 }
242 }
243 rgb_input.sigmoid() } else {
245 let mut rgb = self.rgb_layers[0].forward(&x);
246 rgb = rgb.sigmoid();
247 rgb
248 };
249
250 Ok((rgb, density))
251 }
252
253 fn concatenate(&self, a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
255 let a_data = a.data_f32();
256 let b_data = b.data_f32();
257 let a_dims = a.dims();
258 let b_dims = b.dims();
259
260 if a_dims.len() != b_dims.len() {
261 return Err("Tensors must have same number of dimensions".to_string());
262 }
263
264 for i in 0..a_dims.len() - 1 {
265 if a_dims[i] != b_dims[i] {
266 return Err(format!("Dimension mismatch at axis {}", i));
267 }
268 }
269
270 let a_last = a_dims[a_dims.len() - 1];
271 let b_last = b_dims[b_dims.len() - 1];
272 let batch_size = a_data.len() / a_last;
273
274 let mut result = Vec::with_capacity(batch_size * (a_last + b_last));
275
276 for i in 0..batch_size {
277 let a_start = i * a_last;
278 let b_start = i * b_last;
279 result.extend_from_slice(&a_data[a_start..a_start + a_last]);
280 result.extend_from_slice(&b_data[b_start..b_start + b_last]);
281 }
282
283 let mut output_dims = a_dims.to_vec();
284 output_dims[a_dims.len() - 1] = a_last + b_last;
285
286 Tensor::from_slice(&result, &output_dims)
287 .map_err(|e| format!("Failed to concatenate: {:?}", e))
288 }
289}
290
291pub struct RaySampler {
293 near: f32,
294 far: f32,
295}
296
297impl RaySampler {
298 pub fn new(near: f32, far: f32) -> Self {
300 RaySampler { near, far }
301 }
302
303 pub fn sample_along_rays(&self, ray_origins: &Tensor, ray_directions: &Tensor, num_samples: usize) -> Result<(Tensor, Tensor), String> {
305 let origins_data = ray_origins.data_f32();
306 let directions_data = ray_directions.data_f32();
307 let dims = ray_origins.dims();
308 let num_rays = dims[0];
309
310 let mut depths = Vec::with_capacity(num_rays * num_samples);
312 let step = (self.far - self.near) / (num_samples - 1) as f32;
313
314 for _ in 0..num_rays {
315 for i in 0..num_samples {
316 let depth = self.near + step * i as f32;
317 depths.push(depth);
318 }
319 }
320
321 let mut positions = Vec::with_capacity(num_rays * num_samples * 3);
323
324 for ray_idx in 0..num_rays {
325 let o_start = ray_idx * 3;
326 let d_start = ray_idx * 3;
327
328 for sample_idx in 0..num_samples {
329 let depth = depths[ray_idx * num_samples + sample_idx];
330
331 for dim in 0..3 {
332 let pos = origins_data[o_start + dim] + depth * directions_data[d_start + dim];
333 positions.push(pos);
334 }
335 }
336 }
337
338 let positions_tensor = Tensor::from_slice(&positions, &[num_rays, num_samples, 3])
339 .map_err(|e| format!("Failed to create positions: {:?}", e))?;
340
341 let depths_tensor = Tensor::from_slice(&depths, &[num_rays, num_samples])
342 .map_err(|e| format!("Failed to create depths: {:?}", e))?;
343
344 Ok((positions_tensor, depths_tensor))
345 }
346}
347
348pub struct VolumeRenderer;
350
351impl VolumeRenderer {
352 pub fn render_rays(rgb: &Tensor, density: &Tensor, depths: &Tensor) -> Result<Tensor, String> {
354 let rgb_data = rgb.data_f32();
355 let density_data = density.data_f32();
356 let depths_data = depths.data_f32();
357
358 let dims = rgb.dims();
359 let num_rays = dims[0];
360 let num_samples = dims[1];
361
362 let mut rendered = Vec::with_capacity(num_rays * 3);
363
364 for ray_idx in 0..num_rays {
365 let mut accumulated_rgb = [0.0f32; 3];
366 let mut accumulated_alpha = 0.0f32;
367
368 for sample_idx in 0..num_samples {
369 let delta = if sample_idx < num_samples - 1 {
371 depths_data[ray_idx * num_samples + sample_idx + 1]
372 - depths_data[ray_idx * num_samples + sample_idx]
373 } else {
374 1e10 };
376
377 let density_idx = ray_idx * num_samples + sample_idx;
379 let sigma = density_data[density_idx];
380
381 let alpha = 1.0 - (-sigma * delta).exp();
383
384 let transmittance = 1.0 - accumulated_alpha;
386
387 let rgb_start = (ray_idx * num_samples + sample_idx) * 3;
389 for c in 0..3 {
390 accumulated_rgb[c] += transmittance * alpha * rgb_data[rgb_start + c];
391 }
392
393 accumulated_alpha += transmittance * alpha;
395
396 if accumulated_alpha > 0.999 {
398 break;
399 }
400 }
401
402 rendered.extend_from_slice(&accumulated_rgb);
403 }
404
405 Tensor::from_slice(&rendered, &[num_rays, 3])
406 .map_err(|e| format!("Failed to create rendered image: {:?}", e))
407 }
408}
409
410pub struct NeRF {
412 config: NeRFConfig,
413 coarse_network: NeRFMLP,
414 fine_network: Option<NeRFMLP>,
415 sampler: RaySampler,
416}
417
418impl NeRF {
419 pub fn new(config: NeRFConfig) -> Self {
421 let coarse_network = NeRFMLP::new(config.clone());
422 let fine_network = if config.num_samples_fine > 0 {
423 Some(NeRFMLP::new(config.clone()))
424 } else {
425 None
426 };
427 let sampler = RaySampler::new(config.near, config.far);
428
429 NeRF {
430 config,
431 coarse_network,
432 fine_network,
433 sampler,
434 }
435 }
436
437 pub fn render(&self, ray_origins: &Tensor, ray_directions: &Tensor) -> Result<Tensor, String> {
439 let (positions_coarse, depths_coarse) = self.sampler.sample_along_rays(
441 ray_origins,
442 ray_directions,
443 self.config.num_samples_coarse,
444 )?;
445
446 let num_rays = ray_origins.dims()[0];
448 let positions_flat = self.reshape_for_network(&positions_coarse)?;
449 let directions_flat = self.repeat_directions(ray_directions, self.config.num_samples_coarse)?;
450
451 let (rgb_coarse, density_coarse) = self.coarse_network.forward(&positions_flat, Some(&directions_flat))?;
453
454 let rgb_coarse = self.reshape_from_network(&rgb_coarse, num_rays, self.config.num_samples_coarse, 3)?;
456 let density_coarse = self.reshape_from_network(&density_coarse, num_rays, self.config.num_samples_coarse, 1)?;
457
458 let rendered_coarse = VolumeRenderer::render_rays(&rgb_coarse, &density_coarse, &depths_coarse)?;
460
461 if let Some(ref fine_net) = self.fine_network {
463 let (positions_fine, depths_fine) = self.sampler.sample_along_rays(
466 ray_origins,
467 ray_directions,
468 self.config.num_samples_fine,
469 )?;
470
471 let positions_flat = self.reshape_for_network(&positions_fine)?;
472 let directions_flat = self.repeat_directions(ray_directions, self.config.num_samples_fine)?;
473
474 let (rgb_fine, density_fine) = fine_net.forward(&positions_flat, Some(&directions_flat))?;
475
476 let rgb_fine = self.reshape_from_network(&rgb_fine, num_rays, self.config.num_samples_fine, 3)?;
477 let density_fine = self.reshape_from_network(&density_fine, num_rays, self.config.num_samples_fine, 1)?;
478
479 VolumeRenderer::render_rays(&rgb_fine, &density_fine, &depths_fine)
480 } else {
481 Ok(rendered_coarse)
482 }
483 }
484
485 fn reshape_for_network(&self, x: &Tensor) -> Result<Tensor, String> {
487 let data = x.data_f32();
488 let dims = x.dims();
489 let new_dims = vec![dims[0] * dims[1], dims[2]];
490 Tensor::from_slice(&data, &new_dims)
491 .map_err(|e| format!("Failed to reshape: {:?}", e))
492 }
493
494 fn reshape_from_network(&self, x: &Tensor, num_rays: usize, num_samples: usize, channels: usize) -> Result<Tensor, String> {
496 let data = x.data_f32();
497 let new_dims = vec![num_rays, num_samples, channels];
498 Tensor::from_slice(&data, &new_dims)
499 .map_err(|e| format!("Failed to reshape: {:?}", e))
500 }
501
502 fn repeat_directions(&self, directions: &Tensor, num_samples: usize) -> Result<Tensor, String> {
504 let data = directions.data_f32();
505 let dims = directions.dims();
506 let num_rays = dims[0];
507
508 let mut result = Vec::with_capacity(num_rays * num_samples * 3);
509
510 for ray_idx in 0..num_rays {
511 let start = ray_idx * 3;
512 for _ in 0..num_samples {
513 result.extend_from_slice(&data[start..start + 3]);
514 }
515 }
516
517 Tensor::from_slice(&result, &[num_rays * num_samples, 3])
518 .map_err(|e| format!("Failed to repeat directions: {:?}", e))
519 }
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525
526 #[test]
527 fn test_nerf_config() {
528 let config = NeRFConfig::default();
529 assert_eq!(config.num_samples_coarse, 64);
530 assert_eq!(config.num_samples_fine, 128);
531
532 let tiny = NeRFConfig::tiny();
533 assert_eq!(tiny.hidden_size, 128);
534 }
535
536 #[test]
537 fn test_positional_encoder() {
538 let encoder = PositionalEncoder::new(4, true);
539 assert_eq!(encoder.output_dim(3), 3 + 3 * 4 * 2); let x = Tensor::from_slice(&[0.5f32, 0.3, 0.1], &[1, 3]).unwrap();
542 let encoded = encoder.encode(&x).unwrap();
543 assert_eq!(encoded.dims(), &[1, 27]); }
545
546 #[test]
547 fn test_ray_sampler() {
548 let sampler = RaySampler::new(2.0, 6.0);
549
550 let origins = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 1.0, 1.0, 1.0], &[2, 3]).unwrap();
551 let directions = Tensor::from_slice(&[0.0f32, 0.0, 1.0, 1.0, 0.0, 0.0], &[2, 3]).unwrap();
552
553 let (positions, depths) = sampler.sample_along_rays(&origins, &directions, 8).unwrap();
554
555 assert_eq!(positions.dims(), &[2, 8, 3]);
556 assert_eq!(depths.dims(), &[2, 8]);
557 }
558
559 #[test]
560 fn test_nerf_mlp() {
561 let config = NeRFConfig::tiny();
562 let mlp = NeRFMLP::new(config);
563
564 let positions = Tensor::randn(&[4, 3]);
565 let directions = Tensor::randn(&[4, 3]);
566
567 let (rgb, density) = mlp.forward(&positions, Some(&directions)).unwrap();
568
569 assert_eq!(rgb.dims(), &[4, 3]);
570 assert_eq!(density.dims(), &[4, 1]);
571 }
572
573 #[test]
574 fn test_volume_renderer() {
575 let rgb = Tensor::from_slice(&[1.0f32, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], &[1, 3, 3]).unwrap();
576 let density = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[1, 3, 1]).unwrap();
577 let depths = Tensor::from_slice(&[2.0f32, 3.0, 4.0], &[1, 3]).unwrap();
578
579 let rendered = VolumeRenderer::render_rays(&rgb, &density, &depths).unwrap();
580 assert_eq!(rendered.dims(), &[1, 3]);
581 }
582
583 #[test]
584 fn test_nerf_model() {
585 let config = NeRFConfig::tiny();
586 let nerf = NeRF::new(config);
587
588 let ray_origins = Tensor::from_slice(&[0.0f32, 0.0, 0.0], &[1, 3]).unwrap();
589 let ray_directions = Tensor::from_slice(&[0.0f32, 0.0, 1.0], &[1, 3]).unwrap();
590
591 let rendered = nerf.render(&ray_origins, &ray_directions).unwrap();
592 assert_eq!(rendered.dims(), &[1, 3]); }
594}