import sys
def make_shuffling_single_f64(len):
inputs = ", ".join([str(n) for n in range(len)])
print(f"let values = read_complex_to_array!(input, {{{inputs}}});")
print("")
print("let out = self.perform_fft_direct(values);")
print("")
print(f"write_complex_to_array!(out, output, {{{inputs}}});")
def make_shuffling_single_f32(len):
inputs = ", ".join([str(n) for n in range(len)])
print(f"let values = read_partial1_complex_to_array!(input, {{{inputs}}});")
print("")
print("let out = self.perform_parallel_fft_direct(values);")
print("")
print(f"write_partial_lo_complex_to_array!(out, output, {{{inputs}}});")
def make_shuffling_parallel_f32(len):
inputs = ", ".join([str(2*n) for n in range(len)])
outputs = ", ".join([str(n) for n in range(len)])
print(f"let input_packed = read_complex_to_array!(input, {{{inputs}}});")
print("")
print("let values = [")
for n in range(int(len/2)):
print(f" extract_lo_hi_f32(input_packed[{int(n)}], input_packed[{int(len/2 + n)}]),")
print(f" extract_hi_lo_f32(input_packed[{int(n)}], input_packed[{int(len/2 + n+1)}]),")
print(f" extract_lo_hi_f32(input_packed[{int(len/2)}], input_packed[{int(len-1)}]),")
print("];")
print("")
print("let out = self.perform_parallel_fft_direct(values);")
print("")
print("let out_packed = [")
for n in range(int(len/2)):
print(f" extract_lo_lo_f32(out[{int(2*n)}], out[{int(2*n+1)}]),")
print(f" extract_lo_hi_f32(out[{int(len-1)}], out[0]),")
for n in range(int(len/2)):
print(f" extract_hi_hi_f32(out[{int(2*n+1)}], out[{int(2*n+2)}]),")
print("];")
print("")
print(f"write_complex_to_array_strided!(out_packed, output, 2, {{{outputs}}});")
def make_butterfly(len, fft2func, calcfunc, mulfunc, rotatefunc):
halflen = int((fftlen+1)/2)
for n in range(1, halflen):
print(f"let [x{n}p{fftlen-n}, x{n}m{fftlen-n}] = {fft2func}(values[{n}], values[{fftlen-n}]);")
print("")
items = []
for m in range (1, halflen):
for n in range(1, halflen):
mn = (m*n)%fftlen
if mn > fftlen/2:
mn = fftlen-mn
print(f"let t_a{m}_{n} = {mulfunc}(self.twiddle{mn}re, x{n}p{fftlen-n});")
print("")
items = []
for m in range (1, halflen):
for n in range(1, halflen):
mn = (m*n)%fftlen
if mn > fftlen/2:
mn = fftlen-mn
print(f"let t_b{m}_{n} = {mulfunc}(self.twiddle{mn}im, x{n}m{fftlen-n});")
print("")
print("let x0 = values[0];")
for m in range(1, halflen):
items = ["x0"]
for n in range(1, halflen):
items.append(f"t_a{m}_{n}")
terms = " + ".join(items)
print(f'let t_a{m} = {calcfunc}({terms});')
print("")
for m in range(1, halflen):
terms = f"t_b{m}_1"
for n in range(2, halflen):
mn = (m*n)%fftlen
if mn > fftlen/2:
sign = " - "
else:
sign = " + "
terms = terms + sign + f"t_b{m}_{n}"
print(f'let t_b{m} = {calcfunc}({terms});')
print("")
for m in range(1, halflen):
print(f'let t_b{m}_rot = self.rotate.{rotatefunc}(t_b{m});')
print("")
items = ["x0"]
for n in range(1, halflen):
items.append(f"x{n}p{fftlen-n}")
terms = " + ".join(items)
print(f'let y0 = {calcfunc}({terms});')
for m in range(1, halflen):
print(f"let [y{m}, y{fftlen-m}] = {fft2func}(t_a{m}, t_b{m}_rot);")
items = []
for n in range(0, fftlen):
items.append(f"y{n}")
print(f'[{", ".join(items)}]')
if __name__ == "__main__":
fftlen = int(sys.argv[1])
print("\n\n--------------- f32 ---------------")
print("\n ----- perform_fft_contiguous -----")
make_shuffling_single_f32(fftlen)
print("\n ----- perform_parallel_fft_contiguous -----")
make_shuffling_parallel_f32(fftlen)
print("\n ----- perform_parallel_fft_direct -----")
make_butterfly(fftlen, "parallel_fft2_interleaved_f32", "calc_f32!", "_mm_mul_ps", "rotate_both")
print("\n\n--------------- f64 ---------------")
print("\n ----- perform_fft_contiguous -----")
make_shuffling_single_f64(fftlen)
print("\n ----- perform_parallel_fft_direct -----")
make_butterfly(fftlen, "solo_fft2_f64", "calc_f64!", "_mm_mul_pd", "rotate")