import json
import math
import os
import sys
import totalreclaw_core
from totalreclaw.crypto import (
DerivedKeys,
compute_auth_key_hash as py_compute_auth_key_hash,
decrypt as py_decrypt,
derive_keys_from_mnemonic as py_derive_keys,
derive_lsh_seed as py_derive_lsh_seed,
encrypt as py_encrypt,
generate_blind_indices as py_blind_indices,
generate_content_fingerprint as py_fingerprint,
)
from totalreclaw.lsh import LSHHasher as PyLSHHasher
from totalreclaw.hermes.debrief import (
parse_debrief_response as py_parse_debrief,
)
FIXTURES_DIR = os.path.join(os.path.dirname(__file__), "fixtures")
with open(os.path.join(FIXTURES_DIR, "crypto_vectors.json")) as f:
VECTORS = json.load(f)
MNEMONIC = VECTORS["key_derivation"]["mnemonic"]
def test_key_derivation_parity():
rust_keys = totalreclaw_core.derive_keys_from_mnemonic(MNEMONIC)
py_keys = py_derive_keys(MNEMONIC)
assert rust_keys["salt"] == py_keys.salt, "salt mismatch"
assert rust_keys["auth_key"] == py_keys.auth_key, "auth_key mismatch"
assert rust_keys["encryption_key"] == py_keys.encryption_key, "encryption_key mismatch"
assert rust_keys["dedup_key"] == py_keys.dedup_key, "dedup_key mismatch"
def test_auth_key_hash_parity():
rust_keys = totalreclaw_core.derive_keys_from_mnemonic(MNEMONIC)
py_keys = py_derive_keys(MNEMONIC)
rust_hash = totalreclaw_core.compute_auth_key_hash(rust_keys["auth_key"])
py_hash = py_compute_auth_key_hash(py_keys.auth_key)
assert rust_hash == py_hash, f"auth_key_hash mismatch: {rust_hash} != {py_hash}"
def test_lsh_seed_parity():
rust_keys = totalreclaw_core.derive_keys_from_mnemonic(MNEMONIC)
py_keys = py_derive_keys(MNEMONIC)
rust_lsh_seed = totalreclaw_core.derive_lsh_seed(MNEMONIC, rust_keys["salt"])
py_lsh_seed = py_derive_lsh_seed(MNEMONIC, py_keys.salt)
assert rust_lsh_seed == py_lsh_seed, "LSH seed mismatch"
def test_encrypt_rust_decrypt_python():
rust_keys = totalreclaw_core.derive_keys_from_mnemonic(MNEMONIC)
plaintext = "Cross-implementation test: Rust -> Python"
encrypted = totalreclaw_core.encrypt(plaintext, rust_keys["encryption_key"])
decrypted = py_decrypt(encrypted, rust_keys["encryption_key"])
assert decrypted == plaintext, f"Rust->Python decrypt failed: {decrypted}"
def test_encrypt_python_decrypt_rust():
py_keys = py_derive_keys(MNEMONIC)
plaintext = "Cross-implementation test: Python -> Rust"
encrypted = py_encrypt(plaintext, py_keys.encryption_key)
decrypted = totalreclaw_core.decrypt(encrypted, py_keys.encryption_key)
assert decrypted == plaintext, f"Python->Rust decrypt failed: {decrypted}"
def test_cross_encryption_unicode():
keys = totalreclaw_core.derive_keys_from_mnemonic(MNEMONIC)
texts = [
"Caf\u00e9 latte",
"\u4f60\u597d\u4e16\u754c",
"\U0001f680 Rocket science \U0001f30d",
"\u00e9\u00e8\u00ea\u00eb \u00fc\u00f6\u00e4",
]
for text in texts:
enc_rust = totalreclaw_core.encrypt(text, keys["encryption_key"])
dec_py = py_decrypt(enc_rust, keys["encryption_key"])
assert dec_py == text, f"Rust->Python failed for: {text}"
enc_py = py_encrypt(text, keys["encryption_key"])
dec_rust = totalreclaw_core.decrypt(enc_py, keys["encryption_key"])
assert dec_rust == text, f"Python->Rust failed for: {text}"
def test_blind_indices_parity():
for tc in VECTORS["blind_indices"]["test_cases"]:
text = tc["text"]
rust_indices = totalreclaw_core.generate_blind_indices(text)
py_indices = py_blind_indices(text)
assert rust_indices == py_indices, (
f"Blind indices mismatch for: {text}\n"
f" Rust: {rust_indices}\n"
f" Python: {py_indices}"
)
def test_blind_indices_parity_extra():
exact_match_texts = [
"The user prefers VSCode over Vim",
"API endpoint: /v1/users/create",
"UPPERCASE TEXT WITH NUMBERS 123",
]
for text in exact_match_texts:
rust_indices = totalreclaw_core.generate_blind_indices(text)
py_indices = py_blind_indices(text)
assert rust_indices == py_indices, f"Blind indices mismatch for: {text}"
subset_texts = [
"Meeting at 3pm on Monday",
"caf\u00e9 \u00e9clair",
]
for text in subset_texts:
rust_indices = totalreclaw_core.generate_blind_indices(text)
py_indices = py_blind_indices(text)
rust_set = set(rust_indices)
py_set = set(py_indices)
assert py_set.issubset(rust_set), (
f"Python indices not subset of Rust for: {text}\n"
f" Extra in Python: {py_set - rust_set}"
)
def test_content_fingerprint_parity():
fp = VECTORS["content_fingerprint"]
dedup_key = bytes.fromhex(fp["dedup_key_hex"])
for tc in fp["test_cases"]:
text = tc["text"]
rust_fp = totalreclaw_core.generate_content_fingerprint(text, dedup_key)
py_fp = py_fingerprint(text, dedup_key)
assert rust_fp == py_fp, f"Fingerprint mismatch for: {text}\n Rust: {rust_fp}\n Python: {py_fp}"
def test_content_fingerprint_parity_extra():
keys = totalreclaw_core.derive_keys_from_mnemonic(MNEMONIC)
texts = [
"Hello, World!",
" extra whitespace ",
"caf\u00e9 br\u00fbl\u00e9e",
"a\t\nb\n\nc",
]
for text in texts:
rust_fp = totalreclaw_core.generate_content_fingerprint(text, keys["dedup_key"])
py_fp = py_fingerprint(text, keys["dedup_key"])
assert rust_fp == py_fp, f"Fingerprint mismatch for: {repr(text)}"
def test_lsh_small_parity():
small = VECTORS["lsh"]["small"]
seed = bytes.fromhex(VECTORS["lsh"]["lsh_seed_hex"])
rust_hasher = totalreclaw_core.LshHasher(
seed, small["dims"], small["n_tables"], small["n_bits"]
)
py_hasher = PyLSHHasher(
seed, small["dims"], small["n_tables"], small["n_bits"]
)
embedding = small["embedding"]
rust_hashes = rust_hasher.hash(embedding)
py_hashes = py_hasher.hash(embedding)
assert rust_hashes == py_hashes, "LSH small hashes mismatch"
def test_lsh_real_1024d_parity():
real = VECTORS["lsh"]["real"]
seed = bytes.fromhex(VECTORS["lsh"]["lsh_seed_hex"])
dims = real["dims"]
embedding = [math.sin(i * 0.1) * 0.5 for i in range(dims)]
rust_hasher = totalreclaw_core.LshHasher(
seed, dims, real["n_tables"], real["n_bits"]
)
py_hasher = PyLSHHasher(seed, dims, real["n_tables"], real["n_bits"])
rust_hashes = rust_hasher.hash(embedding)
py_hashes = py_hasher.hash(embedding)
assert rust_hashes == py_hashes, "LSH 1024d hashes mismatch"
def test_debrief_parse_parity():
test_cases = [
json.dumps([
{"text": "Session was about refactoring the auth module", "type": "summary", "importance": 8},
{"text": "Migration to new API is still pending", "type": "context", "importance": 7},
]),
"[]",
"not json",
'```json\n[{"text": "Summary of session work done today", "type": "summary", "importance": 8}]\n```',
json.dumps([
{"text": "Important item passes filter threshold", "type": "summary", "importance": 8},
{"text": "Low importance item gets filtered out", "type": "context", "importance": 3},
]),
json.dumps([{"text": "Item without importance score value", "type": "summary"}]),
json.dumps([{"text": "Item with invalid type value here", "type": "fact", "importance": 7}]),
]
for response in test_cases:
rust_items = totalreclaw_core.parse_debrief_response(response)
py_items = py_parse_debrief(response)
assert len(rust_items) == len(py_items), (
f"Debrief count mismatch for: {response}\n"
f" Rust: {len(rust_items)}, Python: {len(py_items)}"
)
for i, (r, p) in enumerate(zip(rust_items, py_items)):
assert r["text"] == p.text, f"Debrief[{i}] text mismatch"
assert r["type"] == p.type, f"Debrief[{i}] type mismatch"
assert r["importance"] == p.importance, f"Debrief[{i}] importance mismatch"
def test_debrief_prompt_parity():
from totalreclaw.hermes.debrief import DEBRIEF_SYSTEM_PROMPT as PY_PROMPT
rust_prompt = totalreclaw_core.get_debrief_system_prompt()
assert rust_prompt == PY_PROMPT, "Debrief system prompt mismatch"
if __name__ == "__main__":
test_funcs = [v for k, v in sorted(globals().items()) if k.startswith("test_") and callable(v)]
passed = 0
failed = 0
for fn in test_funcs:
try:
fn()
passed += 1
print(f" PASS {fn.__name__}")
except Exception as e:
failed += 1
print(f" FAIL {fn.__name__}: {e}")
print(f"\n{passed} passed, {failed} failed, {passed + failed} total")
sys.exit(1 if failed else 0)