1mod crc;
35#[cfg(feature = "format-encryption")]
36pub mod encryption;
37pub mod license;
38pub mod piracy;
39#[cfg(feature = "format-signing")]
40pub mod signing;
41#[cfg(feature = "format-streaming")]
42pub mod streaming;
43
44pub use crc::crc32;
45use serde::{Deserialize, Serialize};
46
47use crate::error::{Error, Result};
48
49pub const MAGIC: [u8; 4] = [0x41, 0x4C, 0x44, 0x46];
51
52pub const FORMAT_VERSION_MAJOR: u8 = 1;
54pub const FORMAT_VERSION_MINOR: u8 = 2;
56
57pub const HEADER_SIZE: usize = 32;
59
60pub mod flags {
62 pub const ENCRYPTED: u8 = 0b0000_0001;
64 pub const SIGNED: u8 = 0b0000_0010;
66 pub const STREAMING: u8 = 0b0000_0100;
68 pub const LICENSED: u8 = 0b0000_1000;
70 pub const TRUENO_NATIVE: u8 = 0b0001_0000;
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
76#[repr(u16)]
77pub enum DatasetType {
78 Tabular = 0x0001,
81 TimeSeries = 0x0002,
83 Graph = 0x0003,
85 Spatial = 0x0004,
87
88 TextCorpus = 0x0010,
91 TextClassification = 0x0011,
93 TextPairs = 0x0012,
95 SequenceLabeling = 0x0013,
97 QuestionAnswering = 0x0014,
99 Summarization = 0x0015,
101 Translation = 0x0016,
103
104 ImageClassification = 0x0020,
107 ObjectDetection = 0x0021,
109 Segmentation = 0x0022,
111 ImagePairs = 0x0023,
113 Video = 0x0024,
115
116 AudioClassification = 0x0030,
119 SpeechRecognition = 0x0031,
121 SpeakerIdentification = 0x0032,
123
124 UserItemRatings = 0x0040,
127 ImplicitFeedback = 0x0041,
129 SequentialRecs = 0x0042,
131
132 ImageText = 0x0050,
135 AudioText = 0x0051,
137 VideoText = 0x0052,
139
140 Custom = 0x00FF,
143}
144
145impl DatasetType {
146 #[must_use]
148 pub fn from_u16(value: u16) -> Option<Self> {
149 match value {
150 0x0001 => Some(Self::Tabular),
151 0x0002 => Some(Self::TimeSeries),
152 0x0003 => Some(Self::Graph),
153 0x0004 => Some(Self::Spatial),
154 0x0010 => Some(Self::TextCorpus),
155 0x0011 => Some(Self::TextClassification),
156 0x0012 => Some(Self::TextPairs),
157 0x0013 => Some(Self::SequenceLabeling),
158 0x0014 => Some(Self::QuestionAnswering),
159 0x0015 => Some(Self::Summarization),
160 0x0016 => Some(Self::Translation),
161 0x0020 => Some(Self::ImageClassification),
162 0x0021 => Some(Self::ObjectDetection),
163 0x0022 => Some(Self::Segmentation),
164 0x0023 => Some(Self::ImagePairs),
165 0x0024 => Some(Self::Video),
166 0x0030 => Some(Self::AudioClassification),
167 0x0031 => Some(Self::SpeechRecognition),
168 0x0032 => Some(Self::SpeakerIdentification),
169 0x0040 => Some(Self::UserItemRatings),
170 0x0041 => Some(Self::ImplicitFeedback),
171 0x0042 => Some(Self::SequentialRecs),
172 0x0050 => Some(Self::ImageText),
173 0x0051 => Some(Self::AudioText),
174 0x0052 => Some(Self::VideoText),
175 0x00FF => Some(Self::Custom),
176 _ => None,
177 }
178 }
179
180 #[must_use]
182 pub const fn as_u16(self) -> u16 {
183 self as u16
184 }
185}
186
187#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
189#[repr(u8)]
190pub enum Compression {
191 None = 0x00,
193 #[default]
195 ZstdL3 = 0x01,
196 ZstdL19 = 0x02,
198 Lz4 = 0x03,
200}
201
202impl Compression {
203 #[must_use]
205 pub fn from_u8(value: u8) -> Option<Self> {
206 match value {
207 0x00 => Some(Self::None),
208 0x01 => Some(Self::ZstdL3),
209 0x02 => Some(Self::ZstdL19),
210 0x03 => Some(Self::Lz4),
211 _ => None,
212 }
213 }
214
215 #[must_use]
217 pub const fn as_u8(self) -> u8 {
218 self as u8
219 }
220}
221
222#[derive(Debug, Clone, PartialEq, Eq)]
237pub struct Header {
238 pub version: (u8, u8),
240 pub dataset_type: DatasetType,
242 pub metadata_size: u32,
244 pub payload_size: u32,
246 pub uncompressed_size: u32,
248 pub compression: Compression,
250 pub flags: u8,
252 pub schema_size: u16,
254 pub num_rows: u64,
256}
257
258impl Header {
259 #[must_use]
261 pub fn new(dataset_type: DatasetType) -> Self {
262 contract_pre_configuration!();
264 Self {
265 version: (FORMAT_VERSION_MAJOR, FORMAT_VERSION_MINOR),
266 dataset_type,
267 metadata_size: 0,
268 payload_size: 0,
269 uncompressed_size: 0,
270 compression: Compression::default(),
271 flags: 0,
272 schema_size: 0,
273 num_rows: 0,
274 }
275 }
276
277 #[must_use]
279 pub fn to_bytes(&self) -> [u8; HEADER_SIZE] {
280 contract_pre_configuration!();
282 let mut buf = [0u8; HEADER_SIZE];
283
284 buf[0..4].copy_from_slice(&MAGIC);
286
287 buf[4] = self.version.0;
289 buf[5] = self.version.1;
290
291 let dt = self.dataset_type.as_u16().to_le_bytes();
293 buf[6..8].copy_from_slice(&dt);
294
295 buf[8..12].copy_from_slice(&self.metadata_size.to_le_bytes());
297
298 buf[12..16].copy_from_slice(&self.payload_size.to_le_bytes());
300
301 buf[16..20].copy_from_slice(&self.uncompressed_size.to_le_bytes());
303
304 buf[20] = self.compression.as_u8();
306
307 buf[21] = self.flags;
309
310 buf[22..24].copy_from_slice(&self.schema_size.to_le_bytes());
312
313 buf[24..32].copy_from_slice(&self.num_rows.to_le_bytes());
315
316 buf
317 }
318
319 pub fn from_bytes(buf: &[u8]) -> Result<Self> {
326 if buf.len() < HEADER_SIZE {
327 return Err(Error::Format(format!(
328 "Header too short: {} bytes, expected {}",
329 buf.len(),
330 HEADER_SIZE
331 )));
332 }
333
334 if buf[0..4] != MAGIC {
336 return Err(Error::Format(format!(
337 "Invalid magic: expected {:?}, got {:?}",
338 MAGIC,
339 &buf[0..4]
340 )));
341 }
342
343 contract_pre_configuration!(buf);
345
346 let version = (buf[4], buf[5]);
348 if version.0 > FORMAT_VERSION_MAJOR {
349 return Err(Error::Format(format!(
350 "Unsupported version: {}.{}, max supported: {}.{}",
351 version.0, version.1, FORMAT_VERSION_MAJOR, FORMAT_VERSION_MINOR
352 )));
353 }
354
355 let dt_value = u16::from_le_bytes([buf[6], buf[7]]);
357 let dataset_type = DatasetType::from_u16(dt_value)
358 .ok_or_else(|| Error::Format(format!("Unknown dataset type: 0x{:04X}", dt_value)))?;
359
360 let metadata_size = u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]);
362
363 let payload_size = u32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]);
365
366 let uncompressed_size = u32::from_le_bytes([buf[16], buf[17], buf[18], buf[19]]);
368
369 let compression = Compression::from_u8(buf[20])
371 .ok_or_else(|| Error::Format(format!("Unknown compression: 0x{:02X}", buf[20])))?;
372
373 let flags = buf[21];
375
376 let schema_size = u16::from_le_bytes([buf[22], buf[23]]);
378
379 let num_rows = u64::from_le_bytes([
381 buf[24], buf[25], buf[26], buf[27], buf[28], buf[29], buf[30], buf[31],
382 ]);
383
384 Ok(Self {
385 version,
386 dataset_type,
387 metadata_size,
388 payload_size,
389 uncompressed_size,
390 compression,
391 flags,
392 schema_size,
393 num_rows,
394 })
395 }
396
397 #[must_use]
399 pub const fn is_encrypted(&self) -> bool {
400 self.flags & flags::ENCRYPTED != 0
401 }
402
403 #[must_use]
405 pub const fn is_signed(&self) -> bool {
406 self.flags & flags::SIGNED != 0
407 }
408
409 #[must_use]
411 pub const fn is_streaming(&self) -> bool {
412 self.flags & flags::STREAMING != 0
413 }
414
415 #[must_use]
417 pub const fn is_licensed(&self) -> bool {
418 self.flags & flags::LICENSED != 0
419 }
420
421 #[must_use]
423 pub const fn is_trueno_native(&self) -> bool {
424 self.flags & flags::TRUENO_NATIVE != 0
425 }
426}
427
428#[derive(Debug, Clone, Default, Serialize, Deserialize)]
430pub struct Metadata {
431 pub name: Option<String>,
433 pub version: Option<String>,
435 pub license: Option<String>,
437 #[serde(default)]
439 pub tags: Vec<String>,
440 pub description: Option<String>,
442 pub citation: Option<String>,
444 pub created_at: Option<String>,
446 #[serde(default, skip_serializing_if = "Option::is_none")]
448 pub sha256: Option<String>,
449}
450
451#[cfg(feature = "provenance")]
462#[must_use]
463pub fn sha256_hex(data: &[u8]) -> String {
464 use sha2::{Digest, Sha256};
465
466 let mut hasher = Sha256::new();
467 hasher.update(data);
468 let result = hasher.finalize();
469
470 result.iter().fold(String::with_capacity(64), |mut s, b| {
472 use std::fmt::Write;
473 let _ = write!(s, "{b:02x}");
474 s
475 })
476}
477
478#[derive(Debug, Clone)]
480pub struct SaveOptions {
481 pub compression: Compression,
483 pub metadata: Option<Metadata>,
485 #[cfg(feature = "format-encryption")]
487 pub encryption: Option<encryption::EncryptionParams>,
488 #[cfg(feature = "format-signing")]
490 pub signing_key: Option<signing::SigningKeyPair>,
491 pub license: Option<license::LicenseBlock>,
493}
494
495impl Default for SaveOptions {
496 fn default() -> Self {
497 Self {
498 compression: Compression::ZstdL3,
499 metadata: None,
500 #[cfg(feature = "format-encryption")]
501 encryption: None,
502 #[cfg(feature = "format-signing")]
503 signing_key: None,
504 license: None,
505 }
506 }
507}
508
509impl SaveOptions {
510 #[must_use]
512 pub fn with_compression(mut self, compression: Compression) -> Self {
513 self.compression = compression;
514 self
515 }
516
517 #[must_use]
519 pub fn with_metadata(mut self, metadata: Metadata) -> Self {
520 self.metadata = Some(metadata);
521 self
522 }
523
524 #[cfg(feature = "format-encryption")]
526 #[must_use]
527 pub fn with_password(mut self, password: impl Into<String>) -> Self {
528 self.encryption = Some(encryption::EncryptionParams::password(password));
529 self
530 }
531
532 #[cfg(feature = "format-encryption")]
534 #[must_use]
535 pub fn with_recipient(mut self, public_key: [u8; 32]) -> Self {
536 self.encryption = Some(encryption::EncryptionParams::recipient(public_key));
537 self
538 }
539
540 #[cfg(feature = "format-signing")]
542 #[must_use]
543 pub fn with_signing_key(mut self, key: signing::SigningKeyPair) -> Self {
544 self.signing_key = Some(key);
545 self
546 }
547
548 #[must_use]
550 pub fn with_license(mut self, license: license::LicenseBlock) -> Self {
551 self.license = Some(license);
552 self
553 }
554}
555
556#[derive(Debug, Clone, Default)]
558pub struct LoadOptions {
559 #[cfg(feature = "format-encryption")]
561 pub decryption: Option<encryption::DecryptionParams>,
562 #[cfg(feature = "format-signing")]
565 pub trusted_keys: Vec<[u8; 32]>,
566 pub verify_license: bool,
568}
569
570impl LoadOptions {
571 #[cfg(feature = "format-encryption")]
573 #[must_use]
574 pub fn with_password(mut self, password: impl Into<String>) -> Self {
575 self.decryption = Some(encryption::DecryptionParams::password(password));
576 self
577 }
578
579 #[cfg(feature = "format-encryption")]
581 #[must_use]
582 pub fn with_private_key(mut self, key: [u8; 32]) -> Self {
583 self.decryption = Some(encryption::DecryptionParams::private_key(key));
584 self
585 }
586
587 #[cfg(feature = "format-signing")]
589 #[must_use]
590 pub fn with_trusted_key(mut self, key: [u8; 32]) -> Self {
591 self.trusted_keys.push(key);
592 self
593 }
594
595 #[must_use]
597 pub fn verify_license(mut self) -> Self {
598 self.verify_license = true;
599 self
600 }
601}
602
603#[allow(clippy::cast_possible_truncation)]
609pub fn save<W: std::io::Write>(
610 writer: &mut W,
611 batches: &[arrow::array::RecordBatch],
612 dataset_type: DatasetType,
613 options: &SaveOptions,
614) -> Result<()> {
615 if batches.is_empty() {
616 return Err(Error::EmptyDataset);
617 }
618
619 let schema = batches[0].schema();
620 let schema_buf = serialize_arrow_schema(&schema)?;
621 let payload_buf = serialize_arrow_payload(batches, &schema)?;
622 let uncompressed_size = payload_buf.len() as u32;
623 let compressed_payload = compress_payload(payload_buf, options.compression)?;
624
625 let mut header_flags: u8 = 0;
626 let (final_payload, encryption_header) =
627 apply_encryption_if_requested(compressed_payload, options, &mut header_flags)?;
628 update_header_flags(&mut header_flags, options);
629
630 let metadata_buf = serialize_save_metadata(options)?;
631 let num_rows: u64 = batches.iter().map(|b| b.num_rows() as u64).sum();
632 let header = Header {
633 version: (FORMAT_VERSION_MAJOR, FORMAT_VERSION_MINOR),
634 dataset_type,
635 metadata_size: metadata_buf.len() as u32,
636 payload_size: final_payload.len() as u32,
637 uncompressed_size,
638 compression: options.compression,
639 flags: header_flags,
640 schema_size: schema_buf.len() as u16,
641 num_rows,
642 };
643
644 write_packed_output(
645 writer,
646 &header,
647 &metadata_buf,
648 &schema_buf,
649 &encryption_header,
650 &final_payload,
651 options,
652 )
653}
654
655fn serialize_arrow_schema(schema: &arrow::datatypes::SchemaRef) -> Result<Vec<u8>> {
656 use arrow::ipc::writer::StreamWriter;
657 let mut schema_buf = Vec::new();
658 let mut schema_writer = StreamWriter::try_new(&mut schema_buf, schema).map_err(Error::Arrow)?;
659 schema_writer.finish().map_err(Error::Arrow)?;
660 Ok(schema_buf)
661}
662
663fn serialize_arrow_payload(
664 batches: &[arrow::array::RecordBatch],
665 schema: &arrow::datatypes::SchemaRef,
666) -> Result<Vec<u8>> {
667 use arrow::ipc::writer::StreamWriter;
668 let mut payload_buf = Vec::new();
669 let mut payload_writer =
670 StreamWriter::try_new(&mut payload_buf, schema).map_err(Error::Arrow)?;
671 for batch in batches {
672 payload_writer.write(batch).map_err(Error::Arrow)?;
673 }
674 payload_writer.finish().map_err(Error::Arrow)?;
675 Ok(payload_buf)
676}
677
678#[cfg(feature = "format-encryption")]
679fn apply_encryption_if_requested(
680 compressed_payload: Vec<u8>,
681 options: &SaveOptions,
682 header_flags: &mut u8,
683) -> Result<(Vec<u8>, Vec<u8>)> {
684 if let Some(ref enc_params) = options.encryption {
685 *header_flags |= flags::ENCRYPTED;
686 let block = build_encryption_block(&compressed_payload, enc_params)?;
687 let hdr_size = encryption_block_header_size(block[0]);
688 Ok((block[hdr_size..].to_vec(), block[..hdr_size].to_vec()))
689 } else {
690 Ok((compressed_payload, Vec::new()))
691 }
692}
693
694#[cfg(not(feature = "format-encryption"))]
695fn apply_encryption_if_requested(
696 compressed_payload: Vec<u8>,
697 _options: &SaveOptions,
698 _header_flags: &mut u8,
699) -> Result<(Vec<u8>, Vec<u8>)> {
700 Ok((compressed_payload, Vec::new()))
701}
702
703fn update_header_flags(header_flags: &mut u8, options: &SaveOptions) {
704 #[cfg(feature = "format-signing")]
705 if options.signing_key.is_some() {
706 *header_flags |= flags::SIGNED;
707 }
708 if options.license.is_some() {
709 *header_flags |= flags::LICENSED;
710 }
711}
712
713fn serialize_save_metadata(options: &SaveOptions) -> Result<Vec<u8>> {
714 let meta = options.metadata.as_ref();
715 if let Some(m) = meta {
716 rmp_serde::to_vec(m).map_err(|e| Error::Format(e.to_string()))
717 } else {
718 rmp_serde::to_vec(&Metadata::default()).map_err(|e| Error::Format(e.to_string()))
719 }
720}
721
722fn write_packed_output<W: std::io::Write>(
723 writer: &mut W,
724 header: &Header,
725 metadata_buf: &[u8],
726 schema_buf: &[u8],
727 encryption_header: &[u8],
728 final_payload: &[u8],
729 options: &SaveOptions,
730) -> Result<()> {
731 let all_data = assemble_all_data(
732 header,
733 metadata_buf,
734 schema_buf,
735 encryption_header,
736 final_payload,
737 options,
738 );
739 let checksum = crc32(&all_data);
740 writer.write_all(&all_data).map_err(Error::io_no_path)?;
741 writer
742 .write_all(&checksum.to_le_bytes())
743 .map_err(Error::io_no_path)?;
744 Ok(())
745}
746
747fn assemble_all_data(
748 header: &Header,
749 metadata_buf: &[u8],
750 schema_buf: &[u8],
751 encryption_header: &[u8],
752 final_payload: &[u8],
753 options: &SaveOptions,
754) -> Vec<u8> {
755 let mut all_data = Vec::new();
756 all_data.extend_from_slice(&header.to_bytes());
757 all_data.extend_from_slice(metadata_buf);
758 all_data.extend_from_slice(schema_buf);
759 all_data.extend_from_slice(encryption_header);
760 all_data.extend_from_slice(final_payload);
761 append_signature_if_signing(&mut all_data, options);
762 append_license_if_present(&mut all_data, options);
763 all_data
764}
765
766#[cfg(feature = "format-signing")]
767fn append_signature_if_signing(
768 all_data: &mut Vec<u8>,
769 options: &SaveOptions,
770) -> Option<[u8; signing::SignatureBlock::SIZE]> {
771 let key = options.signing_key.as_ref()?;
772 let sig_block = signing::SignatureBlock::sign(all_data, key);
773 let sig_bytes = sig_block.to_bytes();
774 all_data.extend_from_slice(&sig_bytes);
775 Some(sig_bytes)
776}
777
778#[cfg(not(feature = "format-signing"))]
779fn append_signature_if_signing(
780 _all_data: &mut Vec<u8>,
781 _options: &SaveOptions,
782) -> Option<[u8; 96]> {
783 None
784}
785
786fn append_license_if_present(all_data: &mut Vec<u8>, options: &SaveOptions) -> Option<Vec<u8>> {
787 let lic = options.license.as_ref()?;
788 let lic_bytes = lic.to_bytes();
789 all_data.extend_from_slice(&lic_bytes);
790 Some(lic_bytes)
791}
792
793fn compress_payload(payload: Vec<u8>, compression: Compression) -> Result<Vec<u8>> {
795 match compression {
796 Compression::None => Ok(payload),
797 Compression::ZstdL3 => zstd::encode_all(payload.as_slice(), 3).map_err(Error::io_no_path),
798 Compression::ZstdL19 => zstd::encode_all(payload.as_slice(), 19).map_err(Error::io_no_path),
799 Compression::Lz4 => {
800 let mut encoder = lz4_flex::frame::FrameEncoder::new(Vec::new());
801 std::io::Write::write_all(&mut encoder, &payload).map_err(Error::io_no_path)?;
802 encoder
803 .finish()
804 .map_err(|e| Error::Format(format!("LZ4 compression error: {e}")))
805 }
806 }
807}
808
809#[cfg(feature = "format-encryption")]
811fn encryption_block_header_size(mode: u8) -> usize {
812 if mode == encryption::mode::PASSWORD {
813 1 + 16 + 12 } else {
815 1 + 32 + 12 }
817}
818
819#[cfg(feature = "format-encryption")]
821fn build_encryption_block(
822 plaintext: &[u8],
823 params: &encryption::EncryptionParams,
824) -> Result<Vec<u8>> {
825 match ¶ms.mode {
826 encryption::EncryptionMode::Password(password) => {
827 let (mode, salt, nonce, ciphertext) =
828 encryption::encrypt_password(plaintext, password)?;
829 let mut block = Vec::with_capacity(1 + 16 + 12 + ciphertext.len());
830 block.push(mode);
831 block.extend_from_slice(&salt);
832 block.extend_from_slice(&nonce);
833 block.extend_from_slice(&ciphertext);
834 Ok(block)
835 }
836 encryption::EncryptionMode::Recipient {
837 recipient_public_key,
838 } => {
839 let (mode, ephemeral_pub, nonce, ciphertext) =
840 encryption::encrypt_recipient(plaintext, recipient_public_key)?;
841 let mut block = Vec::with_capacity(1 + 32 + 12 + ciphertext.len());
842 block.push(mode);
843 block.extend_from_slice(&ephemeral_pub);
844 block.extend_from_slice(&nonce);
845 block.extend_from_slice(&ciphertext);
846 Ok(block)
847 }
848 }
849}
850
851#[derive(Debug)]
853pub struct LoadedDataset {
854 pub header: Header,
856 pub metadata: Metadata,
858 pub batches: Vec<arrow::array::RecordBatch>,
860 pub license: Option<license::LicenseBlock>,
862 pub signer_public_key: Option<[u8; 32]>,
864}
865
866pub fn load<R: std::io::Read>(reader: &mut R) -> Result<LoadedDataset> {
875 load_with_options(reader, &LoadOptions::default())
876}
877
878pub fn load_with_options<R: std::io::Read>(
885 reader: &mut R,
886 options: &LoadOptions,
887) -> Result<LoadedDataset> {
888 let (all_data, checksum_offset) = read_and_verify_integrity(reader)?;
889 let (header, metadata, metadata_end) = parse_header_and_metadata(&all_data)?;
890 let (payload_end, compressed_payload) =
891 extract_payload_bytes(&header, &all_data, metadata_end, checksum_offset, options)?;
892 let (signer_public_key, license_block) =
893 parse_trailing_blocks(&header, &all_data, payload_end, checksum_offset, options)?;
894 let decompressed_payload = decompress_payload(compressed_payload, header.compression)?;
895 let batches = parse_arrow_batches(decompressed_payload)?;
896 Ok(LoadedDataset {
897 header,
898 metadata,
899 batches,
900 license: license_block,
901 signer_public_key,
902 })
903}
904
905fn read_and_verify_integrity<R: std::io::Read>(reader: &mut R) -> Result<(Vec<u8>, usize)> {
908 let mut all_data = Vec::new();
909 reader
910 .read_to_end(&mut all_data)
911 .map_err(Error::io_no_path)?;
912
913 if all_data.len() < HEADER_SIZE + 4 {
914 return Err(Error::Format("File too small".to_string()));
915 }
916
917 let checksum_offset = all_data.len() - 4;
918 let stored_checksum = u32::from_le_bytes([
919 all_data[checksum_offset],
920 all_data[checksum_offset + 1],
921 all_data[checksum_offset + 2],
922 all_data[checksum_offset + 3],
923 ]);
924
925 let computed_checksum = crc32(&all_data[..checksum_offset]);
926 if stored_checksum != computed_checksum {
927 return Err(Error::ChecksumMismatch {
928 expected: stored_checksum,
929 actual: computed_checksum,
930 });
931 }
932
933 Ok((all_data, checksum_offset))
934}
935
936fn parse_header_and_metadata(all_data: &[u8]) -> Result<(Header, Metadata, usize)> {
939 let header = Header::from_bytes(&all_data[..HEADER_SIZE])?;
940 let metadata_start = HEADER_SIZE;
941 let metadata_end = metadata_start + header.metadata_size as usize;
942 let metadata: Metadata = rmp_serde::from_slice(&all_data[metadata_start..metadata_end])
943 .map_err(|e| Error::Format(format!("Metadata parse error: {e}")))?;
944 Ok((header, metadata, metadata_end))
945}
946
947fn extract_payload_bytes(
950 header: &Header,
951 all_data: &[u8],
952 metadata_end: usize,
953 checksum_offset: usize,
954 options: &LoadOptions,
955) -> Result<(usize, Vec<u8>)> {
956 let schema_end = metadata_end + header.schema_size as usize;
957 let encryption_header_size = determine_encryption_header_size(header, all_data, schema_end)?;
958 let payload_start = schema_end + encryption_header_size;
959 let payload_end = payload_start + header.payload_size as usize;
960
961 if payload_end > checksum_offset {
962 return Err(Error::Format("Payload extends beyond data".to_string()));
963 }
964
965 let bytes = decrypt_or_copy_payload(
966 header,
967 all_data,
968 schema_end,
969 payload_start,
970 payload_end,
971 options,
972 )?;
973 Ok((payload_end, bytes))
974}
975
976#[cfg(feature = "format-encryption")]
978fn decrypt_or_copy_payload(
979 header: &Header,
980 all_data: &[u8],
981 schema_end: usize,
982 payload_start: usize,
983 payload_end: usize,
984 options: &LoadOptions,
985) -> Result<Vec<u8>> {
986 if !header.is_encrypted() {
987 return Ok(all_data[payload_start..payload_end].to_vec());
988 }
989 let enc_header = &all_data[schema_end..payload_start];
990 let ciphertext = &all_data[payload_start..payload_end];
991 let decryption_params = options.decryption.as_ref().ok_or_else(|| {
992 Error::Format("Dataset is encrypted but no decryption params provided".to_string())
993 })?;
994 decrypt_payload(enc_header, ciphertext, decryption_params)
995}
996
997#[cfg(not(feature = "format-encryption"))]
998fn decrypt_or_copy_payload(
999 header: &Header,
1000 all_data: &[u8],
1001 _schema_end: usize,
1002 payload_start: usize,
1003 payload_end: usize,
1004 _options: &LoadOptions,
1005) -> Result<Vec<u8>> {
1006 if header.is_encrypted() {
1007 return Err(Error::Format(
1008 "Dataset is encrypted but format-encryption feature is not enabled".to_string(),
1009 ));
1010 }
1011 Ok(all_data[payload_start..payload_end].to_vec())
1012}
1013
1014fn parse_arrow_batches(
1016 decompressed_payload: Vec<u8>,
1017) -> Result<Vec<arrow::record_batch::RecordBatch>> {
1018 use arrow::ipc::reader::StreamReader;
1019 let cursor = std::io::Cursor::new(decompressed_payload);
1020 let stream_reader = StreamReader::try_new(cursor, None).map_err(Error::Arrow)?;
1021 stream_reader
1022 .into_iter()
1023 .collect::<std::result::Result<Vec<_>, _>>()
1024 .map_err(Error::Arrow)
1025}
1026
1027fn determine_encryption_header_size(
1029 header: &Header,
1030 all_data: &[u8],
1031 schema_end: usize,
1032) -> Result<usize> {
1033 if !header.is_encrypted() {
1034 return Ok(0);
1035 }
1036
1037 if all_data.len() <= schema_end {
1038 return Err(Error::Format("Missing encryption header".to_string()));
1039 }
1040
1041 #[cfg(feature = "format-encryption")]
1042 {
1043 Ok(encryption_block_header_size(all_data[schema_end]))
1044 }
1045 #[cfg(not(feature = "format-encryption"))]
1046 {
1047 Err(Error::Format(
1048 "Dataset is encrypted but format-encryption feature is not enabled".to_string(),
1049 ))
1050 }
1051}
1052
1053fn parse_trailing_blocks(
1055 header: &Header,
1056 all_data: &[u8],
1057 payload_end: usize,
1058 checksum_offset: usize,
1059 options: &LoadOptions,
1060) -> Result<(Option<[u8; 32]>, Option<license::LicenseBlock>)> {
1061 let mut trailing_offset = payload_end;
1062
1063 let signer_public_key = if header.is_signed() {
1064 parse_signature_block(all_data, &mut trailing_offset, checksum_offset, options)?
1065 } else {
1066 None
1067 };
1068
1069 let license_block = if header.is_licensed() {
1070 Some(parse_license_block(
1071 all_data,
1072 trailing_offset,
1073 checksum_offset,
1074 options,
1075 )?)
1076 } else {
1077 None
1078 };
1079
1080 Ok((signer_public_key, license_block))
1081}
1082
1083#[cfg(feature = "format-signing")]
1084fn parse_signature_block(
1085 all_data: &[u8],
1086 trailing_offset: &mut usize,
1087 checksum_offset: usize,
1088 options: &LoadOptions,
1089) -> Result<Option<[u8; 32]>> {
1090 let sig_end = *trailing_offset + signing::SignatureBlock::SIZE;
1091 if sig_end > checksum_offset {
1092 return Err(Error::Format(
1093 "Signature block extends beyond data".to_string(),
1094 ));
1095 }
1096
1097 let sig_block = signing::SignatureBlock::from_bytes(&all_data[*trailing_offset..sig_end])?;
1098
1099 if !options.trusted_keys.is_empty() {
1100 if !options.trusted_keys.contains(&sig_block.public_key) {
1101 return Err(Error::Format("Signer not in trusted keys list".to_string()));
1102 }
1103 sig_block.verify(&all_data[..*trailing_offset])?;
1104 }
1105
1106 *trailing_offset = sig_end;
1107 Ok(Some(sig_block.public_key))
1108}
1109
1110#[cfg(not(feature = "format-signing"))]
1111fn parse_signature_block(
1112 _all_data: &[u8],
1113 _trailing_offset: &mut usize,
1114 _checksum_offset: usize,
1115 _options: &LoadOptions,
1116) -> Result<Option<[u8; 32]>> {
1117 Err(Error::Format(
1118 "Dataset is signed but format-signing feature is not enabled".to_string(),
1119 ))
1120}
1121
1122fn parse_license_block(
1123 all_data: &[u8],
1124 trailing_offset: usize,
1125 checksum_offset: usize,
1126 options: &LoadOptions,
1127) -> Result<license::LicenseBlock> {
1128 if trailing_offset >= checksum_offset {
1129 return Err(Error::Format("Missing license block".to_string()));
1130 }
1131 let lic = license::LicenseBlock::from_bytes(&all_data[trailing_offset..checksum_offset])?;
1132 if options.verify_license {
1133 lic.verify()?;
1134 }
1135 Ok(lic)
1136}
1137
1138fn decompress_payload(payload: Vec<u8>, compression: Compression) -> Result<Vec<u8>> {
1140 match compression {
1141 Compression::None => Ok(payload),
1142 Compression::ZstdL3 | Compression::ZstdL19 => zstd::decode_all(payload.as_slice())
1143 .map_err(|e| Error::Format(format!("Zstd decompression error: {e}"))),
1144 Compression::Lz4 => {
1145 let mut decoder = lz4_flex::frame::FrameDecoder::new(payload.as_slice());
1146 let mut decompressed = Vec::new();
1147 std::io::Read::read_to_end(&mut decoder, &mut decompressed)
1148 .map_err(|e| Error::Format(format!("LZ4 decompression error: {e}")))?;
1149 Ok(decompressed)
1150 }
1151 }
1152}
1153
1154#[cfg(feature = "format-encryption")]
1156fn decrypt_payload(
1157 enc_header: &[u8],
1158 ciphertext: &[u8],
1159 params: &encryption::DecryptionParams,
1160) -> Result<Vec<u8>> {
1161 if enc_header.is_empty() {
1162 return Err(Error::Format("Empty encryption header".to_string()));
1163 }
1164
1165 let mode = enc_header[0];
1166
1167 match (mode, params) {
1168 (encryption::mode::PASSWORD, encryption::DecryptionParams::Password(password)) => {
1169 if enc_header.len() < 1 + 16 + 12 {
1170 return Err(Error::Format(
1171 "Invalid password encryption header".to_string(),
1172 ));
1173 }
1174 let mut salt = [0u8; 16];
1175 let mut nonce = [0u8; 12];
1176 salt.copy_from_slice(&enc_header[1..17]);
1177 nonce.copy_from_slice(&enc_header[17..29]);
1178
1179 encryption::decrypt_password(ciphertext, password, &salt, &nonce)
1180 }
1181 (encryption::mode::RECIPIENT, encryption::DecryptionParams::PrivateKey(private_key)) => {
1182 if enc_header.len() < 1 + 32 + 12 {
1183 return Err(Error::Format(
1184 "Invalid recipient encryption header".to_string(),
1185 ));
1186 }
1187 let mut ephemeral_pub = [0u8; 32];
1188 let mut nonce = [0u8; 12];
1189 ephemeral_pub.copy_from_slice(&enc_header[1..33]);
1190 nonce.copy_from_slice(&enc_header[33..45]);
1191
1192 encryption::decrypt_recipient(ciphertext, private_key, &ephemeral_pub, &nonce)
1193 }
1194 (encryption::mode::PASSWORD, encryption::DecryptionParams::PrivateKey(_)) => Err(
1195 Error::Format("Dataset encrypted with password but private key provided".to_string()),
1196 ),
1197 (encryption::mode::RECIPIENT, encryption::DecryptionParams::Password(_)) => Err(
1198 Error::Format("Dataset encrypted for recipient but password provided".to_string()),
1199 ),
1200 _ => Err(Error::Format(format!("Unknown encryption mode: {mode}"))),
1201 }
1202}
1203
1204pub fn save_to_file<P: AsRef<std::path::Path>>(
1210 path: P,
1211 batches: &[arrow::array::RecordBatch],
1212 dataset_type: DatasetType,
1213 options: &SaveOptions,
1214) -> Result<()> {
1215 let file = std::fs::File::create(path.as_ref())
1216 .map_err(|e| Error::io(e, path.as_ref().to_path_buf()))?;
1217 let mut writer = std::io::BufWriter::new(file);
1218 save(&mut writer, batches, dataset_type, options)
1219}
1220
1221pub fn load_from_file<P: AsRef<std::path::Path>>(path: P) -> Result<LoadedDataset> {
1227 load_from_file_with_options(path, &LoadOptions::default())
1228}
1229
1230pub fn load_from_file_with_options<P: AsRef<std::path::Path>>(
1237 path: P,
1238 options: &LoadOptions,
1239) -> Result<LoadedDataset> {
1240 let file = std::fs::File::open(path.as_ref())
1241 .map_err(|e| Error::io(e, path.as_ref().to_path_buf()))?;
1242 let mut reader = std::io::BufReader::new(file);
1243 load_with_options(&mut reader, options)
1244}
1245
1246#[cfg(test)]
1247mod tests;