extern crate alloc;
use alloc::vec::Vec;
use core::ffi::{c_char, c_void, CStr};
use core::ptr;
use allocator_api2::{alloc::AllocError, alloc::Allocator, alloc::Layout};
use stringtape::{BytesTape, BytesTapeView, CharsTape, CharsTapeView};
pub use crate::stringzilla::{SortedIdx, Status as SzStatus};
pub type Capability = u32;
pub use crate::stringzilla::Status;
#[derive(Debug)]
pub struct Error {
pub status: Status,
pub message: Option<String>,
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.message {
Some(msg) => write!(f, "{}", msg),
None => write!(f, "{:?}", self.status),
}
}
}
impl std::error::Error for Error {}
impl From<Status> for Error {
fn from(status: Status) -> Self {
Error { status, message: None }
}
}
fn rust_error_from_c_message(status: Status, error_msg: *const c_char) -> Error {
let message = if !error_msg.is_null() && status != Status::Success {
unsafe { CStr::from_ptr(error_msg).to_str().ok().map(|s| s.to_string()) }
} else {
None
};
Error { status, message }
}
pub enum AnyCharsTape<'a> {
Tape32(CharsTape<u32, UnifiedAlloc>),
Tape64(CharsTape<u64, UnifiedAlloc>),
View32(CharsTapeView<'a, u32>),
View64(CharsTapeView<'a, u64>),
}
pub enum AnyBytesTape<'a> {
Tape32(BytesTape<u32, UnifiedAlloc>),
Tape64(BytesTape<u64, UnifiedAlloc>),
View32(BytesTapeView<'a, u32>),
View64(BytesTapeView<'a, u64>),
}
pub struct DeviceScope {
handle: *mut c_void,
}
impl DeviceScope {
pub fn default() -> Result<Self, Error> {
let mut handle = ptr::null_mut();
let mut error_msg: *const c_char = ptr::null();
let status = unsafe { szs_device_scope_init_default(&mut handle, &mut error_msg) };
match status {
Status::Success => Ok(Self { handle }),
err => Err(rust_error_from_c_message(err, error_msg)),
}
}
pub fn cpu_cores(cpu_cores: usize) -> Result<Self, Error> {
let mut handle = ptr::null_mut();
let mut error_msg: *const c_char = ptr::null();
let status = unsafe { szs_device_scope_init_cpu_cores(cpu_cores, &mut handle, &mut error_msg) };
match status {
Status::Success => Ok(Self { handle }),
err => Err(rust_error_from_c_message(err, error_msg)),
}
}
pub fn gpu_device(gpu_device: usize) -> Result<Self, Error> {
let mut handle = ptr::null_mut();
let mut error_msg: *const c_char = ptr::null();
let status = unsafe { szs_device_scope_init_gpu_device(gpu_device, &mut handle, &mut error_msg) };
match status {
Status::Success => Ok(Self { handle }),
err => Err(rust_error_from_c_message(err, error_msg)),
}
}
pub fn get_capabilities(&self) -> Result<Capability, Error> {
let mut capabilities: Capability = 0;
let mut error_msg: *const c_char = ptr::null();
let status = unsafe { szs_device_scope_get_capabilities(self.handle, &mut capabilities, &mut error_msg) };
match status {
Status::Success => Ok(capabilities),
err => Err(rust_error_from_c_message(err, error_msg)),
}
}
pub fn get_cpu_cores(&self) -> Result<usize, Error> {
let mut cpu_cores: usize = 0;
let mut error_msg: *const c_char = ptr::null();
let status = unsafe { szs_device_scope_get_cpu_cores(self.handle, &mut cpu_cores, &mut error_msg) };
match status {
Status::Success => Ok(cpu_cores),
err => Err(rust_error_from_c_message(err, error_msg)),
}
}
pub fn get_gpu_device(&self) -> Result<usize, Error> {
let mut gpu_device: usize = 0;
let mut error_msg: *const c_char = ptr::null();
let status = unsafe { szs_device_scope_get_gpu_device(self.handle, &mut gpu_device, &mut error_msg) };
match status {
Status::Success => Ok(gpu_device),
err => Err(rust_error_from_c_message(err, error_msg)),
}
}
pub fn is_gpu(&self) -> bool {
self.get_gpu_device().is_ok()
}
}
impl Drop for DeviceScope {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe { szs_device_scope_free(self.handle) };
}
}
}
unsafe impl Send for DeviceScope {}
unsafe impl Sync for DeviceScope {}
#[repr(C)]
struct SzSequence {
handle: *mut c_void,
count: usize,
get_start: extern "C" fn(*mut c_void, usize) -> *const u8,
get_length: extern "C" fn(*mut c_void, usize) -> usize,
starts: *const *const u8,
lengths: *const usize,
}
#[repr(C)]
#[derive(Copy, Clone)]
struct SzSequenceU32Tape {
data: *const u8,
offsets: *const u32,
count: usize,
}
#[repr(C)]
#[derive(Copy, Clone)]
struct SzSequenceU64Tape {
data: *const u8,
offsets: *const u64,
count: usize,
}
impl From<&BytesTape<u32, UnifiedAlloc>> for SzSequenceU32Tape {
fn from(tape: &BytesTape<u32, UnifiedAlloc>) -> Self {
let parts = tape.as_raw_parts();
SzSequenceU32Tape {
data: parts.data_ptr,
offsets: parts.offsets_ptr,
count: parts.items_count,
}
}
}
impl From<&CharsTape<u32, UnifiedAlloc>> for SzSequenceU32Tape {
fn from(tape: &CharsTape<u32, UnifiedAlloc>) -> Self {
let parts = tape.as_raw_parts();
SzSequenceU32Tape {
data: parts.data_ptr,
offsets: parts.offsets_ptr,
count: parts.items_count,
}
}
}
impl From<&BytesTape<u64, UnifiedAlloc>> for SzSequenceU64Tape {
fn from(tape: &BytesTape<u64, UnifiedAlloc>) -> Self {
let parts = tape.as_raw_parts();
SzSequenceU64Tape {
data: parts.data_ptr,
offsets: parts.offsets_ptr,
count: parts.items_count,
}
}
}
impl From<&CharsTape<u64, UnifiedAlloc>> for SzSequenceU64Tape {
fn from(tape: &CharsTape<u64, UnifiedAlloc>) -> Self {
let parts = tape.as_raw_parts();
SzSequenceU64Tape {
data: parts.data_ptr,
offsets: parts.offsets_ptr,
count: parts.items_count,
}
}
}
impl<'a> From<BytesTapeView<'a, u32>> for SzSequenceU32Tape {
fn from(view: BytesTapeView<'a, u32>) -> Self {
let p = view.as_raw_parts();
SzSequenceU32Tape {
data: p.data_ptr,
offsets: p.offsets_ptr,
count: p.items_count,
}
}
}
impl<'a> From<BytesTapeView<'a, u64>> for SzSequenceU64Tape {
fn from(view: BytesTapeView<'a, u64>) -> Self {
let p = view.as_raw_parts();
SzSequenceU64Tape {
data: p.data_ptr,
offsets: p.offsets_ptr,
count: p.items_count,
}
}
}
impl<'a> From<&BytesTapeView<'a, u32>> for SzSequenceU32Tape {
fn from(view: &BytesTapeView<'a, u32>) -> Self {
let p = view.as_raw_parts();
SzSequenceU32Tape {
data: p.data_ptr,
offsets: p.offsets_ptr,
count: p.items_count,
}
}
}
impl<'a> From<&BytesTapeView<'a, u64>> for SzSequenceU64Tape {
fn from(view: &BytesTapeView<'a, u64>) -> Self {
let p = view.as_raw_parts();
SzSequenceU64Tape {
data: p.data_ptr,
offsets: p.offsets_ptr,
count: p.items_count,
}
}
}
impl<'a> From<CharsTapeView<'a, u32>> for SzSequenceU32Tape {
fn from(view: CharsTapeView<'a, u32>) -> Self {
let p = view.as_raw_parts();
SzSequenceU32Tape {
data: p.data_ptr,
offsets: p.offsets_ptr,
count: p.items_count,
}
}
}
impl<'a> From<CharsTapeView<'a, u64>> for SzSequenceU64Tape {
fn from(view: CharsTapeView<'a, u64>) -> Self {
let p = view.as_raw_parts();
SzSequenceU64Tape {
data: p.data_ptr,
offsets: p.offsets_ptr,
count: p.items_count,
}
}
}
impl<'a> From<&CharsTapeView<'a, u32>> for SzSequenceU32Tape {
fn from(view: &CharsTapeView<'a, u32>) -> Self {
let p = view.as_raw_parts();
SzSequenceU32Tape {
data: p.data_ptr,
offsets: p.offsets_ptr,
count: p.items_count,
}
}
}
impl<'a> From<&CharsTapeView<'a, u64>> for SzSequenceU64Tape {
fn from(view: &CharsTapeView<'a, u64>) -> Self {
let p = view.as_raw_parts();
SzSequenceU64Tape {
data: p.data_ptr,
offsets: p.offsets_ptr,
count: p.items_count,
}
}
}
extern "C" fn sz_sequence_get_start_generic<T: AsRef<[u8]>>(handle: *mut c_void, index: usize) -> *const u8 {
unsafe {
let strings = core::slice::from_raw_parts(handle as *const T, index + 1);
strings[index].as_ref().as_ptr()
}
}
extern "C" fn sz_sequence_get_length_generic<T: AsRef<[u8]>>(handle: *mut c_void, index: usize) -> usize {
unsafe {
let strings = core::slice::from_raw_parts(handle as *const T, index + 1);
strings[index].as_ref().len()
}
}
extern "C" fn sz_sequence_get_start_str<T: AsRef<str>>(handle: *mut c_void, index: usize) -> *const u8 {
unsafe {
let strings = core::slice::from_raw_parts(handle as *const T, index + 1);
strings[index].as_ref().as_bytes().as_ptr()
}
}
extern "C" fn sz_sequence_get_length_str<T: AsRef<str>>(handle: *mut c_void, index: usize) -> usize {
unsafe {
let strings = core::slice::from_raw_parts(handle as *const T, index + 1);
strings[index].as_ref().as_bytes().len()
}
}
trait SzSequenceFromBytes {
fn to_sz_sequence(&self) -> SzSequence;
}
impl<T: AsRef<[u8]>> SzSequenceFromBytes for [T] {
fn to_sz_sequence(&self) -> SzSequence {
SzSequence {
handle: self.as_ptr() as *mut c_void,
count: self.len(),
get_start: sz_sequence_get_start_generic::<T>,
get_length: sz_sequence_get_length_generic::<T>,
starts: ptr::null(),
lengths: ptr::null(),
}
}
}
trait SzSequenceFromChars {
fn to_sz_sequence(&self) -> SzSequence;
}
impl<T: AsRef<str>> SzSequenceFromChars for [T] {
fn to_sz_sequence(&self) -> SzSequence {
SzSequence {
handle: self.as_ptr() as *mut c_void,
count: self.len(),
get_start: sz_sequence_get_start_str::<T>,
get_length: sz_sequence_get_length_str::<T>,
starts: ptr::null(),
lengths: ptr::null(),
}
}
}
pub type FingerprintsHandle = *mut c_void;
pub type LevenshteinDistancesHandle = *mut c_void;
pub type LevenshteinDistancesUtf8Handle = *mut c_void;
pub type NeedlemanWunschScoresHandle = *mut c_void;
pub type SmithWatermanScoresHandle = *mut c_void;
extern "C" {
fn szs_version_major() -> i32;
fn szs_version_minor() -> i32;
fn szs_version_patch() -> i32;
fn szs_capabilities() -> u32;
fn szs_device_scope_init_default(scope: *mut *mut c_void, error_message: *mut *const c_char) -> Status;
fn szs_device_scope_init_cpu_cores(
cpu_cores: usize,
scope: *mut *mut c_void,
error_message: *mut *const c_char,
) -> Status;
fn szs_device_scope_init_gpu_device(
gpu_device: usize,
scope: *mut *mut c_void,
error_message: *mut *const c_char,
) -> Status;
fn szs_device_scope_get_capabilities(
scope: *mut c_void,
capabilities: *mut Capability,
error_message: *mut *const c_char,
) -> Status;
fn szs_device_scope_get_cpu_cores(
scope: *mut c_void,
cpu_cores: *mut usize,
error_message: *mut *const c_char,
) -> Status;
fn szs_device_scope_get_gpu_device(
scope: *mut c_void,
gpu_device: *mut usize,
error_message: *mut *const c_char,
) -> Status;
fn szs_device_scope_free(scope: *mut c_void);
fn szs_levenshtein_distances_init(
match_cost: i8,
mismatch_cost: i8,
open_cost: i8,
extend_cost: i8,
alloc: *const c_void,
capabilities: Capability,
engine: *mut LevenshteinDistancesHandle,
error_message: *mut *const c_char,
) -> Status;
fn szs_levenshtein_distances_sequence(
engine: LevenshteinDistancesHandle,
device: *mut c_void,
a: *const c_void, b: *const c_void, results: *mut usize,
results_stride: usize,
error_message: *mut *const c_char,
) -> Status;
fn szs_levenshtein_distances_u32tape(
engine: LevenshteinDistancesHandle,
device: *mut c_void,
a: *const c_void, b: *const c_void, results: *mut usize,
results_stride: usize,
error_message: *mut *const c_char,
) -> Status;
fn szs_levenshtein_distances_u64tape(
engine: LevenshteinDistancesHandle,
device: *mut c_void,
a: *const c_void, b: *const c_void, results: *mut usize,
results_stride: usize,
error_message: *mut *const c_char,
) -> Status;
fn szs_levenshtein_distances_free(engine: LevenshteinDistancesHandle);
fn szs_levenshtein_distances_utf8_init(
match_cost: i8,
mismatch_cost: i8,
open_cost: i8,
extend_cost: i8,
alloc: *const c_void,
capabilities: Capability,
engine: *mut LevenshteinDistancesUtf8Handle,
error_message: *mut *const c_char,
) -> Status;
fn szs_levenshtein_distances_utf8_sequence(
engine: LevenshteinDistancesUtf8Handle,
device: *mut c_void,
a: *const c_void, b: *const c_void, results: *mut usize,
results_stride: usize,
error_message: *mut *const c_char,
) -> Status;
fn szs_levenshtein_distances_utf8_u32tape(
engine: LevenshteinDistancesUtf8Handle,
device: *mut c_void,
a: *const c_void, b: *const c_void, results: *mut usize,
results_stride: usize,
error_message: *mut *const c_char,
) -> Status;
fn szs_levenshtein_distances_utf8_u64tape(
engine: LevenshteinDistancesUtf8Handle,
device: *mut c_void,
a: *const c_void, b: *const c_void, results: *mut usize,
results_stride: usize,
error_message: *mut *const c_char,
) -> Status;
fn szs_levenshtein_distances_utf8_free(engine: LevenshteinDistancesUtf8Handle);
fn szs_needleman_wunsch_scores_init(
subs: *const i8, open_cost: i8,
extend_cost: i8,
alloc: *const c_void,
capabilities: Capability,
engine: *mut NeedlemanWunschScoresHandle,
error_message: *mut *const c_char,
) -> Status;
fn szs_needleman_wunsch_scores_sequence(
engine: NeedlemanWunschScoresHandle,
device: *mut c_void,
a: *const c_void, b: *const c_void, results: *mut isize,
results_stride: usize,
error_message: *mut *const c_char,
) -> Status;
fn szs_needleman_wunsch_scores_u32tape(
engine: NeedlemanWunschScoresHandle,
device: *mut c_void,
a: *const c_void, b: *const c_void, results: *mut isize,
results_stride: usize,
error_message: *mut *const c_char,
) -> Status;
fn szs_needleman_wunsch_scores_u64tape(
engine: NeedlemanWunschScoresHandle,
device: *mut c_void,
a: *const c_void, b: *const c_void, results: *mut isize,
results_stride: usize,
error_message: *mut *const c_char,
) -> Status;
fn szs_needleman_wunsch_scores_free(engine: NeedlemanWunschScoresHandle);
fn szs_smith_waterman_scores_init(
subs: *const i8, open_cost: i8,
extend_cost: i8,
alloc: *const c_void,
capabilities: Capability,
engine: *mut SmithWatermanScoresHandle,
error_message: *mut *const c_char,
) -> Status;
fn szs_smith_waterman_scores_sequence(
engine: SmithWatermanScoresHandle,
device: *mut c_void,
a: *const c_void, b: *const c_void, results: *mut isize,
results_stride: usize,
error_message: *mut *const c_char,
) -> Status;
fn szs_smith_waterman_scores_u32tape(
engine: SmithWatermanScoresHandle,
device: *mut c_void,
a: *const c_void, b: *const c_void, results: *mut isize,
results_stride: usize,
error_message: *mut *const c_char,
) -> Status;
fn szs_smith_waterman_scores_u64tape(
engine: SmithWatermanScoresHandle,
device: *mut c_void,
a: *const c_void, b: *const c_void, results: *mut isize,
results_stride: usize,
error_message: *mut *const c_char,
) -> Status;
fn szs_smith_waterman_scores_free(engine: SmithWatermanScoresHandle);
fn szs_fingerprints_init(
dimensions: usize,
alphabet_size: usize,
window_widths: *const usize,
window_widths_count: usize,
alloc: *const c_void, capabilities: Capability,
engine: *mut FingerprintsHandle,
error_message: *mut *const c_char,
) -> Status;
fn szs_fingerprints_sequence(
engine: FingerprintsHandle,
device: *mut c_void, texts: *const c_void, min_hashes: *mut u32,
min_hashes_stride: usize,
min_counts: *mut u32,
min_counts_stride: usize,
error_message: *mut *const c_char,
) -> Status;
fn szs_fingerprints_u32tape(
engine: FingerprintsHandle,
device: *mut c_void, texts: *const c_void, min_hashes: *mut u32,
min_hashes_stride: usize,
min_counts: *mut u32,
min_counts_stride: usize,
error_message: *mut *const c_char,
) -> Status;
fn szs_fingerprints_u64tape(
engine: FingerprintsHandle,
device: *mut c_void, texts: *const c_void, min_hashes: *mut u32,
min_hashes_stride: usize,
min_counts: *mut u32,
min_counts_stride: usize,
error_message: *mut *const c_char,
) -> Status;
fn szs_fingerprints_free(engine: FingerprintsHandle);
fn szs_unified_alloc(size_bytes: usize) -> *mut c_void;
fn szs_unified_free(ptr: *mut c_void, size_bytes: usize);
}
pub struct UnifiedAlloc;
unsafe impl Allocator for UnifiedAlloc {
fn allocate(&self, layout: Layout) -> Result<core::ptr::NonNull<[u8]>, AllocError> {
let size = layout.size();
if size == 0 {
let ptr = core::ptr::NonNull::new(layout.align() as *mut u8).ok_or(AllocError)?;
return Ok(core::ptr::NonNull::slice_from_raw_parts(ptr, 0));
}
let ptr = unsafe { szs_unified_alloc(size) };
if ptr.is_null() {
return Err(AllocError);
}
let ptr = core::ptr::NonNull::new(ptr as *mut u8).ok_or(AllocError)?;
Ok(core::ptr::NonNull::slice_from_raw_parts(ptr, size))
}
unsafe fn deallocate(&self, ptr: core::ptr::NonNull<u8>, layout: Layout) {
if layout.size() != 0 {
szs_unified_free(ptr.as_ptr() as *mut c_void, layout.size());
}
}
}
pub type UnifiedVec<T> = allocator_api2::vec::Vec<T, UnifiedAlloc>;
pub fn version() -> crate::stringzilla::SemVer {
crate::stringzilla::SemVer {
major: unsafe { szs_version_major() },
minor: unsafe { szs_version_minor() },
patch: unsafe { szs_version_patch() },
}
}
pub fn capabilities() -> crate::stringzilla::SmallCString {
let caps = unsafe { szs_capabilities() };
crate::stringzilla::capabilities_from_enum(caps)
}
pub struct LevenshteinDistances {
handle: LevenshteinDistancesHandle,
}
impl LevenshteinDistances {
pub fn new(
device: &DeviceScope,
match_cost: i8,
mismatch_cost: i8,
open_cost: i8,
extend_cost: i8,
) -> Result<Self, Error> {
let mut handle = ptr::null_mut();
let capabilities = device.get_capabilities().unwrap_or(0);
let mut error_msg: *const c_char = ptr::null();
let status = unsafe {
szs_levenshtein_distances_init(
match_cost,
mismatch_cost,
open_cost,
extend_cost,
ptr::null(),
capabilities,
&mut handle,
&mut error_msg,
)
};
match status {
Status::Success => Ok(Self { handle }),
err => Err(rust_error_from_c_message(err, error_msg)),
}
}
pub fn compute<T, S>(
&self,
device: &DeviceScope,
sequences_a: T,
sequences_b: T,
) -> Result<UnifiedVec<usize>, Error>
where
T: AsRef<[S]>,
S: AsRef<[u8]>,
{
let seq_a_slice = sequences_a.as_ref();
let seq_b_slice = sequences_b.as_ref();
let num_pairs = seq_a_slice.len().min(seq_b_slice.len());
let mut results = UnifiedVec::with_capacity_in(num_pairs, UnifiedAlloc);
results.resize(num_pairs, 0);
let results_stride = core::mem::size_of::<usize>();
if device.is_gpu() {
let force_64bit = should_use_64bit_for_bytes(seq_a_slice, seq_b_slice);
let tape_a = copy_bytes_into_tape(seq_a_slice, force_64bit)?;
let tape_b = copy_bytes_into_tape(seq_b_slice, force_64bit)?;
self.compute_into(device, tape_a, tape_b, &mut results[..])?;
Ok(results)
} else {
let seq_a = SzSequenceFromBytes::to_sz_sequence(seq_a_slice);
let seq_b = SzSequenceFromBytes::to_sz_sequence(seq_b_slice);
let mut error_msg: *const c_char = ptr::null();
let status = unsafe {
szs_levenshtein_distances_sequence(
self.handle,
device.handle,
&seq_a as *const _ as *const c_void,
&seq_b as *const _ as *const c_void,
results.as_mut_ptr(),
results_stride,
&mut error_msg,
)
};
match status {
Status::Success => Ok(results),
err => Err(rust_error_from_c_message(err, error_msg)),
}
}
}
pub fn compute_into<'a>(
&self,
device: &DeviceScope,
a: AnyBytesTape<'a>,
b: AnyBytesTape<'a>,
results: &mut [usize],
) -> Result<(), Error> {
let mut error_msg: *const c_char = ptr::null();
let results_stride = core::mem::size_of::<usize>();
let a64 = match &a {
AnyBytesTape::Tape64(t) => Some(SzSequenceU64Tape::from(t)),
AnyBytesTape::View64(v) => Some(SzSequenceU64Tape::from(v)),
_ => None,
};
let b64 = match &b {
AnyBytesTape::Tape64(t) => Some(SzSequenceU64Tape::from(t)),
AnyBytesTape::View64(v) => Some(SzSequenceU64Tape::from(v)),
_ => None,
};
if let (Some(va), Some(vb)) = (a64, b64) {
let need = core::cmp::min(va.count, vb.count);
if results.len() < need {
return Err(Error::from(SzStatus::UnexpectedDimensions));
}
let status = unsafe {
szs_levenshtein_distances_u64tape(
self.handle,
device.handle,
&va as *const _ as *const c_void,
&vb as *const _ as *const c_void,
results.as_mut_ptr(),
results_stride,
&mut error_msg,
)
};
return match status {
Status::Success => Ok(()),
err => Err(rust_error_from_c_message(err, error_msg)),
};
}
let a32 = match &a {
AnyBytesTape::Tape32(t) => Some(SzSequenceU32Tape::from(t)),
AnyBytesTape::View32(v) => Some(SzSequenceU32Tape::from(v)),
_ => None,
};
let b32 = match &b {
AnyBytesTape::Tape32(t) => Some(SzSequenceU32Tape::from(t)),
AnyBytesTape::View32(v) => Some(SzSequenceU32Tape::from(v)),
_ => None,
};
if let (Some(va), Some(vb)) = (a32, b32) {
let need = core::cmp::min(va.count, vb.count);
if results.len() < need {
return Err(Error::from(SzStatus::UnexpectedDimensions));
}
let status = unsafe {
szs_levenshtein_distances_u32tape(
self.handle,
device.handle,
&va as *const _ as *const c_void,
&vb as *const _ as *const c_void,
results.as_mut_ptr(),
results_stride,
&mut error_msg,
)
};
return match status {
Status::Success => Ok(()),
err => Err(rust_error_from_c_message(err, error_msg)),
};
}
Err(Error::from(SzStatus::UnexpectedDimensions))
}
}
impl Drop for LevenshteinDistances {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe { szs_levenshtein_distances_free(self.handle) };
}
}
}
unsafe impl Send for LevenshteinDistances {}
unsafe impl Sync for LevenshteinDistances {}
pub struct LevenshteinDistancesUtf8 {
handle: LevenshteinDistancesUtf8Handle,
}
impl LevenshteinDistancesUtf8 {
pub fn new(
device: &DeviceScope,
match_cost: i8,
mismatch_cost: i8,
open_cost: i8,
extend_cost: i8,
) -> Result<Self, Error> {
let mut handle = ptr::null_mut();
let capabilities = device.get_capabilities().unwrap_or(0);
let mut error_msg: *const c_char = ptr::null();
let status = unsafe {
szs_levenshtein_distances_utf8_init(
match_cost,
mismatch_cost,
open_cost,
extend_cost,
ptr::null(),
capabilities,
&mut handle,
&mut error_msg,
)
};
match status {
Status::Success => Ok(Self { handle }),
err => Err(rust_error_from_c_message(err, error_msg)),
}
}
pub fn compute<T, S>(
&self,
device: &DeviceScope,
sequences_a: T,
sequences_b: T,
) -> Result<UnifiedVec<usize>, Error>
where
T: AsRef<[S]>,
S: AsRef<str>,
{
let seq_a_slice = sequences_a.as_ref();
let seq_b_slice = sequences_b.as_ref();
let num_pairs = seq_a_slice.len().min(seq_b_slice.len());
let mut results = UnifiedVec::with_capacity_in(num_pairs, UnifiedAlloc);
results.resize(num_pairs, 0);
let results_stride = core::mem::size_of::<usize>();
if device.is_gpu() {
let force_64bit = should_use_64bit_for_strings(seq_a_slice, seq_b_slice);
let tape_a = copy_chars_into_tape(seq_a_slice, force_64bit)?;
let tape_b = copy_chars_into_tape(seq_b_slice, force_64bit)?;
self.compute_into(device, tape_a, tape_b, &mut results[..])?;
Ok(results)
} else {
let seq_a = SzSequenceFromChars::to_sz_sequence(seq_a_slice);
let seq_b = SzSequenceFromChars::to_sz_sequence(seq_b_slice);
let mut error_msg: *const c_char = ptr::null();
let status = unsafe {
szs_levenshtein_distances_utf8_sequence(
self.handle,
device.handle,
&seq_a as *const _ as *const c_void,
&seq_b as *const _ as *const c_void,
results.as_mut_ptr(),
results_stride,
&mut error_msg,
)
};
match status {
Status::Success => Ok(results),
err => Err(rust_error_from_c_message(err, error_msg)),
}
}
}
pub fn compute_into<'a>(
&self,
device: &DeviceScope,
a: AnyCharsTape<'a>,
b: AnyCharsTape<'a>,
results: &mut [usize],
) -> Result<(), Error> {
let mut error_msg: *const c_char = ptr::null();
let results_stride = core::mem::size_of::<usize>();
let a64 = match &a {
AnyCharsTape::Tape64(t) => Some(SzSequenceU64Tape::from(t)),
AnyCharsTape::View64(v) => Some(SzSequenceU64Tape::from(v)),
_ => None,
};
let b64 = match &b {
AnyCharsTape::Tape64(t) => Some(SzSequenceU64Tape::from(t)),
AnyCharsTape::View64(v) => Some(SzSequenceU64Tape::from(v)),
_ => None,
};
if let (Some(va), Some(vb)) = (a64, b64) {
let need = core::cmp::min(va.count, vb.count);
if results.len() < need {
return Err(Error::from(SzStatus::UnexpectedDimensions));
}
let status = unsafe {
szs_levenshtein_distances_utf8_u64tape(
self.handle,
device.handle,
&va as *const _ as *const c_void,
&vb as *const _ as *const c_void,
results.as_mut_ptr(),
results_stride,
&mut error_msg,
)
};
return match status {
Status::Success => Ok(()),
err => Err(rust_error_from_c_message(err, error_msg)),
};
}
let a32 = match &a {
AnyCharsTape::Tape32(t) => Some(SzSequenceU32Tape::from(t)),
AnyCharsTape::View32(v) => Some(SzSequenceU32Tape::from(v)),
_ => None,
};
let b32 = match &b {
AnyCharsTape::Tape32(t) => Some(SzSequenceU32Tape::from(t)),
AnyCharsTape::View32(v) => Some(SzSequenceU32Tape::from(v)),
_ => None,
};
if let (Some(va), Some(vb)) = (a32, b32) {
let need = core::cmp::min(va.count, vb.count);
if results.len() < need {
return Err(Error::from(SzStatus::UnexpectedDimensions));
}
let status = unsafe {
szs_levenshtein_distances_utf8_u32tape(
self.handle,
device.handle,
&va as *const _ as *const c_void,
&vb as *const _ as *const c_void,
results.as_mut_ptr(),
results_stride,
&mut error_msg,
)
};
return match status {
Status::Success => Ok(()),
err => Err(rust_error_from_c_message(err, error_msg)),
};
}
Err(Error::from(SzStatus::UnexpectedDimensions))
}
}
impl Drop for LevenshteinDistancesUtf8 {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe { szs_levenshtein_distances_utf8_free(self.handle) };
}
}
}
unsafe impl Send for LevenshteinDistancesUtf8 {}
unsafe impl Sync for LevenshteinDistancesUtf8 {}
pub struct NeedlemanWunschScores {
handle: NeedlemanWunschScoresHandle,
}
impl NeedlemanWunschScores {
pub fn new(
device: &DeviceScope,
substitution_matrix: &[[i8; 256]; 256],
open_cost: i8,
extend_cost: i8,
) -> Result<Self, Error> {
let mut handle = ptr::null_mut();
let capabilities = device.get_capabilities().unwrap_or(0);
let mut error_msg: *const c_char = ptr::null();
let status = unsafe {
szs_needleman_wunsch_scores_init(
substitution_matrix.as_ptr() as *const i8,
open_cost,
extend_cost,
ptr::null(),
capabilities,
&mut handle,
&mut error_msg,
)
};
match status {
Status::Success => Ok(Self { handle }),
err => Err(rust_error_from_c_message(err, error_msg)),
}
}
pub fn compute<T, S>(
&self,
device: &DeviceScope,
sequences_a: T,
sequences_b: T,
) -> Result<UnifiedVec<isize>, Error>
where
T: AsRef<[S]>,
S: AsRef<[u8]>,
{
let seq_a_slice = sequences_a.as_ref();
let seq_b_slice = sequences_b.as_ref();
let num_pairs = seq_a_slice.len().min(seq_b_slice.len());
let mut results = UnifiedVec::with_capacity_in(num_pairs, UnifiedAlloc);
results.resize(num_pairs, 0);
let results_stride = core::mem::size_of::<isize>();
if device.is_gpu() {
let force_64bit = should_use_64bit_for_bytes(seq_a_slice, seq_b_slice);
let tape_a = copy_bytes_into_tape(seq_a_slice, force_64bit)?;
let tape_b = copy_bytes_into_tape(seq_b_slice, force_64bit)?;
self.compute_into(device, tape_a, tape_b, &mut results[..])?;
Ok(results)
} else {
let seq_a = SzSequenceFromBytes::to_sz_sequence(seq_a_slice);
let seq_b = SzSequenceFromBytes::to_sz_sequence(seq_b_slice);
let mut error_msg: *const c_char = ptr::null();
let status = unsafe {
szs_needleman_wunsch_scores_sequence(
self.handle,
device.handle,
&seq_a as *const _ as *const c_void,
&seq_b as *const _ as *const c_void,
results.as_mut_ptr(),
results_stride,
&mut error_msg,
)
};
match status {
Status::Success => Ok(results),
err => Err(rust_error_from_c_message(err, error_msg)),
}
}
}
pub fn compute_into<'a>(
&self,
device: &DeviceScope,
a: AnyBytesTape<'a>,
b: AnyBytesTape<'a>,
results: &mut [isize],
) -> Result<(), Error> {
let mut error_msg: *const c_char = ptr::null();
let results_stride = core::mem::size_of::<isize>();
let a64 = match &a {
AnyBytesTape::Tape64(t) => Some(SzSequenceU64Tape::from(t)),
AnyBytesTape::View64(v) => Some(SzSequenceU64Tape::from(v)),
_ => None,
};
let b64 = match &b {
AnyBytesTape::Tape64(t) => Some(SzSequenceU64Tape::from(t)),
AnyBytesTape::View64(v) => Some(SzSequenceU64Tape::from(v)),
_ => None,
};
if let (Some(va), Some(vb)) = (a64, b64) {
let need = core::cmp::min(va.count, vb.count);
if results.len() < need {
return Err(Error::from(SzStatus::UnexpectedDimensions));
}
let status = unsafe {
szs_needleman_wunsch_scores_u64tape(
self.handle,
device.handle,
&va as *const _ as *const c_void,
&vb as *const _ as *const c_void,
results.as_mut_ptr(),
results_stride,
&mut error_msg,
)
};
return match status {
Status::Success => Ok(()),
err => Err(rust_error_from_c_message(err, error_msg)),
};
}
let a32 = match &a {
AnyBytesTape::Tape32(t) => Some(SzSequenceU32Tape::from(t)),
AnyBytesTape::View32(v) => Some(SzSequenceU32Tape::from(v)),
_ => None,
};
let b32 = match &b {
AnyBytesTape::Tape32(t) => Some(SzSequenceU32Tape::from(t)),
AnyBytesTape::View32(v) => Some(SzSequenceU32Tape::from(v)),
_ => None,
};
if let (Some(va), Some(vb)) = (a32, b32) {
let need = core::cmp::min(va.count, vb.count);
if results.len() < need {
return Err(Error::from(SzStatus::UnexpectedDimensions));
}
let status = unsafe {
szs_needleman_wunsch_scores_u32tape(
self.handle,
device.handle,
&va as *const _ as *const c_void,
&vb as *const _ as *const c_void,
results.as_mut_ptr(),
results_stride,
&mut error_msg,
)
};
return match status {
Status::Success => Ok(()),
err => Err(rust_error_from_c_message(err, error_msg)),
};
}
Err(Error::from(SzStatus::UnexpectedDimensions))
}
}
impl Drop for NeedlemanWunschScores {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe { szs_needleman_wunsch_scores_free(self.handle) };
}
}
}
unsafe impl Send for NeedlemanWunschScores {}
unsafe impl Sync for NeedlemanWunschScores {}
pub struct SmithWatermanScores {
handle: SmithWatermanScoresHandle,
}
impl SmithWatermanScores {
pub fn new(
device: &DeviceScope,
substitution_matrix: &[[i8; 256]; 256],
open_cost: i8,
extend_cost: i8,
) -> Result<Self, Error> {
let mut handle = ptr::null_mut();
let capabilities = device.get_capabilities().unwrap_or(0);
let mut error_msg: *const c_char = ptr::null();
let status = unsafe {
szs_smith_waterman_scores_init(
substitution_matrix.as_ptr() as *const i8,
open_cost,
extend_cost,
ptr::null(),
capabilities,
&mut handle,
&mut error_msg,
)
};
match status {
Status::Success => Ok(Self { handle }),
err => Err(rust_error_from_c_message(err, error_msg)),
}
}
pub fn compute<T, S>(
&self,
device: &DeviceScope,
sequences_a: T,
sequences_b: T,
) -> Result<UnifiedVec<isize>, Error>
where
T: AsRef<[S]>,
S: AsRef<[u8]>,
{
let seq_a_slice = sequences_a.as_ref();
let seq_b_slice = sequences_b.as_ref();
let num_pairs = seq_a_slice.len().min(seq_b_slice.len());
let mut results = UnifiedVec::with_capacity_in(num_pairs, UnifiedAlloc);
results.resize(num_pairs, 0);
let results_stride = core::mem::size_of::<isize>();
if device.is_gpu() {
let force_64bit = should_use_64bit_for_bytes(seq_a_slice, seq_b_slice);
let tape_a = copy_bytes_into_tape(seq_a_slice, force_64bit)?;
let tape_b = copy_bytes_into_tape(seq_b_slice, force_64bit)?;
self.compute_into(device, tape_a, tape_b, &mut results[..])?;
Ok(results)
} else {
let seq_a = SzSequenceFromBytes::to_sz_sequence(seq_a_slice);
let seq_b = SzSequenceFromBytes::to_sz_sequence(seq_b_slice);
let mut error_msg: *const c_char = ptr::null();
let status = unsafe {
szs_smith_waterman_scores_sequence(
self.handle,
device.handle,
&seq_a as *const _ as *const c_void,
&seq_b as *const _ as *const c_void,
results.as_mut_ptr(),
results_stride,
&mut error_msg,
)
};
match status {
Status::Success => Ok(results),
err => Err(rust_error_from_c_message(err, error_msg)),
}
}
}
pub fn compute_into<'a>(
&self,
device: &DeviceScope,
a: AnyBytesTape<'a>,
b: AnyBytesTape<'a>,
results: &mut [isize],
) -> Result<(), Error> {
let mut error_msg: *const c_char = ptr::null();
let results_stride = core::mem::size_of::<isize>();
let a64 = match &a {
AnyBytesTape::Tape64(t) => Some(SzSequenceU64Tape::from(t)),
AnyBytesTape::View64(v) => Some(SzSequenceU64Tape::from(v)),
_ => None,
};
let b64 = match &b {
AnyBytesTape::Tape64(t) => Some(SzSequenceU64Tape::from(t)),
AnyBytesTape::View64(v) => Some(SzSequenceU64Tape::from(v)),
_ => None,
};
if let (Some(va), Some(vb)) = (a64, b64) {
let need = core::cmp::min(va.count, vb.count);
if results.len() < need {
return Err(Error::from(SzStatus::UnexpectedDimensions));
}
let status = unsafe {
szs_smith_waterman_scores_u64tape(
self.handle,
device.handle,
&va as *const _ as *const c_void,
&vb as *const _ as *const c_void,
results.as_mut_ptr(),
results_stride,
&mut error_msg,
)
};
return match status {
Status::Success => Ok(()),
err => Err(rust_error_from_c_message(err, error_msg)),
};
}
let a32 = match &a {
AnyBytesTape::Tape32(t) => Some(SzSequenceU32Tape::from(t)),
AnyBytesTape::View32(v) => Some(SzSequenceU32Tape::from(v)),
_ => None,
};
let b32 = match &b {
AnyBytesTape::Tape32(t) => Some(SzSequenceU32Tape::from(t)),
AnyBytesTape::View32(v) => Some(SzSequenceU32Tape::from(v)),
_ => None,
};
if let (Some(va), Some(vb)) = (a32, b32) {
let need = core::cmp::min(va.count, vb.count);
if results.len() < need {
return Err(Error::from(SzStatus::UnexpectedDimensions));
}
let status = unsafe {
szs_smith_waterman_scores_u32tape(
self.handle,
device.handle,
&va as *const _ as *const c_void,
&vb as *const _ as *const c_void,
results.as_mut_ptr(),
results_stride,
&mut error_msg,
)
};
return match status {
Status::Success => Ok(()),
err => Err(rust_error_from_c_message(err, error_msg)),
};
}
Err(Error::from(SzStatus::UnexpectedDimensions))
}
}
impl Drop for SmithWatermanScores {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe { szs_smith_waterman_scores_free(self.handle) };
}
}
}
unsafe impl Send for SmithWatermanScores {}
unsafe impl Sync for SmithWatermanScores {}
pub struct FingerprintsBuilder {
alphabet_size: usize,
window_widths: Option<Vec<usize>>,
dimensions: usize,
}
impl FingerprintsBuilder {
pub fn new() -> Self {
Self {
alphabet_size: 0,
window_widths: None,
dimensions: 1024, }
}
pub fn binary(mut self) -> Self {
self.alphabet_size = 256;
self
}
pub fn ascii(mut self) -> Self {
self.alphabet_size = 128;
self
}
pub fn dna(mut self) -> Self {
self.alphabet_size = 4;
self
}
pub fn protein(mut self) -> Self {
self.alphabet_size = 22;
self
}
pub fn alphabet_size(mut self, size: usize) -> Self {
self.alphabet_size = size;
self
}
pub fn window_widths(mut self, widths: &[usize]) -> Self {
self.window_widths = Some(widths.to_vec());
self
}
pub fn dimensions(mut self, dimensions: usize) -> Self {
self.dimensions = dimensions;
self
}
pub fn build(self, device: &DeviceScope) -> Result<Fingerprints, Error> {
let mut engine: FingerprintsHandle = ptr::null_mut();
let capabilities = device.get_capabilities().unwrap_or(0);
let (widths_ptr, widths_len) = match &self.window_widths {
Some(widths) => (widths.as_ptr(), widths.len()),
None => (ptr::null(), 0),
};
let mut error_msg: *const c_char = ptr::null();
let status = unsafe {
szs_fingerprints_init(
self.dimensions,
self.alphabet_size,
widths_ptr,
widths_len,
ptr::null(), capabilities,
&mut engine,
&mut error_msg,
)
};
match status {
Status::Success => Ok(Fingerprints { handle: engine }),
err => Err(rust_error_from_c_message(err, error_msg)),
}
}
}
pub struct Fingerprints {
handle: FingerprintsHandle,
}
impl Fingerprints {
pub fn builder() -> FingerprintsBuilder {
FingerprintsBuilder::new()
}
pub fn compute<T, S>(
&self,
device: &DeviceScope,
strings: T,
dimensions: usize,
) -> Result<(UnifiedVec<u32>, UnifiedVec<u32>), Error>
where
T: AsRef<[S]>,
S: AsRef<[u8]>,
{
let strings_slice = strings.as_ref();
let num_strings = strings_slice.len();
let hashes_size = num_strings * dimensions;
let counts_size = num_strings * dimensions;
let mut min_hashes = UnifiedVec::with_capacity_in(hashes_size, UnifiedAlloc);
min_hashes.resize(hashes_size, 0);
let mut min_counts = UnifiedVec::with_capacity_in(counts_size, UnifiedAlloc);
min_counts.resize(counts_size, 0);
let hashes_stride = dimensions * core::mem::size_of::<u32>();
let counts_stride = dimensions * core::mem::size_of::<u32>();
if device.is_gpu() {
let total_size: usize = strings_slice.iter().map(|s| s.as_ref().len()).sum();
let force_64bit = total_size > u32::MAX as usize || strings_slice.len() > u32::MAX as usize;
let tape = copy_bytes_into_tape(strings_slice, force_64bit)?;
self.compute_into(device, tape, dimensions, &mut min_hashes[..], &mut min_counts[..])?;
Ok((min_hashes, min_counts))
} else {
let sequence = SzSequenceFromBytes::to_sz_sequence(strings_slice);
let mut error_msg: *const c_char = ptr::null();
let status = unsafe {
szs_fingerprints_sequence(
self.handle,
device.handle,
&sequence as *const _ as *const c_void,
min_hashes.as_mut_ptr(),
hashes_stride,
min_counts.as_mut_ptr(),
counts_stride,
&mut error_msg,
)
};
match status {
Status::Success => Ok((min_hashes, min_counts)),
err => Err(rust_error_from_c_message(err, error_msg)),
}
}
}
pub fn compute_into<'a>(
&self,
device: &DeviceScope,
texts: AnyBytesTape<'a>,
dimensions: usize,
min_hashes: &mut [u32],
min_counts: &mut [u32],
) -> Result<(), Error> {
let mut error_msg: *const c_char = ptr::null();
let count = match &texts {
AnyBytesTape::Tape64(t) => SzSequenceU64Tape::from(t).count,
AnyBytesTape::View64(v) => SzSequenceU64Tape::from(v).count,
AnyBytesTape::Tape32(t) => SzSequenceU32Tape::from(t).count,
AnyBytesTape::View32(v) => SzSequenceU32Tape::from(v).count,
};
let need = count * dimensions;
if min_hashes.len() < need || min_counts.len() < need {
return Err(Error::from(SzStatus::UnexpectedDimensions));
}
let hashes_stride = dimensions * core::mem::size_of::<u32>();
let counts_stride = dimensions * core::mem::size_of::<u32>();
let status = match &texts {
AnyBytesTape::Tape64(t) => {
let v = SzSequenceU64Tape::from(t);
unsafe {
szs_fingerprints_u64tape(
self.handle,
device.handle,
&v as *const _ as *const c_void,
min_hashes.as_mut_ptr(),
hashes_stride,
min_counts.as_mut_ptr(),
counts_stride,
&mut error_msg,
)
}
}
AnyBytesTape::View64(vv) => {
let v = SzSequenceU64Tape::from(vv);
unsafe {
szs_fingerprints_u64tape(
self.handle,
device.handle,
&v as *const _ as *const c_void,
min_hashes.as_mut_ptr(),
hashes_stride,
min_counts.as_mut_ptr(),
counts_stride,
&mut error_msg,
)
}
}
AnyBytesTape::Tape32(t) => {
let v = SzSequenceU32Tape::from(t);
unsafe {
szs_fingerprints_u32tape(
self.handle,
device.handle,
&v as *const _ as *const c_void,
min_hashes.as_mut_ptr(),
hashes_stride,
min_counts.as_mut_ptr(),
counts_stride,
&mut error_msg,
)
}
}
AnyBytesTape::View32(vv) => {
let v = SzSequenceU32Tape::from(vv);
unsafe {
szs_fingerprints_u32tape(
self.handle,
device.handle,
&v as *const _ as *const c_void,
min_hashes.as_mut_ptr(),
hashes_stride,
min_counts.as_mut_ptr(),
counts_stride,
&mut error_msg,
)
}
}
};
match status {
Status::Success => Ok(()),
err => Err(rust_error_from_c_message(err, error_msg)),
}
}
}
impl Drop for Fingerprints {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe { szs_fingerprints_free(self.handle) };
}
}
}
unsafe impl Send for Fingerprints {}
unsafe impl Sync for Fingerprints {}
pub fn error_costs_256x256_diagonal(match_score: i8, mismatch_score: i8) -> [[i8; 256]; 256] {
let mut result = [[0i8; 256]; 256];
for i in 0..256 {
for j in 0..256 {
result[i][j] = if i == j { match_score } else { mismatch_score };
}
}
result
}
pub fn error_costs_256x256_unary() -> [[i8; 256]; 256] {
error_costs_256x256_diagonal(0, -1)
}
fn should_use_64bit_for_bytes<T: AsRef<[u8]>>(seq_a: &[T], seq_b: &[T]) -> bool {
let total_size_a: usize = seq_a.iter().map(|s| s.as_ref().len()).sum();
let total_size_b: usize = seq_b.iter().map(|s| s.as_ref().len()).sum();
total_size_a > u32::MAX as usize
|| seq_a.len() > u32::MAX as usize
|| total_size_b > u32::MAX as usize
|| seq_b.len() > u32::MAX as usize
}
fn should_use_64bit_for_strings<T: AsRef<str>>(seq_a: &[T], seq_b: &[T]) -> bool {
let total_size_a: usize = seq_a.iter().map(|s| s.as_ref().len()).sum();
let total_size_b: usize = seq_b.iter().map(|s| s.as_ref().len()).sum();
total_size_a > u32::MAX as usize
|| seq_a.len() > u32::MAX as usize
|| total_size_b > u32::MAX as usize
|| seq_b.len() > u32::MAX as usize
}
fn copy_bytes_into_tape<'a, T>(sequences: &[T], force_64bit: bool) -> Result<AnyBytesTape<'a>, Error>
where
T: AsRef<[u8]>,
{
let total_size: usize = sequences.iter().map(|s| s.as_ref().len()).sum();
let use_64bit = force_64bit || total_size > u32::MAX as usize || sequences.len() > u32::MAX as usize;
if use_64bit {
let mut tape = BytesTape::<u64, UnifiedAlloc>::new_in(UnifiedAlloc);
tape.extend(sequences).map_err(|_| Error::from(SzStatus::BadAlloc))?;
Ok(AnyBytesTape::Tape64(tape))
} else {
let mut tape = BytesTape::<u32, UnifiedAlloc>::new_in(UnifiedAlloc);
tape.extend(sequences).map_err(|_| Error::from(SzStatus::BadAlloc))?;
Ok(AnyBytesTape::Tape32(tape))
}
}
fn copy_chars_into_tape<'a, T: AsRef<str>>(sequences: &[T], force_64bit: bool) -> Result<AnyCharsTape<'a>, Error> {
let total_size: usize = sequences.iter().map(|s| s.as_ref().len()).sum();
let use_64bit = force_64bit || total_size > u32::MAX as usize || sequences.len() > u32::MAX as usize;
if use_64bit {
let mut tape = CharsTape::<u64, UnifiedAlloc>::new_in(UnifiedAlloc);
tape.extend(sequences).map_err(|_| Error::from(SzStatus::BadAlloc))?;
Ok(AnyCharsTape::Tape64(tape))
} else {
let mut tape = CharsTape::<u32, UnifiedAlloc>::new_in(UnifiedAlloc);
tape.extend(sequences).map_err(|_| Error::from(SzStatus::BadAlloc))?;
Ok(AnyCharsTape::Tape32(tape))
}
}
pub fn backend_info() -> &'static str {
if cfg!(feature = "cuda") {
"CUDA GPU acceleration enabled"
} else if cfg!(all(feature = "rocm", not(feature = "cuda"))) {
"ROCm GPU acceleration enabled"
} else if cfg!(all(feature = "cpus", not(any(feature = "cuda", feature = "rocm")))) {
"Multi-threaded CPU backend enabled"
} else if cfg!(not(any(feature = "cpus", feature = "cuda", feature = "rocm"))) {
"StringZillas not available - enable cpus, cuda, or rocm feature"
} else {
"Unknown backend"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backend_info() {
let info = backend_info();
assert!(!info.is_empty());
println!("Backend: {}", info);
}
#[test]
fn device_scope_creation() {
let default_device = DeviceScope::default();
match default_device {
Ok(device) => {
let _caps = device.get_capabilities();
println!("Default device capabilities: {:?}", _caps);
}
Err(e) => println!("Default device creation failed: {:?}", e),
}
let cpu_device = DeviceScope::cpu_cores(4);
match cpu_device {
Ok(device) => {
assert!(!device.is_gpu());
if let Ok(cores) = device.get_cpu_cores() {
assert_eq!(cores, 4);
}
}
Err(e) => println!("CPU device creation failed: {:?}", e),
}
let gpu_device = DeviceScope::gpu_device(0);
match gpu_device {
Ok(device) => {
assert!(device.is_gpu());
if let Ok(gpu_id) = device.get_gpu_device() {
assert_eq!(gpu_id, 0);
}
}
Err(e) => println!("GPU device creation failed (expected if no GPU): {:?}", e),
}
}
#[test]
fn device_scope_validation() {
let all_cores = DeviceScope::cpu_cores(0);
assert!(all_cores.is_ok(), "CPU cores 0 should mean all cores");
let single_core = DeviceScope::cpu_cores(1);
assert!(single_core.is_ok(), "Single core should be valid");
let multi_cores = DeviceScope::cpu_cores(4);
assert!(multi_cores.is_ok(), "Multiple cores should be valid");
}
#[test]
fn fingerprint_builder_configurations() {
let device_result = DeviceScope::default();
if device_result.is_err() {
println!("Skipping fingerprint tests - device initialization failed");
return;
}
let device = device_result.unwrap();
let default_engine = Fingerprints::builder().build(&device);
assert!(default_engine.is_ok(), "Default fingerprint engine should initialize");
let binary_engine = Fingerprints::builder().binary().dimensions(256).build(&device);
assert!(binary_engine.is_ok(), "Binary fingerprint engine should initialize");
let ascii_engine = Fingerprints::builder().ascii().dimensions(256).build(&device);
assert!(ascii_engine.is_ok(), "ASCII fingerprint engine should initialize");
let dna_engine = Fingerprints::builder()
.dna()
.window_widths(&[3, 5, 7])
.dimensions(192) .build(&device);
assert!(dna_engine.is_ok(), "DNA fingerprint engine should initialize");
let protein_engine = Fingerprints::builder()
.protein()
.window_widths(&[5, 7])
.dimensions(128) .build(&device);
assert!(protein_engine.is_ok(), "Protein fingerprint engine should initialize");
let custom_engine = Fingerprints::builder()
.alphabet_size(16) .window_widths(&[4, 6, 8])
.dimensions(192) .build(&device);
assert!(custom_engine.is_ok(), "Custom fingerprint engine should initialize");
}
#[test]
fn fingerprint_computation() {
let device_result = DeviceScope::default();
if device_result.is_err() {
println!("Skipping fingerprint computation test - device initialization failed");
return;
}
let device = device_result.unwrap();
let engine_result = Fingerprints::builder()
.binary()
.dimensions(64) .build(&device);
if engine_result.is_err() {
println!("Skipping fingerprint computation test - engine initialization failed");
return;
}
let engine = engine_result.unwrap();
let test_strings = vec!["hello", "world", "test"];
let result = engine.compute(&device, &test_strings, 64);
match result {
Ok((hashes, counts)) => {
assert_eq!(hashes.len(), 3 * 64); assert_eq!(counts.len(), 3 * 64); println!("Fingerprint computation successful");
}
Err(e) => println!("Fingerprint computation failed: {:?}", e),
}
}
#[test]
fn levenshtein_distance_engine() {
let device_result = DeviceScope::default();
if device_result.is_err() {
println!("Skipping Levenshtein test - device initialization failed");
return;
}
let device = device_result.unwrap();
let engine_result = LevenshteinDistances::new(
&device, 0, 1, 1, 1, );
if engine_result.is_err() {
println!("Skipping Levenshtein test - engine initialization failed");
return;
}
let engine = engine_result.unwrap();
let strings_a = vec!["kitten", "saturday"];
let strings_b = vec!["sitting", "sunday"];
let result = engine.compute(&device, &strings_a, &strings_b);
match result {
Ok(distances) => {
assert_eq!(distances.len(), 2);
println!("Levenshtein distances: {:?}", distances);
}
Err(e) => println!("Levenshtein computation failed: {:?}", e),
}
}
#[test]
fn levenshtein_utf8_engine() {
let device_result = DeviceScope::default();
if device_result.is_err() {
println!("Skipping UTF-8 Levenshtein test - device initialization failed");
return;
}
let device = device_result.unwrap();
let engine_result = LevenshteinDistancesUtf8::new(&device, 0, 1, 1, 1);
if engine_result.is_err() {
println!("Skipping UTF-8 Levenshtein test - engine initialization failed");
return;
}
let engine = engine_result.unwrap();
let strings_a = vec!["café", "naïve"];
let strings_b = vec!["cafe", "naive"];
let result = engine.compute(&device, &strings_a, &strings_b);
match result {
Ok(distances) => {
assert_eq!(distances.len(), 2);
println!("UTF-8 Levenshtein distances: {:?}", distances);
}
Err(e) => println!("UTF-8 Levenshtein computation failed: {:?}", e),
}
}
#[test]
fn needleman_wunsch_engine() {
let device_result = DeviceScope::default();
if device_result.is_err() {
println!("Skipping Needleman-Wunsch test - device initialization failed");
return;
}
let device = device_result.unwrap();
let mut matrix = [[-1i8; 256]; 256];
for i in 0..256 {
matrix[i][i] = 2; }
let engine_result = NeedlemanWunschScores::new(&device, &matrix, -2, -1);
if engine_result.is_err() {
println!("Skipping Needleman-Wunsch test - engine initialization failed");
return;
}
let engine = engine_result.unwrap();
let sequences_a = vec!["ACGT"];
let sequences_b = vec!["ACGT"];
let result = engine.compute(&device, &sequences_a, &sequences_b);
match result {
Ok(scores) => {
assert_eq!(scores.len(), 1);
println!("Needleman-Wunsch score: {:?}", scores);
}
Err(e) => println!("Needleman-Wunsch computation failed: {:?}", e),
}
}
#[test]
fn smith_waterman_engine() {
let device_result = DeviceScope::default();
if device_result.is_err() {
println!("Skipping Smith-Waterman test - device initialization failed");
return;
}
let device = device_result.unwrap();
let mut matrix = [[-1i8; 256]; 256];
for i in 0..256 {
matrix[i][i] = 3; }
let engine_result = SmithWatermanScores::new(&device, &matrix, -2, -1);
if engine_result.is_err() {
println!("Skipping Smith-Waterman test - engine initialization failed");
return;
}
let engine = engine_result.unwrap();
let sequences_a = vec!["ACGTACGT"];
let sequences_b = vec!["ACGT"];
let result = engine.compute(&device, &sequences_a, &sequences_b);
match result {
Ok(scores) => {
assert_eq!(scores.len(), 1);
println!("Smith-Waterman score: {:?}", scores);
}
Err(e) => println!("Smith-Waterman computation failed: {:?}", e),
}
}
#[test]
fn unified_allocator() {
let layout = std::alloc::Layout::from_size_align(1024, 8).unwrap();
let alloc = UnifiedAlloc;
let result = alloc.allocate(layout);
match result {
Ok(memory) => {
println!("Unified allocation successful: {} bytes", memory.len());
unsafe { alloc.deallocate(memory.cast(), layout) };
}
Err(_) => println!("Unified allocation failed"),
}
let zero_layout = std::alloc::Layout::from_size_align(0, 1).unwrap();
let zero_result = alloc.allocate(zero_layout);
match zero_result {
Ok(memory) => {
assert_eq!(memory.len(), 0);
unsafe { alloc.deallocate(memory.cast(), zero_layout) };
}
Err(_) => println!("Zero-size allocation failed"),
}
}
#[test]
fn error_handling() {
let valid_cpu = DeviceScope::cpu_cores(0); assert!(valid_cpu.is_ok(), "CPU cores 0 should succeed");
let invalid_gpu = DeviceScope::gpu_device(999);
match invalid_gpu {
Ok(_) => println!("GPU device 999 unexpectedly available"),
Err(e) => println!("GPU device 999 correctly failed: {:?}", e),
}
let default_device = DeviceScope::default();
match default_device {
Ok(_) => println!("Default device scope created successfully"),
Err(e) => println!("Default device scope failed: {:?}", e),
}
}
#[test]
fn thread_safety() {
use std::sync::Arc;
use std::thread;
let device_result = DeviceScope::default();
if device_result.is_err() {
println!("Skipping thread safety test - device initialization failed");
return;
}
let device = Arc::new(device_result.unwrap());
let engine_result = Fingerprints::builder().dimensions(64).build(&device);
if engine_result.is_err() {
println!("Skipping thread safety test - engine initialization failed");
return;
}
let engine = Arc::new(engine_result.unwrap());
let handles: Vec<_> = (0..4)
.map(|i| {
let device = Arc::clone(&device);
let engine = Arc::clone(&engine);
thread::spawn(move || {
let test_data = vec![format!("thread_{}_data", i)];
engine.compute(&device, &test_data, 64)
})
})
.collect();
let mut success_count = 0;
for handle in handles {
match handle.join().unwrap() {
Ok(_) => success_count += 1,
Err(e) => println!("Thread computation failed: {:?}", e),
}
}
println!("Thread safety test: {}/4 threads succeeded", success_count);
}
#[test]
fn large_batch_processing() {
let device_result = DeviceScope::default();
if device_result.is_err() {
println!("Skipping large batch test - device initialization failed");
return;
}
let device = device_result.unwrap();
let engine_result = Fingerprints::builder().dimensions(64).build(&device);
if engine_result.is_err() {
println!("Skipping large batch test - engine initialization failed");
return;
}
let engine = engine_result.unwrap();
let large_batch: Vec<String> = (0..1000).map(|i| format!("test_string_{}", i)).collect();
let large_batch_refs: Vec<&str> = large_batch.iter().map(|s| s.as_str()).collect();
let result = engine.compute(&device, &large_batch_refs, 64);
match result {
Ok((hashes, counts)) => {
assert_eq!(hashes.len(), 1000 * 64);
assert_eq!(counts.len(), 1000 * 64);
println!("Large batch processing successful: 1000 strings processed");
}
Err(e) => println!("Large batch processing failed: {:?}", e),
}
}
#[test]
fn similarity_estimation() {
let device_result = DeviceScope::default();
if device_result.is_err() {
println!("Skipping similarity test - device initialization failed");
return;
}
let device = device_result.unwrap();
let engine_result = Fingerprints::builder().dimensions(128).build(&device);
if engine_result.is_err() {
println!("Skipping similarity test - engine initialization failed");
return;
}
let engine = engine_result.unwrap();
let test_strings = vec![
"the quick brown fox",
"the quick brown fox", "the quick brown dog", "completely different", ];
let result = engine.compute(&device, &test_strings, 128);
match result {
Ok((hashes, _counts)) => {
let dimensions = 128;
let mut matches_identical = 0;
for i in 0..dimensions {
if hashes[0 * dimensions + i] == hashes[1 * dimensions + i] {
matches_identical += 1;
}
}
let similarity_identical = matches_identical as f64 / dimensions as f64;
let mut matches_similar = 0;
for i in 0..dimensions {
if hashes[0 * dimensions + i] == hashes[2 * dimensions + i] {
matches_similar += 1;
}
}
let similarity_similar = matches_similar as f64 / dimensions as f64;
let mut matches_different = 0;
for i in 0..dimensions {
if hashes[0 * dimensions + i] == hashes[3 * dimensions + i] {
matches_different += 1;
}
}
let similarity_different = matches_different as f64 / dimensions as f64;
println!("Similarity identical: {:.3}", similarity_identical);
println!("Similarity similar: {:.3}", similarity_similar);
println!("Similarity different: {:.3}", similarity_different);
assert!(similarity_identical >= similarity_similar);
assert!(similarity_similar >= similarity_different);
}
Err(e) => println!("Similarity estimation failed: {:?}", e),
}
}
#[test]
fn error_costs_for_needleman_wunsch() {
let device_result = DeviceScope::default();
if device_result.is_err() {
println!("Skipping error_costs test - device initialization failed");
return;
}
let device = device_result.unwrap();
let matrix = error_costs_256x256_diagonal(2, -1);
let engine_result = NeedlemanWunschScores::new(&device, &matrix, -2, -1);
if engine_result.is_err() {
println!("Skipping error_costs test - NW engine initialization failed");
return;
}
let engine = engine_result.unwrap();
let seq_a = vec!["ABCD"];
let seq_b = vec!["ABCD"];
let result = engine.compute(&device, &seq_a, &seq_b);
match result {
Ok(scores) => {
assert!(scores[0] > 0, "Identical sequences should have positive score");
println!("Error costs matrix integration test passed: score = {}", scores[0]);
}
Err(e) => println!("Error costs test failed: {:?}", e),
}
}
#[test]
fn levenshtein_compute_into_u32_bytes() {
let device = match DeviceScope::default() {
Ok(d) => d,
Err(_) => return, };
let engine = match LevenshteinDistances::new(&device, 0, 1, 1, 1) {
Ok(e) => e,
Err(_) => return,
};
let a = [b"kitten".as_ref(), b"saturday".as_ref()];
let b = [b"sitting".as_ref(), b"sunday".as_ref()];
let mut ta = BytesTape::<u32, UnifiedAlloc>::new_in(UnifiedAlloc);
ta.extend(a).unwrap();
let mut tb = BytesTape::<u32, UnifiedAlloc>::new_in(UnifiedAlloc);
tb.extend(b).unwrap();
let mut results: UnifiedVec<usize> = UnifiedVec::with_capacity_in(2, UnifiedAlloc);
results.resize(2, 0);
let res = engine.compute_into(
&device,
AnyBytesTape::Tape32(ta),
AnyBytesTape::Tape32(tb),
&mut results[..],
);
if let Ok(()) = res {
assert_eq!(&results[..], &[3, 3]);
}
}
#[test]
fn levenshtein_compute_into_u64_bytes() {
let device = match DeviceScope::default() {
Ok(d) => d,
Err(_) => return, };
let engine = match LevenshteinDistances::new(&device, 0, 1, 1, 1) {
Ok(e) => e,
Err(_) => return,
};
let a = [b"abc".as_ref(), b"abcdef".as_ref()];
let b = [b"yabd".as_ref(), b"abcxef".as_ref()];
let mut ta = BytesTape::<u64, UnifiedAlloc>::new_in(UnifiedAlloc);
ta.extend(a).unwrap();
let mut tb = BytesTape::<u64, UnifiedAlloc>::new_in(UnifiedAlloc);
tb.extend(b).unwrap();
let mut results: UnifiedVec<usize> = UnifiedVec::with_capacity_in(2, UnifiedAlloc);
results.resize(2, 0);
let res = engine.compute_into(
&device,
AnyBytesTape::Tape64(ta),
AnyBytesTape::Tape64(tb),
&mut results[..],
);
if let Ok(()) = res {
assert_eq!(&results[..], &[2, 1]);
}
}
}