zki_sieve/plugins/
zkif_vector.rs

1use num_bigint::BigUint;
2use std::collections::BTreeMap;
3use std::convert::TryFrom;
4
5use crate::consumers::evaluator::PlaintextType;
6use crate::plugins::evaluate_plugin::extract_number;
7use crate::structs::count::Count;
8use crate::{Result, TypeId};
9
10/// This function performs the following checks on zkif_vector_add/mul inputs.
11/// - there is no public/private inputs
12/// - `params` are compliant with the plugin vector and the operation add/mul
13/// - `type_id` is defined and is a Field type
14/// - `output_count` and `input_count` are compliant with `plugin(zkif_vector, add/mul, params)`
15/// - `inputs` is compliant with `plugin(zkif_vector, add/mul, params)`
16fn zkif_vector_check<'a>(
17    output_count: &'a [Count],
18    input_count: &'a [Count],
19    inputs: &'a [&BigUint],
20    public_inputs: &BTreeMap<TypeId, Vec<BigUint>>,
21    private_inputs: &BTreeMap<TypeId, Vec<BigUint>>,
22    params: &'a [String],
23    types: &'a [PlaintextType],
24) -> Result<(usize, &'a BigUint)> {
25    // Check that there is no public/private inputs
26    if !public_inputs.is_empty() {
27        return Err("plugin(zkif_vector, add/mul) does not consume any public input.".into());
28    }
29    if !private_inputs.is_empty() {
30        return Err("plugin(zkif_vector, add/mul) does not consume any private input.".into());
31    }
32
33    // Check that params are compliant with the plugin zkif_vector and the operation add/mul
34    if params.len() != 2 {
35        return Err(
36            "plugin(zkif_vector, add/mul) must be declared with 2 params (type_id, length).".into(),
37        );
38    }
39    let param_type_id = u8::try_from(extract_number(&params[0])?)?;
40    let param_len = usize::try_from(extract_number(&params[1])?)?;
41    if param_len == 0 {
42        return Err("plugin(zkif_vector, add/mul) cannot be called without inputs.".into());
43    }
44    // Check that `type_id` is defined and is a Field type.
45    let type_ = types.get(param_type_id as usize).ok_or_else(|| {
46        format!(
47            "plugin(zkif_vector, add/mul) cannot be called with a type id ({}) which is not defined.",
48            param_type_id
49        )
50    })?;
51    let modulo = match type_ {
52        PlaintextType::Field(modulo) => modulo,
53        PlaintextType::PluginType(_, _, _) => {
54            return Err("plugin(zkif_vector, add/mul) cannot be called on a PluginType.".into())
55        }
56    };
57
58    // Check that `output_count` and `input_count` are compliant with `plugin(zkif_vector, add/mul, params)`
59    let expected_output_count = vec![Count::new(param_type_id, u64::try_from(param_len)?)];
60    if *output_count != expected_output_count {
61        return Err(format!(
62            "When calling the plugin(zkif_vector, add/mul, {}, {}), the out parameter in the function signature must be equal to {:?} (and not {:?}).",
63            param_type_id, param_len, expected_output_count, output_count
64        )
65            .into());
66    }
67
68    let expected_input_count = vec![
69        Count::new(param_type_id, u64::try_from(param_len)?),
70        Count::new(param_type_id, u64::try_from(param_len)?),
71    ];
72    if *input_count != expected_input_count {
73        return Err(format!(
74            "When calling the plugin(zkif_vector, add/mul, {}, {}), the in parameter in the function signature must be equal to {:?} (and not {:?}).",
75            param_type_id, param_len, expected_input_count, input_count
76        )
77            .into());
78    }
79
80    // Check that `inputs` is compliant with `plugin(zkif_vector, add/mul, params)`
81    if inputs.len() != 2 * param_len {
82        return Err(format!(
83            "When calling the plugin(zkif_vector, add/mul, {}, {}), we should have {} input values (and not {}).",
84            param_type_id, param_len, 2*param_len, inputs.len()
85        )
86            .into());
87    }
88    Ok((param_len, modulo))
89}
90
91/// @function(vector_add, @out: type_id: length, @in: type_id: length, type_id: length) @plugin(zkif_vector, add, type_id, length)
92/// This function takes as input two vectors `in1` and `in2` of length `length` containing elements from type `type_id`,
93/// This function returns one vector of length `length` such that `out[i] = in1[i] + in2[i] % type_modulo`
94pub fn zkif_vector_add(
95    output_count: &[Count],
96    input_count: &[Count],
97    inputs: &[&BigUint],
98    public_inputs: &BTreeMap<TypeId, Vec<BigUint>>,
99    private_inputs: &BTreeMap<TypeId, Vec<BigUint>>,
100    params: &[String],
101    types: &[PlaintextType],
102) -> Result<Vec<BigUint>> {
103    let (param_len, modulo) = zkif_vector_check(
104        output_count,
105        input_count,
106        inputs,
107        public_inputs,
108        private_inputs,
109        params,
110        types,
111    )?;
112
113    // Evaluate plugin(vector, add)
114    let mut result = vec![];
115    for i in 0..param_len {
116        result.push((inputs[i] + inputs[i + param_len]) % modulo);
117    }
118    Ok(result)
119}
120
121/// @function(vector_mul, @out: type_id: length, @in: type_id: length, type_id: length) @plugin(zkif_vector, mul, type_id, length)
122/// This function takes as input two vectors `in1` and `in2` of length `length` containing elements from type `type_id`,
123/// This function returns one vector of length `length` such that `out[i] = in1[i] * in2[i] % type_modulo`
124pub fn zkif_vector_mul(
125    output_count: &[Count],
126    input_count: &[Count],
127    inputs: &[&BigUint],
128    public_inputs: &BTreeMap<TypeId, Vec<BigUint>>,
129    private_inputs: &BTreeMap<TypeId, Vec<BigUint>>,
130    params: &[String],
131    types: &[PlaintextType],
132) -> Result<Vec<BigUint>> {
133    let (param_len, modulo) = zkif_vector_check(
134        output_count,
135        input_count,
136        inputs,
137        public_inputs,
138        private_inputs,
139        params,
140        types,
141    )?;
142
143    // Evaluate plugin(zkif_vector, mul)
144    let mut result = vec![];
145    for i in 0..param_len {
146        result.push((inputs[i] * inputs[i + param_len]) % modulo);
147    }
148    Ok(result)
149}
150
151#[test]
152fn test_zkif_vector_check() {
153    let output_count = vec![Count::new(0, 2)];
154    let input_count = vec![Count::new(0, 2), Count::new(0, 2)];
155    let inputs = [
156        &BigUint::from_bytes_le(&[1]),
157        &BigUint::from_bytes_le(&[2]),
158        &BigUint::from_bytes_le(&[3]),
159        &BigUint::from_bytes_le(&[4]),
160    ];
161    let types = [PlaintextType::Field(BigUint::from_bytes_le(&[7]))];
162    let params = ["0".to_string(), "2".to_string()];
163    let result = zkif_vector_check(
164        &output_count,
165        &input_count,
166        &inputs,
167        &BTreeMap::new(),
168        &BTreeMap::new(),
169        &params,
170        &types,
171    )
172    .unwrap();
173    let expected_result = (2_usize, &BigUint::from_bytes_le(&[7]));
174    assert_eq!(result, expected_result);
175
176    // Try to use the plugin zkif_vector with an unknown type_id in params
177    let incorrect_params = ["1".to_string(), "2".to_string()];
178    let result = zkif_vector_check(
179        &output_count,
180        &input_count,
181        &inputs,
182        &BTreeMap::new(),
183        &BTreeMap::new(),
184        &incorrect_params,
185        &types,
186    );
187    assert!(result.is_err());
188
189    // Try to use the plugin zkif_vector with a type_id which cannot be parsed into u8
190    let incorrect_params = ["a".to_string(), "2".to_string()];
191    let result = zkif_vector_check(
192        &output_count,
193        &input_count,
194        &inputs,
195        &BTreeMap::new(),
196        &BTreeMap::new(),
197        &incorrect_params,
198        &types,
199    );
200    assert!(result.is_err());
201
202    // Try to use the plugin zkif_vector with a param_len equal to 0
203    let incorrect_params = ["0".to_string(), "0".to_string()];
204    let result = zkif_vector_check(
205        &output_count,
206        &input_count,
207        &inputs,
208        &BTreeMap::new(),
209        &BTreeMap::new(),
210        &incorrect_params,
211        &types,
212    );
213    assert!(result.is_err());
214
215    // Try to use the plugin zkif_vector with an incorrect output_count
216    let incorrect_output_count = vec![Count::new(0, 3)];
217    let result = zkif_vector_check(
218        &incorrect_output_count,
219        &input_count,
220        &inputs,
221        &BTreeMap::new(),
222        &BTreeMap::new(),
223        &params,
224        &types,
225    );
226    assert!(result.is_err());
227
228    // Try to use the plugin zkif_vector with an incorrect input_count
229    let incorrect_input_count = vec![Count::new(0, 2), Count::new(0, 3)];
230    let result = zkif_vector_check(
231        &output_count,
232        &incorrect_input_count,
233        &inputs,
234        &BTreeMap::new(),
235        &BTreeMap::new(),
236        &params,
237        &types,
238    );
239    assert!(result.is_err());
240
241    // Try to use the plugin zkif_vector with an incorrect number of inputs
242    let incorrect_inputs = [
243        &BigUint::from_bytes_le(&[1]),
244        &BigUint::from_bytes_le(&[2]),
245        &BigUint::from_bytes_le(&[3]),
246    ];
247    let result = zkif_vector_check(
248        &output_count,
249        &input_count,
250        &incorrect_inputs,
251        &BTreeMap::new(),
252        &BTreeMap::new(),
253        &params,
254        &types,
255    );
256    assert!(result.is_err());
257}
258
259#[test]
260fn test_vector_add() {
261    let output_count = vec![Count::new(0, 2)];
262    let input_count = vec![Count::new(0, 2), Count::new(0, 2)];
263    let inputs = [
264        &BigUint::from_bytes_le(&[1]),
265        &BigUint::from_bytes_le(&[2]),
266        &BigUint::from_bytes_le(&[3]),
267        &BigUint::from_bytes_le(&[4]),
268    ];
269    let types = [PlaintextType::Field(BigUint::from_bytes_le(&[7]))];
270    let params = ["0".to_string(), "2".to_string()];
271    let result = zkif_vector_add(
272        &output_count,
273        &input_count,
274        &inputs,
275        &BTreeMap::new(),
276        &BTreeMap::new(),
277        &params,
278        &types,
279    )
280    .unwrap();
281    let expected_result = vec![BigUint::from_bytes_le(&[4]), BigUint::from_bytes_le(&[6])];
282    assert_eq!(result, expected_result);
283
284    // Try to use the plugin(zkif_vector, add, params) with 3 parameters instead of 2
285    let incorrect_params = ["1".to_string(), "2".to_string()];
286    let result = zkif_vector_add(
287        &output_count,
288        &input_count,
289        &inputs,
290        &BTreeMap::new(),
291        &BTreeMap::new(),
292        &incorrect_params,
293        &types,
294    );
295    assert!(result.is_err());
296}
297
298#[test]
299fn test_vector_mul() {
300    let output_count = vec![Count::new(0, 2)];
301    let input_count = vec![Count::new(0, 2), Count::new(0, 2)];
302    let inputs = [
303        &BigUint::from_bytes_le(&[1]),
304        &BigUint::from_bytes_le(&[2]),
305        &BigUint::from_bytes_le(&[3]),
306        &BigUint::from_bytes_le(&[4]),
307    ];
308    let types = [PlaintextType::Field(BigUint::from_bytes_le(&[7]))];
309    let params = ["0".to_string(), "2".to_string()];
310    let result = zkif_vector_mul(
311        &output_count,
312        &input_count,
313        &inputs,
314        &BTreeMap::new(),
315        &BTreeMap::new(),
316        &params,
317        &types,
318    )
319    .unwrap();
320    let expected_result = vec![BigUint::from_bytes_le(&[3]), BigUint::from_bytes_le(&[1])];
321    assert_eq!(result, expected_result);
322
323    // Try to use the plugin(zkif_vector, add, params) with 3 parameters instead of 2
324    let incorrect_params = ["1".to_string(), "2".to_string()];
325    let result = zkif_vector_mul(
326        &output_count,
327        &input_count,
328        &inputs,
329        &BTreeMap::new(),
330        &BTreeMap::new(),
331        &incorrect_params,
332        &types,
333    );
334    assert!(result.is_err());
335}