flint-sys 0.9.0

Bindings to the FLINT C library
Documentation
# This file contains an illustration of David Harvey's truncated fft fitted
# with reduced twiddle factor loads. It also contains ifft_trunc_formula which
# prints formulas that are needed in the inverse truncated transform.

using Nemo

L = 2
Qw,w = PolynomialRing(QQ, "x")
K,w = NumberField(w^(2^(L-1))+1, "w")
R,X = PolynomialRing(K, vcat([Symbol(:x,i) for i in 0:2^L-1]))

function revbits(a::Int, l::Int)
    @assert 0 <= a < 2^l
    return evalpoly(2, reverse(digits(a, base=2, pad=l)))
end

# on input:
#  x is assumed to have length at least 2^k
#  x[1+itrunc] ... x[1+2^k-1] are not read and assumed to be zero
# on output (checked by an assert):
#  x[1+otrunc] ... x[1+2^k-1] are not defined
#  x[1+i] is evalpoly(w^revbits(j*2^k+i,L), x[1:itrunc])
# all accesses to x are offset in the real x by I and strided by S
function tfft!(x::Vector, I::Int, S::Int, k::Int, j::Int, itrunc::Int, otrunc::Int)
    @assert k >= 0
    @assert 0 <= itrunc <= 2^k
    @assert 0 <= otrunc <= 2^k
    if otrunc < 1
        return
    elseif itrunc < 1
        for i in 0:otrunc-1
            x[1+I+S*i] = zero(R)
        end
        return
    end

    # for answer check
    inx = elem_type(R)[x[1+I+S*i] for i in 0:itrunc-1]

    if k < 1
        return
    elseif k == 1
        ww = w^revbits(2*j+0,L)
        x0 = 0 < itrunc ? x[1+I+0*S] : zero(R)
        x1 = 1 < itrunc ? x[1+I+1*S] : zero(R)
        x[1+I+0*S] = 0 < otrunc ? x0 + ww*x1 : zero(R)
        x[1+I+1*S] = 1 < otrunc ? x0 - ww*x1 : zero(R)
    else
        # any k1 + k2 = k with k1, k2 >= 1 will do
        k1 = fld(k, 2)
        k2 = k - k1

        n1, n2 = divrem(otrunc, 2^k2)
        z1, z2 = divrem(itrunc, 2^k2)
        n1p = n1 + (n2 != 0)
        z2p = min(2^k2, itrunc)

        # columns   !!! j is not changing in this loop !!!
        for a in 0:z2p-1
            tfft!(x, I + a*S, S*2^k2, k1, j, z1 + (a < z2), n1p)
        end

        # rows
        for b in 0:n1-1
            tfft!(x, I + S*b*2^k2, S, k2, j*2^k1 + b, z2p, 2^k2)
        end
        if n2 > 0
            tfft!(x, I + S*n1*2^k2, S, k2, j*2^k1 + n1, z2p, n2)
        end
    end

    # check answer
    for i in 0:otrunc-1
        if x[1+I+S*i] != evalpoly(R(w^revbits(j*2^k+i,L)), inx)
            error("error at i = $i with k = $k")
        end
    end
end

#=
L = 2^k
suppose a[z] = ... = a[L-1] = 0
n <= z
1 <= n + f <= L
IFFT(L, zeta, z, n, f; ah[0], ..., ah[n-1], L*a[n], ..., L*a[z-1])
returns with (L*a[0], ..., L*a[n-1])        if f = 0
             (L*a[0], ..., L*a[n-1], ah[n]) if f = 1

    ifft_trunc_formula generates the input -> output map as a n+f by z matrix.
    Special case for L = 2:

        Harvey                                          !!! Here !!!
        ah[0] =       a[0] + a[1]                      ah[0] = a[0] + w*a[1]
        ah[1] = zeta*(a[0] - a[1])                     ah[1] = a[0] - w*a[1]

    if n = 2,
        2*a[0] = ah[0] + zeta^-1*ah[1]                2*a[0] =  ah[0] + ah[1]
        2*a[1] = ah[0] - zeta^-1*ah[1]                2*a[1] = (ah[0] - ah[1])*w^-1

    if n = 1, f = 1, z = 2
        ah[0], 2*a[1]
        2*a[0] = 2*ah[0] - 2*a[1]                     2*a[0] = 2*ah[0] - w*2*a[1]
         ah[1] = zeta*(ah[0] - 2*a[1])                 ah[1] =   ah[0] - w*2*a[1]

    if n = 1, f = 1, z = 1
        2*a[0] = 2*ah[0]
         ah[1] = zeta*(ah[0])

    if n = 0, f = 1, z = 2
        ah[0] = (2*a[0] + 2*a[1])/2

    if n = 0, f = 1, z = 1
        ah[0] = (2*a[0])/2
