ciphercore-base 0.3.1

The base package of CipherCore: computation graphs API, Secure MPC Compiler, utilities for graph evaluation and inspection
Documentation
//! Minimum and maximum operations. They operate on unsigned integers represented as bitstrings.
use crate::custom_ops::{CustomOperation, CustomOperationBody};
use crate::data_types::{array_type, Type};
use crate::errors::Result;
use crate::graphs::{Context, Graph, Node};

use super::comparisons::GreaterThan;
use super::multiplexer::Mux;

use serde::{Deserialize, Serialize};

/// A structure that defines the custom operation Min that computes the minimum of length-n bitstring arrays elementwise.
///
/// The last dimension of both inputs must be the same; it defines the length of input bitstrings.
/// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)).
/// For example, if input arrays are of shapes `[2,3]`, and `[1,3]`, the resulting array has shape `[2,3]`.
///
/// To compare signed numbers, `signed_comparison` should be set `true`.
///
/// To use this and other custom operations in computation graphs, see [Graph::custom_op].
///
/// # Custom operation arguments
///
/// - Node containing a binary array or scalar
/// - Node containing a binary array or scalar
///
/// # Custom operation returns
///
/// New Min node
///
/// # Example
///
/// ```
/// # use ciphercore_base::graphs::create_context;
/// # use ciphercore_base::data_types::{array_type, BIT};
/// # use ciphercore_base::custom_ops::{CustomOperation};
/// # use ciphercore_base::ops::min_max::Min;
/// let c = create_context().unwrap();
/// let g = c.create_graph().unwrap();
/// let t = array_type(vec![2, 3], BIT);
/// let n1 = g.input(t.clone()).unwrap();
/// let n2 = g.input(t.clone()).unwrap();
/// let n3 = g.custom_op(CustomOperation::new(Min {signed_comparison: false}), vec![n1, n2]).unwrap();
/// ```
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct Min {
    /// Boolean value indicating whether input bitstring represent signed integers
    pub signed_comparison: bool,
}

/// If `cmp` is an array, add `1` to the shape by reshaping,
/// otherwise, do nothing.
/// This helper function is necessary for min/max, since
/// we need to pad the shape of the result of the comparison
/// in order to be able to call mux later.
fn normalize_cmp(cmp: Node) -> Result<Node> {
    let cmp_type = cmp.get_type()?;
    let normalized_cmp = if cmp_type.is_array() {
        let mut new_shape = cmp_type.get_shape();
        let st = cmp_type.get_scalar_type();
        new_shape.push(1);
        cmp.reshape(array_type(new_shape, st))?
    } else {
        cmp
    };
    Ok(normalized_cmp)
}

#[typetag::serde]
impl CustomOperationBody for Min {
    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
        if arguments_types.len() != 2 {
            return Err(runtime_error!("Invalid number of arguments for Min"));
        }
        let g = context.create_graph()?;
        let i1 = g.input(arguments_types[0].clone())?;
        let i2 = g.input(arguments_types[1].clone())?;
        let cmp = g.custom_op(
            CustomOperation::new(GreaterThan {
                signed_comparison: self.signed_comparison,
            }),
            vec![i1.clone(), i2.clone()],
        )?;
        let normalized_cmp = normalize_cmp(cmp)?;
        let o = g.custom_op(CustomOperation::new(Mux {}), vec![normalized_cmp, i2, i1])?;
        g.set_output_node(o)?;
        g.finalize()?;
        Ok(g)
    }

    fn get_name(&self) -> String {
        format!("Min(signed_comparison={})", self.signed_comparison)
    }
}

/// A structure that defines the custom operation Max that computes the maximum of length-n bitstring arrays elementwise.
///
/// The last dimension of both inputs must be the same; it defines the length of input bitstrings.
/// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)).
/// For example, if input arrays are of shapes `[2,3]`, and `[1,3]`, the resulting array has shape `[2,3]`.
///
/// To compare signed numbers, `signed_comparison` should be set `true`.
///
/// To use this and other custom operations in computation graphs, see [Graph::custom_op].
///
/// # Custom operation arguments
///
/// - Node containing a binary array or scalar
/// - Node containing a binary array or scalar
///
/// # Custom operation returns
///
/// New Max node
///
/// # Example
///
/// ```
/// # use ciphercore_base::graphs::create_context;
/// # use ciphercore_base::data_types::{array_type, BIT};
/// # use ciphercore_base::custom_ops::{CustomOperation};
/// # use ciphercore_base::ops::min_max::Max;
/// let c = create_context().unwrap();
/// let g = c.create_graph().unwrap();
/// let t = array_type(vec![2, 3], BIT);
/// let n1 = g.input(t.clone()).unwrap();
/// let n2 = g.input(t.clone()).unwrap();
/// let n3 = g.custom_op(CustomOperation::new(Max {signed_comparison: true}), vec![n1, n2]).unwrap();
/// ```
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct Max {
    /// Boolean value indicating whether input bitstring represent signed integers
    pub signed_comparison: bool,
}

