use num_complex::Complex64;
use std::panic::{AssertUnwindSafe, catch_unwind};
use std::sync::Arc;
use crate::gemm::{get_backend_handle, spir_gemm_backend};
use crate::types::{BasisType, spir_basis};
use crate::utils::{MemoryOrder, copy_tensor_to_c_array, read_tensor_nd};
use crate::{SPIR_COMPUTATION_SUCCESS, SPIR_INVALID_ARGUMENT, SPIR_NOT_SUPPORTED, StatusCode};
use sparse_ir::dlr::DiscreteLehmannRepresentation;
#[unsafe(no_mangle)]
pub extern "C" fn spir_dlr_new(b: *const spir_basis, status: *mut StatusCode) -> *mut spir_basis {
let result = catch_unwind(AssertUnwindSafe(|| {
if b.is_null() {
return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
}
let basis_ref = unsafe { &*b };
let dlr_type = match basis_ref.inner() {
BasisType::LogisticFermionic(ir_basis) => {
let dlr = DiscreteLehmannRepresentation::new(ir_basis.as_ref());
BasisType::DLRFermionic(Arc::new(dlr))
}
BasisType::LogisticBosonic(ir_basis) => {
let dlr = DiscreteLehmannRepresentation::new(ir_basis.as_ref());
BasisType::DLRBosonic(Arc::new(dlr))
}
BasisType::RegularizedBoseFermionic(ir_basis) => {
let dlr = DiscreteLehmannRepresentation::new(ir_basis.as_ref());
BasisType::DLRFermionic(Arc::new(dlr))
}
BasisType::RegularizedBoseBosonic(ir_basis) => {
let dlr = DiscreteLehmannRepresentation::new(ir_basis.as_ref());
BasisType::DLRBosonic(Arc::new(dlr))
}
_ => {
return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
}
};
let dlr_basis = match dlr_type {
BasisType::DLRFermionic(arc_dlr) => spir_basis::new_dlr_fermionic(arc_dlr),
BasisType::DLRBosonic(arc_dlr) => spir_basis::new_dlr_bosonic(arc_dlr),
_ => unreachable!(), };
(Box::into_raw(Box::new(dlr_basis)), SPIR_COMPUTATION_SUCCESS)
}));
match result {
Ok((ptr, code)) => {
if !status.is_null() {
unsafe {
*status = code;
}
}
ptr
}
Err(_) => {
if !status.is_null() {
unsafe {
*status = crate::SPIR_INTERNAL_ERROR;
}
}
std::ptr::null_mut()
}
}
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_dlr_new_with_poles(
b: *const spir_basis,
npoles: libc::c_int,
poles: *const f64,
status: *mut StatusCode,
) -> *mut spir_basis {
let result = catch_unwind(AssertUnwindSafe(|| {
if b.is_null() || poles.is_null() {
return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
}
if npoles <= 0 {
return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
}
let basis_ref = unsafe { &*b };
let poles_slice = unsafe { std::slice::from_raw_parts(poles, npoles as usize) };
let pole_vec: Vec<f64> = poles_slice.to_vec();
let dlr_type = match basis_ref.inner() {
BasisType::LogisticFermionic(ir_basis) => {
let dlr = DiscreteLehmannRepresentation::with_poles(ir_basis.as_ref(), pole_vec);
BasisType::DLRFermionic(Arc::new(dlr))
}
BasisType::LogisticBosonic(ir_basis) => {
let dlr = DiscreteLehmannRepresentation::with_poles(ir_basis.as_ref(), pole_vec);
BasisType::DLRBosonic(Arc::new(dlr))
}
BasisType::RegularizedBoseFermionic(ir_basis) => {
let dlr = DiscreteLehmannRepresentation::with_poles(ir_basis.as_ref(), pole_vec);
BasisType::DLRFermionic(Arc::new(dlr))
}
BasisType::RegularizedBoseBosonic(ir_basis) => {
let dlr = DiscreteLehmannRepresentation::with_poles(ir_basis.as_ref(), pole_vec);
BasisType::DLRBosonic(Arc::new(dlr))
}
_ => {
return (std::ptr::null_mut(), SPIR_INVALID_ARGUMENT);
}
};
let dlr_basis = match dlr_type {
BasisType::DLRFermionic(arc_dlr) => spir_basis::new_dlr_fermionic(arc_dlr),
BasisType::DLRBosonic(arc_dlr) => spir_basis::new_dlr_bosonic(arc_dlr),
_ => unreachable!(), };
(Box::into_raw(Box::new(dlr_basis)), SPIR_COMPUTATION_SUCCESS)
}));
match result {
Ok((ptr, code)) => {
if !status.is_null() {
unsafe {
*status = code;
}
}
ptr
}
Err(_) => {
if !status.is_null() {
unsafe {
*status = crate::SPIR_INTERNAL_ERROR;
}
}
std::ptr::null_mut()
}
}
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_dlr_get_npoles(
dlr: *const spir_basis,
num_poles: *mut libc::c_int,
) -> StatusCode {
if dlr.is_null() || num_poles.is_null() {
return SPIR_INVALID_ARGUMENT;
}
let result = catch_unwind(AssertUnwindSafe(|| {
let dlr_ref = unsafe { &*dlr };
let npoles = match dlr_ref.inner() {
BasisType::DLRFermionic(dlr) => dlr.poles.len(),
BasisType::DLRBosonic(dlr) => dlr.poles.len(),
_ => return SPIR_INVALID_ARGUMENT, };
unsafe {
*num_poles = npoles as libc::c_int;
}
SPIR_COMPUTATION_SUCCESS
}));
result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_dlr_get_poles(dlr: *const spir_basis, poles: *mut f64) -> StatusCode {
if dlr.is_null() || poles.is_null() {
return SPIR_INVALID_ARGUMENT;
}
let result = catch_unwind(AssertUnwindSafe(|| {
let dlr_ref = unsafe { &*dlr };
let pole_vec = match dlr_ref.inner() {
BasisType::DLRFermionic(dlr) => &dlr.poles,
BasisType::DLRBosonic(dlr) => &dlr.poles,
_ => return SPIR_INVALID_ARGUMENT, };
for (i, &pole) in pole_vec.iter().enumerate() {
unsafe {
*poles.add(i) = pole;
}
}
SPIR_COMPUTATION_SUCCESS
}));
result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_ir2dlr_dd(
dlr: *const spir_basis,
backend: *const spir_gemm_backend,
order: libc::c_int,
ndim: libc::c_int,
input_dims: *const libc::c_int,
target_dim: libc::c_int,
input: *const f64,
out: *mut f64,
) -> StatusCode {
let result = catch_unwind(AssertUnwindSafe(|| {
if dlr.is_null() || input_dims.is_null() || input.is_null() || out.is_null() {
return SPIR_INVALID_ARGUMENT;
}
if ndim <= 0 || target_dim < 0 || target_dim >= ndim {
return SPIR_INVALID_ARGUMENT;
}
let mem_order = match MemoryOrder::from_c_int(order) {
Ok(o) => o,
Err(_) => return SPIR_INVALID_ARGUMENT,
};
let dlr_ref = unsafe { &*dlr };
let dims_slice = unsafe { std::slice::from_raw_parts(input_dims, ndim as usize) };
let orig_dims: Vec<usize> = dims_slice.iter().map(|&d| d as usize).collect();
let input_tensor = unsafe { read_tensor_nd(input, &orig_dims, mem_order) };
let backend_handle = unsafe { get_backend_handle(backend) };
let result_tensor = match dlr_ref.inner() {
BasisType::DLRFermionic(dlr) => {
dlr.from_ir_nd(backend_handle, &input_tensor, target_dim as usize)
}
BasisType::DLRBosonic(dlr) => {
dlr.from_ir_nd(backend_handle, &input_tensor, target_dim as usize)
}
_ => return SPIR_NOT_SUPPORTED, };
unsafe {
copy_tensor_to_c_array(result_tensor, out, mem_order);
}
SPIR_COMPUTATION_SUCCESS
}));
result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_ir2dlr_zz(
dlr: *const spir_basis,
backend: *const spir_gemm_backend,
order: libc::c_int,
ndim: libc::c_int,
input_dims: *const libc::c_int,
target_dim: libc::c_int,
input: *const Complex64,
out: *mut Complex64,
) -> StatusCode {
let result = catch_unwind(AssertUnwindSafe(|| {
if dlr.is_null() || input_dims.is_null() || input.is_null() || out.is_null() {
return SPIR_INVALID_ARGUMENT;
}
if ndim <= 0 || target_dim < 0 || target_dim >= ndim {
return SPIR_INVALID_ARGUMENT;
}
let mem_order = match MemoryOrder::from_c_int(order) {
Ok(o) => o,
Err(_) => return SPIR_INVALID_ARGUMENT,
};
let dlr_ref = unsafe { &*dlr };
let dims_slice = unsafe { std::slice::from_raw_parts(input_dims, ndim as usize) };
let orig_dims: Vec<usize> = dims_slice.iter().map(|&d| d as usize).collect();
let input_tensor = unsafe { read_tensor_nd(input, &orig_dims, mem_order) };
let backend_handle = unsafe { get_backend_handle(backend) };
let result_tensor = match dlr_ref.inner() {
BasisType::DLRFermionic(dlr) => {
dlr.from_ir_nd(backend_handle, &input_tensor, target_dim as usize)
}
BasisType::DLRBosonic(dlr) => {
dlr.from_ir_nd(backend_handle, &input_tensor, target_dim as usize)
}
_ => return SPIR_NOT_SUPPORTED, };
unsafe {
copy_tensor_to_c_array(result_tensor, out, mem_order);
}
SPIR_COMPUTATION_SUCCESS
}));
result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_dlr2ir_dd(
dlr: *const spir_basis,
backend: *const spir_gemm_backend,
order: libc::c_int,
ndim: libc::c_int,
input_dims: *const libc::c_int,
target_dim: libc::c_int,
input: *const f64,
out: *mut f64,
) -> StatusCode {
let result = catch_unwind(AssertUnwindSafe(|| {
if dlr.is_null() || input_dims.is_null() || input.is_null() || out.is_null() {
return SPIR_INVALID_ARGUMENT;
}
if ndim <= 0 || target_dim < 0 || target_dim >= ndim {
return SPIR_INVALID_ARGUMENT;
}
let mem_order = match MemoryOrder::from_c_int(order) {
Ok(o) => o,
Err(_) => return SPIR_INVALID_ARGUMENT,
};
let dlr_ref = unsafe { &*dlr };
let dims_slice = unsafe { std::slice::from_raw_parts(input_dims, ndim as usize) };
let orig_dims: Vec<usize> = dims_slice.iter().map(|&d| d as usize).collect();
let input_tensor = unsafe { read_tensor_nd(input, &orig_dims, mem_order) };
let backend_handle = unsafe { get_backend_handle(backend) };
let result_tensor = match dlr_ref.inner() {
BasisType::DLRFermionic(dlr) => {
dlr.to_ir_nd(backend_handle, &input_tensor, target_dim as usize)
}
BasisType::DLRBosonic(dlr) => {
dlr.to_ir_nd(backend_handle, &input_tensor, target_dim as usize)
}
_ => return SPIR_NOT_SUPPORTED, };
unsafe {
copy_tensor_to_c_array(result_tensor, out, mem_order);
}
SPIR_COMPUTATION_SUCCESS
}));
result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
}
#[unsafe(no_mangle)]
pub extern "C" fn spir_dlr2ir_zz(
dlr: *const spir_basis,
backend: *const spir_gemm_backend,
order: libc::c_int,
ndim: libc::c_int,
input_dims: *const libc::c_int,
target_dim: libc::c_int,
input: *const Complex64,
out: *mut Complex64,
) -> StatusCode {
let result = catch_unwind(AssertUnwindSafe(|| {
if dlr.is_null() || input_dims.is_null() || input.is_null() || out.is_null() {
return SPIR_INVALID_ARGUMENT;
}
if ndim <= 0 || target_dim < 0 || target_dim >= ndim {
return SPIR_INVALID_ARGUMENT;
}
let mem_order = match MemoryOrder::from_c_int(order) {
Ok(o) => o,
Err(_) => return SPIR_INVALID_ARGUMENT,
};
let dlr_ref = unsafe { &*dlr };
let dims_slice = unsafe { std::slice::from_raw_parts(input_dims, ndim as usize) };
let orig_dims: Vec<usize> = dims_slice.iter().map(|&d| d as usize).collect();
let input_tensor = unsafe { read_tensor_nd(input, &orig_dims, mem_order) };
let backend_handle = unsafe { get_backend_handle(backend) };
let result_tensor = match dlr_ref.inner() {
BasisType::DLRFermionic(dlr) => {
dlr.to_ir_nd(backend_handle, &input_tensor, target_dim as usize)
}
BasisType::DLRBosonic(dlr) => {
dlr.to_ir_nd(backend_handle, &input_tensor, target_dim as usize)
}
_ => return SPIR_NOT_SUPPORTED, };
unsafe {
copy_tensor_to_c_array(result_tensor, out, mem_order);
}
SPIR_COMPUTATION_SUCCESS
}));
result.unwrap_or(crate::SPIR_INTERNAL_ERROR)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::SPIR_COMPUTATION_SUCCESS;
use crate::basis::spir_basis_new;
use crate::kernel::spir_logistic_kernel_new;
use crate::sve::spir_sve_result_new;
#[test]
fn test_dlr_creation() {
sparse_ir::gemm::clear_blas_backend();
unsafe {
let mut kernel_status = crate::SPIR_INTERNAL_ERROR;
let kernel = spir_logistic_kernel_new(10.0, &mut kernel_status);
assert_eq!(kernel_status, SPIR_COMPUTATION_SUCCESS);
assert!(!kernel.is_null());
let mut sve_status = crate::SPIR_INTERNAL_ERROR;
let sve = spir_sve_result_new(kernel, 1e-6, -1, -1, 0, &mut sve_status);
assert_eq!(sve_status, SPIR_COMPUTATION_SUCCESS);
assert!(!sve.is_null());
let mut basis_status = crate::SPIR_INTERNAL_ERROR;
let basis = spir_basis_new(1, 10.0, 1.0, 1e-6, kernel, sve, -1, &mut basis_status);
assert_eq!(basis_status, SPIR_COMPUTATION_SUCCESS);
assert!(!basis.is_null());
let mut dlr_status = crate::SPIR_INTERNAL_ERROR;
let dlr = spir_dlr_new(basis, &mut dlr_status);
assert_eq!(dlr_status, SPIR_COMPUTATION_SUCCESS);
assert!(!dlr.is_null());
let mut npoles = 0;
let status = spir_dlr_get_npoles(dlr, &mut npoles);
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
assert!(npoles > 0);
debug_println!("DLR has {} poles", npoles);
let mut poles = vec![0.0; npoles as usize];
let status = spir_dlr_get_poles(dlr, poles.as_mut_ptr());
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
debug_println!("First 3 poles: {:?}", &poles[0..3.min(npoles as usize)]);
let mut u_status = crate::SPIR_INTERNAL_ERROR;
let u_funcs = crate::basis::spir_basis_get_u(dlr, &mut u_status);
assert_eq!(u_status, SPIR_COMPUTATION_SUCCESS);
assert!(!u_funcs.is_null());
debug_println!("✓ Got u funcs from DLR");
let tau = 0.5;
let mut u_values = vec![0.0; npoles as usize];
let status = crate::funcs::spir_funcs_eval(u_funcs, tau, u_values.as_mut_ptr());
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
println!(
"✓ Evaluated u at τ={}: {:?}",
tau,
&u_values[0..3.min(npoles as usize)]
);
let mut uhat_status = crate::SPIR_INTERNAL_ERROR;
let uhat_funcs = crate::basis::spir_basis_get_uhat(dlr, &mut uhat_status);
assert_eq!(uhat_status, SPIR_COMPUTATION_SUCCESS);
assert!(!uhat_funcs.is_null());
debug_println!("✓ Got uhat funcs from DLR");
let n_matsu = 1i64;
let mut uhat_values = vec![num_complex::Complex64::new(0.0, 0.0); npoles as usize];
let status =
crate::funcs::spir_funcs_eval_matsu(uhat_funcs, n_matsu, uhat_values.as_mut_ptr());
assert_eq!(status, SPIR_COMPUTATION_SUCCESS);
println!(
"✓ Evaluated uhat at n={}: |uhat|={:?}",
n_matsu,
&uhat_values[0..3.min(npoles as usize)]
.iter()
.map(|v| v.norm())
.collect::<Vec<_>>()
);
crate::funcs::spir_funcs_release(uhat_funcs);
crate::funcs::spir_funcs_release(u_funcs);
crate::basis::spir_basis_release(dlr);
crate::basis::spir_basis_release(basis);
crate::sve::spir_sve_result_release(sve);
crate::kernel::spir_kernel_release(kernel);
}
}
}