ghostflow_nn/
nerf.rs

1//! NeRF (Neural Radiance Fields)
2//!
3//! Implements Neural Radiance Fields for 3D scene representation:
4//! - Volumetric scene representation
5//! - Novel view synthesis
6//! - Positional encoding
7//! - Hierarchical sampling
8//! - Volume rendering
9
10use ghostflow_core::Tensor;
11use crate::linear::Linear;
12use crate::Module;
13
14/// NeRF configuration
15#[derive(Debug, Clone)]
16pub struct NeRFConfig {
17    /// Number of samples per ray (coarse)
18    pub num_samples_coarse: usize,
19    /// Number of samples per ray (fine)
20    pub num_samples_fine: usize,
21    /// Number of frequency bands for positional encoding
22    pub num_freq_bands: usize,
23    /// Hidden layer size
24    pub hidden_size: usize,
25    /// Number of hidden layers
26    pub num_layers: usize,
27    /// Skip connection layer
28    pub skip_layer: usize,
29    /// Use view direction
30    pub use_view_dirs: bool,
31    /// Near plane distance
32    pub near: f32,
33    /// Far plane distance
34    pub far: f32,
35}
36
37impl Default for NeRFConfig {
38    fn default() -> Self {
39        NeRFConfig {
40            num_samples_coarse: 64,
41            num_samples_fine: 128,
42            num_freq_bands: 10,
43            hidden_size: 256,
44            num_layers: 8,
45            skip_layer: 4,
46            use_view_dirs: true,
47            near: 2.0,
48            far: 6.0,
49        }
50    }
51}
52
53impl NeRFConfig {
54    /// Tiny NeRF for testing
55    pub fn tiny() -> Self {
56        NeRFConfig {
57            num_samples_coarse: 32,
58            num_samples_fine: 64,
59            num_freq_bands: 6,
60            hidden_size: 128,
61            num_layers: 4,
62            skip_layer: 2,
63            use_view_dirs: true,
64            near: 2.0,
65            far: 6.0,
66        }
67    }
68    
69    /// Large NeRF for high quality
70    pub fn large() -> Self {
71        NeRFConfig {
72            num_samples_coarse: 128,
73            num_samples_fine: 256,
74            num_freq_bands: 15,
75            hidden_size: 512,
76            num_layers: 10,
77            skip_layer: 5,
78            use_view_dirs: true,
79            near: 2.0,
80            far: 6.0,
81        }
82    }
83}
84
85/// Positional encoding for NeRF
86pub struct PositionalEncoder {
87    num_freq_bands: usize,
88    include_input: bool,
89}
90
91impl PositionalEncoder {
92    /// Create new positional encoder
93    pub fn new(num_freq_bands: usize, include_input: bool) -> Self {
94        PositionalEncoder {
95            num_freq_bands,
96            include_input,
97        }
98    }
99    
100    /// Get output dimension
101    pub fn output_dim(&self, input_dim: usize) -> usize {
102        let encoded_dim = input_dim * self.num_freq_bands * 2; // sin and cos
103        if self.include_input {
104            encoded_dim + input_dim
105        } else {
106            encoded_dim
107        }
108    }
109    
110    /// Encode positions
111    pub fn encode(&self, x: &Tensor) -> Result<Tensor, String> {
112        let x_data = x.data_f32();
113        let dims = x.dims();
114        let input_dim = dims[dims.len() - 1];
115        let batch_size = x_data.len() / input_dim;
116        
117        let output_dim = self.output_dim(input_dim);
118        let mut result = Vec::with_capacity(batch_size * output_dim);
119        
120        for i in 0..batch_size {
121            let start = i * input_dim;
122            let end = start + input_dim;
123            let input_slice = &x_data[start..end];
124            
125            // Include original input if requested
126            if self.include_input {
127                result.extend_from_slice(input_slice);
128            }
129            
130            // Apply positional encoding: [sin(2^0*pi*x), cos(2^0*pi*x), sin(2^1*pi*x), cos(2^1*pi*x), ...]
131            for freq in 0..self.num_freq_bands {
132                let freq_scale = 2.0_f32.powi(freq as i32) * std::f32::consts::PI;
133                for &val in input_slice.iter() {
134                    let scaled = val * freq_scale;
135                    result.push(scaled.sin());
136                    result.push(scaled.cos());
137                }
138            }
139        }
140        
141        let mut output_dims = dims.to_vec();
142        output_dims[dims.len() - 1] = output_dim;
143        
144        Tensor::from_slice(&result, &output_dims)
145            .map_err(|e| format!("Failed to create encoded tensor: {:?}", e))
146    }
147}
148
149/// NeRF MLP network
150pub struct NeRFMLP {
151    config: NeRFConfig,
152    pos_encoder: PositionalEncoder,
153    dir_encoder: Option<PositionalEncoder>,
154    layers: Vec<Linear>,
155    density_layer: Linear,
156    rgb_layers: Vec<Linear>,
157}
158
159impl NeRFMLP {
160    /// Create new NeRF MLP
161    pub fn new(config: NeRFConfig) -> Self {
162        let pos_encoder = PositionalEncoder::new(config.num_freq_bands, true);
163        let pos_dim = pos_encoder.output_dim(3); // 3D positions
164        
165        let dir_encoder = if config.use_view_dirs {
166            Some(PositionalEncoder::new(config.num_freq_bands / 2, true))
167        } else {
168            None
169        };
170        
171        // Main MLP layers
172        let mut layers = Vec::new();
173        let mut in_dim = pos_dim;
174        
175        for i in 0..config.num_layers {
176            let out_dim = config.hidden_size;
177            layers.push(Linear::new(in_dim, out_dim));
178            
179            // Skip connection
180            if i == config.skip_layer {
181                in_dim = config.hidden_size + pos_dim;
182            } else {
183                in_dim = config.hidden_size;
184            }
185        }
186        
187        // Density output
188        let density_layer = Linear::new(config.hidden_size, 1);
189        
190        // RGB layers
191        let mut rgb_layers = Vec::new();
192        if config.use_view_dirs {
193            let dir_dim = dir_encoder.as_ref().unwrap().output_dim(3);
194            rgb_layers.push(Linear::new(config.hidden_size + dir_dim, config.hidden_size / 2));
195            rgb_layers.push(Linear::new(config.hidden_size / 2, 3));
196        } else {
197            rgb_layers.push(Linear::new(config.hidden_size, 3));
198        }
199        
200        NeRFMLP {
201            config,
202            pos_encoder,
203            dir_encoder,
204            layers,
205            density_layer,
206            rgb_layers,
207        }
208    }
209    
210    /// Forward pass
211    pub fn forward(&self, positions: &Tensor, directions: Option<&Tensor>) -> Result<(Tensor, Tensor), String> {
212        // Encode positions
213        let mut x = self.pos_encoder.encode(positions)?;
214        let encoded_pos = x.clone();
215        
216        // Pass through main MLP
217        for (i, layer) in self.layers.iter().enumerate() {
218            x = layer.forward(&x);
219            x = x.relu();
220            
221            // Skip connection
222            if i == self.config.skip_layer {
223                x = self.concatenate(&x, &encoded_pos)?;
224            }
225        }
226        
227        // Density output (with ReLU activation)
228        let density = self.density_layer.forward(&x);
229        let density = density.relu();
230        
231        // RGB output
232        let rgb = if self.config.use_view_dirs && directions.is_some() {
233            let dirs = directions.unwrap();
234            let encoded_dirs = self.dir_encoder.as_ref().unwrap().encode(dirs)?;
235            let mut rgb_input = self.concatenate(&x, &encoded_dirs)?;
236            
237            for (i, layer) in self.rgb_layers.iter().enumerate() {
238                rgb_input = layer.forward(&rgb_input);
239                if i < self.rgb_layers.len() - 1 {
240                    rgb_input = rgb_input.relu();
241                }
242            }
243            rgb_input.sigmoid() // RGB values in [0, 1]
244        } else {
245            let mut rgb = self.rgb_layers[0].forward(&x);
246            rgb = rgb.sigmoid();
247            rgb
248        };
249        
250        Ok((rgb, density))
251    }
252    
253    /// Concatenate tensors along last dimension
254    fn concatenate(&self, a: &Tensor, b: &Tensor) -> Result<Tensor, String> {
255        let a_data = a.data_f32();
256        let b_data = b.data_f32();
257        let a_dims = a.dims();
258        let b_dims = b.dims();
259        
260        if a_dims.len() != b_dims.len() {
261            return Err("Tensors must have same number of dimensions".to_string());
262        }
263        
264        for i in 0..a_dims.len() - 1 {
265            if a_dims[i] != b_dims[i] {
266                return Err(format!("Dimension mismatch at axis {}", i));
267            }
268        }
269        
270        let a_last = a_dims[a_dims.len() - 1];
271        let b_last = b_dims[b_dims.len() - 1];
272        let batch_size = a_data.len() / a_last;
273        
274        let mut result = Vec::with_capacity(batch_size * (a_last + b_last));
275        
276        for i in 0..batch_size {
277            let a_start = i * a_last;
278            let b_start = i * b_last;
279            result.extend_from_slice(&a_data[a_start..a_start + a_last]);
280            result.extend_from_slice(&b_data[b_start..b_start + b_last]);
281        }
282        
283        let mut output_dims = a_dims.to_vec();
284        output_dims[a_dims.len() - 1] = a_last + b_last;
285        
286        Tensor::from_slice(&result, &output_dims)
287            .map_err(|e| format!("Failed to concatenate: {:?}", e))
288    }
289}
290
291/// Ray sampler for NeRF
292pub struct RaySampler {
293    near: f32,
294    far: f32,
295}
296
297impl RaySampler {
298    /// Create new ray sampler
299    pub fn new(near: f32, far: f32) -> Self {
300        RaySampler { near, far }
301    }
302    
303    /// Sample points along rays
304    pub fn sample_along_rays(&self, ray_origins: &Tensor, ray_directions: &Tensor, num_samples: usize) -> Result<(Tensor, Tensor), String> {
305        let origins_data = ray_origins.data_f32();
306        let directions_data = ray_directions.data_f32();
307        let dims = ray_origins.dims();
308        let num_rays = dims[0];
309        
310        // Generate sample depths
311        let mut depths = Vec::with_capacity(num_rays * num_samples);
312        let step = (self.far - self.near) / (num_samples - 1) as f32;
313        
314        for _ in 0..num_rays {
315            for i in 0..num_samples {
316                let depth = self.near + step * i as f32;
317                depths.push(depth);
318            }
319        }
320        
321        // Compute 3D sample positions: origin + depth * direction
322        let mut positions = Vec::with_capacity(num_rays * num_samples * 3);
323        
324        for ray_idx in 0..num_rays {
325            let o_start = ray_idx * 3;
326            let d_start = ray_idx * 3;
327            
328            for sample_idx in 0..num_samples {
329                let depth = depths[ray_idx * num_samples + sample_idx];
330                
331                for dim in 0..3 {
332                    let pos = origins_data[o_start + dim] + depth * directions_data[d_start + dim];
333                    positions.push(pos);
334                }
335            }
336        }
337        
338        let positions_tensor = Tensor::from_slice(&positions, &[num_rays, num_samples, 3])
339            .map_err(|e| format!("Failed to create positions: {:?}", e))?;
340        
341        let depths_tensor = Tensor::from_slice(&depths, &[num_rays, num_samples])
342            .map_err(|e| format!("Failed to create depths: {:?}", e))?;
343        
344        Ok((positions_tensor, depths_tensor))
345    }
346}
347
348/// Volume renderer for NeRF
349pub struct VolumeRenderer;
350
351impl VolumeRenderer {
352    /// Render rays using volume rendering equation
353    pub fn render_rays(rgb: &Tensor, density: &Tensor, depths: &Tensor) -> Result<Tensor, String> {
354        let rgb_data = rgb.data_f32();
355        let density_data = density.data_f32();
356        let depths_data = depths.data_f32();
357        
358        let dims = rgb.dims();
359        let num_rays = dims[0];
360        let num_samples = dims[1];
361        
362        let mut rendered = Vec::with_capacity(num_rays * 3);
363        
364        for ray_idx in 0..num_rays {
365            let mut accumulated_rgb = [0.0f32; 3];
366            let mut accumulated_alpha = 0.0f32;
367            
368            for sample_idx in 0..num_samples {
369                // Get depth interval
370                let delta = if sample_idx < num_samples - 1 {
371                    depths_data[ray_idx * num_samples + sample_idx + 1] 
372                        - depths_data[ray_idx * num_samples + sample_idx]
373                } else {
374                    1e10 // Large value for last sample
375                };
376                
377                // Get density and RGB
378                let density_idx = ray_idx * num_samples + sample_idx;
379                let sigma = density_data[density_idx];
380                
381                // Compute alpha (opacity)
382                let alpha = 1.0 - (-sigma * delta).exp();
383                
384                // Compute transmittance
385                let transmittance = 1.0 - accumulated_alpha;
386                
387                // Accumulate color
388                let rgb_start = (ray_idx * num_samples + sample_idx) * 3;
389                for c in 0..3 {
390                    accumulated_rgb[c] += transmittance * alpha * rgb_data[rgb_start + c];
391                }
392                
393                // Accumulate alpha
394                accumulated_alpha += transmittance * alpha;
395                
396                // Early stopping if fully opaque
397                if accumulated_alpha > 0.999 {
398                    break;
399                }
400            }
401            
402            rendered.extend_from_slice(&accumulated_rgb);
403        }
404        
405        Tensor::from_slice(&rendered, &[num_rays, 3])
406            .map_err(|e| format!("Failed to create rendered image: {:?}", e))
407    }
408}
409
410/// Complete NeRF model
411pub struct NeRF {
412    config: NeRFConfig,
413    coarse_network: NeRFMLP,
414    fine_network: Option<NeRFMLP>,
415    sampler: RaySampler,
416}
417
418impl NeRF {
419    /// Create new NeRF model
420    pub fn new(config: NeRFConfig) -> Self {
421        let coarse_network = NeRFMLP::new(config.clone());
422        let fine_network = if config.num_samples_fine > 0 {
423            Some(NeRFMLP::new(config.clone()))
424        } else {
425            None
426        };
427        let sampler = RaySampler::new(config.near, config.far);
428        
429        NeRF {
430            config,
431            coarse_network,
432            fine_network,
433            sampler,
434        }
435    }
436    
437    /// Render rays
438    pub fn render(&self, ray_origins: &Tensor, ray_directions: &Tensor) -> Result<Tensor, String> {
439        // Coarse sampling
440        let (positions_coarse, depths_coarse) = self.sampler.sample_along_rays(
441            ray_origins,
442            ray_directions,
443            self.config.num_samples_coarse,
444        )?;
445        
446        // Reshape for network
447        let num_rays = ray_origins.dims()[0];
448        let positions_flat = self.reshape_for_network(&positions_coarse)?;
449        let directions_flat = self.repeat_directions(ray_directions, self.config.num_samples_coarse)?;
450        
451        // Coarse network forward pass
452        let (rgb_coarse, density_coarse) = self.coarse_network.forward(&positions_flat, Some(&directions_flat))?;
453        
454        // Reshape back
455        let rgb_coarse = self.reshape_from_network(&rgb_coarse, num_rays, self.config.num_samples_coarse, 3)?;
456        let density_coarse = self.reshape_from_network(&density_coarse, num_rays, self.config.num_samples_coarse, 1)?;
457        
458        // Render coarse
459        let rendered_coarse = VolumeRenderer::render_rays(&rgb_coarse, &density_coarse, &depths_coarse)?;
460        
461        // Fine sampling (if enabled)
462        if let Some(ref fine_net) = self.fine_network {
463            // For simplicity, use uniform sampling for fine network too
464            // In practice, you'd use importance sampling based on coarse weights
465            let (positions_fine, depths_fine) = self.sampler.sample_along_rays(
466                ray_origins,
467                ray_directions,
468                self.config.num_samples_fine,
469            )?;
470            
471            let positions_flat = self.reshape_for_network(&positions_fine)?;
472            let directions_flat = self.repeat_directions(ray_directions, self.config.num_samples_fine)?;
473            
474            let (rgb_fine, density_fine) = fine_net.forward(&positions_flat, Some(&directions_flat))?;
475            
476            let rgb_fine = self.reshape_from_network(&rgb_fine, num_rays, self.config.num_samples_fine, 3)?;
477            let density_fine = self.reshape_from_network(&density_fine, num_rays, self.config.num_samples_fine, 1)?;
478            
479            VolumeRenderer::render_rays(&rgb_fine, &density_fine, &depths_fine)
480        } else {
481            Ok(rendered_coarse)
482        }
483    }
484    
485    /// Reshape tensor for network input
486    fn reshape_for_network(&self, x: &Tensor) -> Result<Tensor, String> {
487        let data = x.data_f32();
488        let dims = x.dims();
489        let new_dims = vec![dims[0] * dims[1], dims[2]];
490        Tensor::from_slice(&data, &new_dims)
491            .map_err(|e| format!("Failed to reshape: {:?}", e))
492    }
493    
494    /// Reshape tensor from network output
495    fn reshape_from_network(&self, x: &Tensor, num_rays: usize, num_samples: usize, channels: usize) -> Result<Tensor, String> {
496        let data = x.data_f32();
497        let new_dims = vec![num_rays, num_samples, channels];
498        Tensor::from_slice(&data, &new_dims)
499            .map_err(|e| format!("Failed to reshape: {:?}", e))
500    }
501    
502    /// Repeat directions for each sample
503    fn repeat_directions(&self, directions: &Tensor, num_samples: usize) -> Result<Tensor, String> {
504        let data = directions.data_f32();
505        let dims = directions.dims();
506        let num_rays = dims[0];
507        
508        let mut result = Vec::with_capacity(num_rays * num_samples * 3);
509        
510        for ray_idx in 0..num_rays {
511            let start = ray_idx * 3;
512            for _ in 0..num_samples {
513                result.extend_from_slice(&data[start..start + 3]);
514            }
515        }
516        
517        Tensor::from_slice(&result, &[num_rays * num_samples, 3])
518            .map_err(|e| format!("Failed to repeat directions: {:?}", e))
519    }
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525    
526    #[test]
527    fn test_nerf_config() {
528        let config = NeRFConfig::default();
529        assert_eq!(config.num_samples_coarse, 64);
530        assert_eq!(config.num_samples_fine, 128);
531        
532        let tiny = NeRFConfig::tiny();
533        assert_eq!(tiny.hidden_size, 128);
534    }
535    
536    #[test]
537    fn test_positional_encoder() {
538        let encoder = PositionalEncoder::new(4, true);
539        assert_eq!(encoder.output_dim(3), 3 + 3 * 4 * 2); // input + encoded
540        
541        let x = Tensor::from_slice(&[0.5f32, 0.3, 0.1], &[1, 3]).unwrap();
542        let encoded = encoder.encode(&x).unwrap();
543        assert_eq!(encoded.dims(), &[1, 27]); // 3 + 24
544    }
545    
546    #[test]
547    fn test_ray_sampler() {
548        let sampler = RaySampler::new(2.0, 6.0);
549        
550        let origins = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 1.0, 1.0, 1.0], &[2, 3]).unwrap();
551        let directions = Tensor::from_slice(&[0.0f32, 0.0, 1.0, 1.0, 0.0, 0.0], &[2, 3]).unwrap();
552        
553        let (positions, depths) = sampler.sample_along_rays(&origins, &directions, 8).unwrap();
554        
555        assert_eq!(positions.dims(), &[2, 8, 3]);
556        assert_eq!(depths.dims(), &[2, 8]);
557    }
558    
559    #[test]
560    fn test_nerf_mlp() {
561        let config = NeRFConfig::tiny();
562        let mlp = NeRFMLP::new(config);
563        
564        let positions = Tensor::randn(&[4, 3]);
565        let directions = Tensor::randn(&[4, 3]);
566        
567        let (rgb, density) = mlp.forward(&positions, Some(&directions)).unwrap();
568        
569        assert_eq!(rgb.dims(), &[4, 3]);
570        assert_eq!(density.dims(), &[4, 1]);
571    }
572    
573    #[test]
574    fn test_volume_renderer() {
575        let rgb = Tensor::from_slice(&[1.0f32, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], &[1, 3, 3]).unwrap();
576        let density = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[1, 3, 1]).unwrap();
577        let depths = Tensor::from_slice(&[2.0f32, 3.0, 4.0], &[1, 3]).unwrap();
578        
579        let rendered = VolumeRenderer::render_rays(&rgb, &density, &depths).unwrap();
580        assert_eq!(rendered.dims(), &[1, 3]);
581    }
582    
583    #[test]
584    fn test_nerf_model() {
585        let config = NeRFConfig::tiny();
586        let nerf = NeRF::new(config);
587        
588        let ray_origins = Tensor::from_slice(&[0.0f32, 0.0, 0.0], &[1, 3]).unwrap();
589        let ray_directions = Tensor::from_slice(&[0.0f32, 0.0, 1.0], &[1, 3]).unwrap();
590        
591        let rendered = nerf.render(&ray_origins, &ray_directions).unwrap();
592        assert_eq!(rendered.dims(), &[1, 3]); // RGB output
593    }
594}