crast 1.0.4

CRAST, Context RNA Alignment Search Tool
Documentation
#! /usr/bin/env python

import os
import utls
from Bio import SeqIO as SqIO
from bx.align import maf
import numpy as nmpy
from math import sqrt
from sklearn.metrics import roc_curve as roc_crv
from sklearn.metrics import auc
import matplotlib.pyplot as pyplt
import seaborn as sbrn

def mn():
  (cr_wrk_dr, ast_dr, prgrm_dr, cnd_prgrm_dr) = utls.gt_drs()
  clr_plt = sbrn.color_palette()
  utls.int_mtpltlb()
  pstv_algn_fls = [
    ast_dr + "/h_spns_m_mscls_hmlg_lncrnas_2_m_mscls_ncrnas.maf",
    ast_dr + "/last_h_spns_m_mscls_hmlg_lncrnas_2_m_mscls_ncrnas.maf",
    ast_dr + "/blastn_h_spns_m_mscls_hmlg_lncrnas_2_m_mscls_ncrnas.dat",
  ]
  ngtv_algn_fls = [
    ast_dr + "/shfl_h_spns_m_mscls_hmlg_lncrnas_2_m_mscls_ncrnas.maf",
    ast_dr + "/last_shfl_h_spns_m_mscls_hmlg_lncrnas_2_m_mscls_ncrnas.maf",
    ast_dr + "/blastn_shfl_h_spns_m_mscls_hmlg_lncrnas_2_m_mscls_ncrnas.dat",
  ]
  prgrms = [
    "CRAST",
    "LAST",
    "BLASTN",
  ]
  rf_sq_fl = ast_dr + "/m_mscls_ncrnas.fa"
  hmlg_rlts = ast_dr + "/h_spns_m_mscls_hmlg_lncrnas.dat"
  hmlg_rlts = utls.gt_prs_hmlg_rlts(hmlg_rlts)
  img_dr = ast_dr + "/imgs"
  if not os.path.isdir(img_dr):
    os.mkdir(img_dr)
  (fg, ax) = pyplt.subplots()
  for i, (pstv_algn_fl, ngtv_algn_fl) in enumerate(zip(pstv_algn_fls, ngtv_algn_fls)):
    pstv_algns = gt_prs_blast_algns(pstv_algn_fl) if pstv_algn_fl.find("blastn") != -1 else utls.gt_prs_algns(pstv_algn_fl)
    ngtv_algns = gt_prs_blast_algns(ngtv_algn_fl) if ngtv_algn_fl.find("blastn") != -1 else utls.gt_prs_algns(ngtv_algn_fl)
    tp = fp = tn = fn = 0.
    fnd_algns = {}
    bns = nmpy.array([], dtype = bool)
    e_vls = nmpy.array([], dtype = float)
    for pstv_algn in pstv_algns:
      algn_pr = (pstv_algn[1], pstv_algn[2])
      if algn_pr in fnd_algns:
        continue
      if utls.is_gn_prdctd(hmlg_rlts[pstv_algn[2]], pstv_algn[1], rf_sq_fl):
        tp += 1
        bns = nmpy.append(bns, True)
      else:
        fp += 1
        bns = nmpy.append(bns, False)
      fnd_algns[algn_pr] = True
      e_vls = nmpy.append(e_vls, pstv_algn[0])
    fnd_algns = {}
    for ngtv_algn in ngtv_algns:
      algn_pr = (ngtv_algn[1], ngtv_algn[2])
      if algn_pr in fnd_algns:
        continue
      if not utls.is_gn_prdctd(hmlg_rlts[ngtv_algn[2]], ngtv_algn[1], rf_sq_fl):
        tn += 1
        bns = nmpy.append(bns, False)
      else:
        fn += 1
        bns = nmpy.append(bns, True)
      fnd_algns[algn_pr] = True
      e_vls = nmpy.append(e_vls, ngtv_algn[0])
    print tp, fp, tn, fn
    prcsn = tp / (tp + fp)
    rcl = tp / (tp + fn)
    f_ms = 2 * rcl * prcsn / (rcl + prcsn)
    print f_ms
    fpr, tpr, _ = roc_crv(bns, e_vls, pos_label = True)
    roc_auc = auc(fpr, tpr)
    ax.plot(fpr, tpr, color = clr_plt[i], label = "%s (%0.2f)" % (prgrms[i], roc_auc))
  ax.set_xlabel("False positive rate")
  ax.set_ylabel("True positive rate")
  ax.legend(loc = "upper left")
  fg.savefig(img_dr + "/crast_cmprsn_wth_blast_lk_tls.eps")

def gt_prs_blast_algns(algn_fl):
  prs_algns = []
  with open(algn_fl, "rU") as inpt_hndl:
    lns = inpt_hndl.readlines()
    for ln in lns:
      dt = ln.split("\t")
      prs_algns.append((float(dt[10]), dt[1], dt[0]))
  return prs_algns

if __name__ == "__main__":
  mn()