#ifndef PARDENSEMATRIX_H
#define PARDENSEMATRIX_H
#include "MatrixDef.h"
#include "DenseMatrix.h"
#include "DenseVector.h"
#include "MPI_Wrappers.h"
#include "ATC_Error.h"
using ATC::ATC_Error;
#include <algorithm>
#include <sstream>
namespace ATC_matrix {
template <typename T>
class ParDenseMatrix : public DenseMatrix<T> {
public:
MPI_Comm _comm;
ParDenseMatrix(MPI_Comm comm, INDEX rows=0, INDEX cols=0, bool z=1)
: DenseMatrix<T>(rows, cols, z), _comm(comm) {}
ParDenseMatrix(MPI_Comm comm, const DenseMatrix<T>& c)
: DenseMatrix<T>(c), _comm(comm) {}
ParDenseMatrix(MPI_Comm comm, const SparseMatrix<T>& c)
: DenseMatrix<T>(c), _comm(comm) {}
ParDenseMatrix(MPI_Comm comm, const Matrix<T>& c)
: DenseMatrix<T>(c), _comm(comm) {}
void ParMultMv(const Vector<T> &v,
DenseVector<T> &c, const bool At, T a, T b)
{
MultMv(*this, v, c, At, a, b);
}
};
template<>
class ParDenseMatrix<double> : public DenseMatrix<double> {
public:
MPI_Comm _comm;
ParDenseMatrix(MPI_Comm comm, INDEX rows=0, INDEX cols=0, bool z=1)
: DenseMatrix<double>(rows, cols, z), _comm(comm) {}
ParDenseMatrix(MPI_Comm comm, const DenseMatrix<double>& c)
: DenseMatrix<double>(c), _comm(comm) {}
ParDenseMatrix(MPI_Comm comm, const SparseMatrix<double>& c)
: DenseMatrix<double>(c), _comm(comm) {}
ParDenseMatrix(MPI_Comm comm, const Matrix<double>& c)
: DenseMatrix<double>(c), _comm(comm) {}
void ParMultMv(const Vector<double> &v, DenseVector<double> &c,
const bool At, double a, double b) const
{
if (At) {
MultMv(*this, v, c, At, a, b);
return;
}
const INDEX nRows = this->nRows();
const INDEX nCols = this->nCols();
if (c.size() != nRows) {
c.resize(nRows); c.zero(); } else c *= b;
int nProcs = MPI_Wrappers::size(_comm);
int myRank = MPI_Wrappers::rank(_comm);
int *majorCounts = new int[nProcs];
int *offsets = new int[nProcs];
#ifdef COL_STORAGE
int nMajor = nCols;
int nMinor = nRows;
int ParDenseMatrix::*majorField = &ParDenseMatrix::_nCols;
int ParDenseMatrix::*minorField = &ParDenseMatrix::_nRows;
#else
int nMajor = nRows;
int nMinor = nCols;
int ParDenseMatrix::*majorField = &ParDenseMatrix::_nRows;
int ParDenseMatrix::*minorField = &ParDenseMatrix::_nCols;
#endif
for (int i = 0; i < nProcs; i++) {
offsets[i] = (i * nMajor) / nProcs;
majorCounts[i] = (((i + 1) * nMajor) / nProcs) - offsets[i];
}
int myNMajor = majorCounts[myRank];
int myMajorOffset = offsets[myRank];
ParDenseMatrix<double> A_local(_comm);
A_local._data = this->_data + myMajorOffset * nMinor;
A_local.*majorField = myNMajor;
A_local.*minorField = nMinor;
#ifdef COL_STORAGE
DenseVector<double> v_local(myNMajor);
for (int i = 0; i < myNMajor; i++)
v_local(i) = v(myMajorOffset + i);
DenseVector<double> c_local = A_local * v_local;
MPI_Wrappers::allsum(_comm, c_local.ptr(), c.ptr(), c_local.size());
#else
DenseVector<double> c_local(myNMajor);
for (int i = 0; i < myNMajor; i++)
c_local(i) = c(myMajorOffset + i);
MultMv(A_local, v, c_local, At, a, b);
allgatherv(_comm, c_local.ptr(), c_local.size(), c.ptr(),
majorCounts, offsets);
#endif
A_local._data = NULL;
delete [] majorCounts;
delete [] offsets;
}
};
template<typename T>
DenseVector<T> operator*(const ParDenseMatrix<T> &A, const Vector<T> &b)
{
DenseVector<T> c;
A.ParMultMv(b, c, 0, 1.0, 0.0);
return c;
}
} #endif