1use std::panic::catch_unwind;
6
7use sparse_ir::kernel::SVEHints;
8
9use crate::types::spir_kernel;
10use crate::{SPIR_COMPUTATION_SUCCESS, SPIR_INTERNAL_ERROR, SPIR_INVALID_ARGUMENT, StatusCode};
11
12#[unsafe(no_mangle)]
36pub extern "C" fn spir_logistic_kernel_new(
37 lambda: f64,
38 status: *mut StatusCode,
39) -> *mut spir_kernel {
40 if status.is_null() {
42 return std::ptr::null_mut();
43 }
44
45 if lambda <= 0.0 {
46 unsafe {
47 *status = SPIR_INVALID_ARGUMENT;
48 }
49 return std::ptr::null_mut();
50 }
51
52 let result = catch_unwind(|| {
54 let kernel = spir_kernel::new_logistic(lambda);
55 Box::into_raw(Box::new(kernel))
56 });
57
58 match result {
59 Ok(ptr) => {
60 unsafe {
61 *status = SPIR_COMPUTATION_SUCCESS;
62 }
63 ptr
64 }
65 Err(_) => {
66 unsafe {
67 *status = SPIR_INTERNAL_ERROR;
68 }
69 std::ptr::null_mut()
70 }
71 }
72}
73
74#[unsafe(no_mangle)]
83pub extern "C" fn spir_reg_bose_kernel_new(
84 lambda: f64,
85 status: *mut StatusCode,
86) -> *mut spir_kernel {
87 if status.is_null() {
88 return std::ptr::null_mut();
89 }
90
91 if lambda <= 0.0 {
92 unsafe {
93 *status = SPIR_INVALID_ARGUMENT;
94 }
95 return std::ptr::null_mut();
96 }
97
98 let result = catch_unwind(|| {
99 let kernel = spir_kernel::new_regularized_bose(lambda);
100 Box::into_raw(Box::new(kernel))
101 });
102
103 match result {
104 Ok(ptr) => {
105 unsafe {
106 *status = SPIR_COMPUTATION_SUCCESS;
107 }
108 ptr
109 }
110 Err(_) => {
111 unsafe {
112 *status = SPIR_INTERNAL_ERROR;
113 }
114 std::ptr::null_mut()
115 }
116 }
117}
118
119#[unsafe(no_mangle)]
130pub extern "C" fn spir_kernel_get_lambda(
131 kernel: *const spir_kernel,
132 lambda_out: *mut f64,
133) -> StatusCode {
134 if kernel.is_null() || lambda_out.is_null() {
135 return SPIR_INVALID_ARGUMENT;
136 }
137
138 let result = catch_unwind(|| unsafe {
139 let k = &*kernel;
140 *lambda_out = k.lambda();
141 SPIR_COMPUTATION_SUCCESS
142 });
143
144 result.unwrap_or(SPIR_INTERNAL_ERROR)
145}
146
147#[unsafe(no_mangle)]
160pub extern "C" fn spir_kernel_compute(
161 kernel: *const spir_kernel,
162 x: f64,
163 y: f64,
164 out: *mut f64,
165) -> StatusCode {
166 if kernel.is_null() || out.is_null() {
167 return SPIR_INVALID_ARGUMENT;
168 }
169
170 let result = catch_unwind(|| unsafe {
171 let k = &*kernel;
172 *out = k.compute(x, y);
173 SPIR_COMPUTATION_SUCCESS
174 });
175
176 result.unwrap_or(SPIR_INTERNAL_ERROR)
177}
178
179#[unsafe(no_mangle)]
185pub extern "C" fn spir_kernel_release(kernel: *mut spir_kernel) {
186 if !kernel.is_null() {
187 unsafe {
188 let _ = Box::from_raw(kernel);
191 }
192 }
193}
194
195#[unsafe(no_mangle)]
197pub extern "C" fn spir_kernel_clone(src: *const spir_kernel) -> *mut spir_kernel {
198 if src.is_null() {
199 return std::ptr::null_mut();
200 }
201
202 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| unsafe {
203 let src_ref = &*src;
204 let cloned = (*src_ref).clone();
205 Box::into_raw(Box::new(cloned))
206 }));
207
208 result.unwrap_or(std::ptr::null_mut())
209}
210
211#[unsafe(no_mangle)]
213pub extern "C" fn spir_kernel_is_assigned(obj: *const spir_kernel) -> i32 {
214 if obj.is_null() {
215 return 0;
216 }
217
218 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| unsafe {
219 let _ = &*obj;
220 1
221 }));
222
223 result.unwrap_or(0)
224}
225
226#[unsafe(no_mangle)]
240pub extern "C" fn spir_kernel_get_domain(
241 k: *const spir_kernel,
242 xmin: *mut f64,
243 xmax: *mut f64,
244 ymin: *mut f64,
245 ymax: *mut f64,
246) -> StatusCode {
247 if k.is_null() || xmin.is_null() || xmax.is_null() || ymin.is_null() || ymax.is_null() {
248 return SPIR_INVALID_ARGUMENT;
249 }
250
251 let result = catch_unwind(|| unsafe {
252 let kernel = &*k;
253 let (xmin_val, xmax_val, ymin_val, ymax_val) = kernel.domain();
254 *xmin = xmin_val;
255 *xmax = xmax_val;
256 *ymin = ymin_val;
257 *ymax = ymax_val;
258 SPIR_COMPUTATION_SUCCESS
259 });
260
261 result.unwrap_or(SPIR_INTERNAL_ERROR)
262}
263
264#[unsafe(no_mangle)]
281pub extern "C" fn spir_kernel_get_sve_hints_segments_x(
282 k: *const spir_kernel,
283 epsilon: f64,
284 segments: *mut f64,
285 n_segments: *mut libc::c_int,
286) -> StatusCode {
287 if k.is_null() || n_segments.is_null() {
288 return SPIR_INVALID_ARGUMENT;
289 }
290
291 if epsilon <= 0.0 || !epsilon.is_finite() {
292 return SPIR_INVALID_ARGUMENT;
293 }
294
295 let result = catch_unwind(|| unsafe {
296 let kernel = &*k;
297
298 let segs = match kernel.inner() {
300 crate::types::KernelType::Logistic(k) => {
301 use sparse_ir::kernel::KernelProperties;
302 let hints = k.sve_hints::<f64>(epsilon);
303 hints.segments_x()
304 }
305 crate::types::KernelType::RegularizedBose(k) => {
306 use sparse_ir::kernel::KernelProperties;
307 let hints = k.sve_hints::<f64>(epsilon);
308 hints.segments_x()
309 }
310 };
311
312 if segments.is_null() {
313 *n_segments = (segs.len() - 1) as libc::c_int;
315 return SPIR_COMPUTATION_SUCCESS;
316 }
317
318 if *n_segments < (segs.len() - 1) as libc::c_int {
320 return SPIR_INVALID_ARGUMENT;
321 }
322
323 for (i, &seg) in segs.iter().enumerate() {
324 *segments.add(i) = seg;
325 }
326 *n_segments = (segs.len() - 1) as libc::c_int;
327 SPIR_COMPUTATION_SUCCESS
328 });
329
330 result.unwrap_or(SPIR_INTERNAL_ERROR)
331}
332
333#[unsafe(no_mangle)]
350pub extern "C" fn spir_kernel_get_sve_hints_segments_y(
351 k: *const spir_kernel,
352 epsilon: f64,
353 segments: *mut f64,
354 n_segments: *mut libc::c_int,
355) -> StatusCode {
356 if k.is_null() || n_segments.is_null() {
357 return SPIR_INVALID_ARGUMENT;
358 }
359
360 if epsilon <= 0.0 || !epsilon.is_finite() {
361 return SPIR_INVALID_ARGUMENT;
362 }
363
364 let result = catch_unwind(|| unsafe {
365 let kernel = &*k;
366
367 let segs = match kernel.inner() {
369 crate::types::KernelType::Logistic(k) => {
370 use sparse_ir::kernel::KernelProperties;
371 let hints = k.sve_hints::<f64>(epsilon);
372 hints.segments_y()
373 }
374 crate::types::KernelType::RegularizedBose(k) => {
375 use sparse_ir::kernel::KernelProperties;
376 let hints = k.sve_hints::<f64>(epsilon);
377 hints.segments_y()
378 }
379 };
380
381 if segments.is_null() {
382 *n_segments = (segs.len() - 1) as libc::c_int;
384 return SPIR_COMPUTATION_SUCCESS;
385 }
386
387 if *n_segments < (segs.len() - 1) as libc::c_int {
389 return SPIR_INVALID_ARGUMENT;
390 }
391
392 for (i, &seg) in segs.iter().enumerate() {
393 *segments.add(i) = seg;
394 }
395 *n_segments = (segs.len() - 1) as libc::c_int;
396 SPIR_COMPUTATION_SUCCESS
397 });
398
399 result.unwrap_or(SPIR_INTERNAL_ERROR)
400}
401
402#[unsafe(no_mangle)]
414pub extern "C" fn spir_kernel_get_sve_hints_nsvals(
415 k: *const spir_kernel,
416 epsilon: f64,
417 nsvals: *mut libc::c_int,
418) -> StatusCode {
419 if k.is_null() || nsvals.is_null() {
420 return SPIR_INVALID_ARGUMENT;
421 }
422
423 if epsilon <= 0.0 || !epsilon.is_finite() {
424 return SPIR_INVALID_ARGUMENT;
425 }
426
427 let result = catch_unwind(|| unsafe {
428 let kernel = &*k;
429
430 let n = match kernel.inner() {
432 crate::types::KernelType::Logistic(k) => {
433 use sparse_ir::kernel::KernelProperties;
434 let hints = k.sve_hints::<f64>(epsilon);
435 hints.nsvals()
436 }
437 crate::types::KernelType::RegularizedBose(k) => {
438 use sparse_ir::kernel::KernelProperties;
439 let hints = k.sve_hints::<f64>(epsilon);
440 hints.nsvals()
441 }
442 };
443
444 *nsvals = n as libc::c_int;
445 SPIR_COMPUTATION_SUCCESS
446 });
447
448 result.unwrap_or(SPIR_INTERNAL_ERROR)
449}
450
451#[unsafe(no_mangle)]
463pub extern "C" fn spir_kernel_get_sve_hints_ngauss(
464 k: *const spir_kernel,
465 epsilon: f64,
466 ngauss: *mut libc::c_int,
467) -> StatusCode {
468 if k.is_null() || ngauss.is_null() {
469 return SPIR_INVALID_ARGUMENT;
470 }
471
472 if epsilon <= 0.0 || !epsilon.is_finite() {
473 return SPIR_INVALID_ARGUMENT;
474 }
475
476 let result = catch_unwind(|| unsafe {
477 let kernel = &*k;
478
479 let n = match kernel.inner() {
481 crate::types::KernelType::Logistic(k) => {
482 use sparse_ir::kernel::KernelProperties;
483 let hints = k.sve_hints::<f64>(epsilon);
484 hints.ngauss()
485 }
486 crate::types::KernelType::RegularizedBose(k) => {
487 use sparse_ir::kernel::KernelProperties;
488 let hints = k.sve_hints::<f64>(epsilon);
489 hints.ngauss()
490 }
491 };
492
493 *ngauss = n as libc::c_int;
494 SPIR_COMPUTATION_SUCCESS
495 });
496
497 result.unwrap_or(SPIR_INTERNAL_ERROR)
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503 use std::ptr;
504
505 #[test]
506 fn test_logistic_kernel_creation() {
507 let mut status = SPIR_INTERNAL_ERROR;
508 let kernel = spir_logistic_kernel_new(10.0, &mut status);
509
510 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
511 assert!(!kernel.is_null());
512
513 spir_kernel_release(kernel);
514 }
515
516 #[test]
517 fn test_regularized_bose_kernel_creation() {
518 let mut status = SPIR_INTERNAL_ERROR;
519 let kernel = spir_reg_bose_kernel_new(10.0, &mut status);
520
521 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
522 assert!(!kernel.is_null());
523
524 spir_kernel_release(kernel);
525 }
526
527 #[test]
528 fn test_kernel_lambda() {
529 let mut status = SPIR_INTERNAL_ERROR;
530 let kernel = spir_logistic_kernel_new(10.0, &mut status);
531 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
532
533 let mut lambda = 0.0;
534 let status = spir_kernel_get_lambda(kernel, &mut lambda);
535
536 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
537 assert_eq!(lambda, 10.0);
538
539 spir_kernel_release(kernel);
540 }
541
542 #[test]
543 fn test_kernel_compute() {
544 let mut status = SPIR_INTERNAL_ERROR;
545 let kernel = spir_logistic_kernel_new(10.0, &mut status);
546 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
547
548 let mut result = 0.0;
549 let status = spir_kernel_compute(kernel, 0.5, 0.5, &mut result);
550
551 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
552 assert!(result > 0.0); spir_kernel_release(kernel);
555 }
556
557 #[test]
558 fn test_null_pointer_errors() {
559 let kernel = spir_logistic_kernel_new(10.0, ptr::null_mut());
561 assert!(kernel.is_null());
562
563 let mut lambda = 0.0;
565 let status = spir_kernel_get_lambda(ptr::null(), &mut lambda);
566 assert_eq!(status, SPIR_INVALID_ARGUMENT);
567 }
568
569 #[test]
570 fn test_invalid_lambda() {
571 let mut status = SPIR_COMPUTATION_SUCCESS;
572
573 let kernel = spir_logistic_kernel_new(0.0, &mut status);
575 assert_eq!(status, SPIR_INVALID_ARGUMENT);
576 assert!(kernel.is_null());
577
578 let kernel = spir_logistic_kernel_new(-1.0, &mut status);
580 assert_eq!(status, SPIR_INVALID_ARGUMENT);
581 assert!(kernel.is_null());
582 }
583
584 #[test]
585 fn test_kernel_domain() {
586 let mut status = SPIR_INTERNAL_ERROR;
587 let kernel = spir_logistic_kernel_new(10.0, &mut status);
588 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
589
590 let mut xmin = 0.0;
591 let mut xmax = 0.0;
592 let mut ymin = 0.0;
593 let mut ymax = 0.0;
594 let status = spir_kernel_get_domain(kernel, &mut xmin, &mut xmax, &mut ymin, &mut ymax);
595
596 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
597 assert_eq!(xmin, -1.0);
598 assert_eq!(xmax, 1.0);
599 assert_eq!(ymin, -1.0);
600 assert_eq!(ymax, 1.0);
601
602 spir_kernel_release(kernel);
603 }
604
605 #[test]
606 fn test_kernel_get_sve_hints_nsvals() {
607 let lambda = 10.0;
608 let epsilon = 1e-8;
609
610 let mut status = SPIR_INTERNAL_ERROR;
611 let kernel = spir_logistic_kernel_new(lambda, &mut status);
612 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
613 assert!(!kernel.is_null());
614
615 let mut nsvals = 0;
616 let status = spir_kernel_get_sve_hints_nsvals(kernel, epsilon, &mut nsvals);
617 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
618 assert!(nsvals > 0);
619 assert!(nsvals >= 10);
620 assert!(nsvals <= 1000);
621
622 spir_kernel_release(kernel);
623 }
624
625 #[test]
626 fn test_kernel_get_sve_hints_ngauss() {
627 let lambda = 10.0;
628 let epsilon_coarse = 1e-6;
629 let epsilon_fine = 1e-10;
630
631 let mut status = SPIR_INTERNAL_ERROR;
632 let kernel = spir_logistic_kernel_new(lambda, &mut status);
633 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
634 assert!(!kernel.is_null());
635
636 let mut ngauss_coarse = 0;
637 let status = spir_kernel_get_sve_hints_ngauss(kernel, epsilon_coarse, &mut ngauss_coarse);
638 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
639 assert!(ngauss_coarse > 0);
640 assert_eq!(ngauss_coarse, 10); let mut ngauss_fine = 0;
643 let status = spir_kernel_get_sve_hints_ngauss(kernel, epsilon_fine, &mut ngauss_fine);
644 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
645 assert!(ngauss_fine > 0);
646 assert_eq!(ngauss_fine, 16); spir_kernel_release(kernel);
649 }
650
651 #[test]
652 fn test_kernel_get_sve_hints_segments_x() {
653 let lambda = 10.0;
654 let epsilon = 1e-8;
655
656 let mut status = SPIR_INTERNAL_ERROR;
657 let kernel = spir_logistic_kernel_new(lambda, &mut status);
658 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
659 assert!(!kernel.is_null());
660
661 let mut n_segments = 0;
663 let status =
664 spir_kernel_get_sve_hints_segments_x(kernel, epsilon, ptr::null_mut(), &mut n_segments);
665 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
666 assert!(n_segments > 0);
667
668 let mut segments = vec![0.0; (n_segments + 1) as usize];
670 let mut n_segments_out = n_segments + 1;
671 let status = spir_kernel_get_sve_hints_segments_x(
672 kernel,
673 epsilon,
674 segments.as_mut_ptr(),
675 &mut n_segments_out,
676 );
677 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
678 assert_eq!(n_segments_out, n_segments);
679
680 assert_eq!(segments.len(), (n_segments + 1) as usize);
682 assert!((segments[0] - (0.0)).abs() < 1e-10);
683 assert!((segments[n_segments as usize] - 1.0).abs() < 1e-10);
684
685 for i in 1..segments.len() {
687 assert!(segments[i] > segments[i - 1]);
688 }
689
690 spir_kernel_release(kernel);
691 }
692
693 #[test]
694 fn test_kernel_get_sve_hints_segments_y() {
695 let lambda = 10.0;
696 let epsilon = 1e-8;
697
698 let mut status = SPIR_INTERNAL_ERROR;
699 let kernel = spir_logistic_kernel_new(lambda, &mut status);
700 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
701 assert!(!kernel.is_null());
702
703 let mut n_segments = 0;
705 let status =
706 spir_kernel_get_sve_hints_segments_y(kernel, epsilon, ptr::null_mut(), &mut n_segments);
707 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
708 assert!(n_segments > 0);
709
710 let mut segments = vec![0.0; (n_segments + 1) as usize];
712 let mut n_segments_out = n_segments + 1;
713 let status = spir_kernel_get_sve_hints_segments_y(
714 kernel,
715 epsilon,
716 segments.as_mut_ptr(),
717 &mut n_segments_out,
718 );
719 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
720 assert_eq!(n_segments_out, n_segments);
721
722 assert_eq!(segments.len(), (n_segments + 1) as usize);
724 assert!((segments[0] - (0.0)).abs() < 1e-10);
725 assert!((segments[n_segments as usize] - 1.0).abs() < 1e-10);
726
727 for i in 1..segments.len() {
729 assert!(segments[i] > segments[i - 1]);
730 }
731
732 spir_kernel_release(kernel);
733 }
734
735 #[test]
736 fn test_kernel_get_sve_hints_with_regularized_bose() {
737 let lambda = 10.0;
738 let epsilon = 1e-8;
739
740 let mut status = SPIR_INTERNAL_ERROR;
741 let kernel = spir_reg_bose_kernel_new(lambda, &mut status);
742 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
743 assert!(!kernel.is_null());
744
745 let mut nsvals = 0;
747 let status = spir_kernel_get_sve_hints_nsvals(kernel, epsilon, &mut nsvals);
748 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
749 assert!(nsvals > 0);
750
751 let mut ngauss = 0;
753 let status = spir_kernel_get_sve_hints_ngauss(kernel, epsilon, &mut ngauss);
754 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
755 assert!(ngauss > 0);
756
757 let mut n_segments_x = 0;
759 let status = spir_kernel_get_sve_hints_segments_x(
760 kernel,
761 epsilon,
762 ptr::null_mut(),
763 &mut n_segments_x,
764 );
765 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
766 assert!(n_segments_x > 0);
767
768 let mut n_segments_y = 0;
770 let status = spir_kernel_get_sve_hints_segments_y(
771 kernel,
772 epsilon,
773 ptr::null_mut(),
774 &mut n_segments_y,
775 );
776 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
777 assert!(n_segments_y > 0);
778
779 spir_kernel_release(kernel);
780 }
781
782 #[test]
783 fn test_kernel_get_sve_hints_error_handling() {
784 let lambda = 10.0;
785 let epsilon = 1e-8;
786
787 let mut status = SPIR_INTERNAL_ERROR;
788 let kernel = spir_logistic_kernel_new(lambda, &mut status);
789 assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
790 assert!(!kernel.is_null());
791
792 let mut nsvals = 0;
794 let status = spir_kernel_get_sve_hints_nsvals(ptr::null(), epsilon, &mut nsvals);
795 assert_ne!(status, SPIR_COMPUTATION_SUCCESS);
796
797 let status = spir_kernel_get_sve_hints_nsvals(kernel, epsilon, ptr::null_mut());
799 assert_ne!(status, SPIR_COMPUTATION_SUCCESS);
800
801 let mut nsvals = 0;
803 let status = spir_kernel_get_sve_hints_nsvals(kernel, -1.0, &mut nsvals);
804 assert_ne!(status, SPIR_COMPUTATION_SUCCESS);
805
806 spir_kernel_release(kernel);
807 }
808}