import logging
import numpy as np
from grok import TransformerConfig
from recsys_model import HashConfig
from recsys_retrieval_model import PhoenixRetrievalModelConfig
from runners import (
RecsysRetrievalInferenceRunner,
RetrievalModelRunner,
create_example_batch,
create_example_corpus,
ACTIONS,
)
def main():
emb_size = 128 num_actions = len(ACTIONS) history_seq_len = 32 candidate_seq_len = 8
hash_config = HashConfig(
num_user_hashes=2,
num_item_hashes=2,
num_author_hashes=2,
)
retrieval_model_config = PhoenixRetrievalModelConfig(
emb_size=emb_size,
history_seq_len=history_seq_len,
candidate_seq_len=candidate_seq_len,
hash_config=hash_config,
product_surface_vocab_size=16,
model=TransformerConfig(
emb_size=emb_size,
widening_factor=2,
key_size=64,
num_q_heads=2,
num_kv_heads=2,
num_layers=2,
attn_output_multiplier=0.125,
),
)
inference_runner = RecsysRetrievalInferenceRunner(
runner=RetrievalModelRunner(
model=retrieval_model_config,
bs_per_device=0.125,
),
name="retrieval_local",
)
print("Initializing retrieval model...")
inference_runner.initialize()
print("Model initialized!")
print("\n" + "=" * 70)
print("RETRIEVAL SYSTEM DEMO")
print("=" * 70)
batch_size = 2 example_batch, example_embeddings = create_example_batch(
batch_size=batch_size,
emb_size=emb_size,
history_len=history_seq_len,
num_candidates=candidate_seq_len,
num_actions=num_actions,
num_user_hashes=hash_config.num_user_hashes,
num_item_hashes=hash_config.num_item_hashes,
num_author_hashes=hash_config.num_author_hashes,
product_surface_vocab_size=16,
)
valid_history_count = int((example_batch.history_post_hashes[:, :, 0] != 0).sum()) print(f"\nUsers have viewed {valid_history_count} posts total in their history")
print("\n" + "-" * 70)
print("STEP 1: Creating Candidate Corpus")
print("-" * 70)
corpus_size = 1000 corpus_embeddings, corpus_post_ids = create_example_corpus(
corpus_size=corpus_size,
emb_size=emb_size,
seed=456,
)
print(f"Corpus size: {corpus_size} posts")
print(f"Corpus embeddings shape: {corpus_embeddings.shape}")
inference_runner.set_corpus(corpus_embeddings, corpus_post_ids)
print("\n" + "-" * 70)
print("STEP 2: Retrieving Top-K Candidates")
print("-" * 70)
top_k = 10
retrieval_output = inference_runner.retrieve(
example_batch,
example_embeddings,
top_k=top_k,
)
print(f"\nRetrieved top {top_k} candidates for each of {batch_size} users:")
top_k_indices = np.array(retrieval_output.top_k_indices)
top_k_scores = np.array(retrieval_output.top_k_scores)
for user_idx in range(batch_size):
print(f"\n User {user_idx + 1}:")
print(f" {'Rank':<6} {'Post ID':<12} {'Score':<12}")
print(f" {'-' * 30}")
for rank in range(top_k):
post_id = top_k_indices[user_idx, rank]
score = top_k_scores[user_idx, rank]
bar = "█" * int((score + 1) * 10) + "░" * (20 - int((score + 1) * 10))
print(f" {rank + 1:<6} {post_id:<12} {bar} {score:.4f}")
print("\n" + "=" * 70)
print("Demo complete!")
print("=" * 70)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()