1use crate::custom_ops::{CustomOperation, CustomOperationBody, Or};
3use crate::data_types::{array_type, scalar_type, Type, BIT, INT64, UINT64};
4use crate::data_values::Value;
5use crate::errors::Result;
6use crate::graphs::{Context, Graph, GraphAnnotation};
7use crate::ops::utils::{pull_out_bits, put_in_bits};
8
9use serde::{Deserialize, Serialize};
10
11use super::utils::{constant_scalar, multiply_fixed_point, single_bit_to_arithmetic};
12
13#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
54pub struct InverseSqrt {
55 pub iterations: u64,
57 pub denominator_cap_2k: u64,
59}
60
61#[typetag::serde]
62impl CustomOperationBody for InverseSqrt {
63 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
64 if arguments_types.len() != 1 && arguments_types.len() != 2 {
65 return Err(runtime_error!(
66 "Invalid number of arguments for InverseSqrt"
67 ));
68 }
69 let t = arguments_types[0].clone();
70 if !t.is_scalar() && !t.is_array() {
71 return Err(runtime_error!(
72 "Divisor in InverseSqrt must be a scalar or an array"
73 ));
74 }
75 let sc = t.get_scalar_type();
76 if sc != UINT64 && sc != INT64 {
77 return Err(runtime_error!(
78 "Divisor in InverseSqrt must consist of UINT64's or INT64's"
79 ));
80 }
81 let has_initial_approximation = arguments_types.len() == 2;
82 if has_initial_approximation {
83 let divisor_t = arguments_types[1].clone();
84 if divisor_t != t {
85 return Err(runtime_error!(
86 "Divisor and initial approximation must have the same type."
87 ));
88 }
89 }
90 if self.denominator_cap_2k > 31 {
91 return Err(runtime_error!("denominator_cap_2k is too large."));
92 }
93
94 if self.denominator_cap_2k <= 1 {
95 return Err(runtime_error!("denominator_cap_2k is too small."));
96 }
97
98 let bit_type = if t.is_scalar() {
99 scalar_type(BIT)
100 } else {
101 array_type(t.get_shape(), BIT)
102 };
103
104 let g_highest_one_bit = context.create_graph()?;
106 {
107 let input_state = g_highest_one_bit.input(bit_type.clone())?;
108 let input_bit = g_highest_one_bit.input(bit_type.clone())?;
109
110 let one = g_highest_one_bit.ones(scalar_type(BIT))?;
111 let not_input_state = one.add(input_state.clone())?;
112 let output = not_input_state.multiply(input_bit)?;
115 let new_state = input_state.add(output.clone())?;
118 let output_tuple = g_highest_one_bit.create_tuple(vec![new_state, output])?;
119 output_tuple.set_as_output()?;
120 }
121 g_highest_one_bit.add_annotation(GraphAnnotation::AssociativeOperation)?;
122 g_highest_one_bit.finalize()?;
123
124 let g = context.create_graph()?;
125 let divisor = g.input(t.clone())?;
126 let mut approximation = if has_initial_approximation {
127 g.input(t)?
128 } else if self.denominator_cap_2k == 0 {
129 let two = constant_scalar(&g, 2, sc)?;
130 g.zeros(t)?.add(two)?
131 } else {
132 let divisor_bits = pull_out_bits(divisor.a2b()?)?.array_to_vector()?;
133 let mut divisor_bits_reversed = vec![];
134 for i in 0..self.denominator_cap_2k {
135 let index1 = constant_scalar(&g, 2 * self.denominator_cap_2k - 2 * i - 1, UINT64)?;
140 let index2 = constant_scalar(&g, 2 * self.denominator_cap_2k - 2 * i - 2, UINT64)?;
141 let bit1 = divisor_bits.vector_get(index1)?;
142 let bit2 = divisor_bits.vector_get(index2)?;
143 let bit = g.custom_op(CustomOperation::new(Or {}), vec![bit1, bit2])?;
144 divisor_bits_reversed.push(bit);
145 }
146 let zero = g.zeros(bit_type.clone())?;
147 let highest_one_bit_binary = g
148 .iterate(
149 g_highest_one_bit,
150 zero,
151 g.create_vector(bit_type, divisor_bits_reversed)?,
152 )?
153 .tuple_get(1)?
154 .vector_to_array()?;
155 let highest_one_bit = single_bit_to_arithmetic(highest_one_bit_binary, sc)?;
156 let first_approximation_bits = put_in_bits(highest_one_bit)?;
157 let mut powers_of_two = vec![];
158 for i in 0..self.denominator_cap_2k {
159 powers_of_two.push(1u64 << i);
160 }
161 let powers_of_two_node = g.constant(
162 array_type(vec![self.denominator_cap_2k], sc),
163 Value::from_flattened_array(&powers_of_two, sc)?,
164 )?;
165 first_approximation_bits.dot(powers_of_two_node)?
166 };
167 let three_halves = constant_scalar(&g, 3 << (self.denominator_cap_2k - 1), sc)?;
171 for _ in 0..self.iterations {
172 let x = approximation;
173 let ax2 = divisor.clone().multiply(x.clone())?.multiply(x.clone())?;
176 let ax2_norm = g.truncate(ax2, 1 << (self.denominator_cap_2k + 1))?;
177
178 let mult = three_halves.subtract(ax2_norm)?;
179 approximation = multiply_fixed_point(mult, x, self.denominator_cap_2k)?;
180 }
181 approximation.set_as_output()?;
182 g.finalize()?;
183 Ok(g)
184 }
185
186 fn get_name(&self) -> String {
187 format!(
188 "InverseSqrt(iterations={}, cap=2**{})",
189 self.iterations, self.denominator_cap_2k
190 )
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 use crate::custom_ops::run_instantiation_pass;
199 use crate::custom_ops::CustomOperation;
200 use crate::data_types::ScalarType;
201 use crate::data_values::Value;
202 use crate::evaluators::random_evaluate;
203 use crate::graphs::util::simple_context;
204 use crate::inline::inline_common::DepthOptimizationLevel;
205 use crate::inline::inline_ops::inline_operations;
206 use crate::inline::inline_ops::InlineConfig;
207 use crate::inline::inline_ops::InlineMode;
208 use crate::mpc::mpc_compiler::prepare_for_mpc_evaluation;
209 use crate::mpc::mpc_compiler::IOStatus;
210
211 fn scalar_helper(
212 divisor: u64,
213 initial_approximation: Option<u64>,
214 st: ScalarType,
215 ) -> Result<u64> {
216 let c = simple_context(|g| {
217 let i = g.input(scalar_type(st))?;
218 if let Some(approx) = initial_approximation {
219 let approx_const = constant_scalar(&g, approx, st)?;
220 g.custom_op(
221 CustomOperation::new(InverseSqrt {
222 iterations: 5,
223 denominator_cap_2k: 10,
224 }),
225 vec![i, approx_const],
226 )
227 } else {
228 g.custom_op(
229 CustomOperation::new(InverseSqrt {
230 iterations: 5,
231 denominator_cap_2k: 10,
232 }),
233 vec![i],
234 )
235 }
236 })?;
237 let mapped_c = run_instantiation_pass(c)?;
238 let result = random_evaluate(
239 mapped_c.get_context().get_main_graph()?,
240 vec![Value::from_scalar(divisor, st)?],
241 )?;
242 if st == UINT64 {
243 result.to_u64(st)
244 } else {
245 let res = result.to_i64(st)?;
246 assert!(res >= 0);
247 Ok(res as u64)
248 }
249 }
250
251 fn array_helper(divisor: Vec<u64>, st: ScalarType) -> Result<Vec<u64>> {
252 let array_t = array_type(vec![divisor.len() as u64], st);
253 let c = simple_context(|g| {
254 let i = g.input(array_t.clone())?;
255 g.custom_op(
256 CustomOperation::new(InverseSqrt {
257 iterations: 5,
258 denominator_cap_2k: 10,
259 }),
260 vec![i],
261 )
262 })?;
263 let mapped_c = run_instantiation_pass(c)?;
264 let result = random_evaluate(
265 mapped_c.get_context().get_main_graph()?,
266 vec![Value::from_flattened_array(&divisor, st)?],
267 )?;
268 result.to_flattened_array_u64(array_t)
269 }
270
271 #[test]
272 fn test_inverse_sqrt_scalar() {
273 for i in vec![1, 2, 3, 123, 300, 500, 700] {
274 let expected = (1024.0 / (i as f64).powf(0.5)) as i64;
275 assert!((scalar_helper(i, None, UINT64).unwrap() as i64 - expected).abs() <= 1);
276 assert!((scalar_helper(i, None, INT64).unwrap() as i64 - expected).abs() <= 1);
277 }
278 }
279
280 #[test]
281 fn test_inverse_sqrt_array() {
282 let arr = vec![23, 32, 57, 71, 183, 555];
283 let div1 = array_helper(arr.clone(), UINT64).unwrap();
284 let div2 = array_helper(arr.clone(), INT64).unwrap();
285 for i in 0..arr.len() {
286 let expected = (1024.0 / (arr[i] as f64).powf(0.5)) as i64;
287 assert!((div1[i] as i64 - expected).abs() <= 1);
288 assert!((div2[i] as i64 - expected).abs() <= 1);
289 }
290 }
291
292 #[test]
293 fn test_inverse_sqrt_with_initial_guess() {
294 for i in vec![1, 2, 3, 123, 300, 500, 700] {
295 let mut initial_guess = 1;
296 while initial_guess * initial_guess * i * 4 < 1024 * 1024 {
297 initial_guess *= 2;
298 }
299 let expected = (1024.0 / (i as f64).powf(0.5)) as i64;
300 assert!(
301 (scalar_helper(i, Some(initial_guess), UINT64).unwrap() as i64 - expected).abs()
302 <= 1
303 );
304 assert!(
305 (scalar_helper(i, Some(initial_guess), INT64).unwrap() as i64 - expected).abs()
306 <= 1
307 );
308 }
309 }
310
311 #[test]
312 fn test_inverse_sqrt_negative_values_nothing_bad() {
313 for i in vec![-1, -100, -1000] {
314 scalar_helper(i as u64, None, INT64).unwrap();
315 }
316 }
317
318 #[test]
319 fn test_inverse_sqrt_compiles_end2end() -> Result<()> {
320 let c = simple_context(|g| {
321 let i = g.input(scalar_type(INT64))?;
322 g.custom_op(
323 CustomOperation::new(InverseSqrt {
324 iterations: 5,
325 denominator_cap_2k: 10,
326 }),
327 vec![i],
328 )
329 })?;
330 let inline_config = InlineConfig {
331 default_mode: InlineMode::DepthOptimized(DepthOptimizationLevel::Default),
332 ..Default::default()
333 };
334 let instantiated_context = run_instantiation_pass(c)?.get_context();
335 let inlined_context = inline_operations(instantiated_context, inline_config.clone())?;
336 let _unused = prepare_for_mpc_evaluation(
337 inlined_context,
338 vec![vec![IOStatus::Shared]],
339 vec![vec![]],
340 inline_config,
341 )?;
342 Ok(())
343 }
344}