try:
import pyfastx
except ImportError:
pyfastx = None
print("Warning: pyfastx module not available. FASTA processing will not work.")
try:
from rustkmer import Database
except ImportError:
print("Warning: rustkmer module not available. Some functions will not work.")
Database = None
from multiprocessing import Pool, cpu_count
try:
from tqdm import tqdm
except ImportError:
tqdm = None
print("Warning: tqdm module not available. Progress bars will not be shown.")
def get_N_positions(seq):
return [i for i, c in enumerate(seq) if c == 'N']
def get_consecutive_N_regions(seq):
N_positions = get_N_positions(seq)
if not N_positions:
return []
regions = []
nstart = N_positions[0]
nend = N_positions[0]
for i in range(1, len(N_positions)):
if N_positions[i] == N_positions[i-1] + 1: nend = N_positions[i]
else:
regions.append({'nstart': nstart, 'nend': nend})
nstart = N_positions[i]
nend = N_positions[i]
regions.append({'nstart': nstart, 'nend': nend}) return regions
def reduce_N_positions(seq, max_N=11):
N_regions = get_consecutive_N_regions(seq)
if not N_regions:
return seq
seq_list = list(seq)
offset = 0 for region in N_regions:
nstart = region['nstart'] - offset nend = region['nend'] - offset
region_length = nend - nstart + 1
if region_length > max_N:
del seq_list[nstart + max_N : nend + 1]
offset += (region_length - max_N)
return ''.join(seq_list)
def get_kmer_list(seq, kmerlen):
kmerlist = []
for i in range(len(seq)-kmerlen+1):
kmerlist.append(seq[i:i+kmerlen])
return kmerlist
def get_kmer_count_zero(seq, kmerlen, db):
kmerlist = get_kmer_list(seq, kmerlen)
batch_query_res = db.query_exact_batch(kmerlist)
count_zero = 0
for res in batch_query_res:
if batch_query_res[res].count == 0:
count_zero += 1
return count_zero
def calculate_query_positions(nstart, nend, kmerlen, seq_length):
nlen = nend - nstart + 1
kmer_left = kmerlen - nlen
left_start = nstart - int(kmer_left / 2)
left_end = nend + int(kmer_left / 2)
if kmer_left % 2 == 1:
left_end += 1
if left_start < 0:
left_end -= left_start left_start = 0
elif left_end >= seq_length:
left_start -= (left_end - seq_length + 1)
left_end = seq_length - 1
current_length = left_end - left_start + 1
if current_length != kmerlen:
print(f"警告: 计算的序列长度({current_length})不等于kmerlen({kmerlen})")
print(f"调整前: left_start={left_start}, left_end={left_end}")
if current_length < kmerlen:
left_end = min(left_start + kmerlen - 1, seq_length - 1)
if left_end - left_start + 1 < kmerlen:
left_start = max(0, left_end - kmerlen + 1)
else:
left_end = left_start + kmerlen - 1
print(f"调整后: left_start={left_start}, left_end={left_end}")
return left_start, left_end
def process_match(res, befor_polish, nstart, nend, kmerlen, db):
match_kmer = res.kmer
polished_seq = list(befor_polish)
n_start_in_befor = 19
n_end_in_befor = n_start_in_befor + (nend - nstart)
kmer_left = kmerlen - (nend - nstart + 1)
left_context_len = int(kmer_left / 2)
fill_start = left_context_len
fill_end = fill_start + (nend - nstart + 1)
fill_sequence = match_kmer[fill_start:fill_end]
for i, pos in enumerate(range(n_start_in_befor, n_end_in_befor + 1)):
polished_seq[pos] = fill_sequence[i]
polished_sequence = ''.join(polished_seq)
zero_count = get_kmer_count_zero(polished_sequence, 19, db)
return {
'match_kmer': match_kmer,
'fill_sequence': fill_sequence,
'polished_sequence': polished_sequence,
'count': res.count if hasattr(res, 'count') else None,
'zero_count': zero_count
}
def polish_all_gap_regions(seq_before_gap_fill_reduced, N_regions_reduced, kmerlen, db, top_n=1000):
all_results = {}
for region_idx, region in enumerate(N_regions_reduced):
print(f"\n处理第 {region_idx+1} 个N区域 (位置 {region['nstart']}-{region['nend']})")
nstart = region['nstart']
nend = region['nend']
left_start, left_end = calculate_query_positions(
nstart, nend, kmerlen, len(seq_before_gap_fill_reduced)
)
query_kmer = seq_before_gap_fill_reduced[left_start:left_end+1]
print(f"查询kmer: {query_kmer}")
print(f"查询kmer长度: {len(query_kmer)}")
befor_polish = seq_before_gap_fill_reduced[nstart-19:nend+19]
fuzzy_query_res = db.fuzzy_query(query_kmer, max_variants=9999999999, mutations=0)
polish_results = []
if fuzzy_query_res.matches:
print(f"找到 {len(fuzzy_query_res.matches)} 个匹配")
if len(fuzzy_query_res.matches) <= top_n:
selected_matches = fuzzy_query_res.matches
print(f"匹配数量少于{top_n},使用全部 {len(selected_matches)} 个匹配进行处理")
else:
sorted_matches = sorted(fuzzy_query_res.matches, key=lambda x: x.count, reverse=True)
selected_matches = sorted_matches[:top_n]
print(f"选择count最高的前 {len(selected_matches)} 个匹配进行处理")
print("📝 使用单线程处理匹配结果")
polish_results = []
for res in tqdm(selected_matches, desc="Processing matches"):
result = process_match(res, befor_polish, nstart, nend, kmerlen, db)
polish_results.append(result)
sorted_by_zero_count = sorted(polish_results, key=lambda x: x['zero_count'])
min_zero_count = sorted_by_zero_count[0]['zero_count']
min_zero_count_results = [r for r in sorted_by_zero_count if r['zero_count'] == min_zero_count]
def count_polymers(seq):
count_A = seq.count('A')
count_C = seq.count('C')
count_G = seq.count('G')
count_T = seq.count('T')
polymers = 0
max_polymer_length = 0
i = 0
while i < len(seq):
current_base = seq[i]
if current_base not in ['A', 'C', 'G', 'T']: i += 1
continue
j = i
while j < len(seq) and seq[j] == current_base:
j += 1
if j - i >= 2:
polymers += 1
max_polymer_length = max(max_polymer_length, j - i)
i = j
return polymers, max_polymer_length
best_result = min(min_zero_count_results, key=lambda x: count_polymers(x['fill_sequence'])[1])
print(f"\n第 {region_idx+1} 个N区域的最佳结果:")
print(f" 匹配kmer: {best_result['match_kmer']}")
print(f" 填充序列: {best_result['fill_sequence']}")
print(f" 零计数kmer数量: {best_result['zero_count']}")
polymers_count, max_polymer_length = count_polymers(best_result['fill_sequence'])
print(f" 填充序列多聚体数量: {polymers_count}")
print(f" 最长多聚体长度: {max_polymer_length}")
all_results[region_idx] = {
'region': region,
'best_result': best_result,
'all_results': polish_results,
'total_matches': len(fuzzy_query_res.matches),
'selected_matches': len(selected_matches)
}
else:
print(f"第 {region_idx+1} 个N区域没有找到匹配,无法进行polish")
all_results[region_idx] = {
'region': region,
'best_result': None,
'all_results': [],
'total_matches': 0,
'selected_matches': 0
}
return all_results
def apply_all_polish_results(seq_before_gap_fill_reduced, all_results):
seq_list = list(seq_before_gap_fill_reduced)
sorted_regions = sorted(all_results.items(), key=lambda x: x[1]['region']['nstart'], reverse=True)
for region_idx, result_data in sorted_regions:
region = result_data['region']
best_result = result_data['best_result']
if best_result is None:
print(f"跳过N区域 {region_idx+1} (位置 {region['nstart']}-{region['nend']}): 没有找到匹配")
continue
nstart = region['nstart']
nend = region['nend']
original_n_len = nend - nstart + 1
adjustment_level = best_result.get('adjustment_level', 0)
fill_sequence = best_result['fill_sequence']
filled_n_len = original_n_len + adjustment_level
if len(fill_sequence) != filled_n_len:
print(f"警告: 调整后的N区域长度({filled_n_len})与填充序列长度({len(fill_sequence)})不匹配")
if len(fill_sequence) > filled_n_len:
fill_sequence = fill_sequence[:filled_n_len]
else:
fill_sequence = fill_sequence + 'N' * (filled_n_len - len(fill_sequence))
end_pos = nend + adjustment_level
for i, pos in enumerate(range(nstart, end_pos + 1)):
if pos < len(seq_list):
seq_list[pos] = fill_sequence[i]
print(f"已应用N区域 {region_idx+1} (位置 {region['nstart']}-{region['nend']}) 的polish结果")
return ''.join(seq_list)
def main():
db_path = "/Users/forrest/Data/data/kmer/K19/R1_001.rkdb"
kmerlen = 19
seq_before_gap_fill = "CGCCGCCCCGGCCCCCGCGCCGCGNNNNNNNGCCACCGCCGCCCGCGCCGCCCGCGCTCGCGCGCACTGTCGCCCGNNNNNNNNNNNGGCCGGCCGTCCGCCCGCGCGCCCGCCGCCCGCNNNNNNNNNNNNNNNNNNNNNNNNCGCCCGCGCGACGCCGACGCCGCACGGCCGCCGNNNNNNNNNNNNNNNNNNNNNNNNNNCCGCGCCACGTCGCCGTGTTCCCNNCGCGC"
top_n=5000
db = Database(db_path)
print("原始序列:")
print(seq_before_gap_fill)
N_regions = get_consecutive_N_regions(seq_before_gap_fill)
print(f"Found {len(N_regions)} N regions:")
for i, region in enumerate(N_regions):
print(f"Region {i+1}: positions {region['nstart']} to {region['nend']} (length: {region['nend']-region['nstart']+1})")
seq_before_gap_fill_reduced = reduce_N_positions(seq_before_gap_fill)
print("\n减少N区域后的序列:")
print(seq_before_gap_fill_reduced)
N_regions_reduced = get_consecutive_N_regions(seq_before_gap_fill_reduced)
print(f"Found {len(N_regions_reduced)} N regions:")
for i, region in enumerate(N_regions_reduced):
print(f"Region {i+1}: positions {region['nstart']} to {region['nend']} (length: {region['nend']-region['nstart']+1})")
all_results = polish_all_gap_regions(seq_before_gap_fill_reduced, N_regions_reduced, kmerlen, db, top_n)
fully_polished_sequence = apply_all_polish_results(seq_before_gap_fill_reduced, all_results)
print("\n最终polish后的序列:")
print(fully_polished_sequence)
final_N_regions = get_consecutive_N_regions(fully_polished_sequence)
if final_N_regions:
print(f"\n警告: 最终序列中仍有 {len(final_N_regions)} 个N区域:")
for i, region in enumerate(final_N_regions):
print(f"Region {i+1}: positions {region['nstart']} to {region['nend']} (length: {region['nend']-region['nstart']+1})")
else:
print("\n成功: 最终序列中没有N区域")
return fully_polished_sequence
if __name__ == "__main__":
main()