Skip to main content

rust_eddsa_helper/
lib.rs

1use std::ffi::{CStr, CString};
2use std::os::raw::{c_char, c_int};
3use std::panic::{catch_unwind, AssertUnwindSafe};
4use zeroize::Zeroizing;
5
6/// Maximum raw JSON byte length accepted at every FFI entry point (guards against OOM from unbounded serde parsing).
7pub(crate) const MAX_JSON_BYTES: usize = 1_048_576; // 1 MiB
8/// Maximum number of field-element inputs for Poseidon hashing.
9pub(crate) const MAX_POSEIDON_FIELD_INPUTS: usize = 32;
10/// Maximum bit-array length for Poseidon-over-bits hashing and EdDSA signing.
11pub(crate) const MAX_BITS_LEN: usize = 65_536;
12
13pub mod babyjub;
14pub mod blake;
15pub mod eddsa;
16pub mod poseidon_hash;
17
18pub use eddsa::{sign_eddsa, verify_eddsa};
19pub use poseidon_hash::poseidon_hash_bits;
20
21fn write_error_json(output_json: *mut *mut c_char, message: &str) -> c_int {
22    let payload = serde_json::json!({"success": false, "error": message}).to_string();
23    unsafe {
24        match CString::new(payload) {
25            Ok(cstr) => {
26                *output_json = cstr.into_raw();
27                -1
28            }
29            Err(_) => {
30                let fallback = CString::new(r#"{"success":false,"error":"Unknown error"}"#)
31                    .expect("static JSON has no NUL");
32                *output_json = fallback.into_raw();
33                -1
34            }
35        }
36    }
37}
38
39/// FFI function to hash field elements using Poseidon
40///
41/// Input JSON format:
42/// {
43///   "inputs": ["123", "456", ...]  // Array of field element strings (BigInt as decimal string)
44/// }
45///
46/// Output JSON format:
47/// {
48///   "success": true,
49///   "result": "789..."  // Field element as decimal string
50/// }
51///
52/// On error, returns JSON with "success": false and "error": "..."
53///
54/// # Safety
55///
56/// - `input_json` must be a valid, non-null, NUL-terminated C string for the
57///   duration of the call.
58/// - `output_json` must be a valid, non-null, writable pointer for the
59///   duration of the call.
60/// - On success, `*output_json` is set to a newly allocated C string that
61///   **must** be freed with [`poseidon_free_string`].
62#[no_mangle]
63pub unsafe extern "C" fn poseidon_hash(input_json: *const c_char, output_json: *mut *mut c_char) -> c_int {
64    if input_json.is_null() || output_json.is_null() {
65        return -1;
66    }
67
68    let guarded = catch_unwind(AssertUnwindSafe(|| unsafe {
69        let input_str = match CStr::from_ptr(input_json).to_str() {
70            Ok(s) => s,
71            Err(_) => {
72                return write_error_json(output_json, "Invalid UTF-8 input");
73            }
74        };
75
76        if input_str.len() > MAX_JSON_BYTES {
77            return write_error_json(output_json, "Input JSON exceeds maximum allowed size");
78        }
79
80        match poseidon_hash_field_elements(input_str) {
81            Ok(json_str) => match CString::new(json_str) {
82                Ok(cstr) => {
83                    *output_json = cstr.into_raw();
84                    0
85                }
86                Err(_) => write_error_json(output_json, "Failed to create output string"),
87            },
88            Err(e) => write_error_json(output_json, &e),
89        }
90    }));
91    match guarded {
92        Ok(code) => code,
93        Err(_) => write_error_json(output_json, "Rust panic in poseidon_hash"),
94    }
95}
96
97/// Free memory allocated by poseidon_hash
98///
99/// # Safety
100///
101/// `ptr` must be a pointer previously returned by [`poseidon_hash`] or
102/// [`poseidon_hash_bits_ffi`], or null. Must not be freed more than once.
103#[no_mangle]
104pub unsafe extern "C" fn poseidon_free_string(ptr: *mut c_char) {
105    if !ptr.is_null() {
106        unsafe {
107            let _ = CString::from_raw(ptr);
108        }
109    }
110}
111
112/// FFI function to hash bits using Poseidon (matching circomlibjs behavior)
113///
114/// Input JSON format:
115/// {
116///   "bits": [0, 1, 0, ...]  // Array of bits (0 or 1)
117/// }
118///
119/// Output JSON format:
120/// {
121///   "success": true,
122///   "result": "789..."  // Field element as decimal string
123/// }
124///
125/// On error, returns JSON with "success": false and "error": "..."
126///
127/// # Safety
128///
129/// - `input_json` must be a valid, non-null, NUL-terminated C string for the
130///   duration of the call.
131/// - `output_json` must be a valid, non-null, writable pointer for the
132///   duration of the call.
133/// - On success, `*output_json` is set to a newly allocated C string that
134///   **must** be freed with [`poseidon_free_string`].
135#[no_mangle]
136pub unsafe extern "C" fn poseidon_hash_bits_ffi(
137    input_json: *const c_char,
138    output_json: *mut *mut c_char,
139) -> c_int {
140    if input_json.is_null() || output_json.is_null() {
141        return -1;
142    }
143
144    let guarded = catch_unwind(AssertUnwindSafe(|| unsafe {
145        let input_str = match CStr::from_ptr(input_json).to_str() {
146            Ok(s) => s,
147            Err(_) => {
148                return write_error_json(output_json, "Invalid UTF-8 input");
149            }
150        };
151
152        if input_str.len() > MAX_JSON_BYTES {
153            return write_error_json(output_json, "Input JSON exceeds maximum allowed size");
154        }
155
156        match poseidon_hash_bits_from_json(input_str) {
157            Ok(json_str) => match CString::new(json_str) {
158                Ok(cstr) => {
159                    *output_json = cstr.into_raw();
160                    0
161                }
162                Err(_) => write_error_json(output_json, "Failed to create output string"),
163            },
164            Err(e) => write_error_json(output_json, &e),
165        }
166    }));
167    match guarded {
168        Ok(code) => code,
169        Err(_) => write_error_json(output_json, "Rust panic in poseidon_hash_bits_ffi"),
170    }
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
174struct PoseidonHashBitsRequest {
175    bits: Vec<u8>,
176}
177
178/// Hash bits using Poseidon (wrapper for poseidon_hash_bits)
179fn poseidon_hash_bits_from_json(input_json: &str) -> Result<String, String> {
180    let request: PoseidonHashBitsRequest = serde_json::from_str(input_json)
181        .map_err(|e| format!("Failed to parse input JSON: {}", e))?;
182
183    if request.bits.is_empty() {
184        return Err("Bits input must not be empty".to_string());
185    }
186    if request.bits.len() > MAX_BITS_LEN {
187        return Err(format!("Bits input exceeds maximum allowed length of {MAX_BITS_LEN}"));
188    }
189
190    let hash = poseidon_hash_bits(&request.bits)?;
191    let hash_str = hash.into_bigint().to_string();
192
193    let output = PoseidonHashResult {
194        success: true,
195        result: Some(hash_str),
196        error: None,
197    };
198
199    serde_json::to_string(&output).map_err(|e| format!("Failed to serialize output: {}", e))
200}
201
202use ark_bn254::Fr;
203use ark_ff::PrimeField;
204use light_poseidon::{Poseidon, PoseidonHasher};
205use serde::{Deserialize, Serialize};
206use std::str::FromStr;
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209struct PoseidonHashRequest {
210    inputs: Vec<String>,
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
214struct PoseidonHashResult {
215    success: bool,
216    #[serde(skip_serializing_if = "Option::is_none")]
217    result: Option<String>,
218    #[serde(skip_serializing_if = "Option::is_none")]
219    error: Option<String>,
220}
221
222/// Hash an array of field elements using Poseidon
223/// This matches the behavior of circomlibjs's Poseidon hash
224fn poseidon_hash_field_elements(input_json: &str) -> Result<String, String> {
225    let request: PoseidonHashRequest = serde_json::from_str(input_json)
226        .map_err(|e| format!("Failed to parse input JSON: {}", e))?;
227
228    if request.inputs.is_empty() {
229        return Err("Input array cannot be empty".to_string());
230    }
231    if request.inputs.len() == 1 {
232        return Err(
233            "Single-input Poseidon hash is disabled to avoid [x] vs [x,0] domain collision"
234                .to_string(),
235        );
236    }
237    if request.inputs.len() > MAX_POSEIDON_FIELD_INPUTS {
238        return Err(format!("Input array exceeds maximum allowed length of {MAX_POSEIDON_FIELD_INPUTS}"));
239    }
240
241    // Convert string inputs to field elements
242    let mut field_elements = Vec::new();
243    for (idx, input_str) in request.inputs.iter().enumerate() {
244        let field_element = Fr::from_str(input_str)
245            .map_err(|_| format!("Failed to parse field element at index {idx}"))?;
246        field_elements.push(field_element);
247    }
248
249    // Create Poseidon instance with Circom parameters
250    let num_inputs = field_elements.len();
251    let mut poseidon = Poseidon::<Fr>::new_circom(num_inputs)
252        .map_err(|e| format!("Failed to create Poseidon instance: {:?}", e))?;
253
254    // Hash the field elements
255    let hash = poseidon
256        .hash(&field_elements)
257        .map_err(|e| format!("Poseidon hash failed: {:?}", e))?;
258
259    // Convert field element to string (decimal representation)
260    let hash_str = hash.into_bigint().to_string();
261
262    let output = PoseidonHashResult {
263        success: true,
264        result: Some(hash_str),
265        error: None,
266    };
267
268    serde_json::to_string(&output).map_err(|e| format!("Failed to serialize output: {}", e))
269}
270
271/// FFI function to sign data using EdDSA on Baby Jubjub
272///
273/// Input JSON format:
274/// {
275///   "bits": [0, 1, 0, ...],  // Array of bits (0 or 1)
276///   "privateKeyHex": "00010203..."  // 64 hex characters (32 bytes)
277/// }
278///
279/// Output JSON format:
280/// {
281///   "success": true,
282///   "result": {
283///     "Ax": "...",
284///     "Ay": "...",
285///     "R8x": "...",
286///     "R8y": "...",
287///     "S": "..."
288///   }
289/// }
290///
291/// On error, returns JSON with "success": false and "error": "..."
292///
293/// # Safety
294///
295/// - `input_json` must be a valid, non-null, NUL-terminated C string for the
296///   duration of the call.
297/// - `output_json` must be a valid, non-null, writable pointer for the
298///   duration of the call.
299/// - On success, `*output_json` is set to a newly allocated C string that
300///   **must** be freed with [`eddsa_free_string`].
301#[no_mangle]
302pub unsafe extern "C" fn eddsa_sign(input_json: *const c_char, output_json: *mut *mut c_char) -> c_int {
303    if input_json.is_null() || output_json.is_null() {
304        return -1;
305    }
306
307    let guarded = catch_unwind(AssertUnwindSafe(|| unsafe {
308        let input_json_owned = match CStr::from_ptr(input_json).to_str() {
309            Ok(s) => Zeroizing::new(s.to_owned()),
310            Err(_) => {
311                return write_error_json(output_json, "Invalid UTF-8 input");
312            }
313        };
314
315        if input_json_owned.len() > MAX_JSON_BYTES {
316            return write_error_json(output_json, "Input JSON exceeds maximum allowed size");
317        }
318
319        match sign_eddsa(&input_json_owned) {
320            Ok(json_str) => match CString::new(json_str) {
321                Ok(cstr) => {
322                    *output_json = cstr.into_raw();
323                    0
324                }
325                Err(_) => write_error_json(output_json, "Failed to create output string"),
326            },
327            Err(e) => write_error_json(output_json, &e),
328        }
329    }));
330    match guarded {
331        Ok(code) => code,
332        Err(_) => write_error_json(output_json, "Rust panic in eddsa_sign"),
333    }
334}
335
336/// FFI function to verify EdDSA signature over pre-hashed Poseidon digest.
337///
338/// # Safety
339///
340/// - `input_json` must be a valid, non-null, NUL-terminated C string for the
341///   duration of the call.
342/// - `output_json` must be a valid, non-null, writable pointer for the
343///   duration of the call.
344/// - On success, `*output_json` is set to a newly allocated C string that
345///   **must** be freed with [`eddsa_free_string`].
346#[no_mangle]
347pub unsafe extern "C" fn eddsa_verify(input_json: *const c_char, output_json: *mut *mut c_char) -> c_int {
348    if input_json.is_null() || output_json.is_null() {
349        return -1;
350    }
351
352    let guarded = catch_unwind(AssertUnwindSafe(|| unsafe {
353        let input_str = match CStr::from_ptr(input_json).to_str() {
354            Ok(s) => s,
355            Err(_) => {
356                return write_error_json(output_json, "Invalid UTF-8 input");
357            }
358        };
359
360        if input_str.len() > MAX_JSON_BYTES {
361            return write_error_json(output_json, "Input JSON exceeds maximum allowed size");
362        }
363
364        match verify_eddsa(input_str) {
365            Ok(json_str) => match CString::new(json_str) {
366                Ok(cstr) => {
367                    *output_json = cstr.into_raw();
368                    0
369                }
370                Err(_) => write_error_json(output_json, "Failed to create output string"),
371            },
372            Err(e) => write_error_json(output_json, &e),
373        }
374    }));
375    match guarded {
376        Ok(code) => code,
377        Err(_) => write_error_json(output_json, "Rust panic in eddsa_verify"),
378    }
379}
380
381/// Free memory allocated by eddsa_sign
382///
383/// # Safety
384///
385/// `ptr` must be a pointer previously returned by [`eddsa_sign`] or
386/// [`eddsa_verify`], or null. Must not be freed more than once.
387#[no_mangle]
388pub unsafe extern "C" fn eddsa_free_string(ptr: *mut c_char) {
389    if !ptr.is_null() {
390        unsafe {
391            let _ = CString::from_raw(ptr);
392        }
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use rand::RngCore;
400    use serde_json::Value;
401    use std::ptr;
402
403    fn random_private_key_hex() -> String {
404        let mut bytes = [0u8; 32];
405        rand::thread_rng().fill_bytes(&mut bytes);
406        hex::encode(bytes)
407    }
408
409    #[test]
410    fn test_eddsa_sign() {
411        let key = random_private_key_hex();
412        let input = format!(r#"{{
413            "operation": "sign",
414            "data": {{
415                "bits": [0, 1, 0, 1],
416                "privateKeyHex": "{key}"
417            }}
418        }}"#);
419
420        let input_cstr = CString::new(input).unwrap();
421        let mut output_ptr: *mut c_char = ptr::null_mut();
422
423        let result = unsafe { eddsa_sign(input_cstr.as_ptr(), &mut output_ptr) };
424
425        assert_eq!(result, 0);
426
427        if !output_ptr.is_null() {
428            let output_cstr = unsafe { CStr::from_ptr(output_ptr) };
429            let output_str = output_cstr.to_str().unwrap();
430            println!("Output: {}", output_str);
431
432            unsafe { eddsa_free_string(output_ptr) };
433        }
434    }
435
436    #[test]
437    fn test_eddsa_verify_returns_error_json_for_malformed_payload() {
438        let input = r#"{
439            "operation": "verify",
440            "data": {
441                "msgHash": "bad-field",
442                "publicKeyAx": "1",
443                "publicKeyAy": "2",
444                "R8x": "3",
445                "R8y": "4",
446                "S": "5"
447            }
448        }"#;
449
450        let input_cstr = CString::new(input).unwrap();
451        let mut output_ptr: *mut c_char = ptr::null_mut();
452
453        let result = unsafe { eddsa_verify(input_cstr.as_ptr(), &mut output_ptr) };
454        assert_eq!(result, -1);
455        assert!(!output_ptr.is_null());
456
457        let output_cstr = unsafe { CStr::from_ptr(output_ptr) };
458        let output_str = output_cstr.to_str().unwrap();
459        let parsed: Value = serde_json::from_str(output_str).unwrap();
460        assert_eq!(parsed.get("success").and_then(|v| v.as_bool()), Some(false));
461        assert!(parsed
462            .get("error")
463            .and_then(|v| v.as_str())
464            .unwrap_or_default()
465            .contains("Failed to parse msgHash as field element"));
466
467        unsafe { eddsa_free_string(output_ptr) };
468    }
469
470    #[test]
471    fn test_poseidon_hash_rejects_single_input() {
472        let input = r#"{"inputs":["7"]}"#;
473        let input_cstr = CString::new(input).unwrap();
474        let mut output_ptr: *mut c_char = ptr::null_mut();
475
476        let result = unsafe { poseidon_hash(input_cstr.as_ptr(), &mut output_ptr) };
477        assert_eq!(result, -1);
478        assert!(!output_ptr.is_null());
479
480        let output_cstr = unsafe { CStr::from_ptr(output_ptr) };
481        let output_str = output_cstr.to_str().unwrap();
482        let parsed: Value = serde_json::from_str(output_str).unwrap();
483        assert_eq!(parsed.get("success").and_then(|v| v.as_bool()), Some(false));
484        assert!(parsed
485            .get("error")
486            .and_then(|v| v.as_str())
487            .unwrap_or_default()
488            .contains("Single-input Poseidon hash is disabled"));
489        unsafe { poseidon_free_string(output_ptr) };
490    }
491
492    #[test]
493    fn test_poseidon_hash_bits_ffi_rejects_empty_bits() {
494        // Regression test for: empty-bits FFI returns a known constant Poseidon digest.
495        // An empty `bits` array must be rejected, not silently hashed as 248 zero-bits,
496        // because that would return a predictable value an attacker can compute.
497        let input = r#"{"bits":[]}"#;
498        let input_cstr = CString::new(input).unwrap();
499        let mut output_ptr: *mut c_char = ptr::null_mut();
500
501        let result = unsafe { poseidon_hash_bits_ffi(input_cstr.as_ptr(), &mut output_ptr) };
502        assert_eq!(result, -1, "empty bits must return an error code");
503        assert!(!output_ptr.is_null());
504
505        let output_cstr = unsafe { CStr::from_ptr(output_ptr) };
506        let output_str = output_cstr.to_str().unwrap();
507        let parsed: Value = serde_json::from_str(output_str).unwrap();
508        assert_eq!(parsed.get("success").and_then(|v| v.as_bool()), Some(false));
509        assert!(
510            parsed
511                .get("error")
512                .and_then(|v| v.as_str())
513                .unwrap_or_default()
514                .contains("must not be empty"),
515            "error message should mention empty input"
516        );
517        unsafe { poseidon_free_string(output_ptr) };
518    }
519
520    #[test]
521    fn test_poseidon_hash_bits_ffi_rejects_oversized_bits() {
522        // MAX_BITS_LEN + 1 bits must be rejected before Poseidon processing.
523        let bits: Vec<u8> = vec![0u8; MAX_BITS_LEN + 1];
524        let bits_json: Vec<String> = bits.iter().map(|b| b.to_string()).collect();
525        let input = format!("{{\"bits\":[{}]}}", bits_json.join(","));
526        let input_cstr = CString::new(input).unwrap();
527        let mut output_ptr: *mut c_char = ptr::null_mut();
528
529        let result = unsafe { poseidon_hash_bits_ffi(input_cstr.as_ptr(), &mut output_ptr) };
530        assert_eq!(result, -1, "oversized bits must return an error code");
531
532        let output_str = unsafe { CStr::from_ptr(output_ptr).to_str().unwrap() };
533        let parsed: Value = serde_json::from_str(output_str).unwrap();
534        assert_eq!(parsed.get("success").and_then(|v| v.as_bool()), Some(false));
535        assert!(
536            parsed.get("error").and_then(|v| v.as_str()).unwrap_or_default()
537                .contains("exceeds maximum allowed length"),
538            "error should mention length limit"
539        );
540        unsafe { poseidon_free_string(output_ptr) };
541    }
542
543    #[test]
544    fn test_poseidon_hash_rejects_too_many_field_inputs() {
545        // MAX_POSEIDON_FIELD_INPUTS + 1 elements must be rejected.
546        let inputs: Vec<String> = (0..=MAX_POSEIDON_FIELD_INPUTS).map(|i| format!("\"{}\"", i)).collect();
547        let input = format!("{{\"inputs\":[{}]}}", inputs.join(","));
548        let input_cstr = CString::new(input).unwrap();
549        let mut output_ptr: *mut c_char = ptr::null_mut();
550
551        let result = unsafe { poseidon_hash(input_cstr.as_ptr(), &mut output_ptr) };
552        assert_eq!(result, -1, "too many inputs must return an error code");
553
554        let output_str = unsafe { CStr::from_ptr(output_ptr).to_str().unwrap() };
555        let parsed: Value = serde_json::from_str(output_str).unwrap();
556        assert_eq!(parsed.get("success").and_then(|v| v.as_bool()), Some(false));
557        assert!(
558            parsed.get("error").and_then(|v| v.as_str()).unwrap_or_default()
559                .contains("exceeds maximum allowed length"),
560            "error should mention length limit"
561        );
562        unsafe { poseidon_free_string(output_ptr) };
563    }
564
565    #[test]
566    fn test_eddsa_sign_rejects_oversized_bits() {
567        // Signing with bits array > MAX_BITS_LEN must be rejected.
568        let bits: Vec<u8> = vec![0u8; MAX_BITS_LEN + 1];
569        let bits_json: Vec<String> = bits.iter().map(|b| b.to_string()).collect();
570        let key = random_private_key_hex();
571        let input = format!(
572            "{{\"operation\":\"sign\",\"data\":{{\"bits\":[{}],\"privateKeyHex\":\"{}\"}}}}",
573            bits_json.join(","),
574            key
575        );
576        let input_cstr = CString::new(input).unwrap();
577        let mut output_ptr: *mut c_char = ptr::null_mut();
578
579        let result = unsafe { eddsa_sign(input_cstr.as_ptr(), &mut output_ptr) };
580        assert_eq!(result, -1, "oversized bits must return an error code");
581
582        let output_str = unsafe { CStr::from_ptr(output_ptr).to_str().unwrap() };
583        let parsed: Value = serde_json::from_str(output_str).unwrap();
584        assert_eq!(parsed.get("success").and_then(|v| v.as_bool()), Some(false));
585        assert!(
586            parsed.get("error").and_then(|v| v.as_str()).unwrap_or_default()
587                .contains("exceeds maximum allowed length"),
588            "error should mention length limit"
589        );
590        unsafe { eddsa_free_string(output_ptr) };
591    }
592
593    #[test]
594    fn test_ffi_rejects_oversized_raw_json() {
595        // Any FFI call whose raw JSON exceeds MAX_JSON_BYTES must be rejected before parsing.
596        // Build a string that is definitely > 1 MiB but is not valid JSON, proving the check
597        // fires before serde ever touches it.
598        let large = "x".repeat(MAX_JSON_BYTES + 1);
599        let input_cstr = CString::new(large).unwrap();
600        let mut output_ptr: *mut c_char = ptr::null_mut();
601
602        let result = unsafe { poseidon_hash_bits_ffi(input_cstr.as_ptr(), &mut output_ptr) };
603        assert_eq!(result, -1, "oversized raw JSON must return an error code");
604
605        let output_str = unsafe { CStr::from_ptr(output_ptr).to_str().unwrap() };
606        let parsed: Value = serde_json::from_str(output_str).unwrap();
607        assert_eq!(parsed.get("success").and_then(|v| v.as_bool()), Some(false));
608        assert!(
609            parsed.get("error").and_then(|v| v.as_str()).unwrap_or_default()
610                .contains("exceeds maximum allowed size"),
611            "error should mention size limit"
612        );
613        unsafe { poseidon_free_string(output_ptr) };
614    }
615}