1use kornia_tensor::{CpuAllocator, Tensor, Tensor2, Tensor3};
2
3use crate::error::ImageError;
4
5#[derive(Clone, Copy, Debug, PartialEq)]
23pub struct ImageSize {
24 pub width: usize,
26 pub height: usize,
28}
29
30impl std::fmt::Display for ImageSize {
31 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
32 write!(
33 f,
34 "ImageSize {{ width: {}, height: {} }}",
35 self.width, self.height
36 )
37 }
38}
39
40impl From<[usize; 2]> for ImageSize {
41 fn from(size: [usize; 2]) -> Self {
42 ImageSize {
43 width: size[0],
44 height: size[1],
45 }
46 }
47}
48
49impl From<ImageSize> for [u32; 2] {
50 fn from(size: ImageSize) -> Self {
51 [size.width as u32, size.height as u32]
52 }
53}
54
55#[derive(Clone)]
56pub struct Image<T, const C: usize>(pub Tensor3<T, CpuAllocator>);
60
61impl<T, const C: usize> std::ops::Deref for Image<T, C> {
63 type Target = Tensor3<T, CpuAllocator>;
64
65 fn deref(&self) -> &Self::Target {
67 &self.0
68 }
69}
70
71impl<T, const C: usize> std::ops::DerefMut for Image<T, C> {
73 fn deref_mut(&mut self) -> &mut Self::Target {
75 &mut self.0
76 }
77}
78
79impl<T, const C: usize> Image<T, C> {
80 pub fn new(size: ImageSize, data: Vec<T>) -> Result<Self, ImageError>
113 where
114 T: Clone, {
116 if data.len() != size.width * size.height * C {
118 return Err(ImageError::InvalidChannelShape(
119 data.len(),
120 size.width * size.height * C,
121 ));
122 }
123
124 Ok(Self(Tensor3::from_shape_vec(
126 [size.height, size.width, C],
127 data,
128 CpuAllocator,
129 )?))
130 }
131
132 pub fn from_size_val(size: ImageSize, val: T) -> Result<Self, ImageError>
163 where
164 T: Clone + Default,
165 {
166 let data = vec![val; size.width * size.height * C];
167 let image = Image::new(size, data)?;
168
169 Ok(image)
170 }
171
172 pub unsafe fn from_raw_parts(
188 size: ImageSize,
189 data: *const T,
190 len: usize,
191 ) -> Result<Self, ImageError>
192 where
193 T: Clone,
194 {
195 Tensor::from_raw_parts([size.height, size.width, C], data, len, CpuAllocator)?.try_into()
196 }
197
198 pub fn from_size_slice(size: ImageSize, data: &[T]) -> Result<Self, ImageError>
214 where
215 T: Clone,
216 {
217 let tensor: Tensor3<T, CpuAllocator> =
218 Tensor::from_shape_slice([size.height, size.width, C], data, CpuAllocator)?;
219 Image::try_from(tensor)
220 }
221
222 pub fn map<U>(&self, f: impl Fn(&T) -> U) -> Result<Image<U, C>, ImageError>
232 where
233 U: Clone,
234 {
235 let data = self.as_slice().iter().map(f).collect::<Vec<U>>();
236 Image::<U, C>::new(self.size(), data)
237 }
238
239 pub fn cast<U>(&self) -> Result<Image<U, C>, ImageError>
245 where
246 U: num_traits::NumCast + Clone + Copy, T: num_traits::NumCast + Clone + Copy, {
249 let casted_data = self
251 .as_slice()
252 .iter()
253 .map(|&x| {
254 let xu = U::from(x).ok_or(ImageError::CastError)?;
255 Ok(xu)
256 })
257 .collect::<Result<Vec<U>, ImageError>>()?;
258
259 Image::new(self.size(), casted_data)
260 }
261
262 pub fn channel(&self, channel: usize) -> Result<Image<T, 1>, ImageError>
275 where
276 T: Clone,
277 {
278 if channel >= C {
279 return Err(ImageError::ChannelIndexOutOfBounds(channel, C));
280 }
281
282 let channel_data = self
283 .as_slice()
284 .iter()
285 .skip(channel)
286 .step_by(C)
287 .cloned()
288 .collect();
289
290 Image::new(self.size(), channel_data)
291 }
292
293 pub fn split_channels(&self) -> Result<Vec<Image<T, 1>>, ImageError>
315 where
316 T: Clone + Copy, {
318 let mut channels = Vec::with_capacity(C);
319
320 for i in 0..C {
321 channels.push(self.channel(i)?);
322 }
323
324 Ok(channels)
325 }
326
327 pub fn size(&self) -> ImageSize {
329 ImageSize {
330 width: self.shape[1],
331 height: self.shape[0],
332 }
333 }
334
335 pub fn cols(&self) -> usize {
337 self.shape[1]
338 }
339
340 pub fn rows(&self) -> usize {
342 self.shape[0]
343 }
344
345 pub fn width(&self) -> usize {
347 self.cols()
348 }
349
350 pub fn height(&self) -> usize {
352 self.rows()
353 }
354
355 pub fn num_channels(&self) -> usize {
357 C
358 }
359
360 pub fn cast_and_scale<U>(self, scale: U) -> Result<Image<U, C>, ImageError>
394 where
395 U: num_traits::NumCast + std::ops::Mul<Output = U> + Clone + Copy,
396 T: num_traits::NumCast + Clone + Copy,
397 {
398 let casted_data = self
399 .as_slice()
400 .iter()
401 .map(|&x| {
402 let xu = U::from(x).ok_or(ImageError::CastError)?;
403 Ok(xu * scale)
404 })
405 .collect::<Result<Vec<U>, ImageError>>()?;
406
407 Image::new(self.size(), casted_data)
408 }
409
410 pub fn scale_and_cast<U>(&self, scale: T) -> Result<Image<U, C>, ImageError>
420 where
421 U: num_traits::NumCast + Clone + Copy,
422 T: num_traits::NumCast + std::ops::Mul<Output = T> + Clone + Copy,
423 {
424 let casted_data = self
425 .as_slice()
426 .iter()
427 .map(|&x| {
428 let xu = U::from(x * scale).ok_or(ImageError::CastError)?;
429 Ok(xu)
430 })
431 .collect::<Result<Vec<U>, ImageError>>()?;
432
433 Image::new(self.size(), casted_data)
434 }
435
436 pub fn get_pixel(&self, x: usize, y: usize, ch: usize) -> Result<&T, ImageError> {
451 if x >= self.width() || y >= self.height() {
452 return Err(ImageError::PixelIndexOutOfBounds(
453 x,
454 y,
455 self.width(),
456 self.height(),
457 ));
458 }
459
460 if ch >= C {
461 return Err(ImageError::ChannelIndexOutOfBounds(ch, C));
462 }
463
464 let val = match self.get([y, x, ch]) {
465 Some(v) => v,
466 None => return Err(ImageError::ImageDataNotContiguous),
467 };
468
469 Ok(val)
470 }
471
472 pub fn set_pixel(&mut self, x: usize, y: usize, ch: usize, val: T) -> Result<(), ImageError> {
488 if x >= self.width() || y >= self.height() {
489 return Err(ImageError::PixelIndexOutOfBounds(
490 x,
491 y,
492 self.width(),
493 self.height(),
494 ));
495 }
496
497 if ch >= C {
498 return Err(ImageError::ChannelIndexOutOfBounds(ch, C));
499 }
500
501 let idx = y * self.width() * C + x * C + ch;
502 self.as_slice_mut()[idx] = val;
503
504 Ok(())
505 }
506}
507
508impl<T> TryFrom<Tensor2<T, CpuAllocator>> for Image<T, 1>
510where
511 T: Clone,
512{
513 type Error = ImageError;
514
515 fn try_from(value: Tensor2<T, CpuAllocator>) -> Result<Self, Self::Error> {
516 Self::from_size_slice(
517 ImageSize {
518 width: value.shape[1],
519 height: value.shape[0],
520 },
521 value.as_slice(),
522 )
523 }
524}
525
526impl<T, const C: usize> TryFrom<Tensor3<T, CpuAllocator>> for Image<T, C> {
528 type Error = ImageError;
529
530 fn try_from(value: Tensor3<T, CpuAllocator>) -> Result<Self, Self::Error> {
531 if value.shape[2] != C {
532 return Err(ImageError::InvalidChannelShape(value.shape[2], C));
533 }
534 Ok(Self(value))
535 }
536}
537
538impl<T, const C: usize> TryInto<Tensor3<T, CpuAllocator>> for Image<T, C> {
539 type Error = ImageError;
540
541 fn try_into(self) -> Result<Tensor3<T, CpuAllocator>, Self::Error> {
542 Ok(self.0)
543 }
544}
545
546#[cfg(test)]
547mod tests {
548 use crate::image::{Image, ImageError, ImageSize};
549 use kornia_tensor::{CpuAllocator, Tensor};
550
551 #[test]
552 fn test_image_size() {
553 let image_size = ImageSize {
554 width: 10,
555 height: 20,
556 };
557 assert_eq!(image_size.width, 10);
558 assert_eq!(image_size.height, 20);
559 }
560
561 #[test]
562 fn test_image_smoke() -> Result<(), ImageError> {
563 let image = Image::<u8, 3>::new(
564 ImageSize {
565 width: 10,
566 height: 20,
567 },
568 vec![0u8; 10 * 20 * 3],
569 )?;
570 assert_eq!(image.size().width, 10);
571 assert_eq!(image.size().height, 20);
572 assert_eq!(image.num_channels(), 3);
573
574 Ok(())
575 }
576
577 #[test]
578 fn test_image_from_vec() -> Result<(), ImageError> {
579 let image: Image<f32, 3> = Image::new(
580 ImageSize {
581 height: 3,
582 width: 2,
583 },
584 vec![0.0; 3 * 2 * 3],
585 )?;
586 assert_eq!(image.size().width, 2);
587 assert_eq!(image.size().height, 3);
588 assert_eq!(image.num_channels(), 3);
589
590 Ok(())
591 }
592
593 #[test]
594 fn test_image_cast() -> Result<(), ImageError> {
595 let data = vec![0, 1, 2, 3, 4, 5];
596 let image_u8 = Image::<_, 3>::new(
597 ImageSize {
598 height: 2,
599 width: 1,
600 },
601 data,
602 )?;
603 assert_eq!(image_u8.get([1, 0, 2]), Some(&5u8));
604
605 let image_i32: Image<i32, 3> = image_u8.cast()?;
606 assert_eq!(image_i32.get([1, 0, 2]), Some(&5i32));
607
608 Ok(())
609 }
610
611 #[test]
612 fn test_image_rgbd() -> Result<(), ImageError> {
613 let image = Image::<f32, 4>::new(
614 ImageSize {
615 height: 2,
616 width: 3,
617 },
618 vec![0f32; 2 * 3 * 4],
619 )?;
620 assert_eq!(image.size().width, 3);
621 assert_eq!(image.size().height, 2);
622 assert_eq!(image.num_channels(), 4);
623
624 Ok(())
625 }
626
627 #[test]
628 fn test_image_channel() -> Result<(), ImageError> {
629 let image = Image::<f32, 3>::new(
630 ImageSize {
631 height: 2,
632 width: 1,
633 },
634 vec![0., 1., 2., 3., 4., 5.],
635 )?;
636
637 let channel = image.channel(2)?;
638 assert_eq!(channel.get([1, 0, 0]), Some(&5.0f32));
639
640 Ok(())
641 }
642
643 #[test]
644 fn test_image_split_channels() -> Result<(), ImageError> {
645 let image = Image::<f32, 3>::new(
646 ImageSize {
647 height: 2,
648 width: 1,
649 },
650 vec![0., 1., 2., 3., 4., 5.],
651 )
652 .unwrap();
653 let channels = image.split_channels()?;
654 assert_eq!(channels.len(), 3);
655 assert_eq!(channels[0].get([1, 0, 0]), Some(&3.0f32));
656 assert_eq!(channels[1].get([1, 0, 0]), Some(&4.0f32));
657 assert_eq!(channels[2].get([1, 0, 0]), Some(&5.0f32));
658
659 Ok(())
660 }
661
662 #[test]
663 fn test_scale_and_cast() -> Result<(), ImageError> {
664 let data = vec![0u8, 0, 255, 0, 0, 255];
665 let image_u8 = Image::<u8, 3>::new(
666 ImageSize {
667 height: 2,
668 width: 1,
669 },
670 data,
671 )?;
672 let image_f32 = image_u8.cast_and_scale::<f32>(1. / 255.0)?;
673 assert_eq!(image_f32.get([1, 0, 2]), Some(&1.0f32));
674
675 Ok(())
676 }
677
678 #[test]
679 fn test_cast_and_scale() -> Result<(), ImageError> {
680 let data = vec![0u8, 0, 255, 0, 0, 255];
681 let image_u8 = Image::<u8, 3>::new(
682 ImageSize {
683 height: 2,
684 width: 1,
685 },
686 data,
687 )?;
688 let image_f32 = image_u8.cast_and_scale::<f32>(1. / 255.0)?;
689 assert_eq!(image_f32.get([1, 0, 2]), Some(&1.0f32));
690
691 Ok(())
692 }
693
694 #[test]
695 fn test_image_from_tensor() -> Result<(), ImageError> {
696 let data = vec![0u8, 1, 2, 3, 4, 5];
697 let tensor = Tensor::<u8, 2, _>::from_shape_vec([2, 3], data, CpuAllocator)?;
698
699 let image = Image::<u8, 1>::try_from(tensor.clone())?;
700 assert_eq!(image.size().width, 3);
701 assert_eq!(image.size().height, 2);
702 assert_eq!(image.num_channels(), 1);
703
704 let image_2: Image<u8, 1> = tensor.try_into()?;
705 assert_eq!(image_2.size().width, 3);
706 assert_eq!(image_2.size().height, 2);
707 assert_eq!(image_2.num_channels(), 1);
708
709 Ok(())
710 }
711
712 #[test]
713 fn test_image_from_tensor_3d() -> Result<(), ImageError> {
714 let tensor = Tensor::<u8, 3, CpuAllocator>::from_shape_vec(
715 [2, 3, 4],
716 vec![0u8; 2 * 3 * 4],
717 CpuAllocator,
718 )?;
719
720 let image = Image::<u8, 4>::try_from(tensor.clone())?;
721 assert_eq!(image.size().width, 3);
722 assert_eq!(image.size().height, 2);
723 assert_eq!(image.num_channels(), 4);
724
725 let image_2: Image<u8, 4> = tensor.try_into()?;
726 assert_eq!(image_2.size().width, 3);
727 assert_eq!(image_2.size().height, 2);
728 assert_eq!(image_2.num_channels(), 4);
729
730 Ok(())
731 }
732
733 #[test]
734 fn test_image_from_raw_parts() -> Result<(), ImageError> {
735 let data = vec![0u8, 1, 2, 3, 4, 5];
736 let image =
737 unsafe { Image::<_, 1>::from_raw_parts([2, 3].into(), data.as_ptr(), data.len())? };
738 std::mem::forget(data);
739 assert_eq!(image.size().width, 2);
740 assert_eq!(image.size().height, 3);
741 assert_eq!(image.num_channels(), 1);
742 Ok(())
743 }
744
745 #[test]
746 fn test_get_pixel() -> Result<(), ImageError> {
747 let image = Image::<u8, 3>::new(
748 ImageSize {
749 height: 2,
750 width: 1,
751 },
752 vec![1, 2, 5, 19, 255, 128],
753 )?;
754 assert_eq!(image.get_pixel(0, 0, 0)?, &1);
755 assert_eq!(image.get_pixel(0, 0, 1)?, &2);
756 assert_eq!(image.get_pixel(0, 0, 2)?, &5);
757 assert_eq!(image.get_pixel(0, 1, 0)?, &19);
758 assert_eq!(image.get_pixel(0, 1, 1)?, &255);
759 assert_eq!(image.get_pixel(0, 1, 2)?, &128);
760 Ok(())
761 }
762
763 #[test]
764 fn test_set_pixel() -> Result<(), ImageError> {
765 let mut image = Image::<u8, 3>::new(
766 ImageSize {
767 height: 2,
768 width: 1,
769 },
770 vec![1, 2, 5, 19, 255, 128],
771 )?;
772
773 image.set_pixel(0, 0, 0, 128)?;
774 image.set_pixel(0, 1, 1, 25)?;
775
776 assert_eq!(image.get_pixel(0, 0, 0)?, &128);
777 assert_eq!(image.get_pixel(0, 1, 1)?, &25);
778
779 Ok(())
780 }
781
782 #[test]
783 fn test_image_map() -> Result<(), ImageError> {
784 let image_u8 = Image::<u8, 1>::new(
785 ImageSize {
786 height: 2,
787 width: 1,
788 },
789 vec![0, 128],
790 )?;
791
792 let image_f32 = image_u8.map(|x| (x + 2) as f32)?;
793
794 assert_eq!(image_f32.size().width, 1);
795 assert_eq!(image_f32.size().height, 2);
796 assert_eq!(image_f32.num_channels(), 1);
797 assert_eq!(image_f32.get([0, 0, 0]), Some(&2.0f32));
798 assert_eq!(image_f32.get([1, 0, 0]), Some(&130.0f32));
799
800 Ok(())
801 }
802}