import json
import os
import numpy as np
SEED = 0xC0FFEE
B = 4
T = 64
H = 1
BT = 64
def make_strict_lower_inputs(rng):
a_strict = np.zeros((B, T, H, BT), dtype=np.float32)
for b in range(B):
m = rng.standard_normal((BT, BT)).astype(np.float32) * 0.1
m_strict = np.tril(m, k=-1) a_strict[b, :, 0, :] = m_strict
return a_strict
def reference_inv(a_strict):
a_inv = np.zeros_like(a_strict)
for b in range(B):
m_strict = a_strict[b, :, 0, :] m_full = np.eye(BT, dtype=np.float32) + m_strict
m_inv = np.linalg.inv(m_full).astype(np.float32)
a_inv[b, :, 0, :] = m_inv
return a_inv
def main():
rng = np.random.default_rng(SEED)
a_strict = make_strict_lower_inputs(rng)
a_inv_ref = reference_inv(a_strict)
for b in range(B):
block = a_strict[b, :, 0, :]
assert np.all(np.diag(block) == 0.0), f"input {b}: diag not zero"
upper = np.triu(block, k=0)
upper_strict = np.triu(block, k=1)
assert np.all(upper_strict == 0.0), f"input {b}: strict upper not zero"
out_dir = os.path.dirname(os.path.abspath(__file__))
a_strict.tofile(os.path.join(out_dir, "chunk_tri_solve_invert_input_a_strict.bin"))
a_inv_ref.tofile(os.path.join(out_dir, "chunk_tri_solve_invert_a_inv_ref.bin"))
meta = {
"B": B,
"T": T,
"H": H,
"BT": BT,
"seed": f"0x{SEED:08X}",
"spec_source": "FLA solve_tril (vllm/.../solve_tril.py:506-530)",
"reference_impl": "numpy.linalg.inv(I + A_strict)",
}
with open(os.path.join(out_dir, "chunk_tri_solve_invert_meta.json"), "w") as f:
json.dump(meta, f, indent=2)
print(
f"chunk_tri_solve_invert reference: B={B} T={T} H={H} BT={BT} "
f"max(|A_inv|) = {np.abs(a_inv_ref).max():.4f}"
)
if __name__ == "__main__":
main()