import sys
import json
import time
def test_version():
import sci_form
v = sci_form.version()
assert v.startswith("sci-form"), f"Bad version: {v}"
print(f" version: {v}")
def test_embed_single():
import sci_form
r = sci_form.embed("CCO", seed=42)
assert r.is_ok(), f"Embed failed: {r.error}"
assert r.num_atoms == 9, f"Expected 9 atoms, got {r.num_atoms}"
assert len(r.coords) == 27, f"Expected 27 coords, got {len(r.coords)}"
assert len(r.elements) == 9
assert r.elements[0] == 6 assert r.elements[2] == 8 assert len(r.bonds) > 0
pos = r.get_positions()
assert len(pos) == 9
assert len(pos[0]) == 3
r2 = sci_form.embed("CCO", seed=42)
assert r.coords == r2.coords, "Same seed should give same result"
print(f" embed CCO: {r.num_atoms} atoms, {len(r.bonds)} bonds, {r.time_ms:.1f}ms")
def test_embed_molecules():
import sci_form
molecules = {
"c1ccccc1": 12, "CC(=O)O": 8, "CC(=O)Oc1ccccc1C(=O)O": 21, "C": 5, "N": 4, "O": 3, "C#N": 3, "C=C": 6, "c1ccncc1": 11, }
for smi, expected_atoms in molecules.items():
r = sci_form.embed(smi, seed=42)
assert r.is_ok(), f"Failed for {smi}: {r.error}"
assert r.num_atoms == expected_atoms, f"{smi}: expected {expected_atoms} atoms, got {r.num_atoms}"
print(f" embed {len(molecules)} molecules: all OK")
def test_batch():
import sci_form
smiles = ["CCO", "c1ccccc1", "CC(=O)O", "C#N", "c1ccncc1"]
results = sci_form.embed_batch(smiles, seed=42, num_threads=2)
assert len(results) == len(smiles), f"Expected {len(smiles)} results, got {len(results)}"
for r in results:
assert r.is_ok(), f"Batch failed for {r.smiles}: {r.error}"
for smi in smiles:
single = sci_form.embed(smi, seed=42)
batch_r = next(r for r in results if r.smiles == smi)
assert single.coords == batch_r.coords, f"Batch/single mismatch for {smi}"
print(f" batch {len(smiles)} molecules: all match single results")
def test_parse():
import sci_form
info = sci_form.parse("CCO")
assert info["num_atoms"] == 9
assert info["num_bonds"] == 8
assert len(info["atoms"]) == 9
carbons = [a for a in info["atoms"] if a["element"] == 6]
oxygens = [a for a in info["atoms"] if a["element"] == 8]
hydrogens = [a for a in info["atoms"] if a["element"] == 1]
assert len(carbons) == 2
assert len(oxygens) == 1
assert len(hydrogens) == 6
print(f" parse CCO: {info['num_atoms']} atoms, {info['num_bonds']} bonds")
def test_error_handling():
import sci_form
r = sci_form.embed("INVALID_SMILES_XYZ", seed=42)
assert not r.is_ok(), "Should have failed for invalid SMILES"
assert r.error is not None
try:
sci_form.parse("INVALID_SMILES_XYZ")
assert False, "Should have raised ValueError"
except ValueError:
pass
print(" error handling: OK")
def test_coordinate_sanity():
import sci_form
import math
r = sci_form.embed("CCO", seed=42)
pos = r.get_positions()
for i, (_, _, order) in enumerate(r.bonds):
a, b = r.bonds[i][0], r.bonds[i][1]
dx = pos[a][0] - pos[b][0]
dy = pos[a][1] - pos[b][1]
dz = pos[a][2] - pos[b][2]
dist = math.sqrt(dx * dx + dy * dy + dz * dz)
assert 0.5 < dist < 3.0, f"Bond {a}-{b} length {dist:.2f} out of range"
print(f" coordinate sanity: all bond lengths in [0.5, 3.0] Ã…")
def test_performance():
import sci_form
smiles = ["CCO", "c1ccccc1", "CC(=O)O", "CC(=O)Oc1ccccc1C(=O)O", "c1ccncc1"]
for smi in smiles:
r = sci_form.embed(smi, seed=42)
assert r.time_ms < 50, f"{smi} took {r.time_ms:.1f}ms (> 50ms target)"
batch_smiles = smiles * 20 start = time.time()
results = sci_form.embed_batch(batch_smiles, seed=42)
elapsed_ms = (time.time() - start) * 1000
avg_ms = elapsed_ms / len(batch_smiles)
ok = sum(1 for r in results if r.is_ok())
print(f" performance: {len(batch_smiles)} mols in {elapsed_ms:.0f}ms ({avg_ms:.1f}ms/mol), {ok}/{len(batch_smiles)} OK")
def test_data_roundtrip():
import sci_form
r = sci_form.embed("CCO", seed=42)
data = {
"smiles": r.smiles,
"num_atoms": r.num_atoms,
"coords": r.coords,
"elements": list(r.elements),
"bonds": [(a, b, o) for a, b, o in r.bonds],
"error": r.error,
"time_ms": r.time_ms,
}
json_str = json.dumps(data)
roundtripped = json.loads(json_str)
assert roundtripped["num_atoms"] == r.num_atoms
assert roundtripped["coords"] == list(r.coords)
assert roundtripped["elements"] == list(r.elements)
print(" data roundtrip: JSON serialize/deserialize OK")
def main():
tests = [
("Version", test_version),
("Single embed", test_embed_single),
("Multiple molecules", test_embed_molecules),
("Batch parallel", test_batch),
("Parse", test_parse),
("Error handling", test_error_handling),
("Coordinate sanity", test_coordinate_sanity),
("Performance", test_performance),
("Data roundtrip", test_data_roundtrip),
]
print("=" * 60)
print("sci-form Python Integration Tests")
print("=" * 60)
passed = 0
failed = 0
for name, test_fn in tests:
try:
test_fn()
passed += 1
print(f" [PASS] {name}")
except Exception as e:
failed += 1
print(f" [FAIL] {name}: {e}")
print("=" * 60)
print(f"Results: {passed} passed, {failed} failed")
print("=" * 60)
sys.exit(1 if failed > 0 else 0)
if __name__ == "__main__":
main()