Skip to main content

bio_forge/io/pdb/
writer.rs

1//! PDB writer utilities that serialize structures and topologies into fixed-width records.
2//!
3//! The module handles optional unit-cell information, deterministic atom serial numbering,
4//! TER record emission, and `CONECT` reconstruction from a [`Topology`] to ensure round-trip
5//! compatibility with downstream crystallography tools.
6
7use crate::io::error::Error;
8use crate::model::{
9    atom::Atom, residue::Residue, structure::Structure, topology::Topology, types::ResidueCategory,
10};
11use std::collections::HashMap;
12use std::io::Write;
13
14/// Writes a [`Structure`] to PDB format, including optional CRYST1 and TER records.
15///
16/// The function traverses chains in their stored order, emits `ATOM` records for polymeric
17/// residues, `HETATM` records for everything else, and appends a final `END` line so the file
18/// is ready for legacy toolchains.
19///
20/// # Arguments
21///
22/// * `writer` - Destination implementing [`Write`], such as a file or in-memory buffer.
23/// * `structure` - Source structure whose chains and atoms are serialized.
24///
25/// # Returns
26///
27/// [`Ok`] if writing succeeded; [`Error`] if IO failures occur.
28///
29/// # Examples
30///
31/// ```
32/// use bio_forge::io::{write_pdb_structure, IoContext, read_pdb_structure};
33/// use std::io::Cursor;
34///
35/// // Parse a minimal PDB and immediately write it back out.
36/// let pdb = "\
37/// ATOM      1  N   GLY A   1       0.000   0.000   0.000  1.00 20.00           N\n\
38/// END\n";
39/// let context = IoContext::new_default();
40/// let mut cursor = Cursor::new(pdb.as_bytes());
41/// let structure = read_pdb_structure(&mut cursor, &context).unwrap();
42/// let mut out = Vec::new();
43/// write_pdb_structure(&mut out, &structure).unwrap();
44/// assert!(String::from_utf8(out).unwrap().contains("END"));
45/// ```
46pub fn write_structure<W: Write>(writer: W, structure: &Structure) -> Result<(), Error> {
47    let mut ctx = WriterContext::new(writer);
48
49    ctx.write_cryst1(structure.box_vectors)?;
50
51    ctx.write_atoms(structure)?;
52
53    ctx.write_end()?;
54
55    Ok(())
56}
57
58/// Writes a [`Topology`] to PDB format, including `CONECT` bonding information.
59///
60/// This convenience helper mirrors [`write_structure`] for coordinate output, then traverses
61/// the topology graph to emit `CONECT` records that retain bonding data for visualization
62/// tools.
63///
64/// # Arguments
65///
66/// * `writer` - Output sink implementing [`Write`].
67/// * `topology` - Source topology whose structure and bonds are serialized.
68///
69/// # Returns
70///
71/// [`Ok`] if writing succeeded; [`Error`] if serialization or IO steps fail.
72///
73/// # Examples
74///
75/// ```
76/// use bio_forge::{Bond, BondOrder, Topology};
77/// use bio_forge::io::{write_pdb_topology, IoContext, read_pdb_structure};
78/// use std::io::Cursor;
79///
80/// let pdb = "\
81/// ATOM      1  N   GLY A   1       0.000   0.000   0.000  1.00 20.00           N\n\
82/// ATOM      2  CA  GLY A   1       1.000   0.000   0.000  1.00 20.00           C\n\
83/// END\n";
84/// let context = IoContext::new_default();
85/// let mut cursor = Cursor::new(pdb.as_bytes());
86/// let structure = read_pdb_structure(&mut cursor, &context).unwrap();
87/// let topo = Topology::new(
88///     structure.clone(),
89///     vec![Bond::new(0, 1, BondOrder::Single)],
90/// );
91/// let mut out = Vec::new();
92/// write_pdb_topology(&mut out, &topo).unwrap();
93/// assert!(String::from_utf8(out).unwrap().contains("CONECT"));
94/// ```
95pub fn write_topology<W: Write>(writer: W, topology: &Topology) -> Result<(), Error> {
96    let mut ctx = WriterContext::new(writer);
97    let structure = topology.structure();
98
99    ctx.write_cryst1(structure.box_vectors)?;
100
101    ctx.write_atoms(structure)?;
102
103    ctx.write_connects(topology)?;
104
105    ctx.write_end()?;
106
107    Ok(())
108}
109
110struct WriterContext<W> {
111    writer: W,
112    current_serial: usize,
113    atom_index_to_serial: HashMap<usize, usize>,
114}
115
116impl<W: Write> WriterContext<W> {
117    /// Constructs a new writer state machine with cleared serial counters.
118    ///
119    /// # Arguments
120    ///
121    /// * `writer` - Output sink that will receive the generated PDB text.
122    fn new(writer: W) -> Self {
123        Self {
124            writer,
125            current_serial: 1,
126            atom_index_to_serial: HashMap::new(),
127        }
128    }
129
130    /// Outputs a `CRYST1` record if orthogonal vectors are available.
131    ///
132    /// # Arguments
133    ///
134    /// * `box_vectors` - Optional 3×3 matrix of unit-cell vectors.
135    ///
136    /// # Returns
137    ///
138    /// [`Ok`] whether or not the record was emitted; [`Error`] only if IO fails.
139    fn write_cryst1(&mut self, box_vectors: Option<[[f64; 3]; 3]>) -> Result<(), Error> {
140        if let Some(vectors) = box_vectors {
141            let v1 = nalgebra::Vector3::from(vectors[0]);
142            let v2 = nalgebra::Vector3::from(vectors[1]);
143            let v3 = nalgebra::Vector3::from(vectors[2]);
144
145            let a = v1.norm();
146            let b = v2.norm();
147            let c = v3.norm();
148
149            let alpha = v2.angle(&v3).to_degrees();
150            let beta = v1.angle(&v3).to_degrees();
151            let gamma = v1.angle(&v2).to_degrees();
152
153            writeln!(
154                self.writer,
155                "CRYST1{:9.3}{:9.3}{:9.3}{:7.2}{:7.2}{:7.2} P 1           1",
156                a, b, c, alpha, beta, gamma
157            )
158            .map_err(|e| Error::from_io(e, None))?;
159        }
160        Ok(())
161    }
162
163    /// Writes all chain atoms and inserts `TER` records after the final standard residue.
164    ///
165    /// # Arguments
166    ///
167    /// * `structure` - Source structure providing chains, residues, and atoms.
168    fn write_atoms(&mut self, structure: &Structure) -> Result<(), Error> {
169        let mut global_idx = 0;
170
171        for chain in structure.iter_chains() {
172            for residue in chain.iter_residues() {
173                for atom in residue.iter_atoms() {
174                    let record_type = match residue.standard_name {
175                        Some(std) if std.is_protein() || std.is_nucleic() => "ATOM  ",
176                        _ => "HETATM",
177                    };
178
179                    let serial = self.current_serial;
180
181                    self.atom_index_to_serial.insert(global_idx, serial);
182
183                    self.write_atom_record(record_type, serial, atom, residue, &chain.id)?;
184
185                    self.current_serial += 1;
186                    global_idx += 1;
187                }
188            }
189
190            if let Some(last_standard) = chain
191                .iter_residues()
192                .rev()
193                .find(|res| res.category == ResidueCategory::Standard)
194            {
195                let serial = self.current_serial;
196                self.write_ter_record(serial, last_standard, &chain.id)?;
197                self.current_serial += 1;
198            }
199        }
200        Ok(())
201    }
202
203    /// Formats a single `ATOM` or `HETATM` entry using the fixed-width PDB layout.
204    ///
205    /// # Arguments
206    ///
207    /// * `record_type` - Either `"ATOM  "` or `"HETATM"` depending on residue type.
208    /// * `serial` - Sequential atom serial number.
209    /// * `atom` - Atom instance providing coordinates and element.
210    /// * `residue` - Parent residue containing residue identifiers.
211    /// * `chain_id` - Author chain identifier string.
212    fn write_atom_record(
213        &mut self,
214        record_type: &str,
215        serial: usize,
216        atom: &Atom,
217        residue: &Residue,
218        chain_id: &str,
219    ) -> Result<(), Error> {
220        let atom_name = if atom.name.len() >= 4 {
221            format!("{:<4}", &atom.name[0..4])
222        } else {
223            format!(" {:<3}", atom.name)
224        };
225
226        let res_name = if residue.name.len() > 3 {
227            &residue.name[0..3]
228        } else {
229            &residue.name
230        };
231
232        let element_str = format!("{:>2}", atom.element.symbol().to_uppercase());
233
234        writeln!(
235            self.writer,
236            "{:6}{:5} {:4}{:1}{:3} {:1}{:4}{:1}   {:8.3}{:8.3}{:8.3}{:6.2}{:6.2}          {:2}",
237            record_type,
238            serial % 100000,
239            atom_name,
240            ' ',
241            res_name,
242            chain_id.chars().next().unwrap_or(' '),
243            residue.id % 10000,
244            residue.insertion_code.unwrap_or(' '),
245            atom.pos.x,
246            atom.pos.y,
247            atom.pos.z,
248            1.00,
249            0.00,
250            element_str
251        )
252        .map_err(|e| Error::from_io(e, None))
253    }
254
255    /// Emits a `TER` record to terminate the current polymer chain.
256    ///
257    /// # Arguments
258    ///
259    /// * `serial` - Serial number assigned to the TER record.
260    /// * `residue` - Residue that concludes the polymer chain.
261    /// * `chain_id` - Chain identifier owning the residue.
262    fn write_ter_record(
263        &mut self,
264        serial: usize,
265        residue: &Residue,
266        chain_id: &str,
267    ) -> Result<(), Error> {
268        let res_name = if residue.name.len() > 3 {
269            &residue.name[0..3]
270        } else {
271            &residue.name
272        };
273
274        writeln!(
275            self.writer,
276            "TER   {:5}      {:3} {:1}{:4}{:1}",
277            serial % 100000,
278            res_name,
279            chain_id.chars().next().unwrap_or(' '),
280            residue.id % 10000,
281            residue.insertion_code.unwrap_or(' ')
282        )
283        .map_err(|e| Error::from_io(e, None))
284    }
285
286    /// Serializes topology bonds into grouped `CONECT` records with deduplicated targets.
287    ///
288    /// # Arguments
289    ///
290    /// * `topology` - Topology providing bond definitions that should be written.
291    ///
292    /// # Returns
293    ///
294    /// [`Ok`] after writing all bonds or [`Error::InconsistentData`] if atom indices are
295    /// missing from the serial map.
296    fn write_connects(&mut self, topology: &Topology) -> Result<(), Error> {
297        let mut adjacency: HashMap<usize, Vec<usize>> = HashMap::new();
298
299        for bond in topology.bonds() {
300            let s1 = *self.atom_index_to_serial.get(&bond.a1_idx).ok_or_else(|| {
301                Error::inconsistent_data(
302                    "PDB",
303                    None,
304                    format!(
305                        "bond references atom index {} that was not written",
306                        bond.a1_idx
307                    ),
308                )
309            })?;
310            let s2 = *self.atom_index_to_serial.get(&bond.a2_idx).ok_or_else(|| {
311                Error::inconsistent_data(
312                    "PDB",
313                    None,
314                    format!(
315                        "bond references atom index {} that was not written",
316                        bond.a2_idx
317                    ),
318                )
319            })?;
320
321            adjacency.entry(s1).or_default().push(s2);
322            adjacency.entry(s2).or_default().push(s1);
323        }
324
325        let mut serials: Vec<_> = adjacency.keys().copied().collect();
326        serials.sort();
327
328        for src_serial in serials {
329            let targets = adjacency.get(&src_serial).unwrap();
330            let mut targets = targets.clone();
331            targets.sort();
332            targets.dedup();
333
334            for chunk in targets.chunks(4) {
335                write!(self.writer, "CONECT{:5}", src_serial)
336                    .map_err(|e| Error::from_io(e, None))?;
337                for target in chunk {
338                    write!(self.writer, "{:5}", target).map_err(|e| Error::from_io(e, None))?;
339                }
340                writeln!(self.writer).map_err(|e| Error::from_io(e, None))?;
341            }
342        }
343
344        Ok(())
345    }
346
347    /// Appends the terminal `END` card.
348    ///
349    /// # Returns
350    ///
351    /// [`Ok`] when the line is written; [`Error`] if the underlying writer fails.
352    fn write_end(&mut self) -> Result<(), Error> {
353        writeln!(self.writer, "END   ").map_err(|e| Error::from_io(e, None))
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use crate::model::atom::Atom;
361    use crate::model::chain::Chain;
362    use crate::model::residue::Residue;
363    use crate::model::topology::{Bond, Topology};
364    use crate::model::types::{BondOrder, Element, Point, ResidueCategory, StandardResidue};
365
366    fn assert_cryst1_line(line: &str, params: (f64, f64, f64, f64, f64, f64)) {
367        assert!(line.starts_with("CRYST1"));
368        let (a, b, c, alpha, beta, gamma) = params;
369        assert!((parse_float(&line[6..15]) - a).abs() < 1e-3);
370        assert!((parse_float(&line[15..24]) - b).abs() < 1e-3);
371        assert!((parse_float(&line[24..33]) - c).abs() < 1e-3);
372        assert!((parse_float(&line[33..40]) - alpha).abs() < 1e-2);
373        assert!((parse_float(&line[40..47]) - beta).abs() < 1e-2);
374        assert!((parse_float(&line[47..54]) - gamma).abs() < 1e-2);
375    }
376
377    fn assert_atom_line(
378        line: &str,
379        record: &str,
380        serial: usize,
381        atom_name: &str,
382        res_name: &str,
383        chain_id: char,
384        res_seq: i32,
385        insertion_code: char,
386        coords: (f64, f64, f64),
387        element: &str,
388    ) {
389        assert!(line.len() >= 78, "line too short: {line}");
390        assert_eq!(&line[0..6], record);
391        assert_eq!(line[6..11].trim(), serial.to_string());
392        assert_eq!(line[12..16].trim(), atom_name);
393        assert_eq!(line[17..20].trim(), res_name);
394        assert_eq!(line.chars().nth(21).unwrap(), chain_id);
395        assert_eq!(line[22..26].trim(), res_seq.to_string());
396        assert_eq!(line.chars().nth(26).unwrap(), insertion_code);
397        assert!((parse_float(&line[30..38]) - coords.0).abs() < 1e-3);
398        assert!((parse_float(&line[38..46]) - coords.1).abs() < 1e-3);
399        assert!((parse_float(&line[46..54]) - coords.2).abs() < 1e-3);
400        assert_eq!(line[76..78].trim(), element);
401    }
402
403    fn assert_ter_line(
404        line: &str,
405        serial: usize,
406        res_name: &str,
407        chain_id: char,
408        res_seq: i32,
409        insertion_code: char,
410    ) {
411        assert!(line.starts_with("TER   "));
412        assert_eq!(line[6..11].trim(), serial.to_string());
413        assert_eq!(line[17..20].trim(), res_name);
414        assert_eq!(line.chars().nth(21).unwrap(), chain_id);
415        assert_eq!(line[22..26].trim(), res_seq.to_string());
416        assert_eq!(line.chars().nth(26).unwrap(), insertion_code);
417    }
418
419    fn assert_conect_line(line: &str, source: usize, targets: &[usize]) {
420        assert!(line.starts_with("CONECT"));
421        let tokens: Vec<_> = line.split_whitespace().collect();
422        assert_eq!(tokens[0], "CONECT");
423        assert_eq!(tokens[1].parse::<usize>().unwrap(), source);
424        let parsed_targets: Vec<_> = tokens[2..]
425            .iter()
426            .map(|tok| tok.parse::<usize>().unwrap())
427            .collect();
428        assert_eq!(parsed_targets, targets);
429    }
430
431    fn parse_float(slice: &str) -> f64 {
432        slice.trim().parse::<f64>().expect("valid float")
433    }
434
435    #[test]
436    fn write_structure_emits_cryst1_atoms_ter_and_end() {
437        let mut structure = Structure::new();
438        structure.box_vectors = Some([[10.0, 0.0, 0.0], [0.0, 11.0, 0.0], [0.0, 0.0, 12.0]]);
439
440        let mut chain = Chain::new("A");
441
442        let mut gly = Residue::new(
443            1,
444            None,
445            "GLY",
446            Some(StandardResidue::GLY),
447            ResidueCategory::Standard,
448        );
449        gly.add_atom(Atom::new("N", Element::N, Point::new(1.0, 2.0, 3.0)));
450        gly.add_atom(Atom::new("CA", Element::C, Point::new(1.5, 2.5, 3.5)));
451
452        let mut lig = Residue::new(2, None, "LIG", None, ResidueCategory::Hetero);
453        lig.add_atom(Atom::new("C1", Element::C, Point::new(4.0, 5.0, 6.0)));
454
455        chain.add_residue(gly);
456        chain.add_residue(lig);
457        structure.add_chain(chain);
458
459        let mut buffer = Vec::new();
460        write_structure(&mut buffer, &structure).expect("writer should succeed");
461
462        let output = String::from_utf8(buffer).expect("valid UTF-8");
463        let lines: Vec<&str> = output.lines().collect();
464        assert_eq!(lines.len(), 6, "unexpected number of lines: {lines:?}");
465
466        assert_cryst1_line(lines[0], (10.0, 11.0, 12.0, 90.0, 90.0, 90.0));
467        assert_atom_line(
468            lines[1],
469            "ATOM  ",
470            1,
471            "N",
472            "GLY",
473            'A',
474            1,
475            ' ',
476            (1.0, 2.0, 3.0),
477            "N",
478        );
479        assert_atom_line(
480            lines[2],
481            "ATOM  ",
482            2,
483            "CA",
484            "GLY",
485            'A',
486            1,
487            ' ',
488            (1.5, 2.5, 3.5),
489            "C",
490        );
491        assert_atom_line(
492            lines[3],
493            "HETATM",
494            3,
495            "C1",
496            "LIG",
497            'A',
498            2,
499            ' ',
500            (4.0, 5.0, 6.0),
501            "C",
502        );
503        assert_ter_line(lines[4], 4, "GLY", 'A', 1, ' ');
504        assert_eq!(lines[5], "END   ");
505    }
506
507    #[test]
508    fn write_structure_without_box_starts_with_atom_records() {
509        let mut structure = Structure::new();
510        let mut chain = Chain::new("B");
511
512        let mut ser = Residue::new(
513            7,
514            Some('A'),
515            "SER",
516            Some(StandardResidue::SER),
517            ResidueCategory::Standard,
518        );
519        ser.add_atom(Atom::new("OG", Element::O, Point::new(-1.0, 0.5, 2.0)));
520        chain.add_residue(ser);
521        structure.add_chain(chain);
522
523        let mut buffer = Vec::new();
524        write_structure(&mut buffer, &structure).expect("writer should succeed");
525
526        let output = String::from_utf8(buffer).expect("valid UTF-8");
527        let mut lines = output.lines();
528
529        let first_line = lines.next().expect("at least one line");
530        assert!(first_line.starts_with("ATOM"));
531        assert!(lines.any(|line| line == "END   "));
532    }
533
534    #[test]
535    fn non_polymer_standard_residue_uses_hetatm_record() {
536        let mut structure = Structure::new();
537        let mut chain = Chain::new("W");
538
539        let mut water = Residue::new(
540            42,
541            None,
542            "HOH",
543            Some(StandardResidue::HOH),
544            ResidueCategory::Standard,
545        );
546        water.add_atom(Atom::new("O", Element::O, Point::new(0.0, 0.0, 0.0)));
547        chain.add_residue(water);
548        structure.add_chain(chain);
549
550        let mut buffer = Vec::new();
551        write_structure(&mut buffer, &structure).expect("writer should succeed");
552
553        let output = String::from_utf8(buffer).expect("valid UTF-8");
554        let first_line = output.lines().next().expect("at least one line");
555        assert!(
556            first_line.starts_with("HETATM"),
557            "line should be HETATM: {first_line}"
558        );
559    }
560
561    #[test]
562    fn write_topology_emits_conect_records() {
563        let mut structure = Structure::new();
564        let mut chain = Chain::new("C");
565        let mut ala = Residue::new(
566            10,
567            None,
568            "ALA",
569            Some(StandardResidue::ALA),
570            ResidueCategory::Standard,
571        );
572        ala.add_atom(Atom::new("N", Element::N, Point::new(0.0, 0.0, 0.0)));
573        ala.add_atom(Atom::new("CA", Element::C, Point::new(1.0, 0.0, 0.0)));
574        chain.add_residue(ala);
575        structure.add_chain(chain);
576
577        let topology = Topology::new(structure.clone(), vec![Bond::new(0, 1, BondOrder::Single)]);
578
579        let mut buffer = Vec::new();
580        write_topology(&mut buffer, &topology).expect("topology writer succeeds");
581
582        let output = String::from_utf8(buffer).expect("valid UTF-8");
583        let conect_lines: Vec<&str> = output
584            .lines()
585            .filter(|line| line.starts_with("CONECT"))
586            .collect();
587
588        assert_eq!(conect_lines.len(), 2);
589        assert_conect_line(conect_lines[0], 1, &[2]);
590        assert_conect_line(conect_lines[1], 2, &[1]);
591    }
592
593    #[test]
594    fn write_connects_returns_error_when_serial_missing() {
595        let mut structure = Structure::new();
596        let mut chain = Chain::new("D");
597        let mut gly = Residue::new(
598            5,
599            None,
600            "GLY",
601            Some(StandardResidue::GLY),
602            ResidueCategory::Standard,
603        );
604        gly.add_atom(Atom::new("N", Element::N, Point::new(0.0, 0.0, 0.0)));
605        gly.add_atom(Atom::new("CA", Element::C, Point::new(1.0, 1.0, 0.0)));
606        chain.add_residue(gly);
607        structure.add_chain(chain);
608
609        let topology = Topology::new(structure.clone(), vec![Bond::new(0, 1, BondOrder::Single)]);
610
611        let mut ctx = WriterContext::new(Vec::new());
612        ctx.write_atoms(&structure).expect("atoms should write");
613        ctx.atom_index_to_serial.clear();
614
615        let err = ctx
616            .write_connects(&topology)
617            .expect_err("missing serial map should error");
618
619        match err {
620            Error::InconsistentData { details, .. } => {
621                assert!(details.contains("bond references atom index"));
622            }
623            other => panic!("unexpected error: {other:?}"),
624        }
625    }
626}