import sys
len = int(sys.argv[1])
halflen = int((len+1)/2)
for n in range(1, halflen):
print(f"let x{n}{len-n}p = *buffer.get_unchecked({n}) + *buffer.get_unchecked({len-n});")
print(f"let x{n}{len-n}n = *buffer.get_unchecked({n}) - *buffer.get_unchecked({len-n});")
row = ["let sum = *buffer.get_unchecked(0)"]
for n in range(1, halflen):
row.append(f"x{n}{len-n}p")
print(" + ".join(row) + ";")
for n in range(1, halflen):
row = [f"let b{n}{len-n}re_a = buffer.get_unchecked(0).re"]
for m in range(1, halflen):
mn = (m*n)%len
if mn > len/2:
mn = len-mn
row.append(f"self.twiddle{mn}.re*x{m}{len-m}p.re")
print(" + ".join(row) + ";")
row = []
for m in range(1, halflen):
mn = (m*n)%len
if mn > len/2:
mn = len-mn
row.append(f"-self.twiddle{mn}.im*x{m}{len-m}n.im")
else:
row.append(f"self.twiddle{mn}.im*x{m}{len-m}n.im")
print(f"let b{n}{len-n}re_b = " + " + ".join(row) + ";")
print("")
for n in range(1, halflen):
row = [f"let b{n}{len-n}im_a = buffer.get_unchecked(0).im"]
for m in range(1, halflen):
mn = (m*n)%len
if mn > len/2:
mn = len-mn
row.append(f"self.twiddle{mn}.re*x{m}{len-m}p.im")
print(" + ".join(row) + ";")
row = []
for m in range(1, halflen):
mn = (m*n)%len
if mn > len/2:
mn = len-mn
row.append(f"-self.twiddle{mn}.im*x{m}{len-m}n.re")
else:
row.append(f"self.twiddle{mn}.im*x{m}{len-m}n.re")
print(f"let b{n}{len-n}im_b = " + " + ".join(row) + ";")
print("")
for n in range(1,len):
nfold = n
sign_re = "-"
sign_im = "+"
if n > len/2:
nfold = len-n
sign_re = "+"
sign_im = "-"
print(f"let out{n}re = b{nfold}{len-nfold}re_a {sign_re} b{nfold}{len-nfold}re_b;")
print(f"let out{n}im = b{nfold}{len-nfold}im_a {sign_im} b{nfold}{len-nfold}im_b;")
print("*buffer.get_unchecked_mut(0) = sum;")
for n in range(1,len):
print(f"*buffer.get_unchecked_mut({n}) = Complex{{ re: out{n}re, im: out{n}im }};")