1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
use rayon::prelude::*;

use crate::color::Color;
use crate::image::Image;
use crate::ty::Type;

pub struct Join<'a, A: 'a + Filter, B: Filter, F: Fn(f64, f64) -> f64> {
    a: &'a A,
    b: B,
    f: F,
}

pub struct AndThen<'a, A: 'a + Filter, F: Fn(f64) -> f64> {
    a: &'a A,
    f: F,
}

impl<'a, A: Filter, F: Sync + Fn(f64) -> f64> Filter for AndThen<'a, A, F> {
    fn compute_at<T: Type, C: Color, I: Image<T, C>>(
        &self,
        x: usize,
        y: usize,
        c: usize,
        input: &[&I],
    ) -> f64 {
        let f = &self.f;
        f(self.a.compute_at(x, y, c, input))
    }
}

impl<'a, A: Filter, B: Filter, F: Sync + Fn(f64, f64) -> f64> Filter for Join<'a, A, B, F> {
    fn compute_at<T: Type, C: Color, I: Image<T, C>>(
        &self,
        x: usize,
        y: usize,
        c: usize,
        input: &[&I],
    ) -> f64 {
        let f = &self.f;
        f(
            self.a.compute_at(x, y, c, input),
            self.b.compute_at(x, y, c, input),
        )
    }
}

pub trait Filter: Sized + Sync {
    fn compute_at<T: Type, C: Color, I: Image<T, C>>(
        &self,
        x: usize,
        y: usize,
        c: usize,
        input: &[&I],
    ) -> f64;

    fn eval_s<T: Type, C: Color, U: Type, D: Color, I: Image<T, C>, J: Image<U, D>>(
        &self,
        output: &mut I,
        input: &[&J],
    ) {
        let (width, height, channels) = output.shape();
        for y in 0..height {
            for x in 0..width {
                for c in 0..channels {
                    output.set_f(x, y, c, T::clamp(self.compute_at(x, y, c, input)));
                }
            }
        }
    }

    fn eval<
        T: Send + Type,
        C: Color,
        U: Type,
        D: Color,
        I: Sync + Send + Image<T, C>,
        J: Sync + Image<U, D>,
    >(
        &self,
        output: &mut I,
        input: &[&J],
    ) {
        let (width, _height, channels) = output.shape();

        output
            .data_mut()
            .par_iter_mut()
            .chunks(channels)
            .enumerate()
            .for_each(|(n, mut pixel)| {
                let y = n / width;
                let x = n - (y * width);
                for c in 0..channels {
                    *pixel[c] =
                        T::from_float(T::denormalize(T::clamp(self.compute_at(x, y, c, input))));
                }
            });
    }

    fn join<A: Filter, F: Fn(f64, f64) -> f64>(&self, other: A, f: F) -> Join<Self, A, F> {
        Join {
            a: self,
            b: other,
            f,
        }
    }

    fn and_then<F: Fn(f64) -> f64>(&self, f: F) -> AndThen<Self, F> {
        AndThen { a: self, f }
    }
}

#[macro_export]
macro_rules! image2_filter {
    ($name:ident, $x:ident, $y:ident, $c:ident, $input:ident, $f:expr) => {
        pub struct $name;

        impl $crate::Filter for $name {
            fn compute_at<T: Type, C: Color, I: Image<T, C>>(
                &self,
                $x: usize,
                $y: usize,
                $c: usize,
                $input: &[&I],
            ) -> f64 {
                $f
            }
        }
    };
}

image2_filter!(Invert, x, y, c, input, {
    T::max_f() - input[0].get_f(x, y, c)
});

image2_filter!(Blend, x, y, c, input, {
    (input[0].get_f(x, y, c) + input[1].get_f(x, y, c)) / 2.0
});

image2_filter!(ToGrayscale, x, y, _c, input, {
    let a = input[0];
    let v = a.get_f(x, y, 0) * 0.21 + a.get_f(x, y, 1) * 0.72 + a.get_f(x, y, 2) * 0.07;
    if C::channels() == 4 {
        return v * a.get_f(x, y, 3);
    }
    v
});

image2_filter!(ToColor, x, y, c, input, {
    if c == 4 {
        return T::max_f();
    }

    input[0].get_f(x, y, c % C::channels())
});

image2_filter!(RgbaToRgb, x, y, c, input, {
    let a = input[0];
    a.get_f(x, y, c) * a.get_f(x, y, 3)
});

image2_filter!(RgbToBgr, x, y, c, input, {
    if c == 0 {
        input[0].get_f(x, y, 2)
    } else if c == 2 {
        input[0].get_f(x, y, 0)
    } else {
        input[0].get_f(x, y, c)
    }
});