wgebra 0.2.0

Composable WGSL shaders for linear algebra.
Documentation
// This is a WGSL port of https://github.com/wi-re/tbtSVD/blob/master/source/SVD.h

// which is an implementation of "Computing the Singular Value Decomposition of 3x3 matrices with minimal branching and
// elementary floating point operations" from http://pages.cs.wisc.edu/~sifakis/project_pages/svd.html

#define_import_path wgebra::svd3
#import wgebra::quat as Quat


// The SVD of a 3x3 matrix.
struct Svd {
    U: mat3x3<f32>,
    S: vec3<f32>,
    Vt: mat3x3<f32>,
};


// Constants used for calculation of givens quaternions
const GAMMA: f32 = 5.828427124; // sqrt(8)+3;
const CSTAR: f32 = 0.923879532; // cos(pi/8)
const SSTAR: f32 = 0.3826834323; // sin(p/8)
// Threshold value
const SVD_EPSILON: f32 = 1e-6;
// Iteration counts for Jacobi Eigenanlysis and reciprocal square root functions, influence precision
const JACOBI_STEPS: u32 = 12;
const RSQRT_STEPS: u32 = 4;
const RSQRT1_STEPS: u32 = 6;


// The QR decomposition of a 3x3 matrix.
struct QR {
    Q: mat3x3<f32>,
    R: mat3x3<f32>,
};

// A simple symmetrix 3x3 Matrix class (contains no storage for (0, 1) (0, 2) and (1, 2)
struct Symmetric3x3 {
    // TODO: for some reasons, naming these m00, m10, etc. doesn’t compile when
    //       using svd3 as a module from another shader.
    mxx: f32,
    myx: f32,
    myy: f32,
    mzx: f32,
    mzy: f32,
    mzz: f32,
}

// Helper struct to store 2 floats to avoid OUT parameters on functions
struct givens {
    ch: f32,
    sh: f32,
};

/// Calculates the reciprocal square root of x using a fast approximation.
/// The number of newton iterations can be controlled using RSQRT_STEPS.
// A built-in rsqrt function or 1.f/sqrt(x) could be used, however doing this manually allows for exact matching results on CPU and GPU code.
fn rsqrt(val: f32) -> f32 {
    var x = val;
    let xhalf = -0.5f * x;
    var i = bitcast<i32>(x);
    i = 0x5f375a82 - (i >> 1);
    x = bitcast<f32>(i);
    for (var i = 0u; i < RSQRT_STEPS; i++) {
        x = x * fma(x * x, xhalf, 1.5f);
    }
    return x;
}

// See rsqrt. Uses RSQRT1_STEPS to offer a higher precision alternative
fn rsqrt1(val: f32) -> f32 {
    var x = val;
    let xhalf = -0.5f * x;
    var i = bitcast<i32>(x);
    i = 0x5f375a82 - (i >> 1);
    x = bitcast<f32>(i);
    for (var i = 0u; i < RSQRT1_STEPS; i++) {
        x = x * fma(x * x, xhalf, 1.5f);
    }
    return x;
}

// Calculates the square root of x using 1.f/rsqrt1(x).
fn accurateSqrt(x: f32) -> f32 {
    return 1.f / rsqrt1(x);
}

// Helper function used to swap X with Y and Y with  X if c == true
fn condSwap(c: bool, x: ptr<function, f32>, y: ptr<function, f32>) {
    let x0 = *x;
    *x = select(*x, *y, c);
    *y = select(*y, x0, c);
}

// Helper function used to swap X with Y and Y with -X if c == true
fn condNegSwap(c: bool, x: ptr<function, f32>, y: ptr<function, f32>) {
    let x0 = -*x;
    *x = select(*x, *y, c);
    *y = select(*y, x0, c);
}

fn condSwapVec(c: bool, x: ptr<function, vec3<f32>>, y: ptr<function, vec3<f32>>) {
    let x0 = *x;
    *x = select(*x, *y, vec3(c));
    *y = select(*y, x0, vec3(c));
}

fn condNegSwapVec(c: bool, x: ptr<function, vec3<f32>>, y: ptr<function, vec3<f32>>) {
    let x0 = -*x;
    *x = select(*x, *y, vec3(c));
    *y = select(*y, x0, vec3(c));
}

// For an explanation of the math see http://pages.cs.wisc.edu/~sifakis/papers/SVD_TR1690.pdf
// Computing the Singular Value Decomposition of 3 x 3 matrices with minimal branching and elementary floating point operations
// See Algorithm 2 in reference. Given a matrix A this function returns the givens quaternion (x and w component, y and z are 0)
fn approximateGivensQuaternion(A: ptr<function, Symmetric3x3>) -> givens {
    let g = givens(2.f * ((*A).mxx - (*A).myy), (*A).myx);
    var b = GAMMA * g.sh * g.sh < g.ch * g.ch;
    let w = rsqrt(fma(g.ch, g.ch, g.sh * g.sh));
    if (w != w) {
        b = false;
    }

    if b {
        return givens(w * g.ch, w * g.sh);
    } else {
        return givens(CSTAR, SSTAR);
    }
}

