ciphercore_base/ops/pwl/
approx_gelu.rs1use crate::custom_ops::CustomOperationBody;
3use crate::data_types::{Type, INT64};
4use crate::errors::Result;
5use crate::graphs::{Context, Graph};
6
7use serde::{Deserialize, Serialize};
8
9use super::approx_pointwise::{create_approximation, PWLConfig};
10
11#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
39pub struct ApproxGelu {
40 pub precision: u64,
42}
43
44#[typetag::serde]
45impl CustomOperationBody for ApproxGelu {
46 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
47 if arguments_types.len() != 1 {
48 return Err(runtime_error!("Invalid number of arguments for ApproxGelu"));
49 }
50 let t = arguments_types[0].clone();
51 if !t.is_scalar() && !t.is_array() {
52 return Err(runtime_error!(
53 "Argument in ApproxGelu must be a scalar or an array"
54 ));
55 }
56 let sc = t.get_scalar_type();
57 if sc != INT64 {
58 return Err(runtime_error!(
59 "Argument in ApproxGelu must consist of INT64's"
60 ));
61 }
62 if self.precision > 30 || self.precision == 0 {
63 return Err(runtime_error!("`precision` should be in range [1, 30]."));
64 }
65
66 let g = context.create_graph()?;
67 let arg = g.input(t)?;
68 let result = create_approximation(
77 arg,
78 approximate_gelu,
79 -4.0,
80 4.0,
81 self.precision,
82 PWLConfig {
83 log_buckets: 5,
84 flatten_left: true,
85 flatten_right: false,
86 },
87 )?;
88 result.set_as_output()?;
89 g.finalize()?;
90 Ok(g)
91 }
92
93 fn get_name(&self) -> String {
94 format!("ApproxGelu(scaling_factor=2**{})", self.precision)
95 }
96}
97
98fn approximate_gelu(x: f32) -> f32 {
99 let tanh_arg = (2.0 / std::f32::consts::PI).sqrt() * (x + 0.044715 * x * x * x);
103 let ex = tanh_arg.exp();
104 let emx = (-tanh_arg).exp();
105 let tanh = (ex - emx) / (ex + emx);
106 0.5 * x * (1.0 + tanh)
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112
113 use crate::custom_ops::run_instantiation_pass;
114 use crate::custom_ops::CustomOperation;
115 use crate::data_types::array_type;
116 use crate::data_types::scalar_type;
117 use crate::data_values::Value;
118 use crate::evaluators::random_evaluate;
119 use crate::graphs::util::simple_context;
120
121 fn scalar_helper(arg: i64, precision: u64) -> Result<i64> {
122 let c = simple_context(|g| {
123 let i = g.input(scalar_type(INT64))?;
124 g.custom_op(CustomOperation::new(ApproxGelu { precision }), vec![i])
125 })?;
126 let mapped_c = run_instantiation_pass(c)?;
127 let result = random_evaluate(
128 mapped_c.get_context().get_main_graph()?,
129 vec![Value::from_scalar(arg, INT64)?],
130 )?;
131 let res = result.to_i64(INT64)?;
132 Ok(res)
133 }
134
135 fn array_helper(arg: Vec<i64>) -> Result<Vec<i64>> {
136 let array_t = array_type(vec![arg.len() as u64], INT64);
137 let c = simple_context(|g| {
138 let i = g.input(array_t.clone())?;
139 g.custom_op(CustomOperation::new(ApproxGelu { precision: 10 }), vec![i])
140 })?;
141 let mapped_c = run_instantiation_pass(c)?;
142 let result = random_evaluate(
143 mapped_c.get_context().get_main_graph()?,
144 vec![Value::from_flattened_array(&arg, INT64)?],
145 )?;
146 result.to_flattened_array_i64(array_t)
147 }
148
149 #[test]
150 fn test_approx_gelu_scalar() {
151 for i in (-5000..5000).step_by(1000) {
152 let expected = (approximate_gelu((i as f32) / 1024.0) * 1024.0) as i64;
153 let actual = scalar_helper(i, 10).unwrap();
154 let absolute_error = ((expected - actual).abs() as f64) / 1024.0;
155 assert!(absolute_error <= 0.01);
156 }
157 }
158
159 #[test]
160 fn test_approx_gelu_array() {
161 let arr: Vec<i64> = (-5000..5000).step_by(100).collect();
162 let res = array_helper(arr.clone()).unwrap();
163 for i in 0..arr.len() {
164 let expected = (approximate_gelu((arr[i] as f32) / 1024.0) * 1024.0) as i64;
165 let actual = res[i];
166 let absolute_error = ((expected - actual).abs() as f64) / 1024.0;
167 assert!(absolute_error <= 0.01);
168 }
169 }
170}