use std::io::Write;
use num_complex::Complex;
use crate::block::BLOCK_SIZE;
use crate::block::CARD_SIZE;
use crate::block::SPACE_FILL;
use crate::block::ZERO_FILL;
use crate::checksum;
#[cfg(feature = "compression")]
use crate::compress::{CompressOptions, compress_image, compress_table};
use crate::data::Image;
use crate::data::shape_product;
use crate::endian::extend_be;
use crate::endian::push_pq_descriptor;
use crate::error::FitsError;
use crate::error::Result;
use crate::header::Header;
use crate::keyword::key;
#[cfg(feature = "compression")]
use crate::table::BinTable;
use crate::table::ColumnData;
const PLACEHOLDER_CHECKSUM: &str = "0000000000000000";
pub(crate) fn render_header(header: &Header) -> Vec<u8> {
let mut buf = Vec::with_capacity((header.cards.len() + 1) * CARD_SIZE);
for card in &header.cards {
for record in card.render_records() {
buf.extend_from_slice(&record);
}
}
let mut end = [SPACE_FILL; CARD_SIZE];
end[..3].copy_from_slice(b"END");
buf.extend_from_slice(&end);
pad_to_block(&mut buf, SPACE_FILL);
buf
}
fn pad_to_block(buf: &mut Vec<u8>, fill: u8) {
let rem = buf.len() % BLOCK_SIZE;
if rem != 0 {
buf.resize(buf.len() + (BLOCK_SIZE - rem), fill);
}
}
#[derive(Debug, Clone)]
pub struct WriteColumn {
pub name: String,
pub unit: Option<String>,
pub data: ColumnData,
pub repeat: usize,
pub vla: Option<Vec<ColumnData>>,
pub tdim: Option<Vec<usize>>,
pub wide: bool,
pub bits: Option<usize>,
pub tscale: Option<f64>,
pub tzero: Option<f64>,
pub tnull: Option<i64>,
}
impl WriteColumn {
pub fn fixed(name: impl Into<String>, data: ColumnData, repeat: usize) -> WriteColumn {
WriteColumn {
name: name.into(),
unit: None,
data,
repeat,
vla: None,
tdim: None,
wide: false,
bits: None,
tscale: None,
tzero: None,
tnull: None,
}
}
pub fn vla(name: impl Into<String>, rows: Vec<ColumnData>) -> WriteColumn {
let tag = rows
.first()
.cloned()
.unwrap_or(ColumnData::Bytes(Vec::new()));
assert!(
rows.iter()
.all(|r| std::mem::discriminant(r) == std::mem::discriminant(&tag)),
"VLA column cells must all be the same ColumnData variant"
);
WriteColumn {
data: tag,
repeat: 0,
vla: Some(rows),
..WriteColumn::fixed(name, ColumnData::Bytes(Vec::new()), 0)
}
}
pub fn bits(name: impl Into<String>, data: ColumnData, nbits: usize) -> WriteColumn {
WriteColumn {
bits: Some(nbits),
..WriteColumn::fixed(name, data, nbits.div_ceil(8))
}
}
pub fn with_unit(mut self, unit: impl Into<String>) -> WriteColumn {
self.unit = Some(unit.into());
self
}
pub fn with_tdim(mut self, shape: Vec<usize>) -> WriteColumn {
self.tdim = Some(shape);
self
}
pub fn wide(mut self) -> WriteColumn {
self.wide = true;
self
}
pub fn scaled(mut self, tscale: f64, tzero: f64) -> WriteColumn {
self.tscale = Some(tscale);
self.tzero = Some(tzero);
self
}
pub fn with_null(mut self, tnull: i64) -> WriteColumn {
self.tnull = Some(tnull);
self
}
}
#[derive(Debug, Clone)]
pub struct AsciiWriteColumn {
pub name: String,
pub unit: Option<String>,
pub data: ColumnData,
pub width: usize,
pub decimals: usize,
pub tscale: Option<f64>,
pub tzero: Option<f64>,
pub tnull: Option<String>,
}
#[derive(Debug)]
pub struct FitsWriter<W> {
sink: W,
has_primary: bool,
checksum: bool,
scratch: Vec<u8>,
}
impl<W: Write> FitsWriter<W> {
pub fn new(sink: W) -> Self {
FitsWriter {
sink,
has_primary: false,
checksum: false,
scratch: Vec::new(),
}
}
pub fn with_checksums(mut self) -> Self {
self.checksum = true;
self
}
pub fn write_header(&mut self, header: &Header) -> Result<()> {
self.sink.write_all(&render_header(header))?;
Ok(())
}
pub fn write_data_unit(&mut self, raw: &[u8], fill: u8) -> Result<()> {
self.sink.write_all(raw)?;
let rem = raw.len() % BLOCK_SIZE;
if rem != 0 {
self.sink.write_all(&vec![fill; BLOCK_SIZE - rem])?;
}
Ok(())
}
pub fn write_image(&mut self, image: &Image) -> Result<()> {
let expected = shape_product(&image.shape);
assert_eq!(
image.samples.len(),
expected,
"image sample count must match the shape product"
);
let header = image_header(image, !self.has_primary);
self.has_primary = true;
self.scratch.clear();
image.samples.encode_into(&mut self.scratch);
self.write_hdu(header, ZERO_FILL)
}
pub fn write_table(&mut self, nrows: usize, columns: &[WriteColumn]) -> Result<()> {
self.ensure_primary()?;
let mut row_len = 0;
for col in columns {
row_len += check_column(col, nrows)?;
}
let mut heap: Vec<u8> = Vec::new();
let mut descs: Vec<(u64, u64)> = Vec::new();
for r in 0..nrows {
for col in columns {
if let Some(rows) = &col.vla {
let cell = &rows[r];
let (n, o) = (cell.element_count() as u64, heap.len() as u64);
if !col.wide && (n > u32::MAX as u64 || o > u32::MAX as u64) {
return Err(FitsError::DataUnitOverflow);
}
descs.push((n, o));
append_be(&mut heap, cell);
}
}
}
self.scratch.clear();
self.scratch.reserve(nrows * row_len + heap.len());
let mut descs = descs.into_iter();
for r in 0..nrows {
for col in columns {
if col.vla.is_some() {
let (n, o) = descs.next().expect("one descriptor per VLA cell");
push_pq_descriptor(&mut self.scratch, col.wide, n, o);
} else {
pack_cell(&mut self.scratch, col, r);
}
}
}
self.scratch.extend_from_slice(&heap);
let header = bintable_header(nrows, row_len, columns, heap.len());
self.write_hdu(header, ZERO_FILL)
}
pub fn write_ascii_table(&mut self, nrows: usize, columns: &[AsciiWriteColumn]) -> Result<()> {
self.ensure_primary()?;
let mut tbcols = Vec::with_capacity(columns.len());
let mut row_len = 0;
for col in columns {
let count = ascii_count(&col.data)?;
if count != nrows {
return Err(FitsError::RowWidthMismatch {
computed: count,
declared: nrows,
});
}
tbcols.push(row_len + 1); row_len += col.width;
}
let header = ascii_table_header(nrows, row_len, columns, &tbcols);
self.scratch.clear();
self.scratch.reserve(nrows * row_len);
for r in 0..nrows {
for col in columns {
format_ascii_field(&mut self.scratch, col, r);
}
}
self.write_hdu(header, SPACE_FILL)
}
#[cfg(feature = "compression")]
pub fn write_compressed_image(
&mut self,
image: &Image,
cmptype: &str,
options: &CompressOptions,
) -> Result<()> {
self.ensure_primary()?;
let header = compress_image(image, cmptype, options, &mut self.scratch)?;
self.write_hdu(header, ZERO_FILL)
}
#[cfg(feature = "compression")]
pub fn write_compressed_table(
&mut self,
header: &Header,
table: &BinTable,
rows_per_tile: usize,
algo: &str,
) -> Result<()> {
self.ensure_primary()?;
let zheader = compress_table(header, table, rows_per_tile, algo, &mut self.scratch)?;
self.write_hdu(zheader, ZERO_FILL)
}
fn ensure_primary(&mut self) -> Result<()> {
if !self.has_primary {
self.scratch.clear();
self.write_hdu(empty_primary_header(), ZERO_FILL)?;
self.has_primary = true;
}
Ok(())
}
fn write_hdu(&mut self, mut header: Header, fill: u8) -> Result<()> {
pad_to_block(&mut self.scratch, fill);
if self.checksum {
header.set(
"DATASUM",
checksum::accumulate(&self.scratch, 0).to_string(),
);
header.set("CHECKSUM", PLACEHOLDER_CHECKSUM);
}
let mut header_bytes = render_header(&header);
if self.checksum {
let hdu_sum =
checksum::accumulate(&self.scratch, checksum::accumulate(&header_bytes, 0));
patch_checksum(&mut header_bytes, &checksum::encode(hdu_sum, true));
}
self.sink.write_all(&header_bytes)?;
self.sink.write_all(&self.scratch)?;
Ok(())
}
pub fn into_inner(self) -> W {
self.sink
}
}
fn empty_primary_header() -> Header {
let mut header = Header::new();
header
.set("SIMPLE", true)
.comment("SIMPLE", "file conforms to FITS standard");
header.set("BITPIX", 8).set("NAXIS", 0);
header
.set("EXTEND", true)
.comment("EXTEND", "extensions follow");
header
}
fn image_header(image: &Image, primary: bool) -> Header {
let mut header = Header::new();
if primary {
header
.set("SIMPLE", true)
.comment("SIMPLE", "file conforms to FITS standard");
add_image_axes(&mut header, image);
header
.set("EXTEND", true)
.comment("EXTEND", "extensions may follow");
} else {
header
.set("XTENSION", "IMAGE")
.comment("XTENSION", "image extension");
add_image_axes(&mut header, image);
header.set("PCOUNT", 0).set("GCOUNT", 1);
}
add_scaling(&mut header, image);
header
}
fn add_image_axes(header: &mut Header, image: &Image) {
header
.set("BITPIX", image.samples.bitpix().code())
.comment("BITPIX", "number of bits per data pixel");
header
.set("NAXIS", image.shape.len() as i64)
.comment("NAXIS", "number of data axes");
for (i, &n) in image.shape.iter().enumerate() {
header.set(key!("NAXIS{}", i + 1).as_str(), n as i64);
}
}
fn add_scaling(header: &mut Header, image: &Image) {
if !image.scaling.is_identity() {
header.set("BZERO", image.scaling.bzero);
header.set("BSCALE", image.scaling.bscale);
}
if let Some(blank) = image.scaling.blank
&& image.samples.bitpix().is_integer()
{
header.set("BLANK", blank);
}
}
fn bintable_header(
nrows: usize,
row_len: usize,
columns: &[WriteColumn],
heap_len: usize,
) -> Header {
let mut header = Header::new();
header
.set("XTENSION", "BINTABLE")
.comment("XTENSION", "binary table extension");
header.set("BITPIX", 8).set("NAXIS", 2);
header
.set("NAXIS1", row_len as i64)
.comment("NAXIS1", "width of table in bytes");
header
.set("NAXIS2", nrows as i64)
.comment("NAXIS2", "number of rows");
header.set("PCOUNT", heap_len as i64).set("GCOUNT", 1);
header
.set("TFIELDS", columns.len() as i64)
.comment("TFIELDS", "number of columns");
for (i, col) in columns.iter().enumerate() {
let n = i + 1;
header.set(key!("TFORM{n}").as_str(), tform_of(col));
header.set(key!("TTYPE{n}").as_str(), col.name.as_str());
if let Some(unit) = &col.unit {
header.set(key!("TUNIT{n}").as_str(), unit.as_str());
}
if let Some(shape) = &col.tdim {
let dims: Vec<String> = shape.iter().map(|d| d.to_string()).collect();
header.set(key!("TDIM{n}").as_str(), format!("({})", dims.join(",")));
}
if let Some(tscale) = col.tscale {
header.set(key!("TSCAL{n}").as_str(), tscale);
}
if let Some(tzero) = col.tzero {
header.set(key!("TZERO{n}").as_str(), tzero);
}
if let Some(tnull) = col.tnull {
header.set(key!("TNULL{n}").as_str(), tnull);
}
}
header
}
#[derive(Debug, Clone, Copy)]
struct ColumnCode {
letter: char,
elem_size: usize,
}
fn column_code(data: &ColumnData) -> ColumnCode {
let (letter, elem_size) = match data {
ColumnData::Logical(_) => ('L', 1),
ColumnData::Bytes(_) => ('B', 1),
ColumnData::I16(_) => ('I', 2),
ColumnData::I32(_) => ('J', 4),
ColumnData::I64(_) => ('K', 8),
ColumnData::F32(_) => ('E', 4),
ColumnData::F64(_) => ('D', 8),
ColumnData::ComplexF32(_) => ('C', 8),
ColumnData::ComplexF64(_) => ('M', 16),
ColumnData::Text(_) => ('A', 1),
};
ColumnCode { letter, elem_size }
}
fn tform_of(col: &WriteColumn) -> String {
let code = column_code(&col.data).letter;
if let Some(nbits) = col.bits {
return format!("{nbits}X");
}
match &col.vla {
Some(rows) => {
let max = rows
.iter()
.map(ColumnData::element_count)
.max()
.unwrap_or(0);
let p = if col.wide { 'Q' } else { 'P' };
format!("1{p}{code}({max})")
}
None => format!("{}{}", col.repeat, code),
}
}
fn check_column(col: &WriteColumn, nrows: usize) -> Result<usize> {
let elem = column_code(&col.data).elem_size;
if let Some(rows) = &col.vla {
if rows.len() != nrows {
return Err(FitsError::RowWidthMismatch {
computed: rows.len(),
declared: nrows,
});
}
return Ok(if col.wide { 16 } else { 8 });
}
let mismatch = || FitsError::RowWidthMismatch {
computed: col.data.element_count(),
declared: nrows * col.repeat,
};
match &col.data {
ColumnData::Text(v) => {
if v.len() != nrows {
return Err(FitsError::RowWidthMismatch {
computed: v.len(),
declared: nrows,
});
}
Ok(col.repeat) }
_ => {
if col.data.element_count() != nrows * col.repeat {
return Err(mismatch());
}
Ok(col.repeat * elem)
}
}
}
fn append_be(out: &mut Vec<u8>, cell: &ColumnData) {
match cell {
ColumnData::Logical(v) => out.extend(v.iter().map(|&b| match b {
Some(true) => b'T',
Some(false) => b'F',
None => 0, })),
ColumnData::Bytes(v) => out.extend_from_slice(v),
ColumnData::I16(v) => extend_be(out, v, i16::to_be_bytes),
ColumnData::I32(v) => extend_be(out, v, i32::to_be_bytes),
ColumnData::I64(v) => extend_be(out, v, i64::to_be_bytes),
ColumnData::F32(v) => extend_be(out, v, f32::to_be_bytes),
ColumnData::F64(v) => extend_be(out, v, f64::to_be_bytes),
ColumnData::ComplexF32(v) => {
for &Complex { re, im } in v {
out.extend_from_slice(&re.to_be_bytes());
out.extend_from_slice(&im.to_be_bytes());
}
}
ColumnData::ComplexF64(v) => {
for &Complex { re, im } in v {
out.extend_from_slice(&re.to_be_bytes());
out.extend_from_slice(&im.to_be_bytes());
}
}
ColumnData::Text(v) => {
for s in v {
out.extend_from_slice(s.as_bytes());
}
}
}
}
fn pack_cell(out: &mut Vec<u8>, col: &WriteColumn, r: usize) {
let rep = col.repeat;
let base = r * rep;
match &col.data {
ColumnData::Logical(v) => {
for k in 0..rep {
out.push(match v[base + k] {
Some(true) => b'T',
Some(false) => b'F',
None => 0, });
}
}
ColumnData::Bytes(v) => out.extend_from_slice(&v[base..base + rep]),
ColumnData::I16(v) => extend_be(out, &v[base..base + rep], i16::to_be_bytes),
ColumnData::I32(v) => extend_be(out, &v[base..base + rep], i32::to_be_bytes),
ColumnData::I64(v) => extend_be(out, &v[base..base + rep], i64::to_be_bytes),
ColumnData::F32(v) => extend_be(out, &v[base..base + rep], f32::to_be_bytes),
ColumnData::F64(v) => extend_be(out, &v[base..base + rep], f64::to_be_bytes),
ColumnData::ComplexF32(v) => {
for &Complex { re, im } in &v[base..base + rep] {
out.extend_from_slice(&re.to_be_bytes());
out.extend_from_slice(&im.to_be_bytes());
}
}
ColumnData::ComplexF64(v) => {
for &Complex { re, im } in &v[base..base + rep] {
out.extend_from_slice(&re.to_be_bytes());
out.extend_from_slice(&im.to_be_bytes());
}
}
ColumnData::Text(v) => {
let bytes = v[r].as_bytes();
let n = bytes.len().min(rep);
out.extend_from_slice(&bytes[..n]);
out.extend(std::iter::repeat_n(b' ', rep - n));
}
}
}
fn patch_checksum(header_bytes: &mut [u8], encoded: &[u8; 16]) {
for card in header_bytes.chunks_exact_mut(CARD_SIZE) {
if &card[..8] == b"CHECKSUM" {
card[11..27].copy_from_slice(encoded);
return;
}
}
}
fn ascii_count(data: &ColumnData) -> Result<usize> {
match data {
ColumnData::Text(v) => Ok(v.len()),
ColumnData::I64(v) => Ok(v.len()),
ColumnData::F64(v) => Ok(v.len()),
_ => Err(FitsError::InvalidValue {
card: "ASCII table column must be Text, I64, or F64".to_string(),
}),
}
}
fn ascii_table_header(
nrows: usize,
row_len: usize,
columns: &[AsciiWriteColumn],
tbcols: &[usize],
) -> Header {
let mut header = Header::new();
header
.set("XTENSION", "TABLE")
.comment("XTENSION", "ASCII table extension");
header.set("BITPIX", 8).set("NAXIS", 2);
header
.set("NAXIS1", row_len as i64)
.comment("NAXIS1", "width of table in characters");
header
.set("NAXIS2", nrows as i64)
.comment("NAXIS2", "number of rows");
header.set("PCOUNT", 0).set("GCOUNT", 1);
header
.set("TFIELDS", columns.len() as i64)
.comment("TFIELDS", "number of columns");
for (i, col) in columns.iter().enumerate() {
let n = i + 1;
header.set(key!("TBCOL{n}").as_str(), tbcols[i] as i64);
header.set(key!("TFORM{n}").as_str(), ascii_tform(col));
header.set(key!("TTYPE{n}").as_str(), col.name.as_str());
if let Some(unit) = &col.unit {
header.set(key!("TUNIT{n}").as_str(), unit.as_str());
}
if let Some(tscale) = col.tscale {
header.set(key!("TSCAL{n}").as_str(), tscale);
}
if let Some(tzero) = col.tzero {
header.set(key!("TZERO{n}").as_str(), tzero);
}
if let Some(tnull) = &col.tnull {
header.set(key!("TNULL{n}").as_str(), tnull.as_str());
}
}
header
}
fn ascii_tform(col: &AsciiWriteColumn) -> String {
match col.data {
ColumnData::Text(_) => format!("A{}", col.width),
ColumnData::I64(_) => format!("I{}", col.width),
ColumnData::F64(_) => format!("F{}.{}", col.width, col.decimals),
_ => format!("A{}", col.width), }
}
fn format_ascii_field(out: &mut Vec<u8>, col: &AsciiWriteColumn, r: usize) {
let (text, left) = match &col.data {
ColumnData::Text(v) => (v[r].clone(), true),
ColumnData::I64(v) => (v[r].to_string(), false),
ColumnData::F64(v) if !v[r].is_finite() => (col.tnull.clone().unwrap_or_default(), false),
ColumnData::F64(v) => (format!("{:.*}", col.decimals, v[r]), false),
_ => (String::new(), true),
};
let bytes = text.as_bytes();
if bytes.len() > col.width {
out.extend(std::iter::repeat_n(b'*', col.width));
return;
}
let pad = col.width - bytes.len();
if left {
out.extend_from_slice(bytes);
out.extend(std::iter::repeat_n(b' ', pad));
} else {
out.extend(std::iter::repeat_n(b' ', pad));
out.extend_from_slice(bytes);
}
}
#[cfg(test)]
mod tests;