import argparse
import sys
from pathlib import Path
import pyfastx
from tqdm import tqdm
from rustkmer import Database
from gap_filling import (
get_consecutive_N_regions,
reduce_N_positions,
polish_all_gap_regions,
apply_all_polish_results,
get_kmer_count_zero,
get_kmer_list
)
class FastaGapProcessor:
def __init__(self, database_path, kmer_size=19, max_n_per_region=11, top_n=1000):
try:
self.db = Database(database_path, validate=False)
except Exception as e:
print(f"Error loading database {database_path}: {e}")
sys.exit(1)
self.kmer_size = kmer_size
self.max_n_per_region = max_n_per_region
self.top_n = top_n
self.stats = {
'total_sequences': 0,
'processed_sequences': 0, 'skipped_no_gaps': 0, 'too_many_end_n': 0, 'no_n_sequences': 0, 'total_gaps_filled': 0,
'successful_fills': 0
}
def classify_sequence(self, sequence):
leading_n_count = 0
for char in sequence:
if char == 'N':
leading_n_count += 1
else:
break
trailing_n_count = 0
for char in reversed(sequence):
if char == 'N':
trailing_n_count += 1
else:
break
if 'N' not in sequence:
return "no_n"
if leading_n_count > 3 or trailing_n_count > 3:
return "too_many_end_n"
n_regions = get_consecutive_N_regions(sequence)
if not n_regions:
return "no_n"
sequence_length = len(sequence)
for region in n_regions:
if region['nstart'] > 0 and region['nend'] < sequence_length - 1:
return "process"
return "skip"
def should_process_sequence(self, sequence):
return self.classify_sequence(sequence) == "process"
def fill_gaps_in_sequence(self, sequence):
try:
original_regions = get_consecutive_N_regions(sequence)
original_gaps = len(original_regions)
reduced_sequence = reduce_N_positions(sequence, max_N=self.max_n_per_region)
reduced_regions = get_consecutive_N_regions(reduced_sequence)
if not reduced_regions:
return sequence, {
'gaps_processed': 0,
'gaps_filled': 0,
'zero_count_kmers': get_kmer_count_zero(sequence, self.kmer_size, self.db)
}
all_results = polish_all_gap_regions(
reduced_sequence,
reduced_regions,
self.kmer_size,
self.db,
self.top_n
)
filled_sequence = apply_all_polish_results(reduced_sequence, all_results)
successful_fills = sum(1 for result in all_results.values()
if result['best_result'] is not None)
zero_count = get_kmer_count_zero(filled_sequence, self.kmer_size, self.db)
filled_sequence_final = filled_sequence
return filled_sequence_final, {
'gaps_processed': original_gaps,
'gaps_filled': successful_fills,
'zero_count_kmers': zero_count
}
except Exception as e:
print(f"⚠️ 处理序列时出错: {e}")
return sequence, {
'gaps_processed': 0,
'gaps_filled': 0,
'zero_count_kmers': 0
}
def process_fasta_file(self, input_path, output_path):
print(f"\n📝 单线程版本处理 FASTA 文件: {input_path}")
print(f"输出将写入: {output_path}")
try:
fasta = pyfastx.Fasta(input_path)
total_sequences = len(fasta)
self.stats['total_sequences'] = total_sequences
print(f"📊 在输入文件中找到 {total_sequences} 个序列")
except Exception as e:
print(f"❌ 读取 FASTA 文件 {input_path} 时出错: {e}")
sys.exit(1)
processed_sequences = []
error_count = 0
with tqdm(total=total_sequences, desc="单线程处理序列", unit="seq") as pbar:
for i in range(total_sequences):
try:
name = fasta[i].name
sequence = str(fasta[i].seq)
print(f"Processing sequence: {name}")
current_name = name.split()[0][:20]
pbar.set_postfix({'current': current_name})
seq_type = self.classify_sequence(sequence)
if seq_type == "process":
filled_sequence, metadata = self.fill_gaps_in_sequence(sequence)
self.stats['processed_sequences'] += 1
self.stats['total_gaps_filled'] += metadata['gaps_processed']
self.stats['successful_fills'] += metadata['gaps_filled']
new_header = f"{name} gaps_filled={metadata['gaps_filled']}/{metadata['gaps_processed']} zero_count={metadata['zero_count_kmers']}"
processed_sequences.append((new_header, filled_sequence))
elif seq_type == "too_many_end_n":
self.stats['too_many_end_n'] += 1
processed_sequences.append((name, sequence))
elif seq_type == "no_n":
self.stats['no_n_sequences'] += 1
processed_sequences.append((name, sequence))
elif seq_type == "skip":
self.stats['skipped_no_gaps'] += 1
processed_sequences.append((name, sequence))
except Exception as e:
error_count += 1
print(f"⚠️ 处理序列 {i+1}/{total_sequences} 时出错: {e}")
finally:
pbar.update(1)
if error_count > 0:
print(f"⚠️ 总共处理了 {error_count} 个序列时出现错误,但程序继续执行")
print(f"\n💾 写入 {len(processed_sequences)} 个处理过的序列到 {output_path}")
try:
with open(output_path, 'w') as f:
for header, sequence in processed_sequences:
f.write(f">{header}\n")
for i in range(0, len(sequence), 80):
f.write(f"{sequence[i:i+80]}\n")
f.write("\n")
print(f"✅ 成功写入输出文件")
except Exception as e:
print(f"❌ 写入输出文件时出错: {e}")
sys.exit(1)
self.print_statistics()
def print_statistics(self):
print("\n📊 处理统计:")
print(f" 总序列数: {self.stats['total_sequences']}")
print(f" 进行填充处理的序列: {self.stats['processed_sequences']}")
print(f" 跳过的序列 (仅末端N≤3): {self.stats['skipped_no_gaps']}")
print(f" 跳过序列 (开头/结尾>3个N): {self.stats['too_many_end_n']}")
print(f" 跳过序列 (无N): {self.stats['no_n_sequences']}")
print(f" 总gap数: {self.stats['total_gaps_filled']}")
print(f" 成功填充的gap数: {self.stats['successful_fills']}")
if self.stats['processed_sequences'] > 0:
success_rate = (self.stats['successful_fills'] / self.stats['processed_sequences']) * 100
print(f" 成功率: {success_rate:.1f}%")
def main():
parser = argparse.ArgumentParser(
description="FASTA Gap Filling Batch Processor (Single Thread Version - FIXED)",
epilog="""
使用示例:
%(prog)s --input input.fa --output output.fa --db database.rkdb
%(prog)s -i input.fa -o output.fa -d database.rkdb --max-n 15 --top-n 2000
处理规则:
- 有内部gap的序列: 使用k-mer算法填充gap
- 开头/结尾>3个N的序列: 直接输出原始序列
- 没有N的序列: 直接输出原始序列
- 仅末端N≤3的序列: 直接输出原始序列
修复内容:
- 修复列表索引越界错误
- 增强边界检查逻辑
- 改进序列映射算法
- 添加智能序列过滤机制
- 统一输出所有序列
特性:
- 单线程处理,简单可靠
- 智能序列分类处理
- 适合小文件或调试使用
- 清晰的进度显示
- 完整的统计信息
- 强化的错误处理
""",
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument(
'--input', '-i',
required=True,
help='输入 FASTA 文件路径'
)
parser.add_argument(
'--output', '-o',
required=True,
help='输出 FASTA 文件路径'
)
parser.add_argument(
'--database', '-d',
required=True,
help='k-mer 数据库文件路径'
)
parser.add_argument(
'--kmer-size',
type=int,
default=19,
help='k-mer 大小 (默认: 19)'
)
parser.add_argument(
'--max-n',
type=int,
default=11,
help='每个区域最大 N 数量 (默认: 11)'
)
parser.add_argument(
'--top-n',
type=int,
default=1000,
help='考虑的前 N 个匹配 (默认: 1000)'
)
args = parser.parse_args()
if not Path(args.input).exists():
print(f"❌ 输入文件不存在: {args.input}")
sys.exit(1)
if not Path(args.database).exists():
print(f"❌ 数据库文件不存在: {args.database}")
sys.exit(1)
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
processor = FastaGapProcessor(
database_path=args.database,
kmer_size=args.kmer_size,
max_n_per_region=args.max_n,
top_n=args.top_n
)
try:
processor.process_fasta_file(args.input, args.output)
except KeyboardInterrupt:
print("\n⚠️ 用户中断处理")
sys.exit(1)
except Exception as e:
print(f"❌ 处理过程中出错: {e}")
sys.exit(1)
if __name__ == "__main__":
main()