Skip to main content

cadapter/
adapters_utils.rs

1extern crate ciphercore_utils;
2
3use ciphercore_base::custom_ops::CustomOperation;
4use ciphercore_base::data_types::{ScalarType, Type};
5use ciphercore_base::data_values::Value;
6use ciphercore_base::errors::Result;
7use ciphercore_base::graphs::Operation;
8use ciphercore_base::graphs::{Graph, Node, Slice, SliceElement};
9use ciphercore_base::runtime_error;
10use ciphercore_base::typed_value::TypedValue;
11use ciphercore_utils::errors::CiphercoreErrorBody;
12use ciphercore_utils::errors::CiphercoreErrorKind;
13use ciphercore_utils::errors::ErrorWithBody;
14
15//CVecVal stores an array of values of a vector's elements
16#[repr(C)]
17#[derive(Clone)]
18pub struct CVecVal<T> {
19    pub ptr: *mut T,
20    pub len: usize,
21}
22
23impl<T: Copy> CVecVal<T> {
24    pub(crate) fn to_vec(&self) -> Result<Vec<T>> {
25        let mut v = Vec::<T>::new();
26        unsafe {
27            if self.len == 0 {
28                return Ok(v);
29            }
30            if self.ptr.is_null() {
31                return Err(runtime_error!("Base pointer of vector is Null"));
32            }
33            for i in 0..self.len {
34                let x = &*self.ptr.add(i);
35                v.push(*x);
36            }
37        };
38        Ok(v)
39    }
40    pub(crate) fn from_vec(mut v: Vec<T>) -> CVecVal<T> {
41        v.shrink_to_fit();
42        let cvec = CVecVal::<T> {
43            ptr: v.as_mut_ptr(),
44            len: v.len(),
45        };
46        std::mem::forget(v);
47        cvec
48    }
49}
50
51fn cvec_val_destroy_helper<T>(cvec_ptr: *mut CVecVal<T>) -> CResultVal<bool> {
52    let helper = || -> Result<bool> {
53        unsafe {
54            let cvec = Box::from_raw(cvec_ptr);
55            drop(Vec::from_raw_parts(cvec.ptr, cvec.len, cvec.len));
56            drop(cvec);
57        }
58        Ok(true)
59    };
60    CResultVal::new(helper())
61}
62
63#[allow(dead_code)]
64fn vector_from_unsafe_cvec_val<T: Copy>(cvec_ptr: *mut CVecVal<T>) -> Result<Vec<T>> {
65    unsafe_deref(cvec_ptr)?.to_vec()
66}
67
68#[no_mangle]
69pub extern "C" fn cvec_u64_destroy(cvec_ptr: *mut CVecVal<u64>) -> CResultVal<bool> {
70    cvec_val_destroy_helper(cvec_ptr)
71}
72
73#[no_mangle]
74pub extern "C" fn cvec_cstr_destroy(cvec_ptr: *mut CVecVal<CStr>) -> CResultVal<bool> {
75    cvec_val_destroy_helper(cvec_ptr)
76}
77
78//CVec stores an array of pointers to a vector's elements
79#[repr(C)]
80#[derive(Clone)]
81pub struct CVec<T> {
82    pub ptr: *mut *mut T,
83    pub len: usize,
84}
85
86impl<T: Clone> CVec<T> {
87    pub(crate) fn to_vec(&self) -> Result<Vec<T>> {
88        let mut v = Vec::<T>::new();
89        unsafe {
90            if self.len == 0 {
91                return Ok(v);
92            }
93            if self.ptr.is_null() {
94                return Err(runtime_error!("Base pointer of vector is Null"));
95            }
96            for i in 0..self.len {
97                if self.ptr.add(i).is_null() {
98                    return Err(runtime_error!("Vector pointer is Null"));
99                }
100                let x = &**self.ptr.add(i);
101                v.push((*x).clone());
102            }
103        };
104        Ok(v)
105    }
106    pub(crate) fn from_vec(v: Vec<T>) -> CVec<T> {
107        let mut ptr_vec: Vec<*mut T> = v.into_iter().map(|x| unsafe_ref(x)).collect();
108        ptr_vec.shrink_to_fit();
109        let cvec = CVec {
110            ptr: ptr_vec.as_mut_ptr(),
111            len: ptr_vec.len(),
112        };
113        std::mem::forget(ptr_vec);
114        cvec
115    }
116}
117
118fn _vector_from_unsafe_cvec<T: Clone>(cvec_ptr: *mut CVec<T>) -> Result<Vec<T>> {
119    unsafe_deref(cvec_ptr)?.to_vec()
120}
121
122fn cvec_destroy_helper<T>(cvec_ptr: *mut CVec<T>) -> CResultVal<bool> {
123    let helper = || -> Result<bool> {
124        unsafe {
125            let cvec = Box::from_raw(cvec_ptr);
126            Vec::from_raw_parts(cvec.ptr, cvec.len, cvec.len);
127            drop(cvec);
128        }
129        Ok(true)
130    };
131    CResultVal::new(helper())
132}
133
134#[no_mangle]
135pub extern "C" fn cvec_type_destroy(cvec_ptr: *mut CVec<Type>) -> CResultVal<bool> {
136    cvec_destroy_helper(cvec_ptr)
137}
138#[no_mangle]
139pub extern "C" fn cvec_node_destroy(cvec_ptr: *mut CVec<Node>) -> CResultVal<bool> {
140    cvec_destroy_helper(cvec_ptr)
141}
142#[no_mangle]
143pub extern "C" fn cvec_graph_destroy(cvec_ptr: *mut CVec<Graph>) -> CResultVal<bool> {
144    cvec_destroy_helper(cvec_ptr)
145}
146#[no_mangle]
147pub extern "C" fn cvec_cslice_element_destroy(
148    cvec_ptr: *mut CVec<CSliceElement>,
149) -> CResultVal<bool> {
150    cvec_destroy_helper(cvec_ptr)
151}
152
153pub(crate) fn unsafe_ref<T>(x: T) -> *mut T {
154    Box::into_raw(Box::new(x))
155}
156
157pub(crate) fn destroy_helper<T>(ptr: *mut T) {
158    unsafe {
159        Box::from_raw(ptr);
160    }
161}
162
163pub(crate) fn unsafe_deref<T: Clone>(ptr: *mut T) -> Result<T> {
164    unsafe {
165        if ptr.is_null() {
166            return Err(runtime_error!("Null pointer passed by C"));
167        }
168        let tmp_ptr = &*ptr;
169        Ok((*tmp_ptr).clone())
170    }
171}
172
173pub(crate) fn unsafe_deref_const<T: Clone>(ptr: *const T) -> Result<T> {
174    unsafe {
175        if ptr.is_null() {
176            return Err(runtime_error!("Null pointer passed by C"));
177        }
178        let tmp_ptr = &*ptr;
179        Ok((*tmp_ptr).clone())
180    }
181}
182#[repr(C)]
183#[derive(Clone, Copy)]
184pub struct CStr {
185    pub ptr: *const u8,
186}
187impl CStr {
188    pub(crate) fn to_str_slice(self) -> Result<&'static str> {
189        let cs = unsafe { std::ffi::CStr::from_ptr(self.ptr as *const i8) };
190        let str_slice = cs.to_str()?;
191        Ok(str_slice)
192    }
193    pub(crate) fn to_string(self) -> Result<String> {
194        let str_slice = self.to_str_slice()?;
195        Ok(str_slice.to_owned())
196    }
197    pub(crate) fn from_string(s: String) -> Result<CStr> {
198        let cs = std::ffi::CString::new(s)?;
199        let p = cs.as_ptr() as *const u8;
200        unsafe_ref(cs);
201        Ok(CStr { ptr: p })
202    }
203}
204
205#[no_mangle]
206pub extern "C" fn cstr_destroy(cstr: CStr) -> CResultVal<bool> {
207    let helper = || -> Result<bool> {
208        unsafe {
209            Box::from_raw(cstr.ptr as *mut u8);
210        }
211        Ok(true)
212    };
213    CResultVal::new(helper())
214}
215#[repr(C)]
216pub struct CiphercoreError {
217    pub kind: CiphercoreErrorKind,
218    pub msg: CStr,
219}
220
221impl CiphercoreError {
222    pub(crate) fn new(body: CiphercoreErrorBody) -> CiphercoreError {
223        let s = format!("{}", body);
224        let cs = std::ffi::CString::new(s).unwrap();
225        let p = cs.as_ptr() as *const u8;
226        std::mem::forget(cs);
227        CiphercoreError {
228            kind: (body.kind),
229            msg: (CStr { ptr: p }),
230        }
231    }
232}
233
234pub trait CResultTrait<T> {
235    fn new(res: Result<T>) -> Self;
236}
237
238// CResult returns a raw pointer in case of success and error message and type in case of error.
239#[repr(C)]
240pub enum CResult<T> {
241    Ok(*mut T),
242    Err(CiphercoreError),
243}
244
245impl<T> CResultTrait<T> for CResult<T> {
246    fn new(res: Result<T>) -> CResult<T> {
247        match res {
248            Ok(x) => CResult::Ok(unsafe_ref(x)),
249            Err(e) => CResult::Err(CiphercoreError::new(e.get_body())),
250        }
251    }
252}
253
254// CResultVal returns a value in case of success and error message and type in case of error.
255#[repr(C)]
256pub enum CResultVal<T> {
257    Ok(T),
258    Err(CiphercoreError),
259}
260
261impl<T> CResultTrait<T> for CResultVal<T> {
262    fn new(res: Result<T>) -> CResultVal<T> {
263        match res {
264            Ok(x) => CResultVal::Ok(x),
265            Err(e) => CResultVal::Err(CiphercoreError::new(e.get_body())),
266        }
267    }
268}
269
270#[repr(C)]
271#[derive(Clone)]
272pub struct CTypedValue {
273    json: CStr,
274}
275impl CTypedValue {
276    pub(crate) fn to_type_value(&self) -> Result<(Type, Value)> {
277        let op_str_slice = self.json.to_str_slice()?;
278        let tv = TypedValue::from_json(&json::parse(op_str_slice)?)?;
279        Ok((tv.t, tv.value))
280    }
281    pub(crate) fn from_type_and_value(t: Type, value: Value) -> Result<CTypedValue> {
282        let tv = TypedValue { value, t };
283        let jstr = CStr::from_string(tv.to_json()?.dump())?;
284        Ok(CTypedValue { json: jstr })
285    }
286}
287
288#[repr(C)]
289pub struct CCustomOperation {
290    json: CStr,
291}
292impl CCustomOperation {
293    pub(crate) fn to_custom_op(&self) -> Result<CustomOperation> {
294        let op_str_slice = self.json.to_str_slice()?;
295        Ok(serde_json::from_str::<CustomOperation>(op_str_slice)?)
296    }
297    pub(crate) fn from_custom_op(cop: CustomOperation) -> Result<CCustomOperation> {
298        Ok(CCustomOperation {
299            json: CStr::from_string(serde_json::to_string(&cop)?)?,
300        })
301    }
302}
303
304#[derive(Clone)]
305#[repr(C)]
306pub struct COption_i64 {
307    valid: bool,
308    num: i64,
309}
310impl COption_i64 {
311    pub(crate) fn to_option(&self) -> Option<i64> {
312        if self.valid {
313            Some(self.num)
314        } else {
315            None
316        }
317    }
318    pub(crate) fn from_option(op: Option<i64>) -> COption_i64 {
319        match op {
320            Some(x) => COption_i64 {
321                valid: true,
322                num: x,
323            },
324            None => COption_i64 {
325                valid: false,
326                num: 0,
327            },
328        }
329    }
330}
331#[derive(Clone)]
332#[repr(C)]
333pub struct COption_i64_triplet {
334    op1: COption_i64,
335    op2: COption_i64,
336    op3: COption_i64,
337}
338
339#[derive(Clone)]
340#[repr(C)]
341pub enum CSliceElement {
342    SingleIndex(i64),
343    SubArray(COption_i64_triplet),
344    Ellipsis,
345}
346impl CSliceElement {
347    pub(crate) fn to_slice_element(&self) -> SliceElement {
348        match self {
349            Self::SingleIndex(x) => SliceElement::SingleIndex(*x),
350            Self::SubArray(x) => {
351                SliceElement::SubArray(x.op1.to_option(), x.op2.to_option(), x.op3.to_option())
352            }
353            Self::Ellipsis => SliceElement::Ellipsis,
354        }
355    }
356    pub(crate) fn from_slice_element(se: SliceElement) -> CSliceElement {
357        match se {
358            SliceElement::SingleIndex(x) => Self::SingleIndex(x),
359            SliceElement::SubArray(x, y, z) => Self::SubArray(COption_i64_triplet {
360                op1: COption_i64::from_option(x),
361                op2: COption_i64::from_option(y),
362                op3: COption_i64::from_option(z),
363            }),
364            SliceElement::Ellipsis => Self::Ellipsis,
365        }
366    }
367}
368
369#[repr(C)]
370#[derive(Clone)]
371pub struct CSlice {
372    elements: CVec<CSliceElement>,
373}
374impl CSlice {
375    pub(crate) fn to_slice(&self) -> Result<Slice> {
376        let celem_vec = self.elements.to_vec()?;
377        let elem_vec = celem_vec.iter().map(|x| x.to_slice_element()).collect();
378        Ok(elem_vec)
379    }
380    pub(crate) fn from_slice(s: Slice) -> CSlice {
381        let celem_vec = s
382            .iter()
383            .map(|x| CSliceElement::from_slice_element((*x).clone()))
384            .collect();
385        let celem_cvec = CVec::from_vec(celem_vec);
386        CSlice {
387            elements: celem_cvec,
388        }
389    }
390}
391
392#[allow(clippy::not_unsafe_ptr_arg_deref)]
393#[no_mangle]
394pub extern "C" fn c_slice_destroy(cslice_ptr: *mut CSlice) {
395    unsafe {
396        let cslice_ref = Box::from_raw(cslice_ptr);
397        let elements = cslice_ref.elements;
398        let vec_elements = Vec::from_raw_parts(elements.ptr, elements.len, elements.len);
399        for elem in vec_elements {
400            Box::from_raw(elem);
401        }
402    }
403}
404
405#[repr(C)]
406#[derive(Clone)]
407pub struct U64TypePtrTuple {
408    iv: u64,
409    type_ptr: *mut Type,
410}
411
412#[repr(C)]
413pub struct CStrTypePtrTuple {
414    str: CStr,
415    type_ptr: *mut Type,
416}
417#[repr(C)]
418pub enum COperation {
419    Input(*mut Type),
420    Add,
421    Subtract,
422    Multiply,
423    Dot,
424    Matmul,
425    Truncate(u64),
426    Sum(*mut CVecVal<u64>),
427    PermuteAxes(*mut CVecVal<u64>),
428    Get(*mut CVecVal<u64>),
429    GetSlice(*mut CSlice),
430    Reshape(*mut Type),
431    NOP,
432    Random(*mut Type),
433    PRF(*mut U64TypePtrTuple),
434    Stack(*mut CVecVal<u64>),
435    Constant(*mut CTypedValue),
436    A2B,
437    B2A(*mut ScalarType),
438    CreateTuple,
439    CreateNamedTuple(*mut CVecVal<CStr>),
440    CreateVector(*mut Type),
441    TupleGet(u64),
442    NamedTupleGet(CStr),
443    VectorGet,
444    Zip,
445    Repeat(u64),
446    Call,
447    Iterate,
448    ArrayToVector,
449    VectorToArray,
450    Custom(CCustomOperation),
451}
452impl COperation {
453    pub(crate) fn _to_operation(&self) -> Result<Operation> {
454        let op = match self {
455            Self::Input(t_ptr) => Operation::Input(unsafe_deref(*t_ptr)?),
456            Self::Add => Operation::Add,
457            Self::Subtract => Operation::Subtract,
458            Self::Multiply => Operation::Multiply,
459            Self::Dot => Operation::Dot,
460            Self::Matmul => Operation::Matmul,
461            Self::Truncate(x) => Operation::Truncate(*x),
462            Self::Sum(cvec) => Operation::Sum(vector_from_unsafe_cvec_val(*cvec)?),
463            Self::PermuteAxes(cvec) => Operation::PermuteAxes(vector_from_unsafe_cvec_val(*cvec)?),
464            Self::Get(cvec) => Operation::Get(vector_from_unsafe_cvec_val(*cvec)?),
465            Self::GetSlice(cslice) => Operation::GetSlice(unsafe_deref(*cslice)?.to_slice()?),
466            Self::Reshape(t_ptr) => Operation::Reshape(unsafe_deref(*t_ptr)?),
467            Self::NOP => Operation::NOP,
468            Self::Random(t_ptr) => Operation::Random(unsafe_deref(*t_ptr)?),
469            Self::PRF(tuple) => {
470                let tuple_safe = unsafe_deref(*tuple)?;
471                Operation::PRF(tuple_safe.iv, unsafe_deref(tuple_safe.type_ptr)?)
472            }
473            Self::Stack(cvec) => Operation::Stack(vector_from_unsafe_cvec_val(*cvec)?),
474            Self::A2B => Operation::A2B,
475            Self::B2A(st_ptr) => Operation::B2A(unsafe_deref(*st_ptr)?),
476            Self::CreateTuple => Operation::CreateTuple,
477            Self::CreateNamedTuple(cvec_cstr) => {
478                let vec_cstr = vector_from_unsafe_cvec_val(*cvec_cstr)?;
479                let vec_str = vec_cstr
480                    .iter()
481                    .map(|x| -> Result<String> { x.to_string() })
482                    .collect::<Result<Vec<String>>>()?;
483                Operation::CreateNamedTuple(vec_str)
484            }
485            Self::CreateVector(t_ptr) => Operation::CreateVector(unsafe_deref(*t_ptr)?),
486            Self::TupleGet(x) => Operation::TupleGet(*x),
487            Self::NamedTupleGet(cstr) => Operation::NamedTupleGet(cstr.to_string()?),
488            Self::VectorGet => Operation::VectorGet,
489            Self::Zip => Operation::Zip,
490            Self::Repeat(n) => Operation::Repeat(*n),
491            Self::Call => Operation::Call,
492            Self::Iterate => Operation::Iterate,
493            Self::ArrayToVector => Operation::ArrayToVector,
494            Self::VectorToArray => Operation::VectorToArray,
495            Self::Custom(c_cust_op) => Operation::Custom((*c_cust_op).to_custom_op()?),
496            Self::Constant(c_val) => {
497                let c_val_safe = unsafe_deref(*c_val)?;
498                Operation::Constant(
499                    (c_val_safe).to_type_value()?.0,
500                    (c_val_safe).to_type_value()?.1,
501                )
502            }
503        };
504        Ok(op)
505    }
506    pub(crate) fn from_operation(op: Operation) -> Result<COperation> {
507        let cop = match op {
508            Operation::Input(t) => Self::Input(unsafe_ref(t)),
509            Operation::Add => Self::Add,
510            Operation::Subtract => Self::Subtract,
511            Operation::Multiply => Self::Multiply,
512            Operation::Dot => Self::Dot,
513            Operation::Matmul => Self::Matmul,
514            Operation::Truncate(x) => Self::Truncate(x),
515            Operation::Sum(vec) => Self::Sum(unsafe_ref(CVecVal::from_vec(vec))),
516            Operation::PermuteAxes(vec) => Self::PermuteAxes(unsafe_ref(CVecVal::from_vec(vec))),
517            Operation::Get(vec) => Self::Get(unsafe_ref(CVecVal::from_vec(vec))),
518            Operation::GetSlice(slice) => Self::GetSlice(unsafe_ref(CSlice::from_slice(slice))),
519            Operation::Reshape(t) => Self::Reshape(unsafe_ref(t)),
520            Operation::NOP => Self::NOP,
521            Operation::Random(t) => Self::Random(unsafe_ref(t)),
522            Operation::PRF(iv, t) => Self::PRF(unsafe_ref(U64TypePtrTuple {
523                iv,
524                type_ptr: unsafe_ref(t),
525            })),
526            Operation::Stack(vec) => Self::Stack(unsafe_ref(CVecVal::from_vec(vec))),
527            Operation::A2B => Self::A2B,
528            Operation::B2A(st) => Self::B2A(unsafe_ref(st)),
529            Operation::CreateTuple => Self::CreateTuple,
530            Operation::CreateNamedTuple(vec_str) => {
531                let vec_cstr = vec_str
532                    .iter()
533                    .map(|x| -> Result<CStr> { CStr::from_string((*x).clone()) })
534                    .collect::<Result<Vec<CStr>>>()?;
535                let cvec_cstr = CVecVal::from_vec(vec_cstr);
536                Self::CreateNamedTuple(unsafe_ref(cvec_cstr))
537            }
538            Operation::CreateVector(t) => Self::CreateVector(unsafe_ref(t)),
539            Operation::TupleGet(x) => Self::TupleGet(x),
540            Operation::NamedTupleGet(str) => Self::NamedTupleGet(CStr::from_string(str)?),
541            Operation::VectorGet => Self::VectorGet,
542            Operation::Zip => Self::Zip,
543            Operation::Repeat(n) => Self::Repeat(n),
544            Operation::Call => Self::Call,
545            Operation::Iterate => Self::Iterate,
546            Operation::ArrayToVector => Self::ArrayToVector,
547            Operation::VectorToArray => Self::VectorToArray,
548            Operation::Custom(cust_op) => Self::Custom(CCustomOperation::from_custom_op(cust_op)?),
549            Operation::Constant(t, v) => {
550                Self::Constant(unsafe_ref(CTypedValue::from_type_and_value(t, v)?))
551            }
552        };
553        Ok(cop)
554    }
555}
556
557#[no_mangle]
558pub extern "C" fn c_operation_destroy(cop_ptr: *mut COperation) {
559    destroy_helper(cop_ptr);
560}