use crate::array::Array;
use crate::error::{NumRs2Error, Result};
pub fn packbits(
array: &Array<u8>,
axis: Option<isize>,
bitorder: Option<&str>,
) -> Result<Array<u8>> {
let bitorder_str = bitorder.unwrap_or("big");
if bitorder_str != "big" && bitorder_str != "little" {
return Err(NumRs2Error::InvalidOperation(format!(
"bitorder must be 'big' or 'little', got '{}'",
bitorder_str
)));
}
let data = array.to_vec();
for &val in &data {
if val != 0 && val != 1 {
return Err(NumRs2Error::InvalidOperation(
"packbits requires binary input (0 or 1)".to_string(),
));
}
}
match axis {
Some(ax) => {
let ndim = array.ndim();
let axis_idx = if ax < 0 {
(ndim as isize + ax) as usize
} else {
ax as usize
};
if axis_idx >= ndim {
return Err(NumRs2Error::DimensionMismatch(format!(
"axis {} is out of bounds for array of dimension {}",
ax, ndim
)));
}
let shape = array.shape();
let axis_size = shape[axis_idx];
let packed_axis_size = axis_size.div_ceil(8);
let mut new_shape = shape.clone();
new_shape[axis_idx] = packed_axis_size;
let mut outer_size = 1;
for i in 0..axis_idx {
outer_size *= shape[i];
}
let mut inner_size = 1;
for i in (axis_idx + 1)..ndim {
inner_size *= shape[i];
}
let mut packed_data = Vec::with_capacity(outer_size * packed_axis_size * inner_size);
for outer in 0..outer_size {
for inner in 0..inner_size {
for packed_idx in 0..packed_axis_size {
let mut byte = 0u8;
let start_bit = packed_idx * 8;
let end_bit = ((packed_idx + 1) * 8).min(axis_size);
for bit_idx in start_bit..end_bit {
let flat_idx =
outer * axis_size * inner_size + bit_idx * inner_size + inner;
let bit = data[flat_idx];
if bitorder_str == "big" {
byte |= bit << (7 - (bit_idx - start_bit));
} else {
byte |= bit << (bit_idx - start_bit);
}
}
packed_data.push(byte);
}
}
}
Ok(Array::from_vec(packed_data).reshape(&new_shape))
}
None => {
let flat_data = array.to_vec();
let n = flat_data.len();
let packed_size = n.div_ceil(8);
let mut packed = Vec::with_capacity(packed_size);
for i in 0..packed_size {
let mut byte = 0u8;
let start = i * 8;
let end = ((i + 1) * 8).min(n);
for j in start..end {
let bit = flat_data[j];
if bitorder_str == "big" {
byte |= bit << (7 - (j - start));
} else {
byte |= bit << (j - start);
}
}
packed.push(byte);
}
Ok(Array::from_vec(packed))
}
}
}
pub fn unpackbits(
packed: &Array<u8>,
axis: Option<isize>,
count: Option<usize>,
bitorder: Option<&str>,
) -> Result<Array<u8>> {
let bitorder_str = bitorder.unwrap_or("big");
if bitorder_str != "big" && bitorder_str != "little" {
return Err(NumRs2Error::InvalidOperation(format!(
"bitorder must be 'big' or 'little', got '{}'",
bitorder_str
)));
}
match axis {
Some(ax) => {
let ndim = packed.ndim();
let axis_idx = if ax < 0 {
(ndim as isize + ax) as usize
} else {
ax as usize
};
if axis_idx >= ndim {
return Err(NumRs2Error::DimensionMismatch(format!(
"axis {} is out of bounds for array of dimension {}",
ax, ndim
)));
}
let shape = packed.shape();
let packed_axis_size = shape[axis_idx];
let unpacked_axis_size = count.unwrap_or(packed_axis_size * 8);
if unpacked_axis_size > packed_axis_size * 8 {
return Err(NumRs2Error::InvalidOperation(format!(
"count ({}) cannot be larger than {} (8 * packed_axis_size)",
unpacked_axis_size,
packed_axis_size * 8
)));
}
let mut new_shape = shape.clone();
new_shape[axis_idx] = unpacked_axis_size;
let mut outer_size = 1;
for i in 0..axis_idx {
outer_size *= shape[i];
}
let mut inner_size = 1;
for i in (axis_idx + 1)..ndim {
inner_size *= shape[i];
}
let packed_data = packed.to_vec();
let mut unpacked_data =
Vec::with_capacity(outer_size * unpacked_axis_size * inner_size);
for outer in 0..outer_size {
for inner in 0..inner_size {
for bit_idx in 0..unpacked_axis_size {
let packed_idx = bit_idx / 8;
let bit_offset = bit_idx % 8;
let flat_idx =
outer * packed_axis_size * inner_size + packed_idx * inner_size + inner;
let byte = packed_data[flat_idx];
let bit = if bitorder_str == "big" {
(byte >> (7 - bit_offset)) & 1
} else {
(byte >> bit_offset) & 1
};
unpacked_data.push(bit);
}
}
}
Ok(Array::from_vec(unpacked_data).reshape(&new_shape))
}
None => {
let packed_data = packed.to_vec();
let n_bytes = packed_data.len();
let n_bits = count.unwrap_or(n_bytes * 8);
if n_bits > n_bytes * 8 {
return Err(NumRs2Error::InvalidOperation(format!(
"count ({}) cannot be larger than {} (8 * number of bytes)",
n_bits,
n_bytes * 8
)));
}
let mut unpacked = Vec::with_capacity(n_bits);
for i in 0..n_bits {
let byte_idx = i / 8;
let bit_idx = i % 8;
let byte = packed_data[byte_idx];
let bit = if bitorder_str == "big" {
(byte >> (7 - bit_idx)) & 1
} else {
(byte >> bit_idx) & 1
};
unpacked.push(bit);
}
Ok(Array::from_vec(unpacked))
}
}
}
pub fn unravel_index(
indices: &Array<usize>,
shape: &[usize],
order: Option<&str>,
) -> Result<Vec<Array<usize>>> {
let order_str = order.unwrap_or("C");
if order_str != "C" && order_str != "F" {
return Err(NumRs2Error::InvalidOperation(format!(
"order must be 'C' or 'F', got '{}'",
order_str
)));
}
if shape.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"shape cannot be empty".to_string(),
));
}
let total_size: usize = shape.iter().product();
let indices_data = indices.to_vec();
for &idx in &indices_data {
if idx >= total_size {
return Err(NumRs2Error::InvalidOperation(format!(
"index {} is out of bounds for array with size {}",
idx, total_size
)));
}
}
let n_dims = shape.len();
let n_indices = indices_data.len();
let mut coordinates: Vec<Vec<usize>> = vec![Vec::with_capacity(n_indices); n_dims];
let mut strides = vec![1; n_dims];
if order_str == "C" {
for i in (0..n_dims - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
} else {
for i in 1..n_dims {
strides[i] = strides[i - 1] * shape[i - 1];
}
}
for &flat_idx in &indices_data {
let mut remainder = flat_idx;
if order_str == "C" {
for i in 0..n_dims {
coordinates[i].push(remainder / strides[i]);
remainder %= strides[i];
}
} else {
for i in 0..n_dims {
coordinates[i].push(remainder % shape[i]);
remainder /= shape[i];
}
}
}
let mut result = Vec::with_capacity(n_dims);
for coord_vec in coordinates {
result.push(Array::from_vec(coord_vec).reshape(&indices.shape()));
}
Ok(result)
}
pub fn ravel_multi_index(
multi_index: &[&Array<usize>],
dims: &[usize],
mode: Option<&str>,
order: Option<&str>,
) -> Result<Array<usize>> {
let mode_str = mode.unwrap_or("raise");
if mode_str != "raise" && mode_str != "wrap" && mode_str != "clip" {
return Err(NumRs2Error::InvalidOperation(format!(
"mode must be 'raise', 'wrap', or 'clip', got '{}'",
mode_str
)));
}
let order_str = order.unwrap_or("C");
if order_str != "C" && order_str != "F" {
return Err(NumRs2Error::InvalidOperation(format!(
"order must be 'C' or 'F', got '{}'",
order_str
)));
}
if multi_index.len() != dims.len() {
return Err(NumRs2Error::InvalidOperation(format!(
"number of index arrays ({}) must match number of dimensions ({})",
multi_index.len(),
dims.len()
)));
}
if multi_index.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"multi_index cannot be empty".to_string(),
));
}
let result_shape = multi_index[0].shape();
for idx_array in multi_index.iter().skip(1) {
if idx_array.shape() != result_shape {
return Err(NumRs2Error::ShapeMismatch {
expected: result_shape.to_vec(),
actual: idx_array.shape().to_vec(),
});
}
}
let n_indices = multi_index[0].size();
let n_dims = dims.len();
let mut strides = vec![1; n_dims];
if order_str == "C" {
for i in (0..n_dims - 1).rev() {
strides[i] = strides[i + 1] * dims[i + 1];
}
} else {
for i in 1..n_dims {
strides[i] = strides[i - 1] * dims[i - 1];
}
}
let mut flat_indices = Vec::with_capacity(n_indices);
let coord_data: Vec<Vec<usize>> = multi_index.iter().map(|arr| arr.to_vec()).collect();
for idx in 0..n_indices {
let mut flat_idx = 0;
for dim in 0..n_dims {
let coord = coord_data[dim][idx];
let adjusted_coord = match mode_str {
"raise" => {
if coord >= dims[dim] {
return Err(NumRs2Error::InvalidOperation(format!(
"index {} is out of bounds for axis {} with size {}",
coord, dim, dims[dim]
)));
}
coord
}
"wrap" => coord % dims[dim],
"clip" => coord.min(dims[dim].saturating_sub(1)),
_ => unreachable!(),
};
flat_idx += adjusted_coord * strides[dim];
}
flat_indices.push(flat_idx);
}
Ok(Array::from_vec(flat_indices).reshape(&result_shape))
}