#[typetag::serde]
impl CustomOperationBody for Max {
    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
        if arguments_types.len() != 2 {
            return Err(runtime_error!("Invalid number of arguments for Max"));
        }
        let g = context.create_graph()?;
        let i1 = g.input(arguments_types[0].clone())?;
        let i2 = g.input(arguments_types[1].clone())?;
        let cmp = g.custom_op(
            CustomOperation::new(GreaterThan {
                signed_comparison: self.signed_comparison,
            }),
            vec![i1.clone(), i2.clone()],
        )?;
        let normalized_cmp = normalize_cmp(cmp)?;
        let o = g.custom_op(CustomOperation::new(Mux {}), vec![normalized_cmp, i1, i2])?;
        g.set_output_node(o)?;
        g.finalize()?;
        Ok(g)
    }

    fn get_name(&self) -> String {
        format!("Max(signed_comparison={})", self.signed_comparison)
    }
}

#[cfg(test)]
mod tests {

    use crate::custom_ops::run_instantiation_pass;
    use crate::data_types::{array_type, scalar_type, BIT, INT64, UINT64};
    use crate::data_values::Value;
    use crate::evaluators::random_evaluate;
    use crate::graphs::create_context;
    use crate::graphs::util::simple_context;

    use super::*;

    use std::cmp::{max, min};

    #[test]
    fn test_well_formed() {
        || -> Result<()> {
            let test_data: Vec<(u64, u64)> = vec![
                (31, 32),
                (76543, 76544),
                (0, 1),
                (0, 0),
                (761523, 761523),
                (u64::MAX, u64::MAX - 1),
                (u64::MAX - 761522, u64::MAX - 761523),
            ];
            let context = || -> Result<Context> {
                let c = simple_context(|g| {
                    let i1 = g.input(scalar_type(UINT64))?.a2b()?;
                    let i2 = g.input(scalar_type(UINT64))?.a2b()?;
                    g.create_tuple(vec![
                        g.custom_op(
                            CustomOperation::new(Min {
                                signed_comparison: false,
                            }),
                            vec![i1.clone(), i2.clone()],
                        )?,
                        g.custom_op(
                            CustomOperation::new(Max {
                                signed_comparison: true,
                            }),
                            vec![i1.clone(), i2.clone()],
                        )?,
                    ])
                })?;
                let mapped_c = run_instantiation_pass(c)?;
                Ok(mapped_c.get_context())
            }()?;
            for (u, v) in test_data {
                let minmax = random_evaluate(
                    context.get_main_graph()?,
                    vec![
                        Value::from_scalar(u, UINT64)?,
                        Value::from_scalar(v, UINT64)?,
                    ],
                )?
                .to_vector()?;
                let computed_min = minmax[0].to_u64(UINT64)?;
                let computed_max = minmax[1].to_i64(INT64)?;
                assert_eq!(min(u, v), computed_min);
                assert_eq!(max(u as i64, v as i64), computed_max);
            }
            Ok(())
        }()
        .unwrap();
    }

    #[test]
    fn test_malformed() {
        || -> Result<()> {
            let c = create_context()?;
            let g = c.create_graph()?;
            let i1 = g.input(scalar_type(UINT64))?.a2b()?;
            assert!(g
                .custom_op(
                    CustomOperation::new(Min {
                        signed_comparison: false
                    }),
                    vec![i1.clone()]
                )
                .is_err());
            assert!(g
                .custom_op(
                    CustomOperation::new(Max {
                        signed_comparison: false
                    }),
                    vec![i1.clone()]
                )
                .is_err());
            Ok(())
        }()
        .unwrap();
    }

    #[test]
    fn test_vector() {
        || -> Result<()> {
            let context = || -> Result<Context> {
                let c = simple_context(|g| {
                    let i1 = g.input(array_type(vec![1, 3, 64], BIT))?;
                    let i2 = g.input(array_type(vec![3, 1, 64], BIT))?;
                    g.create_tuple(vec![
                        g.custom_op(
                            CustomOperation::new(Min {
                                signed_comparison: false,
                            }),
                            vec![i1.clone(), i2.clone()],
                        )?,
                        g.custom_op(
                            CustomOperation::new(Max {
                                signed_comparison: false,
                            }),
                            vec![i1.clone(), i2.clone()],
                        )?,
                    ])
                })?;
                let mapped_c = run_instantiation_pass(c)?;
                Ok(mapped_c.get_context())
            }()?;
            let a = vec![0, 30, 100];
            let b = vec![10, 50, 150];
            let v = random_evaluate(
                context.get_main_graph()?,
                vec![
                    Value::from_flattened_array(&a, UINT64)?,
                    Value::from_flattened_array(&b, UINT64)?,
                ],
            )?
            .to_vector()?;
            let min_a_b = v[0].to_flattened_array_u64(array_type(vec![3, 3], UINT64))?;
            let max_a_b = v[1].to_flattened_array_u64(array_type(vec![3, 3], UINT64))?;
            assert_eq!(min_a_b, vec![0, 10, 10, 0, 30, 50, 0, 30, 100]);
            assert_eq!(max_a_b, vec![10, 30, 100, 50, 50, 100, 150, 150, 150]);
            Ok(())
        }()
        .unwrap();
    }
}