1mod crc;
35#[cfg(feature = "format-encryption")]
36pub mod encryption;
37pub mod license;
38pub mod piracy;
39#[cfg(feature = "format-signing")]
40pub mod signing;
41#[cfg(feature = "format-streaming")]
42pub mod streaming;
43
44pub use crc::crc32;
45use serde::{Deserialize, Serialize};
46
47use crate::error::{Error, Result};
48
49pub const MAGIC: [u8; 4] = [0x41, 0x4C, 0x44, 0x46];
51
52pub const FORMAT_VERSION_MAJOR: u8 = 1;
54pub const FORMAT_VERSION_MINOR: u8 = 2;
56
57pub const HEADER_SIZE: usize = 32;
59
60pub mod flags {
62 pub const ENCRYPTED: u8 = 0b0000_0001;
64 pub const SIGNED: u8 = 0b0000_0010;
66 pub const STREAMING: u8 = 0b0000_0100;
68 pub const LICENSED: u8 = 0b0000_1000;
70 pub const TRUENO_NATIVE: u8 = 0b0001_0000;
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
76#[repr(u16)]
77pub enum DatasetType {
78 Tabular = 0x0001,
81 TimeSeries = 0x0002,
83 Graph = 0x0003,
85 Spatial = 0x0004,
87
88 TextCorpus = 0x0010,
91 TextClassification = 0x0011,
93 TextPairs = 0x0012,
95 SequenceLabeling = 0x0013,
97 QuestionAnswering = 0x0014,
99 Summarization = 0x0015,
101 Translation = 0x0016,
103
104 ImageClassification = 0x0020,
107 ObjectDetection = 0x0021,
109 Segmentation = 0x0022,
111 ImagePairs = 0x0023,
113 Video = 0x0024,
115
116 AudioClassification = 0x0030,
119 SpeechRecognition = 0x0031,
121 SpeakerIdentification = 0x0032,
123
124 UserItemRatings = 0x0040,
127 ImplicitFeedback = 0x0041,
129 SequentialRecs = 0x0042,
131
132 ImageText = 0x0050,
135 AudioText = 0x0051,
137 VideoText = 0x0052,
139
140 Custom = 0x00FF,
143}
144
145impl DatasetType {
146 #[must_use]
148 pub fn from_u16(value: u16) -> Option<Self> {
149 match value {
150 0x0001 => Some(Self::Tabular),
151 0x0002 => Some(Self::TimeSeries),
152 0x0003 => Some(Self::Graph),
153 0x0004 => Some(Self::Spatial),
154 0x0010 => Some(Self::TextCorpus),
155 0x0011 => Some(Self::TextClassification),
156 0x0012 => Some(Self::TextPairs),
157 0x0013 => Some(Self::SequenceLabeling),
158 0x0014 => Some(Self::QuestionAnswering),
159 0x0015 => Some(Self::Summarization),
160 0x0016 => Some(Self::Translation),
161 0x0020 => Some(Self::ImageClassification),
162 0x0021 => Some(Self::ObjectDetection),
163 0x0022 => Some(Self::Segmentation),
164 0x0023 => Some(Self::ImagePairs),
165 0x0024 => Some(Self::Video),
166 0x0030 => Some(Self::AudioClassification),
167 0x0031 => Some(Self::SpeechRecognition),
168 0x0032 => Some(Self::SpeakerIdentification),
169 0x0040 => Some(Self::UserItemRatings),
170 0x0041 => Some(Self::ImplicitFeedback),
171 0x0042 => Some(Self::SequentialRecs),
172 0x0050 => Some(Self::ImageText),
173 0x0051 => Some(Self::AudioText),
174 0x0052 => Some(Self::VideoText),
175 0x00FF => Some(Self::Custom),
176 _ => None,
177 }
178 }
179
180 #[must_use]
182 pub const fn as_u16(self) -> u16 {
183 self as u16
184 }
185}
186
187#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
189#[repr(u8)]
190pub enum Compression {
191 None = 0x00,
193 #[default]
195 ZstdL3 = 0x01,
196 ZstdL19 = 0x02,
198 Lz4 = 0x03,
200}
201
202impl Compression {
203 #[must_use]
205 pub fn from_u8(value: u8) -> Option<Self> {
206 match value {
207 0x00 => Some(Self::None),
208 0x01 => Some(Self::ZstdL3),
209 0x02 => Some(Self::ZstdL19),
210 0x03 => Some(Self::Lz4),
211 _ => None,
212 }
213 }
214
215 #[must_use]
217 pub const fn as_u8(self) -> u8 {
218 self as u8
219 }
220}
221
222#[derive(Debug, Clone, PartialEq, Eq)]
237pub struct Header {
238 pub version: (u8, u8),
240 pub dataset_type: DatasetType,
242 pub metadata_size: u32,
244 pub payload_size: u32,
246 pub uncompressed_size: u32,
248 pub compression: Compression,
250 pub flags: u8,
252 pub schema_size: u16,
254 pub num_rows: u64,
256}
257
258impl Header {
259 #[must_use]
261 pub fn new(dataset_type: DatasetType) -> Self {
262 contract_pre_configuration!();
264 Self {
265 version: (FORMAT_VERSION_MAJOR, FORMAT_VERSION_MINOR),
266 dataset_type,
267 metadata_size: 0,
268 payload_size: 0,
269 uncompressed_size: 0,
270 compression: Compression::default(),
271 flags: 0,
272 schema_size: 0,
273 num_rows: 0,
274 }
275 }
276
277 #[must_use]
279 pub fn to_bytes(&self) -> [u8; HEADER_SIZE] {
280 contract_pre_configuration!();
282 let mut buf = [0u8; HEADER_SIZE];
283
284 buf[0..4].copy_from_slice(&MAGIC);
286
287 buf[4] = self.version.0;
289 buf[5] = self.version.1;
290
291 let dt = self.dataset_type.as_u16().to_le_bytes();
293 buf[6..8].copy_from_slice(&dt);
294
295 buf[8..12].copy_from_slice(&self.metadata_size.to_le_bytes());
297
298 buf[12..16].copy_from_slice(&self.payload_size.to_le_bytes());
300
301 buf[16..20].copy_from_slice(&self.uncompressed_size.to_le_bytes());
303
304 buf[20] = self.compression.as_u8();
306
307 buf[21] = self.flags;
309
310 buf[22..24].copy_from_slice(&self.schema_size.to_le_bytes());
312
313 buf[24..32].copy_from_slice(&self.num_rows.to_le_bytes());
315
316 buf
317 }
318
319 pub fn from_bytes(buf: &[u8]) -> Result<Self> {
326 if buf.len() < HEADER_SIZE {
327 return Err(Error::Format(format!(
328 "Header too short: {} bytes, expected {}",
329 buf.len(),
330 HEADER_SIZE
331 )));
332 }
333
334 if buf[0..4] != MAGIC {
336 return Err(Error::Format(format!(
337 "Invalid magic: expected {:?}, got {:?}",
338 MAGIC,
339 &buf[0..4]
340 )));
341 }
342
343 contract_pre_configuration!(buf);
345
346 let version = (buf[4], buf[5]);
348 if version.0 > FORMAT_VERSION_MAJOR {
349 return Err(Error::Format(format!(
350 "Unsupported version: {}.{}, max supported: {}.{}",
351 version.0, version.1, FORMAT_VERSION_MAJOR, FORMAT_VERSION_MINOR
352 )));
353 }
354
355 let dt_value = u16::from_le_bytes([buf[6], buf[7]]);
357 let dataset_type = DatasetType::from_u16(dt_value)
358 .ok_or_else(|| Error::Format(format!("Unknown dataset type: 0x{:04X}", dt_value)))?;
359
360 let metadata_size = u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]);
362
363 let payload_size = u32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]);
365
366 let uncompressed_size = u32::from_le_bytes([buf[16], buf[17], buf[18], buf[19]]);
368
369 let compression = Compression::from_u8(buf[20])
371 .ok_or_else(|| Error::Format(format!("Unknown compression: 0x{:02X}", buf[20])))?;
372
373 let flags = buf[21];
375
376 let schema_size = u16::from_le_bytes([buf[22], buf[23]]);
378
379 let num_rows = u64::from_le_bytes([
381 buf[24], buf[25], buf[26], buf[27], buf[28], buf[29], buf[30], buf[31],
382 ]);
383
384 Ok(Self {
385 version,
386 dataset_type,
387 metadata_size,
388 payload_size,
389 uncompressed_size,
390 compression,
391 flags,
392 schema_size,
393 num_rows,
394 })
395 }
396
397 #[must_use]
399 pub const fn is_encrypted(&self) -> bool {
400 self.flags & flags::ENCRYPTED != 0
401 }
402
403 #[must_use]
405 pub const fn is_signed(&self) -> bool {
406 self.flags & flags::SIGNED != 0
407 }
408
409 #[must_use]
411 pub const fn is_streaming(&self) -> bool {
412 self.flags & flags::STREAMING != 0
413 }
414
415 #[must_use]
417 pub const fn is_licensed(&self) -> bool {
418 self.flags & flags::LICENSED != 0
419 }
420
421 #[must_use]
423 pub const fn is_trueno_native(&self) -> bool {
424 self.flags & flags::TRUENO_NATIVE != 0
425 }
426}
427
428#[derive(Debug, Clone, Default, Serialize, Deserialize)]
430pub struct Metadata {
431 pub name: Option<String>,
433 pub version: Option<String>,
435 pub license: Option<String>,
437 #[serde(default)]
439 pub tags: Vec<String>,
440 pub description: Option<String>,
442 pub citation: Option<String>,
444 pub created_at: Option<String>,
446 #[serde(default, skip_serializing_if = "Option::is_none")]
448 pub sha256: Option<String>,
449}
450
451#[cfg(feature = "provenance")]
462#[must_use]
463pub fn sha256_hex(data: &[u8]) -> String {
464 use sha2::{Digest, Sha256};
465
466 let mut hasher = Sha256::new();
467 hasher.update(data);
468 let result = hasher.finalize();
469
470 result.iter().fold(String::with_capacity(64), |mut s, b| {
472 use std::fmt::Write;
473 let _ = write!(s, "{b:02x}");
474 s
475 })
476}
477
478#[derive(Debug, Clone)]
480pub struct SaveOptions {
481 pub compression: Compression,
483 pub metadata: Option<Metadata>,
485 #[cfg(feature = "format-encryption")]
487 pub encryption: Option<encryption::EncryptionParams>,
488 #[cfg(feature = "format-signing")]
490 pub signing_key: Option<signing::SigningKeyPair>,
491 pub license: Option<license::LicenseBlock>,
493}
494
495impl Default for SaveOptions {
496 fn default() -> Self {
497 Self {
498 compression: Compression::ZstdL3,
499 metadata: None,
500 #[cfg(feature = "format-encryption")]
501 encryption: None,
502 #[cfg(feature = "format-signing")]
503 signing_key: None,
504 license: None,
505 }
506 }
507}
508
509impl SaveOptions {
510 #[must_use]
512 pub fn with_compression(mut self, compression: Compression) -> Self {
513 self.compression = compression;
514 self
515 }
516
517 #[must_use]
519 pub fn with_metadata(mut self, metadata: Metadata) -> Self {
520 self.metadata = Some(metadata);
521 self
522 }
523
524 #[cfg(feature = "format-encryption")]
526 #[must_use]
527 pub fn with_password(mut self, password: impl Into<String>) -> Self {
528 self.encryption = Some(encryption::EncryptionParams::password(password));
529 self
530 }
531
532 #[cfg(feature = "format-encryption")]
534 #[must_use]
535 pub fn with_recipient(mut self, public_key: [u8; 32]) -> Self {
536 self.encryption = Some(encryption::EncryptionParams::recipient(public_key));
537 self
538 }
539
540 #[cfg(feature = "format-signing")]
542 #[must_use]
543 pub fn with_signing_key(mut self, key: signing::SigningKeyPair) -> Self {
544 self.signing_key = Some(key);
545 self
546 }
547
548 #[must_use]
550 pub fn with_license(mut self, license: license::LicenseBlock) -> Self {
551 self.license = Some(license);
552 self
553 }
554}
555
556#[derive(Debug, Clone, Default)]
558pub struct LoadOptions {
559 #[cfg(feature = "format-encryption")]
561 pub decryption: Option<encryption::DecryptionParams>,
562 #[cfg(feature = "format-signing")]
565 pub trusted_keys: Vec<[u8; 32]>,
566 pub verify_license: bool,
568}
569
570impl LoadOptions {
571 #[cfg(feature = "format-encryption")]
573 #[must_use]
574 pub fn with_password(mut self, password: impl Into<String>) -> Self {
575 self.decryption = Some(encryption::DecryptionParams::password(password));
576 self
577 }
578
579 #[cfg(feature = "format-encryption")]
581 #[must_use]
582 pub fn with_private_key(mut self, key: [u8; 32]) -> Self {
583 self.decryption = Some(encryption::DecryptionParams::private_key(key));
584 self
585 }
586
587 #[cfg(feature = "format-signing")]
589 #[must_use]
590 pub fn with_trusted_key(mut self, key: [u8; 32]) -> Self {
591 self.trusted_keys.push(key);
592 self
593 }
594
595 #[must_use]
597 pub fn verify_license(mut self) -> Self {
598 self.verify_license = true;
599 self
600 }
601}
602
603#[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
609pub fn save<W: std::io::Write>(
610 writer: &mut W,
611 batches: &[arrow::array::RecordBatch],
612 dataset_type: DatasetType,
613 options: &SaveOptions,
614) -> Result<()> {
615 use arrow::ipc::writer::StreamWriter;
616
617 if batches.is_empty() {
618 return Err(Error::EmptyDataset);
619 }
620
621 let schema = batches[0].schema();
622
623 let mut schema_buf = Vec::new();
625 {
626 let mut schema_writer =
627 StreamWriter::try_new(&mut schema_buf, &schema).map_err(Error::Arrow)?;
628 schema_writer.finish().map_err(Error::Arrow)?;
629 }
630
631 let mut payload_buf = Vec::new();
633 {
634 let mut payload_writer =
635 StreamWriter::try_new(&mut payload_buf, &schema).map_err(Error::Arrow)?;
636 for batch in batches {
637 payload_writer.write(batch).map_err(Error::Arrow)?;
638 }
639 payload_writer.finish().map_err(Error::Arrow)?;
640 }
641
642 let uncompressed_size = payload_buf.len() as u32;
643
644 let compressed_payload = compress_payload(payload_buf, options.compression)?;
646
647 let mut header_flags: u8 = 0;
649
650 #[cfg(feature = "format-encryption")]
652 let (final_payload, encryption_header) = if let Some(ref enc_params) = options.encryption {
653 header_flags |= flags::ENCRYPTED;
654 let block = build_encryption_block(&compressed_payload, enc_params)?;
655 let hdr_size = encryption_block_header_size(block[0]);
656 (block[hdr_size..].to_vec(), block[..hdr_size].to_vec())
657 } else {
658 (compressed_payload, Vec::new())
659 };
660 #[cfg(not(feature = "format-encryption"))]
661 let (final_payload, encryption_header): (Vec<u8>, Vec<u8>) = (compressed_payload, Vec::new());
662
663 #[cfg(feature = "format-signing")]
665 if options.signing_key.is_some() {
666 header_flags |= flags::SIGNED;
667 }
668
669 if options.license.is_some() {
671 header_flags |= flags::LICENSED;
672 }
673
674 let metadata_buf = if let Some(ref meta) = options.metadata {
676 rmp_serde::to_vec(meta).map_err(|e| Error::Format(e.to_string()))?
677 } else {
678 rmp_serde::to_vec(&Metadata::default()).map_err(|e| Error::Format(e.to_string()))?
679 };
680
681 let num_rows: u64 = batches.iter().map(|b| b.num_rows() as u64).sum();
683
684 let header = Header {
686 version: (FORMAT_VERSION_MAJOR, FORMAT_VERSION_MINOR),
687 dataset_type,
688 metadata_size: metadata_buf.len() as u32,
689 payload_size: final_payload.len() as u32,
690 uncompressed_size,
691 compression: options.compression,
692 flags: header_flags,
693 schema_size: schema_buf.len() as u16,
694 num_rows,
695 };
696
697 let mut all_data = Vec::new();
699 let header_bytes = header.to_bytes();
700 all_data.extend_from_slice(&header_bytes);
701 all_data.extend_from_slice(&metadata_buf);
702 all_data.extend_from_slice(&schema_buf);
703 all_data.extend_from_slice(&encryption_header);
704 all_data.extend_from_slice(&final_payload);
705
706 #[cfg(feature = "format-signing")]
708 let signature_block: Option<[u8; signing::SignatureBlock::SIZE]> =
709 if let Some(ref key) = options.signing_key {
710 let sig_block = signing::SignatureBlock::sign(&all_data, key);
711 let sig_bytes = sig_block.to_bytes();
712 all_data.extend_from_slice(&sig_bytes);
713 Some(sig_bytes)
714 } else {
715 None
716 };
717 #[cfg(not(feature = "format-signing"))]
718 let signature_block: Option<[u8; 96]> = None;
719
720 let license_bytes: Option<Vec<u8>> = if let Some(ref lic) = options.license {
722 let lic_bytes = lic.to_bytes();
723 all_data.extend_from_slice(&lic_bytes);
724 Some(lic_bytes)
725 } else {
726 None
727 };
728
729 let checksum = crc32(&all_data);
731
732 writer.write_all(&header_bytes).map_err(Error::io_no_path)?;
734 writer.write_all(&metadata_buf).map_err(Error::io_no_path)?;
735 writer.write_all(&schema_buf).map_err(Error::io_no_path)?;
736 writer
737 .write_all(&encryption_header)
738 .map_err(Error::io_no_path)?;
739 writer
740 .write_all(&final_payload)
741 .map_err(Error::io_no_path)?;
742
743 if let Some(ref sig) = signature_block {
744 writer.write_all(sig).map_err(Error::io_no_path)?;
745 }
746
747 if let Some(ref lic) = license_bytes {
748 writer.write_all(lic).map_err(Error::io_no_path)?;
749 }
750
751 writer
752 .write_all(&checksum.to_le_bytes())
753 .map_err(Error::io_no_path)?;
754
755 Ok(())
756}
757
758fn compress_payload(payload: Vec<u8>, compression: Compression) -> Result<Vec<u8>> {
760 match compression {
761 Compression::None => Ok(payload),
762 Compression::ZstdL3 => zstd::encode_all(payload.as_slice(), 3).map_err(Error::io_no_path),
763 Compression::ZstdL19 => zstd::encode_all(payload.as_slice(), 19).map_err(Error::io_no_path),
764 Compression::Lz4 => {
765 let mut encoder = lz4_flex::frame::FrameEncoder::new(Vec::new());
766 std::io::Write::write_all(&mut encoder, &payload).map_err(Error::io_no_path)?;
767 encoder
768 .finish()
769 .map_err(|e| Error::Format(format!("LZ4 compression error: {e}")))
770 }
771 }
772}
773
774#[cfg(feature = "format-encryption")]
776fn encryption_block_header_size(mode: u8) -> usize {
777 if mode == encryption::mode::PASSWORD {
778 1 + 16 + 12 } else {
780 1 + 32 + 12 }
782}
783
784#[cfg(feature = "format-encryption")]
786fn build_encryption_block(
787 plaintext: &[u8],
788 params: &encryption::EncryptionParams,
789) -> Result<Vec<u8>> {
790 match ¶ms.mode {
791 encryption::EncryptionMode::Password(password) => {
792 let (mode, salt, nonce, ciphertext) =
793 encryption::encrypt_password(plaintext, password)?;
794 let mut block = Vec::with_capacity(1 + 16 + 12 + ciphertext.len());
795 block.push(mode);
796 block.extend_from_slice(&salt);
797 block.extend_from_slice(&nonce);
798 block.extend_from_slice(&ciphertext);
799 Ok(block)
800 }
801 encryption::EncryptionMode::Recipient {
802 recipient_public_key,
803 } => {
804 let (mode, ephemeral_pub, nonce, ciphertext) =
805 encryption::encrypt_recipient(plaintext, recipient_public_key)?;
806 let mut block = Vec::with_capacity(1 + 32 + 12 + ciphertext.len());
807 block.push(mode);
808 block.extend_from_slice(&ephemeral_pub);
809 block.extend_from_slice(&nonce);
810 block.extend_from_slice(&ciphertext);
811 Ok(block)
812 }
813 }
814}
815
816#[derive(Debug)]
818pub struct LoadedDataset {
819 pub header: Header,
821 pub metadata: Metadata,
823 pub batches: Vec<arrow::array::RecordBatch>,
825 pub license: Option<license::LicenseBlock>,
827 pub signer_public_key: Option<[u8; 32]>,
829}
830
831pub fn load<R: std::io::Read>(reader: &mut R) -> Result<LoadedDataset> {
840 load_with_options(reader, &LoadOptions::default())
841}
842
843#[allow(clippy::too_many_lines)]
850pub fn load_with_options<R: std::io::Read>(
851 reader: &mut R,
852 options: &LoadOptions,
853) -> Result<LoadedDataset> {
854 use arrow::ipc::reader::StreamReader;
855
856 let mut all_data = Vec::new();
858 reader
859 .read_to_end(&mut all_data)
860 .map_err(Error::io_no_path)?;
861
862 if all_data.len() < HEADER_SIZE + 4 {
863 return Err(Error::Format("File too small".to_string()));
864 }
865
866 let checksum_offset = all_data.len() - 4;
868 let stored_checksum = u32::from_le_bytes([
869 all_data[checksum_offset],
870 all_data[checksum_offset + 1],
871 all_data[checksum_offset + 2],
872 all_data[checksum_offset + 3],
873 ]);
874
875 let computed_checksum = crc32(&all_data[..checksum_offset]);
877 if stored_checksum != computed_checksum {
878 return Err(Error::ChecksumMismatch {
879 expected: stored_checksum,
880 actual: computed_checksum,
881 });
882 }
883
884 let header = Header::from_bytes(&all_data[..HEADER_SIZE])?;
886
887 let metadata_start = HEADER_SIZE;
889 let metadata_end = metadata_start + header.metadata_size as usize;
890 let metadata: Metadata = rmp_serde::from_slice(&all_data[metadata_start..metadata_end])
891 .map_err(|e| Error::Format(format!("Metadata parse error: {e}")))?;
892
893 let schema_end = metadata_end + header.schema_size as usize;
895
896 let encryption_header_size = determine_encryption_header_size(&header, &all_data, schema_end)?;
898
899 let payload_start = schema_end + encryption_header_size;
900 let payload_end = payload_start + header.payload_size as usize;
901
902 if payload_end > checksum_offset {
903 return Err(Error::Format("Payload extends beyond data".to_string()));
904 }
905
906 let compressed_payload: Vec<u8> = if header.is_encrypted() {
908 #[cfg(feature = "format-encryption")]
909 {
910 let enc_header = &all_data[schema_end..payload_start];
911 let ciphertext = &all_data[payload_start..payload_end];
912
913 let decryption_params = options.decryption.as_ref().ok_or_else(|| {
914 Error::Format("Dataset is encrypted but no decryption params provided".to_string())
915 })?;
916
917 decrypt_payload(enc_header, ciphertext, decryption_params)?
918 }
919 #[cfg(not(feature = "format-encryption"))]
920 {
921 return Err(Error::Format(
922 "Dataset is encrypted but format-encryption feature is not enabled".to_string(),
923 ));
924 }
925 } else {
926 all_data[payload_start..payload_end].to_vec()
927 };
928
929 let (signer_public_key, license_block) =
931 parse_trailing_blocks(&header, &all_data, payload_end, checksum_offset, options)?;
932
933 let decompressed_payload = decompress_payload(compressed_payload, header.compression)?;
935
936 let cursor = std::io::Cursor::new(decompressed_payload);
938 let stream_reader = StreamReader::try_new(cursor, None).map_err(Error::Arrow)?;
939
940 let batches: Vec<_> = stream_reader
941 .into_iter()
942 .collect::<std::result::Result<Vec<_>, _>>()
943 .map_err(Error::Arrow)?;
944
945 Ok(LoadedDataset {
946 header,
947 metadata,
948 batches,
949 license: license_block,
950 signer_public_key,
951 })
952}
953
954fn determine_encryption_header_size(
956 header: &Header,
957 all_data: &[u8],
958 schema_end: usize,
959) -> Result<usize> {
960 if !header.is_encrypted() {
961 return Ok(0);
962 }
963
964 if all_data.len() <= schema_end {
965 return Err(Error::Format("Missing encryption header".to_string()));
966 }
967
968 #[cfg(feature = "format-encryption")]
969 {
970 Ok(encryption_block_header_size(all_data[schema_end]))
971 }
972 #[cfg(not(feature = "format-encryption"))]
973 {
974 Err(Error::Format(
975 "Dataset is encrypted but format-encryption feature is not enabled".to_string(),
976 ))
977 }
978}
979
980fn parse_trailing_blocks(
982 header: &Header,
983 all_data: &[u8],
984 payload_end: usize,
985 checksum_offset: usize,
986 options: &LoadOptions,
987) -> Result<(Option<[u8; 32]>, Option<license::LicenseBlock>)> {
988 #[allow(unused_mut)]
989 let mut trailing_offset = payload_end;
990 #[allow(unused_mut)]
991 let mut signer_public_key: Option<[u8; 32]> = None;
992 let mut license_block: Option<license::LicenseBlock> = None;
993
994 if header.is_signed() {
995 #[cfg(feature = "format-signing")]
996 {
997 let sig_end = trailing_offset + signing::SignatureBlock::SIZE;
998 if sig_end > checksum_offset {
999 return Err(Error::Format(
1000 "Signature block extends beyond data".to_string(),
1001 ));
1002 }
1003
1004 let sig_block =
1005 signing::SignatureBlock::from_bytes(&all_data[trailing_offset..sig_end])?;
1006
1007 if !options.trusted_keys.is_empty() {
1008 let signed_data = &all_data[..trailing_offset];
1009 if !options.trusted_keys.contains(&sig_block.public_key) {
1010 return Err(Error::Format("Signer not in trusted keys list".to_string()));
1011 }
1012 sig_block.verify(signed_data)?;
1013 }
1014
1015 signer_public_key = Some(sig_block.public_key);
1016 trailing_offset = sig_end;
1017 }
1018 #[cfg(not(feature = "format-signing"))]
1019 {
1020 return Err(Error::Format(
1021 "Dataset is signed but format-signing feature is not enabled".to_string(),
1022 ));
1023 }
1024 }
1025
1026 if header.is_licensed() {
1027 if trailing_offset >= checksum_offset {
1028 return Err(Error::Format("Missing license block".to_string()));
1029 }
1030 let lic = license::LicenseBlock::from_bytes(&all_data[trailing_offset..checksum_offset])?;
1031 if options.verify_license {
1032 lic.verify()?;
1033 }
1034 license_block = Some(lic);
1035 }
1036
1037 Ok((signer_public_key, license_block))
1038}
1039
1040fn decompress_payload(payload: Vec<u8>, compression: Compression) -> Result<Vec<u8>> {
1042 match compression {
1043 Compression::None => Ok(payload),
1044 Compression::ZstdL3 | Compression::ZstdL19 => zstd::decode_all(payload.as_slice())
1045 .map_err(|e| Error::Format(format!("Zstd decompression error: {e}"))),
1046 Compression::Lz4 => {
1047 let mut decoder = lz4_flex::frame::FrameDecoder::new(payload.as_slice());
1048 let mut decompressed = Vec::new();
1049 std::io::Read::read_to_end(&mut decoder, &mut decompressed)
1050 .map_err(|e| Error::Format(format!("LZ4 decompression error: {e}")))?;
1051 Ok(decompressed)
1052 }
1053 }
1054}
1055
1056#[cfg(feature = "format-encryption")]
1058fn decrypt_payload(
1059 enc_header: &[u8],
1060 ciphertext: &[u8],
1061 params: &encryption::DecryptionParams,
1062) -> Result<Vec<u8>> {
1063 if enc_header.is_empty() {
1064 return Err(Error::Format("Empty encryption header".to_string()));
1065 }
1066
1067 let mode = enc_header[0];
1068
1069 match (mode, params) {
1070 (encryption::mode::PASSWORD, encryption::DecryptionParams::Password(password)) => {
1071 if enc_header.len() < 1 + 16 + 12 {
1072 return Err(Error::Format(
1073 "Invalid password encryption header".to_string(),
1074 ));
1075 }
1076 let mut salt = [0u8; 16];
1077 let mut nonce = [0u8; 12];
1078 salt.copy_from_slice(&enc_header[1..17]);
1079 nonce.copy_from_slice(&enc_header[17..29]);
1080
1081 encryption::decrypt_password(ciphertext, password, &salt, &nonce)
1082 }
1083 (encryption::mode::RECIPIENT, encryption::DecryptionParams::PrivateKey(private_key)) => {
1084 if enc_header.len() < 1 + 32 + 12 {
1085 return Err(Error::Format(
1086 "Invalid recipient encryption header".to_string(),
1087 ));
1088 }
1089 let mut ephemeral_pub = [0u8; 32];
1090 let mut nonce = [0u8; 12];
1091 ephemeral_pub.copy_from_slice(&enc_header[1..33]);
1092 nonce.copy_from_slice(&enc_header[33..45]);
1093
1094 encryption::decrypt_recipient(ciphertext, private_key, &ephemeral_pub, &nonce)
1095 }
1096 (encryption::mode::PASSWORD, encryption::DecryptionParams::PrivateKey(_)) => Err(
1097 Error::Format("Dataset encrypted with password but private key provided".to_string()),
1098 ),
1099 (encryption::mode::RECIPIENT, encryption::DecryptionParams::Password(_)) => Err(
1100 Error::Format("Dataset encrypted for recipient but password provided".to_string()),
1101 ),
1102 _ => Err(Error::Format(format!("Unknown encryption mode: {mode}"))),
1103 }
1104}
1105
1106pub fn save_to_file<P: AsRef<std::path::Path>>(
1112 path: P,
1113 batches: &[arrow::array::RecordBatch],
1114 dataset_type: DatasetType,
1115 options: &SaveOptions,
1116) -> Result<()> {
1117 let file = std::fs::File::create(path.as_ref())
1118 .map_err(|e| Error::io(e, path.as_ref().to_path_buf()))?;
1119 let mut writer = std::io::BufWriter::new(file);
1120 save(&mut writer, batches, dataset_type, options)
1121}
1122
1123pub fn load_from_file<P: AsRef<std::path::Path>>(path: P) -> Result<LoadedDataset> {
1129 load_from_file_with_options(path, &LoadOptions::default())
1130}
1131
1132pub fn load_from_file_with_options<P: AsRef<std::path::Path>>(
1139 path: P,
1140 options: &LoadOptions,
1141) -> Result<LoadedDataset> {
1142 let file = std::fs::File::open(path.as_ref())
1143 .map_err(|e| Error::io(e, path.as_ref().to_path_buf()))?;
1144 let mut reader = std::io::BufReader::new(file);
1145 load_with_options(&mut reader, options)
1146}
1147
1148#[cfg(test)]
1149mod tests;