use std::{
ops::{Deref, DerefMut},
os::raw::c_char,
};
use crate::prelude::*;
#[derive(Debug)]
pub enum VectorOpt<'a> {
Some(&'a Vector<'a>),
Active,
None,
}
impl<'a> From<&'a Vector<'_>> for VectorOpt<'a> {
fn from(vec: &'a Vector) -> Self {
debug_assert!(vec.ptr != unsafe { bind_ceed::CEED_VECTOR_ACTIVE });
debug_assert!(vec.ptr != unsafe { bind_ceed::CEED_VECTOR_NONE });
Self::Some(vec)
}
}
impl<'a> VectorOpt<'a> {
pub(crate) fn to_raw(self) -> bind_ceed::CeedVector {
match self {
Self::Some(vec) => vec.ptr,
Self::Active => unsafe { bind_ceed::CEED_VECTOR_ACTIVE },
Self::None => unsafe { bind_ceed::CEED_VECTOR_NONE },
}
}
pub fn is_some(&self) -> bool {
match self {
Self::Some(_) => true,
Self::Active => false,
Self::None => false,
}
}
pub fn is_active(&self) -> bool {
match self {
Self::Some(_) => false,
Self::Active => true,
Self::None => false,
}
}
pub fn is_none(&self) -> bool {
match self {
Self::Some(_) => false,
Self::Active => false,
Self::None => true,
}
}
}
pub struct VectorSliceWrapper<'a> {
pub(crate) vector: crate::Vector<'a>,
pub(crate) _slice: &'a mut [crate::Scalar],
}
impl<'a> Drop for VectorSliceWrapper<'a> {
fn drop(&mut self) {
unsafe {
bind_ceed::CeedVectorTakeArray(
self.vector.ptr,
crate::MemType::Host as bind_ceed::CeedMemType,
std::ptr::null_mut(),
)
};
}
}
impl<'a> VectorSliceWrapper<'a> {
fn from_vector_and_slice_mut<'b>(
vec: &'b mut crate::Vector,
slice: &'a mut [crate::Scalar],
) -> crate::Result<Self> {
assert_eq!(vec.length(), slice.len());
let (host, copy_mode) = (
crate::MemType::Host as bind_ceed::CeedMemType,
crate::CopyMode::UsePointer as bind_ceed::CeedCopyMode,
);
let ierr = unsafe {
bind_ceed::CeedVectorSetArray(
vec.ptr,
host,
copy_mode,
slice.as_ptr() as *mut crate::Scalar,
)
};
vec.check_error(ierr)?;
Ok(Self {
vector: crate::Vector::from_raw(vec.ptr_copy_mut()?)?,
_slice: slice,
})
}
}
#[derive(Debug)]
pub struct Vector<'a> {
pub(crate) ptr: bind_ceed::CeedVector,
_lifeline: PhantomData<&'a ()>,
}
impl From<&'_ Vector<'_>> for bind_ceed::CeedVector {
fn from(vec: &Vector) -> Self {
vec.ptr
}
}
impl<'a> Drop for Vector<'a> {
fn drop(&mut self) {
let not_none_and_active = self.ptr != unsafe { bind_ceed::CEED_VECTOR_NONE }
&& self.ptr != unsafe { bind_ceed::CEED_VECTOR_ACTIVE };
if not_none_and_active {
unsafe { bind_ceed::CeedVectorDestroy(&mut self.ptr) };
}
}
}
impl<'a> fmt::Display for Vector<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut ptr = std::ptr::null_mut();
let mut sizeloc = crate::MAX_BUFFER_LENGTH;
let format = CString::new("%12.8f").expect("CString::new failed");
let format_c: *const c_char = format.into_raw();
let cstring = unsafe {
let file = bind_ceed::open_memstream(&mut ptr, &mut sizeloc);
bind_ceed::CeedVectorView(self.ptr, format_c, file);
bind_ceed::fclose(file);
CString::from_raw(ptr)
};
cstring.to_string_lossy().fmt(f)
}
}
impl<'a> Vector<'a> {
pub fn create(ceed: &crate::Ceed, n: usize) -> crate::Result<Self> {
let n = isize::try_from(n).unwrap();
let mut ptr = std::ptr::null_mut();
let ierr = unsafe { bind_ceed::CeedVectorCreate(ceed.ptr, n, &mut ptr) };
ceed.check_error(ierr)?;
Ok(Self {
ptr,
_lifeline: PhantomData,
})
}
pub(crate) fn from_raw(ptr: bind_ceed::CeedVector) -> crate::Result<Self> {
Ok(Self {
ptr,
_lifeline: PhantomData,
})
}
fn ptr_copy_mut(&mut self) -> crate::Result<bind_ceed::CeedVector> {
let mut ptr_copy = std::ptr::null_mut();
let ierr = unsafe { bind_ceed::CeedVectorReferenceCopy(self.ptr, &mut ptr_copy) };
self.check_error(ierr)?;
Ok(ptr_copy)
}
pub fn copy_from(&mut self, vec_source: &crate::Vector) -> crate::Result<i32> {
let ierr = unsafe { bind_ceed::CeedVectorCopy(vec_source.ptr, self.ptr) };
self.check_error(ierr)
}
pub fn from_slice(ceed: &crate::Ceed, v: &[crate::Scalar]) -> crate::Result<Self> {
let mut x = Self::create(ceed, v.len())?;
x.set_slice(v)?;
Ok(x)
}
pub fn from_array(ceed: &crate::Ceed, v: &mut [crate::Scalar]) -> crate::Result<Self> {
let x = Self::create(ceed, v.len())?;
let (host, user_pointer) = (
crate::MemType::Host as bind_ceed::CeedMemType,
crate::CopyMode::UsePointer as bind_ceed::CeedCopyMode,
);
let v = v.as_ptr() as *mut crate::Scalar;
let ierr = unsafe { bind_ceed::CeedVectorSetArray(x.ptr, host, user_pointer, v) };
ceed.check_error(ierr)?;
Ok(x)
}
#[doc(hidden)]
fn check_error(&self, ierr: i32) -> crate::Result<i32> {
let mut ptr = std::ptr::null_mut();
unsafe {
bind_ceed::CeedVectorGetCeed(self.ptr, &mut ptr);
}
crate::check_error(ptr, ierr)
}
pub fn length(&self) -> usize {
let mut n = 0;
unsafe { bind_ceed::CeedVectorGetLength(self.ptr, &mut n) };
usize::try_from(n).unwrap()
}
pub fn len(&self) -> usize {
self.length()
}
pub fn set_value(&mut self, value: crate::Scalar) -> crate::Result<i32> {
let ierr = unsafe { bind_ceed::CeedVectorSetValue(self.ptr, value) };
self.check_error(ierr)
}
pub fn set_slice(&mut self, slice: &[crate::Scalar]) -> crate::Result<i32> {
assert_eq!(self.length(), slice.len());
let (host, copy_mode) = (
crate::MemType::Host as bind_ceed::CeedMemType,
crate::CopyMode::CopyValues as bind_ceed::CeedCopyMode,
);
let ierr = unsafe {
bind_ceed::CeedVectorSetArray(
self.ptr,
host,
copy_mode,
slice.as_ptr() as *mut crate::Scalar,
)
};
self.check_error(ierr)
}
pub fn wrap_slice_mut<'b>(
&mut self,
slice: &'b mut [crate::Scalar],
) -> crate::Result<VectorSliceWrapper<'b>> {
crate::VectorSliceWrapper::from_vector_and_slice_mut(self, slice)
}
pub fn sync(&self, mtype: crate::MemType) -> crate::Result<i32> {
let ierr =
unsafe { bind_ceed::CeedVectorSyncArray(self.ptr, mtype as bind_ceed::CeedMemType) };
self.check_error(ierr)
}
pub fn view(&self) -> crate::Result<VectorView> {
VectorView::new(self)
}
pub fn view_mut(&mut self) -> crate::Result<VectorViewMut> {
VectorViewMut::new(self)
}
pub fn norm(&self, ntype: crate::NormType) -> crate::Result<crate::Scalar> {
let mut res: crate::Scalar = 0.0;
let ierr = unsafe {
bind_ceed::CeedVectorNorm(self.ptr, ntype as bind_ceed::CeedNormType, &mut res)
};
self.check_error(ierr)?;
Ok(res)
}
#[allow(unused_mut)]
pub fn scale(mut self, alpha: crate::Scalar) -> crate::Result<Self> {
let ierr = unsafe { bind_ceed::CeedVectorScale(self.ptr, alpha) };
self.check_error(ierr)?;
Ok(self)
}
#[allow(unused_mut)]
pub fn axpy(mut self, alpha: crate::Scalar, x: &crate::Vector) -> crate::Result<Self> {
let ierr = unsafe { bind_ceed::CeedVectorAXPY(self.ptr, alpha, x.ptr) };
self.check_error(ierr)?;
Ok(self)
}
#[allow(unused_mut)]
pub fn axpby(
mut self,
alpha: crate::Scalar,
beta: crate::Scalar,
x: &crate::Vector,
) -> crate::Result<Self> {
let ierr = unsafe { bind_ceed::CeedVectorAXPBY(self.ptr, alpha, beta, x.ptr) };
self.check_error(ierr)?;
Ok(self)
}
#[allow(unused_mut)]
pub fn pointwise_mult(mut self, x: &crate::Vector, y: &crate::Vector) -> crate::Result<Self> {
let ierr = unsafe { bind_ceed::CeedVectorPointwiseMult(self.ptr, x.ptr, y.ptr) };
self.check_error(ierr)?;
Ok(self)
}
#[allow(unused_mut)]
pub fn pointwise_scale(mut self, x: &crate::Vector) -> crate::Result<Self> {
let ierr = unsafe { bind_ceed::CeedVectorPointwiseMult(self.ptr, self.ptr, x.ptr) };
self.check_error(ierr)?;
Ok(self)
}
#[allow(unused_mut)]
pub fn pointwise_square(mut self) -> crate::Result<Self> {
let ierr = unsafe { bind_ceed::CeedVectorPointwiseMult(self.ptr, self.ptr, self.ptr) };
self.check_error(ierr)?;
Ok(self)
}
}
#[derive(Debug)]
pub struct VectorView<'a> {
vec: &'a Vector<'a>,
array: *const crate::Scalar,
}
impl<'a> VectorView<'a> {
fn new(vec: &'a Vector) -> crate::Result<Self> {
let mut array = std::ptr::null();
let ierr = unsafe {
bind_ceed::CeedVectorGetArrayRead(
vec.ptr,
crate::MemType::Host as bind_ceed::CeedMemType,
&mut array,
)
};
vec.check_error(ierr)?;
Ok(Self {
vec: vec,
array: array,
})
}
}
impl<'a> Drop for VectorView<'a> {
fn drop(&mut self) {
unsafe {
bind_ceed::CeedVectorRestoreArrayRead(self.vec.ptr, &mut self.array);
}
}
}
impl<'a> Deref for VectorView<'a> {
type Target = [crate::Scalar];
fn deref(&self) -> &[crate::Scalar] {
unsafe { std::slice::from_raw_parts(self.array, self.vec.len()) }
}
}
impl<'a> fmt::Display for VectorView<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "VectorView({:?})", self.deref())
}
}
#[derive(Debug)]
pub struct VectorViewMut<'a> {
vec: &'a Vector<'a>,
array: *mut crate::Scalar,
}
impl<'a> VectorViewMut<'a> {
fn new(vec: &'a mut Vector) -> crate::Result<Self> {
let mut ptr = std::ptr::null_mut();
let ierr = unsafe {
bind_ceed::CeedVectorGetArray(
vec.ptr,
crate::MemType::Host as bind_ceed::CeedMemType,
&mut ptr,
)
};
vec.check_error(ierr)?;
Ok(Self {
vec: vec,
array: ptr,
})
}
}
impl<'a> Drop for VectorViewMut<'a> {
fn drop(&mut self) {
unsafe {
bind_ceed::CeedVectorRestoreArray(self.vec.ptr, &mut self.array);
}
}
}
impl<'a> Deref for VectorViewMut<'a> {
type Target = [crate::Scalar];
fn deref(&self) -> &[crate::Scalar] {
unsafe { std::slice::from_raw_parts(self.array, self.vec.len()) }
}
}
impl<'a> DerefMut for VectorViewMut<'a> {
fn deref_mut(&mut self) -> &mut [crate::Scalar] {
unsafe { std::slice::from_raw_parts_mut(self.array, self.vec.len()) }
}
}
impl<'a> fmt::Display for VectorViewMut<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "VectorViewMut({:?})", self.deref())
}
}