1use cust::{error::CudaError, prelude::*};
2use thiserror::Error;
3
4#[derive(Error, Debug)]
5pub enum TransitionError {
6 #[error("Sizes of the two images are not equal. First image size: {first_image_size}, Second image size: {second_image_size}")]
7 SizeNotEqual {
8 first_image_size: usize,
9 second_image_size: usize,
10 },
11 #[error("gpu error")]
12 GPUError(#[from] CudaError),
13}
14
15pub enum GridStride {
16 Default,
17 Custom(u64),
18}
19
20static PTX: &str = include_str!("../ptx/image.ptx");
21
22pub fn cross_fade(
30 first_image: &[u8],
31 second_image: &[u8],
32 iterations: u16,
33 stride: GridStride,
34) -> Result<Vec<u8>, TransitionError> {
35 let _ctx = cust::quick_init()?;
36 let module = Module::from_ptx(PTX, &[])?;
37 let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
38
39 if first_image.len() != second_image.len() {
40 return Err(TransitionError::SizeNotEqual {
41 first_image_size: first_image.len(),
42 second_image_size: second_image.len(),
43 });
44 }
45
46 let image_size = first_image.len();
47
48 let first_image_buffer = first_image.as_dbuf()?;
49 let second_image_buffer = second_image.as_dbuf()?;
50
51 let mut output_buffer: UnifiedBuffer<u8> =
56 unsafe { UnifiedBuffer::uninitialized(image_size * iterations as usize) }?;
57
58 let func = module.get_function("cross_fade")?;
59 let (_, block_size) = func.suggested_launch_configuration(0, 0.into())?;
60
61 let stride_number = match stride {
62 GridStride::Default => 100,
63 GridStride::Custom(number) => number,
64 };
65
66 let grid_size = u32::try_from(
76 (image_size as u64 * u64::from(iterations) + u64::from(block_size) * 100)
77 / (u64::from(block_size) * stride_number),
78 )
79 .expect(&format!(
80 "image size: {} or iterations: {} are too large",
81 image_size, iterations
82 ));
83
84 unsafe {
85 launch!(
86 func<<<grid_size, block_size, 0, stream>>>(
87 first_image_buffer.as_device_ptr(),
88 first_image_buffer.len(),
89 second_image_buffer.as_device_ptr(),
90 second_image_buffer.len(),
91 iterations,
92 block_size as u64 * grid_size as u64,
93 stride_number,
94 output_buffer.as_unified_ptr(),
95 )
96 )?;
97 }
98
99 stream.synchronize()?;
100
101 let output_image = output_buffer.to_vec();
104
105 Ok(output_image)
106}
107
108#[cfg(test)]
109mod tests {
110 use crate::cross_fade;
111 use crate::GridStride;
112
113 #[test]
114 fn test_basic() {
115 let first_image: &[u8] = &[100, 255, 5, 76];
116 let second_image: &[u8] = &[28, 8, 245, 100];
117 let iterations = 3;
118
119 let output =
120 cross_fade(first_image, second_image, iterations, GridStride::Default).unwrap();
121
122 let split_output = output.chunks_exact(first_image.len()).collect::<Vec<_>>();
123
124 assert_eq!(
125 split_output,
126 vec![[100, 255, 5, 76,], [64, 131, 125, 88,], [28, 8, 245, 100,]]
127 )
128 }
129
130 #[test]
131 fn test_large_random_image() {
132 const N: usize = 11059200;
133 let first_image: Vec<u8> = (0..N).map(|_| rand::random::<u8>()).collect();
134 let second_image: Vec<u8> = (0..N).map(|_| rand::random::<u8>()).collect();
135 let iterations = 720;
136
137 let output =
138 cross_fade(&first_image, &second_image, iterations, GridStride::Default).unwrap();
139
140 let split_output = output.chunks_exact(first_image.len()).collect::<Vec<_>>();
141 assert_ne!(split_output.last().unwrap(), &vec![0; first_image.len()])
142 }
143}