1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
//
// @author Aurélien Nicolas <aurel@qed-it.com>
// @date 2019

use reading::Messages;
use std::error::Error;
use std::slice;
use writing::CircuitOwned;

#[allow(improper_ctypes)]
extern "C" {
    fn call_gadget(
        call_msg: *const u8,
        constraints_callback: extern fn(context_ptr: *mut Messages, message: *const u8) -> bool,
        constraints_context: *mut Messages,
        witness_callback: extern fn(context_ptr: *mut Messages, message: *const u8) -> bool,
        witness_context: *mut Messages,
        return_callback: extern fn(context_ptr: *mut Messages, message: *const u8) -> bool,
        return_context: *mut Messages,
    ) -> bool;
}

// Read a size prefix (4 bytes, little-endian).
fn read_size_prefix(ptr: *const u8) -> u32 {
    let buf = unsafe { slice::from_raw_parts(ptr, 4) };
    ((buf[0] as u32) << 0) | ((buf[1] as u32) << 8) | ((buf[2] as u32) << 16) | ((buf[3] as u32) << 24)
}

// Bring arguments from C calls back into the type system.
fn from_c<'a, CTX>(
    context_ptr: *mut CTX,
    response: *const u8,
) -> (&'a mut CTX, &'a [u8]) {
    let context = unsafe { &mut *context_ptr };

    let response_len = read_size_prefix(response) + 4;
    let buf = unsafe { slice::from_raw_parts(response, response_len as usize) };

    (context, buf)
}

/// Collect the stream of any messages into the context.
extern "C"
fn callback_c(
    context_ptr: *mut Messages,
    message_ptr: *const u8,
) -> bool {
    let (context, buf) = from_c(context_ptr, message_ptr);

    context.push_message(Vec::from(buf)).is_ok()
}

pub fn call_gadget_wrapper(circuit: &CircuitOwned) -> Result<Messages, Box<Error>> {
    let mut message_buf = vec![];
    circuit.write(&mut message_buf)?;

    let mut context = Messages::new(circuit.free_variable_id);
    let ok = unsafe {
        call_gadget(
            message_buf.as_ptr(),
            callback_c,
            &mut context as *mut Messages,
            callback_c,
            &mut context as *mut Messages,
            callback_c,
            &mut context as *mut Messages,
        )
    };

    match ok {
        true => Ok(context),
        false => Err("call_gadget failed".into()),
    }
}


#[test]
#[cfg(feature = "cpp")]
fn test_cpp_gadget() {
    use writing::VariablesOwned;

    let mut call = CircuitOwned {
        connections: VariablesOwned {
            variable_ids: vec![100, 101], // Some input variables.
            values: None,
        },
        free_variable_id: 102,
        r1cs_generation: true,
        field_order: None,
    };

    println!("==== R1CS generation ====");

    let r1cs_response = call_gadget_wrapper(&call).unwrap();

    println!("R1CS: Rust received {} messages including {} gadget return.",
             r1cs_response.messages.len(),
             r1cs_response.circuits().len());

    assert!(r1cs_response.messages.len() == 2);
    assert!(r1cs_response.circuits().len() == 1);

    println!("R1CS: Got constraints:");
    for c in r1cs_response.iter_constraints() {
        println!("{:?} * {:?} = {:?}", c.a, c.b, c.c);
    }

    let free_variable_id_after = r1cs_response.last_circuit().unwrap().free_variable_id();
    println!("R1CS: Free variable id after the call: {}\n", free_variable_id_after);
    assert!(free_variable_id_after == 102 + 1 + 2);


    println!("==== Witness generation ====");

    call.r1cs_generation = false;
    call.connections.values = Some(vec![4, 5, 6, 14, 15, 16 as u8]);

    let witness_response = call_gadget_wrapper(&call).unwrap();

    println!("Assignment: Rust received {} messages including {} gadget return.",
             witness_response.messages.len(),
             witness_response.circuits().len());

    assert!(witness_response.messages.len() == 2);
    assert!(witness_response.circuits().len() == 1);

    {
        let assignment: Vec<_> = witness_response.iter_assignment().collect();

        println!("Assignment: Got witness:");
        for var in assignment.iter() {
            println!("{} = {:?}", var.id, var.value);
        }

        assert_eq!(assignment.len(), 2);
        assert_eq!(assignment[0].value.len(), 3);
        assert_eq!(assignment[0].id, 103 + 0); // First gadget-allocated variable.
        assert_eq!(assignment[1].id, 103 + 1); // Second "
        assert_eq!(assignment[0].value, &[10, 11, 12]); // First element.
        assert_eq!(assignment[1].value, &[8, 7, 6]);    // Second element

        let free_variable_id_after2 = witness_response.last_circuit().unwrap().free_variable_id();
        println!("Assignment: Free variable id after the call: {}", free_variable_id_after2);
        assert!(free_variable_id_after2 == 102 + 1 + 2);
        assert!(free_variable_id_after2 == free_variable_id_after);

        let out_vars = witness_response.connection_variables().unwrap();
        println!("{:?}", out_vars);
    }
    println!();
}