import os
import utls
from Bio import SeqIO as SqIO
from bx.align import maf
import numpy as nmpy
from math import sqrt
from sklearn.metrics import roc_curve as roc_crv
from sklearn.metrics import auc
import matplotlib.pyplot as pyplt
import seaborn as sbrn
def mn():
(cr_wrk_dr, ast_dr, prgrm_dr, cnd_prgrm_dr) = utls.gt_drs()
clr_plt = sbrn.color_palette()
utls.int_mtpltlb()
pstv_algn_fls = [
ast_dr + "/h_spns_m_mscls_hmlg_lncrnas_2_m_mscls_ncrnas.maf",
ast_dr + "/last_h_spns_m_mscls_hmlg_lncrnas_2_m_mscls_ncrnas.maf",
ast_dr + "/blastn_h_spns_m_mscls_hmlg_lncrnas_2_m_mscls_ncrnas.dat",
]
ngtv_algn_fls = [
ast_dr + "/shfl_h_spns_m_mscls_hmlg_lncrnas_2_m_mscls_ncrnas.maf",
ast_dr + "/last_shfl_h_spns_m_mscls_hmlg_lncrnas_2_m_mscls_ncrnas.maf",
ast_dr + "/blastn_shfl_h_spns_m_mscls_hmlg_lncrnas_2_m_mscls_ncrnas.dat",
]
prgrms = [
"CRAST",
"LAST",
"BLASTN",
]
rf_sq_fl = ast_dr + "/m_mscls_ncrnas.fa"
hmlg_rlts = ast_dr + "/h_spns_m_mscls_hmlg_lncrnas.dat"
hmlg_rlts = utls.gt_prs_hmlg_rlts(hmlg_rlts)
img_dr = ast_dr + "/imgs"
if not os.path.isdir(img_dr):
os.mkdir(img_dr)
(fg, ax) = pyplt.subplots()
for i, (pstv_algn_fl, ngtv_algn_fl) in enumerate(zip(pstv_algn_fls, ngtv_algn_fls)):
pstv_algns = gt_prs_blast_algns(pstv_algn_fl) if pstv_algn_fl.find("blastn") != -1 else utls.gt_prs_algns(pstv_algn_fl)
ngtv_algns = gt_prs_blast_algns(ngtv_algn_fl) if ngtv_algn_fl.find("blastn") != -1 else utls.gt_prs_algns(ngtv_algn_fl)
tp = fp = tn = fn = 0.
fnd_algns = {}
bns = nmpy.array([], dtype = bool)
e_vls = nmpy.array([], dtype = float)
for pstv_algn in pstv_algns:
algn_pr = (pstv_algn[1], pstv_algn[2])
if algn_pr in fnd_algns:
continue
if utls.is_gn_prdctd(hmlg_rlts[pstv_algn[2]], pstv_algn[1], rf_sq_fl):
tp += 1
bns = nmpy.append(bns, True)
else:
fp += 1
bns = nmpy.append(bns, False)
fnd_algns[algn_pr] = True
e_vls = nmpy.append(e_vls, pstv_algn[0])
fnd_algns = {}
for ngtv_algn in ngtv_algns:
algn_pr = (ngtv_algn[1], ngtv_algn[2])
if algn_pr in fnd_algns:
continue
if not utls.is_gn_prdctd(hmlg_rlts[ngtv_algn[2]], ngtv_algn[1], rf_sq_fl):
tn += 1
bns = nmpy.append(bns, False)
else:
fn += 1
bns = nmpy.append(bns, True)
fnd_algns[algn_pr] = True
e_vls = nmpy.append(e_vls, ngtv_algn[0])
print tp, fp, tn, fn
prcsn = tp / (tp + fp)
rcl = tp / (tp + fn)
f_ms = 2 * rcl * prcsn / (rcl + prcsn)
print f_ms
fpr, tpr, _ = roc_crv(bns, e_vls, pos_label = True)
roc_auc = auc(fpr, tpr)
ax.plot(fpr, tpr, color = clr_plt[i], label = "%s (%0.2f)" % (prgrms[i], roc_auc))
ax.set_xlabel("False positive rate")
ax.set_ylabel("True positive rate")
ax.legend(loc = "upper left")
fg.savefig(img_dr + "/crast_cmprsn_wth_blast_lk_tls.eps")
def gt_prs_blast_algns(algn_fl):
prs_algns = []
with open(algn_fl, "rU") as inpt_hndl:
lns = inpt_hndl.readlines()
for ln in lns:
dt = ln.split("\t")
prs_algns.append((float(dt[10]), dt[1], dt[0]))
return prs_algns
if __name__ == "__main__":
mn()