import jax.numpy as jnp
import numpy as np
import pytest
from grok import make_recsys_attn_mask
class TestMakeRecsysAttnMask:
def test_output_shape(self):
seq_len = 10
candidate_start_offset = 5
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
assert mask.shape == (1, 1, seq_len, seq_len)
def test_user_history_has_causal_attention(self):
seq_len = 8
candidate_start_offset = 5
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
mask_2d = mask[0, 0]
for i in range(candidate_start_offset):
for j in range(candidate_start_offset):
if j <= i:
assert mask_2d[i, j] == 1, f"Position {i} should attend to position {j}"
else:
assert mask_2d[i, j] == 0, (
f"Position {i} should NOT attend to future position {j}"
)
def test_candidates_attend_to_user_history(self):
seq_len = 8
candidate_start_offset = 5
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
mask_2d = mask[0, 0]
for candidate_pos in range(candidate_start_offset, seq_len):
for history_pos in range(candidate_start_offset):
assert mask_2d[candidate_pos, history_pos] == 1, (
f"Candidate at {candidate_pos} should attend to user+history at {history_pos}"
)
def test_candidates_attend_to_themselves(self):
seq_len = 8
candidate_start_offset = 5
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
mask_2d = mask[0, 0]
for candidate_pos in range(candidate_start_offset, seq_len):
assert mask_2d[candidate_pos, candidate_pos] == 1, (
f"Candidate at {candidate_pos} should attend to itself"
)
def test_candidates_do_not_attend_to_other_candidates(self):
seq_len = 8
candidate_start_offset = 5
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
mask_2d = mask[0, 0]
for query_pos in range(candidate_start_offset, seq_len):
for key_pos in range(candidate_start_offset, seq_len):
if query_pos != key_pos:
assert mask_2d[query_pos, key_pos] == 0, (
f"Candidate at {query_pos} should NOT attend to candidate at {key_pos}"
)
def test_full_mask_structure(self):
seq_len = 6
candidate_start_offset = 3
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
mask_2d = mask[0, 0]
expected = np.array(
[
[1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 1, 0], [1, 1, 1, 0, 0, 1], ],
dtype=np.float32,
)
np.testing.assert_array_equal(
np.array(mask_2d),
expected,
err_msg="Full mask structure does not match expected pattern",
)
def test_dtype_preserved(self):
seq_len = 5
candidate_start_offset = 3
mask_f32 = make_recsys_attn_mask(seq_len, candidate_start_offset, dtype=jnp.float32)
mask_f16 = make_recsys_attn_mask(seq_len, candidate_start_offset, dtype=jnp.float16)
assert mask_f32.dtype == jnp.float32
assert mask_f16.dtype == jnp.float16
def test_single_candidate(self):
seq_len = 4
candidate_start_offset = 3
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
mask_2d = mask[0, 0]
expected = np.array(
[
[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1],
],
dtype=np.float32,
)
np.testing.assert_array_equal(np.array(mask_2d), expected)
def test_all_candidates(self):
seq_len = 4
candidate_start_offset = 1
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
mask_2d = mask[0, 0]
expected = np.array(
[
[1, 0, 0, 0], [1, 1, 0, 0], [1, 0, 1, 0], [1, 0, 0, 1], ],
dtype=np.float32,
)
np.testing.assert_array_equal(np.array(mask_2d), expected)
if __name__ == "__main__":
pytest.main([__file__, "-v"])