use std::panic::{AssertUnwindSafe, catch_unwind};
use sparse_ir::basis::FiniteTempBasis;
use crate::types::{spir_basis, spir_funcs, spir_kernel, spir_sve_result};
use crate::{
SPIR_COMPUTATION_SUCCESS, SPIR_INTERNAL_ERROR, SPIR_INVALID_ARGUMENT, SPIR_STATISTICS_BOSONIC,
SPIR_STATISTICS_FERMIONIC, StatusCode,
};
#[unsafe(no_mangle)]
pub extern "C" fn spir_basis_release(basis: *mut spir_basis) {
if !basis.is_null() {
unsafe {
let _ = Box::from_raw(basis);
}
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn spir_basis_clone(src: *const spir_basis) -> *mut spir_basis {
if src.is_null() {
return std::ptr::null_mut();
}
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| unsafe {
let src_ref = &*src;
let cloned = (*src_ref).clone();
Box::into_raw(Box::new(cloned))
}));
result.unwrap_or(std::ptr::null_mut())
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_basis_is_assigned(obj: *const spir_basis) -> i32 {
if obj.is_null() { 0 } else { 1 }
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_basis_new(
statistics: libc::c_int,
beta: f64,
omega_max: f64,
epsilon: f64,
k: *const spir_kernel,
sve: *const spir_sve_result,
max_size: libc::c_int,
status: *mut StatusCode,
) -> *mut spir_basis {
if status.is_null() {
return std::ptr::null_mut();
}
if beta <= 0.0 || omega_max <= 0.0 || epsilon <= 0.0 {
unsafe {
*status = SPIR_INVALID_ARGUMENT;
}
return std::ptr::null_mut();
}
if statistics != SPIR_STATISTICS_BOSONIC && statistics != SPIR_STATISTICS_FERMIONIC {
unsafe {
*status = SPIR_INVALID_ARGUMENT;
}
return std::ptr::null_mut();
}
if k.is_null() {
unsafe {
*status = SPIR_INVALID_ARGUMENT;
}
return std::ptr::null_mut();
}
let max_size_opt = if max_size < 0 {
None
} else {
Some(max_size as usize)
};
let result = catch_unwind(AssertUnwindSafe(|| unsafe {
let kernel_ref = &*k;
let expected_lambda = beta * omega_max;
let kernel_lambda = kernel_ref.lambda();
if (kernel_lambda - expected_lambda).abs() > 1e-10 {
return Err(format!(
"Kernel lambda ({}) does not match beta * omega_max ({})",
kernel_lambda, expected_lambda
));
}
if let Some(logistic) = kernel_ref.as_logistic() {
if statistics == SPIR_STATISTICS_FERMIONIC {
let basis: FiniteTempBasis<_, _> = if !sve.is_null() {
let sve_ref = &*sve;
FiniteTempBasis::from_sve_result(
**logistic,
beta,
sve_ref.inner().as_ref().clone(),
Some(epsilon),
max_size_opt,
)
} else {
FiniteTempBasis::new(**logistic, beta, Some(epsilon), max_size_opt)
};
Ok(Box::into_raw(Box::new(spir_basis::new_logistic_fermionic(
basis,
))))
} else {
let basis: FiniteTempBasis<_, _> = if !sve.is_null() {
let sve_ref = &*sve;
FiniteTempBasis::from_sve_result(
**logistic,
beta,
sve_ref.inner().as_ref().clone(),
Some(epsilon),
max_size_opt,
)
} else {
FiniteTempBasis::new(**logistic, beta, Some(epsilon), max_size_opt)
};
Ok(Box::into_raw(Box::new(spir_basis::new_logistic_bosonic(
basis,
))))
}
} else if let Some(reg_bose) = kernel_ref.as_regularized_bose() {
if statistics == SPIR_STATISTICS_FERMIONIC {
let basis: FiniteTempBasis<_, _> = if !sve.is_null() {
let sve_ref = &*sve;
FiniteTempBasis::from_sve_result(
**reg_bose,
beta,
sve_ref.inner().as_ref().clone(),
Some(epsilon),
max_size_opt,
)
} else {
FiniteTempBasis::new(**reg_bose, beta, Some(epsilon), max_size_opt)
};
Ok(Box::into_raw(Box::new(
spir_basis::new_regularized_bose_fermionic(basis),
)))
} else {
let basis: FiniteTempBasis<_, _> = if !sve.is_null() {
let sve_ref = &*sve;
FiniteTempBasis::from_sve_result(
**reg_bose,
beta,
sve_ref.inner().as_ref().clone(),
Some(epsilon),
max_size_opt,
)
} else {
FiniteTempBasis::new(**reg_bose, beta, Some(epsilon), max_size_opt)
};
Ok(Box::into_raw(Box::new(
spir_basis::new_regularized_bose_bosonic(basis),
)))
}
} else {
Err("Unknown kernel type".to_string())
}
}));
match result {
Ok(Ok(ptr)) => {
unsafe {
*status = SPIR_COMPUTATION_SUCCESS;
}
ptr
}
Ok(Err(_)) | Err(_) => {
unsafe {
*status = SPIR_INTERNAL_ERROR;
}
std::ptr::null_mut()
}
}
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_basis_new_from_sve_and_regularizer(
statistics: libc::c_int,
beta: f64,
omega_max: f64,
epsilon: f64,
lambda: f64,
ypower: libc::c_int,
_conv_radius: f64,
sve: *const spir_sve_result,
regularizer_funcs: *const spir_funcs,
max_size: libc::c_int,
status: *mut StatusCode,
) -> *mut spir_basis {
if status.is_null() {
return std::ptr::null_mut();
}
if beta <= 0.0 || omega_max <= 0.0 || epsilon <= 0.0 || lambda <= 0.0 {
unsafe {
*status = SPIR_INVALID_ARGUMENT;
}
return std::ptr::null_mut();
}
if statistics != SPIR_STATISTICS_BOSONIC && statistics != SPIR_STATISTICS_FERMIONIC {
unsafe {
*status = SPIR_INVALID_ARGUMENT;
}
return std::ptr::null_mut();
}
if sve.is_null() || regularizer_funcs.is_null() {
unsafe {
*status = SPIR_INVALID_ARGUMENT;
}
return std::ptr::null_mut();
}
let expected_lambda = beta * omega_max;
if (lambda - expected_lambda).abs() > 1e-10 {
unsafe {
*status = SPIR_INVALID_ARGUMENT;
}
return std::ptr::null_mut();
}
if ypower != 0 && ypower != 1 {
unsafe {
*status = SPIR_INVALID_ARGUMENT;
}
return std::ptr::null_mut();
}
let max_size_opt = if max_size < 0 {
None
} else {
Some(max_size as usize)
};
let result = catch_unwind(AssertUnwindSafe(|| unsafe {
let sve_ref = &*sve;
let sve_result = sve_ref.inner().as_ref().clone();
let test_omega = omega_max / 2.0;
let _regularizer_value = match (*regularizer_funcs).eval_continuous(test_omega) {
Some(values) if !values.is_empty() => values[0],
_ => {
1.0
}
};
if ypower == 0 {
use sparse_ir::kernel::LogisticKernel;
let kernel = LogisticKernel::new(lambda);
if statistics == SPIR_STATISTICS_FERMIONIC {
let basis =
FiniteTempBasis::<LogisticKernel, sparse_ir::traits::Fermionic>::from_sve_result(
kernel,
beta,
sve_result,
Some(epsilon),
max_size_opt,
);
Ok::<*mut spir_basis, StatusCode>(Box::into_raw(Box::new(
spir_basis::new_logistic_fermionic(basis),
)))
} else {
let basis =
FiniteTempBasis::<LogisticKernel, sparse_ir::traits::Bosonic>::from_sve_result(
kernel,
beta,
sve_result,
Some(epsilon),
max_size_opt,
);
Ok::<*mut spir_basis, StatusCode>(Box::into_raw(Box::new(
spir_basis::new_logistic_bosonic(basis),
)))
}
} else {
use sparse_ir::kernel::RegularizedBoseKernel;
let kernel = RegularizedBoseKernel::new(lambda);
if statistics == SPIR_STATISTICS_FERMIONIC {
let basis =
FiniteTempBasis::<RegularizedBoseKernel, sparse_ir::traits::Fermionic>::from_sve_result(
kernel,
beta,
sve_result,
Some(epsilon),
max_size_opt,
);
Ok::<*mut spir_basis, StatusCode>(Box::into_raw(Box::new(
spir_basis::new_regularized_bose_fermionic(basis),
)))
} else {
let basis =
FiniteTempBasis::<RegularizedBoseKernel, sparse_ir::traits::Bosonic>::from_sve_result(
kernel,
beta,
sve_result,
Some(epsilon),
max_size_opt,
);
Ok::<*mut spir_basis, StatusCode>(Box::into_raw(Box::new(
spir_basis::new_regularized_bose_bosonic(basis),
)))
}
}
}));
match result {
Ok(Ok(ptr)) => {
unsafe {
*status = SPIR_COMPUTATION_SUCCESS;
}
ptr
}
Ok(Err(msg)) => {
debug_eprintln!("Error in spir_basis_new_from_sve_and_regularizer: {}", msg);
unsafe {
*status = SPIR_INTERNAL_ERROR;
}
std::ptr::null_mut()
}
Err(_) => {
unsafe {
*status = SPIR_INTERNAL_ERROR;
}
std::ptr::null_mut()
}
}
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_basis_get_size(b: *const spir_basis, size: *mut libc::c_int) -> StatusCode {
if b.is_null() || size.is_null() {
return SPIR_INVALID_ARGUMENT;
}
let result = catch_unwind(AssertUnwindSafe(|| unsafe {
let basis = &*b;
*size = basis.size() as libc::c_int;
SPIR_COMPUTATION_SUCCESS
}));
result.unwrap_or(SPIR_INTERNAL_ERROR)
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_basis_get_svals(b: *const spir_basis, svals: *mut f64) -> StatusCode {
use std::io::Write;
if b.is_null() || svals.is_null() {
return SPIR_INVALID_ARGUMENT;
}
let result = catch_unwind(AssertUnwindSafe(|| unsafe {
let basis = &*b;
let sval_vec = basis.svals();
debug_println!("spir_basis_get_svals: sval_vec.len() = {}", sval_vec.len());
std::io::stderr().flush().unwrap();
if sval_vec.len() > 0 {
std::ptr::copy_nonoverlapping(sval_vec.as_ptr(), svals, sval_vec.len());
}
SPIR_COMPUTATION_SUCCESS
}));
match result {
Ok(code) => code,
Err(panic_payload) => {
use std::io::Write;
debug_eprintln!("Panic in spir_basis_get_svals");
std::io::stderr().flush().unwrap();
if let Some(s) = panic_payload.downcast_ref::<String>() {
debug_eprintln!("Panic message: {}", s);
} else if let Some(s) = panic_payload.downcast_ref::<&str>() {
debug_eprintln!("Panic message: {}", s);
}
std::io::stderr().flush().unwrap();
SPIR_INTERNAL_ERROR
}
}
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_basis_get_stats(
b: *const spir_basis,
statistics: *mut libc::c_int,
) -> StatusCode {
if b.is_null() || statistics.is_null() {
return SPIR_INVALID_ARGUMENT;
}
let result = catch_unwind(AssertUnwindSafe(|| unsafe {
let basis = &*b;
*statistics = basis.statistics();
SPIR_COMPUTATION_SUCCESS
}));
result.unwrap_or(SPIR_INTERNAL_ERROR)
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_basis_get_singular_values(
b: *const spir_basis,
svals: *mut f64,
) -> StatusCode {
spir_basis_get_svals(b, svals)
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_basis_get_n_default_taus(
b: *const spir_basis,
num_points: *mut libc::c_int,
) -> StatusCode {
if b.is_null() || num_points.is_null() {
return SPIR_INVALID_ARGUMENT;
}
let result = catch_unwind(AssertUnwindSafe(|| unsafe {
let basis = &*b;
let points = basis.default_tau_sampling_points();
*num_points = points.len() as libc::c_int;
SPIR_COMPUTATION_SUCCESS
}));
result.unwrap_or(SPIR_INTERNAL_ERROR)
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_basis_get_default_taus(
b: *const spir_basis,
points: *mut f64,
) -> StatusCode {
if b.is_null() || points.is_null() {
return SPIR_INVALID_ARGUMENT;
}
let result = catch_unwind(AssertUnwindSafe(|| unsafe {
let basis = &*b;
let tau_points = basis.default_tau_sampling_points();
std::ptr::copy_nonoverlapping(tau_points.as_ptr(), points, tau_points.len());
SPIR_COMPUTATION_SUCCESS
}));
result.unwrap_or(SPIR_INTERNAL_ERROR)
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_basis_get_n_default_matsus(
b: *const spir_basis,
positive_only: bool,
num_points: *mut libc::c_int,
) -> StatusCode {
if b.is_null() || num_points.is_null() {
return SPIR_INVALID_ARGUMENT;
}
let result = catch_unwind(AssertUnwindSafe(|| unsafe {
let basis = &*b;
let points = basis.default_matsubara_sampling_points(positive_only);
*num_points = points.len() as libc::c_int;
SPIR_COMPUTATION_SUCCESS
}));
result.unwrap_or(SPIR_INTERNAL_ERROR)
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_basis_get_default_matsus(
b: *const spir_basis,
positive_only: bool,
points: *mut i64,
) -> StatusCode {
if b.is_null() || points.is_null() {
return SPIR_INVALID_ARGUMENT;
}
let result = catch_unwind(AssertUnwindSafe(|| unsafe {
let basis = &*b;
let matsu_points = basis.default_matsubara_sampling_points(positive_only);
std::ptr::copy_nonoverlapping(matsu_points.as_ptr(), points, matsu_points.len());
SPIR_COMPUTATION_SUCCESS
}));
result.unwrap_or(SPIR_INTERNAL_ERROR)
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn spir_basis_get_u(
b: *const spir_basis,
status: *mut StatusCode,
) -> *mut spir_funcs {
use crate::types::{BasisType, spir_funcs};
use std::panic::catch_unwind;
if status.is_null() {
return std::ptr::null_mut();
}
if b.is_null() {
unsafe {
*status = SPIR_INVALID_ARGUMENT;
}
return std::ptr::null_mut();
}
let result = catch_unwind(AssertUnwindSafe(|| unsafe {
let basis_ref = &*b;
let beta = basis_ref.beta();
let funcs = match basis_ref.inner() {
BasisType::LogisticFermionic(basis) => {
spir_funcs::from_u_fermionic(basis.u().clone(), beta)
}
BasisType::LogisticBosonic(basis) => {
spir_funcs::from_u_bosonic(basis.u().clone(), beta)
}
BasisType::RegularizedBoseFermionic(basis) => {
spir_funcs::from_u_fermionic(basis.u().clone(), beta)
}
BasisType::RegularizedBoseBosonic(basis) => {
spir_funcs::from_u_bosonic(basis.u().clone(), beta)
}
BasisType::DLRFermionic(dlr) => spir_funcs::from_dlr_tau_fermionic(
dlr.poles.clone(),
beta,
dlr.wmax,
dlr.pole_weights().to_vec(),
dlr.kernel_ypower(),
),
BasisType::DLRBosonic(dlr) => spir_funcs::from_dlr_tau_bosonic(
dlr.poles.clone(),
beta,
dlr.wmax,
dlr.pole_weights().to_vec(),
dlr.kernel_ypower(),
),
};
Result::<*mut spir_funcs, String>::Ok(Box::into_raw(Box::new(funcs)))
}));
match result {
Ok(Ok(ptr)) => {
unsafe {
*status = SPIR_COMPUTATION_SUCCESS;
}
ptr
}
Ok(Err(msg)) => {
debug_eprintln!("Error in spir_basis_get_u: {}", msg);
unsafe {
*status = SPIR_INTERNAL_ERROR;
}
std::ptr::null_mut()
}
Err(_) => {
unsafe {
*status = SPIR_INTERNAL_ERROR;
}
std::ptr::null_mut()
}
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn spir_basis_get_v(
b: *const spir_basis,
status: *mut StatusCode,
) -> *mut spir_funcs {
use crate::types::{BasisType, spir_funcs};
use std::panic::catch_unwind;
if status.is_null() {
return std::ptr::null_mut();
}
if b.is_null() {
unsafe {
*status = SPIR_INVALID_ARGUMENT;
}
return std::ptr::null_mut();
}
let result = catch_unwind(AssertUnwindSafe(|| unsafe {
let basis_ref = &*b;
let beta = basis_ref.beta();
let funcs = match basis_ref.inner() {
BasisType::LogisticFermionic(basis) => spir_funcs::from_v(basis.v().clone(), beta),
BasisType::LogisticBosonic(basis) => spir_funcs::from_v(basis.v().clone(), beta),
BasisType::RegularizedBoseFermionic(basis) => {
spir_funcs::from_v(basis.v().clone(), beta)
}
BasisType::RegularizedBoseBosonic(basis) => spir_funcs::from_v(basis.v().clone(), beta),
BasisType::DLRFermionic(_) | BasisType::DLRBosonic(_) => {
return Result::<*mut spir_funcs, String>::Err(
"DLR does not support continuous functions".to_string(),
);
}
};
Result::<*mut spir_funcs, String>::Ok(Box::into_raw(Box::new(funcs)))
}));
match result {
Ok(Ok(ptr)) => {
unsafe {
*status = SPIR_COMPUTATION_SUCCESS;
}
ptr
}
Ok(Err(msg)) => {
debug_eprintln!("Error in spir_basis_get_v: {}", msg);
unsafe {
*status = SPIR_INTERNAL_ERROR;
}
std::ptr::null_mut()
}
Err(_) => {
unsafe {
*status = SPIR_INTERNAL_ERROR;
}
std::ptr::null_mut()
}
}
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_basis_get_n_default_ws(
b: *const spir_basis,
num_points: *mut libc::c_int,
) -> StatusCode {
if b.is_null() || num_points.is_null() {
return SPIR_INVALID_ARGUMENT;
}
let result = catch_unwind(AssertUnwindSafe(|| unsafe {
let basis = &*b;
let omega_points = basis.default_omega_sampling_points();
*num_points = omega_points.len() as libc::c_int;
SPIR_COMPUTATION_SUCCESS
}));
result.unwrap_or(SPIR_INTERNAL_ERROR)
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_basis_get_default_ws(b: *const spir_basis, points: *mut f64) -> StatusCode {
if b.is_null() || points.is_null() {
return SPIR_INVALID_ARGUMENT;
}
let result = catch_unwind(AssertUnwindSafe(|| unsafe {
let basis = &*b;
let omega_points = basis.default_omega_sampling_points();
std::ptr::copy_nonoverlapping(omega_points.as_ptr(), points, omega_points.len());
SPIR_COMPUTATION_SUCCESS
}));
result.unwrap_or(SPIR_INTERNAL_ERROR)
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn spir_basis_get_uhat(
b: *const spir_basis,
status: *mut StatusCode,
) -> *mut spir_funcs {
use crate::types::{BasisType, spir_funcs};
use std::panic::catch_unwind;
if status.is_null() {
return std::ptr::null_mut();
}
if b.is_null() {
unsafe {
*status = SPIR_INVALID_ARGUMENT;
}
return std::ptr::null_mut();
}
let result = catch_unwind(AssertUnwindSafe(|| unsafe {
let basis_ref = &*b;
let beta = basis_ref.beta();
let funcs = match basis_ref.inner() {
BasisType::LogisticFermionic(basis) => {
spir_funcs::from_uhat_fermionic(basis.uhat().clone(), beta)
}
BasisType::LogisticBosonic(basis) => {
spir_funcs::from_uhat_bosonic(basis.uhat().clone(), beta)
}
BasisType::RegularizedBoseFermionic(basis) => {
spir_funcs::from_uhat_fermionic(basis.uhat().clone(), beta)
}
BasisType::RegularizedBoseBosonic(basis) => {
spir_funcs::from_uhat_bosonic(basis.uhat().clone(), beta)
}
BasisType::DLRFermionic(dlr) => spir_funcs::from_dlr_matsubara_fermionic(
dlr.poles.clone(),
beta,
dlr.wmax,
dlr.pole_weights().to_vec(),
dlr.kernel_ypower(),
),
BasisType::DLRBosonic(dlr) => spir_funcs::from_dlr_matsubara_bosonic(
dlr.poles.clone(),
beta,
dlr.wmax,
dlr.pole_weights().to_vec(),
dlr.kernel_ypower(),
),
};
Result::<*mut spir_funcs, String>::Ok(Box::into_raw(Box::new(funcs)))
}));
match result {
Ok(Ok(ptr)) => {
unsafe {
*status = SPIR_COMPUTATION_SUCCESS;
}
ptr
}
Ok(Err(msg)) => {
debug_eprintln!("Error in spir_basis_get_uhat: {}", msg);
unsafe {
*status = SPIR_INTERNAL_ERROR;
}
std::ptr::null_mut()
}
Err(_) => {
unsafe {
*status = SPIR_INTERNAL_ERROR;
}
std::ptr::null_mut()
}
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn spir_basis_get_uhat_full(
b: *const spir_basis,
status: *mut StatusCode,
) -> *mut spir_funcs {
use crate::SPIR_NOT_SUPPORTED;
use crate::types::{BasisType, spir_funcs};
use std::panic::catch_unwind;
if status.is_null() {
return std::ptr::null_mut();
}
if b.is_null() {
unsafe {
*status = SPIR_INVALID_ARGUMENT;
}
return std::ptr::null_mut();
}
let result = catch_unwind(AssertUnwindSafe(|| unsafe {
let basis_ref = &*b;
let beta = basis_ref.beta();
let funcs = match basis_ref.inner() {
BasisType::LogisticFermionic(basis) => {
spir_funcs::from_uhat_full_fermionic(basis.uhat_full().clone(), beta)
}
BasisType::LogisticBosonic(basis) => {
spir_funcs::from_uhat_full_bosonic(basis.uhat_full().clone(), beta)
}
BasisType::RegularizedBoseFermionic(basis) => {
spir_funcs::from_uhat_full_fermionic(basis.uhat_full().clone(), beta)
}
BasisType::RegularizedBoseBosonic(basis) => {
spir_funcs::from_uhat_full_bosonic(basis.uhat_full().clone(), beta)
}
BasisType::DLRFermionic(_) | BasisType::DLRBosonic(_) => {
return Result::<*mut spir_funcs, String>::Err(
"uhat_full is only available for IR basis, not DLR".to_string(),
);
}
};
Result::<*mut spir_funcs, String>::Ok(Box::into_raw(Box::new(funcs)))
}));
match result {
Ok(Ok(ptr)) => {
unsafe {
*status = SPIR_COMPUTATION_SUCCESS;
}
ptr
}
Ok(Err(msg)) => {
debug_eprintln!("Error in spir_basis_get_uhat_full: {}", msg);
unsafe {
*status = SPIR_NOT_SUPPORTED;
}
std::ptr::null_mut()
}
Err(_) => {
unsafe {
*status = SPIR_INTERNAL_ERROR;
}
std::ptr::null_mut()
}
}
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_basis_get_default_taus_ext(
b: *const spir_basis,
n_points: libc::c_int,
points: *mut f64,
n_points_returned: *mut libc::c_int,
) -> StatusCode {
if b.is_null() || points.is_null() || n_points_returned.is_null() {
return SPIR_INVALID_ARGUMENT;
}
if n_points < 0 {
return SPIR_INVALID_ARGUMENT;
}
let result = catch_unwind(AssertUnwindSafe(|| unsafe {
let basis = &*b;
let tau_points = basis.default_tau_sampling_points_size_requested(n_points as usize);
let n_to_return = std::cmp::min(n_points as usize, tau_points.len());
std::ptr::copy_nonoverlapping(tau_points.as_ptr(), points, n_to_return);
*n_points_returned = n_to_return as libc::c_int;
SPIR_COMPUTATION_SUCCESS
}));
result.unwrap_or(SPIR_INTERNAL_ERROR)
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_basis_get_n_default_matsus_ext(
b: *const spir_basis,
positive_only: bool,
mitigate: bool,
#[allow(non_snake_case)] L: libc::c_int,
num_points_returned: *mut libc::c_int,
) -> StatusCode {
if b.is_null() || num_points_returned.is_null() {
return SPIR_INVALID_ARGUMENT;
}
if L < 0 {
return SPIR_INVALID_ARGUMENT;
}
let result = catch_unwind(AssertUnwindSafe(|| unsafe {
let basis = &*b;
let matsu_points = basis.default_matsubara_sampling_points_with_mitigate(
positive_only,
mitigate,
L as usize,
);
*num_points_returned = matsu_points.len() as libc::c_int;
SPIR_COMPUTATION_SUCCESS
}));
result.unwrap_or(SPIR_INTERNAL_ERROR)
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_basis_get_default_matsus_ext(
b: *const spir_basis,
positive_only: bool,
mitigate: bool,
n_points: libc::c_int,
points: *mut i64,
n_points_returned: *mut libc::c_int,
) -> StatusCode {
if b.is_null() || points.is_null() || n_points_returned.is_null() {
return SPIR_INVALID_ARGUMENT;
}
if n_points < 0 {
return SPIR_INVALID_ARGUMENT;
}
let result = catch_unwind(AssertUnwindSafe(|| unsafe {
let basis = &*b;
let matsu_points = basis.default_matsubara_sampling_points_with_mitigate(
positive_only,
mitigate,
n_points as usize,
);
let n_to_return = matsu_points.len();
std::ptr::copy_nonoverlapping(matsu_points.as_ptr(), points, n_to_return);
*n_points_returned = n_to_return as libc::c_int;
SPIR_COMPUTATION_SUCCESS
}));
result.unwrap_or(SPIR_INTERNAL_ERROR)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::*;
use crate::sve::*;
use crate::{
spir_funcs_get_size, spir_funcs_release, spir_gauss_legendre_rule_piecewise_double,
};
use std::ptr;
#[test]
fn test_basis_from_sve() {
use crate::{spir_funcs_get_size, spir_funcs_release};
let mut kernel_status = SPIR_INTERNAL_ERROR;
let kernel = spir_logistic_kernel_new(10.0, &mut kernel_status);
assert_eq!(kernel_status, SPIR_COMPUTATION_SUCCESS);
let mut sve_status = SPIR_INTERNAL_ERROR;
let sve = spir_sve_result_new(kernel, 1e-6, -1, -1, -1, &mut sve_status);
assert_eq!(sve_status, SPIR_COMPUTATION_SUCCESS);
let mut basis_status = SPIR_INTERNAL_ERROR;
let basis = spir_basis_new(
1, 10.0, 1.0, 1e-6, kernel, sve, -1, &mut basis_status,
);
assert_eq!(basis_status, SPIR_COMPUTATION_SUCCESS);
assert!(!basis.is_null());
let mut size = 0;
let status = spir_basis_get_size(basis, &mut size);
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
assert!(size > 0);
debug_println!("Basis size: {}", size);
let mut stats = -1;
let status = spir_basis_get_stats(basis, &mut stats);
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
assert_eq!(stats, 1);
spir_basis_release(basis);
spir_sve_result_release(sve);
spir_kernel_release(kernel);
}
#[test]
fn test_basis_from_kernel() {
let mut kernel_status = SPIR_INTERNAL_ERROR;
let kernel = spir_logistic_kernel_new(10.0, &mut kernel_status);
let mut basis_status = SPIR_INTERNAL_ERROR;
let basis = spir_basis_new(
0, 10.0, 1.0, 1e-6, kernel,
ptr::null(), -1,
&mut basis_status,
);
assert_eq!(basis_status, SPIR_COMPUTATION_SUCCESS);
assert!(!basis.is_null());
let mut stats = -1;
spir_basis_get_stats(basis, &mut stats);
assert_eq!(stats, 0);
spir_basis_release(basis);
spir_kernel_release(kernel);
}
#[test]
fn test_basis_tau_sampling() {
let mut kernel_status = SPIR_INTERNAL_ERROR;
let kernel = spir_logistic_kernel_new(10.0, &mut kernel_status);
let mut basis_status = SPIR_INTERNAL_ERROR;
let basis = spir_basis_new(
0,
10.0,
1.0,
1e-6,
kernel,
ptr::null(),
-1,
&mut basis_status,
);
let mut n_taus = 0;
let status = spir_basis_get_n_default_taus(basis, &mut n_taus);
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
assert!(n_taus > 0);
debug_println!("Number of tau points: {}", n_taus);
let mut taus = vec![0.0; n_taus as usize];
let status = spir_basis_get_default_taus(basis, taus.as_mut_ptr());
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
debug_println!("First 5 tau points:");
for i in 0..std::cmp::min(5, taus.len()) {
debug_println!(" tau[{}] = {}", i, taus[i]);
}
spir_basis_release(basis);
spir_kernel_release(kernel);
}
#[test]
fn test_basis_matsubara_sampling() {
let mut kernel_status = SPIR_INTERNAL_ERROR;
let kernel = spir_logistic_kernel_new(10.0, &mut kernel_status);
let mut basis_status = SPIR_INTERNAL_ERROR;
let basis = spir_basis_new(
1,
10.0,
1.0,
1e-6,
kernel,
ptr::null(),
-1,
&mut basis_status,
);
let mut n_matsus = 0;
let status = spir_basis_get_n_default_matsus(basis, true, &mut n_matsus);
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
assert!(n_matsus > 0);
debug_println!("Number of Matsubara points (positive): {}", n_matsus);
let mut matsus = vec![0i64; n_matsus as usize];
let status = spir_basis_get_default_matsus(basis, true, matsus.as_mut_ptr());
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
debug_println!("First 5 Matsubara indices:");
for i in 0..std::cmp::min(5, matsus.len()) {
debug_println!(" n[{}] = {}", i, matsus[i]);
}
spir_basis_release(basis);
spir_kernel_release(kernel);
}
#[test]
fn test_basis_omega_sampling() {
let mut kernel_status = SPIR_INTERNAL_ERROR;
let kernel = spir_logistic_kernel_new(10.0, &mut kernel_status);
let mut basis_status = SPIR_INTERNAL_ERROR;
let basis = spir_basis_new(
1,
10.0,
1.0,
1e-6,
kernel,
ptr::null(),
-1,
&mut basis_status,
);
let mut n_ws = 0;
let status = spir_basis_get_n_default_ws(basis, &mut n_ws);
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
assert!(n_ws > 0);
debug_println!("Number of omega points: {}", n_ws);
let mut ws = vec![0.0; n_ws as usize];
let status = spir_basis_get_default_ws(basis, ws.as_mut_ptr());
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
debug_println!("First 5 omega points:");
for i in 0..std::cmp::min(5, ws.len()) {
debug_println!(" w[{}] = {}", i, ws[i]);
}
let mut size = 0;
let status = spir_basis_get_size(basis, &mut size);
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
let mut svals = vec![0.0; size as usize];
let status = spir_basis_get_singular_values(basis, svals.as_mut_ptr());
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
let mut svals2 = vec![0.0; size as usize];
let status2 = spir_basis_get_svals(basis, svals2.as_mut_ptr());
assert_eq!(status2, SPIR_COMPUTATION_SUCCESS);
assert_eq!(svals, svals2);
debug_println!("✓ get_singular_values matches get_svals");
spir_basis_release(basis);
spir_kernel_release(kernel);
}
#[test]
fn test_basis_ext_functions() {
use crate::kernel::*;
let mut kernel_status = SPIR_INTERNAL_ERROR;
let kernel = spir_logistic_kernel_new(10.0, &mut kernel_status);
assert_eq!(kernel_status, SPIR_COMPUTATION_SUCCESS);
let mut basis_status = SPIR_INTERNAL_ERROR;
let basis = spir_basis_new(
1,
10.0,
1.0,
1e-6,
kernel,
ptr::null(),
-1,
&mut basis_status,
);
assert_eq!(basis_status, SPIR_COMPUTATION_SUCCESS);
let requested_tau = 5; let mut tau_points = vec![0.0; requested_tau];
let mut tau_returned = 0;
let status = spir_basis_get_default_taus_ext(
basis,
requested_tau as libc::c_int,
tau_points.as_mut_ptr(),
&mut tau_returned,
);
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
assert_eq!(tau_returned, requested_tau as libc::c_int);
println!(
"✓ get_default_taus_ext returned {} tau points (requested {})",
tau_returned, requested_tau
);
debug_println!(" First 3: {:?}", &tau_points[..3]);
let requested_matsu = 3; let mut matsu_count = 0;
let status = spir_basis_get_n_default_matsus_ext(
basis,
true, false, requested_matsu,
&mut matsu_count,
);
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
assert!(matsu_count > 0);
debug_println!("✓ get_n_default_matsus_ext returned count: {}", matsu_count);
let mut matsu_points = vec![0i64; matsu_count as usize];
let mut matsu_returned = 0;
let status = spir_basis_get_default_matsus_ext(
basis,
true, false, requested_matsu,
matsu_points.as_mut_ptr(),
&mut matsu_returned,
);
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
assert_eq!(matsu_returned, matsu_count);
println!(
"✓ get_default_matsus_ext returned {} matsubara points",
matsu_returned
);
debug_println!(" Points: {:?}", matsu_points);
let mut bad_returned = 0;
let status = spir_basis_get_default_taus_ext(basis, -1, ptr::null_mut(), &mut bad_returned);
assert_eq!(status, SPIR_INVALID_ARGUMENT);
debug_println!("✓ Negative n_points correctly rejected");
spir_basis_release(basis);
spir_kernel_release(kernel);
}
#[test]
fn test_basis_get_uhat_full() {
let mut kernel_status = SPIR_INTERNAL_ERROR;
let kernel = spir_logistic_kernel_new(10.0, &mut kernel_status);
assert_eq!(kernel_status, SPIR_COMPUTATION_SUCCESS);
let mut basis_status = SPIR_INTERNAL_ERROR;
let basis = spir_basis_new(
1, 10.0, 1.0, 1e-6, kernel,
ptr::null(),
-1,
&mut basis_status,
);
assert_eq!(basis_status, SPIR_COMPUTATION_SUCCESS);
let mut uhat_status = SPIR_INTERNAL_ERROR;
let uhat_funcs = unsafe { spir_basis_get_uhat(basis, &mut uhat_status) };
assert_eq!(uhat_status, SPIR_COMPUTATION_SUCCESS);
assert!(!uhat_funcs.is_null());
let mut uhat_size = 0;
let status = spir_funcs_get_size(uhat_funcs, &mut uhat_size);
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
let mut uhat_full_status = SPIR_INTERNAL_ERROR;
let uhat_full_funcs = unsafe { spir_basis_get_uhat_full(basis, &mut uhat_full_status) };
assert_eq!(uhat_full_status, SPIR_COMPUTATION_SUCCESS);
assert!(!uhat_full_funcs.is_null());
let mut uhat_full_size = 0;
let status = spir_funcs_get_size(uhat_full_funcs, &mut uhat_full_size);
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
assert_eq!(uhat_size, 10);
assert!(uhat_full_size >= uhat_size);
{
}
unsafe {
spir_funcs_release(uhat_funcs);
spir_funcs_release(uhat_full_funcs);
spir_basis_release(basis);
spir_kernel_release(kernel);
}
}
#[test]
fn test_basis_new_from_sve_and_regularizer() {
use crate::{
SPIR_COMPUTATION_SUCCESS, SPIR_INTERNAL_ERROR, spir_funcs_from_piecewise_legendre,
spir_funcs_release,
};
let lambda = 10.0;
let beta = 1.0;
let omega_max = lambda / beta;
let epsilon = 1e-8;
let mut kernel_status = SPIR_INTERNAL_ERROR;
let kernel = spir_logistic_kernel_new(lambda, &mut kernel_status);
assert_eq!(kernel_status, SPIR_COMPUTATION_SUCCESS);
assert!(!kernel.is_null());
let mut sve_status = SPIR_INTERNAL_ERROR;
use crate::SPIR_TWORK_AUTO;
let sve = spir_sve_result_new(kernel, epsilon, -1, -1, SPIR_TWORK_AUTO, &mut sve_status);
assert_eq!(sve_status, SPIR_COMPUTATION_SUCCESS);
assert!(!sve.is_null());
let n_segments = 1;
let segments = [-omega_max, omega_max]; let coeffs = [1.0]; let nfuncs = 1;
let order = 0;
let mut regularizer_status = SPIR_INTERNAL_ERROR;
let regularizer_funcs = spir_funcs_from_piecewise_legendre(
segments.as_ptr(),
n_segments,
coeffs.as_ptr(),
nfuncs,
order,
&mut regularizer_status,
);
assert_eq!(regularizer_status, SPIR_COMPUTATION_SUCCESS);
assert!(!regularizer_funcs.is_null());
let mut basis_status = SPIR_INTERNAL_ERROR;
let basis = spir_basis_new_from_sve_and_regularizer(
1, beta,
omega_max,
epsilon,
lambda,
0, 1.0, sve,
regularizer_funcs,
-1, &mut basis_status,
);
if basis_status == SPIR_COMPUTATION_SUCCESS && !basis.is_null() {
let mut basis_size = 0;
let status = spir_basis_get_size(basis, &mut basis_size);
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
assert!(basis_size > 0);
unsafe {
spir_basis_release(basis);
}
}
{
let mut basis_status = SPIR_INTERNAL_ERROR;
let basis_err = spir_basis_new_from_sve_and_regularizer(
1,
beta,
omega_max,
epsilon,
lambda,
0,
1.0,
ptr::null(),
regularizer_funcs,
-1,
&mut basis_status,
);
assert_ne!(basis_status, SPIR_COMPUTATION_SUCCESS);
assert!(basis_err.is_null());
}
{
let mut basis_status = SPIR_INTERNAL_ERROR;
let basis_err = spir_basis_new_from_sve_and_regularizer(
1,
beta,
omega_max,
epsilon,
lambda,
0,
1.0,
sve,
ptr::null(),
-1,
&mut basis_status,
);
assert_ne!(basis_status, SPIR_COMPUTATION_SUCCESS);
assert!(basis_err.is_null());
}
unsafe {
spir_funcs_release(regularizer_funcs);
spir_sve_result_release(sve);
spir_kernel_release(kernel);
}
}
}