1use std::io::Write;
11
12use num_complex::Complex;
13
14use crate::block::BLOCK_SIZE;
15use crate::block::CARD_SIZE;
16use crate::block::SPACE_FILL;
17use crate::block::ZERO_FILL;
18use crate::checksum;
19#[cfg(feature = "compression")]
20use crate::compress::{CompressOptions, compress_image, compress_table};
21use crate::data::Image;
22use crate::data::shape_product;
23use crate::endian::extend_be;
24use crate::endian::push_pq_descriptor;
25use crate::error::FitsError;
26use crate::error::Result;
27use crate::header::Header;
28use crate::keyword::key;
29#[cfg(feature = "compression")]
30use crate::table::BinTable;
31use crate::table::ColumnData;
32
33const PLACEHOLDER_CHECKSUM: &str = "0000000000000000";
36
37pub(crate) fn render_header(header: &Header) -> Vec<u8> {
40 let mut buf = Vec::with_capacity((header.cards.len() + 1) * CARD_SIZE);
41 for card in &header.cards {
42 for record in card.render_records() {
43 buf.extend_from_slice(&record);
44 }
45 }
46 let mut end = [SPACE_FILL; CARD_SIZE];
47 end[..3].copy_from_slice(b"END");
48 buf.extend_from_slice(&end);
49 pad_to_block(&mut buf, SPACE_FILL);
50 buf
51}
52
53fn pad_to_block(buf: &mut Vec<u8>, fill: u8) {
55 let rem = buf.len() % BLOCK_SIZE;
56 if rem != 0 {
57 buf.resize(buf.len() + (BLOCK_SIZE - rem), fill);
58 }
59}
60
61#[derive(Debug, Clone)]
69pub struct WriteColumn {
70 pub name: String,
71 pub unit: Option<String>,
72 pub data: ColumnData,
73 pub repeat: usize,
74 pub vla: Option<Vec<ColumnData>>,
75 pub tdim: Option<Vec<usize>>,
77 pub wide: bool,
79 pub bits: Option<usize>,
81 pub tscale: Option<f64>,
84 pub tzero: Option<f64>,
85 pub tnull: Option<i64>,
87}
88
89impl WriteColumn {
90 pub fn fixed(name: impl Into<String>, data: ColumnData, repeat: usize) -> WriteColumn {
92 WriteColumn {
93 name: name.into(),
94 unit: None,
95 data,
96 repeat,
97 vla: None,
98 tdim: None,
99 wide: false,
100 bits: None,
101 tscale: None,
102 tzero: None,
103 tnull: None,
104 }
105 }
106
107 pub fn vla(name: impl Into<String>, rows: Vec<ColumnData>) -> WriteColumn {
110 let tag = rows
112 .first()
113 .cloned()
114 .unwrap_or(ColumnData::Bytes(Vec::new()));
115 assert!(
119 rows.iter()
120 .all(|r| std::mem::discriminant(r) == std::mem::discriminant(&tag)),
121 "VLA column cells must all be the same ColumnData variant"
122 );
123 WriteColumn {
124 data: tag,
125 repeat: 0,
126 vla: Some(rows),
127 ..WriteColumn::fixed(name, ColumnData::Bytes(Vec::new()), 0)
128 }
129 }
130
131 pub fn bits(name: impl Into<String>, data: ColumnData, nbits: usize) -> WriteColumn {
135 WriteColumn {
136 bits: Some(nbits),
137 ..WriteColumn::fixed(name, data, nbits.div_ceil(8))
138 }
139 }
140
141 pub fn with_unit(mut self, unit: impl Into<String>) -> WriteColumn {
143 self.unit = Some(unit.into());
144 self
145 }
146
147 pub fn with_tdim(mut self, shape: Vec<usize>) -> WriteColumn {
149 self.tdim = Some(shape);
150 self
151 }
152
153 pub fn wide(mut self) -> WriteColumn {
155 self.wide = true;
156 self
157 }
158
159 pub fn scaled(mut self, tscale: f64, tzero: f64) -> WriteColumn {
162 self.tscale = Some(tscale);
163 self.tzero = Some(tzero);
164 self
165 }
166
167 pub fn with_null(mut self, tnull: i64) -> WriteColumn {
169 self.tnull = Some(tnull);
170 self
171 }
172}
173
174#[derive(Debug, Clone)]
177pub struct AsciiWriteColumn {
178 pub name: String,
179 pub unit: Option<String>,
180 pub data: ColumnData,
181 pub width: usize,
182 pub decimals: usize,
183 pub tscale: Option<f64>,
186 pub tzero: Option<f64>,
187 pub tnull: Option<String>,
191}
192
193#[derive(Debug)]
196pub struct FitsWriter<W> {
197 sink: W,
198 has_primary: bool,
199 checksum: bool,
200 scratch: Vec<u8>,
204}
205
206impl<W: Write> FitsWriter<W> {
207 pub fn new(sink: W) -> Self {
208 FitsWriter {
209 sink,
210 has_primary: false,
211 checksum: false,
212 scratch: Vec::new(),
213 }
214 }
215
216 pub fn with_checksums(mut self) -> Self {
220 self.checksum = true;
221 self
222 }
223
224 pub fn write_header(&mut self, header: &Header) -> Result<()> {
226 self.sink.write_all(&render_header(header))?;
227 Ok(())
228 }
229
230 pub fn write_data_unit(&mut self, raw: &[u8], fill: u8) -> Result<()> {
233 self.sink.write_all(raw)?;
234 let rem = raw.len() % BLOCK_SIZE;
235 if rem != 0 {
236 self.sink.write_all(&vec![fill; BLOCK_SIZE - rem])?;
237 }
238 Ok(())
239 }
240
241 pub fn write_image(&mut self, image: &Image) -> Result<()> {
246 let expected = shape_product(&image.shape);
247 assert_eq!(
248 image.samples.len(),
249 expected,
250 "image sample count must match the shape product"
251 );
252 let header = image_header(image, !self.has_primary);
253 self.has_primary = true;
254 self.scratch.clear();
255 image.samples.encode_into(&mut self.scratch);
256 self.write_hdu(header, ZERO_FILL)
257 }
258
259 pub fn write_table(&mut self, nrows: usize, columns: &[WriteColumn]) -> Result<()> {
264 self.ensure_primary()?;
265 let mut row_len = 0;
266 for col in columns {
267 row_len += check_column(col, nrows)?;
268 }
269 let mut heap: Vec<u8> = Vec::new();
274 let mut descs: Vec<(u64, u64)> = Vec::new();
275 for r in 0..nrows {
276 for col in columns {
277 if let Some(rows) = &col.vla {
278 let cell = &rows[r];
279 let (n, o) = (cell.element_count() as u64, heap.len() as u64);
280 if !col.wide && (n > u32::MAX as u64 || o > u32::MAX as u64) {
285 return Err(FitsError::DataUnitOverflow);
286 }
287 descs.push((n, o));
288 append_be(&mut heap, cell);
289 }
290 }
291 }
292 self.scratch.clear();
296 self.scratch.reserve(nrows * row_len + heap.len());
297 let mut descs = descs.into_iter();
298 for r in 0..nrows {
299 for col in columns {
300 if col.vla.is_some() {
301 let (n, o) = descs.next().expect("one descriptor per VLA cell");
302 push_pq_descriptor(&mut self.scratch, col.wide, n, o);
303 } else {
304 pack_cell(&mut self.scratch, col, r);
305 }
306 }
307 }
308 self.scratch.extend_from_slice(&heap);
309 let header = bintable_header(nrows, row_len, columns, heap.len());
310 self.write_hdu(header, ZERO_FILL)
311 }
312
313 pub fn write_ascii_table(&mut self, nrows: usize, columns: &[AsciiWriteColumn]) -> Result<()> {
317 self.ensure_primary()?;
318 let mut tbcols = Vec::with_capacity(columns.len());
319 let mut row_len = 0;
320 for col in columns {
321 let count = ascii_count(&col.data)?;
322 if count != nrows {
323 return Err(FitsError::RowWidthMismatch {
324 computed: count,
325 declared: nrows,
326 });
327 }
328 tbcols.push(row_len + 1); row_len += col.width;
330 }
331 let header = ascii_table_header(nrows, row_len, columns, &tbcols);
332 self.scratch.clear();
333 self.scratch.reserve(nrows * row_len);
334 for r in 0..nrows {
335 for col in columns {
336 format_ascii_field(&mut self.scratch, col, r);
337 }
338 }
339 self.write_hdu(header, SPACE_FILL)
340 }
341
342 #[cfg(feature = "compression")]
350 pub fn write_compressed_image(
351 &mut self,
352 image: &Image,
353 cmptype: &str,
354 options: &CompressOptions,
355 ) -> Result<()> {
356 self.ensure_primary()?;
357 let header = compress_image(image, cmptype, options, &mut self.scratch)?;
360 self.write_hdu(header, ZERO_FILL)
361 }
362
363 #[cfg(feature = "compression")]
368 pub fn write_compressed_table(
369 &mut self,
370 header: &Header,
371 table: &BinTable,
372 rows_per_tile: usize,
373 algo: &str,
374 ) -> Result<()> {
375 self.ensure_primary()?;
376 let zheader = compress_table(header, table, rows_per_tile, algo, &mut self.scratch)?;
377 self.write_hdu(zheader, ZERO_FILL)
378 }
379
380 fn ensure_primary(&mut self) -> Result<()> {
383 if !self.has_primary {
384 self.scratch.clear();
385 self.write_hdu(empty_primary_header(), ZERO_FILL)?;
386 self.has_primary = true;
387 }
388 Ok(())
389 }
390
391 fn write_hdu(&mut self, mut header: Header, fill: u8) -> Result<()> {
398 pad_to_block(&mut self.scratch, fill);
399 if self.checksum {
400 header.set(
401 "DATASUM",
402 checksum::accumulate(&self.scratch, 0).to_string(),
403 );
404 header.set("CHECKSUM", PLACEHOLDER_CHECKSUM);
405 }
406 let mut header_bytes = render_header(&header);
407 if self.checksum {
408 let hdu_sum =
411 checksum::accumulate(&self.scratch, checksum::accumulate(&header_bytes, 0));
412 patch_checksum(&mut header_bytes, &checksum::encode(hdu_sum, true));
413 }
414 self.sink.write_all(&header_bytes)?;
415 self.sink.write_all(&self.scratch)?;
416 Ok(())
417 }
418
419 pub fn into_inner(self) -> W {
424 self.sink
425 }
426}
427
428fn empty_primary_header() -> Header {
431 let mut header = Header::new();
432 header
433 .set("SIMPLE", true)
434 .comment("SIMPLE", "file conforms to FITS standard");
435 header.set("BITPIX", 8).set("NAXIS", 0);
436 header
437 .set("EXTEND", true)
438 .comment("EXTEND", "extensions follow");
439 header
440}
441
442fn image_header(image: &Image, primary: bool) -> Header {
446 let mut header = Header::new();
447 if primary {
448 header
449 .set("SIMPLE", true)
450 .comment("SIMPLE", "file conforms to FITS standard");
451 add_image_axes(&mut header, image);
452 header
453 .set("EXTEND", true)
454 .comment("EXTEND", "extensions may follow");
455 } else {
456 header
457 .set("XTENSION", "IMAGE")
458 .comment("XTENSION", "image extension");
459 add_image_axes(&mut header, image);
460 header.set("PCOUNT", 0).set("GCOUNT", 1);
461 }
462 add_scaling(&mut header, image);
463 header
464}
465
466fn add_image_axes(header: &mut Header, image: &Image) {
468 header
469 .set("BITPIX", image.samples.bitpix().code())
470 .comment("BITPIX", "number of bits per data pixel");
471 header
472 .set("NAXIS", image.shape.len() as i64)
473 .comment("NAXIS", "number of data axes");
474 for (i, &n) in image.shape.iter().enumerate() {
475 header.set(key!("NAXIS{}", i + 1).as_str(), n as i64);
476 }
477}
478
479fn add_scaling(header: &mut Header, image: &Image) {
482 if !image.scaling.is_identity() {
483 header.set("BZERO", image.scaling.bzero);
484 header.set("BSCALE", image.scaling.bscale);
485 }
486 if let Some(blank) = image.scaling.blank
488 && image.samples.bitpix().is_integer()
489 {
490 header.set("BLANK", blank);
491 }
492}
493
494fn bintable_header(
496 nrows: usize,
497 row_len: usize,
498 columns: &[WriteColumn],
499 heap_len: usize,
500) -> Header {
501 let mut header = Header::new();
502 header
503 .set("XTENSION", "BINTABLE")
504 .comment("XTENSION", "binary table extension");
505 header.set("BITPIX", 8).set("NAXIS", 2);
506 header
507 .set("NAXIS1", row_len as i64)
508 .comment("NAXIS1", "width of table in bytes");
509 header
510 .set("NAXIS2", nrows as i64)
511 .comment("NAXIS2", "number of rows");
512 header.set("PCOUNT", heap_len as i64).set("GCOUNT", 1);
513 header
514 .set("TFIELDS", columns.len() as i64)
515 .comment("TFIELDS", "number of columns");
516 for (i, col) in columns.iter().enumerate() {
517 let n = i + 1;
518 header.set(key!("TFORM{n}").as_str(), tform_of(col));
519 header.set(key!("TTYPE{n}").as_str(), col.name.as_str());
520 if let Some(unit) = &col.unit {
521 header.set(key!("TUNIT{n}").as_str(), unit.as_str());
522 }
523 if let Some(shape) = &col.tdim {
524 let dims: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
525 header.set(key!("TDIM{n}").as_str(), format!("({})", dims.join(",")));
526 }
527 if let Some(tscale) = col.tscale {
528 header.set(key!("TSCAL{n}").as_str(), tscale);
529 }
530 if let Some(tzero) = col.tzero {
531 header.set(key!("TZERO{n}").as_str(), tzero);
532 }
533 if let Some(tnull) = col.tnull {
534 header.set(key!("TNULL{n}").as_str(), tnull);
535 }
536 }
537 header
538}
539
540#[derive(Debug, Clone, Copy)]
542struct ColumnCode {
543 letter: char,
544 elem_size: usize,
545}
546
547fn column_code(data: &ColumnData) -> ColumnCode {
548 let (letter, elem_size) = match data {
549 ColumnData::Logical(_) => ('L', 1),
550 ColumnData::Bytes(_) => ('B', 1),
551 ColumnData::I16(_) => ('I', 2),
552 ColumnData::I32(_) => ('J', 4),
553 ColumnData::I64(_) => ('K', 8),
554 ColumnData::F32(_) => ('E', 4),
555 ColumnData::F64(_) => ('D', 8),
556 ColumnData::ComplexF32(_) => ('C', 8),
557 ColumnData::ComplexF64(_) => ('M', 16),
558 ColumnData::Text(_) => ('A', 1),
559 };
560 ColumnCode { letter, elem_size }
561}
562
563fn tform_of(col: &WriteColumn) -> String {
564 let code = column_code(&col.data).letter;
565 if let Some(nbits) = col.bits {
566 return format!("{nbits}X");
567 }
568 match &col.vla {
569 Some(rows) => {
571 let max = rows
572 .iter()
573 .map(ColumnData::element_count)
574 .max()
575 .unwrap_or(0);
576 let p = if col.wide { 'Q' } else { 'P' };
577 format!("1{p}{code}({max})")
578 }
579 None => format!("{}{}", col.repeat, code),
580 }
581}
582
583fn check_column(col: &WriteColumn, nrows: usize) -> Result<usize> {
585 let elem = column_code(&col.data).elem_size;
586 if let Some(rows) = &col.vla {
587 if rows.len() != nrows {
588 return Err(FitsError::RowWidthMismatch {
589 computed: rows.len(),
590 declared: nrows,
591 });
592 }
593 return Ok(if col.wide { 16 } else { 8 });
595 }
596 let mismatch = || FitsError::RowWidthMismatch {
597 computed: col.data.element_count(),
598 declared: nrows * col.repeat,
599 };
600 match &col.data {
601 ColumnData::Text(v) => {
602 if v.len() != nrows {
603 return Err(FitsError::RowWidthMismatch {
604 computed: v.len(),
605 declared: nrows,
606 });
607 }
608 Ok(col.repeat) }
610 _ => {
611 if col.data.element_count() != nrows * col.repeat {
612 return Err(mismatch());
613 }
614 Ok(col.repeat * elem)
615 }
616 }
617}
618
619fn append_be(out: &mut Vec<u8>, cell: &ColumnData) {
621 match cell {
622 ColumnData::Logical(v) => out.extend(v.iter().map(|&b| match b {
623 Some(true) => b'T',
624 Some(false) => b'F',
625 None => 0, })),
627 ColumnData::Bytes(v) => out.extend_from_slice(v),
628 ColumnData::I16(v) => extend_be(out, v, i16::to_be_bytes),
629 ColumnData::I32(v) => extend_be(out, v, i32::to_be_bytes),
630 ColumnData::I64(v) => extend_be(out, v, i64::to_be_bytes),
631 ColumnData::F32(v) => extend_be(out, v, f32::to_be_bytes),
632 ColumnData::F64(v) => extend_be(out, v, f64::to_be_bytes),
633 ColumnData::ComplexF32(v) => {
634 for &Complex { re, im } in v {
635 out.extend_from_slice(&re.to_be_bytes());
636 out.extend_from_slice(&im.to_be_bytes());
637 }
638 }
639 ColumnData::ComplexF64(v) => {
640 for &Complex { re, im } in v {
641 out.extend_from_slice(&re.to_be_bytes());
642 out.extend_from_slice(&im.to_be_bytes());
643 }
644 }
645 ColumnData::Text(v) => {
647 for s in v {
648 out.extend_from_slice(s.as_bytes());
649 }
650 }
651 }
652}
653
654fn pack_cell(out: &mut Vec<u8>, col: &WriteColumn, r: usize) {
655 let rep = col.repeat;
656 let base = r * rep;
657 match &col.data {
658 ColumnData::Logical(v) => {
659 for k in 0..rep {
660 out.push(match v[base + k] {
661 Some(true) => b'T',
662 Some(false) => b'F',
663 None => 0, });
665 }
666 }
667 ColumnData::Bytes(v) => out.extend_from_slice(&v[base..base + rep]),
668 ColumnData::I16(v) => extend_be(out, &v[base..base + rep], i16::to_be_bytes),
669 ColumnData::I32(v) => extend_be(out, &v[base..base + rep], i32::to_be_bytes),
670 ColumnData::I64(v) => extend_be(out, &v[base..base + rep], i64::to_be_bytes),
671 ColumnData::F32(v) => extend_be(out, &v[base..base + rep], f32::to_be_bytes),
672 ColumnData::F64(v) => extend_be(out, &v[base..base + rep], f64::to_be_bytes),
673 ColumnData::ComplexF32(v) => {
674 for &Complex { re, im } in &v[base..base + rep] {
675 out.extend_from_slice(&re.to_be_bytes());
676 out.extend_from_slice(&im.to_be_bytes());
677 }
678 }
679 ColumnData::ComplexF64(v) => {
680 for &Complex { re, im } in &v[base..base + rep] {
681 out.extend_from_slice(&re.to_be_bytes());
682 out.extend_from_slice(&im.to_be_bytes());
683 }
684 }
685 ColumnData::Text(v) => {
687 let bytes = v[r].as_bytes();
688 let n = bytes.len().min(rep);
689 out.extend_from_slice(&bytes[..n]);
690 out.extend(std::iter::repeat_n(b' ', rep - n));
691 }
692 }
693}
694
695fn patch_checksum(header_bytes: &mut [u8], encoded: &[u8; 16]) {
698 for card in header_bytes.chunks_exact_mut(CARD_SIZE) {
699 if &card[..8] == b"CHECKSUM" {
700 card[11..27].copy_from_slice(encoded);
701 return;
702 }
703 }
704}
705
706fn ascii_count(data: &ColumnData) -> Result<usize> {
708 match data {
709 ColumnData::Text(v) => Ok(v.len()),
710 ColumnData::I64(v) => Ok(v.len()),
711 ColumnData::F64(v) => Ok(v.len()),
712 _ => Err(FitsError::InvalidValue {
713 card: "ASCII table column must be Text, I64, or F64".to_string(),
714 }),
715 }
716}
717
718fn ascii_table_header(
720 nrows: usize,
721 row_len: usize,
722 columns: &[AsciiWriteColumn],
723 tbcols: &[usize],
724) -> Header {
725 let mut header = Header::new();
726 header
727 .set("XTENSION", "TABLE")
728 .comment("XTENSION", "ASCII table extension");
729 header.set("BITPIX", 8).set("NAXIS", 2);
730 header
731 .set("NAXIS1", row_len as i64)
732 .comment("NAXIS1", "width of table in characters");
733 header
734 .set("NAXIS2", nrows as i64)
735 .comment("NAXIS2", "number of rows");
736 header.set("PCOUNT", 0).set("GCOUNT", 1);
737 header
738 .set("TFIELDS", columns.len() as i64)
739 .comment("TFIELDS", "number of columns");
740 for (i, col) in columns.iter().enumerate() {
741 let n = i + 1;
742 header.set(key!("TBCOL{n}").as_str(), tbcols[i] as i64);
743 header.set(key!("TFORM{n}").as_str(), ascii_tform(col));
744 header.set(key!("TTYPE{n}").as_str(), col.name.as_str());
745 if let Some(unit) = &col.unit {
746 header.set(key!("TUNIT{n}").as_str(), unit.as_str());
747 }
748 if let Some(tscale) = col.tscale {
749 header.set(key!("TSCAL{n}").as_str(), tscale);
750 }
751 if let Some(tzero) = col.tzero {
752 header.set(key!("TZERO{n}").as_str(), tzero);
753 }
754 if let Some(tnull) = &col.tnull {
755 header.set(key!("TNULL{n}").as_str(), tnull.as_str());
756 }
757 }
758 header
759}
760
761fn ascii_tform(col: &AsciiWriteColumn) -> String {
762 match col.data {
763 ColumnData::Text(_) => format!("A{}", col.width),
764 ColumnData::I64(_) => format!("I{}", col.width),
765 ColumnData::F64(_) => format!("F{}.{}", col.width, col.decimals),
766 _ => format!("A{}", col.width), }
768}
769
770fn format_ascii_field(out: &mut Vec<u8>, col: &AsciiWriteColumn, r: usize) {
773 let (text, left) = match &col.data {
774 ColumnData::Text(v) => (v[r].clone(), true),
775 ColumnData::I64(v) => (v[r].to_string(), false),
776 ColumnData::F64(v) if !v[r].is_finite() => (col.tnull.clone().unwrap_or_default(), false),
779 ColumnData::F64(v) => (format!("{:.*}", col.decimals, v[r]), false),
780 _ => (String::new(), true),
781 };
782 let bytes = text.as_bytes();
783 if bytes.len() > col.width {
784 out.extend(std::iter::repeat_n(b'*', col.width));
785 return;
786 }
787 let pad = col.width - bytes.len();
788 if left {
789 out.extend_from_slice(bytes);
790 out.extend(std::iter::repeat_n(b' ', pad));
791 } else {
792 out.extend(std::iter::repeat_n(b' ', pad));
793 out.extend_from_slice(bytes);
794 }
795}
796
797#[cfg(test)]
798mod tests;