1use crate::custom_ops::CustomOperationBody;
3use crate::data_types::Type;
4use crate::errors::Result;
5use crate::graphs::{Context, Graph};
6
7use serde::{Deserialize, Serialize};
8
9use super::utils::{constant_scalar, inverse_initial_approximation, multiply_fixed_point};
10
11#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
51pub struct GoldschmidtDivision {
52 pub iterations: u64,
54 pub denominator_cap_2k: u64,
56}
57
58#[typetag::serde]
59impl CustomOperationBody for GoldschmidtDivision {
60 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
61 if arguments_types.len() != 2 && arguments_types.len() != 3 {
62 return Err(runtime_error!(
63 "Invalid number of arguments for GoldschmidtDivision, given {}, expected 2 or 3",
64 arguments_types.len()
65 ));
66 }
67
68 let dividend_type = arguments_types[0].clone();
69 let divisor_type = arguments_types[1].clone();
70 if dividend_type.get_scalar_type() != divisor_type.get_scalar_type() {
71 return Err(runtime_error!(
72 "Invalid scalar types for GoldschmidtDivision: dividend scalr type {} and divisor scalar type {} must be the same",
73 dividend_type.get_scalar_type(),
74 divisor_type.get_scalar_type()
75 ));
76 }
77 if !divisor_type.is_scalar() && !divisor_type.is_array() {
78 return Err(runtime_error!(
79 "Divisor in GoldschmidtDivision must be a scalar or an array"
80 ));
81 }
82 if !dividend_type.is_scalar() && !dividend_type.is_array() {
83 return Err(runtime_error!(
84 "Dividend in GoldschmidtDivision must be a scalar or an array"
85 ));
86 }
87
88 let sc = dividend_type.get_scalar_type();
89 if sc.size_in_bits() < 64 {
90 return Err(runtime_error!(
91 "Divisor in GoldshmidtDivision supported only for 64-bit+ types: INT64, UINT64, INT128, UINT128"
92 ));
93 }
94 let has_initial_approximation = arguments_types.len() == 3;
95 if has_initial_approximation {
96 let initial_approximation_t = arguments_types[2].clone();
97 if initial_approximation_t != divisor_type {
98 return Err(runtime_error!(
99 "Divisor and initial approximation must have the same type."
100 ));
101 }
102 }
103
104 let g_initial_approximation =
105 inverse_initial_approximation(&context, divisor_type.clone(), self.denominator_cap_2k)?;
106 let g = context.create_graph()?;
107 let dividend = g.input(dividend_type)?;
108 let divisor = g.input(divisor_type.clone())?;
109 let approximation = if has_initial_approximation {
110 g.input(divisor_type)?
111 } else if self.denominator_cap_2k == 0 {
112 g.ones(divisor_type)?
113 } else {
114 g.call(g_initial_approximation, vec![divisor.clone()])?
115 };
116 let two_power_cap_plus_one =
122 constant_scalar(&g, 1u128 << (self.denominator_cap_2k + 1), sc)?;
123 let mut w = approximation;
124 let mut a = dividend.multiply(w.clone())?;
125 let mut b = divisor.multiply(w.clone())?;
126 for _ in 0..self.iterations - 1 {
127 w = two_power_cap_plus_one.subtract(b.clone())?;
128 a = multiply_fixed_point(a.clone(), w.clone(), self.denominator_cap_2k)?;
129 b = multiply_fixed_point(b.clone(), w.clone(), self.denominator_cap_2k)?;
130 }
131 a.set_as_output()?;
132 g.finalize()?;
133 Ok(g)
134 }
135
136 fn get_name(&self) -> String {
137 format!(
138 "GoldshmidtDivision(iterations={}, cap=2**{})",
139 self.iterations, self.denominator_cap_2k
140 )
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use crate::custom_ops::run_instantiation_pass;
148 use crate::custom_ops::CustomOperation;
149 use crate::data_types::{array_type, scalar_type, ScalarType};
150 use crate::data_types::{INT128, INT64, UINT128, UINT64};
151 use crate::data_values::Value;
152 use crate::evaluators::random_evaluate;
153 use crate::graphs::util::simple_context;
154 use crate::inline::inline_common::DepthOptimizationLevel;
155 use crate::inline::inline_ops::inline_operations;
156 use crate::inline::inline_ops::InlineConfig;
157 use crate::inline::inline_ops::InlineMode;
158 use crate::mpc::mpc_compiler::prepare_for_mpc_evaluation;
159 use crate::mpc::mpc_compiler::IOStatus;
160
161 fn scalar_division_helper(
162 dividend: u64,
163 divisor: u64,
164 initial_approximation: Option<u64>,
165 st: ScalarType,
166 denominator_cap_2k: u64,
167 ) -> Result<Value> {
168 let c = simple_context(|g| {
169 let dividend_node = g.input(scalar_type(st))?;
170 let divisor_node = g.input(scalar_type(st))?;
171 if let Some(approx) = initial_approximation {
172 let approx_const = constant_scalar(&g, approx, st)?;
173 g.custom_op(
174 CustomOperation::new(GoldschmidtDivision {
175 iterations: 5,
176 denominator_cap_2k,
177 }),
178 vec![dividend_node, divisor_node, approx_const],
179 )
180 } else {
181 g.custom_op(
182 CustomOperation::new(GoldschmidtDivision {
183 iterations: 5,
184 denominator_cap_2k,
185 }),
186 vec![dividend_node, divisor_node],
187 )
188 }
189 })?;
190 let mapped_c = run_instantiation_pass(c)?;
191 let result = random_evaluate(
192 mapped_c.get_context().get_main_graph()?,
193 vec![
194 Value::from_scalar(dividend, st)?,
195 Value::from_scalar(divisor, st)?,
196 ],
197 )?;
198 Ok(result)
199 }
200
201 fn array_division_helper_array_scalar(
202 dividend: Vec<u64>,
203 divisor: u64,
204 st: ScalarType,
205 ) -> Result<Vec<u64>> {
206 let array_t = array_type(vec![dividend.len() as u64], st);
207 let c = simple_context(|g| {
208 let dividend_node = g.input(array_t.clone())?;
209 let divisor_node = g.input(scalar_type(st))?;
210 g.custom_op(
211 CustomOperation::new(GoldschmidtDivision {
212 iterations: 5,
213 denominator_cap_2k: 10,
214 }),
215 vec![dividend_node, divisor_node],
216 )
217 })?;
218 let mapped_c = run_instantiation_pass(c)?;
219 let result = random_evaluate(
220 mapped_c.get_context().get_main_graph()?,
221 vec![
222 Value::from_flattened_array(÷nd, st)?,
223 Value::from_scalar(divisor, st)?,
224 ],
225 )?;
226 result.to_flattened_array_u64(array_t)
227 }
228
229 fn array_division_helper_scalar_array(
230 dividend: u64,
231 divisor: Vec<u64>,
232 st: ScalarType,
233 ) -> Result<Vec<u64>> {
234 let array_t = array_type(vec![divisor.len() as u64], st);
235 let c = simple_context(|g| {
236 let dividend_node = g.input(scalar_type(st))?;
237 let divisor_node = g.input(array_t.clone())?;
238 g.custom_op(
239 CustomOperation::new(GoldschmidtDivision {
240 iterations: 5,
241 denominator_cap_2k: 10,
242 }),
243 vec![dividend_node, divisor_node],
244 )
245 })?;
246 let mapped_c = run_instantiation_pass(c)?;
247 let result = random_evaluate(
248 mapped_c.get_context().get_main_graph()?,
249 vec![
250 Value::from_scalar(dividend, st)?,
251 Value::from_flattened_array(&divisor, st)?,
252 ],
253 )?;
254 result.to_flattened_array_u64(array_t)
255 }
256
257 fn array_division_helper_array_array(
258 dividend: Vec<u64>,
259 divisor: Vec<u64>,
260 st: ScalarType,
261 ) -> Result<Vec<u64>> {
262 let array_t = array_type(vec![divisor.len() as u64], st);
263 let c = simple_context(|g| {
264 let dividend_node = g.input(array_t.clone())?;
265 let divisor_node = g.input(array_t.clone())?;
266 g.custom_op(
267 CustomOperation::new(GoldschmidtDivision {
268 iterations: 5,
269 denominator_cap_2k: 10,
270 }),
271 vec![dividend_node, divisor_node],
272 )
273 })?;
274 let mapped_c = run_instantiation_pass(c)?;
275 let result = random_evaluate(
276 mapped_c.get_context().get_main_graph()?,
277 vec![
278 Value::from_flattened_array(÷nd, st)?,
279 Value::from_flattened_array(&divisor, st)?,
280 ],
281 )?;
282 result.to_flattened_array_u64(array_t)
283 }
284
285 #[test]
286 fn test_goldschmidt_division_scalar() {
287 let dividend = 123456;
288 let div_v = vec![1, 2, 3, 123, 300, 500, 700];
289 for i in div_v.clone() {
290 let result_int64 = scalar_division_helper(dividend, i, None, INT64, 10)
291 .unwrap()
292 .to_i64(INT64)
293 .unwrap() as i64;
294 let result_uint64 = scalar_division_helper(dividend, i, None, UINT64, 10)
295 .unwrap()
296 .to_u64(UINT64)
297 .unwrap() as i64;
298 let actual_result = (dividend * (1 << 10) / i) as i64;
299
300 assert!(((result_int64 - actual_result).abs() * 100) / actual_result <= 1);
301 assert!(((result_uint64 - actual_result).abs() * 100) / actual_result <= 1);
302 }
303 }
304
305 #[test]
306 fn test_goldschmidt_division_128_bit() {
307 let dividend = 1234567890123456789;
308 let div_v = vec![1, 2, 3, 123, 300, 500, 700];
309 for denominator_cap_2k in [10, 20, 30] {
310 for i in div_v.clone() {
311 let result_int128 =
312 scalar_division_helper(dividend, i, None, INT128, denominator_cap_2k)
313 .unwrap()
314 .to_i128(INT128)
315 .unwrap();
316 let result_uint128 =
317 scalar_division_helper(dividend, i, None, UINT128, denominator_cap_2k)
318 .unwrap()
319 .to_u128(UINT128)
320 .unwrap();
321 let actual_result = dividend as i128 * (1 << denominator_cap_2k) / i as i128;
322
323 assert!(((result_int128 - actual_result).abs() * 100) / actual_result <= 1);
324 assert!(
325 ((result_uint128 as i128 - actual_result).abs() * 100) / actual_result <= 1
326 );
327 }
328 }
329 }
330
331 #[test]
332 fn test_goldschmidt_division_array() {
333 let dividends = vec![2300, 3200, 57, 71000, 183293, 55511];
334 let divisor = 122;
335 let div = array_division_helper_array_scalar(dividends.clone(), divisor, UINT64).unwrap();
336 let i_div = array_division_helper_array_scalar(dividends.clone(), divisor, INT64).unwrap();
337 let actual_result = dividends
338 .iter()
339 .map(|x| (x * (1 << 10) / divisor) as i64)
340 .collect::<Vec<i64>>();
341 for i in 0..dividends.len() {
342 let result_int64 = i_div[i] as i64;
343 let result_uint64 = div[i] as i64;
344 assert!(((result_int64 - actual_result[i]).abs() * 100) / actual_result[i] <= 1);
345 assert!(((result_uint64 - actual_result[i]).abs() * 100) / actual_result[i] <= 1);
346 }
347 let dividend = 1234567;
348 let divisors = vec![23, 32, 57, 710, 183, 555];
349 let div = array_division_helper_scalar_array(dividend, divisors.clone(), UINT64).unwrap();
350 let i_div = array_division_helper_scalar_array(dividend, divisors.clone(), INT64).unwrap();
351 let actual_result = divisors
352 .iter()
353 .map(|x| (dividend * (1 << 10) / x) as i64)
354 .collect::<Vec<i64>>();
355 for i in 0..dividends.len() {
356 let result_int64 = i_div[i] as i64;
357 let result_uint64 = div[i] as i64;
358 assert!(((result_int64 - actual_result[i]).abs() * 100) / actual_result[i] <= 1);
359 assert!(((result_uint64 - actual_result[i]).abs() * 100) / actual_result[i] <= 1);
360 }
361 let dividends = vec![2300, 3200, 57, 71000, 183293, 55511];
362 let divisors = vec![23, 32, 57, 710, 183, 555];
363 let div =
364 array_division_helper_array_array(dividends.clone(), divisors.clone(), UINT64).unwrap();
365 let i_div =
366 array_division_helper_array_array(dividends.clone(), divisors.clone(), INT64).unwrap();
367 let actual_result = dividends
368 .iter()
369 .zip(divisors.iter())
370 .map(|(x, y)| (*x * (1 << 10) / *y) as i64)
371 .collect::<Vec<i64>>();
372 for i in 0..dividends.len() {
373 let result_int64 = i_div[i] as i64;
374 let result_uint64 = div[i] as i64;
375 assert!(((result_int64 - actual_result[i]).abs() * 100) / actual_result[i] <= 1);
376 assert!(((result_uint64 - actual_result[i]).abs() * 100) / actual_result[i] <= 1);
377 }
378 }
379 #[test]
380 fn test_goldschmidt_division_compiles_end2end() -> Result<()> {
381 let c = simple_context(|g| {
382 let dividend = g.input(scalar_type(INT64))?;
383 let divisor = g.input(scalar_type(INT64))?;
384 g.custom_op(
385 CustomOperation::new(GoldschmidtDivision {
386 iterations: 5,
387 denominator_cap_2k: 10,
388 }),
389 vec![dividend, divisor],
390 )
391 })?;
392 let inline_config = InlineConfig {
393 default_mode: InlineMode::DepthOptimized(DepthOptimizationLevel::Default),
394 ..Default::default()
395 };
396 let instantiated_context = run_instantiation_pass(c)?.get_context();
397 let inlined_context = inline_operations(instantiated_context, inline_config.clone())?;
398 let _unused = prepare_for_mpc_evaluation(
399 inlined_context,
400 vec![vec![IOStatus::Shared, IOStatus::Shared]],
401 vec![vec![]],
402 inline_config,
403 )?;
404 Ok(())
405 }
406}