import os
import sys
import torch
import numpy as np
import coremltools as ct
from transformers import AutoTokenizer
import torch.nn.functional as F
model_dir = "/Users/mazhewitt/Library/Caches/candle-coreml/clean-anemll--anemll-Qwen-Qwen3-0.6B-LUT888-ctx512_0.3.4"
print(f"Loading models from: {model_dir}")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", trust_remote_code=True)
embed_model = ct.models.MLModel(os.path.join(model_dir, "qwen_embeddings.mlmodelc"))
ffn_model = ct.models.MLModel(os.path.join(model_dir, "qwen_FFN_PF_lut8_chunk_01of01.mlmodelc"))
lm_head = ct.models.MLModel(os.path.join(model_dir, "qwen_lm_head_lut8.mlmodelc"))
print("Models loaded successfully")
prompt = "The quick brown fox jumps over the lazy"
print(f"\nPrompt: '{prompt}'")
input_ids = tokenizer.encode(prompt, return_tensors='pt')
print(f"Token IDs: {input_ids[0].tolist()}")
print(f"Token count: {len(input_ids[0])}")
state = ffn_model.make_state()
context_pos = len(input_ids[0])
batch_size = 64
context_length = 512
print("\n=== PREFILL PHASE ===")
padded_input = F.pad(input_ids, (0, batch_size - context_pos), value=0)
print(f"Padded input shape: {padded_input.shape}")
embed_output = embed_model.predict({'input_ids': padded_input.numpy().astype(np.int32)})
hidden_states = torch.from_numpy(embed_output['hidden_states'])
print(f"Embeddings shape: {hidden_states.shape}")
position_ids = torch.arange(batch_size, dtype=torch.int32)
def make_causal_mask(context_length, prefetch):
mask = np.ones((1, 1, context_length, context_length), dtype=np.float16)
for i in range(context_length):
for j in range(i + 1 + prefetch, context_length):
mask[0, 0, i, j] = -65500.0
return mask
causal_mask = torch.tensor(make_causal_mask(context_length, 0), dtype=torch.float16)
batch_causal_mask = causal_mask[:, :, :batch_size, :]
current_pos = torch.tensor([0], dtype=torch.int32)
ffn_output = ffn_model.predict({
'hidden_states': hidden_states.numpy().astype(np.float16),
'position_ids': position_ids.numpy().astype(np.int32),
'causal_mask': batch_causal_mask.numpy().astype(np.float16),
'current_pos': current_pos.numpy().astype(np.int32)
}, state)
print(f"FFN prefill complete")
print("\n=== INFERENCE PHASE ===")
pos = context_pos
current_token = input_ids[:, pos-1:pos]
print(f"Current token for inference: {current_token[0].tolist()}")
hidden_states = torch.from_numpy(
embed_model.predict({'input_ids': current_token.numpy().astype(np.int32)})['hidden_states']
)
print(f"Infer embedding shape: {hidden_states.shape}")
update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
update_mask[0, 0, pos-1, 0] = 1.0
position_ids = torch.tensor([pos-1], dtype=torch.int32)
single_causal_mask = causal_mask[:, :, pos-1:pos, :]
ffn_inputs = {
'hidden_states': hidden_states.numpy().astype(np.float16),
'update_mask': update_mask.numpy().astype(np.float16),
'position_ids': position_ids.numpy().astype(np.int32),
'causal_mask': single_causal_mask.numpy().astype(np.float16),
'current_pos': position_ids.numpy().astype(np.int32)
}
try:
ffn_output = ffn_model.predict(ffn_inputs, state)
hidden_states = torch.from_numpy(ffn_output['output_hidden_states'])
print(f"FFN infer output shape: {hidden_states.shape}")
except:
print("Trying without update_mask...")
ffn_inputs.pop('update_mask')
ffn_output = ffn_model.predict(ffn_inputs, state)
hidden_states = torch.from_numpy(ffn_output['output_hidden_states'])
print(f"FFN output shape: {hidden_states.shape}")
lm_output = lm_head.predict({'hidden_states': hidden_states.numpy().astype(np.float16)})
print(f"\nLM head outputs: {list(lm_output.keys())}")
logits_parts = []
for i in range(1, 17): key = f'logits{i}'
if key in lm_output:
logits_parts.append(torch.from_numpy(lm_output[key]))
if logits_parts:
logits = torch.cat(logits_parts, dim=-1)
print(f"Combined logits shape: {logits.shape}")
next_token = torch.argmax(logits[0, -1, :]).item()
print(f"\n{'='*50}")
print(f"RESULT FROM PYTHON/CHAT.PY LOGIC:")
print(f"{'='*50}")
print(f"Next token ID: {next_token}")
print(f"Next token decoded: '{tokenizer.decode([next_token])}'")
if next_token == 5562:
print("✅ Python generated 'dog' (token 5562)")
elif next_token == 3974:
print("⚠️ Python also generated ' quick' (token 3974) - same as Rust!")
else:
decoded = tokenizer.decode([next_token])
print(f"❓ Python generated unexpected token: {next_token} ('{decoded}')")
logits_flat = logits[0, -1, :]
top5 = torch.topk(logits_flat, 5)
print(f"\nTop 5 predictions:")
for i, (score, idx) in enumerate(zip(top5.values, top5.indices)):
token_id = idx.item()
decoded = tokenizer.decode([token_id])
print(f" {i+1}. Token {token_id} ('{decoded}'): {score:.2f}")
else:
print("ERROR: No logits found in output")