1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
use std::os::raw::{c_char, c_int};

use crate::{
    bindings::*, mm_allocator::MMAllocator, penalties::AffinePenalties,
};

#[derive(Debug, Clone, Copy)]
pub enum WavefrontError {
    InputLengthError,
}

impl std::fmt::Display for WavefrontError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            WavefrontError::InputLengthError => write!(
                f,
                "Pattern and text must be no longer than \
                 the lengths used to allocate the wavefront"
            ),
        }
    }
}

impl std::error::Error for WavefrontError {}

/// Safe wrapper over an `affine_wavefronts_t` instance allocated by
/// a libwfa `mm_allocator`.
pub struct AffineWavefronts<'a> {
    ptr: *mut affine_wavefronts_t,
    // This allocator ref is mainly kept to force the wavefronts to be
    // dropped before the allocator is freed
    allocator: &'a MMAllocator,
    pattern_len: usize,
    text_len: usize,
}

impl<'a> AffineWavefronts<'a> {
    /// Construct a new set of complete wavefronts
    pub fn new_complete(
        pattern_len: usize,
        text_len: usize,
        penalties: &mut AffinePenalties,
        alloc: &'a MMAllocator,
    ) -> Self {
        assert!(pattern_len > 0 && text_len > 0);
        let stats_ptr = std::ptr::null_mut() as *mut wavefronts_stats_t;
        let ptr = unsafe {
            affine_wavefronts_new_complete(
                pattern_len as c_int,
                text_len as c_int,
                penalties.as_ptr(),
                stats_ptr,
                alloc.alloc_ptr(),
            )
        };
        AffineWavefronts {
            ptr,
            allocator: alloc,
            pattern_len,
            text_len,
        }
    }

    /// Construct a new set of reduced wavefronts
    pub fn new_reduced(
        pattern_len: usize,
        text_len: usize,
        penalties: &mut AffinePenalties,
        min_wavefront_len: i32,
        min_dist_threshold: i32,
        alloc: &'a MMAllocator,
    ) -> Self {
        assert!(pattern_len > 0 && text_len > 0);
        let stats_ptr = std::ptr::null_mut() as *mut wavefronts_stats_t;
        let ptr = unsafe {
            affine_wavefronts_new_reduced(
                pattern_len as c_int,
                text_len as c_int,
                penalties.as_ptr(),
                min_wavefront_len as c_int,
                min_dist_threshold as c_int,
                stats_ptr,
                alloc.alloc_ptr(),
            )
        };

        Self {
            ptr,
            allocator: alloc,
            pattern_len,
            text_len,
        }
    }

    /// Clear the wavefronts
    pub fn clear(&mut self) {
        unsafe {
            affine_wavefronts_clear(self.ptr);
        }
    }

    /// Align the given pattern and text string. Callers need to make
    /// sure the byteslices have the correct length compared to the
    /// lengths used to construct thing wavefronts object.
    ///
    /// Does *not* check that `pattern` and `text` are nul-terminated
    /// CStrings, since the C function used takes the string lengths
    /// as arguments.
    pub fn align(
        &mut self,
        pattern: &[u8],
        text: &[u8],
    ) -> Result<(), WavefrontError> {
        if pattern.len() > self.pattern_len || text.len() > self.text_len {
            return Err(WavefrontError::InputLengthError);
        }
        unsafe {
            affine_wavefronts_align(
                self.ptr,
                pattern.as_ptr() as *const c_char,
                pattern.len() as c_int,
                text.as_ptr() as *const c_char,
                text.len() as c_int,
            );
        }
        Ok(())
    }

    fn edit_cigar(&self) -> &edit_cigar_t {
        unsafe {
            let wf_ref = self.ptr.as_ref().unwrap();
            &wf_ref.edit_cigar
        }
    }

    /// Returns the cigar string for the wavefront alignment as a
    /// vector of bytes. Note that each operation is repeated however
    /// many times it applies, i.e. instead of "3M1X" you get "MMMX".
    pub fn cigar_bytes_raw(&self) -> Vec<u8> {
        let slice = unsafe { self.cigar_slice() };
        slice.into()
    }

    pub fn cigar_bytes(&self) -> Vec<u8> {
        let slice = unsafe { self.cigar_slice() };
        if slice.is_empty() {
            Vec::new()
        } else {
            compress_cigar(slice).unwrap()
        }
    }

    /// Returns a slice to the cigar string for the wavefront
    /// alignment. Unsafe as the slice is pointing to the
    /// `edit_cigar_t` managed by libwfa.
    pub unsafe fn cigar_slice(&self) -> &[u8] {
        let cigar = self.edit_cigar();
        let ops_ptr = cigar.operations as *mut u8;
        let start = ops_ptr.offset(cigar.begin_offset as isize);
        let len = (cigar.end_offset - cigar.begin_offset) as usize;
        std::slice::from_raw_parts(start, len)
    }

    /// Returns the alignment score
    pub fn edit_cigar_score(
        &mut self,
        penalties: &mut AffinePenalties,
    ) -> isize {
        let penalties = penalties as *mut AffinePenalties;
        let penalties_ptr: *mut affine_penalties_t = penalties.cast();
        let score = unsafe {
            let wf_ref = self.ptr.as_mut().unwrap();
            let cigar = &mut wf_ref.edit_cigar as *mut edit_cigar_t;
            edit_cigar_score_gap_affine(cigar, penalties_ptr)
        };

        score as isize
    }

    /// Prints the alignment using the C library pretty printer. For
    /// now it only prints to stderr.
    pub fn print_cigar(&mut self, pattern: &[u8], text: &[u8]) {
        unsafe {
            let wf_ref = self.ptr.as_mut().unwrap();
            let cg_mut = &mut wf_ref.edit_cigar as *mut edit_cigar_t;
            edit_cigar_print_pretty(
                stderr,
                pattern.as_ptr() as *const c_char,
                pattern.len() as c_int,
                text.as_ptr() as *const c_char,
                text.len() as c_int,
                cg_mut,
                self.allocator.alloc_ptr(),
            );
        }
    }
}

/// "Compresses" a cigar string produced by WFA (like "MMMXXII") into
/// a regular cigar string with operation counts (e.g. "3M2X2I")
fn compress_cigar(cigar: &[u8]) -> Option<Vec<u8>> {
    let mut result = Vec::new();

    let mut iter = cigar.iter();
    let mut last_op = iter.next().copied()?;
    let mut last_count = 1;
    for &op in iter {
        if op == last_op {
            last_count += 1;
        } else {
            let op_char = char::from(last_op);
            let string = format!("{}{}", last_count, op_char);
            result.extend(string.as_bytes());
            last_op = op;
            last_count = 1;
        }
    }

    let op_char = char::from(last_op);
    let string = format!("{}{}", last_count, op_char);
    result.extend(string.as_bytes());

    Some(result)
}

impl Drop for AffineWavefronts<'_> {
    fn drop(&mut self) {
        unsafe { affine_wavefronts_delete(self.ptr) }
    }
}