// Function used to apply a givens rotation S. Calculates the weights and updates the quaternion to contain the cumultative rotation
fn jacobiConjugation(x: i32, y: i32, z: i32, S: ptr<function, Symmetric3x3>, q: ptr<function, Quat::Quat>) {
    var g = approximateGivensQuaternion(S);
    let scale = 1.f / fma(g.ch, g.ch, g.sh *  g.sh);
    let a = fma(g.ch, g.ch, -g.sh * g.sh) * scale;
    let b = 2.f * g.sh * g.ch * scale;
    var _S = (*S);
    // perform conjugation S = Q'*S*Q
    (*S).mxx = fma(a, fma(a, _S.mxx, b * _S.myx), b * (fma(a, _S.myx, b * _S.myy)));
    (*S).myx = fma(a, fma(-b, _S.mxx, a * _S.myx), b * (fma(-b, _S.myx, a * _S.myy)));
    (*S).myy = fma(-b, fma(-b, _S.mxx, a * _S.myx), a * (fma(-b, _S.myx, a * _S.myy)));
    (*S).mzx = fma(a, _S.mzx, b * _S.mzy);
    (*S).mzy = fma(-b, _S.mzx, a * _S.mzy);
    (*S).mzz = _S.mzz;
    // update cumulative rotation qV
    var tmp = array<f32, 3>( // TODO: why does it have to be `var` instead of `let` so we can index with z?
        (*q).coords[0] * g.sh,
        (*q).coords[1] * g.sh,
        (*q).coords[2] * g.sh,
    );
    g.sh *= (*q).coords[3];
    // (x,y,z) corresponds to ((0,1,2),(1,2,0),(2,0,1)) for (p,q) = ((0,1),(1,2),(0,2))
    (*q).coords[z] = fma((*q).coords[z], g.ch, g.sh);
    (*q).coords[3] = fma((*q).coords[3], g.ch, -tmp[z]); // w
    (*q).coords[x] = fma((*q).coords[x], g.ch, tmp[y]);
    (*q).coords[y] = fma((*q).coords[y], g.ch, -tmp[x]);
    // re-arrange matrix for next iteration
    _S.mxx = (*S).myy;
    _S.myx = (*S).mzy; _S.myy = (*S).mzz;
    _S.mzx = (*S).myx; _S.mzy = (*S).mzx; _S.mzz = (*S).mxx;
    (*S).mxx = _S.mxx;
    (*S).myx = _S.myx; (*S).myy = _S.myy;
    (*S).mzx = _S.mzx; (*S).mzy = _S.mzy; (*S).mzz = _S.mzz;
}

// Function used to contain the givens permutations and the loop of the jacobi steps controlled by JACOBI_STEPS
// Returns the quaternion q containing the cumultative result used to reconstruct S
fn jacobiEigenanalysis(S: Symmetric3x3) -> Quat::Quat {
    var mat = S;
    var q = Quat::identity();
    for (var i = 0u; i < JACOBI_STEPS; i += 1u) {
        jacobiConjugation(0, 1, 2, &mat, &q);
        jacobiConjugation(1, 2, 0, &mat, &q);
        jacobiConjugation(2, 0, 1, &mat, &q);
    }
    return q;
}

struct SortedSingularValues {
    B: mat3x3<f32>,
    V: mat3x3<f32>,
}


/// Implementation of Algorithm 3
// NOTE: doing this through pointers to B and fails with an internal error in naga
//       when trying to swap their columns.
fn sortSingularValues(B: mat3x3<f32>, V: mat3x3<f32>) -> SortedSingularValues {
    var bx = B.x;
    var by = B.y;
    var bz = B.z;
    var vx = V.x;
    var vy = V.y;
    var vz = V.z;
    var rho1 = dot(bx, bx);
    var rho2 = dot(by, by);
    var rho3 = dot(bz, bz);

    var c = rho1 < rho2;
    condNegSwapVec(c, &bx, &by);
    condNegSwapVec(c, &vx, &vy);
    condSwap(c, &rho1, &rho2);
    c = rho1 < rho3;
    condNegSwapVec(c, &bx, &bz);
    condNegSwapVec(c, &vx, &vz);
    condSwap(c, &rho1, &rho3);
    c = rho2 < rho3;
    condNegSwapVec(c, &by, &bz);
    condNegSwapVec(c, &vy, &vz);

    return SortedSingularValues(mat3x3(bx, by, bz), mat3x3(vx, vy, vz));
}

// Implementation of Algorithm 4
fn QRGivensQuaternion(a1: f32, a2: f32) -> givens {
    // a1 = pivot point on diagonal
    // a2 = lower triangular entry we want to annihilate
    let epsilon = SVD_EPSILON;
    let rho = accurateSqrt(fma(a1, a1, a2 * a2));
    var ch = abs(a1) + max(rho, epsilon);
    var sh = select(0.0, a2, rho > epsilon);
    let b = a1 < 0.f;
    condSwap(b, &sh, &ch);
    let w = rsqrt(fma(ch, ch, sh * sh));
    ch *= w;
    sh *= w;
    return givens(ch, sh);
}

