image_transitions/
lib.rs

1use cust::{error::CudaError, prelude::*};
2use thiserror::Error;
3
4#[derive(Error, Debug)]
5pub enum TransitionError {
6    #[error("Sizes of the two images are not equal. First image size: {first_image_size}, Second image size: {second_image_size}")]
7    SizeNotEqual {
8        first_image_size: usize,
9        second_image_size: usize,
10    },
11    #[error("gpu error")]
12    GPUError(#[from] CudaError),
13}
14
15pub enum GridStride {
16    Default,
17    Custom(u64),
18}
19
20static PTX: &str = include_str!("../ptx/image.ptx");
21
22// This function takes the raw data of 2 images and generates all the intermediate frames necessary
23// for a cross fade affect. The way this works,
24// ```output_image = (1 - alpha) * first_image + alpha * second image```.
25// We have a different alpha (which determines the transparency) for each iteration. We calculate
26// what the value should be based on the number of iterations. The output vector is flat vector
27// with all the raw intermediate frames concatenated together.
28// This function is generic over a constant called STRIDE. This paramete
29pub fn cross_fade(
30    first_image: &[u8],
31    second_image: &[u8],
32    iterations: u16,
33    stride: GridStride,
34) -> Result<Vec<u8>, TransitionError> {
35    let _ctx = cust::quick_init()?;
36    let module = Module::from_ptx(PTX, &[])?;
37    let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
38
39    if first_image.len() != second_image.len() {
40        return Err(TransitionError::SizeNotEqual {
41            first_image_size: first_image.len(),
42            second_image_size: second_image.len(),
43        });
44    }
45
46    let image_size = first_image.len();
47
48    let first_image_buffer = first_image.as_dbuf()?;
49    let second_image_buffer = second_image.as_dbuf()?;
50
51    // SAFETY: We are making sure not to access this buffer until we have passed it to the kernel
52    // and written the values to it
53    // iterations can be usize since we don't expect this to work on anything less than 32-bit
54    // systems
55    let mut output_buffer: UnifiedBuffer<u8> =
56        unsafe { UnifiedBuffer::uninitialized(image_size * iterations as usize) }?;
57
58    let func = module.get_function("cross_fade")?;
59    let (_, block_size) = func.suggested_launch_configuration(0, 0.into())?;
60
61    let stride_number = match stride {
62        GridStride::Default => 100,
63        GridStride::Custom(number) => number,
64    };
65
66    // We want the grid size to be large enough that we can parallely compute for all iterations of
67    // the entire image. We divide it by stride so that we don't take up too much vram. We want
68    // each thread to do multiple iterations
69    // We are casting grid_size to u32 because that's the maximum supported by cuda. This should
70    // likely be fine since we don't expect any of the values like image_size, block_size and
71    // iterations to be larger than u32. The reason we are casting everything to u64 first is
72    // because it's possible for ```image_size * iterations``` to go higher than u32 and we don't
73    // want overflow to happen. But since we are later diving this number by ```block_size * stride```
74    // we expect converting to u32 should usually work
75    let grid_size = u32::try_from(
76        (image_size as u64 * u64::from(iterations) + u64::from(block_size) * 100)
77            / (u64::from(block_size) * stride_number),
78    )
79    .expect(&format!(
80        "image size: {} or iterations: {} are too large",
81        image_size, iterations
82    ));
83
84    unsafe {
85        launch!(
86            func<<<grid_size, block_size, 0, stream>>>(
87                first_image_buffer.as_device_ptr(),
88                first_image_buffer.len(),
89                second_image_buffer.as_device_ptr(),
90                second_image_buffer.len(),
91                iterations,
92                block_size as u64 * grid_size as u64,
93                stride_number,
94                output_buffer.as_unified_ptr(),
95            )
96        )?;
97    }
98
99    stream.synchronize()?;
100
101    // TODO: see if we can avoid an allocation here
102    // This is safe since we're accessing the unified memory after stream synchonization
103    let output_image = output_buffer.to_vec();
104
105    Ok(output_image)
106}
107
108#[cfg(test)]
109mod tests {
110    use crate::cross_fade;
111    use crate::GridStride;
112
113    #[test]
114    fn test_basic() {
115        let first_image: &[u8] = &[100, 255, 5, 76];
116        let second_image: &[u8] = &[28, 8, 245, 100];
117        let iterations = 3;
118
119        let output =
120            cross_fade(first_image, second_image, iterations, GridStride::Default).unwrap();
121
122        let split_output = output.chunks_exact(first_image.len()).collect::<Vec<_>>();
123
124        assert_eq!(
125            split_output,
126            vec![[100, 255, 5, 76,], [64, 131, 125, 88,], [28, 8, 245, 100,]]
127        )
128    }
129
130    #[test]
131    fn test_large_random_image() {
132        const N: usize = 11059200;
133        let first_image: Vec<u8> = (0..N).map(|_| rand::random::<u8>()).collect();
134        let second_image: Vec<u8> = (0..N).map(|_| rand::random::<u8>()).collect();
135        let iterations = 720;
136
137        let output =
138            cross_fade(&first_image, &second_image, iterations, GridStride::Default).unwrap();
139
140        let split_output = output.chunks_exact(first_image.len()).collect::<Vec<_>>();
141        assert_ne!(split_output.last().unwrap(), &vec![0; first_image.len()])
142    }
143}