cadapter/
graph_adapter.rs

1extern crate ciphercore_base;
2
3use crate::adapters_utils::{
4    destroy_helper, unsafe_deref, CCustomOperation, COperation, CResult, CResultTrait, CResultVal,
5    CSlice, CStr, CTypedValue, CVec, CVecVal,
6};
7use ciphercore_base::data_types::{ScalarType, Type};
8use ciphercore_base::errors::Result;
9use ciphercore_base::graphs::{Context, Graph, Node};
10use ciphercore_base::runtime_error;
11
12fn graph_method_helper<T, F, R: CResultTrait<T>>(graph_ptr: *mut Graph, op: F) -> R
13where
14    F: FnOnce(Graph) -> Result<T>,
15{
16    let helper = || -> Result<T> {
17        let graph = unsafe_deref(graph_ptr)?;
18        op(graph)
19    };
20    R::new(helper())
21}
22
23fn graph_one_node_method_helper<T, F, R: CResultTrait<T>>(
24    graph_ptr: *mut Graph,
25    a_ptr: *mut Node,
26    op: F,
27) -> R
28where
29    F: FnOnce(Graph, Node) -> Result<T>,
30{
31    let helper = || -> Result<T> {
32        let graph = unsafe_deref(graph_ptr)?;
33        let a = unsafe_deref(a_ptr)?;
34        op(graph, a)
35    };
36    R::new(helper())
37}
38
39fn graph_two_nodes_method_helper<T, F, R: CResultTrait<T>>(
40    graph_ptr: *mut Graph,
41    a_ptr: *mut Node,
42    b_ptr: *mut Node,
43    op: F,
44) -> R
45where
46    F: FnOnce(Graph, Node, Node) -> Result<T>,
47{
48    let helper = || -> Result<T> {
49        let graph = unsafe_deref(graph_ptr)?;
50        let a = unsafe_deref(a_ptr)?;
51        let b = unsafe_deref(b_ptr)?;
52        op(graph, a, b)
53    };
54    R::new(helper())
55}
56
57#[no_mangle]
58pub extern "C" fn graph_input(graph_ptr: *mut Graph, type_ptr: *mut Type) -> CResult<Node> {
59    graph_method_helper(graph_ptr, |g| {
60        let t = unsafe_deref(type_ptr)?;
61        g.input(t)
62    })
63}
64
65#[no_mangle]
66pub extern "C" fn graph_add(
67    graph_ptr: *mut Graph,
68    a_ptr: *mut Node,
69    b_ptr: *mut Node,
70) -> CResult<Node> {
71    graph_two_nodes_method_helper(graph_ptr, a_ptr, b_ptr, |g, a, b| g.add(a, b))
72}
73
74#[no_mangle]
75pub extern "C" fn graph_subtract(
76    graph_ptr: *mut Graph,
77    a_ptr: *mut Node,
78    b_ptr: *mut Node,
79) -> CResult<Node> {
80    graph_two_nodes_method_helper(graph_ptr, a_ptr, b_ptr, |g, a, b| g.subtract(a, b))
81}
82
83#[no_mangle]
84pub extern "C" fn graph_multiply(
85    graph_ptr: *mut Graph,
86    a_ptr: *mut Node,
87    b_ptr: *mut Node,
88) -> CResult<Node> {
89    graph_two_nodes_method_helper(graph_ptr, a_ptr, b_ptr, |g, a, b| g.multiply(a, b))
90}
91#[no_mangle]
92pub extern "C" fn graph_dot(
93    graph_ptr: *mut Graph,
94    a_ptr: *mut Node,
95    b_ptr: *mut Node,
96) -> CResult<Node> {
97    graph_two_nodes_method_helper(graph_ptr, a_ptr, b_ptr, |g, a, b| g.dot(a, b))
98}
99
100#[no_mangle]
101pub extern "C" fn graph_matmul(
102    graph_ptr: *mut Graph,
103    a_ptr: *mut Node,
104    b_ptr: *mut Node,
105) -> CResult<Node> {
106    graph_two_nodes_method_helper(graph_ptr, a_ptr, b_ptr, |g, a, b| g.matmul(a, b))
107}
108
109#[no_mangle]
110pub extern "C" fn graph_truncate(
111    graph_ptr: *mut Graph,
112    a_ptr: *mut Node,
113    scale: u64,
114) -> CResult<Node> {
115    graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| g.truncate(a, scale))
116}
117
118#[no_mangle]
119pub extern "C" fn graph_sum(
120    graph_ptr: *mut Graph,
121    a_ptr: *mut Node,
122    axis: CVecVal<u64>,
123) -> CResult<Node> {
124    graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| {
125        let axis_vec = axis.to_vec()?;
126        g.sum(a, axis_vec)
127    })
128}
129
130#[no_mangle]
131pub extern "C" fn graph_permute_axes(
132    graph_ptr: *mut Graph,
133    a_ptr: *mut Node,
134    axis: CVecVal<u64>,
135) -> CResult<Node> {
136    graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| {
137        let axis_vec = axis.to_vec()?;
138        g.permute_axes(a, axis_vec)
139    })
140}
141
142#[no_mangle]
143pub extern "C" fn graph_get(
144    graph_ptr: *mut Graph,
145    a_ptr: *mut Node,
146    index: CVecVal<u64>,
147) -> CResult<Node> {
148    graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| {
149        let index_vec = index.to_vec()?;
150        g.get(a, index_vec)
151    })
152}
153
154#[no_mangle]
155pub extern "C" fn graph_get_slice(
156    graph_ptr: *mut Graph,
157    a_ptr: *mut Node,
158    cslice: CSlice,
159) -> CResult<Node> {
160    graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| {
161        let slice = cslice.to_slice()?;
162        g.get_slice(a, slice)
163    })
164}
165
166#[no_mangle]
167pub extern "C" fn graph_reshape(
168    graph_ptr: *mut Graph,
169    a_ptr: *mut Node,
170    new_type_ptr: *mut Type,
171) -> CResult<Node> {
172    graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| {
173        let new_t = unsafe_deref(new_type_ptr)?;
174        g.reshape(a, new_t)
175    })
176}
177
178#[no_mangle]
179pub extern "C" fn graph_random(graph_ptr: *mut Graph, output_type_ptr: *mut Type) -> CResult<Node> {
180    graph_method_helper(graph_ptr, |g| {
181        let output_t = unsafe_deref(output_type_ptr)?;
182        g.random(output_t)
183    })
184}
185
186#[no_mangle]
187pub extern "C" fn graph_stack(
188    graph_ptr: *mut Graph,
189    nodes: CVec<Node>,
190    outer_shape: CVecVal<u64>,
191) -> CResult<Node> {
192    graph_method_helper(graph_ptr, |g| {
193        let nodes_vec = nodes.to_vec()?;
194        let outer_shape_vec = outer_shape.to_vec()?;
195        g.stack(nodes_vec, outer_shape_vec)
196    })
197}
198#[no_mangle]
199pub extern "C" fn graph_constant(graph_ptr: *mut Graph, typed_value: CTypedValue) -> CResult<Node> {
200    graph_method_helper(graph_ptr, |g| {
201        let (t, v) = typed_value.to_type_value()?;
202        g.constant(t, v)
203    })
204}
205
206#[no_mangle]
207pub extern "C" fn graph_a2b(graph_ptr: *mut Graph, a_ptr: *mut Node) -> CResult<Node> {
208    graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| g.a2b(a))
209}
210
211#[no_mangle]
212pub extern "C" fn graph_b2a(
213    graph_ptr: *mut Graph,
214    a_ptr: *mut Node,
215    scalar_type_ptr: *mut ScalarType,
216) -> CResult<Node> {
217    graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| {
218        let st = unsafe_deref(scalar_type_ptr)?;
219        g.b2a(a, st)
220    })
221}
222#[no_mangle]
223pub extern "C" fn graph_create_tuple(graph_ptr: *mut Graph, elements: CVec<Node>) -> CResult<Node> {
224    graph_method_helper(graph_ptr, |g| {
225        let elem = elements.to_vec()?;
226        g.create_tuple(elem)
227    })
228}
229
230#[no_mangle]
231pub extern "C" fn graph_create_vector(
232    graph_ptr: *mut Graph,
233    type_ptr: *mut Type,
234    elements: CVec<Node>,
235) -> CResult<Node> {
236    graph_method_helper(graph_ptr, |g| {
237        let t = unsafe_deref(type_ptr)?;
238        let elem = elements.to_vec()?;
239        g.create_vector(t, elem)
240    })
241}
242
243#[no_mangle]
244pub extern "C" fn graph_create_named_tuple(
245    graph_ptr: *mut Graph,
246    elements_nodes: CVec<Node>,
247    elements_names: CVecVal<CStr>,
248) -> CResult<Node> {
249    graph_method_helper(graph_ptr, |g| {
250        let elem_nodes = elements_nodes.to_vec()?;
251        let elem_names = elements_names.to_vec()?;
252        let elem_names_string: Vec<String> = elem_names
253            .iter()
254            .map(|x| -> Result<String> { x.to_string() })
255            .collect::<Result<Vec<String>>>()?;
256        let elem: Vec<(String, Node)> = elem_names_string
257            .iter()
258            .zip(elem_nodes.iter())
259            .map(|(x, y)| ((*x).clone(), (*y).clone()))
260            .collect();
261        g.create_named_tuple(elem)
262    })
263}
264#[no_mangle]
265pub extern "C" fn graph_tuple_get(
266    graph_ptr: *mut Graph,
267    tuple_node_ptr: *mut Node,
268    index: u64,
269) -> CResult<Node> {
270    graph_one_node_method_helper(graph_ptr, tuple_node_ptr, |g, t| g.tuple_get(t, index))
271}
272
273#[no_mangle]
274pub extern "C" fn graph_named_tuple_get(
275    graph_ptr: *mut Graph,
276    tuple_node_ptr: *mut Node,
277    key: CStr,
278) -> CResult<Node> {
279    graph_one_node_method_helper(graph_ptr, tuple_node_ptr, |g, t| {
280        let key_string = key.to_string()?;
281        g.named_tuple_get(t, key_string)
282    })
283}
284
285#[no_mangle]
286pub extern "C" fn graph_vector_get(
287    graph_ptr: *mut Graph,
288    vec_node_ptr: *mut Node,
289    index_node_ptr: *mut Node,
290) -> CResult<Node> {
291    graph_two_nodes_method_helper(graph_ptr, vec_node_ptr, index_node_ptr, |g, v, i| {
292        g.vector_get(v, i)
293    })
294}
295
296#[no_mangle]
297pub extern "C" fn graph_zip(graph_ptr: *mut Graph, elements: CVec<Node>) -> CResult<Node> {
298    graph_method_helper(graph_ptr, |g| g.zip(elements.to_vec()?))
299}
300
301#[no_mangle]
302pub extern "C" fn graph_repeat(graph_ptr: *mut Graph, a_ptr: *mut Node, n: u64) -> CResult<Node> {
303    graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| g.repeat(a, n))
304}
305
306#[no_mangle]
307pub extern "C" fn graph_call(
308    graph_ptr: *mut Graph,
309    callee_ptr: *mut Graph,
310    arguments: CVec<Node>,
311) -> CResult<Node> {
312    graph_method_helper(graph_ptr, |g| {
313        let callee = unsafe_deref(callee_ptr)?;
314        g.call(callee, arguments.to_vec()?)
315    })
316}
317
318#[no_mangle]
319pub extern "C" fn graph_iterate(
320    graph_ptr: *mut Graph,
321    callee_ptr: *mut Graph,
322    state_ptr: *mut Node,
323    input_ptr: *mut Node,
324) -> CResult<Node> {
325    graph_method_helper(graph_ptr, |g| {
326        let callee = unsafe_deref(callee_ptr)?;
327        let state = unsafe_deref(state_ptr)?;
328        let input = unsafe_deref(input_ptr)?;
329        g.iterate(callee, state, input)
330    })
331}
332
333#[no_mangle]
334pub extern "C" fn graph_vector_to_array(graph_ptr: *mut Graph, a_ptr: *mut Node) -> CResult<Node> {
335    graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| g.vector_to_array(a))
336}
337
338#[no_mangle]
339pub extern "C" fn graph_array_to_vector(graph_ptr: *mut Graph, a_ptr: *mut Node) -> CResult<Node> {
340    graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| g.array_to_vector(a))
341}
342
343#[no_mangle]
344pub extern "C" fn graph_custom_op(
345    graph_ptr: *mut Graph,
346    c_custom_op: CCustomOperation,
347    args: CVec<Node>,
348) -> CResult<Node> {
349    graph_method_helper(graph_ptr, |g| {
350        g.custom_op(c_custom_op.to_custom_op()?, args.to_vec()?)
351    })
352}
353
354#[no_mangle]
355pub extern "C" fn graph_finalize(graph_ptr: *mut Graph) -> CResult<Graph> {
356    graph_method_helper(graph_ptr, |g| g.finalize())
357}
358
359#[no_mangle]
360pub extern "C" fn graph_get_nodes(graph_ptr: *mut Graph) -> CResult<CVec<Node>> {
361    graph_method_helper(graph_ptr, |g| Ok(CVec::from_vec(g.get_nodes())))
362}
363
364#[no_mangle]
365pub extern "C" fn graph_set_output_node(
366    graph_ptr: *mut Graph,
367    n_ptr: *mut Node,
368) -> CResultVal<bool> {
369    graph_one_node_method_helper(graph_ptr, n_ptr, |g, n| {
370        let res = g.set_output_node(n);
371        match res {
372            Ok(_) => Ok(true),
373            Err(e) => Err(e),
374        }
375    })
376}
377
378#[no_mangle]
379pub extern "C" fn graph_get_output_node(graph_ptr: *mut Graph) -> CResult<Node> {
380    graph_method_helper(graph_ptr, |g| g.get_output_node())
381}
382
383#[no_mangle]
384pub extern "C" fn graph_get_id(graph_ptr: *mut Graph) -> CResultVal<u64> {
385    graph_method_helper(graph_ptr, |g| Ok(g.get_id()))
386}
387
388#[no_mangle]
389pub extern "C" fn graph_get_num_nodes(graph_ptr: *mut Graph) -> CResultVal<u64> {
390    graph_method_helper(graph_ptr, |g| Ok(g.get_num_nodes()))
391}
392
393#[no_mangle]
394pub extern "C" fn graph_get_node_by_id(graph_ptr: *mut Graph, id: u64) -> CResult<Node> {
395    graph_method_helper(graph_ptr, |g| g.get_node_by_id(id))
396}
397
398#[no_mangle]
399pub extern "C" fn graph_get_context(graph_ptr: *mut Graph) -> CResult<Context> {
400    graph_method_helper(graph_ptr, |g| Ok(g.get_context()))
401}
402
403#[no_mangle]
404pub extern "C" fn graph_set_as_main(graph_ptr: *mut Graph) -> CResult<Graph> {
405    graph_method_helper(graph_ptr, |g| g.set_as_main())
406}
407
408#[no_mangle]
409pub extern "C" fn graph_set_name(graph_ptr: *mut Graph, name: CStr) -> CResult<Graph> {
410    graph_method_helper(graph_ptr, |g| g.set_name(name.to_str_slice()?))
411}
412
413#[no_mangle]
414pub extern "C" fn graph_get_name(graph_ptr: *mut Graph) -> CResultVal<CStr> {
415    graph_method_helper(graph_ptr, |g| CStr::from_string(g.get_name()?))
416}
417
418#[no_mangle]
419pub extern "C" fn graph_retrieve_node(graph_ptr: *mut Graph, name: CStr) -> CResult<Node> {
420    graph_method_helper(graph_ptr, |g| g.retrieve_node(name.to_str_slice()?))
421}
422
423#[no_mangle]
424pub extern "C" fn create_context() -> CResult<Context> {
425    let context_res = ciphercore_base::graphs::create_context();
426    CResult::new(context_res)
427}
428
429fn context_method_helper<T, F, R: CResultTrait<T>>(context_ptr: *mut Context, op: F) -> R
430where
431    F: FnOnce(Context) -> Result<T>,
432{
433    let helper = || -> Result<T> {
434        let context = unsafe_deref(context_ptr)?;
435        op(context)
436    };
437    R::new(helper())
438}
439
440#[no_mangle]
441pub extern "C" fn context_create_graph(context_ptr: *mut Context) -> CResult<Graph> {
442    context_method_helper(context_ptr, |c| c.create_graph())
443}
444
445#[no_mangle]
446pub extern "C" fn context_finalize(context_ptr: *mut Context) -> CResult<Context> {
447    context_method_helper(context_ptr, |c| c.finalize())
448}
449
450#[no_mangle]
451pub extern "C" fn context_set_main_graph(
452    context_ptr: *mut Context,
453    graph_ptr: *mut Graph,
454) -> CResult<Context> {
455    context_method_helper(context_ptr, |c| {
456        let graph = unsafe_deref(graph_ptr)?;
457        c.set_main_graph(graph)
458    })
459}
460
461#[no_mangle]
462pub extern "C" fn context_get_graphs(context_ptr: *mut Context) -> CResult<CVec<Graph>> {
463    context_method_helper(context_ptr, |c| Ok(CVec::from_vec(c.get_graphs())))
464}
465
466#[no_mangle]
467pub extern "C" fn context_check_finalized(context_ptr: *mut Context) -> CResultVal<bool> {
468    context_method_helper(context_ptr, |c| {
469        let res = c.check_finalized();
470        match res {
471            Ok(_) => Ok(true),
472            Err(e) => Err(e),
473        }
474    })
475}
476
477#[no_mangle]
478pub extern "C" fn context_get_main_graph(context_ptr: *mut Context) -> CResult<Graph> {
479    context_method_helper(context_ptr, |c| c.get_main_graph())
480}
481
482#[no_mangle]
483pub extern "C" fn context_get_num_graphs(context_ptr: *mut Context) -> CResultVal<u64> {
484    context_method_helper(context_ptr, |c| Ok(c.get_num_graphs()))
485}
486
487#[no_mangle]
488pub extern "C" fn context_get_graph_by_id(context_ptr: *mut Context, id: u64) -> CResult<Graph> {
489    context_method_helper(context_ptr, |c| c.get_graph_by_id(id))
490}
491
492#[no_mangle]
493pub extern "C" fn context_get_node_by_global_id(
494    context_ptr: *mut Context,
495    global_id: CVecVal<u64>,
496) -> CResult<Node> {
497    context_method_helper(context_ptr, |c| {
498        let g_id_vec = global_id.to_vec()?;
499        if g_id_vec.len() != 2 {
500            return Err(runtime_error!("Global Id vector should have two elements!"));
501        }
502        c.get_node_by_global_id((g_id_vec[0], g_id_vec[1]))
503    })
504}
505
506#[no_mangle]
507pub extern "C" fn context_to_string(context_ptr: *mut Context) -> CResultVal<CStr> {
508    context_method_helper(context_ptr, |c| {
509        CStr::from_string(serde_json::to_string(&c)?)
510    })
511}
512
513#[no_mangle]
514pub extern "C" fn contexts_deep_equal(
515    context1_ptr: *mut Context,
516    context2_ptr: *mut Context,
517) -> CResultVal<bool> {
518    context_method_helper(context1_ptr, |c| {
519        let c2 = unsafe_deref(context2_ptr)?;
520        Ok(ciphercore_base::graphs::contexts_deep_equal(c, c2))
521    })
522}
523
524#[no_mangle]
525pub extern "C" fn context_set_graph_name(
526    context_ptr: *mut Context,
527    graph_ptr: *mut Graph,
528    name: CStr,
529) -> CResult<Context> {
530    context_method_helper(context_ptr, |c| {
531        let graph = unsafe_deref(graph_ptr)?;
532        c.set_graph_name(graph, name.to_str_slice()?)
533    })
534}
535
536#[no_mangle]
537pub extern "C" fn context_get_graph_name(
538    context_ptr: *mut Context,
539    graph_ptr: *mut Graph,
540) -> CResultVal<CStr> {
541    context_method_helper(context_ptr, |c| {
542        let graph = unsafe_deref(graph_ptr)?;
543        CStr::from_string(c.get_graph_name(graph)?)
544    })
545}
546#[no_mangle]
547pub extern "C" fn context_retrieve_graph(context_ptr: *mut Context, name: CStr) -> CResult<Graph> {
548    context_method_helper(context_ptr, |c| c.retrieve_graph(name.to_str_slice()?))
549}
550
551#[no_mangle]
552pub extern "C" fn context_set_node_name(
553    context_ptr: *mut Context,
554    node_ptr: *mut Node,
555    name: CStr,
556) -> CResult<Context> {
557    context_method_helper(context_ptr, |c| {
558        let node = unsafe_deref(node_ptr)?;
559        c.set_node_name(node, name.to_str_slice()?)
560    })
561}
562
563#[no_mangle]
564pub extern "C" fn context_get_node_name(
565    context_ptr: *mut Context,
566    node_ptr: *mut Node,
567) -> CResultVal<CStr> {
568    context_method_helper(context_ptr, |c| {
569        let node = unsafe_deref(node_ptr)?;
570        CStr::from_string(c.get_node_name(node)?)
571    })
572}
573
574#[no_mangle]
575pub extern "C" fn context_retrieve_node(
576    context_ptr: *mut Context,
577    graph_ptr: *mut Graph,
578    name: CStr,
579) -> CResult<Node> {
580    context_method_helper(context_ptr, |c| {
581        let graph = unsafe_deref(graph_ptr)?;
582        c.retrieve_node(graph, name.to_str_slice()?)
583    })
584}
585
586#[no_mangle]
587pub extern "C" fn context_destroy(context_ptr: *mut Context) {
588    destroy_helper(context_ptr);
589}
590#[no_mangle]
591pub extern "C" fn graph_destroy(graph_ptr: *mut Graph) {
592    destroy_helper(graph_ptr);
593}
594
595#[no_mangle]
596pub extern "C" fn node_destroy(node_ptr: *mut Node) {
597    destroy_helper(node_ptr);
598}
599
600fn node_method_helper<T, F, R: CResultTrait<T>>(node_ptr: *mut Node, op: F) -> R
601where
602    F: FnOnce(Node) -> Result<T>,
603{
604    let helper = || -> Result<T> {
605        let node = unsafe_deref(node_ptr)?;
606        op(node)
607    };
608    R::new(helper())
609}
610fn node_one_node_method_helper<T, F, R: CResultTrait<T>>(
611    node_ptr: *mut Node,
612    b_ptr: *mut Node,
613    op: F,
614) -> R
615where
616    F: FnOnce(Node, Node) -> Result<T>,
617{
618    let helper = || -> Result<T> {
619        let node = unsafe_deref(node_ptr)?;
620        let b = unsafe_deref(b_ptr)?;
621        op(node, b)
622    };
623    R::new(helper())
624}
625#[no_mangle]
626pub extern "C" fn node_get_graph(node_ptr: *mut Node) -> CResult<Graph> {
627    node_method_helper(node_ptr, |n| Ok(n.get_graph()))
628}
629
630#[no_mangle]
631pub extern "C" fn node_get_dependencies(node_ptr: *mut Node) -> CResult<CVec<Node>> {
632    node_method_helper(node_ptr, |n| Ok(CVec::from_vec(n.get_node_dependencies())))
633}
634
635#[no_mangle]
636pub extern "C" fn node_graph_dependencies(node_ptr: *mut Node) -> CResult<CVec<Graph>> {
637    node_method_helper(node_ptr, |n| Ok(CVec::from_vec(n.get_graph_dependencies())))
638}
639#[no_mangle]
640pub extern "C" fn node_get_operation(node_ptr: *mut Node) -> CResult<COperation> {
641    node_method_helper(node_ptr, |n| COperation::from_operation(n.get_operation()))
642}
643
644#[no_mangle]
645pub extern "C" fn node_get_id(node_ptr: *mut Node) -> CResultVal<u64> {
646    node_method_helper(node_ptr, |n| Ok(n.get_id()))
647}
648
649#[no_mangle]
650pub extern "C" fn node_get_global_id(node_ptr: *mut Node) -> CResult<CVecVal<u64>> {
651    node_method_helper(node_ptr, |n| {
652        let ids = n.get_global_id();
653        let ids_vec = vec![ids.0, ids.1];
654        Ok(CVecVal::from_vec(ids_vec))
655    })
656}
657
658#[no_mangle]
659pub extern "C" fn node_get_type(node_ptr: *mut Node) -> CResult<Type> {
660    node_method_helper(node_ptr, |n| n.get_type())
661}
662
663#[no_mangle]
664pub extern "C" fn node_add(node_ptr: *mut Node, b_ptr: *mut Node) -> CResult<Node> {
665    node_one_node_method_helper(node_ptr, b_ptr, |a, b| a.add(b))
666}
667
668#[no_mangle]
669pub extern "C" fn node_subtract(node_ptr: *mut Node, b_ptr: *mut Node) -> CResult<Node> {
670    node_one_node_method_helper(node_ptr, b_ptr, |a, b| a.subtract(b))
671}
672
673#[no_mangle]
674pub extern "C" fn node_multiply(node_ptr: *mut Node, b_ptr: *mut Node) -> CResult<Node> {
675    node_one_node_method_helper(node_ptr, b_ptr, |a, b| a.multiply(b))
676}
677#[no_mangle]
678pub extern "C" fn node_dot(node_ptr: *mut Node, b_ptr: *mut Node) -> CResult<Node> {
679    node_one_node_method_helper(node_ptr, b_ptr, |a, b| a.dot(b))
680}
681
682#[no_mangle]
683pub extern "C" fn node_matmul(node_ptr: *mut Node, b_ptr: *mut Node) -> CResult<Node> {
684    node_one_node_method_helper(node_ptr, b_ptr, |a, b| a.matmul(b))
685}
686
687#[no_mangle]
688pub extern "C" fn node_truncate(node_ptr: *mut Node, scale: u64) -> CResult<Node> {
689    node_method_helper(node_ptr, |a| a.truncate(scale))
690}
691
692#[no_mangle]
693pub extern "C" fn node_sum(node_ptr: *mut Node, axis: CVecVal<u64>) -> CResult<Node> {
694    node_method_helper(node_ptr, |a| a.sum(axis.to_vec()?))
695}
696#[no_mangle]
697pub extern "C" fn node_permute_axes(node_ptr: *mut Node, axis: CVecVal<u64>) -> CResult<Node> {
698    node_method_helper(node_ptr, |a| a.permute_axes(axis.to_vec()?))
699}
700
701#[no_mangle]
702pub extern "C" fn node_get(node_ptr: *mut Node, index: CVecVal<u64>) -> CResult<Node> {
703    node_method_helper(node_ptr, |a| a.get(index.to_vec()?))
704}
705
706#[no_mangle]
707pub extern "C" fn node_get_slice(node_ptr: *mut Node, cslice: CSlice) -> CResult<Node> {
708    node_method_helper(node_ptr, |a| a.get_slice(cslice.to_slice()?))
709}
710
711#[no_mangle]
712pub extern "C" fn node_reshape(node_ptr: *mut Node, type_ptr: *mut Type) -> CResult<Node> {
713    node_method_helper(node_ptr, |a| {
714        let t = unsafe_deref(type_ptr)?;
715        a.reshape(t)
716    })
717}
718
719#[no_mangle]
720pub extern "C" fn node_nop(node_ptr: *mut Node) -> CResult<Node> {
721    node_method_helper(node_ptr, |a| a.nop())
722}
723
724#[no_mangle]
725pub extern "C" fn node_prf(
726    node_ptr: *mut Node,
727    iv: u64,
728    output_type_ptr: *mut Type,
729) -> CResult<Node> {
730    node_method_helper(node_ptr, |a| {
731        let output_type = unsafe_deref(output_type_ptr)?;
732        a.prf(iv, output_type)
733    })
734}
735
736#[no_mangle]
737pub extern "C" fn node_a2b(node_ptr: *mut Node) -> CResult<Node> {
738    node_method_helper(node_ptr, |a| a.a2b())
739}
740
741#[no_mangle]
742pub extern "C" fn node_b2a(node_ptr: *mut Node, scalar_type_ptr: *mut ScalarType) -> CResult<Node> {
743    node_method_helper(node_ptr, |a| {
744        let st = unsafe_deref(scalar_type_ptr)?;
745        a.b2a(st)
746    })
747}
748
749#[no_mangle]
750pub extern "C" fn node_tuple_get(node_ptr: *mut Node, index: u64) -> CResult<Node> {
751    node_method_helper(node_ptr, |a| a.tuple_get(index))
752}
753
754#[no_mangle]
755pub extern "C" fn node_named_tuple_get(node_ptr: *mut Node, key: CStr) -> CResult<Node> {
756    node_method_helper(node_ptr, |a| a.named_tuple_get(key.to_string()?))
757}
758
759#[no_mangle]
760pub extern "C" fn node_vector_get(node_ptr: *mut Node, index_node_ptr: *mut Node) -> CResult<Node> {
761    node_one_node_method_helper(node_ptr, index_node_ptr, |a, index| a.vector_get(index))
762}
763
764#[no_mangle]
765pub extern "C" fn node_array_to_vector(node_ptr: *mut Node) -> CResult<Node> {
766    node_method_helper(node_ptr, |a| a.array_to_vector())
767}
768
769#[no_mangle]
770pub extern "C" fn node_vector_to_array(node_ptr: *mut Node) -> CResult<Node> {
771    node_method_helper(node_ptr, |a| a.vector_to_array())
772}
773
774#[no_mangle]
775pub extern "C" fn node_repeat(node_ptr: *mut Node, n: u64) -> CResult<Node> {
776    node_method_helper(node_ptr, |a| a.repeat(n))
777}
778
779#[no_mangle]
780pub extern "C" fn node_set_as_output(node_ptr: *mut Node) -> CResult<Node> {
781    node_method_helper(node_ptr, |a| a.set_as_output())
782}