use std::pin::Pin;
use crate::prelude::*;
pub type QFunctionInputs<'a> = [&'a [crate::Scalar]; MAX_QFUNCTION_FIELDS];
pub type QFunctionOutputs<'a> = [&'a mut [crate::Scalar]; MAX_QFUNCTION_FIELDS];
#[derive(Debug)]
pub struct QFunctionField<'a> {
ptr: bind_ceed::CeedQFunctionField,
_lifeline: PhantomData<&'a ()>,
}
impl<'a> QFunctionField<'a> {
pub fn name(&self) -> &str {
let mut name_ptr: *mut std::os::raw::c_char = std::ptr::null_mut();
unsafe {
bind_ceed::CeedQFunctionFieldGetName(self.ptr, &mut name_ptr);
}
unsafe { CStr::from_ptr(name_ptr) }.to_str().unwrap()
}
pub fn size(&self) -> usize {
let mut size = 0;
unsafe {
bind_ceed::CeedQFunctionFieldGetSize(self.ptr, &mut size);
}
usize::try_from(size).unwrap()
}
pub fn eval_mode(&self) -> crate::EvalMode {
let mut mode = 0;
unsafe {
bind_ceed::CeedQFunctionFieldGetEvalMode(self.ptr, &mut mode);
}
crate::EvalMode::from_u32(mode as u32)
}
}
pub enum QFunctionOpt<'a> {
SomeQFunction(&'a QFunction<'a>),
SomeQFunctionByName(&'a QFunctionByName<'a>),
None,
}
impl<'a> From<&'a QFunction<'_>> for QFunctionOpt<'a> {
fn from(qfunc: &'a QFunction) -> Self {
debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE });
Self::SomeQFunction(qfunc)
}
}
impl<'a> From<&'a QFunctionByName<'_>> for QFunctionOpt<'a> {
fn from(qfunc: &'a QFunctionByName) -> Self {
debug_assert!(qfunc.qf_core.ptr != unsafe { bind_ceed::CEED_QFUNCTION_NONE });
Self::SomeQFunctionByName(qfunc)
}
}
impl<'a> QFunctionOpt<'a> {
pub(crate) fn to_raw(self) -> bind_ceed::CeedQFunction {
match self {
Self::SomeQFunction(qfunc) => qfunc.qf_core.ptr,
Self::SomeQFunctionByName(qfunc) => qfunc.qf_core.ptr,
Self::None => unsafe { bind_ceed::CEED_QFUNCTION_NONE },
}
}
pub fn is_some(&self) -> bool {
match self {
Self::SomeQFunction(_) => true,
Self::SomeQFunctionByName(_) => true,
Self::None => false,
}
}
pub fn is_some_q_function(&self) -> bool {
match self {
Self::SomeQFunction(_) => true,
Self::SomeQFunctionByName(_) => false,
Self::None => false,
}
}
pub fn is_some_q_function_by_name(&self) -> bool {
match self {
Self::SomeQFunction(_) => false,
Self::SomeQFunctionByName(_) => true,
Self::None => false,
}
}
pub fn is_none(&self) -> bool {
match self {
Self::SomeQFunction(_) => false,
Self::SomeQFunctionByName(_) => false,
Self::None => true,
}
}
}
#[derive(Debug)]
pub(crate) struct QFunctionCore<'a> {
ptr: bind_ceed::CeedQFunction,
_lifeline: PhantomData<&'a ()>,
}
struct QFunctionTrampolineData {
number_inputs: usize,
number_outputs: usize,
input_sizes: [usize; MAX_QFUNCTION_FIELDS],
output_sizes: [usize; MAX_QFUNCTION_FIELDS],
user_f: Box<QFunctionUserClosure>,
}
pub struct QFunction<'a> {
qf_core: QFunctionCore<'a>,
qf_ctx_ptr: bind_ceed::CeedQFunctionContext,
trampoline_data: Pin<Box<QFunctionTrampolineData>>,
}
#[derive(Debug)]
pub struct QFunctionByName<'a> {
qf_core: QFunctionCore<'a>,
}
impl<'a> Drop for QFunctionCore<'a> {
fn drop(&mut self) {
unsafe {
if self.ptr != bind_ceed::CEED_QFUNCTION_NONE {
bind_ceed::CeedQFunctionDestroy(&mut self.ptr);
}
}
}
}
impl<'a> Drop for QFunction<'a> {
fn drop(&mut self) {
unsafe {
bind_ceed::CeedQFunctionContextDestroy(&mut self.qf_ctx_ptr);
}
}
}
impl<'a> fmt::Display for QFunctionCore<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut ptr = std::ptr::null_mut();
let mut sizeloc = crate::MAX_BUFFER_LENGTH;
let cstring = unsafe {
let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc);
bind_ceed::CeedQFunctionView(self.ptr, file);
bind_ceed::fclose(file);
CString::from_raw(ptr)
};
cstring.to_string_lossy().fmt(f)
}
}
impl<'a> fmt::Display for QFunction<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.qf_core.fmt(f)
}
}
impl<'a> fmt::Display for QFunctionByName<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.qf_core.fmt(f)
}
}
impl<'a> QFunctionCore<'a> {
#[doc(hidden)]
fn check_error(&self, ierr: i32) -> crate::Result<i32> {
let mut ptr = std::ptr::null_mut();
unsafe {
bind_ceed::CeedQFunctionGetCeed(self.ptr, &mut ptr);
}
crate::check_error(ptr, ierr)
}
pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
let mut u_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS];
for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, u.len()) {
u_c[i] = u[i].ptr;
}
let mut v_c = [std::ptr::null_mut(); MAX_QFUNCTION_FIELDS];
for i in 0..std::cmp::min(MAX_QFUNCTION_FIELDS, v.len()) {
v_c[i] = v[i].ptr;
}
let Q = i32::try_from(Q).unwrap();
let ierr = unsafe {
bind_ceed::CeedQFunctionApply(self.ptr, Q, u_c.as_mut_ptr(), v_c.as_mut_ptr())
};
self.check_error(ierr)
}
pub fn inputs(&self) -> crate::Result<&[crate::QFunctionField]> {
let mut num_inputs = 0;
let mut inputs_ptr = std::ptr::null_mut();
let ierr = unsafe {
bind_ceed::CeedQFunctionGetFields(
self.ptr,
&mut num_inputs,
&mut inputs_ptr,
std::ptr::null_mut() as *mut bind_ceed::CeedInt,
std::ptr::null_mut() as *mut *mut bind_ceed::CeedQFunctionField,
)
};
self.check_error(ierr)?;
let inputs_slice = unsafe {
std::slice::from_raw_parts(
inputs_ptr as *const crate::QFunctionField,
num_inputs as usize,
)
};
Ok(inputs_slice)
}
pub fn outputs(&self) -> crate::Result<&[crate::QFunctionField]> {
let mut num_outputs = 0;
let mut outputs_ptr = std::ptr::null_mut();
let ierr = unsafe {
bind_ceed::CeedQFunctionGetFields(
self.ptr,
std::ptr::null_mut() as *mut bind_ceed::CeedInt,
std::ptr::null_mut() as *mut *mut bind_ceed::CeedQFunctionField,
&mut num_outputs,
&mut outputs_ptr,
)
};
self.check_error(ierr)?;
let outputs_slice = unsafe {
std::slice::from_raw_parts(
outputs_ptr as *const crate::QFunctionField,
num_outputs as usize,
)
};
Ok(outputs_slice)
}
}
pub type QFunctionUserClosure = dyn FnMut(
[&[crate::Scalar]; MAX_QFUNCTION_FIELDS],
[&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS],
) -> i32;
macro_rules! mut_max_fields {
($e:expr) => {
[
$e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e, $e,
]
};
}
unsafe extern "C" fn trampoline(
ctx: *mut ::std::os::raw::c_void,
q: bind_ceed::CeedInt,
inputs: *const *const bind_ceed::CeedScalar,
outputs: *const *mut bind_ceed::CeedScalar,
) -> ::std::os::raw::c_int {
let trampoline_data: Pin<&mut QFunctionTrampolineData> = std::mem::transmute(ctx);
let inputs_slice: &[*const bind_ceed::CeedScalar] =
std::slice::from_raw_parts(inputs, MAX_QFUNCTION_FIELDS);
let mut inputs_array: [&[crate::Scalar]; MAX_QFUNCTION_FIELDS] = [&[0.0]; MAX_QFUNCTION_FIELDS];
inputs_slice
.iter()
.enumerate()
.map(|(i, &x)| {
std::slice::from_raw_parts(x, trampoline_data.input_sizes[i] * q as usize)
as &[crate::Scalar]
})
.zip(inputs_array.iter_mut())
.for_each(|(x, a)| *a = x);
let outputs_slice: &[*mut bind_ceed::CeedScalar] =
std::slice::from_raw_parts(outputs, MAX_QFUNCTION_FIELDS);
let mut outputs_array: [&mut [crate::Scalar]; MAX_QFUNCTION_FIELDS] =
mut_max_fields!(&mut [0.0]);
outputs_slice
.iter()
.enumerate()
.map(|(i, &x)| {
std::slice::from_raw_parts_mut(x, trampoline_data.output_sizes[i] * q as usize)
as &mut [crate::Scalar]
})
.zip(outputs_array.iter_mut())
.for_each(|(x, a)| *a = x);
(trampoline_data.get_unchecked_mut().user_f)(inputs_array, outputs_array)
}
unsafe extern "C" fn destroy_trampoline(ctx: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int {
let trampoline_data: Pin<&mut QFunctionTrampolineData> = std::mem::transmute(ctx);
drop(trampoline_data);
0 }
impl<'a> QFunction<'a> {
pub fn create(
ceed: &crate::Ceed,
vlength: usize,
user_f: Box<QFunctionUserClosure>,
) -> crate::Result<Self> {
let source_c = CString::new("").expect("CString::new failed");
let mut ptr = std::ptr::null_mut();
let number_inputs = 0;
let number_outputs = 0;
let input_sizes = [0; MAX_QFUNCTION_FIELDS];
let output_sizes = [0; MAX_QFUNCTION_FIELDS];
let trampoline_data = unsafe {
Pin::new_unchecked(Box::new(QFunctionTrampolineData {
number_inputs,
number_outputs,
input_sizes,
output_sizes,
user_f,
}))
};
let vlength = i32::try_from(vlength).unwrap();
let mut ierr = unsafe {
bind_ceed::CeedQFunctionCreateInterior(
ceed.ptr,
vlength,
Some(trampoline),
source_c.as_ptr(),
&mut ptr,
)
};
ceed.check_error(ierr)?;
let mut qf_ctx_ptr = std::ptr::null_mut();
ierr = unsafe { bind_ceed::CeedQFunctionContextCreate(ceed.ptr, &mut qf_ctx_ptr) };
ceed.check_error(ierr)?;
ierr = unsafe {
bind_ceed::CeedQFunctionContextSetData(
qf_ctx_ptr,
crate::MemType::Host as bind_ceed::CeedMemType,
crate::CopyMode::UsePointer as bind_ceed::CeedCopyMode,
std::mem::size_of::<QFunctionTrampolineData>(),
std::mem::transmute(trampoline_data.as_ref()),
)
};
ceed.check_error(ierr)?;
ierr = unsafe {
bind_ceed::CeedQFunctionContextSetDataDestroy(
qf_ctx_ptr,
crate::MemType::Host as bind_ceed::CeedMemType,
Some(destroy_trampoline),
)
};
ceed.check_error(ierr)?;
ierr = unsafe { bind_ceed::CeedQFunctionSetContext(ptr, qf_ctx_ptr) };
ceed.check_error(ierr)?;
Ok(Self {
qf_core: QFunctionCore {
ptr,
_lifeline: PhantomData,
},
qf_ctx_ptr,
trampoline_data,
})
}
pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
self.qf_core.apply(Q, u, v)
}
pub fn input(
mut self,
fieldname: &str,
size: usize,
emode: crate::EvalMode,
) -> crate::Result<Self> {
let name_c = CString::new(fieldname).expect("CString::new failed");
let idx = self.trampoline_data.number_inputs;
self.trampoline_data.input_sizes[idx] = size;
self.trampoline_data.number_inputs += 1;
let (size, emode) = (
i32::try_from(size).unwrap(),
emode as bind_ceed::CeedEvalMode,
);
let ierr = unsafe {
bind_ceed::CeedQFunctionAddInput(self.qf_core.ptr, name_c.as_ptr(), size, emode)
};
self.qf_core.check_error(ierr)?;
Ok(self)
}
pub fn output(
mut self,
fieldname: &str,
size: usize,
emode: crate::EvalMode,
) -> crate::Result<Self> {
let name_c = CString::new(fieldname).expect("CString::new failed");
let idx = self.trampoline_data.number_outputs;
self.trampoline_data.output_sizes[idx] = size;
self.trampoline_data.number_outputs += 1;
let (size, emode) = (
i32::try_from(size).unwrap(),
emode as bind_ceed::CeedEvalMode,
);
let ierr = unsafe {
bind_ceed::CeedQFunctionAddOutput(self.qf_core.ptr, name_c.as_ptr(), size, emode)
};
self.qf_core.check_error(ierr)?;
Ok(self)
}
pub fn inputs(&self) -> crate::Result<&[crate::QFunctionField]> {
self.qf_core.inputs()
}
pub fn outputs(&self) -> crate::Result<&[crate::QFunctionField]> {
self.qf_core.outputs()
}
}
impl<'a> QFunctionByName<'a> {
pub fn create(ceed: &crate::Ceed, name: &str) -> crate::Result<Self> {
let name_c = CString::new(name).expect("CString::new failed");
let mut ptr = std::ptr::null_mut();
let ierr = unsafe {
bind_ceed::CeedQFunctionCreateInteriorByName(ceed.ptr, name_c.as_ptr(), &mut ptr)
};
ceed.check_error(ierr)?;
Ok(Self {
qf_core: QFunctionCore {
ptr,
_lifeline: PhantomData,
},
})
}
pub fn apply(&self, Q: usize, u: &[Vector], v: &[Vector]) -> crate::Result<i32> {
self.qf_core.apply(Q, u, v)
}
pub fn inputs(&self) -> crate::Result<&[crate::QFunctionField]> {
self.qf_core.inputs()
}
pub fn outputs(&self) -> crate::Result<&[crate::QFunctionField]> {
self.qf_core.outputs()
}
}