1use std::{collections::HashMap, fs::File, io, io::BufReader};
2
3use gen_models::{
4 block_group::BlockGroup,
5 db::GraphConnection,
6 path::{Annotation, Path},
7 sample::Sample,
8};
9use noodles::{core::Position, gff};
10
11pub fn gff_attribute_value_to_string(
12 attrs: &gff::feature::record_buf::Attributes,
13 key: &str,
14) -> Option<String> {
15 let key_bytes = key.as_bytes();
16 attrs.as_ref().iter().find_map(|(tag, value)| {
17 let tag_bytes: &[u8] = tag.as_ref();
18 if !tag_bytes.eq_ignore_ascii_case(key_bytes) {
19 return None;
20 }
21 if let Some(value) = value.as_string() {
22 Some(String::from_utf8_lossy(value.as_ref()).to_string())
23 } else {
24 value
25 .iter()
26 .next()
27 .map(|item| String::from_utf8_lossy(item.as_ref()).to_string())
28 }
29 })
30}
31
32pub fn propagate_gff(
33 conn: &GraphConnection,
34 collection_name: &str,
35 from_sample_name: Option<&str>,
36 to_sample_name: &str,
37 gff_input_filename: &str,
38 gff_output_filename: &str,
39) -> io::Result<()> {
40 let mut reader = File::open(gff_input_filename)
41 .map(BufReader::new)
42 .map(gff::io::Reader::new)?;
43
44 let output_file = File::create(gff_output_filename).unwrap();
45 let mut writer = gff::io::Writer::new(output_file);
46
47 let source_block_groups = Sample::get_block_groups(conn, collection_name, from_sample_name);
48 let target_block_groups = Sample::get_block_groups(conn, collection_name, Some(to_sample_name));
49 let source_paths_by_bg_name = source_block_groups
50 .iter()
51 .map(|bg| (bg.name.clone(), BlockGroup::get_current_path(conn, &bg.id)))
52 .collect::<HashMap<String, Path>>();
53 let target_paths_by_bg_name = target_block_groups
54 .iter()
55 .map(|bg| (bg.name.clone(), BlockGroup::get_current_path(conn, &bg.id)))
56 .collect::<HashMap<String, Path>>();
57
58 let mut path_mappings_by_bg_name = HashMap::new();
59 for (name, target_path) in &target_paths_by_bg_name {
60 let source_path = source_paths_by_bg_name.get(name).unwrap();
61 let mapping = source_path.get_mapping_tree(conn, target_path);
62 path_mappings_by_bg_name.insert(name, mapping);
63 }
64
65 let sequence_lengths_by_path_name = target_paths_by_bg_name
66 .iter()
67 .map(|(name, path)| (name.clone(), path.sequence(conn).len() as i64))
68 .collect::<HashMap<String, i64>>();
69
70 for result in reader.record_bufs() {
71 let record = result?;
72 let path_name = record.reference_sequence_name().to_string();
73 let annotation = Annotation {
74 name: "".to_string(),
75 start: record.start().get() as i64,
76 end: record.end().get() as i64,
77 };
78 let mapping_tree = path_mappings_by_bg_name.get(&path_name).unwrap();
79 let sequence_length = sequence_lengths_by_path_name.get(&path_name).unwrap();
80 let propagated_annotation =
81 Path::propagate_annotation(annotation, mapping_tree, *sequence_length).unwrap();
82
83 let score = record.score();
84 let phase = record.phase();
85 let mut updated_record_builder = gff::feature::RecordBuf::builder()
86 .set_reference_sequence_name(path_name)
87 .set_source(record.source().to_string())
88 .set_type(record.ty().to_string())
89 .set_start(
90 Position::new(propagated_annotation.start.try_into().unwrap())
91 .expect("Could not convert start ({start}) to usize for propagation"),
92 )
93 .set_end(
94 Position::new(propagated_annotation.end.try_into().unwrap())
95 .expect("Could not convert end ({end}) to usize for propagation"),
96 )
97 .set_strand(record.strand())
98 .set_attributes(record.attributes().clone());
99
100 if let Some(score) = score {
101 updated_record_builder = updated_record_builder.set_score(score);
102 }
103 if let Some(phase) = phase {
104 updated_record_builder = updated_record_builder.set_phase(phase);
105 }
106
107 writer.write_record(&updated_record_builder.build())?;
108 }
109
110 Ok(())
111}
112
113#[cfg(test)]
114mod tests {
115 use std::{fs::File, io::BufReader, path::PathBuf};
116
117 use gen_core::{
118 HashId, NO_CHROMOSOME_INDEX, PATH_END_NODE_ID, PATH_START_NODE_ID, PathBlock, Strand,
119 };
120 use gen_models::{
121 block_group::{BlockGroup, PathChange},
122 block_group_edge::{BlockGroupEdge, BlockGroupEdgeData},
123 collection::Collection,
124 db::GraphConnection,
125 edge::Edge,
126 node::Node,
127 path::Path,
128 sample::Sample,
129 sequence::Sequence,
130 traits::Query,
131 };
132 use noodles::gff;
133 use tempfile::tempdir;
134
135 use super::propagate_gff;
136 use crate::test_helpers::get_connection;
137
138 fn create_block_group(conn: &GraphConnection) {
139 let collection = Collection::create(conn, "test");
140 let sequence = "ATCGATCGATCGATCGATCGGGAACACACAGAGA";
141 let reference_sequence = Sequence::new()
142 .sequence_type("DNA")
143 .sequence(sequence)
144 .save(conn);
145 let node_id = Node::create(
146 conn,
147 &reference_sequence.hash,
148 &HashId::convert_str(&format!(
149 "{collection}.m123:{hash}",
150 collection = collection.name,
151 hash = reference_sequence.hash
152 )),
153 );
154 let block_group = BlockGroup::create(conn, &collection.name, None, "m123");
155
156 let edge_into = Edge::create(
157 conn,
158 PATH_START_NODE_ID,
159 0,
160 Strand::Forward,
161 node_id,
162 0,
163 Strand::Forward,
164 );
165 let edge_out_of = Edge::create(
166 conn,
167 node_id,
168 reference_sequence.length,
169 Strand::Forward,
170 PATH_END_NODE_ID,
171 0,
172 Strand::Forward,
173 );
174
175 let new_block_group_edges = vec![
176 BlockGroupEdgeData {
177 block_group_id: block_group.id,
178 edge_id: edge_into.id,
179 chromosome_index: 0,
180 phased: 0,
181 },
182 BlockGroupEdgeData {
183 block_group_id: block_group.id,
184 edge_id: edge_out_of.id,
185 chromosome_index: 0,
186 phased: 0,
187 },
188 ];
189
190 BlockGroupEdge::bulk_create(conn, &new_block_group_edges);
191 Path::create(
192 conn,
193 "m123",
194 &block_group.id,
195 &[edge_into.id, edge_out_of.id],
196 );
197 }
198
199 fn apply_child_sample_update_from_aa_fasta(conn: &GraphConnection) {
200 Sample::get_or_create(conn, "child sample");
201 let _ = Sample::get_or_create_child(conn, "test", "child sample", None);
202
203 let sample_bg_id = BlockGroup::get_or_create_sample_block_group(
204 conn,
205 "test",
206 "child sample",
207 "m123",
208 None,
209 )
210 .expect("should create child block group");
211 let sample_path = BlockGroup::get_current_path(conn, &sample_bg_id);
212 let tree = sample_path.intervaltree(conn);
213 let replacement_sequence = "AA";
214
215 let replacement = Sequence::new()
216 .sequence_type("DNA")
217 .sequence(replacement_sequence)
218 .save(conn);
219 let node_id = Node::create(
220 conn,
221 &replacement.hash,
222 &HashId::convert_str(&format!(
223 "{path_id}:15-25->{sequence_hash}",
224 path_id = sample_path.id,
225 sequence_hash = replacement.hash,
226 )),
227 );
228 let change = PathChange {
229 block_group_id: sample_bg_id,
230 path: sample_path.clone(),
231 path_accession: None,
232 start: 15,
233 end: 25,
234 block: PathBlock {
235 id: 0,
236 node_id,
237 block_sequence: replacement_sequence.to_string(),
238 sequence_start: 0,
239 sequence_end: replacement_sequence.len() as i64,
240 path_start: 15,
241 path_end: 25,
242 strand: Strand::Forward,
243 },
244 chromosome_index: NO_CHROMOSOME_INDEX,
245 phased: 0,
246 preserve_edge: true,
247 };
248
249 BlockGroup::insert_change(conn, &change, &tree)
250 .expect("should apply AA update to child sample");
251
252 let edge_to_insert = Edge::query(
253 conn,
254 "select * from edges where target_node_id = ?1",
255 rusqlite::params![node_id],
256 )[0]
257 .clone();
258 let edge_from_insert = Edge::query(
259 conn,
260 "select * from edges where source_node_id = ?1",
261 rusqlite::params![node_id],
262 )[0]
263 .clone();
264 sample_path.new_path_with(conn, 15, 25, &edge_to_insert, &edge_from_insert);
265 }
266
267 #[test]
268 fn simple_propagate() {
269 let conn = get_connection();
270 create_block_group(&conn);
271 apply_child_sample_update_from_aa_fasta(&conn);
272
273 let gff_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("fixtures/simple.gff");
274 let temp_dir = tempdir().expect("should create temp directory");
275 let output_path = temp_dir.path().join("output.gff");
276
277 propagate_gff(
278 &conn,
279 "test",
280 None,
281 "child sample",
282 gff_path.to_str().expect("should convert gff path to UTF-8"),
283 output_path
284 .to_str()
285 .expect("should convert output path to UTF-8"),
286 )
287 .expect("should propagate gff to child sample");
288
289 let mut reader = File::open(output_path)
290 .map(BufReader::new)
291 .map(gff::io::Reader::new)
292 .expect("should read output file");
293
294 for (index, result) in reader.record_bufs().enumerate() {
295 let record = result.expect("should parse output gff record");
296 assert_eq!(record.reference_sequence_name(), "m123");
297 if index == 0 {
298 assert_eq!(record.source(), "gen-test");
299 assert_eq!(record.ty(), "Region");
300 assert_eq!(record.start().get(), 1);
301 assert_eq!(record.end().get(), 26);
302 } else {
303 assert_eq!(record.source(), "gen-test");
304 assert_eq!(record.ty(), "Gene");
305 assert_eq!(record.start().get(), 5);
306 assert_eq!(record.end().get(), 15);
307 }
308 }
309 }
310}