1use crate::io::error::Error;
8use crate::model::{
9 atom::Atom, residue::Residue, structure::Structure, topology::Topology, types::ResidueCategory,
10};
11use std::collections::HashMap;
12use std::io::Write;
13
14pub fn write_structure<W: Write>(writer: W, structure: &Structure) -> Result<(), Error> {
47 let mut ctx = WriterContext::new(writer);
48
49 ctx.write_cryst1(structure.box_vectors)?;
50
51 ctx.write_atoms(structure)?;
52
53 ctx.write_end()?;
54
55 Ok(())
56}
57
58pub fn write_topology<W: Write>(writer: W, topology: &Topology) -> Result<(), Error> {
96 let mut ctx = WriterContext::new(writer);
97 let structure = topology.structure();
98
99 ctx.write_cryst1(structure.box_vectors)?;
100
101 ctx.write_atoms(structure)?;
102
103 ctx.write_connects(topology)?;
104
105 ctx.write_end()?;
106
107 Ok(())
108}
109
110struct WriterContext<W> {
111 writer: W,
112 current_serial: usize,
113 atom_index_to_serial: HashMap<usize, usize>,
114}
115
116impl<W: Write> WriterContext<W> {
117 fn new(writer: W) -> Self {
123 Self {
124 writer,
125 current_serial: 1,
126 atom_index_to_serial: HashMap::new(),
127 }
128 }
129
130 fn write_cryst1(&mut self, box_vectors: Option<[[f64; 3]; 3]>) -> Result<(), Error> {
140 if let Some(vectors) = box_vectors {
141 let v1 = nalgebra::Vector3::from(vectors[0]);
142 let v2 = nalgebra::Vector3::from(vectors[1]);
143 let v3 = nalgebra::Vector3::from(vectors[2]);
144
145 let a = v1.norm();
146 let b = v2.norm();
147 let c = v3.norm();
148
149 let alpha = v2.angle(&v3).to_degrees();
150 let beta = v1.angle(&v3).to_degrees();
151 let gamma = v1.angle(&v2).to_degrees();
152
153 writeln!(
154 self.writer,
155 "CRYST1{:9.3}{:9.3}{:9.3}{:7.2}{:7.2}{:7.2} P 1 1",
156 a, b, c, alpha, beta, gamma
157 )
158 .map_err(|e| Error::from_io(e, None))?;
159 }
160 Ok(())
161 }
162
163 fn write_atoms(&mut self, structure: &Structure) -> Result<(), Error> {
169 let mut global_idx = 0;
170
171 for chain in structure.iter_chains() {
172 for residue in chain.iter_residues() {
173 for atom in residue.iter_atoms() {
174 let record_type = match residue.standard_name {
175 Some(std) if std.is_protein() || std.is_nucleic() => "ATOM ",
176 _ => "HETATM",
177 };
178
179 let serial = self.current_serial;
180
181 self.atom_index_to_serial.insert(global_idx, serial);
182
183 self.write_atom_record(record_type, serial, atom, residue, &chain.id)?;
184
185 self.current_serial += 1;
186 global_idx += 1;
187 }
188 }
189
190 if let Some(last_standard) = chain
191 .iter_residues()
192 .rev()
193 .find(|res| res.category == ResidueCategory::Standard)
194 {
195 let serial = self.current_serial;
196 self.write_ter_record(serial, last_standard, &chain.id)?;
197 self.current_serial += 1;
198 }
199 }
200 Ok(())
201 }
202
203 fn write_atom_record(
213 &mut self,
214 record_type: &str,
215 serial: usize,
216 atom: &Atom,
217 residue: &Residue,
218 chain_id: &str,
219 ) -> Result<(), Error> {
220 let atom_name = if atom.name.len() >= 4 {
221 format!("{:<4}", &atom.name[0..4])
222 } else {
223 format!(" {:<3}", atom.name)
224 };
225
226 let res_name = if residue.name.len() > 3 {
227 &residue.name[0..3]
228 } else {
229 &residue.name
230 };
231
232 let element_str = format!("{:>2}", atom.element.symbol().to_uppercase());
233
234 writeln!(
235 self.writer,
236 "{:6}{:5} {:4}{:1}{:3} {:1}{:4}{:1} {:8.3}{:8.3}{:8.3}{:6.2}{:6.2} {:2}",
237 record_type,
238 serial % 100000,
239 atom_name,
240 ' ',
241 res_name,
242 chain_id.chars().next().unwrap_or(' '),
243 residue.id % 10000,
244 residue.insertion_code.unwrap_or(' '),
245 atom.pos.x,
246 atom.pos.y,
247 atom.pos.z,
248 1.00,
249 0.00,
250 element_str
251 )
252 .map_err(|e| Error::from_io(e, None))
253 }
254
255 fn write_ter_record(
263 &mut self,
264 serial: usize,
265 residue: &Residue,
266 chain_id: &str,
267 ) -> Result<(), Error> {
268 let res_name = if residue.name.len() > 3 {
269 &residue.name[0..3]
270 } else {
271 &residue.name
272 };
273
274 writeln!(
275 self.writer,
276 "TER {:5} {:3} {:1}{:4}{:1}",
277 serial % 100000,
278 res_name,
279 chain_id.chars().next().unwrap_or(' '),
280 residue.id % 10000,
281 residue.insertion_code.unwrap_or(' ')
282 )
283 .map_err(|e| Error::from_io(e, None))
284 }
285
286 fn write_connects(&mut self, topology: &Topology) -> Result<(), Error> {
297 let mut adjacency: HashMap<usize, Vec<usize>> = HashMap::new();
298
299 for bond in topology.bonds() {
300 let s1 = *self.atom_index_to_serial.get(&bond.a1_idx).ok_or_else(|| {
301 Error::inconsistent_data(
302 "PDB",
303 None,
304 format!(
305 "bond references atom index {} that was not written",
306 bond.a1_idx
307 ),
308 )
309 })?;
310 let s2 = *self.atom_index_to_serial.get(&bond.a2_idx).ok_or_else(|| {
311 Error::inconsistent_data(
312 "PDB",
313 None,
314 format!(
315 "bond references atom index {} that was not written",
316 bond.a2_idx
317 ),
318 )
319 })?;
320
321 adjacency.entry(s1).or_default().push(s2);
322 adjacency.entry(s2).or_default().push(s1);
323 }
324
325 let mut serials: Vec<_> = adjacency.keys().copied().collect();
326 serials.sort();
327
328 for src_serial in serials {
329 let targets = adjacency.get(&src_serial).unwrap();
330 let mut targets = targets.clone();
331 targets.sort();
332 targets.dedup();
333
334 for chunk in targets.chunks(4) {
335 write!(self.writer, "CONECT{:5}", src_serial)
336 .map_err(|e| Error::from_io(e, None))?;
337 for target in chunk {
338 write!(self.writer, "{:5}", target).map_err(|e| Error::from_io(e, None))?;
339 }
340 writeln!(self.writer).map_err(|e| Error::from_io(e, None))?;
341 }
342 }
343
344 Ok(())
345 }
346
347 fn write_end(&mut self) -> Result<(), Error> {
353 writeln!(self.writer, "END ").map_err(|e| Error::from_io(e, None))
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use crate::model::atom::Atom;
361 use crate::model::chain::Chain;
362 use crate::model::residue::Residue;
363 use crate::model::topology::{Bond, Topology};
364 use crate::model::types::{BondOrder, Element, Point, ResidueCategory, StandardResidue};
365
366 fn assert_cryst1_line(line: &str, params: (f64, f64, f64, f64, f64, f64)) {
367 assert!(line.starts_with("CRYST1"));
368 let (a, b, c, alpha, beta, gamma) = params;
369 assert!((parse_float(&line[6..15]) - a).abs() < 1e-3);
370 assert!((parse_float(&line[15..24]) - b).abs() < 1e-3);
371 assert!((parse_float(&line[24..33]) - c).abs() < 1e-3);
372 assert!((parse_float(&line[33..40]) - alpha).abs() < 1e-2);
373 assert!((parse_float(&line[40..47]) - beta).abs() < 1e-2);
374 assert!((parse_float(&line[47..54]) - gamma).abs() < 1e-2);
375 }
376
377 fn assert_atom_line(
378 line: &str,
379 record: &str,
380 serial: usize,
381 atom_name: &str,
382 res_name: &str,
383 chain_id: char,
384 res_seq: i32,
385 insertion_code: char,
386 coords: (f64, f64, f64),
387 element: &str,
388 ) {
389 assert!(line.len() >= 78, "line too short: {line}");
390 assert_eq!(&line[0..6], record);
391 assert_eq!(line[6..11].trim(), serial.to_string());
392 assert_eq!(line[12..16].trim(), atom_name);
393 assert_eq!(line[17..20].trim(), res_name);
394 assert_eq!(line.chars().nth(21).unwrap(), chain_id);
395 assert_eq!(line[22..26].trim(), res_seq.to_string());
396 assert_eq!(line.chars().nth(26).unwrap(), insertion_code);
397 assert!((parse_float(&line[30..38]) - coords.0).abs() < 1e-3);
398 assert!((parse_float(&line[38..46]) - coords.1).abs() < 1e-3);
399 assert!((parse_float(&line[46..54]) - coords.2).abs() < 1e-3);
400 assert_eq!(line[76..78].trim(), element);
401 }
402
403 fn assert_ter_line(
404 line: &str,
405 serial: usize,
406 res_name: &str,
407 chain_id: char,
408 res_seq: i32,
409 insertion_code: char,
410 ) {
411 assert!(line.starts_with("TER "));
412 assert_eq!(line[6..11].trim(), serial.to_string());
413 assert_eq!(line[17..20].trim(), res_name);
414 assert_eq!(line.chars().nth(21).unwrap(), chain_id);
415 assert_eq!(line[22..26].trim(), res_seq.to_string());
416 assert_eq!(line.chars().nth(26).unwrap(), insertion_code);
417 }
418
419 fn assert_conect_line(line: &str, source: usize, targets: &[usize]) {
420 assert!(line.starts_with("CONECT"));
421 let tokens: Vec<_> = line.split_whitespace().collect();
422 assert_eq!(tokens[0], "CONECT");
423 assert_eq!(tokens[1].parse::<usize>().unwrap(), source);
424 let parsed_targets: Vec<_> = tokens[2..]
425 .iter()
426 .map(|tok| tok.parse::<usize>().unwrap())
427 .collect();
428 assert_eq!(parsed_targets, targets);
429 }
430
431 fn parse_float(slice: &str) -> f64 {
432 slice.trim().parse::<f64>().expect("valid float")
433 }
434
435 #[test]
436 fn write_structure_emits_cryst1_atoms_ter_and_end() {
437 let mut structure = Structure::new();
438 structure.box_vectors = Some([[10.0, 0.0, 0.0], [0.0, 11.0, 0.0], [0.0, 0.0, 12.0]]);
439
440 let mut chain = Chain::new("A");
441
442 let mut gly = Residue::new(
443 1,
444 None,
445 "GLY",
446 Some(StandardResidue::GLY),
447 ResidueCategory::Standard,
448 );
449 gly.add_atom(Atom::new("N", Element::N, Point::new(1.0, 2.0, 3.0)));
450 gly.add_atom(Atom::new("CA", Element::C, Point::new(1.5, 2.5, 3.5)));
451
452 let mut lig = Residue::new(2, None, "LIG", None, ResidueCategory::Hetero);
453 lig.add_atom(Atom::new("C1", Element::C, Point::new(4.0, 5.0, 6.0)));
454
455 chain.add_residue(gly);
456 chain.add_residue(lig);
457 structure.add_chain(chain);
458
459 let mut buffer = Vec::new();
460 write_structure(&mut buffer, &structure).expect("writer should succeed");
461
462 let output = String::from_utf8(buffer).expect("valid UTF-8");
463 let lines: Vec<&str> = output.lines().collect();
464 assert_eq!(lines.len(), 6, "unexpected number of lines: {lines:?}");
465
466 assert_cryst1_line(lines[0], (10.0, 11.0, 12.0, 90.0, 90.0, 90.0));
467 assert_atom_line(
468 lines[1],
469 "ATOM ",
470 1,
471 "N",
472 "GLY",
473 'A',
474 1,
475 ' ',
476 (1.0, 2.0, 3.0),
477 "N",
478 );
479 assert_atom_line(
480 lines[2],
481 "ATOM ",
482 2,
483 "CA",
484 "GLY",
485 'A',
486 1,
487 ' ',
488 (1.5, 2.5, 3.5),
489 "C",
490 );
491 assert_atom_line(
492 lines[3],
493 "HETATM",
494 3,
495 "C1",
496 "LIG",
497 'A',
498 2,
499 ' ',
500 (4.0, 5.0, 6.0),
501 "C",
502 );
503 assert_ter_line(lines[4], 4, "GLY", 'A', 1, ' ');
504 assert_eq!(lines[5], "END ");
505 }
506
507 #[test]
508 fn write_structure_without_box_starts_with_atom_records() {
509 let mut structure = Structure::new();
510 let mut chain = Chain::new("B");
511
512 let mut ser = Residue::new(
513 7,
514 Some('A'),
515 "SER",
516 Some(StandardResidue::SER),
517 ResidueCategory::Standard,
518 );
519 ser.add_atom(Atom::new("OG", Element::O, Point::new(-1.0, 0.5, 2.0)));
520 chain.add_residue(ser);
521 structure.add_chain(chain);
522
523 let mut buffer = Vec::new();
524 write_structure(&mut buffer, &structure).expect("writer should succeed");
525
526 let output = String::from_utf8(buffer).expect("valid UTF-8");
527 let mut lines = output.lines();
528
529 let first_line = lines.next().expect("at least one line");
530 assert!(first_line.starts_with("ATOM"));
531 assert!(lines.any(|line| line == "END "));
532 }
533
534 #[test]
535 fn non_polymer_standard_residue_uses_hetatm_record() {
536 let mut structure = Structure::new();
537 let mut chain = Chain::new("W");
538
539 let mut water = Residue::new(
540 42,
541 None,
542 "HOH",
543 Some(StandardResidue::HOH),
544 ResidueCategory::Standard,
545 );
546 water.add_atom(Atom::new("O", Element::O, Point::new(0.0, 0.0, 0.0)));
547 chain.add_residue(water);
548 structure.add_chain(chain);
549
550 let mut buffer = Vec::new();
551 write_structure(&mut buffer, &structure).expect("writer should succeed");
552
553 let output = String::from_utf8(buffer).expect("valid UTF-8");
554 let first_line = output.lines().next().expect("at least one line");
555 assert!(
556 first_line.starts_with("HETATM"),
557 "line should be HETATM: {first_line}"
558 );
559 }
560
561 #[test]
562 fn write_topology_emits_conect_records() {
563 let mut structure = Structure::new();
564 let mut chain = Chain::new("C");
565 let mut ala = Residue::new(
566 10,
567 None,
568 "ALA",
569 Some(StandardResidue::ALA),
570 ResidueCategory::Standard,
571 );
572 ala.add_atom(Atom::new("N", Element::N, Point::new(0.0, 0.0, 0.0)));
573 ala.add_atom(Atom::new("CA", Element::C, Point::new(1.0, 0.0, 0.0)));
574 chain.add_residue(ala);
575 structure.add_chain(chain);
576
577 let topology = Topology::new(structure.clone(), vec![Bond::new(0, 1, BondOrder::Single)]);
578
579 let mut buffer = Vec::new();
580 write_topology(&mut buffer, &topology).expect("topology writer succeeds");
581
582 let output = String::from_utf8(buffer).expect("valid UTF-8");
583 let conect_lines: Vec<&str> = output
584 .lines()
585 .filter(|line| line.starts_with("CONECT"))
586 .collect();
587
588 assert_eq!(conect_lines.len(), 2);
589 assert_conect_line(conect_lines[0], 1, &[2]);
590 assert_conect_line(conect_lines[1], 2, &[1]);
591 }
592
593 #[test]
594 fn write_connects_returns_error_when_serial_missing() {
595 let mut structure = Structure::new();
596 let mut chain = Chain::new("D");
597 let mut gly = Residue::new(
598 5,
599 None,
600 "GLY",
601 Some(StandardResidue::GLY),
602 ResidueCategory::Standard,
603 );
604 gly.add_atom(Atom::new("N", Element::N, Point::new(0.0, 0.0, 0.0)));
605 gly.add_atom(Atom::new("CA", Element::C, Point::new(1.0, 1.0, 0.0)));
606 chain.add_residue(gly);
607 structure.add_chain(chain);
608
609 let topology = Topology::new(structure.clone(), vec![Bond::new(0, 1, BondOrder::Single)]);
610
611 let mut ctx = WriterContext::new(Vec::new());
612 ctx.write_atoms(&structure).expect("atoms should write");
613 ctx.atom_index_to_serial.clear();
614
615 let err = ctx
616 .write_connects(&topology)
617 .expect_err("missing serial map should error");
618
619 match err {
620 Error::InconsistentData { details, .. } => {
621 assert!(details.contains("bond references atom index"));
622 }
623 other => panic!("unexpected error: {other:?}"),
624 }
625 }
626}