vcfexpress 0.3.3

A tool for filtering VCF files using Lua expressions
Documentation
use mlua::{AnyUserData, Lua, MetaMethod, UserDataFields, UserDataMethods};
use rust_htslib::bcf::header::{Header, HeaderView};
use rust_htslib::bcf::HeaderRecord;
use std::collections::HashMap;
use std::sync::Arc;

fn handle_hash_get<'a>(
    tbl: &'a HashMap<String, String>,
    key: &str,
    func: &str,
) -> Result<&'a str, mlua::Error> {
    match tbl.get(key) {
        Some(x) => Ok(x),
        None => Err(mlua::Error::ExternalError(Arc::new(std::io::Error::new(
            std::io::ErrorKind::InvalidInput,
            format!(
                "must specify {} in the argument to add_{}. got {:?}",
                key, func, tbl
            ),
        )))),
    }
}

macro_rules! find_record_match {
    ($record:expr, $key:expr, $record_type:ident) => {
        match $record {
            HeaderRecord::$record_type { key: _, values } => {
                if values.get("ID") != Some(&$key.to_string()) {
                    return None;
                }
                Some(HashMap::from_iter(
                    values
                        .into_iter()
                        .map(|(k, v)| (k.to_string(), v.to_string())),
                ))
            }
            _ => None,
        }
    };
}

fn find_record(
    records: &[HeaderRecord],
    key: &str,
    hdr_type: ::libc::c_uint,
) -> Result<HashMap<String, String>, mlua::Error> {
    let hrec = records
        .iter()
        .filter_map(|x| {
            if hdr_type == rust_htslib::htslib::BCF_HL_INFO {
                find_record_match!(x, key, Info)
            } else if hdr_type == rust_htslib::htslib::BCF_HL_FMT {
                find_record_match!(x, key, Format)
            } else {
                None
            }
        })
        .next();

    if let Some(hrec) = hrec {
        return Ok(hrec);
    }
    Err(mlua::Error::ExternalError(Arc::new(std::io::Error::new(
        std::io::ErrorKind::InvalidInput,
        format!("key {}, hdr_type:{:?} not found in header", key, hdr_type),
    ))))
}

