ciphercore_base/ops/pwl/
approx_gelu.rs

1//! GELU(x) piecewise-linear approximation.
2use crate::custom_ops::CustomOperationBody;
3use crate::data_types::{Type, INT64};
4use crate::errors::Result;
5use crate::graphs::{Context, Graph};
6
7use serde::{Deserialize, Serialize};
8
9use super::approx_pointwise::{create_approximation, PWLConfig};
10
11/// A structure that defines the custom operation ApproxGelu that computes an approximate Gelu(x / (2 ** precision)) * (2 ** precision) using piecewise-linear approximation.
12///
13/// So far this operation supports only INT64 scalar type.
14/// For background on gelu function, see <https://arxiv.org/pdf/1606.08415v4.pdf>.
15///
16/// # Custom operation arguments
17///
18/// - Node containing a signed 64-bit array or scalar to compute the GELU
19///
20/// # Custom operation returns
21///
22/// New ApproxGelu node
23///
24/// # Example
25///
26/// ```
27/// # use ciphercore_base::graphs::create_context;
28/// # use ciphercore_base::data_types::{scalar_type, array_type, INT64};
29/// # use ciphercore_base::custom_ops::{CustomOperation};
30/// # use ciphercore_base::ops::pwl::approx_gelu::ApproxGelu;
31/// let c = create_context().unwrap();
32/// let g = c.create_graph().unwrap();
33/// let t = array_type(vec![3], INT64);
34/// let x = g.input(t.clone()).unwrap();
35/// let n = g.custom_op(CustomOperation::new(ApproxGelu {precision: 4}), vec![x]).unwrap();
36///
37// TODO: generalize to other types.
38#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
39pub struct ApproxGelu {
40    /// Assume that we're operating in fixed precision arithmetic with denominator 2 ** precision.
41    pub precision: u64,
42}
43
44#[typetag::serde]
45impl CustomOperationBody for ApproxGelu {
46    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
47        if arguments_types.len() != 1 {
48            return Err(runtime_error!("Invalid number of arguments for ApproxGelu"));
49        }
50        let t = arguments_types[0].clone();
51        if !t.is_scalar() && !t.is_array() {
52            return Err(runtime_error!(
53                "Argument in ApproxGelu must be a scalar or an array"
54            ));
55        }
56        let sc = t.get_scalar_type();
57        if sc != INT64 {
58            return Err(runtime_error!(
59                "Argument in ApproxGelu must consist of INT64's"
60            ));
61        }
62        if self.precision > 30 || self.precision == 0 {
63            return Err(runtime_error!("`precision` should be in range [1, 30]."));
64        }
65
66        let g = context.create_graph()?;
67        let arg = g.input(t)?;
68        // Choice of parameters:
69        // -- left/right: our typical use-case is precision=15, leading to minimum value around 3e-5. GELU(-5) is already lower than that, so -4 is a reasonable choice with our precision. On the other side, at 4, Gelu is pretty much linear;
70        // -- log_buckets: we look at max absolute difference to the real sigmoid. It looks as follows:
71        //    log_buckets=4 => 0.0232,
72        //    log_buckets=5 => 0.0059,
73        //    log_buckets=6 => 0.0015.
74        // After 5 segments, we're getting diminishing returns, so it doesn't make sense to go higher (for the sake of performance).
75        // -- flatten_left/flatten_right: GELU is linear on the right and flat on the left.
76        let result = create_approximation(
77            arg,
78            approximate_gelu,
79            -4.0,
80            4.0,
81            self.precision,
82            PWLConfig {
83                log_buckets: 5,
84                flatten_left: true,
85                flatten_right: false,
86            },
87        )?;
88        result.set_as_output()?;
89        g.finalize()?;
90        Ok(g)
91    }
92
93    fn get_name(&self) -> String {
94        format!("ApproxGelu(scaling_factor=2**{})", self.precision)
95    }
96}
97
98fn approximate_gelu(x: f32) -> f32 {
99    // It appears there is no Erf in Rust without additional crates. So we use an approximation.
100    // See also <https://paperswithcode.com/method/gelu>.
101    // The accurate GELU formula is: 0.5 * x * (1 + erf(x / sqrt(2))).
102    let tanh_arg = (2.0 / std::f32::consts::PI).sqrt() * (x + 0.044715 * x * x * x);
103    let ex = tanh_arg.exp();
104    let emx = (-tanh_arg).exp();
105    let tanh = (ex - emx) / (ex + emx);
106    0.5 * x * (1.0 + tanh)
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    use crate::custom_ops::run_instantiation_pass;
114    use crate::custom_ops::CustomOperation;
115    use crate::data_types::array_type;
116    use crate::data_types::scalar_type;
117    use crate::data_values::Value;
118    use crate::evaluators::random_evaluate;
119    use crate::graphs::util::simple_context;
120
121    fn scalar_helper(arg: i64, precision: u64) -> Result<i64> {
122        let c = simple_context(|g| {
123            let i = g.input(scalar_type(INT64))?;
124            g.custom_op(CustomOperation::new(ApproxGelu { precision }), vec![i])
125        })?;
126        let mapped_c = run_instantiation_pass(c)?;
127        let result = random_evaluate(
128            mapped_c.get_context().get_main_graph()?,
129            vec![Value::from_scalar(arg, INT64)?],
130        )?;
131        let res = result.to_i64(INT64)?;
132        Ok(res)
133    }
134
135    fn array_helper(arg: Vec<i64>) -> Result<Vec<i64>> {
136        let array_t = array_type(vec![arg.len() as u64], INT64);
137        let c = simple_context(|g| {
138            let i = g.input(array_t.clone())?;
139            g.custom_op(CustomOperation::new(ApproxGelu { precision: 10 }), vec![i])
140        })?;
141        let mapped_c = run_instantiation_pass(c)?;
142        let result = random_evaluate(
143            mapped_c.get_context().get_main_graph()?,
144            vec![Value::from_flattened_array(&arg, INT64)?],
145        )?;
146        result.to_flattened_array_i64(array_t)
147    }
148
149    #[test]
150    fn test_approx_gelu_scalar() {
151        for i in (-5000..5000).step_by(1000) {
152            let expected = (approximate_gelu((i as f32) / 1024.0) * 1024.0) as i64;
153            let actual = scalar_helper(i, 10).unwrap();
154            let absolute_error = ((expected - actual).abs() as f64) / 1024.0;
155            assert!(absolute_error <= 0.01);
156        }
157    }
158
159    #[test]
160    fn test_approx_gelu_array() {
161        let arr: Vec<i64> = (-5000..5000).step_by(100).collect();
162        let res = array_helper(arr.clone()).unwrap();
163        for i in 0..arr.len() {
164            let expected = (approximate_gelu((arr[i] as f32) / 1024.0) * 1024.0) as i64;
165            let actual = res[i];
166            let absolute_error = ((expected - actual).abs() as f64) / 1024.0;
167            assert!(absolute_error <= 0.01);
168        }
169    }
170}