#include "gr_vec.h"
#include "gr_mat.h"
int
gr_mat_nonsingular_solve_triu_classical(gr_mat_t X,
const gr_mat_t U, const gr_mat_t B, int unit, gr_ctx_t ctx)
{
slong i, j, n, m;
gr_ptr tmp;
gr_ptr inv = NULL;
gr_ptr s;
int use_division = 0;
int status = GR_SUCCESS;
slong sz = ctx->sizeof_elem;
gr_method_void_unary_op set_shallow = GR_VOID_UNARY_OP(ctx, SET_SHALLOW);
n = U->r;
m = B->c;
if (!unit)
{
GR_TMP_INIT_VEC(inv, n, ctx);
for (i = 0; i < n; i++)
{
status = gr_inv(GR_ENTRY(inv, i, sz), GR_MAT_ENTRY(U, i, i, sz), ctx);
if (status != GR_SUCCESS)
{
use_division = 1;
status = GR_SUCCESS;
break;
}
}
}
GR_TMP_INIT(s, ctx);
tmp = flint_malloc(sz * n);
for (i = 0; i < m; i++)
{
for (j = 0; j < n; j++)
set_shallow(GR_ENTRY(tmp, j, sz), GR_MAT_ENTRY(X, j, i, sz), ctx);
for (j = n - 1; j >= 0; j--)
{
status |= _gr_vec_dot(s, GR_MAT_ENTRY(B, j, i, sz), 1, GR_MAT_ENTRY(U, j, j + 1, sz), GR_ENTRY(tmp, j + 1, sz), n - j - 1, ctx);
if (!unit)
{
if (use_division)
status |= gr_div(GR_ENTRY(tmp, j, sz), s, GR_MAT_ENTRY(U, j, j, sz), ctx);
else
status |= gr_mul(GR_ENTRY(tmp, j, sz), s, GR_ENTRY(inv, j, sz), ctx);
}
else
gr_swap(GR_ENTRY(tmp, j, sz), s, ctx);
if (status != GR_SUCCESS)
{
for (j = 0; j < n; j++)
set_shallow(GR_MAT_ENTRY(X, j, i, sz), GR_ENTRY(tmp, j, sz), ctx);
goto cleanup;
}
}
for (j = 0; j < n; j++)
set_shallow(GR_MAT_ENTRY(X, j, i, sz), GR_ENTRY(tmp, j, sz), ctx);
}
cleanup:
if (!unit)
{
GR_TMP_CLEAR_VEC(inv, n, ctx);
}
flint_free(tmp);
GR_TMP_CLEAR(s, ctx);
return status;
}
int
gr_mat_nonsingular_solve_triu_recursive(gr_mat_t X,
const gr_mat_t U, const gr_mat_t B, int unit, gr_ctx_t ctx)
{
gr_mat_t UA, UB, UD, XX, XY, BX, BY, T;
slong r, n, m;
int status = GR_SUCCESS;
n = U->r;
m = B->c;
r = n / 2;
if (n == 0 || m == 0)
return GR_SUCCESS;
gr_mat_window_init(UA, U, 0, 0, r, r, ctx);
gr_mat_window_init(UB, U, 0, r, r, n, ctx);
gr_mat_window_init(UD, U, r, r, n, n, ctx);
gr_mat_window_init(BX, B, 0, 0, r, m, ctx);
gr_mat_window_init(BY, B, r, 0, n, m, ctx);
gr_mat_window_init(XX, X, 0, 0, r, m, ctx);
gr_mat_window_init(XY, X, r, 0, n, m, ctx);
status |= gr_mat_nonsingular_solve_triu(XY, UD, BY, unit, ctx);
if (status == GR_SUCCESS)
{
gr_mat_init(T, UB->r, XY->c, ctx);
status |= gr_mat_mul(T, UB, XY, ctx);
status |= gr_mat_sub(XX, BX, T, ctx);
gr_mat_clear(T, ctx);
status |= gr_mat_nonsingular_solve_triu(XX, UA, XX, unit, ctx);
}
gr_mat_window_clear(UA, ctx);
gr_mat_window_clear(UB, ctx);
gr_mat_window_clear(UD, ctx);
gr_mat_window_clear(BX, ctx);
gr_mat_window_clear(BY, ctx);
gr_mat_window_clear(XX, ctx);
gr_mat_window_clear(XY, ctx);
return status;
}
int
gr_mat_nonsingular_solve_triu_generic(gr_mat_t X, const gr_mat_t U,
const gr_mat_t B, int unit, gr_ctx_t ctx)
{
if (B->r < 10 || B->c < 10)
return gr_mat_nonsingular_solve_triu_classical(X, U, B, unit, ctx);
else
return gr_mat_nonsingular_solve_triu_recursive(X, U, B, unit, ctx);
}
int
gr_mat_nonsingular_solve_triu(gr_mat_t X, const gr_mat_t U,
const gr_mat_t B, int unit, gr_ctx_t ctx)
{
return GR_MAT_BINARY_OP_WITH_FLAG(ctx, MAT_NONSINGULAR_SOLVE_TRIU)(X, U, B, unit, ctx);
}