1use num_bigint::BigUint;
2use std::collections::BTreeMap;
3use std::convert::TryFrom;
4
5use crate::consumers::evaluator::PlaintextType;
6use crate::plugins::evaluate_plugin::extract_number;
7use crate::structs::count::Count;
8use crate::{Result, TypeId};
9
10fn zkif_vector_check<'a>(
17 output_count: &'a [Count],
18 input_count: &'a [Count],
19 inputs: &'a [&BigUint],
20 public_inputs: &BTreeMap<TypeId, Vec<BigUint>>,
21 private_inputs: &BTreeMap<TypeId, Vec<BigUint>>,
22 params: &'a [String],
23 types: &'a [PlaintextType],
24) -> Result<(usize, &'a BigUint)> {
25 if !public_inputs.is_empty() {
27 return Err("plugin(zkif_vector, add/mul) does not consume any public input.".into());
28 }
29 if !private_inputs.is_empty() {
30 return Err("plugin(zkif_vector, add/mul) does not consume any private input.".into());
31 }
32
33 if params.len() != 2 {
35 return Err(
36 "plugin(zkif_vector, add/mul) must be declared with 2 params (type_id, length).".into(),
37 );
38 }
39 let param_type_id = u8::try_from(extract_number(¶ms[0])?)?;
40 let param_len = usize::try_from(extract_number(¶ms[1])?)?;
41 if param_len == 0 {
42 return Err("plugin(zkif_vector, add/mul) cannot be called without inputs.".into());
43 }
44 let type_ = types.get(param_type_id as usize).ok_or_else(|| {
46 format!(
47 "plugin(zkif_vector, add/mul) cannot be called with a type id ({}) which is not defined.",
48 param_type_id
49 )
50 })?;
51 let modulo = match type_ {
52 PlaintextType::Field(modulo) => modulo,
53 PlaintextType::PluginType(_, _, _) => {
54 return Err("plugin(zkif_vector, add/mul) cannot be called on a PluginType.".into())
55 }
56 };
57
58 let expected_output_count = vec![Count::new(param_type_id, u64::try_from(param_len)?)];
60 if *output_count != expected_output_count {
61 return Err(format!(
62 "When calling the plugin(zkif_vector, add/mul, {}, {}), the out parameter in the function signature must be equal to {:?} (and not {:?}).",
63 param_type_id, param_len, expected_output_count, output_count
64 )
65 .into());
66 }
67
68 let expected_input_count = vec![
69 Count::new(param_type_id, u64::try_from(param_len)?),
70 Count::new(param_type_id, u64::try_from(param_len)?),
71 ];
72 if *input_count != expected_input_count {
73 return Err(format!(
74 "When calling the plugin(zkif_vector, add/mul, {}, {}), the in parameter in the function signature must be equal to {:?} (and not {:?}).",
75 param_type_id, param_len, expected_input_count, input_count
76 )
77 .into());
78 }
79
80 if inputs.len() != 2 * param_len {
82 return Err(format!(
83 "When calling the plugin(zkif_vector, add/mul, {}, {}), we should have {} input values (and not {}).",
84 param_type_id, param_len, 2*param_len, inputs.len()
85 )
86 .into());
87 }
88 Ok((param_len, modulo))
89}
90
91pub fn zkif_vector_add(
95 output_count: &[Count],
96 input_count: &[Count],
97 inputs: &[&BigUint],
98 public_inputs: &BTreeMap<TypeId, Vec<BigUint>>,
99 private_inputs: &BTreeMap<TypeId, Vec<BigUint>>,
100 params: &[String],
101 types: &[PlaintextType],
102) -> Result<Vec<BigUint>> {
103 let (param_len, modulo) = zkif_vector_check(
104 output_count,
105 input_count,
106 inputs,
107 public_inputs,
108 private_inputs,
109 params,
110 types,
111 )?;
112
113 let mut result = vec![];
115 for i in 0..param_len {
116 result.push((inputs[i] + inputs[i + param_len]) % modulo);
117 }
118 Ok(result)
119}
120
121pub fn zkif_vector_mul(
125 output_count: &[Count],
126 input_count: &[Count],
127 inputs: &[&BigUint],
128 public_inputs: &BTreeMap<TypeId, Vec<BigUint>>,
129 private_inputs: &BTreeMap<TypeId, Vec<BigUint>>,
130 params: &[String],
131 types: &[PlaintextType],
132) -> Result<Vec<BigUint>> {
133 let (param_len, modulo) = zkif_vector_check(
134 output_count,
135 input_count,
136 inputs,
137 public_inputs,
138 private_inputs,
139 params,
140 types,
141 )?;
142
143 let mut result = vec![];
145 for i in 0..param_len {
146 result.push((inputs[i] * inputs[i + param_len]) % modulo);
147 }
148 Ok(result)
149}
150
151#[test]
152fn test_zkif_vector_check() {
153 let output_count = vec![Count::new(0, 2)];
154 let input_count = vec![Count::new(0, 2), Count::new(0, 2)];
155 let inputs = [
156 &BigUint::from_bytes_le(&[1]),
157 &BigUint::from_bytes_le(&[2]),
158 &BigUint::from_bytes_le(&[3]),
159 &BigUint::from_bytes_le(&[4]),
160 ];
161 let types = [PlaintextType::Field(BigUint::from_bytes_le(&[7]))];
162 let params = ["0".to_string(), "2".to_string()];
163 let result = zkif_vector_check(
164 &output_count,
165 &input_count,
166 &inputs,
167 &BTreeMap::new(),
168 &BTreeMap::new(),
169 ¶ms,
170 &types,
171 )
172 .unwrap();
173 let expected_result = (2_usize, &BigUint::from_bytes_le(&[7]));
174 assert_eq!(result, expected_result);
175
176 let incorrect_params = ["1".to_string(), "2".to_string()];
178 let result = zkif_vector_check(
179 &output_count,
180 &input_count,
181 &inputs,
182 &BTreeMap::new(),
183 &BTreeMap::new(),
184 &incorrect_params,
185 &types,
186 );
187 assert!(result.is_err());
188
189 let incorrect_params = ["a".to_string(), "2".to_string()];
191 let result = zkif_vector_check(
192 &output_count,
193 &input_count,
194 &inputs,
195 &BTreeMap::new(),
196 &BTreeMap::new(),
197 &incorrect_params,
198 &types,
199 );
200 assert!(result.is_err());
201
202 let incorrect_params = ["0".to_string(), "0".to_string()];
204 let result = zkif_vector_check(
205 &output_count,
206 &input_count,
207 &inputs,
208 &BTreeMap::new(),
209 &BTreeMap::new(),
210 &incorrect_params,
211 &types,
212 );
213 assert!(result.is_err());
214
215 let incorrect_output_count = vec![Count::new(0, 3)];
217 let result = zkif_vector_check(
218 &incorrect_output_count,
219 &input_count,
220 &inputs,
221 &BTreeMap::new(),
222 &BTreeMap::new(),
223 ¶ms,
224 &types,
225 );
226 assert!(result.is_err());
227
228 let incorrect_input_count = vec![Count::new(0, 2), Count::new(0, 3)];
230 let result = zkif_vector_check(
231 &output_count,
232 &incorrect_input_count,
233 &inputs,
234 &BTreeMap::new(),
235 &BTreeMap::new(),
236 ¶ms,
237 &types,
238 );
239 assert!(result.is_err());
240
241 let incorrect_inputs = [
243 &BigUint::from_bytes_le(&[1]),
244 &BigUint::from_bytes_le(&[2]),
245 &BigUint::from_bytes_le(&[3]),
246 ];
247 let result = zkif_vector_check(
248 &output_count,
249 &input_count,
250 &incorrect_inputs,
251 &BTreeMap::new(),
252 &BTreeMap::new(),
253 ¶ms,
254 &types,
255 );
256 assert!(result.is_err());
257}
258
259#[test]
260fn test_vector_add() {
261 let output_count = vec![Count::new(0, 2)];
262 let input_count = vec![Count::new(0, 2), Count::new(0, 2)];
263 let inputs = [
264 &BigUint::from_bytes_le(&[1]),
265 &BigUint::from_bytes_le(&[2]),
266 &BigUint::from_bytes_le(&[3]),
267 &BigUint::from_bytes_le(&[4]),
268 ];
269 let types = [PlaintextType::Field(BigUint::from_bytes_le(&[7]))];
270 let params = ["0".to_string(), "2".to_string()];
271 let result = zkif_vector_add(
272 &output_count,
273 &input_count,
274 &inputs,
275 &BTreeMap::new(),
276 &BTreeMap::new(),
277 ¶ms,
278 &types,
279 )
280 .unwrap();
281 let expected_result = vec![BigUint::from_bytes_le(&[4]), BigUint::from_bytes_le(&[6])];
282 assert_eq!(result, expected_result);
283
284 let incorrect_params = ["1".to_string(), "2".to_string()];
286 let result = zkif_vector_add(
287 &output_count,
288 &input_count,
289 &inputs,
290 &BTreeMap::new(),
291 &BTreeMap::new(),
292 &incorrect_params,
293 &types,
294 );
295 assert!(result.is_err());
296}
297
298#[test]
299fn test_vector_mul() {
300 let output_count = vec![Count::new(0, 2)];
301 let input_count = vec![Count::new(0, 2), Count::new(0, 2)];
302 let inputs = [
303 &BigUint::from_bytes_le(&[1]),
304 &BigUint::from_bytes_le(&[2]),
305 &BigUint::from_bytes_le(&[3]),
306 &BigUint::from_bytes_le(&[4]),
307 ];
308 let types = [PlaintextType::Field(BigUint::from_bytes_le(&[7]))];
309 let params = ["0".to_string(), "2".to_string()];
310 let result = zkif_vector_mul(
311 &output_count,
312 &input_count,
313 &inputs,
314 &BTreeMap::new(),
315 &BTreeMap::new(),
316 ¶ms,
317 &types,
318 )
319 .unwrap();
320 let expected_result = vec![BigUint::from_bytes_le(&[3]), BigUint::from_bytes_le(&[1])];
321 assert_eq!(result, expected_result);
322
323 let incorrect_params = ["1".to_string(), "2".to_string()];
325 let result = zkif_vector_mul(
326 &output_count,
327 &input_count,
328 &inputs,
329 &BTreeMap::new(),
330 &BTreeMap::new(),
331 &incorrect_params,
332 &types,
333 );
334 assert!(result.is_err());
335}