// Implements a QR decomposition of a Matrix, see Sec 4.2
fn QRDecomposition(in_B: mat3x3<f32>) -> QR {
    var B = in_B;

    // first givens rotation (ch,0,0,sh)
    let g1 = QRGivensQuaternion(B.x.x, B.x.y);
    var a = fma(-2.f, g1.sh * g1.sh, 1.f);
    var b = 2.f * g1.ch * g1.sh;
    // apply B = Q' * B
    var r00 = fma(a, B.x.x, b * B.x.y);  var r01 = fma(a, B.y.x, b * B.y.y);	var r02 = fma(a, B.z.x, b * B.z.y);
    var r10 = fma(-b, B.x.x, a * B.x.y); var r11 = fma(-b, B.y.x, a * B.y.y);	var r12 = fma(-b, B.z.x, a * B.z.y);
    var r20 = B.x.z;					 var r21 = B.y.z;						var r22 = B.z.z;
    // second givens rotation (ch,0,-sh,0)
    let g2 = QRGivensQuaternion(r00, r20);
    a = fma(-2.f, g2.sh * g2.sh, 1.f);
    b = 2.f * g2.ch * g2.sh;
    // apply B = Q' * B;
    var b00 = fma(a, r00, b * r20);  var b01 = fma(a, r01, b * r21);  var b02 = fma(a, r02, b * r22);
    var b10 = r10;                   var b11 = r11;                   var b12 = r12;
    var b20 = fma(-b, r00, a * r20); var b21 = fma(-b, r01, a * r21); var b22 = fma(-b, r02, a * r22);
    // third givens rotation (ch,sh,0,0)
    let g3 = QRGivensQuaternion(b11, b21);
    a = fma(-2.f, g3.sh * g3.sh, 1.f);
    b = 2.f * g3.ch * g3.sh;
    // R is now set to desired value
    r00 = b00;                   r01 = b01;                  r02 = b02;
    r10 = fma(a, b10, b * b20);  r11 = fma(a, b11, b * b21);  r12 = fma(a, b12, b * b22);
    r20 = fma(-b, b10, a * b20); r21 = fma(-b, b11, a * b21); r22 = fma(-b, b12, a * b22);
    // construct the cumulative rotation Q=Q1 * Q2 * Q3
    // the number of floating point operations for three quaternion multiplications
    // is more or less comparable to the explicit form of the joined matrix.
    // certainly more memory-efficient!
    let sh12 = 2.f * fma(g1.sh, g1.sh, -0.5f);
    let sh22 = 2.f * fma(g2.sh, g2.sh, -0.5f);
    let sh32 = 2.f * fma(g3.sh, g3.sh, -0.5f);

    let q00 = sh12 * sh22;
    let q01 = fma(4.f * g2.ch * g3.ch, sh12 * g2.sh * g3.sh, 2.f * g1.ch * g1.sh * sh32);
    let q02 = fma(4.f * g1.ch * g3.ch, g1.sh * g3.sh, -2.f * g2.ch * sh12 * g2.sh * sh32);

    let q10 = -2.f * g1.ch * g1.sh * sh22;
    let q11 = fma(-8.f * g1.ch * g2.ch * g3.ch, g1.sh * g2.sh * g3.sh, sh12 * sh32);
    let q12 = fma(-2.f * g3.ch, g3.sh, 4.f * g1.sh * fma(g3.ch * g1.sh, g3.sh, g1.ch * g2.ch*g2.sh*sh32));

    let q20 = 2.f * g2.ch * g2.sh;
    let q21 = -2.f * g3.ch * sh22 * g3.sh;
    let q22 = sh22 * sh32;

    let Q = mat3x3(
        vec3(q00, q10, q20),
        vec3(q01, q11, q21),
        vec3(q02, q12, q22),
    );
    let R = mat3x3(
        vec3(r00, r10, r20),
        vec3(r01, r11, r21),
        vec3(r02, r12, r22),
    );

    return QR(Q, R);
}

// Computes the SVD of a 3x3 matrix.
fn svd(A: mat3x3<f32>) -> Svd {
    let ata = transpose(A) * A;
    let ata_sym = Symmetric3x3(
        ata.x.x,
        ata.x.y, ata.y.y,
        ata.x.z, ata.y.z, ata.z.z
    );
    let V = Quat::toMatrix(jacobiEigenanalysis(ata_sym));
    let B = A * V;
    let sorted = sortSingularValues(B, V);
    let qr = QRDecomposition(sorted.B);
    let S = vec3(qr.R.x.x, qr.R.y.y, qr.R.z.z);
    return Svd(qr.Q, S, transpose(sorted.V));
}

// Rebuilds the matrix this svd is the decomposition of.
fn recompose(svd: Svd) -> mat3x3<f32> {
    let U_S = mat3x3(svd.U.x * svd.S.x, svd.U.y * svd.S.y, svd.U.z * svd.S.z);
    return U_S * svd.Vt;
}