import ctypes
import os
import sys
class RypeU64Array(ctypes.Structure):
_fields_ = [
("data", ctypes.POINTER(ctypes.c_uint64)),
("len", ctypes.c_size_t),
]
def to_list(self):
return [self.data[i] for i in range(self.len)]
class RypeMinimizerSetResult(ctypes.Structure):
_fields_ = [
("forward", RypeU64Array),
("reverse_complement", RypeU64Array),
]
class RypeStrandResult(ctypes.Structure):
_fields_ = [
("hashes", ctypes.POINTER(ctypes.c_uint64)),
("positions", ctypes.POINTER(ctypes.c_uint64)),
("len", ctypes.c_size_t),
]
def hashes_list(self):
return [self.hashes[i] for i in range(self.len)]
def positions_list(self):
return [self.positions[i] for i in range(self.len)]
class RypeStrandMinimizersResult(ctypes.Structure):
_fields_ = [
("forward", RypeStrandResult),
("reverse_complement", RypeStrandResult),
]
def load_librype():
script_dir = os.path.dirname(os.path.abspath(__file__))
lib_path = os.path.normpath(os.path.join(script_dir, "..", "target", "release", "librype.so"))
if not os.path.exists(lib_path):
print(f"ERROR: {lib_path} not found.", file=sys.stderr)
print("Build with: cargo build --release", file=sys.stderr)
sys.exit(1)
lib = ctypes.CDLL(lib_path)
lib.rype_extract_minimizer_set.argtypes = [
ctypes.c_char_p, ctypes.c_size_t, ctypes.c_size_t, ctypes.c_size_t, ctypes.c_uint64, ]
lib.rype_extract_minimizer_set.restype = ctypes.POINTER(RypeMinimizerSetResult)
lib.rype_minimizer_set_result_free.argtypes = [ctypes.POINTER(RypeMinimizerSetResult)]
lib.rype_minimizer_set_result_free.restype = None
lib.rype_extract_strand_minimizers.argtypes = [
ctypes.c_char_p, ctypes.c_size_t, ctypes.c_size_t, ctypes.c_size_t, ctypes.c_uint64, ]
lib.rype_extract_strand_minimizers.restype = ctypes.POINTER(RypeStrandMinimizersResult)
lib.rype_strand_minimizers_result_free.argtypes = [ctypes.POINTER(RypeStrandMinimizersResult)]
lib.rype_strand_minimizers_result_free.restype = None
lib.rype_get_last_error.argtypes = []
lib.rype_get_last_error.restype = ctypes.c_char_p
return lib
def extract_minimizer_set(lib, seq, k, w, salt):
result_ptr = lib.rype_extract_minimizer_set(seq, len(seq), k, w, salt)
if not result_ptr:
err = lib.rype_get_last_error()
raise RuntimeError(f"rype_extract_minimizer_set failed: {err.decode() if err else 'unknown'}")
try:
fwd = result_ptr.contents.forward.to_list()
rc = result_ptr.contents.reverse_complement.to_list()
return fwd, rc
finally:
lib.rype_minimizer_set_result_free(result_ptr)
def extract_strand_minimizers(lib, seq, k, w, salt):
result_ptr = lib.rype_extract_strand_minimizers(seq, len(seq), k, w, salt)
if not result_ptr:
err = lib.rype_get_last_error()
raise RuntimeError(f"rype_extract_strand_minimizers failed: {err.decode() if err else 'unknown'}")
try:
fwd_h = result_ptr.contents.forward.hashes_list()
fwd_p = result_ptr.contents.forward.positions_list()
rc_h = result_ptr.contents.reverse_complement.hashes_list()
rc_p = result_ptr.contents.reverse_complement.positions_list()
return (fwd_h, fwd_p), (rc_h, rc_p)
finally:
lib.rype_strand_minimizers_result_free(result_ptr)
def main():
lib = load_librype()
test_sequences = [
(b"AAAAACCCCCAAAAACCCCCAAAAACCCCCAAAAACCCCCAAAAACCCCCAAAAACCCCCAAAAACCCCC", "mixed 70bp"),
(b"AAAATTTTGGGGCCCCAAAATTTTGGGGCCCCAAAATTTTGGGGCCCC", "alternating 48bp"),
(b"ACGT", "too short for k=16"),
(b"AAAAACCCCCAAAAACCCCCNAAAACCCCCAAAAACCCCCAAAAACCCCC", "N in middle"),
]
k, w, salt = 16, 5, 0
errors = 0
print(f"Parameters: k={k}, w={w}, salt=0x{salt:016x}")
print(f"Test sequences: {len(test_sequences)}")
print("\n=== rype_extract_minimizer_set ===")
for seq, label in test_sequences:
fwd, rc = extract_minimizer_set(lib, seq, k, w, salt)
display = seq[:40].decode() + ("..." if len(seq) > 40 else "")
print(f"\n [{label}] len={len(seq)}: {display}")
print(f" fwd: {len(fwd)} hashes, rc: {len(rc)} hashes")
if fwd:
print(f" fwd[0:3] = {['0x%016x' % h for h in fwd[:3]]}")
if rc:
print(f" rc[0:3] = {['0x%016x' % h for h in rc[:3]]}")
for i in range(1, len(fwd)):
if fwd[i] <= fwd[i - 1]:
print(f" ERROR: fwd not strictly sorted at index {i}")
errors += 1
break
for i in range(1, len(rc)):
if rc[i] <= rc[i - 1]:
print(f" ERROR: rc not strictly sorted at index {i}")
errors += 1
break
if len(seq) < k:
if fwd or rc:
print(f" ERROR: short sequence produced non-empty results")
errors += 1
else:
print(f" OK: empty results for short sequence")
print("\n=== rype_extract_strand_minimizers ===")
for seq, label in test_sequences:
(fwd_h, fwd_p), (rc_h, rc_p) = extract_strand_minimizers(lib, seq, k, w, salt)
display = seq[:40].decode() + ("..." if len(seq) > 40 else "")
print(f"\n [{label}] len={len(seq)}: {display}")
print(f" fwd: {len(fwd_h)} minimizers, rc: {len(rc_h)} minimizers")
if fwd_h:
print(f" fwd[0]: hash=0x{fwd_h[0]:016x}, pos={fwd_p[0]}")
if rc_h:
print(f" rc[0]: hash=0x{rc_h[0]:016x}, pos={rc_p[0]}")
if len(fwd_h) != len(fwd_p):
print(f" ERROR: fwd hashes/positions length mismatch: {len(fwd_h)} vs {len(fwd_p)}")
errors += 1
if len(rc_h) != len(rc_p):
print(f" ERROR: rc hashes/positions length mismatch: {len(rc_h)} vs {len(rc_p)}")
errors += 1
for i in range(1, len(fwd_p)):
if fwd_p[i] < fwd_p[i - 1]:
print(f" ERROR: fwd positions not non-decreasing at index {i}")
errors += 1
break
for i in range(1, len(rc_p)):
if rc_p[i] < rc_p[i - 1]:
print(f" ERROR: rc positions not non-decreasing at index {i}")
errors += 1
break
for i, p in enumerate(fwd_p):
if p + k > len(seq):
print(f" ERROR: fwd position {p} out of bounds (pos+k={p+k} > len={len(seq)})")
errors += 1
break
for i, p in enumerate(rc_p):
if p + k > len(seq):
print(f" ERROR: rc position {p} out of bounds (pos+k={p+k} > len={len(seq)})")
errors += 1
break
if b"N" in seq:
n_pos = seq.index(b"N")
for p in fwd_p:
if p < n_pos and p + k > n_pos:
print(f" ERROR: fwd position {p} spans N at {n_pos}")
errors += 1
break
for p in rc_p:
if p < n_pos and p + k > n_pos:
print(f" ERROR: rc position {p} spans N at {n_pos}")
errors += 1
break
if len(seq) < k:
if fwd_h or rc_h:
print(f" ERROR: short sequence produced non-empty results")
errors += 1
else:
print(f" OK: empty results for short sequence")
print("\n=== Error handling ===")
result_ptr = lib.rype_extract_minimizer_set(b"ACGTACGT", 8, 17, 5, 0)
if not result_ptr:
err = lib.rype_get_last_error()
print(f" Invalid k=17 correctly rejected: {err.decode()}")
else:
print(f" ERROR: invalid k=17 was not rejected")
lib.rype_minimizer_set_result_free(result_ptr)
errors += 1
result_ptr = lib.rype_extract_strand_minimizers(b"ACGTACGT", 8, 16, 0, 0)
if not result_ptr:
err = lib.rype_get_last_error()
print(f" Invalid w=0 correctly rejected: {err.decode()}")
else:
print(f" ERROR: invalid w=0 was not rejected")
lib.rype_strand_minimizers_result_free(result_ptr)
errors += 1
print(f"\n{'='*40}")
if errors == 0:
print("All checks passed.")
else:
print(f"FAILED: {errors} error(s) detected.")
sys.exit(1)
if __name__ == "__main__":
main()