pub(crate) fn register_header(lua: &Lua) -> mlua::Result<()> {
    lua.register_userdata_type::<HeaderView>(|reg| {
        reg.add_function(
            "info_get",
            |_lua: &Lua, (ud, find_key): (AnyUserData, String)| {
                // get the HREC
                let this = ud.borrow_mut::<HeaderView>()?;
                find_record(
                    &this.header_records(),
                    &find_key,
                    rust_htslib::htslib::BCF_HL_INFO,
                )
            },
        );
        reg.add_function(
            "format_get",
            |_lua: &Lua, (ud, find_key): (AnyUserData, String)| {
                // get the HREC
                let this = ud.borrow_mut::<HeaderView>()?;
                find_record(
                    &this.header_records(),
                    &find_key,
                    rust_htslib::htslib::BCF_HL_FMT,
                )
            },
        );
        reg.add_meta_function(MetaMethod::ToString, |_lua, this: AnyUserData| {
            let this = this.borrow::<HeaderView>()?;

            let mut kstr = rust_htslib::htslib::kstring_t {
                l: 0,
                m: 0,
                s: std::ptr::null_mut(),
            };
            if unsafe { rust_htslib::htslib::bcf_hdr_format(this.inner, 0, &mut kstr) } != 0 {
                return Err(mlua::Error::ExternalError(Arc::new(
                    std::io::Error::last_os_error(),
                )));
            }
            let s = unsafe {
                String::from_utf8_unchecked(
                    std::slice::from_raw_parts(kstr.s as *const u8, kstr.l as usize).to_vec(),
                )
            };

            Ok(s)
        });
        reg.add_field_method_get("samples", |_lua, this: &HeaderView| {
            let samples = this
                .samples()
                .iter()
                .map(|&x| String::from_utf8_lossy(x).to_string())
                .collect::<Vec<_>>();
            Ok(samples)
        });
        reg.add_field_method_set(
            "samples",
            |_lua, this: &mut HeaderView, samples: Vec<String>| {
                let sample_bytes = samples.iter().map(|x| x.as_bytes()).collect::<Vec<_>>();
                match Header::from_template_subset(this, &sample_bytes) {
                    Ok(h) => {
                        //_ = unsafe { rust_htslib::htslib::bcf_hdr_sync(h.inner) };
                        let header_t = unsafe { rust_htslib::htslib::bcf_hdr_dup(h.inner) };
                        *this = HeaderView::new(header_t);
                        eprintln!(
                            "samples from c directly: {:?}",
                            this.samples()
                                .iter()
                                .map(|&x| String::from_utf8_lossy(x).to_string())
                                .collect::<Vec<_>>()
                        );
                        Ok(())
                    }
                    Err(e) => Err(mlua::Error::ExternalError(Arc::new(e))),
                }
            },
        );
        reg.add_function_mut(
            "add_info",
            |_lua, (ud, tbl): (AnyUserData, HashMap<String, String>)| {
                let this = ud.borrow_mut::<HeaderView>()?;
                let c_str = std::ffi::CString::new(format!(
                    r#"##INFO=<ID={},Number={},Type={},Description={}>"#,
                    handle_hash_get(&tbl, "ID", "info")?,
                    handle_hash_get(&tbl, "Number", "info")?,
                    handle_hash_get(&tbl, "Type", "info")?,
                    handle_hash_get(&tbl, "Description", "info")?,
                ))
                .expect("CString::new failed");
                let ret =
                    unsafe { rust_htslib::htslib::bcf_hdr_append(this.inner, c_str.as_ptr()) };
                if ret != 0 {
                    log::error!("Error adding INFO field for {:?}: {}", tbl, ret);
                    return Err(mlua::Error::ExternalError(Arc::new(
                        std::io::Error::last_os_error(),
                    )));
                }
                let ret = unsafe { rust_htslib::htslib::bcf_hdr_sync(this.inner) };
                if ret != 0 {
                    log::warn!(
                        "Error syncing header after adding INFO field for {:?}: {}",
                        tbl,
                        ret
                    );
                }
                Ok(())
            },
        );
        reg.add_function_mut(
            "add_filter",
            |_lua, (ud, tbl): (AnyUserData, HashMap<String, String>)| {
                let this = ud.borrow_mut::<HeaderView>()?;
                let c_str = std::ffi::CString::new(format!(
                    r#"##FILTER=<ID={},Description="{}">"#,
                    handle_hash_get(&tbl, "ID", "filter")?,
                    handle_hash_get(&tbl, "Description", "filter")?,
                ))
                .expect("CString::new failed");
                let ret =
                    unsafe { rust_htslib::htslib::bcf_hdr_append(this.inner, c_str.as_ptr()) };
                if ret != 0 {
                    log::error!("Error adding FILTER field for {:?}: {}", tbl, ret);
                    return Err(mlua::Error::ExternalError(Arc::new(
                        std::io::Error::last_os_error(),
                    )));
                }
                let ret = unsafe { rust_htslib::htslib::bcf_hdr_sync(this.inner) };
                if ret != 0 {
                    log::warn!(
                        "Error syncing header after adding FILTER field for {:?}: {}",
                        tbl,
                        ret
                    );
                }
                Ok(())
            },
        );
        reg.add_function_mut(
            "add_format",
            |_lua, (ud, tbl): (AnyUserData, HashMap<String, String>)| {
                let this = ud.borrow_mut::<HeaderView>()?;
                let c_str = std::ffi::CString::new(format!(
                    r#"##FORMAT=<ID={},Number={},Type={},Description="{}">"#,
                    handle_hash_get(&tbl, "ID", "format")?,
                    handle_hash_get(&tbl, "Number", "format")?,
                    handle_hash_get(&tbl, "Type", "format")?,
                    handle_hash_get(&tbl, "Description", "format")?,
                ))
                .expect("CString::new failed");
                let ret =
                    unsafe { rust_htslib::htslib::bcf_hdr_append(this.inner, c_str.as_ptr()) };
                if ret != 0 {
                    log::error!("Error adding FORMAT field for {:?}: {}", tbl, ret);
                    return Err(mlua::Error::ExternalError(Arc::new(
                        std::io::Error::last_os_error(),
                    )));
                }
                _ = unsafe { rust_htslib::htslib::bcf_hdr_sync(this.inner) };
                Ok(())
            },
        );
    })
}

