Skip to main content

oxicuda_nerf/rendering/
ray.rs

1//! Ray representation and pinhole camera model.
2
3use crate::error::{NerfError, NerfResult};
4
5// ─── Ray ─────────────────────────────────────────────────────────────────────
6
7/// A 3D ray with origin and (ideally normalized) direction.
8#[derive(Debug, Clone, Copy)]
9pub struct Ray {
10    /// Ray origin in world space.
11    pub origin: [f32; 3],
12    /// Ray direction (should be normalized).
13    pub dir: [f32; 3],
14}
15
16impl Ray {
17    /// Create a ray with the given origin and direction (not normalized).
18    ///
19    /// # Errors
20    ///
21    /// Returns `ZeroRayDirection` if `|dir| < 1e-8`.
22    pub fn new(origin: [f32; 3], dir: [f32; 3]) -> NerfResult<Self> {
23        let len_sq = dir[0] * dir[0] + dir[1] * dir[1] + dir[2] * dir[2];
24        if len_sq < 1e-16 {
25            return Err(NerfError::ZeroRayDirection);
26        }
27        Ok(Self { origin, dir })
28    }
29
30    /// Create a ray with automatically normalized direction.
31    ///
32    /// # Errors
33    ///
34    /// Returns `ZeroRayDirection` if `|dir| < 1e-8`.
35    pub fn normalized(origin: [f32; 3], dir: [f32; 3]) -> NerfResult<Self> {
36        let len_sq = dir[0] * dir[0] + dir[1] * dir[1] + dir[2] * dir[2];
37        if len_sq < 1e-16 {
38            return Err(NerfError::ZeroRayDirection);
39        }
40        let inv_len = 1.0 / len_sq.sqrt();
41        let ndir = [dir[0] * inv_len, dir[1] * inv_len, dir[2] * inv_len];
42        Ok(Self { origin, dir: ndir })
43    }
44
45    /// Evaluate the point along the ray at parameter `t`: `origin + t * dir`.
46    #[must_use]
47    pub fn at(&self, t: f32) -> [f32; 3] {
48        [
49            self.origin[0] + t * self.dir[0],
50            self.origin[1] + t * self.dir[1],
51            self.origin[2] + t * self.dir[2],
52        ]
53    }
54}
55
56// ─── PinholeCamera ───────────────────────────────────────────────────────────
57
58/// Pinhole camera intrinsics.
59#[derive(Debug, Clone, Copy)]
60pub struct PinholeCamera {
61    /// Focal length in x (pixels).
62    pub fx: f32,
63    /// Focal length in y (pixels).
64    pub fy: f32,
65    /// Principal point x (pixels).
66    pub cx: f32,
67    /// Principal point y (pixels).
68    pub cy: f32,
69    /// Image width in pixels.
70    pub width: u32,
71    /// Image height in pixels.
72    pub height: u32,
73}
74
75impl PinholeCamera {
76    /// Create a new pinhole camera model.
77    ///
78    /// # Errors
79    ///
80    /// Returns `InvalidCameraIntrinsics` for non-positive focal lengths or zero dimensions.
81    pub fn new(fx: f32, fy: f32, cx: f32, cy: f32, w: u32, h: u32) -> NerfResult<Self> {
82        if fx <= 0.0 || fy <= 0.0 {
83            return Err(NerfError::InvalidCameraIntrinsics {
84                msg: "focal lengths must be positive".into(),
85            });
86        }
87        if w == 0 || h == 0 {
88            return Err(NerfError::InvalidCameraIntrinsics {
89                msg: "image dimensions must be > 0".into(),
90            });
91        }
92        Ok(Self {
93            fx,
94            fy,
95            cx,
96            cy,
97            width: w,
98            height: h,
99        })
100    }
101
102    /// Generate the ray through pixel `(u, v)` in camera coordinates.
103    ///
104    /// `c2w` is a row-major 3×4 camera-to-world transform:
105    /// ```text
106    /// [ R[0,0] R[0,1] R[0,2] t[0]
107    ///   R[1,0] R[1,1] R[1,2] t[1]
108    ///   R[2,0] R[2,1] R[2,2] t[2] ]
109    /// ```
110    ///
111    /// # Errors
112    ///
113    /// Returns `ZeroRayDirection` if the transformed direction is near-zero.
114    pub fn ray_through_pixel(&self, u: f32, v: f32, c2w: &[f32; 12]) -> NerfResult<Ray> {
115        // Camera-space direction (unnormalized)
116        let dx = (u - self.cx) / self.fx;
117        let dy = (v - self.cy) / self.fy;
118        let dz = 1.0_f32;
119
120        // Rotate to world space using upper-left 3×3 of c2w
121        let wx = c2w[0] * dx + c2w[1] * dy + c2w[2] * dz;
122        let wy = c2w[4] * dx + c2w[5] * dy + c2w[6] * dz;
123        let wz = c2w[8] * dx + c2w[9] * dy + c2w[10] * dz;
124
125        // Origin = translation column of c2w
126        let origin = [c2w[3], c2w[7], c2w[11]];
127
128        Ray::normalized(origin, [wx, wy, wz])
129    }
130
131    /// Generate all W×H rays for the camera with given `c2w` matrix.
132    ///
133    /// Output: `width * height` rays in row-major order (left-to-right, top-to-bottom).
134    ///
135    /// # Errors
136    ///
137    /// Returns `ZeroRayDirection` if any pixel ray has near-zero direction.
138    pub fn generate_rays(&self, c2w: &[f32; 12]) -> NerfResult<Vec<Ray>> {
139        let n = (self.width * self.height) as usize;
140        let mut rays = Vec::with_capacity(n);
141        for row in 0..self.height {
142            for col in 0..self.width {
143                let u = col as f32 + 0.5;
144                let v = row as f32 + 0.5;
145                rays.push(self.ray_through_pixel(u, v, c2w)?);
146            }
147        }
148        Ok(rays)
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    fn identity_c2w() -> [f32; 12] {
157        [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
158    }
159
160    #[test]
161    fn ray_at_origin() {
162        let r = Ray::new([0.0, 0.0, 0.0], [0.0, 0.0, 1.0]).unwrap();
163        let pt = r.at(2.0);
164        assert!((pt[2] - 2.0).abs() < 1e-6);
165    }
166
167    #[test]
168    fn ray_zero_dir_error() {
169        assert!(Ray::new([0.0; 3], [0.0; 3]).is_err());
170    }
171
172    #[test]
173    fn camera_principal_ray() {
174        let cam = PinholeCamera::new(100.0, 100.0, 50.0, 50.0, 100, 100).unwrap();
175        let ray = cam.ray_through_pixel(50.5, 50.5, &identity_c2w()).unwrap();
176        // Principal ray at (cx, cy) points forward (+z in camera = +z in world for identity)
177        assert!(ray.dir[2] > 0.0);
178    }
179
180    #[test]
181    fn generate_rays_count() {
182        let cam = PinholeCamera::new(100.0, 100.0, 50.0, 50.0, 4, 3).unwrap();
183        let rays = cam.generate_rays(&identity_c2w()).unwrap();
184        assert_eq!(rays.len(), 12);
185    }
186}