use std::fmt;
use std::sync::Arc;
use crate::enums::error::MinarrowError;
use crate::enums::shape_dim::ShapeDim;
use crate::structs::buffer::Buffer;
use crate::structs::shared_buffer::SharedBuffer;
use crate::traits::{concatenate::Concatenate, shape::Shape};
use crate::{Array, Field, FieldArray, FloatArray, NumericArray, Table, Vec64};
#[repr(C, align(64))]
#[derive(Clone, PartialEq)]
pub struct Matrix {
pub n_rows: usize,
pub n_cols: usize,
pub stride: usize,
pub data: Vec64<f64>,
pub name: Option<String>,
}
const ALIGN_ELEMS: usize = 64 / std::mem::size_of::<f64>();
#[inline]
const fn aligned_stride(n_rows: usize) -> usize {
(n_rows + ALIGN_ELEMS - 1) & !(ALIGN_ELEMS - 1)
}
impl Matrix {
pub fn new(n_rows: usize, n_cols: usize, name: Option<String>) -> Self {
let stride = aligned_stride(n_rows);
let len = stride * n_cols;
let mut data = Vec64::with_capacity(len);
data.0.resize(len, 0.0);
Matrix { n_rows, n_cols, stride, data, name }
}
pub fn from_f64_aligned(data: Vec64<f64>, n_rows: usize, n_cols: usize, name: Option<String>) -> Self {
let stride = aligned_stride(n_rows);
assert_eq!(
data.len(),
stride * n_cols,
"Matrix: padded buffer length does not match stride * n_cols"
);
Matrix { n_rows, n_cols, stride, data, name }
}
pub fn from_f64_unaligned(src: &[f64], n_rows: usize, n_cols: usize, name: Option<String>) -> Self {
assert_eq!(
src.len(),
n_rows * n_cols,
"Matrix shape does not match buffer length"
);
let stride = aligned_stride(n_rows);
if stride == n_rows {
let data = Vec64::from(src);
return Matrix { n_rows, n_cols, stride, data, name };
}
let mut data = Vec64::with_capacity(stride * n_cols);
data.0.resize(stride * n_cols, 0.0);
for col in 0..n_cols {
let src_start = col * n_rows;
let dst_start = col * stride;
data.as_mut_slice()[dst_start..dst_start + n_rows]
.copy_from_slice(&src[src_start..src_start + n_rows]);
}
Matrix { n_rows, n_cols, stride, data, name }
}
#[inline]
pub fn get(&self, row: usize, col: usize) -> f64 {
debug_assert!(row < self.n_rows, "Row out of bounds");
debug_assert!(col < self.n_cols, "Col out of bounds");
self.data[col * self.stride + row]
}
#[inline]
pub fn set(&mut self, row: usize, col: usize, value: f64) {
debug_assert!(row < self.n_rows, "Row out of bounds");
debug_assert!(col < self.n_cols, "Col out of bounds");
self.data[col * self.stride + row] = value;
}
#[inline]
pub fn is_empty(&self) -> bool {
self.n_rows == 0 || self.n_cols == 0
}
#[inline]
pub fn len(&self) -> usize {
self.n_rows * self.n_cols
}
#[inline]
pub fn as_slice(&self) -> &[f64] {
&self.data
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [f64] {
&mut self.data
}
pub fn columns(&self) -> Vec<&[f64]> {
(0..self.n_cols)
.map(|col| &self.data[(col * self.stride)..(col * self.stride + self.n_rows)])
.collect()
}
pub fn columns_mut(&mut self) -> Vec<&mut [f64]> {
let n_rows = self.n_rows;
let stride = self.stride;
let n_cols = self.n_cols;
let ptr = self.data.as_mut_slice().as_mut_ptr();
let mut result = Vec::with_capacity(n_cols);
for col in 0..n_cols {
let start = col * stride;
unsafe {
let col_ptr = ptr.add(start);
let slice = std::slice::from_raw_parts_mut(col_ptr, n_rows);
result.push(slice);
}
}
result
}
#[inline]
pub fn col(&self, col: usize) -> &[f64] {
debug_assert!(col < self.n_cols, "Col out of bounds");
&self.data[(col * self.stride)..(col * self.stride + self.n_rows)]
}
#[inline]
pub fn col_mut(&mut self, col: usize) -> &mut [f64] {
debug_assert!(col < self.n_cols, "Col out of bounds");
let start = col * self.stride;
&mut self.data[start..start + self.n_rows]
}
#[inline]
pub fn row(&self, row: usize) -> Vec<f64> {
debug_assert!(row < self.n_rows, "Row out of bounds");
(0..self.n_cols).map(|col| self.data[col * self.stride + row]).collect()
}
#[inline]
pub fn set_name(&mut self, name: impl Into<String>) {
self.name = Some(name.into());
}
pub fn n_cols(&self) -> usize {
self.n_cols
}
#[inline]
pub fn n_rows(&self) -> usize {
self.n_rows
}
#[inline]
pub fn m(&self) -> i32 {
self.n_rows as i32
}
#[inline]
pub fn n(&self) -> i32 {
self.n_cols as i32
}
#[inline]
pub fn lda(&self) -> i32 {
self.stride as i32
}
pub fn to_table(self, fields: Vec<Field>) -> Result<Table, MinarrowError> {
if fields.len() != self.n_cols {
return Err(MinarrowError::ShapeError {
message: format!(
"to_table: expected {} fields for {} columns, got {}",
self.n_cols, self.n_cols, fields.len()
),
});
}
let n_rows = self.n_rows;
let n_cols = self.n_cols;
let stride = self.stride;
let shared = unsafe { SharedBuffer::from_vec64_typed(self.data) };
let mut cols = Vec::with_capacity(n_cols);
for (i, field) in fields.into_iter().enumerate() {
let col_offset = i * stride;
let buf: Buffer<f64> = Buffer::from_shared_column(shared.clone(), col_offset, n_rows);
let float_arr = FloatArray::new(buf, None);
let array = Array::NumericArray(NumericArray::Float64(Arc::new(float_arr)));
cols.push(FieldArray::new(field, array));
}
Ok(Table::new(self.name.unwrap_or_default(), Some(cols)))
}
pub fn to_table_gen(self) -> Table {
let n_cols = self.n_cols;
let fields: Vec<Field> = (0..n_cols)
.map(|i| Field::new(format!("col_{}", i), crate::ffi::arrow_dtype::ArrowType::Float64, false, None))
.collect();
self.to_table(fields).unwrap()
}
}
impl Shape for Matrix {
fn shape(&self) -> ShapeDim {
ShapeDim::Rank2 {
rows: self.n_rows(),
cols: self.n_cols(),
}
}
}
impl Concatenate for Matrix {
fn concat(self, other: Self) -> Result<Self, MinarrowError> {
if self.n_cols != other.n_cols {
return Err(MinarrowError::IncompatibleTypeError {
from: "Matrix",
to: "Matrix",
message: Some(format!(
"Cannot concatenate matrices with different column counts: {} vs {}",
self.n_cols, other.n_cols
)),
});
}
if self.is_empty() && other.is_empty() {
return Ok(Matrix::new(
0,
0,
None,
));
}
let result_n_rows = self.n_rows + other.n_rows;
let result_n_cols = self.n_cols;
let result_stride = aligned_stride(result_n_rows);
let pad = result_stride - result_n_rows;
let mut result_data = Vec64::with_capacity(result_stride * result_n_cols);
for col in 0..result_n_cols {
result_data.extend_from_slice(self.col(col));
result_data.extend_from_slice(other.col(col));
if pad > 0 {
result_data.extend_from_slice(&[0.0; ALIGN_ELEMS][..pad]);
}
}
Ok(Matrix {
n_rows: result_n_rows,
n_cols: result_n_cols,
stride: result_stride,
data: result_data,
name: None,
})
}
}
impl fmt::Debug for Matrix {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Matrix{}: {} × {} [col-major]",
self.name.as_deref().map_or(String::new(), |n| format!(" '{}'", n)),
self.n_rows, self.n_cols
)?;
for row in 0..self.n_rows.min(6) {
write!(f, "\n[")?;
for col in 0..self.n_cols.min(8) {
write!(f, " {:8.4}", self.get(row, col))?;
if col != self.n_cols - 1 {
write!(f, ",")?;
}
}
if self.n_cols > 8 {
write!(f, " ...")?;
}
write!(f, " ]")?;
}
if self.n_rows > 6 {
write!(f, "\n...")?;
}
Ok(())
}
}
impl From<Vec<FloatArray<f64>>> for Matrix {
fn from(columns: Vec<FloatArray<f64>>) -> Self {
let n_cols = columns.len();
let n_rows = columns.first().map(|c| c.data.len()).unwrap_or(0);
let stride = aligned_stride(n_rows);
let pad = stride - n_rows;
for col in &columns {
assert_eq!(col.data.len(), n_rows, "Column length mismatch");
}
let mut data = Vec64::with_capacity(stride * n_cols);
for col in &columns {
data.extend_from_slice(&col.data);
if pad > 0 {
data.extend_from_slice(&[0.0; ALIGN_ELEMS][..pad]);
}
}
Matrix { n_rows, n_cols, stride, data, name: None }
}
}
impl From<(Vec<FloatArray<f64>>, String)> for Matrix {
fn from((columns, name): (Vec<FloatArray<f64>>, String)) -> Self {
let mut mat = Matrix::from(columns);
mat.name = Some(name);
mat
}
}
impl From<&[FloatArray<f64>]> for Matrix {
fn from(columns: &[FloatArray<f64>]) -> Self {
let n_cols = columns.len();
let n_rows = columns.first().map(|c| c.data.len()).unwrap_or(0);
let stride = aligned_stride(n_rows);
let pad = stride - n_rows;
for col in columns {
assert_eq!(col.data.len(), n_rows, "Column length mismatch");
}
let mut data = Vec64::with_capacity(stride * n_cols);
for col in columns {
data.extend_from_slice(&col.data);
if pad > 0 {
data.extend_from_slice(&[0.0; ALIGN_ELEMS][..pad]);
}
}
Matrix { n_rows, n_cols, stride, data, name: None }
}
}
impl TryFrom<&Table> for Matrix {
type Error = MinarrowError;
fn try_from(table: &Table) -> Result<Self, Self::Error> {
let name = if table.name.is_empty() { None } else { Some(table.name.clone()) };
let n_cols = table.n_cols();
let n_rows = table.n_rows;
let stride = aligned_stride(n_rows);
let pad = stride - n_rows;
let mut data = Vec64::with_capacity(stride * n_cols);
for (col_idx, fa) in table.cols.iter().enumerate() {
let numeric = fa.array.num_ref().map_err(|_| MinarrowError::TypeError {
from: "non-numeric",
to: "Float64",
message: Some(format!("column {} is not numeric", col_idx)),
})?;
let f64_arr = numeric.clone().f64()?;
if f64_arr.data.len() != n_rows {
return Err(MinarrowError::ColumnLengthMismatch {
col: col_idx,
expected: n_rows,
found: f64_arr.data.len(),
});
}
data.extend_from_slice(f64_arr.data.as_slice());
if pad > 0 {
data.extend_from_slice(&[0.0; ALIGN_ELEMS][..pad]);
}
}
Ok(Matrix { n_rows, n_cols, stride, data, name })
}
}
impl TryFrom<Table> for Matrix {
type Error = MinarrowError;
fn try_from(table: Table) -> Result<Self, Self::Error> {
Matrix::try_from(&table)
}
}
impl From<&[Vec<f64>]> for Matrix {
fn from(columns: &[Vec<f64>]) -> Self {
let n_cols = columns.len();
let n_rows = columns.first().map(|c| c.len()).unwrap_or(0);
let stride = aligned_stride(n_rows);
let pad = stride - n_rows;
for col in columns {
assert_eq!(col.len(), n_rows, "Column length mismatch");
}
let mut data = Vec64::with_capacity(stride * n_cols);
for col in columns {
data.extend_from_slice(col);
if pad > 0 {
data.extend_from_slice(&[0.0; ALIGN_ELEMS][..pad]);
}
}
Matrix { n_rows, n_cols, stride, data, name: None }
}
}
impl<'a> From<(&'a [f64], usize, usize, Option<String>)> for Matrix {
fn from((slice, n_rows, n_cols, name): (&'a [f64], usize, usize, Option<String>)) -> Self {
assert_eq!(slice.len(), n_rows * n_cols, "Slice shape mismatch");
Matrix::from_f64_unaligned(slice, n_rows, n_cols, name)
}
}
impl<'a> IntoIterator for &'a Matrix {
type Item = &'a f64;
type IntoIter = std::slice::Iter<'a, f64>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.data.iter()
}
}
impl<'a> IntoIterator for &'a mut Matrix {
type Item = &'a mut f64;
type IntoIter = std::slice::IterMut<'a, f64>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.data.iter_mut()
}
}
impl IntoIterator for Matrix {
type Item = f64;
type IntoIter = <Vec64<f64> as IntoIterator>::IntoIter;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.data.into_iter()
}
}