#[cfg(test)]
mod tests {
    use super::*;
    use mlua::Lua;

    fn setup() -> (Lua, Header, HeaderView) {
        let lua = Lua::new();
        register_header(&lua).unwrap();

        let mut header = Header::new();
        header.push_record(r#"##contig=<ID=chr1,length=10000>"#.as_bytes());

        header.push_sample("Sample1".as_bytes());
        header.push_sample("Sample2".as_bytes());
        unsafe { rust_htslib::htslib::bcf_hdr_sync(header.inner) };
        let header_t = unsafe { rust_htslib::htslib::bcf_hdr_dup(header.inner) };
        let header_view = HeaderView::new(header_t);

        (lua, header, header_view)
    }

    #[test]
    fn test_lua_header_samples() {
        let (lua, _header, mut header_view) = setup();
        let globals = lua.globals();

        let exp = lua
            .load(
                r#"
            -- TODO: this is broken
            header.samples = {"Sample1"};
            return table.concat(header.samples, ",")
            "#,
            )
            .set_name("test_lua_header_samples")
            .into_function()
            .expect("error in test_lua_header_samples");

        lua.scope(|scope| {
            globals.set(
                "header",
                scope.create_any_userdata_ref_mut(&mut header_view)?,
            )?;
            let result: String = exp.call(())?;
            assert_eq!(result, "Sample1");

            Ok(())
        })
        .expect("error in test_lua_header_samples")
    }

    #[test]
    fn test_add_info() {
        let (lua, _header, mut header_view) = setup();
        let globals = lua.globals();

        let exp = lua
            .load(
                r#"
            header:add_info({ID="TEST", Number="1", Type="Integer", Description="Test field"});
            return tostring(header)
            "#,
            )
            .set_name("test_add_info")
            .into_function()
            .expect("error in test_add_info");

        lua.scope(|scope| {
            globals.set(
                "header",
                scope.create_any_userdata_ref_mut(&mut header_view)?,
            )?;
            let result: String = exp.call(())?;
            let expected = "##fileformat=VCFv4.2\n##FILTER=<ID=PASS,Description=\"All filters passed\">\n##contig=<ID=chr1,length=10000>\n##INFO=<ID=TEST,Number=1,Type=Integer,Description=Test field>\n#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSample1\tSample2\n";
            assert_eq!(
                result,
                expected,
            );

            Ok(())
        })
        .expect("error in test_add_info")
    }

    #[test]
    fn test_add_filter() {
        let (lua, _header, mut header_view) = setup();
        let globals = lua.globals();

        let exp = lua
            .load(
                r#"
            header:add_filter({ID="LowQual", Description="Qual less than 50"});
            return tostring(header)
            "#,
            )
            .set_name("test_add_filter")
            .into_function()
            .expect("error in test_add_filter");

        lua.scope(|scope| {
            globals.set(
                "header",
                scope.create_any_userdata_ref_mut(&mut header_view)?,
            )?;
            let result: String = exp.call(())?;
            assert!(result.contains("##FILTER=<ID=LowQual"));

            Ok(())
        })
        .expect("error in test_add_filter")
    }

    #[test]
    fn test_add_format() {
        let (lua, _header, mut header_view) = setup();
        let globals = lua.globals();

        let exp = lua
            .load(
                r#"
            header:add_format({ID="TEST", Number="1", Type="Integer", Description="Test field"});
            return tostring(header)
            "#,
            )
            .set_name("test_add_format")
            .into_function()
            .expect("error in test_add_format");

        lua.scope(|scope| {
            globals.set(
                "header",
                scope.create_any_userdata_ref_mut(&mut header_view)?,
            )?;
            let result: String = exp.call(())?;
            let expected = "##fileformat=VCFv4.2\n##FILTER=<ID=PASS,Description=\"All filters passed\">\n##contig=<ID=chr1,length=10000>\n##FORMAT=<ID=TEST,Number=1,Type=Integer,Description=\"Test field\">\n#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSample1\tSample2\n";
            assert_eq!(
                result,
                expected,
            );

            Ok(())
        })
        .expect("error in test_add_format")
    }
}