use std::io::Write;
use std::path::Path;
use thiserror::Error;
use crate::daf::{DOUBLE_BYTES, RECORD_BYTES};
pub const SPK_ND: u32 = 2;
pub const SPK_NI: u32 = 6;
pub const SPK_SUMMARY_DOUBLES: usize = 2 + 6_usize.div_ceil(2);
pub const SPK_SUMMARY_BYTES: usize = SPK_SUMMARY_DOUBLES * DOUBLE_BYTES;
const SUMMARY_HEADER_BYTES: usize = 24;
pub const SPK_SUMMARIES_PER_RECORD: usize =
(RECORD_BYTES - SUMMARY_HEADER_BYTES) / SPK_SUMMARY_BYTES;
pub const DOUBLES_PER_RECORD: usize = RECORD_BYTES / DOUBLE_BYTES;
#[derive(Debug, Error)]
pub enum SpkWriterError {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("segment id '{0}' is longer than 40 bytes after UTF-8 encoding")]
SegmentIdTooLong(String),
#[error("too many segments ({got}); single summary record holds at most {max}")]
TooManySegments { got: usize, max: usize },
#[error("invalid Type 3 segment: {0}")]
BadType3(&'static str),
#[error("invalid Type 9 segment: {0}")]
BadType9(&'static str),
}
#[derive(Debug, Clone)]
pub struct Type3Record {
pub mid: f64,
pub radius: f64,
pub x: Vec<f64>,
pub y: Vec<f64>,
pub z: Vec<f64>,
pub vx: Vec<f64>,
pub vy: Vec<f64>,
pub vz: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct Type3Segment {
pub target: i32,
pub center: i32,
pub frame_id: i32,
pub start_et: f64,
pub end_et: f64,
pub segment_id: String,
pub intlen: f64,
pub init: f64,
pub records: Vec<Type3Record>,
}
#[derive(Debug, Clone)]
pub struct Type9Segment {
pub target: i32,
pub center: i32,
pub frame_id: i32,
pub start_et: f64,
pub end_et: f64,
pub segment_id: String,
pub degree: i32,
pub states: Vec<f64>,
pub epochs: Vec<f64>,
}
enum Segment {
Type3(Type3Segment),
Type9(Type9Segment),
}
pub struct SpkWriter {
idword: [u8; 8],
locifn: [u8; 60],
segments: Vec<Segment>,
}
impl SpkWriter {
pub fn new_spk(locifn: &str) -> Self {
let mut locifn_bytes = [b' '; 60];
let src = locifn.as_bytes();
let n = src.len().min(60);
locifn_bytes[..n].copy_from_slice(&src[..n]);
Self {
idword: *b"DAF/SPK ",
locifn: locifn_bytes,
segments: Vec::new(),
}
}
pub fn add_type3(&mut self, seg: Type3Segment) -> Result<(), SpkWriterError> {
validate_segment_id(&seg.segment_id)?;
if seg.records.is_empty() {
return Err(SpkWriterError::BadType3("empty records"));
}
let n_coef = seg.records[0].x.len();
if n_coef == 0 {
return Err(SpkWriterError::BadType3("degree < 0"));
}
for r in &seg.records {
if r.x.len() != n_coef
|| r.y.len() != n_coef
|| r.z.len() != n_coef
|| r.vx.len() != n_coef
|| r.vy.len() != n_coef
|| r.vz.len() != n_coef
{
return Err(SpkWriterError::BadType3("inconsistent coefficient count"));
}
}
if seg.intlen <= 0.0 {
return Err(SpkWriterError::BadType3("INTLEN must be > 0"));
}
self.segments.push(Segment::Type3(seg));
Ok(())
}
pub fn add_type9(&mut self, seg: Type9Segment) -> Result<(), SpkWriterError> {
validate_segment_id(&seg.segment_id)?;
let n = seg.epochs.len();
if n == 0 {
return Err(SpkWriterError::BadType9("no epochs"));
}
if seg.states.len() != 6 * n {
return Err(SpkWriterError::BadType9(
"states length != 6 * epochs length",
));
}
if seg.degree < 1 {
return Err(SpkWriterError::BadType9("degree < 1"));
}
if (seg.degree as usize) + 1 > n {
return Err(SpkWriterError::BadType9(
"window (degree+1) exceeds sample count",
));
}
for pair in seg.epochs.windows(2) {
if pair[0].partial_cmp(&pair[1]) != Some(std::cmp::Ordering::Less) {
return Err(SpkWriterError::BadType9(
"epochs must be strictly increasing",
));
}
}
self.segments.push(Segment::Type9(seg));
Ok(())
}
pub fn to_bytes(&self) -> Result<Vec<u8>, SpkWriterError> {
if self.segments.len() > SPK_SUMMARIES_PER_RECORD {
return Err(SpkWriterError::TooManySegments {
got: self.segments.len(),
max: SPK_SUMMARIES_PER_RECORD,
});
}
let data_start_double = 3 * DOUBLES_PER_RECORD as u32 + 1; let mut cursor_double = data_start_double;
let mut segment_meta: Vec<SegmentMeta> = Vec::with_capacity(self.segments.len());
let mut payloads: Vec<Vec<f64>> = Vec::with_capacity(self.segments.len());
for seg in &self.segments {
let (meta_stub, payload) = encode_segment(seg)?;
let start = cursor_double;
let end = start + payload.len() as u32 - 1;
cursor_double = end + 1;
segment_meta.push(SegmentMeta {
start_et: meta_stub.start_et,
end_et: meta_stub.end_et,
target: meta_stub.target,
center: meta_stub.center,
frame_id: meta_stub.frame_id,
data_type: meta_stub.data_type,
start_addr: start as i32,
end_addr: end as i32,
name: meta_stub.name,
});
payloads.push(payload);
}
let total_data_doubles: usize = payloads.iter().map(|p| p.len()).sum();
let data_records = total_data_doubles.div_ceil(DOUBLES_PER_RECORD);
let total_records = 3 + data_records;
let mut buf = vec![0u8; total_records * RECORD_BYTES];
write_file_record(
&mut buf[0..RECORD_BYTES],
&self.idword,
SPK_ND,
SPK_NI,
&self.locifn,
2,
2,
cursor_double,
);
write_summary_record(&mut buf[RECORD_BYTES..2 * RECORD_BYTES], &segment_meta);
write_name_record(&mut buf[2 * RECORD_BYTES..3 * RECORD_BYTES], &segment_meta);
let data_byte_start = 3 * RECORD_BYTES;
let mut double_idx_in_data: usize = 0;
for payload in &payloads {
for &d in payload {
let off = data_byte_start + double_idx_in_data * DOUBLE_BYTES;
buf[off..off + DOUBLE_BYTES].copy_from_slice(&d.to_le_bytes());
double_idx_in_data += 1;
}
}
Ok(buf)
}
pub fn write<P: AsRef<Path>>(&self, path: P) -> Result<(), SpkWriterError> {
let bytes = self.to_bytes()?;
let target = path.as_ref();
let tmp = target.with_extension("tmp");
{
let mut f = std::fs::File::create(&tmp)?;
f.write_all(&bytes)?;
f.sync_all()?;
}
std::fs::rename(&tmp, target)?;
Ok(())
}
}
struct SegmentMeta {
start_et: f64,
end_et: f64,
target: i32,
center: i32,
frame_id: i32,
data_type: i32,
start_addr: i32,
end_addr: i32,
name: String,
}
struct SegmentMetaStub {
start_et: f64,
end_et: f64,
target: i32,
center: i32,
frame_id: i32,
data_type: i32,
name: String,
}
fn validate_segment_id(id: &str) -> Result<(), SpkWriterError> {
if id.len() > 40 {
return Err(SpkWriterError::SegmentIdTooLong(id.to_string()));
}
Ok(())
}
fn encode_segment(seg: &Segment) -> Result<(SegmentMetaStub, Vec<f64>), SpkWriterError> {
match seg {
Segment::Type3(s) => Ok((
SegmentMetaStub {
start_et: s.start_et,
end_et: s.end_et,
target: s.target,
center: s.center,
frame_id: s.frame_id,
data_type: 3,
name: s.segment_id.clone(),
},
encode_type3(s),
)),
Segment::Type9(s) => Ok((
SegmentMetaStub {
start_et: s.start_et,
end_et: s.end_et,
target: s.target,
center: s.center,
frame_id: s.frame_id,
data_type: 9,
name: s.segment_id.clone(),
},
encode_type9(s),
)),
}
}
fn encode_type3(seg: &Type3Segment) -> Vec<f64> {
let n_coef = seg.records[0].x.len();
let rsize = 2 + 6 * n_coef;
let n_records = seg.records.len();
let mut out = Vec::with_capacity(rsize * n_records + 4);
for r in &seg.records {
out.push(r.mid);
out.push(r.radius);
out.extend_from_slice(&r.x);
out.extend_from_slice(&r.y);
out.extend_from_slice(&r.z);
out.extend_from_slice(&r.vx);
out.extend_from_slice(&r.vy);
out.extend_from_slice(&r.vz);
}
out.push(seg.init);
out.push(seg.intlen);
out.push(rsize as f64);
out.push(n_records as f64);
out
}
fn encode_type9(seg: &Type9Segment) -> Vec<f64> {
let n = seg.epochs.len();
let n_dir = (n - 1) / 100;
let mut out = Vec::with_capacity(6 * n + n + n_dir + 2);
out.extend_from_slice(&seg.states);
out.extend_from_slice(&seg.epochs);
for k in 1..=n_dir {
out.push(seg.epochs[k * 100 - 1]);
}
out.push(seg.degree as f64);
out.push(n as f64);
out
}
#[allow(clippy::too_many_arguments)]
fn write_file_record(
rec: &mut [u8],
idword: &[u8; 8],
nd: u32,
ni: u32,
locifn: &[u8; 60],
fward: u32,
bward: u32,
free: u32,
) {
assert_eq!(rec.len(), RECORD_BYTES);
for b in rec.iter_mut() {
*b = 0;
}
rec[0..8].copy_from_slice(idword);
rec[8..12].copy_from_slice(&nd.to_le_bytes());
rec[12..16].copy_from_slice(&ni.to_le_bytes());
rec[16..76].copy_from_slice(locifn);
rec[76..80].copy_from_slice(&fward.to_le_bytes());
rec[80..84].copy_from_slice(&bward.to_le_bytes());
rec[84..88].copy_from_slice(&free.to_le_bytes());
rec[88..96].copy_from_slice(b"LTL-IEEE");
let ftpstr: &[u8] = b"FTPSTR:\r:\n:\r\n:\r\x00:\x81:\x10\xce:ENDFTP";
rec[500..500 + ftpstr.len()].copy_from_slice(ftpstr);
}
fn write_summary_record(rec: &mut [u8], segments: &[SegmentMeta]) {
assert_eq!(rec.len(), RECORD_BYTES);
rec[0..8].copy_from_slice(&(0.0_f64).to_le_bytes());
rec[8..16].copy_from_slice(&(0.0_f64).to_le_bytes());
rec[16..24].copy_from_slice(&(segments.len() as f64).to_le_bytes());
for (i, s) in segments.iter().enumerate() {
let off = 24 + i * SPK_SUMMARY_BYTES;
rec[off..off + 8].copy_from_slice(&s.start_et.to_le_bytes());
rec[off + 8..off + 16].copy_from_slice(&s.end_et.to_le_bytes());
let int_off = off + 16;
rec[int_off..int_off + 4].copy_from_slice(&s.target.to_le_bytes());
rec[int_off + 4..int_off + 8].copy_from_slice(&s.center.to_le_bytes());
rec[int_off + 8..int_off + 12].copy_from_slice(&s.frame_id.to_le_bytes());
rec[int_off + 12..int_off + 16].copy_from_slice(&s.data_type.to_le_bytes());
rec[int_off + 16..int_off + 20].copy_from_slice(&s.start_addr.to_le_bytes());
rec[int_off + 20..int_off + 24].copy_from_slice(&s.end_addr.to_le_bytes());
}
}
fn write_name_record(rec: &mut [u8], segments: &[SegmentMeta]) {
assert_eq!(rec.len(), RECORD_BYTES);
for b in rec.iter_mut() {
*b = b' ';
}
for (i, s) in segments.iter().enumerate() {
let off = i * SPK_SUMMARY_BYTES;
let name_bytes = s.name.as_bytes();
let n = name_bytes.len().min(SPK_SUMMARY_BYTES);
rec[off..off + n].copy_from_slice(&name_bytes[..n]);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::daf::DafFile;
use crate::spk::SpkFile;
use tempfile::NamedTempFile;
fn tmp_path() -> NamedTempFile {
NamedTempFile::new().expect("create tempfile")
}
#[test]
fn type3_roundtrip_via_daf_reader() {
let mut w = SpkWriter::new_spk("unit-test");
let segment = Type3Segment {
target: 1_000_000,
center: 0,
frame_id: 1, start_et: 0.0,
end_et: 100.0,
segment_id: "type3-roundtrip".to_string(),
init: 0.0,
intlen: 50.0,
records: vec![
Type3Record {
mid: 25.0,
radius: 25.0,
x: vec![1.0, 0.0, 0.0],
y: vec![2.0, 0.0, 0.0],
z: vec![3.0, 0.0, 0.0],
vx: vec![0.001, 0.0, 0.0],
vy: vec![0.002, 0.0, 0.0],
vz: vec![0.003, 0.0, 0.0],
},
Type3Record {
mid: 75.0,
radius: 25.0,
x: vec![1.1, 0.0, 0.0],
y: vec![2.1, 0.0, 0.0],
z: vec![3.1, 0.0, 0.0],
vx: vec![0.0011, 0.0, 0.0],
vy: vec![0.0022, 0.0, 0.0],
vz: vec![0.0033, 0.0, 0.0],
},
],
};
w.add_type3(segment.clone()).unwrap();
let f = tmp_path();
w.write(f.path()).unwrap();
let daf = DafFile::open(f.path()).unwrap();
let summaries = daf.summaries().unwrap();
assert_eq!(summaries.len(), 1);
let s = &summaries[0];
assert_eq!(s.doubles.len(), 2);
assert_eq!(s.doubles[0], 0.0);
assert_eq!(s.doubles[1], 100.0);
assert_eq!(s.integers[0], 1_000_000);
assert_eq!(s.integers[1], 0);
assert_eq!(s.integers[2], 1);
assert_eq!(s.integers[3], 3);
assert!(s.name.starts_with("type3-roundtrip"));
let end_addr = s.integers[5] as u32;
let trailer = daf.read_doubles(end_addr - 3, end_addr).unwrap();
assert_eq!(trailer[0], 0.0);
assert_eq!(trailer[1], 50.0);
assert_eq!(trailer[2], (2 + 6 * 3) as f64);
assert_eq!(trailer[3], 2.0);
let spk = SpkFile::open(f.path()).unwrap();
assert_eq!(spk.segments().len(), 1);
let seg = &spk.segments()[0];
assert_eq!(seg.data_type, 3);
let st = spk.state(1_000_000, 0, 25.0).unwrap();
assert!((st[0] - 1.0).abs() < 1e-14);
assert!((st[1] - 2.0).abs() < 1e-14);
assert!((st[2] - 3.0).abs() < 1e-14);
assert!((st[3] - 0.001).abs() < 1e-16);
assert!((st[4] - 0.002).abs() < 1e-16);
assert!((st[5] - 0.003).abs() < 1e-16);
}
#[test]
fn type9_roundtrip_via_daf_reader() {
let n: usize = 20;
let mut epochs = Vec::with_capacity(n);
let mut states = Vec::with_capacity(6 * n);
for i in 0..n {
let t = i as f64 * 10.0;
epochs.push(t);
states.extend_from_slice(&[
1.0 + 0.5 * t, -2.0 + 0.1 * t, 0.5 - 0.2 * t, 0.5, 0.1, -0.2, ]);
}
let mut w = SpkWriter::new_spk("type9-test");
w.add_type9(Type9Segment {
target: -1,
center: 0,
frame_id: 1,
start_et: epochs[0],
end_et: *epochs.last().unwrap(),
segment_id: "type9-linear".to_string(),
degree: 3,
states,
epochs,
})
.unwrap();
let f = tmp_path();
w.write(f.path()).unwrap();
let spk = SpkFile::open(f.path()).unwrap();
let seg = &spk.segments()[0];
assert_eq!(seg.data_type, 9);
let st = spk.state(-1, 0, 55.0).unwrap();
assert!((st[0] - (1.0 + 0.5 * 55.0)).abs() < 1e-12);
assert!((st[1] - (-2.0 + 0.1 * 55.0)).abs() < 1e-12);
assert!((st[2] - (0.5 - 0.2 * 55.0)).abs() < 1e-12);
assert!((st[3] - 0.5).abs() < 1e-14);
assert!((st[4] - 0.1).abs() < 1e-14);
assert!((st[5] + 0.2).abs() < 1e-14);
}
#[test]
fn multiple_segments_in_one_file() {
let mut w = SpkWriter::new_spk("multi");
let t3 = Type3Segment {
target: 100,
center: 0,
frame_id: 1,
start_et: 0.0,
end_et: 10.0,
segment_id: "t3a".to_string(),
init: 0.0,
intlen: 10.0,
records: vec![Type3Record {
mid: 5.0,
radius: 5.0,
x: vec![7.0, 0.0],
y: vec![8.0, 0.0],
z: vec![9.0, 0.0],
vx: vec![0.0, 0.0],
vy: vec![0.0, 0.0],
vz: vec![0.0, 0.0],
}],
};
w.add_type3(t3).unwrap();
let epochs: Vec<f64> = (0..10).map(|i| i as f64).collect();
let states: Vec<f64> = (0..10)
.flat_map(|i| {
let t = i as f64;
[10.0 + t, 20.0, 30.0, 1.0, 0.0, 0.0].into_iter()
})
.collect();
w.add_type9(Type9Segment {
target: 101,
center: 0,
frame_id: 1,
start_et: 0.0,
end_et: 9.0,
segment_id: "t9a".to_string(),
degree: 1,
states,
epochs,
})
.unwrap();
let f = tmp_path();
w.write(f.path()).unwrap();
let spk = SpkFile::open(f.path()).unwrap();
assert_eq!(spk.segments().len(), 2);
let s1 = spk.state(100, 0, 5.0).unwrap();
assert!((s1[0] - 7.0).abs() < 1e-14);
let s2 = spk.state(101, 0, 5.5).unwrap();
assert!((s2[0] - 15.5).abs() < 1e-12);
}
#[test]
fn file_record_fields_validate() {
let mut w = SpkWriter::new_spk("header-check");
w.add_type3(Type3Segment {
target: 1,
center: 0,
frame_id: 1,
start_et: 0.0,
end_et: 1.0,
segment_id: "x".to_string(),
init: 0.0,
intlen: 1.0,
records: vec![Type3Record {
mid: 0.5,
radius: 0.5,
x: vec![0.0],
y: vec![0.0],
z: vec![0.0],
vx: vec![0.0],
vy: vec![0.0],
vz: vec![0.0],
}],
})
.unwrap();
let bytes = w.to_bytes().unwrap();
assert_eq!(&bytes[0..8], b"DAF/SPK ");
assert_eq!(u32::from_le_bytes(bytes[8..12].try_into().unwrap()), 2);
assert_eq!(u32::from_le_bytes(bytes[12..16].try_into().unwrap()), 6);
assert_eq!(&bytes[88..96], b"LTL-IEEE");
assert_eq!(u32::from_le_bytes(bytes[76..80].try_into().unwrap()), 2);
assert_eq!(u32::from_le_bytes(bytes[80..84].try_into().unwrap()), 2);
}
#[test]
fn rejects_too_many_segments() {
let mut w = SpkWriter::new_spk("overflow");
for i in 0..(SPK_SUMMARIES_PER_RECORD + 1) {
w.add_type3(Type3Segment {
target: i as i32,
center: 0,
frame_id: 1,
start_et: 0.0,
end_et: 1.0,
segment_id: format!("s{i}"),
init: 0.0,
intlen: 1.0,
records: vec![Type3Record {
mid: 0.5,
radius: 0.5,
x: vec![0.0],
y: vec![0.0],
z: vec![0.0],
vx: vec![0.0],
vy: vec![0.0],
vz: vec![0.0],
}],
})
.unwrap();
}
let err = w.to_bytes().unwrap_err();
matches!(err, SpkWriterError::TooManySegments { .. });
}
#[test]
fn rejects_empty_type3() {
let mut w = SpkWriter::new_spk("empty");
let err = w
.add_type3(Type3Segment {
target: 1,
center: 0,
frame_id: 1,
start_et: 0.0,
end_et: 1.0,
segment_id: "x".to_string(),
init: 0.0,
intlen: 1.0,
records: vec![],
})
.unwrap_err();
matches!(err, SpkWriterError::BadType3(_));
}
#[test]
fn rejects_non_monotone_type9_epochs() {
let mut w = SpkWriter::new_spk("non-mono");
let err = w
.add_type9(Type9Segment {
target: -1,
center: 0,
frame_id: 1,
start_et: 0.0,
end_et: 1.0,
segment_id: "x".to_string(),
degree: 1,
states: vec![0.0; 6 * 3],
epochs: vec![0.0, 1.0, 0.5],
})
.unwrap_err();
matches!(err, SpkWriterError::BadType9(_));
}
#[test]
fn rejects_segment_id_over_40_bytes() {
let mut w = SpkWriter::new_spk("long-id");
let long = "x".repeat(41);
let err = w
.add_type3(Type3Segment {
target: 1,
center: 0,
frame_id: 1,
start_et: 0.0,
end_et: 1.0,
segment_id: long,
init: 0.0,
intlen: 1.0,
records: vec![Type3Record {
mid: 0.5,
radius: 0.5,
x: vec![0.0],
y: vec![0.0],
z: vec![0.0],
vx: vec![0.0],
vy: vec![0.0],
vz: vec![0.0],
}],
})
.unwrap_err();
matches!(err, SpkWriterError::SegmentIdTooLong(_));
}
#[test]
fn type9_directory_present_when_n_over_100() {
let n = 150;
let epochs: Vec<f64> = (0..n).map(|i| i as f64).collect();
let states: Vec<f64> = (0..n)
.flat_map(|_| [0.0_f64, 0.0, 0.0, 0.0, 0.0, 0.0].into_iter())
.collect();
let mut w = SpkWriter::new_spk("n150");
w.add_type9(Type9Segment {
target: -1,
center: 0,
frame_id: 1,
start_et: 0.0,
end_et: (n - 1) as f64,
segment_id: "n150".to_string(),
degree: 3,
states,
epochs,
})
.unwrap();
let f = tmp_path();
w.write(f.path()).unwrap();
let daf = DafFile::open(f.path()).unwrap();
let s = &daf.summaries().unwrap()[0];
let end_addr = s.integers[5] as u32;
let start_addr = s.integers[4] as u32;
let expected_len = 6 * n + n + 1 + 2;
assert_eq!((end_addr - start_addr + 1) as usize, expected_len);
let dir_addr = start_addr + (6 * n + n) as u32;
let dir = daf.read_doubles(dir_addr, dir_addr).unwrap();
assert_eq!(dir[0], 99.0);
}
}