rai_core/primitives/
vision.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
use crate::{Primitive, Shape, Tensor};
use std::any::Any;
use tracing::Level;

#[derive(Clone, Debug, PartialEq)]
pub struct UpsampleNearest1d {
    pub size: usize,
}

impl UpsampleNearest1d {
    pub fn new(size: usize) -> Self {
        Self { size }
    }
}

impl Primitive for UpsampleNearest1d {
    fn clone_boxed(&self) -> Box<dyn Primitive> {
        Box::new(self.clone())
    }

    fn as_any(&self) -> &dyn Any {
        self
    }

    fn dot_label(&self) -> String {
        format!("UpsampleNearest1d({})", self.size)
    }

    #[tracing::instrument(ret(level = Level::TRACE))]
    fn jvp(&self, _output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
        todo!("jvp for UpsampleNearest1d")
    }

    #[tracing::instrument(ret(level = Level::TRACE))]
    fn vjp(&self, output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
        let x = &primals[0];
        let [_n, c, size] = x.shape_before::<3>();
        assert!(
            self.size % size != 0,
            "UpsampleNearest1d vjp not supported for non integer upscaling factors"
        );
        let scale = self.size / size;
        let kernel = &Tensor::ones([c, 1, scale], x.dtype(), x.device());
        let cotan_x = cotangent.conv1d(kernel, 0, scale, 1, c);
        vec![cotan_x]
    }
}

#[derive(Clone, Debug, PartialEq)]
pub struct UpsampleNearest2d {
    pub size: (usize, usize),
}

impl UpsampleNearest2d {
    pub fn new(size: (usize, usize)) -> Self {
        Self { size }
    }
}

impl Primitive for UpsampleNearest2d {
    fn clone_boxed(&self) -> Box<dyn Primitive> {
        Box::new(self.clone())
    }

    fn as_any(&self) -> &dyn Any {
        self
    }

    fn dot_label(&self) -> String {
        format!("UpsampleNearest2d({:?})", self.size)
    }

    #[tracing::instrument(ret(level = Level::TRACE))]
    fn jvp(&self, _output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
        todo!("jvp for UpsampleNearest2d")
    }

    #[tracing::instrument(ret(level = Level::TRACE))]
    fn vjp(&self, _output: &Tensor, primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
        let x = &primals[0];
        let [_n, c, h, w] = x.shape_before::<4>();
        assert!(
            self.size.0 % h != 0 || self.size.1 % w != 0,
            "UpsampleNearest2d vjp not supported for non integer upscaling factors"
        );
        let scale_h = self.size.0 / h;
        let scale_w = self.size.1 / w;
        let kernel = Tensor::ones([c, 1, scale_h, scale_w], x.dtype(), x.device());
        let cotan_x = cotangent.conv2d(kernel, [0, 0], [scale_h, scale_w], [1, 1], c);
        vec![cotan_x]
    }
}