=#
function itfft!(x::Vector, I::Int, S::Int, k::Int, j::Int, z::Int, n::Int, f::Bool)
    @assert n <= z
    @assert 1 <= z <= 2^k
    @assert 1 <= n + f <= 2^k
    if k < 1
        return
    elseif k == 1
        ww = w^revbits(2*j+0,L)
        if n == 2
            u = x[1+I+0*S]
            v = x[1+I+1*S]
            x[1+I+0*S] =  u + v
            x[1+I+1*S] = (u - v)*ww^-1
        elseif n == 1
            u = x[1+I+0*S]
            v = z == 2 ? x[1+I+1*S] : zero(R)
            x[1+I+0*S] = 2*u - ww*v
            if f
                x[1+I+1*S] = u - ww*v
            end
        elseif n == 0
            u = x[1+I+0*S]
            v = z == 2 ? x[1+I+1*S] : zero(R)
            x[1+I+0*S] = (u + ww*v)*inv(K(2))
        else
            error("case n=$n, f=$f, z=$z not implemented")
        end
    else
        # any k1 + k2 = k with k1, k2 >= 1 will do
        k1 = fld(k, 2)
        k2 = k - k1

        n1, n2 = divrem(n, 2^k2)
        z1, z2 = divrem(z, 2^k2)
        fp = n2 + f > 0
        z2p = min(2^k2, z)
        m = min(n2, z2); mp = max(n2, z2)

        # complete rows
        for b in 0:n1-1
            itfft!(x, I + S*b*2^k2, S, k2, j*2^k1 + b, 2^k2, 2^k2, false)
        end

        # rightmost columns !!! j is not changing in this loop !!!
        for a in n2:z2p-1
            itfft!(x, I + S*a, S*2^k2, k1, j, z1 + (a < mp), n1, fp)
        end

        # last partial row
        if fp
            itfft!(x, I + S*n1*2^k2, S, k2, j*2^k1 + n1, z2p, n2, f)
        end

        # leftmost columns  !!! j is not changing in this loop !!!
        for a in 0:n2-1
            itfft!(x, I + S*a, S*2^k2, k1, j, z1 + (a < m), n1 + 1, false)
        end
    end
end

println("-------- fft ----------")
for itrunc in 1:2^L
    @show itrunc
    for otrunc in 1:2^L
        x = copy(X)
        tfft!(x, 0, 1, L, 0, itrunc, otrunc)
        for i in 0:otrunc-1
            @assert x[1+i] == evalpoly(R(w^revbits(i,L)), X[1:itrunc])
        end
    end
end

println("-------- ifft ---------")
for trunc in 1:2^L
    @show trunc
    x = copy(X)
    tfft!(x, 0, 1, L, 0, trunc, trunc)
    itfft!(x, 0, 1, L, 0, trunc, trunc, false)
#show(stdout, "text/plain", x); println();
    for i in 0:trunc-1
        @assert x[1+i] == 2^L*X[1+i]
    end
end

function ifft_trunc_formula(k::Int, n::Int, z::Int, f::Bool)
    @assert n <= z
    @assert 1 <= z <= 2^k
    @assert 1 <= n + f <= 2^k
    l = 2^k

    local Qr,r = PolynomialRing(QQ, "r")
    local K,r = NumberField(r^(2^(k-1)) + 1, "r")
    local Rwr,w = PolynomialRing(K, "w")
    r = Rwr(r)
    local F = FractionField(Rwr)
    r = F(r)
    w = F(w)

    M = zero_matrix(F, l, l)
    for i in 0:n-1, j in 0:l-1
        M[1+i,1+j] = (r^revbits(i,k)*w)^j
    end
    for i in n:l-1
        M[1+i,1+i] = l
    end

    N = zero_matrix(F, n+f, l)
    for i in 0:n-1
        N[1+i,1+i] = l
    end
    if f
        for j in 0:l-1
            N[1+n,1+j] = (r^revbits(n,k)*w)^j
        end
    end

    println("\nk = $k, n = $n, z = $z, f = $f")
    show(stdout, "text/plain", (N*inv(M))[:, 1:z])
    println()
    nothing
end

println("\n-------- radix 2 --------------------")
ifft_trunc_formula(1, 2,2,false)
ifft_trunc_formula(1, 1,2,true)
ifft_trunc_formula(1, 1,2,false)
ifft_trunc_formula(1, 1,1,true)
ifft_trunc_formula(1, 1,1,false)

println("\n-------- radix 4 (r^2 = -1) ---------")
ifft_trunc_formula(2, 4,4,false)
ifft_trunc_formula(2, 1,1,false)
ifft_trunc_formula(2, 1,1,true)
ifft_trunc_formula(2, 2,4,false)
ifft_trunc_formula(2, 2,2,false)
ifft_trunc_formula(2, 3,3,false)
ifft_trunc_formula(2, 3,3,true)
ifft_trunc_formula(2, 2,2,true)
ifft_trunc_formula(2, 1,4,false)
ifft_trunc_formula(2, 3,4,false)
ifft_trunc_formula(2, 0,4,true)
ifft_trunc_formula(2, 2,4,true)
ifft_trunc_formula(2, 3,4,true)
ifft_trunc_formula(2, 1,4,true)

nothing