1extern crate ciphercore_base;
2
3use crate::adapters_utils::{
4 destroy_helper, unsafe_deref, CCustomOperation, COperation, CResult, CResultTrait, CResultVal,
5 CSlice, CStr, CTypedValue, CVec, CVecVal,
6};
7use ciphercore_base::data_types::{ScalarType, Type};
8use ciphercore_base::errors::Result;
9use ciphercore_base::graphs::{Context, Graph, Node};
10use ciphercore_base::runtime_error;
11
12fn graph_method_helper<T, F, R: CResultTrait<T>>(graph_ptr: *mut Graph, op: F) -> R
13where
14 F: FnOnce(Graph) -> Result<T>,
15{
16 let helper = || -> Result<T> {
17 let graph = unsafe_deref(graph_ptr)?;
18 op(graph)
19 };
20 R::new(helper())
21}
22
23fn graph_one_node_method_helper<T, F, R: CResultTrait<T>>(
24 graph_ptr: *mut Graph,
25 a_ptr: *mut Node,
26 op: F,
27) -> R
28where
29 F: FnOnce(Graph, Node) -> Result<T>,
30{
31 let helper = || -> Result<T> {
32 let graph = unsafe_deref(graph_ptr)?;
33 let a = unsafe_deref(a_ptr)?;
34 op(graph, a)
35 };
36 R::new(helper())
37}
38
39fn graph_two_nodes_method_helper<T, F, R: CResultTrait<T>>(
40 graph_ptr: *mut Graph,
41 a_ptr: *mut Node,
42 b_ptr: *mut Node,
43 op: F,
44) -> R
45where
46 F: FnOnce(Graph, Node, Node) -> Result<T>,
47{
48 let helper = || -> Result<T> {
49 let graph = unsafe_deref(graph_ptr)?;
50 let a = unsafe_deref(a_ptr)?;
51 let b = unsafe_deref(b_ptr)?;
52 op(graph, a, b)
53 };
54 R::new(helper())
55}
56
57#[no_mangle]
58pub extern "C" fn graph_input(graph_ptr: *mut Graph, type_ptr: *mut Type) -> CResult<Node> {
59 graph_method_helper(graph_ptr, |g| {
60 let t = unsafe_deref(type_ptr)?;
61 g.input(t)
62 })
63}
64
65#[no_mangle]
66pub extern "C" fn graph_add(
67 graph_ptr: *mut Graph,
68 a_ptr: *mut Node,
69 b_ptr: *mut Node,
70) -> CResult<Node> {
71 graph_two_nodes_method_helper(graph_ptr, a_ptr, b_ptr, |g, a, b| g.add(a, b))
72}
73
74#[no_mangle]
75pub extern "C" fn graph_subtract(
76 graph_ptr: *mut Graph,
77 a_ptr: *mut Node,
78 b_ptr: *mut Node,
79) -> CResult<Node> {
80 graph_two_nodes_method_helper(graph_ptr, a_ptr, b_ptr, |g, a, b| g.subtract(a, b))
81}
82
83#[no_mangle]
84pub extern "C" fn graph_multiply(
85 graph_ptr: *mut Graph,
86 a_ptr: *mut Node,
87 b_ptr: *mut Node,
88) -> CResult<Node> {
89 graph_two_nodes_method_helper(graph_ptr, a_ptr, b_ptr, |g, a, b| g.multiply(a, b))
90}
91#[no_mangle]
92pub extern "C" fn graph_dot(
93 graph_ptr: *mut Graph,
94 a_ptr: *mut Node,
95 b_ptr: *mut Node,
96) -> CResult<Node> {
97 graph_two_nodes_method_helper(graph_ptr, a_ptr, b_ptr, |g, a, b| g.dot(a, b))
98}
99
100#[no_mangle]
101pub extern "C" fn graph_matmul(
102 graph_ptr: *mut Graph,
103 a_ptr: *mut Node,
104 b_ptr: *mut Node,
105) -> CResult<Node> {
106 graph_two_nodes_method_helper(graph_ptr, a_ptr, b_ptr, |g, a, b| g.matmul(a, b))
107}
108
109#[no_mangle]
110pub extern "C" fn graph_truncate(
111 graph_ptr: *mut Graph,
112 a_ptr: *mut Node,
113 scale: u64,
114) -> CResult<Node> {
115 graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| g.truncate(a, scale))
116}
117
118#[no_mangle]
119pub extern "C" fn graph_sum(
120 graph_ptr: *mut Graph,
121 a_ptr: *mut Node,
122 axis: CVecVal<u64>,
123) -> CResult<Node> {
124 graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| {
125 let axis_vec = axis.to_vec()?;
126 g.sum(a, axis_vec)
127 })
128}
129
130#[no_mangle]
131pub extern "C" fn graph_permute_axes(
132 graph_ptr: *mut Graph,
133 a_ptr: *mut Node,
134 axis: CVecVal<u64>,
135) -> CResult<Node> {
136 graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| {
137 let axis_vec = axis.to_vec()?;
138 g.permute_axes(a, axis_vec)
139 })
140}
141
142#[no_mangle]
143pub extern "C" fn graph_get(
144 graph_ptr: *mut Graph,
145 a_ptr: *mut Node,
146 index: CVecVal<u64>,
147) -> CResult<Node> {
148 graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| {
149 let index_vec = index.to_vec()?;
150 g.get(a, index_vec)
151 })
152}
153
154#[no_mangle]
155pub extern "C" fn graph_get_slice(
156 graph_ptr: *mut Graph,
157 a_ptr: *mut Node,
158 cslice: CSlice,
159) -> CResult<Node> {
160 graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| {
161 let slice = cslice.to_slice()?;
162 g.get_slice(a, slice)
163 })
164}
165
166#[no_mangle]
167pub extern "C" fn graph_reshape(
168 graph_ptr: *mut Graph,
169 a_ptr: *mut Node,
170 new_type_ptr: *mut Type,
171) -> CResult<Node> {
172 graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| {
173 let new_t = unsafe_deref(new_type_ptr)?;
174 g.reshape(a, new_t)
175 })
176}
177
178#[no_mangle]
179pub extern "C" fn graph_random(graph_ptr: *mut Graph, output_type_ptr: *mut Type) -> CResult<Node> {
180 graph_method_helper(graph_ptr, |g| {
181 let output_t = unsafe_deref(output_type_ptr)?;
182 g.random(output_t)
183 })
184}
185
186#[no_mangle]
187pub extern "C" fn graph_stack(
188 graph_ptr: *mut Graph,
189 nodes: CVec<Node>,
190 outer_shape: CVecVal<u64>,
191) -> CResult<Node> {
192 graph_method_helper(graph_ptr, |g| {
193 let nodes_vec = nodes.to_vec()?;
194 let outer_shape_vec = outer_shape.to_vec()?;
195 g.stack(nodes_vec, outer_shape_vec)
196 })
197}
198#[no_mangle]
199pub extern "C" fn graph_constant(graph_ptr: *mut Graph, typed_value: CTypedValue) -> CResult<Node> {
200 graph_method_helper(graph_ptr, |g| {
201 let (t, v) = typed_value.to_type_value()?;
202 g.constant(t, v)
203 })
204}
205
206#[no_mangle]
207pub extern "C" fn graph_a2b(graph_ptr: *mut Graph, a_ptr: *mut Node) -> CResult<Node> {
208 graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| g.a2b(a))
209}
210
211#[no_mangle]
212pub extern "C" fn graph_b2a(
213 graph_ptr: *mut Graph,
214 a_ptr: *mut Node,
215 scalar_type_ptr: *mut ScalarType,
216) -> CResult<Node> {
217 graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| {
218 let st = unsafe_deref(scalar_type_ptr)?;
219 g.b2a(a, st)
220 })
221}
222#[no_mangle]
223pub extern "C" fn graph_create_tuple(graph_ptr: *mut Graph, elements: CVec<Node>) -> CResult<Node> {
224 graph_method_helper(graph_ptr, |g| {
225 let elem = elements.to_vec()?;
226 g.create_tuple(elem)
227 })
228}
229
230#[no_mangle]
231pub extern "C" fn graph_create_vector(
232 graph_ptr: *mut Graph,
233 type_ptr: *mut Type,
234 elements: CVec<Node>,
235) -> CResult<Node> {
236 graph_method_helper(graph_ptr, |g| {
237 let t = unsafe_deref(type_ptr)?;
238 let elem = elements.to_vec()?;
239 g.create_vector(t, elem)
240 })
241}
242
243#[no_mangle]
244pub extern "C" fn graph_create_named_tuple(
245 graph_ptr: *mut Graph,
246 elements_nodes: CVec<Node>,
247 elements_names: CVecVal<CStr>,
248) -> CResult<Node> {
249 graph_method_helper(graph_ptr, |g| {
250 let elem_nodes = elements_nodes.to_vec()?;
251 let elem_names = elements_names.to_vec()?;
252 let elem_names_string: Vec<String> = elem_names
253 .iter()
254 .map(|x| -> Result<String> { x.to_string() })
255 .collect::<Result<Vec<String>>>()?;
256 let elem: Vec<(String, Node)> = elem_names_string
257 .iter()
258 .zip(elem_nodes.iter())
259 .map(|(x, y)| ((*x).clone(), (*y).clone()))
260 .collect();
261 g.create_named_tuple(elem)
262 })
263}
264#[no_mangle]
265pub extern "C" fn graph_tuple_get(
266 graph_ptr: *mut Graph,
267 tuple_node_ptr: *mut Node,
268 index: u64,
269) -> CResult<Node> {
270 graph_one_node_method_helper(graph_ptr, tuple_node_ptr, |g, t| g.tuple_get(t, index))
271}
272
273#[no_mangle]
274pub extern "C" fn graph_named_tuple_get(
275 graph_ptr: *mut Graph,
276 tuple_node_ptr: *mut Node,
277 key: CStr,
278) -> CResult<Node> {
279 graph_one_node_method_helper(graph_ptr, tuple_node_ptr, |g, t| {
280 let key_string = key.to_string()?;
281 g.named_tuple_get(t, key_string)
282 })
283}
284
285#[no_mangle]
286pub extern "C" fn graph_vector_get(
287 graph_ptr: *mut Graph,
288 vec_node_ptr: *mut Node,
289 index_node_ptr: *mut Node,
290) -> CResult<Node> {
291 graph_two_nodes_method_helper(graph_ptr, vec_node_ptr, index_node_ptr, |g, v, i| {
292 g.vector_get(v, i)
293 })
294}
295
296#[no_mangle]
297pub extern "C" fn graph_zip(graph_ptr: *mut Graph, elements: CVec<Node>) -> CResult<Node> {
298 graph_method_helper(graph_ptr, |g| g.zip(elements.to_vec()?))
299}
300
301#[no_mangle]
302pub extern "C" fn graph_repeat(graph_ptr: *mut Graph, a_ptr: *mut Node, n: u64) -> CResult<Node> {
303 graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| g.repeat(a, n))
304}
305
306#[no_mangle]
307pub extern "C" fn graph_call(
308 graph_ptr: *mut Graph,
309 callee_ptr: *mut Graph,
310 arguments: CVec<Node>,
311) -> CResult<Node> {
312 graph_method_helper(graph_ptr, |g| {
313 let callee = unsafe_deref(callee_ptr)?;
314 g.call(callee, arguments.to_vec()?)
315 })
316}
317
318#[no_mangle]
319pub extern "C" fn graph_iterate(
320 graph_ptr: *mut Graph,
321 callee_ptr: *mut Graph,
322 state_ptr: *mut Node,
323 input_ptr: *mut Node,
324) -> CResult<Node> {
325 graph_method_helper(graph_ptr, |g| {
326 let callee = unsafe_deref(callee_ptr)?;
327 let state = unsafe_deref(state_ptr)?;
328 let input = unsafe_deref(input_ptr)?;
329 g.iterate(callee, state, input)
330 })
331}
332
333#[no_mangle]
334pub extern "C" fn graph_vector_to_array(graph_ptr: *mut Graph, a_ptr: *mut Node) -> CResult<Node> {
335 graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| g.vector_to_array(a))
336}
337
338#[no_mangle]
339pub extern "C" fn graph_array_to_vector(graph_ptr: *mut Graph, a_ptr: *mut Node) -> CResult<Node> {
340 graph_one_node_method_helper(graph_ptr, a_ptr, |g, a| g.array_to_vector(a))
341}
342
343#[no_mangle]
344pub extern "C" fn graph_custom_op(
345 graph_ptr: *mut Graph,
346 c_custom_op: CCustomOperation,
347 args: CVec<Node>,
348) -> CResult<Node> {
349 graph_method_helper(graph_ptr, |g| {
350 g.custom_op(c_custom_op.to_custom_op()?, args.to_vec()?)
351 })
352}
353
354#[no_mangle]
355pub extern "C" fn graph_finalize(graph_ptr: *mut Graph) -> CResult<Graph> {
356 graph_method_helper(graph_ptr, |g| g.finalize())
357}
358
359#[no_mangle]
360pub extern "C" fn graph_get_nodes(graph_ptr: *mut Graph) -> CResult<CVec<Node>> {
361 graph_method_helper(graph_ptr, |g| Ok(CVec::from_vec(g.get_nodes())))
362}
363
364#[no_mangle]
365pub extern "C" fn graph_set_output_node(
366 graph_ptr: *mut Graph,
367 n_ptr: *mut Node,
368) -> CResultVal<bool> {
369 graph_one_node_method_helper(graph_ptr, n_ptr, |g, n| {
370 let res = g.set_output_node(n);
371 match res {
372 Ok(_) => Ok(true),
373 Err(e) => Err(e),
374 }
375 })
376}
377
378#[no_mangle]
379pub extern "C" fn graph_get_output_node(graph_ptr: *mut Graph) -> CResult<Node> {
380 graph_method_helper(graph_ptr, |g| g.get_output_node())
381}
382
383#[no_mangle]
384pub extern "C" fn graph_get_id(graph_ptr: *mut Graph) -> CResultVal<u64> {
385 graph_method_helper(graph_ptr, |g| Ok(g.get_id()))
386}
387
388#[no_mangle]
389pub extern "C" fn graph_get_num_nodes(graph_ptr: *mut Graph) -> CResultVal<u64> {
390 graph_method_helper(graph_ptr, |g| Ok(g.get_num_nodes()))
391}
392
393#[no_mangle]
394pub extern "C" fn graph_get_node_by_id(graph_ptr: *mut Graph, id: u64) -> CResult<Node> {
395 graph_method_helper(graph_ptr, |g| g.get_node_by_id(id))
396}
397
398#[no_mangle]
399pub extern "C" fn graph_get_context(graph_ptr: *mut Graph) -> CResult<Context> {
400 graph_method_helper(graph_ptr, |g| Ok(g.get_context()))
401}
402
403#[no_mangle]
404pub extern "C" fn graph_set_as_main(graph_ptr: *mut Graph) -> CResult<Graph> {
405 graph_method_helper(graph_ptr, |g| g.set_as_main())
406}
407
408#[no_mangle]
409pub extern "C" fn graph_set_name(graph_ptr: *mut Graph, name: CStr) -> CResult<Graph> {
410 graph_method_helper(graph_ptr, |g| g.set_name(name.to_str_slice()?))
411}
412
413#[no_mangle]
414pub extern "C" fn graph_get_name(graph_ptr: *mut Graph) -> CResultVal<CStr> {
415 graph_method_helper(graph_ptr, |g| CStr::from_string(g.get_name()?))
416}
417
418#[no_mangle]
419pub extern "C" fn graph_retrieve_node(graph_ptr: *mut Graph, name: CStr) -> CResult<Node> {
420 graph_method_helper(graph_ptr, |g| g.retrieve_node(name.to_str_slice()?))
421}
422
423#[no_mangle]
424pub extern "C" fn create_context() -> CResult<Context> {
425 let context_res = ciphercore_base::graphs::create_context();
426 CResult::new(context_res)
427}
428
429fn context_method_helper<T, F, R: CResultTrait<T>>(context_ptr: *mut Context, op: F) -> R
430where
431 F: FnOnce(Context) -> Result<T>,
432{
433 let helper = || -> Result<T> {
434 let context = unsafe_deref(context_ptr)?;
435 op(context)
436 };
437 R::new(helper())
438}
439
440#[no_mangle]
441pub extern "C" fn context_create_graph(context_ptr: *mut Context) -> CResult<Graph> {
442 context_method_helper(context_ptr, |c| c.create_graph())
443}
444
445#[no_mangle]
446pub extern "C" fn context_finalize(context_ptr: *mut Context) -> CResult<Context> {
447 context_method_helper(context_ptr, |c| c.finalize())
448}
449
450#[no_mangle]
451pub extern "C" fn context_set_main_graph(
452 context_ptr: *mut Context,
453 graph_ptr: *mut Graph,
454) -> CResult<Context> {
455 context_method_helper(context_ptr, |c| {
456 let graph = unsafe_deref(graph_ptr)?;
457 c.set_main_graph(graph)
458 })
459}
460
461#[no_mangle]
462pub extern "C" fn context_get_graphs(context_ptr: *mut Context) -> CResult<CVec<Graph>> {
463 context_method_helper(context_ptr, |c| Ok(CVec::from_vec(c.get_graphs())))
464}
465
466#[no_mangle]
467pub extern "C" fn context_check_finalized(context_ptr: *mut Context) -> CResultVal<bool> {
468 context_method_helper(context_ptr, |c| {
469 let res = c.check_finalized();
470 match res {
471 Ok(_) => Ok(true),
472 Err(e) => Err(e),
473 }
474 })
475}
476
477#[no_mangle]
478pub extern "C" fn context_get_main_graph(context_ptr: *mut Context) -> CResult<Graph> {
479 context_method_helper(context_ptr, |c| c.get_main_graph())
480}
481
482#[no_mangle]
483pub extern "C" fn context_get_num_graphs(context_ptr: *mut Context) -> CResultVal<u64> {
484 context_method_helper(context_ptr, |c| Ok(c.get_num_graphs()))
485}
486
487#[no_mangle]
488pub extern "C" fn context_get_graph_by_id(context_ptr: *mut Context, id: u64) -> CResult<Graph> {
489 context_method_helper(context_ptr, |c| c.get_graph_by_id(id))
490}
491
492#[no_mangle]
493pub extern "C" fn context_get_node_by_global_id(
494 context_ptr: *mut Context,
495 global_id: CVecVal<u64>,
496) -> CResult<Node> {
497 context_method_helper(context_ptr, |c| {
498 let g_id_vec = global_id.to_vec()?;
499 if g_id_vec.len() != 2 {
500 return Err(runtime_error!("Global Id vector should have two elements!"));
501 }
502 c.get_node_by_global_id((g_id_vec[0], g_id_vec[1]))
503 })
504}
505
506#[no_mangle]
507pub extern "C" fn context_to_string(context_ptr: *mut Context) -> CResultVal<CStr> {
508 context_method_helper(context_ptr, |c| {
509 CStr::from_string(serde_json::to_string(&c)?)
510 })
511}
512
513#[no_mangle]
514pub extern "C" fn contexts_deep_equal(
515 context1_ptr: *mut Context,
516 context2_ptr: *mut Context,
517) -> CResultVal<bool> {
518 context_method_helper(context1_ptr, |c| {
519 let c2 = unsafe_deref(context2_ptr)?;
520 Ok(ciphercore_base::graphs::contexts_deep_equal(c, c2))
521 })
522}
523
524#[no_mangle]
525pub extern "C" fn context_set_graph_name(
526 context_ptr: *mut Context,
527 graph_ptr: *mut Graph,
528 name: CStr,
529) -> CResult<Context> {
530 context_method_helper(context_ptr, |c| {
531 let graph = unsafe_deref(graph_ptr)?;
532 c.set_graph_name(graph, name.to_str_slice()?)
533 })
534}
535
536#[no_mangle]
537pub extern "C" fn context_get_graph_name(
538 context_ptr: *mut Context,
539 graph_ptr: *mut Graph,
540) -> CResultVal<CStr> {
541 context_method_helper(context_ptr, |c| {
542 let graph = unsafe_deref(graph_ptr)?;
543 CStr::from_string(c.get_graph_name(graph)?)
544 })
545}
546#[no_mangle]
547pub extern "C" fn context_retrieve_graph(context_ptr: *mut Context, name: CStr) -> CResult<Graph> {
548 context_method_helper(context_ptr, |c| c.retrieve_graph(name.to_str_slice()?))
549}
550
551#[no_mangle]
552pub extern "C" fn context_set_node_name(
553 context_ptr: *mut Context,
554 node_ptr: *mut Node,
555 name: CStr,
556) -> CResult<Context> {
557 context_method_helper(context_ptr, |c| {
558 let node = unsafe_deref(node_ptr)?;
559 c.set_node_name(node, name.to_str_slice()?)
560 })
561}
562
563#[no_mangle]
564pub extern "C" fn context_get_node_name(
565 context_ptr: *mut Context,
566 node_ptr: *mut Node,
567) -> CResultVal<CStr> {
568 context_method_helper(context_ptr, |c| {
569 let node = unsafe_deref(node_ptr)?;
570 CStr::from_string(c.get_node_name(node)?)
571 })
572}
573
574#[no_mangle]
575pub extern "C" fn context_retrieve_node(
576 context_ptr: *mut Context,
577 graph_ptr: *mut Graph,
578 name: CStr,
579) -> CResult<Node> {
580 context_method_helper(context_ptr, |c| {
581 let graph = unsafe_deref(graph_ptr)?;
582 c.retrieve_node(graph, name.to_str_slice()?)
583 })
584}
585
586#[no_mangle]
587pub extern "C" fn context_destroy(context_ptr: *mut Context) {
588 destroy_helper(context_ptr);
589}
590#[no_mangle]
591pub extern "C" fn graph_destroy(graph_ptr: *mut Graph) {
592 destroy_helper(graph_ptr);
593}
594
595#[no_mangle]
596pub extern "C" fn node_destroy(node_ptr: *mut Node) {
597 destroy_helper(node_ptr);
598}
599
600fn node_method_helper<T, F, R: CResultTrait<T>>(node_ptr: *mut Node, op: F) -> R
601where
602 F: FnOnce(Node) -> Result<T>,
603{
604 let helper = || -> Result<T> {
605 let node = unsafe_deref(node_ptr)?;
606 op(node)
607 };
608 R::new(helper())
609}
610fn node_one_node_method_helper<T, F, R: CResultTrait<T>>(
611 node_ptr: *mut Node,
612 b_ptr: *mut Node,
613 op: F,
614) -> R
615where
616 F: FnOnce(Node, Node) -> Result<T>,
617{
618 let helper = || -> Result<T> {
619 let node = unsafe_deref(node_ptr)?;
620 let b = unsafe_deref(b_ptr)?;
621 op(node, b)
622 };
623 R::new(helper())
624}
625#[no_mangle]
626pub extern "C" fn node_get_graph(node_ptr: *mut Node) -> CResult<Graph> {
627 node_method_helper(node_ptr, |n| Ok(n.get_graph()))
628}
629
630#[no_mangle]
631pub extern "C" fn node_get_dependencies(node_ptr: *mut Node) -> CResult<CVec<Node>> {
632 node_method_helper(node_ptr, |n| Ok(CVec::from_vec(n.get_node_dependencies())))
633}
634
635#[no_mangle]
636pub extern "C" fn node_graph_dependencies(node_ptr: *mut Node) -> CResult<CVec<Graph>> {
637 node_method_helper(node_ptr, |n| Ok(CVec::from_vec(n.get_graph_dependencies())))
638}
639#[no_mangle]
640pub extern "C" fn node_get_operation(node_ptr: *mut Node) -> CResult<COperation> {
641 node_method_helper(node_ptr, |n| COperation::from_operation(n.get_operation()))
642}
643
644#[no_mangle]
645pub extern "C" fn node_get_id(node_ptr: *mut Node) -> CResultVal<u64> {
646 node_method_helper(node_ptr, |n| Ok(n.get_id()))
647}
648
649#[no_mangle]
650pub extern "C" fn node_get_global_id(node_ptr: *mut Node) -> CResult<CVecVal<u64>> {
651 node_method_helper(node_ptr, |n| {
652 let ids = n.get_global_id();
653 let ids_vec = vec![ids.0, ids.1];
654 Ok(CVecVal::from_vec(ids_vec))
655 })
656}
657
658#[no_mangle]
659pub extern "C" fn node_get_type(node_ptr: *mut Node) -> CResult<Type> {
660 node_method_helper(node_ptr, |n| n.get_type())
661}
662
663#[no_mangle]
664pub extern "C" fn node_add(node_ptr: *mut Node, b_ptr: *mut Node) -> CResult<Node> {
665 node_one_node_method_helper(node_ptr, b_ptr, |a, b| a.add(b))
666}
667
668#[no_mangle]
669pub extern "C" fn node_subtract(node_ptr: *mut Node, b_ptr: *mut Node) -> CResult<Node> {
670 node_one_node_method_helper(node_ptr, b_ptr, |a, b| a.subtract(b))
671}
672
673#[no_mangle]
674pub extern "C" fn node_multiply(node_ptr: *mut Node, b_ptr: *mut Node) -> CResult<Node> {
675 node_one_node_method_helper(node_ptr, b_ptr, |a, b| a.multiply(b))
676}
677#[no_mangle]
678pub extern "C" fn node_dot(node_ptr: *mut Node, b_ptr: *mut Node) -> CResult<Node> {
679 node_one_node_method_helper(node_ptr, b_ptr, |a, b| a.dot(b))
680}
681
682#[no_mangle]
683pub extern "C" fn node_matmul(node_ptr: *mut Node, b_ptr: *mut Node) -> CResult<Node> {
684 node_one_node_method_helper(node_ptr, b_ptr, |a, b| a.matmul(b))
685}
686
687#[no_mangle]
688pub extern "C" fn node_truncate(node_ptr: *mut Node, scale: u64) -> CResult<Node> {
689 node_method_helper(node_ptr, |a| a.truncate(scale))
690}
691
692#[no_mangle]
693pub extern "C" fn node_sum(node_ptr: *mut Node, axis: CVecVal<u64>) -> CResult<Node> {
694 node_method_helper(node_ptr, |a| a.sum(axis.to_vec()?))
695}
696#[no_mangle]
697pub extern "C" fn node_permute_axes(node_ptr: *mut Node, axis: CVecVal<u64>) -> CResult<Node> {
698 node_method_helper(node_ptr, |a| a.permute_axes(axis.to_vec()?))
699}
700
701#[no_mangle]
702pub extern "C" fn node_get(node_ptr: *mut Node, index: CVecVal<u64>) -> CResult<Node> {
703 node_method_helper(node_ptr, |a| a.get(index.to_vec()?))
704}
705
706#[no_mangle]
707pub extern "C" fn node_get_slice(node_ptr: *mut Node, cslice: CSlice) -> CResult<Node> {
708 node_method_helper(node_ptr, |a| a.get_slice(cslice.to_slice()?))
709}
710
711#[no_mangle]
712pub extern "C" fn node_reshape(node_ptr: *mut Node, type_ptr: *mut Type) -> CResult<Node> {
713 node_method_helper(node_ptr, |a| {
714 let t = unsafe_deref(type_ptr)?;
715 a.reshape(t)
716 })
717}
718
719#[no_mangle]
720pub extern "C" fn node_nop(node_ptr: *mut Node) -> CResult<Node> {
721 node_method_helper(node_ptr, |a| a.nop())
722}
723
724#[no_mangle]
725pub extern "C" fn node_prf(
726 node_ptr: *mut Node,
727 iv: u64,
728 output_type_ptr: *mut Type,
729) -> CResult<Node> {
730 node_method_helper(node_ptr, |a| {
731 let output_type = unsafe_deref(output_type_ptr)?;
732 a.prf(iv, output_type)
733 })
734}
735
736#[no_mangle]
737pub extern "C" fn node_a2b(node_ptr: *mut Node) -> CResult<Node> {
738 node_method_helper(node_ptr, |a| a.a2b())
739}
740
741#[no_mangle]
742pub extern "C" fn node_b2a(node_ptr: *mut Node, scalar_type_ptr: *mut ScalarType) -> CResult<Node> {
743 node_method_helper(node_ptr, |a| {
744 let st = unsafe_deref(scalar_type_ptr)?;
745 a.b2a(st)
746 })
747}
748
749#[no_mangle]
750pub extern "C" fn node_tuple_get(node_ptr: *mut Node, index: u64) -> CResult<Node> {
751 node_method_helper(node_ptr, |a| a.tuple_get(index))
752}
753
754#[no_mangle]
755pub extern "C" fn node_named_tuple_get(node_ptr: *mut Node, key: CStr) -> CResult<Node> {
756 node_method_helper(node_ptr, |a| a.named_tuple_get(key.to_string()?))
757}
758
759#[no_mangle]
760pub extern "C" fn node_vector_get(node_ptr: *mut Node, index_node_ptr: *mut Node) -> CResult<Node> {
761 node_one_node_method_helper(node_ptr, index_node_ptr, |a, index| a.vector_get(index))
762}
763
764#[no_mangle]
765pub extern "C" fn node_array_to_vector(node_ptr: *mut Node) -> CResult<Node> {
766 node_method_helper(node_ptr, |a| a.array_to_vector())
767}
768
769#[no_mangle]
770pub extern "C" fn node_vector_to_array(node_ptr: *mut Node) -> CResult<Node> {
771 node_method_helper(node_ptr, |a| a.vector_to_array())
772}
773
774#[no_mangle]
775pub extern "C" fn node_repeat(node_ptr: *mut Node, n: u64) -> CResult<Node> {
776 node_method_helper(node_ptr, |a| a.repeat(n))
777}
778
779#[no_mangle]
780pub extern "C" fn node_set_as_output(node_ptr: *mut Node) -> CResult<Node> {
781 node_method_helper(node_ptr, |a| a.set_as_output())
782}