use std::ffi::c_void;
use std::ptr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use flodl_sys::{self as ffi, FlodlTensor};
use crate::tensor::{
check_err, current_cuda_device, set_current_cuda_device,
Device, Result, Tensor, TensorError,
};
use super::cuda_stream::CudaStream;
#[derive(Clone, Copy, Debug)]
#[repr(i32)]
pub enum ReduceOp {
Sum = 0,
Prod = 1,
Max = 2,
Min = 3,
Avg = 4,
}
pub struct NcclComms {
handle: *mut c_void,
devices: Vec<Device>,
}
unsafe impl Send for NcclComms {}
impl NcclComms {
#[cfg(test)]
pub(crate) unsafe fn from_raw(handle: *mut c_void, devices: Vec<Device>) -> Self {
NcclComms { handle, devices }
}
pub fn new(devices: &[Device]) -> Result<Self> {
if devices.len() < 2 {
return Err(TensorError::new(
"NcclComms requires at least 2 devices",
));
}
let mut devlist: Vec<i32> = Vec::with_capacity(devices.len());
for &dev in devices {
match dev {
Device::CUDA(idx) => devlist.push(idx as i32),
Device::CPU => {
return Err(TensorError::new(
"NcclComms requires CUDA devices, got CPU",
))
}
}
}
let mut handle: *mut c_void = ptr::null_mut();
let saved = current_cuda_device();
let err = unsafe {
ffi::flodl_nccl_init(
devlist.len() as i32,
devlist.as_ptr(),
&mut handle,
)
};
set_current_cuda_device(saved);
check_err(err)?;
Ok(NcclComms {
handle,
devices: devices.to_vec(),
})
}
pub fn all_reduce(&self, tensors: &[&Tensor], op: ReduceOp) -> Result<()> {
self.validate_tensors(tensors, "all_reduce")?;
let mut handles: Vec<FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
let saved = current_cuda_device();
let err = unsafe {
ffi::flodl_nccl_all_reduce(
self.handle,
handles.as_mut_ptr(),
ptr::null_mut(),
op as i32,
)
};
set_current_cuda_device(saved);
check_err(err)
}
pub fn all_reduce_on_streams(
&self,
tensors: &[&Tensor],
op: ReduceOp,
streams: &[&CudaStream],
) -> Result<()> {
self.validate_tensors(tensors, "all_reduce_on_streams")?;
if streams.len() != self.devices.len() {
return Err(TensorError::new(&format!(
"all_reduce_on_streams: expected {} streams, got {}",
self.devices.len(), streams.len()
)));
}
let mut handles: Vec<FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
let mut stream_ptrs: Vec<*mut c_void> = streams.iter().map(|s| s.as_ptr()).collect();
let saved = current_cuda_device();
let err = unsafe {
ffi::flodl_nccl_all_reduce(
self.handle,
handles.as_mut_ptr(),
stream_ptrs.as_mut_ptr(),
op as i32,
)
};
set_current_cuda_device(saved);
check_err(err)
}
pub fn broadcast(&self, tensors: &[&Tensor], root: usize) -> Result<()> {
self.validate_tensors(tensors, "broadcast")?;
if root >= self.devices.len() {
return Err(TensorError::new(&format!(
"broadcast: root {} out of range (have {} devices)",
root, self.devices.len()
)));
}
let mut handles: Vec<FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
let saved = current_cuda_device();
let err = unsafe {
ffi::flodl_nccl_broadcast(
self.handle,
handles.as_mut_ptr(),
ptr::null_mut(),
root as i32,
)
};
set_current_cuda_device(saved);
check_err(err)
}
pub fn broadcast_on_streams(
&self,
tensors: &[&Tensor],
root: usize,
streams: &[&CudaStream],
) -> Result<()> {
self.validate_tensors(tensors, "broadcast_on_streams")?;
if root >= self.devices.len() {
return Err(TensorError::new(&format!(
"broadcast_on_streams: root {} out of range", root
)));
}
if streams.len() != self.devices.len() {
return Err(TensorError::new(&format!(
"broadcast_on_streams: expected {} streams, got {}",
self.devices.len(), streams.len()
)));
}
let mut handles: Vec<FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
let mut stream_ptrs: Vec<*mut c_void> = streams.iter().map(|s| s.as_ptr()).collect();
let saved = current_cuda_device();
let err = unsafe {
ffi::flodl_nccl_broadcast(
self.handle,
handles.as_mut_ptr(),
stream_ptrs.as_mut_ptr(),
root as i32,
)
};
set_current_cuda_device(saved);
check_err(err)
}
pub fn size(&self) -> usize {
self.devices.len()
}
pub fn devices(&self) -> &[Device] {
&self.devices
}
fn validate_tensors(&self, tensors: &[&Tensor], op: &str) -> Result<()> {
if tensors.len() != self.devices.len() {
return Err(TensorError::new(&format!(
"{}: expected {} tensors (one per device), got {}",
op, self.devices.len(), tensors.len()
)));
}
Ok(())
}
pub fn split(self) -> Result<Vec<NcclRankComm>> {
let mut comms = Vec::with_capacity(self.devices.len());
for i in 0..self.devices.len() {
let mut rank_handle: *mut c_void = ptr::null_mut();
let err = unsafe {
ffi::flodl_nccl_split_rank(
self.handle,
i as i32,
&mut rank_handle,
)
};
check_err(err)?;
let abort_handle = Arc::new(NcclAbortHandle {
ptr: rank_handle,
aborted: AtomicBool::new(false),
});
comms.push(NcclRankComm {
handle: rank_handle,
rank: i,
world_size: self.devices.len(),
abort_handle,
});
}
Ok(comms)
}
}
impl Drop for NcclComms {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe { ffi::flodl_nccl_destroy(self.handle) };
self.handle = ptr::null_mut();
}
}
}
pub const NCCL_UNIQUE_ID_BYTES: usize = 128;
#[derive(Clone)]
pub struct NcclUniqueId {
bytes: [u8; NCCL_UNIQUE_ID_BYTES],
}
unsafe impl Send for NcclUniqueId {}
unsafe impl Sync for NcclUniqueId {}
impl NcclUniqueId {
pub fn new() -> Result<Self> {
let mut bytes = [0u8; NCCL_UNIQUE_ID_BYTES];
let err = unsafe { ffi::flodl_nccl_get_unique_id(bytes.as_mut_ptr()) };
check_err(err)?;
Ok(NcclUniqueId { bytes })
}
pub fn as_bytes(&self) -> &[u8; NCCL_UNIQUE_ID_BYTES] {
&self.bytes
}
}
impl std::fmt::Debug for NcclUniqueId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NcclUniqueId").finish()
}
}
pub struct NcclAbortHandle {
ptr: *mut c_void,
aborted: AtomicBool,
}
unsafe impl Send for NcclAbortHandle {}
unsafe impl Sync for NcclAbortHandle {}
impl NcclAbortHandle {
pub fn abort(&self) -> Result<()> {
if self.aborted.swap(true, Ordering::AcqRel) {
return Ok(()); }
let err = unsafe { ffi::flodl_nccl_abort_rank(self.ptr) };
check_err(err)
}
pub fn is_aborted(&self) -> bool {
self.aborted.load(Ordering::Acquire)
}
fn mark_destroyed(&self) {
self.aborted.store(true, Ordering::Release);
}
}
pub struct NcclRankComm {
handle: *mut c_void,
rank: usize,
world_size: usize,
abort_handle: Arc<NcclAbortHandle>,
}
unsafe impl Send for NcclRankComm {}
impl NcclRankComm {
pub fn init_rank(rank: usize, world_size: usize, uid: &NcclUniqueId) -> Result<Self> {
if rank >= world_size {
return Err(TensorError::new(&format!(
"NcclRankComm: rank {} >= world_size {}", rank, world_size
)));
}
if world_size < 2 {
return Err(TensorError::new(
"NcclRankComm requires world_size >= 2"
));
}
let mut handle: *mut c_void = ptr::null_mut();
let err = unsafe {
ffi::flodl_nccl_init_rank(
rank as i32,
world_size as i32,
uid.bytes.as_ptr(),
&mut handle,
)
};
check_err(err)?;
let abort_handle = Arc::new(NcclAbortHandle {
ptr: handle,
aborted: AtomicBool::new(false),
});
Ok(NcclRankComm { handle, rank, world_size, abort_handle })
}
pub fn rank(&self) -> usize {
self.rank
}
pub fn world_size(&self) -> usize {
self.world_size
}
pub fn abort_handle(&self) -> Arc<NcclAbortHandle> {
self.abort_handle.clone()
}
pub fn all_reduce(&self, tensors: &[&Tensor], op: ReduceOp) -> Result<()> {
let mut handles: Vec<ffi::FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
let err = unsafe {
ffi::flodl_nccl_all_reduce_rank(
self.handle,
handles.as_mut_ptr(),
handles.len() as i32,
ptr::null_mut(),
op as i32,
)
};
check_err(err)
}
pub fn all_reduce_on_stream(
&self,
tensors: &[&Tensor],
op: ReduceOp,
stream: &CudaStream,
) -> Result<()> {
let mut handles: Vec<ffi::FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
let err = unsafe {
ffi::flodl_nccl_all_reduce_rank(
self.handle,
handles.as_mut_ptr(),
handles.len() as i32,
stream.as_ptr(),
op as i32,
)
};
check_err(err)
}
}
impl Drop for NcclRankComm {
fn drop(&mut self) {
if !self.handle.is_null() && !self.abort_handle.is_aborted() {
unsafe { ffi::flodl_nccl_destroy_rank(self.handle) };
self.handle = ptr::null_mut();
}
self.abort_handle.mark_destroyed();
}
}
impl std::fmt::Debug for NcclRankComm {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NcclRankComm")
.field("rank", &self.rank)
.field("world_size", &self.world_size)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{test_device, cuda_device_count, cuda_synchronize, TensorOptions, DType};
use crate::distributed::ddp::NCCL_LOCK;
fn require_multi_gpu() -> bool {
if !test_device().is_cuda() || cuda_device_count() < 2 {
return false;
}
for i in 0..2 {
let opts = TensorOptions { dtype: DType::Float32, device: Device::CUDA(i) };
if Tensor::zeros(&[1], opts).is_err() {
eprintln!("Device CUDA({i}) cannot run compute kernels, skipping multi-GPU test");
return false;
}
}
true
}
#[test]
fn test_nccl_requires_two_devices() {
let result = NcclComms::new(&[Device::CUDA(0)]);
assert!(result.is_err(), "NcclComms should require 2+ devices");
}
#[test]
fn test_nccl_rejects_cpu() {
let result = NcclComms::new(&[Device::CPU, Device::CPU]);
assert!(result.is_err(), "NcclComms should reject CPU devices");
}
#[test]
#[ignore = "NCCL init needs exclusive GPU; run with: fdl cuda-test-all"]
fn test_nccl_init_destroy() {
if !require_multi_gpu() { return; }
let _lock = NCCL_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let comms = NcclComms::new(&[Device::CUDA(0), Device::CUDA(1)]).unwrap();
assert_eq!(comms.size(), 2);
assert_eq!(comms.devices(), &[Device::CUDA(0), Device::CUDA(1)]);
}
#[test]
#[ignore = "NCCL init needs exclusive GPU; run with: fdl cuda-test-all"]
fn test_nccl_broadcast() {
if !require_multi_gpu() { return; }
let _lock = NCCL_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let comms = NcclComms::new(&[Device::CUDA(0), Device::CUDA(1)]).unwrap();
let opts0 = TensorOptions { dtype: DType::Float32, device: Device::CUDA(0) };
let opts1 = TensorOptions { dtype: DType::Float32, device: Device::CUDA(1) };
let t0 = Tensor::full(&[64], 42.0, opts0).unwrap();
let t1 = Tensor::zeros(&[64], opts1).unwrap();
comms.broadcast(&[&t0, &t1], 0).unwrap();
cuda_synchronize(0);
cuda_synchronize(1);
let vals0 = t0.to_f32_vec().unwrap();
let vals1 = t1.to_f32_vec().unwrap();
assert!(vals0.iter().all(|&v| (v - 42.0).abs() < 1e-5),
"device 0 should still have 42.0");
assert!(vals1.iter().all(|&v| (v - 42.0).abs() < 1e-5),
"device 1 should have 42.0 after broadcast");
}
#[test]
#[ignore = "NCCL init needs exclusive GPU; run with: fdl cuda-test-all"]
fn test_nccl_all_reduce_sum() {
if !require_multi_gpu() { return; }
let _lock = NCCL_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let comms = NcclComms::new(&[Device::CUDA(0), Device::CUDA(1)]).unwrap();
let opts0 = TensorOptions { dtype: DType::Float32, device: Device::CUDA(0) };
let opts1 = TensorOptions { dtype: DType::Float32, device: Device::CUDA(1) };
let t0 = Tensor::full(&[128], 1.0, opts0).unwrap();
let t1 = Tensor::full(&[128], 2.0, opts1).unwrap();
comms.all_reduce(&[&t0, &t1], ReduceOp::Sum).unwrap();
cuda_synchronize(0);
cuda_synchronize(1);
let vals0 = t0.to_f32_vec().unwrap();
let vals1 = t1.to_f32_vec().unwrap();
assert!(vals0.iter().all(|&v| (v - 3.0).abs() < 1e-5),
"device 0 should have 3.0 after AllReduce Sum");
assert!(vals1.iter().all(|&v| (v - 3.0).abs() < 1e-5),
"device 1 should have 3.0 after AllReduce Sum");
}
#[test]
#[ignore = "NCCL init needs exclusive GPU; run with: fdl cuda-test-all"]
fn test_nccl_all_reduce_avg() {
if !require_multi_gpu() { return; }
let _lock = NCCL_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let comms = NcclComms::new(&[Device::CUDA(0), Device::CUDA(1)]).unwrap();
let opts0 = TensorOptions { dtype: DType::Float32, device: Device::CUDA(0) };
let opts1 = TensorOptions { dtype: DType::Float32, device: Device::CUDA(1) };
let t0 = Tensor::full(&[64], 10.0, opts0).unwrap();
let t1 = Tensor::full(&[64], 20.0, opts1).unwrap();
comms.all_reduce(&[&t0, &t1], ReduceOp::Avg).unwrap();
cuda_synchronize(0);
cuda_synchronize(1);
let vals0 = t0.to_f32_vec().unwrap();
let vals1 = t1.to_f32_vec().unwrap();
assert!(vals0.iter().all(|&v| (v - 15.0).abs() < 1e-5),
"device 0 should have 15.0 after AllReduce Avg");
assert!(vals1.iter().all(|&v| (v - 15.0).abs() < 1e-5),
"device 1 should have 15.0 after AllReduce Avg");
}
#[test]
#[ignore = "NCCL init needs exclusive GPU; run with: fdl cuda-test-all"]
fn test_nccl_all_reduce_on_streams() {
if !require_multi_gpu() { return; }
let _lock = NCCL_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let comms = NcclComms::new(&[Device::CUDA(0), Device::CUDA(1)]).unwrap();
let opts0 = TensorOptions { dtype: DType::Float32, device: Device::CUDA(0) };
let opts1 = TensorOptions { dtype: DType::Float32, device: Device::CUDA(1) };
let stream0 = CudaStream::new(Device::CUDA(0), false).unwrap();
let stream1 = CudaStream::new(Device::CUDA(1), false).unwrap();
let t0 = Tensor::full(&[32], 5.0, opts0).unwrap();
let t1 = Tensor::full(&[32], 7.0, opts1).unwrap();
comms.all_reduce_on_streams(
&[&t0, &t1], ReduceOp::Sum, &[&stream0, &stream1],
).unwrap();
stream0.synchronize().unwrap();
stream1.synchronize().unwrap();
let vals0 = t0.to_f32_vec().unwrap();
let vals1 = t1.to_f32_vec().unwrap();
assert!(vals0.iter().all(|&v| (v - 12.0).abs() < 1e-5),
"device 0 should have 12.0 after AllReduce Sum on streams");
assert!(vals1.iter().all(|&v| (v - 12.0).abs() < 1e-5),
"device 1 should have 12.0 after AllReduce Sum on streams");
}
#[test]
fn test_nccl_rank_comm_rejects_invalid_rank() {
let result = NcclRankComm::init_rank(2, 2, &NcclUniqueId { bytes: [0; NCCL_UNIQUE_ID_BYTES] });
assert!(result.is_err(), "rank >= world_size should fail");
}
#[test]
fn test_nccl_rank_comm_rejects_world_size_one() {
let result = NcclRankComm::init_rank(0, 1, &NcclUniqueId { bytes: [0; NCCL_UNIQUE_ID_BYTES] });
assert!(result.is_err(), "world_size < 2 should fail");
}
#[test]
fn test_nccl_unique_id_clone() {
fn assert_send_sync_clone<T: Send + Sync + Clone>() {}
assert_send_sync_clone::<NcclUniqueId>();
}
#[test]
fn test_nccl_rank_comm_send() {
fn assert_send<T: Send>() {}
assert_send::<NcclRankComm>();
}
#[test]
#[ignore = "NCCL init needs exclusive GPU; run with: fdl cuda-test-all"]
fn test_nccl_rank_comm_init_and_reduce() {
if !require_multi_gpu() { return; }
let _lock = NCCL_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let uid = NcclUniqueId::new().unwrap();
let uid0 = uid.clone();
let uid1 = uid;
let h0 = std::thread::spawn(move || {
crate::tensor::set_current_cuda_device(0);
NcclRankComm::init_rank(0, 2, &uid0).unwrap()
});
let h1 = std::thread::spawn(move || {
crate::tensor::set_current_cuda_device(1);
NcclRankComm::init_rank(1, 2, &uid1).unwrap()
});
let comm0 = h0.join().unwrap();
let comm1 = h1.join().unwrap();
assert_eq!(comm0.rank(), 0);
assert_eq!(comm0.world_size(), 2);
assert_eq!(comm1.rank(), 1);
let opts0 = TensorOptions { dtype: DType::Float32, device: Device::CUDA(0) };
let opts1 = TensorOptions { dtype: DType::Float32, device: Device::CUDA(1) };
let t0 = Tensor::full(&[64], 10.0, opts0).unwrap();
let t1 = Tensor::full(&[64], 20.0, opts1).unwrap();
let t0_clone = t0.clone();
let t1_clone = t1.clone();
let h0 = std::thread::spawn(move || {
crate::tensor::set_current_cuda_device(0);
comm0.all_reduce(&[&t0_clone], ReduceOp::Avg).unwrap();
cuda_synchronize(0);
});
let h1 = std::thread::spawn(move || {
crate::tensor::set_current_cuda_device(1);
comm1.all_reduce(&[&t1_clone], ReduceOp::Avg).unwrap();
cuda_synchronize(1);
});
h0.join().unwrap();
h1.join().unwrap();
let vals0 = t0.to_f32_vec().unwrap();
let vals1 = t1.to_f32_vec().unwrap();
assert!(vals0.iter().all(|&v| (v - 15.0).abs() < 1e-5),
"rank 0 should have 15.0 after AllReduce Avg, got {}", vals0[0]);
assert!(vals1.iter().all(|&v| (v - 15.0).abs() < 1e-5),
"rank 1 should have 15.0 after AllReduce Avg, got {}", vals1[0]);
}
#[test]
#[ignore = "NCCL init needs exclusive GPU; run with: fdl cuda-test-all"]
fn test_nccl_rank_comm_on_stream() {
if !require_multi_gpu() { return; }
let _lock = NCCL_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let uid = NcclUniqueId::new().unwrap();
let uid0 = uid.clone();
let uid1 = uid;
let h0 = std::thread::spawn(move || {
crate::tensor::set_current_cuda_device(0);
NcclRankComm::init_rank(0, 2, &uid0).unwrap()
});
let h1 = std::thread::spawn(move || {
crate::tensor::set_current_cuda_device(1);
NcclRankComm::init_rank(1, 2, &uid1).unwrap()
});
let comm0 = h0.join().unwrap();
let comm1 = h1.join().unwrap();
let opts0 = TensorOptions { dtype: DType::Float32, device: Device::CUDA(0) };
let opts1 = TensorOptions { dtype: DType::Float32, device: Device::CUDA(1) };
let stream0 = CudaStream::new(Device::CUDA(0), false).unwrap();
let stream1 = CudaStream::new(Device::CUDA(1), false).unwrap();
let t0 = Tensor::full(&[32], 3.0, opts0).unwrap();
let t1 = Tensor::full(&[32], 7.0, opts1).unwrap();
let t0c = t0.clone();
let t1c = t1.clone();
let h0 = std::thread::spawn(move || {
crate::tensor::set_current_cuda_device(0);
comm0.all_reduce_on_stream(&[&t0c], ReduceOp::Sum, &stream0).unwrap();
stream0.synchronize().unwrap();
});
let h1 = std::thread::spawn(move || {
crate::tensor::set_current_cuda_device(1);
comm1.all_reduce_on_stream(&[&t1c], ReduceOp::Sum, &stream1).unwrap();
stream1.synchronize().unwrap();
});
h0.join().unwrap();
h1.join().unwrap();
let vals0 = t0.to_f32_vec().unwrap();
let vals1 = t1.to_f32_vec().unwrap();
assert!(vals0.iter().all(|&v| (v - 10.0).abs() < 1e-5),
"rank 0 should have 10.0 after Sum, got {}", vals0[0]);
assert!(vals1.iter().all(|&v| (v - 10.0).abs() < 1e-5),
"rank 1 should have 10.0 after Sum, got {}", vals1[0]);
}
#[test]
#[ignore = "NCCL init needs exclusive GPU; run with: fdl cuda-test-all"]
fn test_nccl_rank_comm_multi_tensor_batch() {
if !require_multi_gpu() { return; }
let _lock = NCCL_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let uid = NcclUniqueId::new().unwrap();
let uid0 = uid.clone();
let uid1 = uid;
let h0 = std::thread::spawn(move || {
crate::tensor::set_current_cuda_device(0);
NcclRankComm::init_rank(0, 2, &uid0).unwrap()
});
let h1 = std::thread::spawn(move || {
crate::tensor::set_current_cuda_device(1);
NcclRankComm::init_rank(1, 2, &uid1).unwrap()
});
let comm0 = h0.join().unwrap();
let comm1 = h1.join().unwrap();
let opts0 = TensorOptions { dtype: DType::Float32, device: Device::CUDA(0) };
let opts1 = TensorOptions { dtype: DType::Float32, device: Device::CUDA(1) };
let a0 = Tensor::full(&[16], 1.0, opts0).unwrap();
let b0 = Tensor::full(&[8], 100.0, opts0).unwrap();
let a1 = Tensor::full(&[16], 3.0, opts1).unwrap();
let b1 = Tensor::full(&[8], 200.0, opts1).unwrap();
let a0c = a0.clone();
let b0c = b0.clone();
let a1c = a1.clone();
let b1c = b1.clone();
let h0 = std::thread::spawn(move || {
crate::tensor::set_current_cuda_device(0);
comm0.all_reduce(&[&a0c, &b0c], ReduceOp::Avg).unwrap();
cuda_synchronize(0);
});
let h1 = std::thread::spawn(move || {
crate::tensor::set_current_cuda_device(1);
comm1.all_reduce(&[&a1c, &b1c], ReduceOp::Avg).unwrap();
cuda_synchronize(1);
});
h0.join().unwrap();
h1.join().unwrap();
let va0 = a0.to_f32_vec().unwrap();
let vb0 = b0.to_f32_vec().unwrap();
assert!(va0.iter().all(|&v| (v - 2.0).abs() < 1e-5), "a0 should be 2.0");
assert!(vb0.iter().all(|&v| (v - 150.0).abs() < 1e-5), "b0 should be 150.0");
let va1 = a1.to_f32_vec().unwrap();
let vb1 = b1.to_f32_vec().unwrap();
assert!(va1.iter().all(|&v| (v - 2.0).abs() < 1e-5), "a1 should be 2.0");
assert!(vb1.iter().all(|&v| (v - 150.0).abs() < 1e-5), "b1 should be 150.0");
}
}