import json
import sys
from collections import defaultdict
def normalize(text):
t = text
while t and t[0] in (" ", "_", "▁"):
t = t[1:]
return t.casefold()
def find_for_target(features, target):
target_norm = normalize(target)
hits = []
for ft in features:
best = None
for tok in ft["top_tokens"]:
if normalize(tok["text"]) == target_norm:
cos = tok["cosine"]
if best is None or cos > best:
best = cos
if best is not None:
hits.append((best, ft))
hits.sort(key=lambda x: -x[0])
return hits
def show(raw_path, targets, topn=5):
with open(raw_path, encoding="utf-8") as f:
data = json.load(f)
feats = data["features"]
print(f"=== {raw_path} ({len(feats)} features scanned) ===")
for target in targets:
hits = find_for_target(feats, target)
print(f"\n Top {topn} features for inject_word \"{target}\":")
if not hits:
print(f" (no feature has \"{target}\" in its top-{data.get('top_k','?')})")
continue
for cos, ft in hits[:topn]:
top5 = ", ".join(
t["text"].strip() + f"({t['cosine']:.3f})"
for t in ft["top_tokens"][:5]
)
print(
f" L{ft['feature']['layer']:>2}:{ft['feature']['index']:>5} "
f"cos_to_target={cos:.4f} max_cosine={ft['max_cosine']:.4f} "
f"[{top5}]"
)
def main():
sys.stdout.reconfigure(encoding="utf-8")
if len(sys.argv) < 3:
print(
"Usage: pick_inject_feature.py <raw_scan_json> "
"<target_word> [<target_word> ...]"
)
sys.exit(1)
raw_path = sys.argv[1]
targets = sys.argv[2:]
show(raw_path, targets)
if __name__ == "__main__":
main()