1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
use cust::{error::CudaError, memory::DeviceCopy, prelude::*};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum TransitionError {
#[error("Sizes of the two images are not equal. First image size: {first_image_size}, Second image size: {second_image_size}")]
SizeNotEqual {
first_image_size: usize,
second_image_size: usize,
},
#[error("gpu error")]
GPUError(#[from] CudaError),
}
static PTX: &str = include_str!("../ptx/image.ptx");
pub fn cross_fade<T: DeviceCopy + std::fmt::Debug + Default>(
first_image: &[T],
second_image: &[T],
iterations: usize,
) -> Result<Vec<T>, TransitionError> {
let _ctx = cust::quick_init()?;
let module = Module::from_ptx(PTX, &[])?;
let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
if first_image.len() != second_image.len() {
return Err(TransitionError::SizeNotEqual {
first_image_size: first_image.len(),
second_image_size: second_image.len(),
});
}
let image_size = first_image.len();
let first_image_buffer = first_image.as_dbuf()?;
let second_image_buffer = second_image.as_dbuf()?;
let mut output_buffer: UnifiedBuffer<T> =
unsafe { UnifiedBuffer::uninitialized(image_size * iterations) }?;
let func = module.get_function("cross_fade")?;
let (_, block_size) = func.suggested_launch_configuration(0, 0.into())?;
let grid_size = ((image_size * iterations) as u32 + block_size - 1) / block_size;
unsafe {
launch!(
func<<<grid_size, block_size, 0, stream>>>(
first_image_buffer.as_device_ptr(),
first_image_buffer.len(),
second_image_buffer.as_device_ptr(),
second_image_buffer.len(),
iterations,
output_buffer.as_unified_ptr(),
)
)?;
}
stream.synchronize()?;
let output_image = output_buffer.to_vec();
Ok(output_image)
}
#[cfg(test)]
mod tests {
use crate::cross_fade;
#[test]
fn test_cross_fade() {
let first_image: &[u16] = &[100, 255, 5, 76];
let second_image = &[28, 8, 245, 100];
let iterations = 720;
let output = cross_fade(first_image, second_image, iterations).unwrap();
let split_output = output.chunks_exact(first_image.len()).collect::<Vec<_>>();
println!("output: {:#?}", split_output);
}
}