use std::collections::BTreeMap;
use std::ops::Range;
#[cfg(feature = "async")]
use std::sync::MutexGuard;
use std::sync::{Arc, Mutex, OnceLock};
use bytes::Bytes;
use object_store::path::Path as ObjectPath;
use object_store::{ObjectStore, ObjectStoreExt};
use url::Url;
use crate::decode::DecodeOptions;
use crate::error::{Result, TensogramError};
use crate::framing;
use crate::metadata;
use crate::remote_scan_parse::footer_region_present;
use crate::scan_opts::RemoteScanOptions;
use crate::types::{DataObjectDescriptor, GlobalMetadata, IndexFrame};
use crate::wire::{
DATA_OBJECT_FOOTER_SIZE, DataObjectFlags, FRAME_COMMON_FOOTER_SIZE, FRAME_END,
FRAME_HEADER_SIZE, FrameHeader, FrameType, MAGIC, MessageFlags, POSTAMBLE_SIZE, PREAMBLE_SIZE,
Postamble, Preamble,
};
const REMOTE_SCHEMES: &[&str] = &["s3", "s3a", "gs", "az", "azure", "http", "https"];
static SHARED_RUNTIME: OnceLock<std::result::Result<tokio::runtime::Runtime, String>> =
OnceLock::new();
fn shared_runtime() -> Result<&'static tokio::runtime::Runtime> {
SHARED_RUNTIME
.get_or_init(|| {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.thread_name("tensogram-remote-io")
.build()
.map_err(|e| format!("tokio runtime: {e}"))
})
.as_ref()
.map_err(|e| TensogramError::Remote(e.clone()))
}
fn block_on_shared<T, Fut>(future: Fut) -> Result<T>
where
T: Send,
Fut: std::future::Future<Output = std::result::Result<T, object_store::Error>> + Send,
{
let rt = shared_runtime()?;
let handle = rt.handle().clone();
match tokio::runtime::Handle::try_current() {
Err(_) => {
handle
.block_on(future)
.map_err(|e| TensogramError::Remote(e.to_string()))
}
Ok(current) if current.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread => {
tokio::task::block_in_place(|| {
handle
.block_on(future)
.map_err(|e| TensogramError::Remote(e.to_string()))
})
}
Ok(_) => {
std::thread::scope(|s| {
match s
.spawn(|| {
handle
.block_on(future)
.map_err(|e| TensogramError::Remote(e.to_string()))
})
.join()
{
Ok(result) => result,
Err(_) => Err(TensogramError::Remote(
"remote I/O thread panicked".to_string(),
)),
}
})
}
}
}
pub fn is_remote_url(source: &str) -> bool {
match source.find("://") {
Some(pos) => {
let scheme = &source[..pos];
REMOTE_SCHEMES
.iter()
.any(|s| s.eq_ignore_ascii_case(scheme))
}
None => false,
}
}
#[derive(Debug, Clone)]
struct CachedLayout {
offset: u64,
length: u64,
preamble: Preamble,
index: Option<IndexFrame>,
global_metadata: Option<GlobalMetadata>,
}
#[derive(Debug)]
enum BackwardOutcome {
NeedPreambleValidation {
msg_start: u64,
length: u64,
first_footer_offset: u64,
},
Format(&'static str),
Streaming,
}
fn parse_backward_postamble(pa_bytes: &[u8], snap: &ScanSnapshot) -> BackwardOutcome {
let min_message_size = (PREAMBLE_SIZE + POSTAMBLE_SIZE) as u64;
if pa_bytes.len() < POSTAMBLE_SIZE {
return BackwardOutcome::Format("short-fetch-bwd");
}
if &pa_bytes[POSTAMBLE_SIZE - crate::wire::END_MAGIC.len()..] != crate::wire::END_MAGIC {
return BackwardOutcome::Format("bad-end-magic-bwd");
}
let postamble = match Postamble::read_from(pa_bytes) {
Ok(p) => p,
Err(_) => return BackwardOutcome::Format("postamble-parse-error"),
};
let total = postamble.total_length;
if total == 0 {
return BackwardOutcome::Streaming;
}
if total < min_message_size {
return BackwardOutcome::Format("length-below-minimum-bwd");
}
let msg_start = match snap.prev.checked_sub(total) {
Some(s) => s,
None => return BackwardOutcome::Format("backward-arith-underflow"),
};
if msg_start < snap.next {
return BackwardOutcome::Format("backward-overlaps-forward");
}
BackwardOutcome::NeedPreambleValidation {
msg_start,
length: total,
first_footer_offset: postamble.first_footer_offset,
}
}
#[derive(Debug)]
enum BackwardCommit {
Format(&'static str),
Layout(CachedLayout),
}
fn validate_backward_preamble(
preamble_bytes: &[u8],
msg_start: u64,
length: u64,
) -> BackwardCommit {
if preamble_bytes.len() < PREAMBLE_SIZE {
return BackwardCommit::Format("short-fetch-bwd");
}
if &preamble_bytes[..MAGIC.len()] != MAGIC {
return BackwardCommit::Format("bad-magic-bwd");
}
let preamble = match Preamble::read_from(preamble_bytes) {
Ok(p) => p,
Err(_) => return BackwardCommit::Format("preamble-parse-error-bwd"),
};
if preamble.total_length == 0 {
return BackwardCommit::Format("streaming-preamble-non-tail");
}
if preamble.total_length != length {
return BackwardCommit::Format("preamble-postamble-length-mismatch");
}
BackwardCommit::Layout(CachedLayout {
offset: msg_start,
length,
preamble,
index: None,
global_metadata: None,
})
}
#[cfg(feature = "async")]
#[derive(Debug)]
enum EagerAction {
ScanForwardEager,
ScanBidir,
Discover,
}
#[derive(Debug)]
enum ForwardOutcome {
Hit {
offset: u64,
length: u64,
preamble: Preamble,
msg_end: u64,
},
HitBeyondBound {
offset: u64,
length: u64,
preamble: Preamble,
msg_end: u64,
},
Streaming(u64),
Terminate(&'static str),
}
fn same_message_as_forward(fwd: &ForwardOutcome, layout: &CachedLayout) -> bool {
matches!(
fwd,
ForwardOutcome::Hit { offset, length, .. }
if *offset == layout.offset && *length == layout.length
)
}
fn parse_forward_preamble(
preamble_bytes: &[u8],
pos: u64,
file_size: u64,
bound: u64,
) -> ForwardOutcome {
let min_message_size = (PREAMBLE_SIZE + POSTAMBLE_SIZE) as u64;
if preamble_bytes.len() < PREAMBLE_SIZE {
return ForwardOutcome::Terminate("short-fetch-fwd");
}
if &preamble_bytes[..MAGIC.len()] != MAGIC {
return ForwardOutcome::Terminate("bad-magic-fwd");
}
let preamble = match Preamble::read_from(preamble_bytes) {
Ok(p) => p,
Err(_) => return ForwardOutcome::Terminate("preamble-parse-error-fwd"),
};
let msg_len = preamble.total_length;
if msg_len == 0 {
let remaining = file_size - pos;
return ForwardOutcome::Streaming(remaining);
}
let end = match pos.checked_add(msg_len) {
Some(e) => e,
None => return ForwardOutcome::Terminate("length-out-of-range-fwd"),
};
if msg_len < min_message_size || end > file_size {
return ForwardOutcome::Terminate("length-out-of-range-fwd");
}
if end > bound {
return ForwardOutcome::HitBeyondBound {
offset: pos,
length: msg_len,
preamble,
msg_end: end,
};
}
ForwardOutcome::Hit {
offset: pos,
length: msg_len,
preamble,
msg_end: end,
}
}
pub(crate) struct RemoteBackend {
source_url: String,
store: Arc<dyn ObjectStore>,
path: ObjectPath,
file_size: u64,
state: Mutex<RemoteState>,
scan_opts: RemoteScanOptions,
}
#[derive(Debug, Default)]
struct RemoteState {
layouts: Vec<CachedLayout>,
next_scan_offset: u64,
suffix_rev: Vec<CachedLayout>,
prev_scan_offset: u64,
bwd_active: bool,
fwd_terminated: bool,
gap_closed: bool,
scan_epoch: u64,
}
#[derive(Debug, Clone, Copy)]
struct ScanSnapshot {
next: u64,
prev: u64,
epoch: u64,
}
fn emit_scan_mode(scan_opts: &RemoteScanOptions) {
let mode = if scan_opts.bidirectional {
"bidirectional"
} else {
"forward-only"
};
tracing::debug!(
target: "tensogram::remote_scan",
mode = mode,
"remote scan mode",
);
}
impl RemoteState {
fn scan_complete(&self) -> bool {
self.gap_closed || (self.fwd_terminated && self.suffix_rev.is_empty())
}
fn record_forward_hop(&mut self, layout: CachedLayout) {
debug_assert!(
layout.offset.checked_add(layout.length).is_some(),
"forward end must fit in u64; validated by scan_next before recording",
);
let end = layout.offset + layout.length;
let backward_live = self.bwd_active || !self.suffix_rev.is_empty();
debug_assert!(!backward_live || end <= self.prev_scan_offset);
let offset = layout.offset;
let length = layout.length;
self.layouts.push(layout);
self.next_scan_offset = end;
self.scan_epoch = self.scan_epoch.wrapping_add(1);
tracing::debug!(
target: "tensogram::remote_scan",
direction = "fwd",
offset = offset,
length = length,
"scan hop",
);
}
fn record_backward_hop(&mut self, layout: CachedLayout) {
debug_assert!(self.bwd_active);
debug_assert!(layout.offset >= self.next_scan_offset);
debug_assert_eq!(layout.offset + layout.length, self.prev_scan_offset);
let offset = layout.offset;
let length = layout.length;
self.prev_scan_offset = layout.offset;
self.suffix_rev.push(layout);
self.scan_epoch = self.scan_epoch.wrapping_add(1);
tracing::debug!(
target: "tensogram::remote_scan",
direction = "bwd",
offset = offset,
length = length,
"scan hop",
);
}
fn snapshot(&self) -> ScanSnapshot {
ScanSnapshot {
next: self.next_scan_offset,
prev: self.prev_scan_offset,
epoch: self.scan_epoch,
}
}
fn matches(&self, snap: &ScanSnapshot) -> bool {
self.next_scan_offset == snap.next
&& self.prev_scan_offset == snap.prev
&& self.scan_epoch == snap.epoch
}
fn disable_backward(&mut self, reason: &'static str) {
if !self.bwd_active && self.suffix_rev.is_empty() {
return;
}
self.bwd_active = false;
self.suffix_rev.clear();
self.scan_epoch = self.scan_epoch.wrapping_add(1);
tracing::debug!(
target: "tensogram::remote_scan",
reason = reason,
"backward walker disabled",
);
}
fn terminate_forward(&mut self, reason: &'static str) {
if !self.fwd_terminated {
self.fwd_terminated = true;
self.scan_epoch = self.scan_epoch.wrapping_add(1);
tracing::debug!(
target: "tensogram::remote_scan",
reason = reason,
"forward walker terminated",
);
}
self.disable_backward(reason);
}
fn close_gap(&mut self) {
debug_assert!(self.bwd_active);
debug_assert_eq!(self.next_scan_offset, self.prev_scan_offset);
let mut tail = std::mem::take(&mut self.suffix_rev);
tail.reverse();
self.layouts.extend(tail);
self.gap_closed = true;
self.bwd_active = false;
self.scan_epoch = self.scan_epoch.wrapping_add(1);
tracing::debug!(
target: "tensogram::remote_scan",
messages = self.layouts.len(),
"gap closed",
);
}
}
impl std::fmt::Debug for RemoteBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RemoteBackend")
.field("source", &self.source_url)
.field("file_size", &self.file_size)
.field("scan_opts", &self.scan_opts)
.field(
"messages",
&self
.state
.lock()
.map(|state| state.layouts.len())
.unwrap_or(0),
)
.finish()
}
}
impl RemoteBackend {
pub(crate) fn source_url(&self) -> &str {
&self.source_url
}
#[cfg(feature = "async")]
fn lock_state(&self) -> Result<MutexGuard<'_, RemoteState>> {
self.state
.lock()
.map_err(|_| TensogramError::Remote("remote state lock poisoned".to_string()))
}
pub(crate) fn open_with_scan_opts(
source: &str,
storage_options: &BTreeMap<String, String>,
scan_opts: RemoteScanOptions,
) -> Result<Self> {
emit_scan_mode(&scan_opts);
let url = Url::parse(source)
.map_err(|e| TensogramError::Remote(format!("invalid URL '{source}': {e}")))?;
let mut opts: Vec<(String, String)> = storage_options
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
if url.scheme() == "http" && !opts.iter().any(|(k, _)| k == "allow_http") {
opts.push(("allow_http".to_string(), "true".to_string()));
}
let (store, path) = object_store::parse_url_opts(&url, opts)
.map_err(|e| TensogramError::Remote(format!("cannot open '{source}': {e}")))?;
let store: Arc<dyn ObjectStore> = Arc::from(store);
let head_store = store.clone();
let head_path = path.clone();
let meta = block_on_shared(async move { head_store.head(&head_path).await })?;
let file_size = meta.size;
if file_size < (PREAMBLE_SIZE + POSTAMBLE_SIZE) as u64 {
return Err(TensogramError::Remote(format!(
"remote file too small ({file_size} bytes)"
)));
}
let backend = RemoteBackend {
source_url: source.to_string(),
store,
path,
file_size,
state: Mutex::new(RemoteState {
prev_scan_offset: file_size,
bwd_active: scan_opts.bidirectional,
..RemoteState::default()
}),
scan_opts,
};
{
let mut state = backend
.state
.lock()
.map_err(|_| TensogramError::Remote("remote state lock poisoned".to_string()))?;
backend.scan_next_locked(&mut state)?;
if state.layouts.is_empty() {
return Err(TensogramError::Remote(
"no valid messages found in remote file".to_string(),
));
}
}
Ok(backend)
}
fn get_range(&self, range: Range<u64>) -> Result<Bytes> {
let store = self.store.clone();
let path = self.path.clone();
block_on_shared(async move { store.get_range(&path, range).await })
}
fn scan_next_locked(&self, state: &mut RemoteState) -> Result<()> {
self.scan_fwd_step_locked(state, self.file_size)
}
fn scan_fwd_step_locked(&self, state: &mut RemoteState, bound: u64) -> Result<()> {
if state.scan_complete() {
return Ok(());
}
let min_message_size = (PREAMBLE_SIZE + POSTAMBLE_SIZE) as u64;
let pos = state.next_scan_offset;
if pos.saturating_add(min_message_size) > bound {
state.terminate_forward("eof");
return Ok(());
}
let preamble_bytes = self.get_range(pos..pos + PREAMBLE_SIZE as u64)?;
if &preamble_bytes[..MAGIC.len()] != MAGIC {
state.terminate_forward("bad-magic-fwd");
return Ok(());
}
let preamble = match Preamble::read_from(&preamble_bytes) {
Ok(p) => p,
Err(_) => {
state.terminate_forward("preamble-parse-error-fwd");
return Ok(());
}
};
let msg_len = preamble.total_length;
if msg_len == 0 {
let remaining = self.file_size - pos;
if remaining < min_message_size {
state.terminate_forward("streaming-tail-too-small");
return Ok(());
}
let end_magic_pos = self.file_size - crate::wire::END_MAGIC.len() as u64;
let end_bytes =
self.get_range(end_magic_pos..end_magic_pos + crate::wire::END_MAGIC.len() as u64)?;
if &end_bytes[..] != crate::wire::END_MAGIC {
state.terminate_forward("streaming-end-magic-mismatch");
return Ok(());
}
state.record_forward_hop(CachedLayout {
offset: pos,
length: remaining,
preamble,
index: None,
global_metadata: None,
});
state.terminate_forward("streaming-tail");
return Ok(());
}
match pos.checked_add(msg_len) {
Some(end) if msg_len >= min_message_size && end <= bound => {}
_ => {
state.terminate_forward("length-out-of-range-fwd");
return Ok(());
}
}
state.record_forward_hop(CachedLayout {
offset: pos,
length: msg_len,
preamble,
index: None,
global_metadata: None,
});
Ok(())
}
fn scan_step_locked(&self, state: &mut RemoteState) -> Result<()> {
if state.scan_complete() {
return Ok(());
}
if state.bwd_active && !state.fwd_terminated {
self.scan_bidir_round_locked(state)
} else {
let bound = self.forward_bound(state);
self.scan_fwd_step_locked(state, bound)
}
}
fn forward_bound(&self, state: &RemoteState) -> u64 {
if state.suffix_rev.is_empty() {
self.file_size
} else {
state.prev_scan_offset
}
}
fn scan_bidir_round_locked(&self, state: &mut RemoteState) -> Result<()> {
let snap = state.snapshot();
let bound = snap.prev;
let min_message_size = (PREAMBLE_SIZE + POSTAMBLE_SIZE) as u64;
if snap.prev < snap.next.saturating_add(min_message_size) {
if snap.next == snap.prev {
state.close_gap();
return Ok(());
}
state.disable_backward("gap-below-min-message-size");
let recovery_bound = self.forward_bound(state);
return self.scan_fwd_step_locked(state, recovery_bound);
}
let fwd_r = snap.next..snap.next + PREAMBLE_SIZE as u64;
let bwd_r = snap.prev - POSTAMBLE_SIZE as u64..snap.prev;
let (fwd_bytes, bwd_bytes) = block_on_shared({
let store = self.store.clone();
let path = self.path.clone();
let fwd_r2 = fwd_r.clone();
let bwd_r2 = bwd_r.clone();
async move {
let (f, b) = tokio::join!(
store.get_range(&path, fwd_r2),
store.get_range(&path, bwd_r2),
);
Ok::<_, object_store::Error>((f?, b?))
}
})?;
let bwd_outcome = parse_backward_postamble(&bwd_bytes, &snap);
let (candidate_preamble_bytes, candidate_footer_bytes) = match &bwd_outcome {
BackwardOutcome::NeedPreambleValidation {
msg_start,
length,
first_footer_offset,
} => {
let preamble = self.get_range(*msg_start..*msg_start + PREAMBLE_SIZE as u64)?;
let footer = if footer_region_present(*first_footer_offset, *length) {
let footer_start = msg_start.saturating_add(*first_footer_offset);
let footer_end = msg_start
.saturating_add(*length)
.saturating_sub(POSTAMBLE_SIZE as u64);
if footer_start < footer_end {
match self.get_range(footer_start..footer_end) {
Ok(b) => Some(b),
Err(e) => {
tracing::debug!(
target: "tensogram::remote_scan",
error = %e,
msg_start = *msg_start,
footer_start = footer_start,
footer_end = footer_end,
"eager footer fetch failed; falling back to lazy"
);
None
}
}
} else {
None
}
} else {
None
};
(Some(preamble), footer)
}
_ => (None, None),
};
let fwd_kind = parse_forward_preamble(&fwd_bytes, snap.next, self.file_size, bound);
self.apply_round_outcomes(
state,
fwd_kind,
bwd_outcome,
candidate_preamble_bytes.as_deref(),
candidate_footer_bytes.as_deref(),
);
Ok(())
}
fn apply_round_outcomes(
&self,
state: &mut RemoteState,
fwd: ForwardOutcome,
bwd: BackwardOutcome,
candidate_preamble_bytes: Option<&[u8]>,
candidate_footer_bytes: Option<&[u8]>,
) {
match bwd {
BackwardOutcome::Format(reason) => state.disable_backward(reason),
BackwardOutcome::Streaming => state.disable_backward("streaming-zero-bwd"),
BackwardOutcome::NeedPreambleValidation {
msg_start,
length,
first_footer_offset: _,
} => {
let validation = candidate_preamble_bytes
.map(|bytes| validate_backward_preamble(bytes, msg_start, length))
.unwrap_or(BackwardCommit::Format("missing-candidate-preamble"));
Self::commit_or_yield_backward(state, &fwd, validation, candidate_footer_bytes);
}
}
match fwd {
ForwardOutcome::Hit {
offset,
length,
preamble,
msg_end,
} => {
debug_assert_eq!(offset, state.next_scan_offset);
debug_assert_eq!(msg_end, offset + length);
state.record_forward_hop(CachedLayout {
offset,
length,
preamble,
index: None,
global_metadata: None,
});
}
ForwardOutcome::HitBeyondBound {
offset,
length,
preamble,
msg_end,
} => {
state.disable_backward("forward-exceeds-backward-bound");
debug_assert_eq!(offset, state.next_scan_offset);
debug_assert_eq!(msg_end, offset + length);
state.record_forward_hop(CachedLayout {
offset,
length,
preamble,
index: None,
global_metadata: None,
});
}
ForwardOutcome::Streaming(_remaining) => {
state.disable_backward("streaming-fwd-non-tail");
}
ForwardOutcome::Terminate(reason) => state.terminate_forward(reason),
}
if state.bwd_active && state.next_scan_offset == state.prev_scan_offset {
state.close_gap();
}
}
fn commit_or_yield_backward(
state: &mut RemoteState,
fwd: &ForwardOutcome,
validation: BackwardCommit,
candidate_footer_bytes: Option<&[u8]>,
) {
let mut layout = match validation {
BackwardCommit::Format(reason) => {
state.disable_backward(reason);
return;
}
BackwardCommit::Layout(layout) => layout,
};
if !state.bwd_active {
return;
}
if same_message_as_forward(fwd, &layout) {
return;
}
match fwd {
ForwardOutcome::Hit { msg_end, .. } if *msg_end > layout.offset => {
state.disable_backward("backward-overlaps-forward");
}
ForwardOutcome::HitBeyondBound { .. } => {
state.disable_backward("forward-exceeds-backward-bound");
}
_ => {
if let Some(footer_bytes) = candidate_footer_bytes {
Self::try_populate_eager_footer(&mut layout, footer_bytes);
}
state.record_backward_hop(layout);
}
}
}
fn try_populate_eager_footer(layout: &mut CachedLayout, footer_bytes: &[u8]) {
let flags = layout.preamble.flags;
if !(flags.has(MessageFlags::FOOTER_METADATA) && flags.has(MessageFlags::FOOTER_INDEX)) {
return;
}
if let Ok((metadata, index)) = Self::parse_footer_frames_into(footer_bytes) {
if let (Some(m), Some(i)) = (metadata, index) {
layout.global_metadata = Some(m);
layout.index = Some(i);
tracing::debug!(
target: "tensogram::remote_scan",
action = "footer_eager",
offset = layout.offset,
footer_bytes = footer_bytes.len(),
);
}
}
}
fn ensure_message_locked(&self, state: &mut RemoteState, msg_idx: usize) -> Result<()> {
while msg_idx >= state.layouts.len() && !state.scan_complete() {
self.scan_step_locked(state)?;
}
if msg_idx >= state.layouts.len() {
return Err(TensogramError::Framing(format!(
"message index {} out of range (count={})",
msg_idx,
state.layouts.len()
)));
}
Ok(())
}
fn scan_all_locked(&self, state: &mut RemoteState) -> Result<()> {
while !state.scan_complete() {
self.scan_step_locked(state)?;
}
Ok(())
}
fn scan_and_discover_next_locked(&self, state: &mut RemoteState) -> Result<()> {
if state.scan_complete() {
return Ok(());
}
let min_message_size = (PREAMBLE_SIZE + POSTAMBLE_SIZE) as u64;
let pos = state.next_scan_offset;
if pos.saturating_add(min_message_size) > self.file_size {
state.terminate_forward("eof");
return Ok(());
}
let chunk_size = (self.file_size - pos).min(256 * 1024);
let chunk = self.get_range(pos..pos + chunk_size)?;
if chunk.len() < PREAMBLE_SIZE || &chunk[..MAGIC.len()] != MAGIC {
state.terminate_forward("bad-magic-fwd");
return Ok(());
}
let preamble = match Preamble::read_from(&chunk[..PREAMBLE_SIZE]) {
Ok(p) => p,
Err(_) => {
state.terminate_forward("preamble-parse-error-fwd");
return Ok(());
}
};
let msg_len = preamble.total_length;
if msg_len == 0 {
let remaining = self.file_size - pos;
if remaining < min_message_size {
state.terminate_forward("streaming-tail-too-small");
return Ok(());
}
let end_magic_pos = self.file_size - crate::wire::END_MAGIC.len() as u64;
let end_bytes =
self.get_range(end_magic_pos..end_magic_pos + crate::wire::END_MAGIC.len() as u64)?;
if &end_bytes[..] != crate::wire::END_MAGIC {
state.terminate_forward("streaming-end-magic-mismatch");
return Ok(());
}
state.record_forward_hop(CachedLayout {
offset: pos,
length: remaining,
preamble,
index: None,
global_metadata: None,
});
state.terminate_forward("streaming-tail");
return Ok(());
}
match pos.checked_add(msg_len) {
Some(end) if msg_len >= min_message_size && end <= self.file_size => {}
_ => {
state.terminate_forward("length-out-of-range-fwd");
return Ok(());
}
}
let flags = preamble.flags;
let msg_idx = state.layouts.len();
state.record_forward_hop(CachedLayout {
offset: pos,
length: msg_len,
preamble,
index: None,
global_metadata: None,
});
if flags.has(MessageFlags::HEADER_METADATA) && flags.has(MessageFlags::HEADER_INDEX) {
let chunk_end = (msg_len as usize).min(chunk.len());
Self::parse_header_frames(state, msg_idx, &chunk[..chunk_end])?;
} else if flags.has(MessageFlags::FOOTER_METADATA) && flags.has(MessageFlags::FOOTER_INDEX)
{
self.discover_footer_layout_from_suffix_locked(state, msg_idx)?;
}
Ok(())
}
fn discover_footer_layout_from_suffix_locked(
&self,
state: &mut RemoteState,
msg_idx: usize,
) -> Result<()> {
let msg_offset = state.layouts[msg_idx].offset;
let msg_len = state.layouts[msg_idx].length;
let suffix_size = msg_len.min(256 * 1024);
let msg_end = msg_offset
.checked_add(msg_len)
.ok_or_else(|| TensogramError::Remote("message end overflow".to_string()))?;
let suffix_start = msg_end - suffix_size;
let suffix = self.get_range(suffix_start..msg_end)?;
if suffix.len() < POSTAMBLE_SIZE {
return Err(TensogramError::Remote(
"suffix too short for postamble".to_string(),
));
}
let pa_bytes = &suffix[suffix.len() - POSTAMBLE_SIZE..];
let postamble = Postamble::read_from(pa_bytes)?;
if postamble.first_footer_offset < PREAMBLE_SIZE as u64 {
return Err(TensogramError::Remote(format!(
"first_footer_offset ({}) is before preamble end ({PREAMBLE_SIZE})",
postamble.first_footer_offset
)));
}
let footer_abs_start = msg_offset
.checked_add(postamble.first_footer_offset)
.ok_or_else(|| TensogramError::Remote("footer offset overflow".to_string()))?;
let footer_abs_end = msg_end - POSTAMBLE_SIZE as u64;
if footer_abs_start >= footer_abs_end {
return Err(TensogramError::Remote(
"first_footer_offset points at or past postamble".to_string(),
));
}
if footer_abs_start >= suffix_start {
let local_start = (footer_abs_start - suffix_start) as usize;
let local_end = suffix.len() - POSTAMBLE_SIZE;
Self::parse_footer_frames(state, msg_idx, &suffix[local_start..local_end])
} else {
let footer_bytes = self.get_range(footer_abs_start..footer_abs_end)?;
Self::parse_footer_frames(state, msg_idx, &footer_bytes)
}
}
fn ensure_layout_eager_locked(&self, state: &mut RemoteState, msg_idx: usize) -> Result<()> {
while msg_idx >= state.layouts.len() && !state.scan_complete() {
if state.bwd_active && !state.fwd_terminated {
self.scan_bidir_round_locked(state)?;
} else {
self.scan_and_discover_next_locked(state)?;
}
}
if msg_idx >= state.layouts.len() {
return Err(TensogramError::Framing(format!(
"message index {} out of range (count={})",
msg_idx,
state.layouts.len()
)));
}
if state.layouts[msg_idx].global_metadata.is_some()
&& state.layouts[msg_idx].index.is_some()
{
return Ok(());
}
self.ensure_layout_locked(state, msg_idx)
}
fn ensure_layout_locked(&self, state: &mut RemoteState, msg_idx: usize) -> Result<()> {
self.ensure_message_locked(state, msg_idx)?;
if state.layouts[msg_idx].global_metadata.is_some()
&& state.layouts[msg_idx].index.is_some()
{
return Ok(());
}
let flags = state.layouts[msg_idx].preamble.flags;
if flags.has(MessageFlags::HEADER_METADATA) && flags.has(MessageFlags::HEADER_INDEX) {
self.discover_header_layout_locked(state, msg_idx)?;
} else if flags.has(MessageFlags::FOOTER_METADATA) && flags.has(MessageFlags::FOOTER_INDEX)
{
self.discover_footer_layout_locked(state, msg_idx)?;
} else {
return Err(TensogramError::Remote(
"remote access requires header-indexed or footer-indexed messages".to_string(),
));
}
Ok(())
}
fn discover_footer_layout_locked(&self, state: &mut RemoteState, msg_idx: usize) -> Result<()> {
let msg_offset = state.layouts[msg_idx].offset;
let msg_len = state.layouts[msg_idx].length;
let pa_offset = msg_offset
.checked_add(msg_len)
.and_then(|end| end.checked_sub(POSTAMBLE_SIZE as u64))
.ok_or_else(|| TensogramError::Remote("postamble offset overflow".to_string()))?;
let pa_bytes = self.get_range(pa_offset..pa_offset + POSTAMBLE_SIZE as u64)?;
let postamble = Postamble::read_from(&pa_bytes)?;
if postamble.first_footer_offset < PREAMBLE_SIZE as u64 {
return Err(TensogramError::Remote(format!(
"first_footer_offset ({}) is before preamble end ({})",
postamble.first_footer_offset, PREAMBLE_SIZE
)));
}
let footer_start = msg_offset
.checked_add(postamble.first_footer_offset)
.ok_or_else(|| TensogramError::Remote("footer offset overflow".to_string()))?;
let footer_end = pa_offset;
if footer_start >= footer_end {
return Err(TensogramError::Remote(
"first_footer_offset points at or past postamble".to_string(),
));
}
let footer_bytes = self.get_range(footer_start..footer_end)?;
Self::parse_footer_frames(state, msg_idx, &footer_bytes)
}
fn discover_header_layout_locked(&self, state: &mut RemoteState, msg_idx: usize) -> Result<()> {
let layout = &state.layouts[msg_idx];
let msg_offset = layout.offset;
let msg_len = layout.length;
let chunk_size = msg_len.min(256 * 1024);
let header_bytes = self.get_range(msg_offset..msg_offset + chunk_size)?;
Self::parse_header_frames(state, msg_idx, &header_bytes)
}
fn parse_header_frames(state: &mut RemoteState, msg_idx: usize, buf: &[u8]) -> Result<()> {
let min_frame_size = FRAME_HEADER_SIZE + FRAME_END.len();
let mut pos = PREAMBLE_SIZE;
while pos + FRAME_HEADER_SIZE <= buf.len() {
if &buf[pos..pos + 2] != b"FR" {
pos += 1;
continue;
}
let fh = FrameHeader::read_from(&buf[pos..])?;
let frame_total = usize::try_from(fh.total_length).map_err(|_| {
TensogramError::Remote("frame total_length does not fit in usize".to_string())
})?;
if frame_total < min_frame_size {
return Err(TensogramError::Remote(format!(
"frame total_length ({frame_total}) smaller than minimum ({min_frame_size})"
)));
}
let frame_end = match pos.checked_add(frame_total) {
Some(end) if end <= buf.len() => end,
_ => break,
};
if &buf[frame_end - FRAME_END.len()..frame_end] != FRAME_END {
return Err(TensogramError::Remote(
"frame missing ENDF trailer".to_string(),
));
}
let payload = &buf[pos + FRAME_HEADER_SIZE..frame_end - FRAME_COMMON_FOOTER_SIZE];
match fh.frame_type {
FrameType::HeaderMetadata => {
let meta = metadata::cbor_to_global_metadata(payload)?;
state.layouts[msg_idx].global_metadata = Some(meta);
}
FrameType::HeaderIndex => {
let idx = metadata::cbor_to_index(payload)?;
state.layouts[msg_idx].index = Some(idx);
}
FrameType::NTensorFrame | FrameType::PrecederMetadata => {
break;
}
_ => {}
}
let aligned = (frame_end.saturating_add(7)) & !7;
pos = aligned.min(buf.len());
}
if state.layouts[msg_idx].global_metadata.is_none() {
return Err(TensogramError::Remote(
"header region did not contain a metadata frame".to_string(),
));
}
if state.layouts[msg_idx].index.is_none() {
return Err(TensogramError::Remote(
"header region did not contain an index frame (header chunk may be too small)"
.to_string(),
));
}
Ok(())
}
fn parse_footer_frames_into(
buf: &[u8],
) -> Result<(Option<GlobalMetadata>, Option<IndexFrame>)> {
let min_frame_size = FRAME_HEADER_SIZE + FRAME_END.len();
let mut metadata: Option<GlobalMetadata> = None;
let mut index: Option<IndexFrame> = None;
let mut pos = 0;
while pos + FRAME_HEADER_SIZE <= buf.len() {
if &buf[pos..pos + 2] != b"FR" {
pos += 1;
continue;
}
let fh = FrameHeader::read_from(&buf[pos..])?;
let frame_total = usize::try_from(fh.total_length).map_err(|_| {
TensogramError::Remote(
"footer frame total_length does not fit in usize".to_string(),
)
})?;
if frame_total < min_frame_size {
return Err(TensogramError::Remote(format!(
"footer frame total_length ({frame_total}) smaller than minimum ({min_frame_size})"
)));
}
let frame_end = match pos.checked_add(frame_total) {
Some(end) if end <= buf.len() => end,
_ => break,
};
if &buf[frame_end - FRAME_END.len()..frame_end] != FRAME_END {
return Err(TensogramError::Remote(
"footer frame missing ENDF trailer".to_string(),
));
}
let payload = &buf[pos + FRAME_HEADER_SIZE..frame_end - FRAME_COMMON_FOOTER_SIZE];
match fh.frame_type {
FrameType::FooterMetadata => {
metadata = Some(metadata::cbor_to_global_metadata(payload)?);
}
FrameType::FooterIndex => {
index = Some(metadata::cbor_to_index(payload)?);
}
_ => {}
}
let aligned = (frame_end.saturating_add(7)) & !7;
pos = aligned.min(buf.len());
}
Ok((metadata, index))
}
fn parse_footer_frames(state: &mut RemoteState, msg_idx: usize, buf: &[u8]) -> Result<()> {
let (metadata, index) = Self::parse_footer_frames_into(buf)?;
if let Some(m) = metadata {
state.layouts[msg_idx].global_metadata = Some(m);
}
if let Some(i) = index {
state.layouts[msg_idx].index = Some(i);
}
if state.layouts[msg_idx].global_metadata.is_none() {
return Err(TensogramError::Remote(
"footer region did not contain a metadata frame".to_string(),
));
}
if state.layouts[msg_idx].index.is_none() {
return Err(TensogramError::Remote(
"footer region did not contain an index frame".to_string(),
));
}
Ok(())
}
pub(crate) fn message_count(&self) -> Result<usize> {
let mut state = self
.state
.lock()
.map_err(|_| TensogramError::Remote("remote state lock poisoned".to_string()))?;
self.scan_all_locked(&mut state)?;
Ok(state.layouts.len())
}
pub(crate) fn message_layouts(&self) -> Result<Vec<crate::MessageLayout>> {
let mut state = self
.state
.lock()
.map_err(|_| TensogramError::Remote("remote state lock poisoned".to_string()))?;
self.scan_all_locked(&mut state)?;
Ok(state
.layouts
.iter()
.map(|l| crate::MessageLayout {
offset: l.offset,
length: l.length,
})
.collect())
}
pub(crate) fn read_message(&self, msg_idx: usize) -> Result<Vec<u8>> {
let (offset, length) = {
let mut state = self
.state
.lock()
.map_err(|_| TensogramError::Remote("remote state lock poisoned".to_string()))?;
self.ensure_message_locked(&mut state, msg_idx)?;
let layout = &state.layouts[msg_idx];
(layout.offset, layout.length)
};
let bytes = self.get_range(offset..offset + length)?;
Ok(bytes.to_vec())
}
pub(crate) fn read_metadata(&self, msg_idx: usize) -> Result<GlobalMetadata> {
let mut state = self
.state
.lock()
.map_err(|_| TensogramError::Remote("remote state lock poisoned".to_string()))?;
self.ensure_layout_eager_locked(&mut state, msg_idx)?;
state.layouts[msg_idx]
.global_metadata
.clone()
.ok_or_else(|| TensogramError::Remote("metadata not found".to_string()))
}
pub(crate) fn read_descriptors(
&self,
msg_idx: usize,
) -> Result<(GlobalMetadata, Vec<DataObjectDescriptor>)> {
let layout = {
let mut state = self
.state
.lock()
.map_err(|_| TensogramError::Remote("remote state lock poisoned".to_string()))?;
self.ensure_layout_eager_locked(&mut state, msg_idx)?;
state.layouts[msg_idx].clone()
};
let msg_offset = layout.offset;
if let Some(ref index) = layout.index {
if index.offsets.len() != index.lengths.len() {
return Err(TensogramError::Remote(format!(
"corrupt index: offsets.len()={} != lengths.len()={}",
index.offsets.len(),
index.lengths.len()
)));
}
let meta = layout
.global_metadata
.clone()
.ok_or_else(|| TensogramError::Remote("metadata not cached".to_string()))?;
let msg_length = layout.length;
let mut descriptors = Vec::with_capacity(index.offsets.len());
for i in 0..index.offsets.len() {
let desc = self.read_descriptor_only(
msg_offset,
msg_length,
index.offsets[i],
index.lengths[i],
)?;
descriptors.push(desc);
}
Ok((meta, descriptors))
} else {
let msg_bytes = self.read_message(msg_idx)?;
crate::decode::decode_descriptors(&msg_bytes)
}
}
fn read_descriptor_only(
&self,
msg_offset: u64,
msg_length: u64,
frame_offset_in_msg: u64,
frame_length: u64,
) -> Result<DataObjectDescriptor> {
const DESCRIPTOR_PREFIX_THRESHOLD: u64 = 64 * 1024;
let range =
Self::checked_frame_range(msg_offset, msg_length, frame_offset_in_msg, frame_length)?;
if frame_length <= DESCRIPTOR_PREFIX_THRESHOLD {
let frame_bytes = self.get_range(range.clone())?;
let (desc, _payload, _mask_region, _consumed) =
framing::decode_data_object_frame(&frame_bytes)?;
return Ok(desc);
}
let frame_start = range.start;
let frame_end = range.end;
let header_bytes = self.get_range(frame_start..frame_start + FRAME_HEADER_SIZE as u64)?;
let fh = FrameHeader::read_from(&header_bytes)?;
if !fh.frame_type.is_data_object() {
return Err(TensogramError::Remote(format!(
"expected DataObject frame, got {:?}",
fh.frame_type
)));
}
let footer_start = frame_end - DATA_OBJECT_FOOTER_SIZE as u64;
let footer_bytes = self.get_range(footer_start..frame_end)?;
if footer_bytes.len() < DATA_OBJECT_FOOTER_SIZE {
return Err(TensogramError::Remote("frame footer too short".to_string()));
}
let endf_pos = footer_bytes.len() - FRAME_END.len();
if &footer_bytes[endf_pos..] != FRAME_END {
return Err(TensogramError::Remote(
"frame missing ENDF trailer".to_string(),
));
}
let cbor_offset = u64::from_be_bytes(
footer_bytes[..8]
.try_into()
.map_err(|_| TensogramError::Remote("footer cbor_offset truncated".to_string()))?,
);
if cbor_offset < FRAME_HEADER_SIZE as u64 {
return Err(TensogramError::Remote(format!(
"cbor_offset ({cbor_offset}) below frame header size ({FRAME_HEADER_SIZE})"
)));
}
let cbor_after = fh.flags & DataObjectFlags::CBOR_AFTER_PAYLOAD != 0;
let cbor_start = frame_start
.checked_add(cbor_offset)
.ok_or_else(|| TensogramError::Remote("cbor_start overflow".to_string()))?;
if cbor_after {
if cbor_start >= footer_start {
return Err(TensogramError::Remote(
"cbor_offset points at or past footer".to_string(),
));
}
let cbor_bytes = self.get_range(cbor_start..footer_start)?;
metadata::cbor_to_object_descriptor(&cbor_bytes)
} else {
if cbor_start >= footer_start {
return Err(TensogramError::Remote(
"cbor_offset beyond frame body".to_string(),
));
}
let max_cbor_len = footer_start - cbor_start;
let mut prefix_size: u64 = 8192;
loop {
let read_end = (cbor_start + prefix_size).min(footer_start);
let prefix_bytes = self.get_range(cbor_start..read_end)?;
match metadata::cbor_to_object_descriptor(&prefix_bytes) {
Ok(desc) => return Ok(desc),
Err(_) if prefix_size < max_cbor_len => {
prefix_size = (prefix_size * 2).min(max_cbor_len);
}
Err(e) => return Err(e),
}
}
}
}
pub(crate) fn read_object(
&self,
msg_idx: usize,
obj_idx: usize,
options: &DecodeOptions,
) -> Result<(GlobalMetadata, DataObjectDescriptor, Vec<u8>)> {
let layout = {
let mut state = self
.state
.lock()
.map_err(|_| TensogramError::Remote("remote state lock poisoned".to_string()))?;
self.ensure_layout_eager_locked(&mut state, msg_idx)?;
state.layouts[msg_idx].clone()
};
let msg_offset = layout.offset;
if let Some(ref index) = layout.index {
Self::validate_index_access(index, obj_idx)?;
let meta = layout
.global_metadata
.clone()
.ok_or_else(|| TensogramError::Remote("metadata not cached".to_string()))?;
let range = Self::checked_frame_range(
msg_offset,
layout.length,
index.offsets[obj_idx],
index.lengths[obj_idx],
)?;
let frame_bytes = self.get_range(range)?;
let (desc, decoded) = crate::decode::decode_object_from_frame(&frame_bytes, options)
.map_err(|e| crate::error::with_object_index(e, obj_idx))?;
Ok((meta, desc, decoded))
} else {
let msg_bytes = self.read_message(msg_idx)?;
crate::decode::decode_object(&msg_bytes, obj_idx, options)
}
}
pub(crate) fn read_range_batch(
&self,
msg_indices: &[usize],
obj_idx: usize,
ranges: &[(u64, u64)],
options: &DecodeOptions,
) -> Result<Vec<(DataObjectDescriptor, Vec<Vec<u8>>)>> {
let mut byte_ranges = Vec::with_capacity(msg_indices.len());
{
let mut state = self
.state
.lock()
.map_err(|_| TensogramError::Remote("remote state lock poisoned".to_string()))?;
for &msg_idx in msg_indices {
self.ensure_layout_eager_locked(&mut state, msg_idx)?;
}
for &msg_idx in msg_indices {
let layout = &state.layouts[msg_idx];
if let Some(ref index) = layout.index {
Self::validate_index_access(index, obj_idx)?;
byte_ranges.push(Self::checked_frame_range(
layout.offset,
layout.length,
index.offsets[obj_idx],
index.lengths[obj_idx],
)?);
} else {
return Err(TensogramError::Remote(format!(
"message {} has no index frame; batch requires indexed messages",
msg_idx
)));
}
}
}
let store = self.store.clone();
let path = self.path.clone();
let all_bytes =
block_on_shared(async move { store.get_ranges(&path, &byte_ranges).await })?;
let mut results = Vec::with_capacity(msg_indices.len());
for frame_bytes in &all_bytes {
let (desc, payload, _mask_region, _consumed) =
framing::decode_data_object_frame(frame_bytes)?;
let parts = crate::decode::decode_range_from_payload(&desc, payload, ranges, options)?;
results.push((desc, parts));
}
Ok(results)
}
pub(crate) fn read_object_batch(
&self,
msg_indices: &[usize],
obj_idx: usize,
options: &DecodeOptions,
) -> Result<Vec<(GlobalMetadata, DataObjectDescriptor, Vec<u8>)>> {
let mut byte_ranges = Vec::with_capacity(msg_indices.len());
let mut metas = Vec::with_capacity(msg_indices.len());
{
let mut state = self
.state
.lock()
.map_err(|_| TensogramError::Remote("remote state lock poisoned".to_string()))?;
for &msg_idx in msg_indices {
self.ensure_layout_eager_locked(&mut state, msg_idx)?;
}
for &msg_idx in msg_indices {
let layout = &state.layouts[msg_idx];
if let Some(ref index) = layout.index {
Self::validate_index_access(index, obj_idx)?;
byte_ranges.push(Self::checked_frame_range(
layout.offset,
layout.length,
index.offsets[obj_idx],
index.lengths[obj_idx],
)?);
metas.push(layout.global_metadata.clone().ok_or_else(|| {
TensogramError::Remote("metadata not cached".to_string())
})?);
} else {
return Err(TensogramError::Remote(format!(
"message {} has no index frame; batch decode requires indexed messages",
msg_idx
)));
}
}
}
let store = self.store.clone();
let path = self.path.clone();
let all_bytes =
block_on_shared(async move { store.get_ranges(&path, &byte_ranges).await })?;
let mut results = Vec::with_capacity(msg_indices.len());
for (frame_bytes, meta) in all_bytes.iter().zip(metas) {
let (desc, decoded) = crate::decode::decode_object_from_frame(frame_bytes, options)
.map_err(|e| crate::error::with_object_index(e, obj_idx))?;
results.push((meta, desc, decoded));
}
Ok(results)
}
pub(crate) fn read_range(
&self,
msg_idx: usize,
obj_idx: usize,
ranges: &[(u64, u64)],
options: &DecodeOptions,
) -> Result<(DataObjectDescriptor, Vec<Vec<u8>>)> {
let layout = {
let mut state = self
.state
.lock()
.map_err(|_| TensogramError::Remote("remote state lock poisoned".to_string()))?;
self.ensure_layout_eager_locked(&mut state, msg_idx)?;
state.layouts[msg_idx].clone()
};
let msg_offset = layout.offset;
if let Some(ref index) = layout.index {
Self::validate_index_access(index, obj_idx)?;
let range = Self::checked_frame_range(
msg_offset,
layout.length,
index.offsets[obj_idx],
index.lengths[obj_idx],
)?;
let frame_bytes = self.get_range(range)?;
let (desc, payload, _mask_region, _consumed) =
framing::decode_data_object_frame(&frame_bytes)?;
let parts = crate::decode::decode_range_from_payload(&desc, payload, ranges, options)?;
Ok((desc, parts))
} else {
let msg_bytes = self.read_message(msg_idx)?;
crate::decode::decode_range(&msg_bytes, obj_idx, ranges, options)
}
}
fn validate_index_access(index: &IndexFrame, obj_idx: usize) -> Result<()> {
if index.offsets.len() != index.lengths.len() {
return Err(TensogramError::Remote(format!(
"corrupt index: offsets.len()={} != lengths.len()={}",
index.offsets.len(),
index.lengths.len()
)));
}
if obj_idx >= index.offsets.len() {
return Err(TensogramError::Object(format!(
"object index {} out of range (count={})",
obj_idx,
index.offsets.len()
)));
}
Ok(())
}
fn checked_frame_range(
msg_offset: u64,
msg_length: u64,
frame_offset_in_msg: u64,
frame_length: u64,
) -> Result<Range<u64>> {
let start = msg_offset
.checked_add(frame_offset_in_msg)
.ok_or_else(|| TensogramError::Remote("frame offset overflow".to_string()))?;
let end = start
.checked_add(frame_length)
.ok_or_else(|| TensogramError::Remote("frame end overflow".to_string()))?;
let msg_end = msg_offset
.checked_add(msg_length)
.ok_or_else(|| TensogramError::Remote("message end overflow".to_string()))?;
if end > msg_end {
return Err(TensogramError::Remote(format!(
"indexed frame {start}..{end} exceeds message boundary {msg_end}"
)));
}
Ok(start..end)
}
}
#[cfg(feature = "async")]
impl RemoteBackend {
async fn get_range_async(&self, range: Range<u64>) -> Result<Bytes> {
self.store
.get_range(&self.path, range)
.await
.map_err(|e| TensogramError::Remote(e.to_string()))
}
pub(crate) async fn open_async_with_scan_opts(
source: &str,
storage_options: &BTreeMap<String, String>,
scan_opts: RemoteScanOptions,
) -> Result<Self> {
emit_scan_mode(&scan_opts);
let url = Url::parse(source)
.map_err(|e| TensogramError::Remote(format!("invalid URL '{source}': {e}")))?;
let mut opts: Vec<(String, String)> = storage_options
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
if url.scheme() == "http" && !opts.iter().any(|(k, _)| k == "allow_http") {
opts.push(("allow_http".to_string(), "true".to_string()));
}
let (store, path) = object_store::parse_url_opts(&url, opts)
.map_err(|e| TensogramError::Remote(format!("cannot open '{source}': {e}")))?;
let store: Arc<dyn ObjectStore> = Arc::from(store);
let meta = store
.head(&path)
.await
.map_err(|e| TensogramError::Remote(e.to_string()))?;
let file_size = meta.size;
if file_size < (PREAMBLE_SIZE + POSTAMBLE_SIZE) as u64 {
return Err(TensogramError::Remote(format!(
"remote file too small ({file_size} bytes)"
)));
}
let backend = RemoteBackend {
source_url: source.to_string(),
store,
path,
file_size,
state: Mutex::new(RemoteState {
prev_scan_offset: file_size,
bwd_active: scan_opts.bidirectional,
..RemoteState::default()
}),
scan_opts,
};
backend.scan_next_async().await?;
{
let state = backend.lock_state()?;
if state.layouts.is_empty() {
return Err(TensogramError::Remote(
"no valid messages found in remote file".to_string(),
));
}
}
Ok(backend)
}
async fn scan_next_async(&self) -> Result<()> {
self.scan_fwd_step_async(self.file_size).await
}
async fn scan_fwd_step_async(&self, bound: u64) -> Result<()> {
let min_message_size = (PREAMBLE_SIZE + POSTAMBLE_SIZE) as u64;
let pos = {
let state = self.lock_state()?;
if state.scan_complete() {
return Ok(());
}
state.next_scan_offset
};
if pos.saturating_add(min_message_size) > bound {
let mut state = self.lock_state()?;
if state.next_scan_offset == pos {
state.terminate_forward("eof");
}
return Ok(());
}
let preamble_bytes = self
.get_range_async(pos..pos + PREAMBLE_SIZE as u64)
.await?;
if &preamble_bytes[..MAGIC.len()] != MAGIC {
let mut state = self.lock_state()?;
if state.next_scan_offset == pos {
state.terminate_forward("bad-magic-fwd");
}
return Ok(());
}
let preamble = match Preamble::read_from(&preamble_bytes) {
Ok(preamble) => preamble,
Err(_) => {
let mut state = self.lock_state()?;
if state.next_scan_offset == pos {
state.terminate_forward("preamble-parse-error-fwd");
}
return Ok(());
}
};
let msg_len = preamble.total_length;
if msg_len == 0 {
let remaining = self.file_size - pos;
if remaining < min_message_size {
let mut state = self.lock_state()?;
if state.next_scan_offset == pos {
state.terminate_forward("streaming-tail-too-small");
}
return Ok(());
}
let end_magic_pos = self.file_size - crate::wire::END_MAGIC.len() as u64;
let end_bytes = self
.get_range_async(end_magic_pos..end_magic_pos + crate::wire::END_MAGIC.len() as u64)
.await?;
if &end_bytes[..] != crate::wire::END_MAGIC {
let mut state = self.lock_state()?;
if state.next_scan_offset == pos {
state.terminate_forward("streaming-end-magic-mismatch");
}
return Ok(());
}
let mut state = self.lock_state()?;
if state.scan_complete() || state.next_scan_offset != pos {
return Ok(());
}
state.record_forward_hop(CachedLayout {
offset: pos,
length: remaining,
preamble,
index: None,
global_metadata: None,
});
state.terminate_forward("streaming-tail");
return Ok(());
}
match pos.checked_add(msg_len) {
Some(end) if msg_len >= min_message_size && end <= bound => {}
_ => {
let mut state = self.lock_state()?;
if state.next_scan_offset == pos {
state.terminate_forward("length-out-of-range-fwd");
}
return Ok(());
}
}
let mut state = self.lock_state()?;
if state.scan_complete() || state.next_scan_offset != pos {
return Ok(());
}
state.record_forward_hop(CachedLayout {
offset: pos,
length: msg_len,
preamble,
index: None,
global_metadata: None,
});
Ok(())
}
async fn scan_step_async(&self) -> Result<()> {
let fwd_bound: Option<u64> = {
let state = self.lock_state()?;
if state.scan_complete() {
return Ok(());
}
if state.bwd_active && !state.fwd_terminated {
None
} else {
Some(self.forward_bound(&state))
}
};
match fwd_bound {
Some(bound) => self.scan_fwd_step_async(bound).await,
None => self.scan_bidir_round_async().await,
}
}
async fn scan_bidir_round_async(&self) -> Result<()> {
let snap_opt: Option<ScanSnapshot> = {
let state = self.lock_state()?;
if state.scan_complete() {
return Ok(());
}
if state.bwd_active && !state.fwd_terminated {
Some(state.snapshot())
} else {
None
}
};
let snap = match snap_opt {
Some(snap) => snap,
None => return self.scan_fwd_step_async(self.file_size).await,
};
let bound = snap.prev;
let min_message_size = (PREAMBLE_SIZE + POSTAMBLE_SIZE) as u64;
if snap.prev < snap.next.saturating_add(min_message_size) {
let recovery_bound = {
let mut state = self.lock_state()?;
if !state.matches(&snap) {
return Ok(());
}
if snap.next == snap.prev {
state.close_gap();
return Ok(());
}
state.disable_backward("gap-below-min-message-size");
self.forward_bound(&state)
};
return self.scan_fwd_step_async(recovery_bound).await;
}
let fwd_r = snap.next..snap.next + PREAMBLE_SIZE as u64;
let bwd_r = snap.prev - POSTAMBLE_SIZE as u64..snap.prev;
let (fwd_res, bwd_res) = tokio::join!(
self.get_range_async(fwd_r.clone()),
self.get_range_async(bwd_r.clone())
);
let fwd_bytes = fwd_res?;
let bwd_bytes = bwd_res?;
let bwd_outcome = parse_backward_postamble(&bwd_bytes, &snap);
let (candidate_preamble_bytes, candidate_footer_bytes) = match &bwd_outcome {
BackwardOutcome::NeedPreambleValidation {
msg_start,
length,
first_footer_offset,
} => {
let preamble_fut =
self.get_range_async(*msg_start..*msg_start + PREAMBLE_SIZE as u64);
if footer_region_present(*first_footer_offset, *length) {
let footer_start = msg_start.saturating_add(*first_footer_offset);
let footer_end = msg_start
.saturating_add(*length)
.saturating_sub(POSTAMBLE_SIZE as u64);
if footer_start < footer_end {
let footer_fut = self.get_range_async(footer_start..footer_end);
let (preamble_res, footer_res) = tokio::join!(preamble_fut, footer_fut);
let footer = match footer_res {
Ok(b) => Some(b),
Err(e) => {
tracing::debug!(
target: "tensogram::remote_scan",
error = %e,
msg_start = *msg_start,
footer_start = footer_start,
footer_end = footer_end,
"eager footer fetch failed; falling back to lazy"
);
None
}
};
(Some(preamble_res?), footer)
} else {
(Some(preamble_fut.await?), None)
}
} else {
(Some(preamble_fut.await?), None)
}
}
_ => (None, None),
};
let mut state = self.lock_state()?;
if !state.matches(&snap) {
return Ok(());
}
let fwd_outcome = parse_forward_preamble(&fwd_bytes, snap.next, self.file_size, bound);
self.apply_round_outcomes(
&mut state,
fwd_outcome,
bwd_outcome,
candidate_preamble_bytes.as_deref(),
candidate_footer_bytes.as_deref(),
);
Ok(())
}
async fn scan_pipelined_async(&self, target_fwd_count: Option<usize>) -> Result<bool> {
let (snap, committed_fwd_at_start) = {
let state = self.lock_state()?;
if state.scan_complete() {
return Ok(true);
}
if let Some(target) = target_fwd_count {
if state.layouts.len() >= target {
return Ok(true);
}
}
if !state.bwd_active || state.fwd_terminated {
return Ok(false);
}
(state.snapshot(), state.layouts.len())
};
let min_message_size = (PREAMBLE_SIZE + POSTAMBLE_SIZE) as u64;
let mut fwd_cursor = snap.next;
let mut bwd_cursor = snap.prev;
let mut pending: Option<(u64, u64)> = None;
let mut local_fwd: Vec<CachedLayout> = Vec::new();
let mut local_bwd: Vec<CachedLayout> = Vec::new();
loop {
if fwd_cursor == bwd_cursor {
break;
}
if let Some(target) = target_fwd_count {
if committed_fwd_at_start + local_fwd.len() >= target {
break;
}
}
if bwd_cursor < fwd_cursor.saturating_add(min_message_size) {
return Ok(false);
}
let fwd_r = fwd_cursor..fwd_cursor + PREAMBLE_SIZE as u64;
let bwd_r = bwd_cursor - POSTAMBLE_SIZE as u64..bwd_cursor;
let val_r_opt = pending.map(|(ms, _)| ms..ms + PREAMBLE_SIZE as u64);
let (fwd_bytes, bwd_postamble_bytes, val_bytes_opt) = match val_r_opt {
Some(val_r) => {
let (f, b, v) = tokio::join!(
self.get_range_async(fwd_r),
self.get_range_async(bwd_r),
self.get_range_async(val_r),
);
(f?, b?, Some(v?))
}
None => {
let (f, b) =
tokio::join!(self.get_range_async(fwd_r), self.get_range_async(bwd_r));
(f?, b?, None)
}
};
if let Some(val_bytes) = val_bytes_opt {
let (msg_start, length) = pending.take().expect("pending matched val_r_opt");
match validate_backward_preamble(&val_bytes, msg_start, length) {
BackwardCommit::Layout(layout) => local_bwd.push(layout),
BackwardCommit::Format(_) => return Ok(false),
}
}
let pre_iter_fwd_cursor = fwd_cursor;
let fwd_outcome =
parse_forward_preamble(&fwd_bytes, pre_iter_fwd_cursor, self.file_size, bwd_cursor);
let (fwd_offset, fwd_length) = match fwd_outcome {
ForwardOutcome::Hit {
offset,
length,
preamble,
msg_end,
} => {
local_fwd.push(CachedLayout {
offset,
length,
preamble,
index: None,
global_metadata: None,
});
fwd_cursor = msg_end;
(offset, length)
}
_ => return Ok(false),
};
let snap_local = ScanSnapshot {
next: pre_iter_fwd_cursor,
prev: bwd_cursor,
epoch: snap.epoch,
};
match parse_backward_postamble(&bwd_postamble_bytes, &snap_local) {
BackwardOutcome::NeedPreambleValidation {
msg_start,
length,
first_footer_offset: _,
} => {
if msg_start == fwd_offset && length == fwd_length {
continue;
}
if msg_start < fwd_cursor {
return Ok(false);
}
pending = Some((msg_start, length));
bwd_cursor = msg_start;
}
_ => return Ok(false),
}
}
if let Some((msg_start, length)) = pending {
let val_bytes = self
.get_range_async(msg_start..msg_start + PREAMBLE_SIZE as u64)
.await?;
match validate_backward_preamble(&val_bytes, msg_start, length) {
BackwardCommit::Layout(layout) => local_bwd.push(layout),
BackwardCommit::Format(_) => return Ok(false),
}
}
let mut state = self.lock_state()?;
if !state.matches(&snap) {
return Ok(false);
}
for layout in local_fwd {
state.record_forward_hop(layout);
}
for layout in local_bwd {
state.record_backward_hop(layout);
}
if state.next_scan_offset == state.prev_scan_offset {
state.close_gap();
}
Ok(true)
}
async fn ensure_message_async(&self, msg_idx: usize) -> Result<()> {
let needs_layout = {
let state = self.lock_state()?;
msg_idx >= state.layouts.len()
&& !state.scan_complete()
&& state.bwd_active
&& !state.fwd_terminated
};
if needs_layout {
let _ = self.scan_pipelined_async(Some(msg_idx + 1)).await?;
}
loop {
let ready = {
let state = self.lock_state()?;
if msg_idx < state.layouts.len() {
return Ok(());
}
if state.scan_complete() {
return Err(TensogramError::Framing(format!(
"message index {} out of range (count={})",
msg_idx,
state.layouts.len()
)));
}
false
};
if !ready {
self.scan_step_async().await?;
}
}
}
async fn scan_and_discover_next_async(&self) -> Result<()> {
let min_message_size = (PREAMBLE_SIZE + POSTAMBLE_SIZE) as u64;
let pos = {
let state = self.lock_state()?;
if state.scan_complete() {
return Ok(());
}
state.next_scan_offset
};
if pos.saturating_add(min_message_size) > self.file_size {
let mut state = self.lock_state()?;
if state.next_scan_offset == pos {
state.terminate_forward("eof");
}
return Ok(());
}
let chunk_size = (self.file_size - pos).min(256 * 1024);
let chunk = self.get_range_async(pos..pos + chunk_size).await?;
if chunk.len() < PREAMBLE_SIZE || &chunk[..MAGIC.len()] != MAGIC {
let mut state = self.lock_state()?;
if state.next_scan_offset == pos {
state.terminate_forward("bad-magic-fwd");
}
return Ok(());
}
let preamble = match Preamble::read_from(&chunk[..PREAMBLE_SIZE]) {
Ok(preamble) => preamble,
Err(_) => {
let mut state = self.lock_state()?;
if state.next_scan_offset == pos {
state.terminate_forward("preamble-parse-error-fwd");
}
return Ok(());
}
};
let msg_len = preamble.total_length;
if msg_len == 0 {
let remaining = self.file_size - pos;
if remaining < min_message_size {
let mut state = self.lock_state()?;
if state.next_scan_offset == pos {
state.terminate_forward("streaming-tail-too-small");
}
return Ok(());
}
let end_magic_pos = self.file_size - crate::wire::END_MAGIC.len() as u64;
let end_bytes = self
.get_range_async(end_magic_pos..end_magic_pos + crate::wire::END_MAGIC.len() as u64)
.await?;
if &end_bytes[..] != crate::wire::END_MAGIC {
let mut state = self.lock_state()?;
if state.next_scan_offset == pos {
state.terminate_forward("streaming-end-magic-mismatch");
}
return Ok(());
}
let mut state = self.lock_state()?;
if state.scan_complete() || state.next_scan_offset != pos {
return Ok(());
}
state.record_forward_hop(CachedLayout {
offset: pos,
length: remaining,
preamble,
index: None,
global_metadata: None,
});
state.terminate_forward("streaming-tail");
return Ok(());
}
match pos.checked_add(msg_len) {
Some(end) if msg_len >= min_message_size && end <= self.file_size => {}
_ => {
let mut state = self.lock_state()?;
if state.next_scan_offset == pos {
state.terminate_forward("length-out-of-range-fwd");
}
return Ok(());
}
}
let flags = preamble.flags;
let msg_idx = {
let mut state = self.lock_state()?;
if state.scan_complete() || state.next_scan_offset != pos {
return Ok(());
}
let msg_idx = state.layouts.len();
state.record_forward_hop(CachedLayout {
offset: pos,
length: msg_len,
preamble,
index: None,
global_metadata: None,
});
msg_idx
};
if flags.has(MessageFlags::HEADER_METADATA) && flags.has(MessageFlags::HEADER_INDEX) {
let chunk_end = (msg_len as usize).min(chunk.len());
let mut state = self.lock_state()?;
if msg_idx < state.layouts.len()
&& state.layouts[msg_idx].offset == pos
&& state.layouts[msg_idx].global_metadata.is_none()
&& state.layouts[msg_idx].index.is_none()
{
Self::parse_header_frames(&mut state, msg_idx, &chunk[..chunk_end])?;
}
} else if flags.has(MessageFlags::FOOTER_METADATA) && flags.has(MessageFlags::FOOTER_INDEX)
{
self.discover_footer_layout_from_suffix_async(msg_idx)
.await?;
}
Ok(())
}
async fn discover_footer_layout_from_suffix_async(&self, msg_idx: usize) -> Result<()> {
let (msg_offset, msg_len) = {
let state = self.lock_state()?;
let layout = state.layouts.get(msg_idx).ok_or_else(|| {
TensogramError::Framing(format!(
"message index {} out of range (count={})",
msg_idx,
state.layouts.len()
))
})?;
(layout.offset, layout.length)
};
let suffix_size = msg_len.min(256 * 1024);
let msg_end = msg_offset
.checked_add(msg_len)
.ok_or_else(|| TensogramError::Remote("message end overflow".to_string()))?;
let suffix_start = msg_end - suffix_size;
let suffix = self.get_range_async(suffix_start..msg_end).await?;
if suffix.len() < POSTAMBLE_SIZE {
return Err(TensogramError::Remote(
"suffix too short for postamble".to_string(),
));
}
let pa_bytes = &suffix[suffix.len() - POSTAMBLE_SIZE..];
let postamble = Postamble::read_from(pa_bytes)?;
if postamble.first_footer_offset < PREAMBLE_SIZE as u64 {
return Err(TensogramError::Remote(format!(
"first_footer_offset ({}) is before preamble end ({PREAMBLE_SIZE})",
postamble.first_footer_offset
)));
}
let footer_abs_start = msg_offset
.checked_add(postamble.first_footer_offset)
.ok_or_else(|| TensogramError::Remote("footer offset overflow".to_string()))?;
let footer_abs_end = msg_end - POSTAMBLE_SIZE as u64;
if footer_abs_start >= footer_abs_end {
return Err(TensogramError::Remote(
"first_footer_offset points at or past postamble".to_string(),
));
}
if footer_abs_start >= suffix_start {
let local_start = (footer_abs_start - suffix_start) as usize;
let local_end = suffix.len() - POSTAMBLE_SIZE;
let mut state = self.lock_state()?;
if state.layouts[msg_idx].global_metadata.is_some()
&& state.layouts[msg_idx].index.is_some()
{
return Ok(());
}
Self::parse_footer_frames(&mut state, msg_idx, &suffix[local_start..local_end])
} else {
let footer_bytes = self
.get_range_async(footer_abs_start..footer_abs_end)
.await?;
let mut state = self.lock_state()?;
if state.layouts[msg_idx].global_metadata.is_some()
&& state.layouts[msg_idx].index.is_some()
{
return Ok(());
}
Self::parse_footer_frames(&mut state, msg_idx, &footer_bytes)
}
}
async fn ensure_layout_eager_async(&self, msg_idx: usize) -> Result<()> {
let needs_layout = {
let state = self.lock_state()?;
msg_idx >= state.layouts.len()
&& !state.scan_complete()
&& state.bwd_active
&& !state.fwd_terminated
};
if needs_layout {
let _ = self.scan_pipelined_async(Some(msg_idx + 1)).await?;
}
loop {
let action = {
let state = self.lock_state()?;
if let Some(layout) = state.layouts.get(msg_idx) {
if layout.global_metadata.is_some() && layout.index.is_some() {
return Ok(());
}
EagerAction::Discover
} else if state.scan_complete() {
return Err(TensogramError::Framing(format!(
"message index {} out of range (count={})",
msg_idx,
state.layouts.len()
)));
} else if state.bwd_active && !state.fwd_terminated {
EagerAction::ScanBidir
} else {
EagerAction::ScanForwardEager
}
};
match action {
EagerAction::ScanBidir => {
self.scan_bidir_round_async().await?;
}
EagerAction::ScanForwardEager => {
self.scan_and_discover_next_async().await?;
}
EagerAction::Discover => {
return self.ensure_layout_async(msg_idx).await;
}
}
}
}
async fn ensure_layout_async(&self, msg_idx: usize) -> Result<()> {
self.ensure_message_async(msg_idx).await?;
let flags = {
let state = self.lock_state()?;
let layout = &state.layouts[msg_idx];
if layout.global_metadata.is_some() && layout.index.is_some() {
return Ok(());
}
layout.preamble.flags
};
if flags.has(MessageFlags::HEADER_METADATA) && flags.has(MessageFlags::HEADER_INDEX) {
self.discover_header_layout_async(msg_idx).await?;
} else if flags.has(MessageFlags::FOOTER_METADATA) && flags.has(MessageFlags::FOOTER_INDEX)
{
self.discover_footer_layout_async(msg_idx).await?;
} else {
return Err(TensogramError::Remote(
"remote access requires header-indexed or footer-indexed messages".to_string(),
));
}
Ok(())
}
async fn discover_header_layout_async(&self, msg_idx: usize) -> Result<()> {
let (msg_offset, msg_len) = {
let state = self.lock_state()?;
let layout = state.layouts.get(msg_idx).ok_or_else(|| {
TensogramError::Framing(format!(
"message index {} out of range (count={})",
msg_idx,
state.layouts.len()
))
})?;
(layout.offset, layout.length)
};
let chunk_size = msg_len.min(256 * 1024);
let header_bytes = self
.get_range_async(msg_offset..msg_offset + chunk_size)
.await?;
let mut state = self.lock_state()?;
if state.layouts[msg_idx].global_metadata.is_some()
&& state.layouts[msg_idx].index.is_some()
{
return Ok(());
}
Self::parse_header_frames(&mut state, msg_idx, &header_bytes)
}
async fn discover_footer_layout_async(&self, msg_idx: usize) -> Result<()> {
let (msg_offset, msg_len) = {
let state = self.lock_state()?;
let layout = state.layouts.get(msg_idx).ok_or_else(|| {
TensogramError::Framing(format!(
"message index {} out of range (count={})",
msg_idx,
state.layouts.len()
))
})?;
(layout.offset, layout.length)
};
let pa_offset = msg_offset
.checked_add(msg_len)
.and_then(|end| end.checked_sub(POSTAMBLE_SIZE as u64))
.ok_or_else(|| TensogramError::Remote("postamble offset overflow".to_string()))?;
let pa_bytes = self
.get_range_async(pa_offset..pa_offset + POSTAMBLE_SIZE as u64)
.await?;
let postamble = Postamble::read_from(&pa_bytes)?;
if postamble.first_footer_offset < PREAMBLE_SIZE as u64 {
return Err(TensogramError::Remote(format!(
"first_footer_offset ({}) is before preamble end ({})",
postamble.first_footer_offset, PREAMBLE_SIZE
)));
}
let footer_start = msg_offset
.checked_add(postamble.first_footer_offset)
.ok_or_else(|| TensogramError::Remote("footer offset overflow".to_string()))?;
let footer_end = pa_offset;
if footer_start >= footer_end {
return Err(TensogramError::Remote(
"first_footer_offset points at or past postamble".to_string(),
));
}
let footer_bytes = self.get_range_async(footer_start..footer_end).await?;
let mut state = self.lock_state()?;
if state.layouts[msg_idx].global_metadata.is_some()
&& state.layouts[msg_idx].index.is_some()
{
return Ok(());
}
Self::parse_footer_frames(&mut state, msg_idx, &footer_bytes)
}
pub(crate) async fn message_count_async(&self) -> Result<usize> {
if self.scan_pipelined_async(None).await? {
let state = self.lock_state()?;
return Ok(state.layouts.len());
}
loop {
let done = {
let state = self.lock_state()?;
state.scan_complete()
};
if done {
break;
}
self.scan_step_async().await?;
}
let state = self.lock_state()?;
Ok(state.layouts.len())
}
pub(crate) async fn message_layouts_async(&self) -> Result<Vec<crate::MessageLayout>> {
if !self.scan_pipelined_async(None).await? {
loop {
let done = {
let state = self.lock_state()?;
state.scan_complete()
};
if done {
break;
}
self.scan_step_async().await?;
}
}
let state = self.lock_state()?;
Ok(state
.layouts
.iter()
.map(|l| crate::MessageLayout {
offset: l.offset,
length: l.length,
})
.collect())
}
pub(crate) async fn read_message_async(&self, msg_idx: usize) -> Result<Vec<u8>> {
self.ensure_message_async(msg_idx).await?;
let (offset, length) = {
let state = self.lock_state()?;
let layout = &state.layouts[msg_idx];
(layout.offset, layout.length)
};
let bytes = self.get_range_async(offset..offset + length).await?;
Ok(bytes.to_vec())
}
pub(crate) async fn read_metadata_async(&self, msg_idx: usize) -> Result<GlobalMetadata> {
self.ensure_layout_eager_async(msg_idx).await?;
let state = self.lock_state()?;
state.layouts[msg_idx]
.global_metadata
.clone()
.ok_or_else(|| TensogramError::Remote("metadata not found".to_string()))
}
pub(crate) async fn read_descriptors_async(
&self,
msg_idx: usize,
) -> Result<(GlobalMetadata, Vec<DataObjectDescriptor>)> {
self.ensure_layout_eager_async(msg_idx).await?;
let layout = {
let state = self.lock_state()?;
state.layouts[msg_idx].clone()
};
let msg_offset = layout.offset;
if let Some(ref index) = layout.index {
if index.offsets.len() != index.lengths.len() {
return Err(TensogramError::Remote(format!(
"corrupt index: offsets.len()={} != lengths.len()={}",
index.offsets.len(),
index.lengths.len()
)));
}
let meta = layout
.global_metadata
.clone()
.ok_or_else(|| TensogramError::Remote("metadata not cached".to_string()))?;
let msg_length = layout.length;
let mut descriptors = Vec::with_capacity(index.offsets.len());
for i in 0..index.offsets.len() {
let desc = self
.read_descriptor_only_async(
msg_offset,
msg_length,
index.offsets[i],
index.lengths[i],
)
.await?;
descriptors.push(desc);
}
Ok((meta, descriptors))
} else {
let msg_bytes = self.read_message_async(msg_idx).await?;
crate::decode::decode_descriptors(&msg_bytes)
}
}
async fn read_descriptor_only_async(
&self,
msg_offset: u64,
msg_length: u64,
frame_offset_in_msg: u64,
frame_length: u64,
) -> Result<DataObjectDescriptor> {
const DESCRIPTOR_PREFIX_THRESHOLD: u64 = 64 * 1024;
let range =
Self::checked_frame_range(msg_offset, msg_length, frame_offset_in_msg, frame_length)?;
if frame_length <= DESCRIPTOR_PREFIX_THRESHOLD {
let frame_bytes = self.get_range_async(range.clone()).await?;
let (desc, _payload, _mask_region, _consumed) =
framing::decode_data_object_frame(&frame_bytes)?;
return Ok(desc);
}
let frame_start = range.start;
let frame_end = range.end;
let header_bytes = self
.get_range_async(frame_start..frame_start + FRAME_HEADER_SIZE as u64)
.await?;
let fh = FrameHeader::read_from(&header_bytes)?;
if !fh.frame_type.is_data_object() {
return Err(TensogramError::Remote(format!(
"expected DataObject frame, got {:?}",
fh.frame_type
)));
}
let footer_start = frame_end - DATA_OBJECT_FOOTER_SIZE as u64;
let footer_bytes = self.get_range_async(footer_start..frame_end).await?;
if footer_bytes.len() < DATA_OBJECT_FOOTER_SIZE {
return Err(TensogramError::Remote("frame footer too short".to_string()));
}
let endf_pos = footer_bytes.len() - FRAME_END.len();
if &footer_bytes[endf_pos..] != FRAME_END {
return Err(TensogramError::Remote(
"frame missing ENDF trailer".to_string(),
));
}
let cbor_offset = u64::from_be_bytes(
footer_bytes[..8]
.try_into()
.map_err(|_| TensogramError::Remote("footer cbor_offset truncated".to_string()))?,
);
if cbor_offset < FRAME_HEADER_SIZE as u64 {
return Err(TensogramError::Remote(format!(
"cbor_offset ({cbor_offset}) below frame header size ({FRAME_HEADER_SIZE})"
)));
}
let cbor_after = fh.flags & DataObjectFlags::CBOR_AFTER_PAYLOAD != 0;
let cbor_start = frame_start
.checked_add(cbor_offset)
.ok_or_else(|| TensogramError::Remote("cbor_start overflow".to_string()))?;
if cbor_after {
if cbor_start >= footer_start {
return Err(TensogramError::Remote(
"cbor_offset points at or past footer".to_string(),
));
}
let cbor_bytes = self.get_range_async(cbor_start..footer_start).await?;
metadata::cbor_to_object_descriptor(&cbor_bytes)
} else {
if cbor_start >= footer_start {
return Err(TensogramError::Remote(
"cbor_offset beyond frame body".to_string(),
));
}
let max_cbor_len = footer_start - cbor_start;
let mut prefix_size: u64 = 8192;
loop {
let read_end = (cbor_start + prefix_size).min(footer_start);
let prefix_bytes = self.get_range_async(cbor_start..read_end).await?;
match metadata::cbor_to_object_descriptor(&prefix_bytes) {
Ok(desc) => return Ok(desc),
Err(_) if prefix_size < max_cbor_len => {
prefix_size = (prefix_size * 2).min(max_cbor_len);
}
Err(e) => return Err(e),
}
}
}
}
pub(crate) async fn read_object_async(
&self,
msg_idx: usize,
obj_idx: usize,
options: &DecodeOptions,
) -> Result<(GlobalMetadata, DataObjectDescriptor, Vec<u8>)> {
self.ensure_layout_eager_async(msg_idx).await?;
let layout = {
let state = self.lock_state()?;
state.layouts[msg_idx].clone()
};
let msg_offset = layout.offset;
if let Some(ref index) = layout.index {
Self::validate_index_access(index, obj_idx)?;
let meta = layout
.global_metadata
.clone()
.ok_or_else(|| TensogramError::Remote("metadata not cached".to_string()))?;
let range = Self::checked_frame_range(
msg_offset,
layout.length,
index.offsets[obj_idx],
index.lengths[obj_idx],
)?;
let frame_bytes = self.get_range_async(range).await?;
let (desc, decoded) = crate::decode::decode_object_from_frame(&frame_bytes, options)
.map_err(|e| crate::error::with_object_index(e, obj_idx))?;
Ok((meta, desc, decoded))
} else {
let msg_bytes = self.read_message_async(msg_idx).await?;
crate::decode::decode_object(&msg_bytes, obj_idx, options)
}
}
pub(crate) async fn ensure_all_layouts_batch_async(&self, msg_indices: &[usize]) -> Result<()> {
if msg_indices.is_empty() {
return Ok(());
}
let max_idx = msg_indices.iter().copied().max().unwrap_or(0);
let needs_discovery = {
let state = self.lock_state()?;
state.layouts.len() <= max_idx
&& !state.scan_complete()
&& state.bwd_active
&& !state.fwd_terminated
};
if needs_discovery {
let _ = self.scan_pipelined_async(Some(max_idx + 1)).await?;
}
loop {
let bidir = {
let state = self.lock_state()?;
if state.layouts.len() > max_idx || state.scan_complete() {
break;
}
state.bwd_active && !state.fwd_terminated
};
if bidir {
self.scan_bidir_round_async().await?;
} else {
self.scan_and_discover_next_async().await?;
}
}
{
let state = self.lock_state()?;
for &idx in msg_indices {
if idx >= state.layouts.len() {
return Err(TensogramError::Framing(format!(
"message index {} out of range (count={})",
idx,
state.layouts.len()
)));
}
}
}
let needs_layout: Vec<usize> = {
let state = self.lock_state()?;
msg_indices
.iter()
.copied()
.filter(|&idx| {
state
.layouts
.get(idx)
.is_none_or(|l| l.global_metadata.is_none() || l.index.is_none())
})
.collect()
};
if needs_layout.is_empty() {
return Ok(());
}
let mut fetch_ranges: Vec<Range<u64>> = Vec::new();
let mut fetch_map: Vec<(usize, bool)> = Vec::new(); {
let state = self.lock_state()?;
for &msg_idx in &needs_layout {
let layout = &state.layouts[msg_idx];
let flags = layout.preamble.flags;
if flags.has(MessageFlags::HEADER_METADATA) && flags.has(MessageFlags::HEADER_INDEX)
{
let chunk_size = layout.length.min(256 * 1024);
fetch_ranges.push(layout.offset..layout.offset + chunk_size);
fetch_map.push((msg_idx, true));
} else if flags.has(MessageFlags::FOOTER_METADATA)
&& flags.has(MessageFlags::FOOTER_INDEX)
{
let pa_offset = layout
.offset
.checked_add(layout.length)
.and_then(|end| end.checked_sub(POSTAMBLE_SIZE as u64))
.ok_or_else(|| {
TensogramError::Remote("postamble offset overflow".to_string())
})?;
fetch_ranges.push(pa_offset..pa_offset + POSTAMBLE_SIZE as u64);
fetch_map.push((msg_idx, false));
} else {
return Err(TensogramError::Remote(
"remote batch requires header-indexed or footer-indexed messages"
.to_string(),
));
}
}
}
let all_bytes = self
.store
.get_ranges(&self.path, &fetch_ranges)
.await
.map_err(|e| TensogramError::Remote(e.to_string()))?;
let mut footer_fetches: Vec<(usize, Range<u64>)> = Vec::new();
{
let mut state = self.lock_state()?;
for (bytes, &(msg_idx, is_header)) in all_bytes.iter().zip(fetch_map.iter()) {
if state.layouts[msg_idx].global_metadata.is_some()
&& state.layouts[msg_idx].index.is_some()
{
continue;
}
if is_header {
Self::parse_header_frames(&mut state, msg_idx, bytes)?;
} else {
let postamble = Postamble::read_from(bytes)?;
let layout = &state.layouts[msg_idx];
let footer_start = layout
.offset
.checked_add(postamble.first_footer_offset)
.ok_or_else(|| {
TensogramError::Remote("footer offset overflow".to_string())
})?;
let pa_offset = layout
.offset
.checked_add(layout.length)
.and_then(|end| end.checked_sub(POSTAMBLE_SIZE as u64))
.ok_or_else(|| {
TensogramError::Remote("postamble offset overflow".to_string())
})?;
footer_fetches.push((msg_idx, footer_start..pa_offset));
}
}
}
if !footer_fetches.is_empty() {
let footer_ranges: Vec<Range<u64>> =
footer_fetches.iter().map(|(_, r)| r.clone()).collect();
let footer_bytes = self
.store
.get_ranges(&self.path, &footer_ranges)
.await
.map_err(|e| TensogramError::Remote(e.to_string()))?;
let mut state = self.lock_state()?;
for (bytes, &(msg_idx, _)) in footer_bytes.iter().zip(footer_fetches.iter()) {
if state.layouts[msg_idx].global_metadata.is_some()
&& state.layouts[msg_idx].index.is_some()
{
continue;
}
Self::parse_footer_frames(&mut state, msg_idx, bytes)?;
}
}
Ok(())
}
pub(crate) async fn read_object_batch_async(
&self,
msg_indices: &[usize],
obj_idx: usize,
options: &DecodeOptions,
) -> Result<Vec<(GlobalMetadata, DataObjectDescriptor, Vec<u8>)>> {
self.ensure_all_layouts_batch_async(msg_indices).await?;
let mut byte_ranges = Vec::with_capacity(msg_indices.len());
let mut metas = Vec::with_capacity(msg_indices.len());
{
let state = self.lock_state()?;
for &msg_idx in msg_indices {
let layout = &state.layouts[msg_idx];
if let Some(ref index) = layout.index {
Self::validate_index_access(index, obj_idx)?;
byte_ranges.push(Self::checked_frame_range(
layout.offset,
layout.length,
index.offsets[obj_idx],
index.lengths[obj_idx],
)?);
metas.push(layout.global_metadata.clone().ok_or_else(|| {
TensogramError::Remote("metadata not cached".to_string())
})?);
} else {
return Err(TensogramError::Remote(format!(
"message {} has no index frame; batch decode requires indexed messages",
msg_idx
)));
}
}
}
let all_bytes = self
.store
.get_ranges(&self.path, &byte_ranges)
.await
.map_err(|e| TensogramError::Remote(e.to_string()))?;
let mut results = Vec::with_capacity(msg_indices.len());
for (frame_bytes, meta) in all_bytes.iter().zip(metas) {
let (desc, decoded) = crate::decode::decode_object_from_frame(frame_bytes, options)
.map_err(|e| crate::error::with_object_index(e, obj_idx))?;
results.push((meta, desc, decoded));
}
Ok(results)
}
pub(crate) async fn read_range_batch_async(
&self,
msg_indices: &[usize],
obj_idx: usize,
ranges: &[(u64, u64)],
options: &DecodeOptions,
) -> Result<Vec<(DataObjectDescriptor, Vec<Vec<u8>>)>> {
self.ensure_all_layouts_batch_async(msg_indices).await?;
let mut byte_ranges = Vec::with_capacity(msg_indices.len());
{
let state = self.lock_state()?;
for &msg_idx in msg_indices {
let layout = &state.layouts[msg_idx];
if let Some(ref index) = layout.index {
Self::validate_index_access(index, obj_idx)?;
let range = Self::checked_frame_range(
layout.offset,
layout.length,
index.offsets[obj_idx],
index.lengths[obj_idx],
)?;
byte_ranges.push(range);
} else {
return Err(TensogramError::Remote(format!(
"message {} has no index frame; batch range decode requires indexed messages",
msg_idx
)));
}
}
}
let all_bytes = self
.store
.get_ranges(&self.path, &byte_ranges)
.await
.map_err(|e| TensogramError::Remote(e.to_string()))?;
let mut results = Vec::with_capacity(msg_indices.len());
for frame_bytes in &all_bytes {
let (desc, payload, _mask_region, _consumed) =
framing::decode_data_object_frame(frame_bytes)?;
let parts = crate::decode::decode_range_from_payload(&desc, payload, ranges, options)?;
results.push((desc, parts));
}
Ok(results)
}
pub(crate) async fn read_range_async(
&self,
msg_idx: usize,
obj_idx: usize,
ranges: &[(u64, u64)],
options: &DecodeOptions,
) -> Result<(DataObjectDescriptor, Vec<Vec<u8>>)> {
self.ensure_layout_eager_async(msg_idx).await?;
let layout = {
let state = self.lock_state()?;
state.layouts[msg_idx].clone()
};
let msg_offset = layout.offset;
if let Some(ref index) = layout.index {
Self::validate_index_access(index, obj_idx)?;
let range = Self::checked_frame_range(
msg_offset,
layout.length,
index.offsets[obj_idx],
index.lengths[obj_idx],
)?;
let frame_bytes = self.get_range_async(range).await?;
let (desc, payload, _mask_region, _consumed) =
framing::decode_data_object_frame(&frame_bytes)?;
let parts = crate::decode::decode_range_from_payload(&desc, payload, ranges, options)?;
Ok((desc, parts))
} else {
let msg_bytes = self.read_message_async(msg_idx).await?;
crate::decode::decode_range(&msg_bytes, obj_idx, ranges, options)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn dummy_layout(offset: u64, length: u64) -> CachedLayout {
CachedLayout {
offset,
length,
preamble: Preamble {
version: 3,
flags: MessageFlags::default(),
reserved: 0,
total_length: length,
},
index: None,
global_metadata: None,
}
}
#[test]
fn scan_options_default_is_bidirectional() {
let opts = RemoteScanOptions::default();
assert!(
opts.bidirectional,
"default enables the pipelined bidirectional walker",
);
}
#[test]
fn scan_complete_truth_table_forward_only_equivalence() {
for fwd_terminated in [false, true] {
for gap_closed in [false, true] {
let s = RemoteState {
fwd_terminated,
gap_closed,
..RemoteState::default()
};
let expected = gap_closed || fwd_terminated;
assert_eq!(
s.scan_complete(),
expected,
"fwd_terminated={fwd_terminated} gap_closed={gap_closed}",
);
}
}
}
#[test]
fn record_forward_hop_advances_cursor_and_bumps_epoch() {
let mut s = RemoteState {
prev_scan_offset: 1000,
..RemoteState::default()
};
let epoch_before = s.scan_epoch;
s.record_forward_hop(dummy_layout(0, 100));
assert_eq!(s.next_scan_offset, 100);
assert_eq!(s.layouts.len(), 1);
assert_ne!(s.scan_epoch, epoch_before, "epoch must bump on forward hop");
}
#[test]
fn disable_backward_clears_suffix_and_disables() {
let mut s = RemoteState {
prev_scan_offset: 1000,
bwd_active: true,
..RemoteState::default()
};
s.suffix_rev.push(dummy_layout(800, 200));
s.disable_backward("test");
assert!(!s.bwd_active);
assert!(s.suffix_rev.is_empty());
}
#[test]
fn disable_backward_is_idempotent_when_already_inactive() {
let mut s = RemoteState::default();
let epoch_before = s.scan_epoch;
s.disable_backward("test");
assert_eq!(
s.scan_epoch, epoch_before,
"no-op disable on already-inactive state must not bump epoch",
);
}
#[test]
fn terminate_forward_clears_provisional_backward_state() {
let mut s = RemoteState {
prev_scan_offset: 1000,
bwd_active: true,
..RemoteState::default()
};
s.suffix_rev.push(dummy_layout(800, 200));
s.terminate_forward("test");
assert!(s.fwd_terminated);
assert!(
s.suffix_rev.is_empty(),
"fwd_terminated => suffix_rev empty (bidirectional is never recovery)",
);
assert!(!s.bwd_active);
}
fn snap_at(prev: u64) -> ScanSnapshot {
ScanSnapshot {
next: 0,
prev,
epoch: 0,
}
}
#[test]
fn parse_backward_postamble_short_fetch_yields_format() {
let buf = vec![0u8; POSTAMBLE_SIZE - 1];
match parse_backward_postamble(&buf, &snap_at(POSTAMBLE_SIZE as u64)) {
BackwardOutcome::Format("short-fetch-bwd") => {}
other => panic!("expected short-fetch-bwd, got {other:?}"),
}
}
#[test]
fn parse_backward_postamble_bad_end_magic_yields_format() {
let mut buf = vec![0u8; POSTAMBLE_SIZE];
buf[POSTAMBLE_SIZE - crate::wire::END_MAGIC.len()..].copy_from_slice(b"NOTMAGIC");
match parse_backward_postamble(&buf, &snap_at(POSTAMBLE_SIZE as u64)) {
BackwardOutcome::Format("bad-end-magic-bwd") => {}
other => panic!("expected bad-end-magic-bwd, got {other:?}"),
}
}
#[test]
fn parse_backward_postamble_arith_underflow_yields_format() {
let mut buf = vec![0u8; POSTAMBLE_SIZE];
buf[8..16].copy_from_slice(&u64::MAX.to_be_bytes());
buf[POSTAMBLE_SIZE - crate::wire::END_MAGIC.len()..]
.copy_from_slice(crate::wire::END_MAGIC);
match parse_backward_postamble(&buf, &snap_at(100)) {
BackwardOutcome::Format("backward-arith-underflow") => {}
other => panic!("expected backward-arith-underflow, got {other:?}"),
}
}
#[test]
fn parse_backward_postamble_overlap_with_forward_yields_format() {
let mut buf = vec![0u8; POSTAMBLE_SIZE];
let total: u64 = 200;
buf[8..16].copy_from_slice(&total.to_be_bytes());
buf[POSTAMBLE_SIZE - crate::wire::END_MAGIC.len()..]
.copy_from_slice(crate::wire::END_MAGIC);
let snap = ScanSnapshot {
next: 100,
prev: 250,
epoch: 0,
};
match parse_backward_postamble(&buf, &snap) {
BackwardOutcome::Format("backward-overlaps-forward") => {}
other => panic!("expected backward-overlaps-forward, got {other:?}"),
}
}
#[test]
fn parse_backward_postamble_propagates_first_footer_offset() {
let mut buf = vec![0u8; POSTAMBLE_SIZE];
let footer_offset: u64 = 96;
let total: u64 = 200;
buf[..8].copy_from_slice(&footer_offset.to_be_bytes());
buf[8..16].copy_from_slice(&total.to_be_bytes());
buf[POSTAMBLE_SIZE - crate::wire::END_MAGIC.len()..]
.copy_from_slice(crate::wire::END_MAGIC);
match parse_backward_postamble(&buf, &snap_at(200)) {
BackwardOutcome::NeedPreambleValidation {
msg_start,
length,
first_footer_offset,
} => {
assert_eq!(msg_start, 0);
assert_eq!(length, 200);
assert_eq!(first_footer_offset, footer_offset);
}
other => panic!("expected NeedPreambleValidation, got {other:?}"),
}
}
#[test]
fn validate_backward_preamble_short_fetch_yields_format() {
let buf = vec![0u8; PREAMBLE_SIZE - 1];
match validate_backward_preamble(&buf, 0, 100) {
BackwardCommit::Format("short-fetch-bwd") => {}
other => panic!("expected short-fetch-bwd, got {other:?}"),
}
}
#[test]
fn validate_backward_preamble_bad_magic_yields_format() {
let buf = vec![0u8; PREAMBLE_SIZE];
match validate_backward_preamble(&buf, 0, 100) {
BackwardCommit::Format("bad-magic-bwd") => {}
other => panic!("expected bad-magic-bwd, got {other:?}"),
}
}
#[test]
fn validate_backward_preamble_streaming_at_non_tail_yields_format() {
let preamble = Preamble {
version: crate::wire::WIRE_VERSION,
flags: MessageFlags::default(),
reserved: 0,
total_length: 0,
};
let mut buf = Vec::with_capacity(PREAMBLE_SIZE);
preamble.write_to(&mut buf);
match validate_backward_preamble(&buf, 0, 100) {
BackwardCommit::Format("streaming-preamble-non-tail") => {}
other => panic!("expected streaming-preamble-non-tail, got {other:?}"),
}
}
}
#[cfg(all(test, feature = "async"))]
mod bidir_http_tests {
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use http_body_util::Full;
use hyper::body::Bytes as HyperBytes;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
use super::*;
use crate::Dtype;
use crate::encode::{self, EncodeOptions};
use crate::types::{ByteOrder, DataObjectDescriptor, GlobalMetadata};
struct MockObjectStore {
request_count: Arc<AtomicUsize>,
range_request_count: Arc<AtomicUsize>,
addr: SocketAddr,
}
impl MockObjectStore {
async fn start(data: Vec<u8>) -> std::io::Result<Self> {
let data = Arc::new(data);
let request_count = Arc::new(AtomicUsize::new(0));
let range_request_count = Arc::new(AtomicUsize::new(0));
let listener = TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let data_clone = data.clone();
let count_clone = request_count.clone();
let range_count_clone = range_request_count.clone();
tokio::spawn(async move {
loop {
let (stream, _) = match listener.accept().await {
Ok(v) => v,
Err(_) => break,
};
let io = TokioIo::new(stream);
let data = data_clone.clone();
let count = count_clone.clone();
let range_count = range_count_clone.clone();
tokio::spawn(async move {
let _ = http1::Builder::new()
.serve_connection(
io,
service_fn(move |req: Request<hyper::body::Incoming>| {
let data = data.clone();
let count = count.clone();
let range_count = range_count.clone();
async move { handle(req, data, count, range_count) }
}),
)
.await;
});
}
});
Ok(MockObjectStore {
request_count,
range_request_count,
addr,
})
}
fn url(&self) -> String {
format!("http://127.0.0.1:{}/test.tgm", self.addr.port())
}
fn request_count(&self) -> usize {
self.request_count.load(Ordering::SeqCst)
}
fn range_request_count(&self) -> usize {
self.range_request_count.load(Ordering::SeqCst)
}
}
fn handle(
req: Request<hyper::body::Incoming>,
data: Arc<Vec<u8>>,
request_count: Arc<AtomicUsize>,
range_request_count: Arc<AtomicUsize>,
) -> std::io::Result<Response<Full<HyperBytes>>> {
request_count.fetch_add(1, Ordering::SeqCst);
if req.method() == hyper::Method::HEAD {
return Response::builder()
.status(StatusCode::OK)
.header("Content-Length", data.len())
.header("Accept-Ranges", "bytes")
.body(Full::new(HyperBytes::new()))
.map_err(std::io::Error::other);
}
if let Some(range_header) = req.headers().get("Range") {
range_request_count.fetch_add(1, Ordering::SeqCst);
let range_str = range_header.to_str().unwrap_or("");
if let Some((start, end_exclusive)) = parse_range(range_str, data.len()) {
let slice = &data[start..end_exclusive];
return Response::builder()
.status(StatusCode::PARTIAL_CONTENT)
.header(
"Content-Range",
format!("bytes {}-{}/{}", start, end_exclusive - 1, data.len()),
)
.header("Content-Length", slice.len())
.body(Full::new(HyperBytes::copy_from_slice(slice)))
.map_err(std::io::Error::other);
}
return Response::builder()
.status(StatusCode::RANGE_NOT_SATISFIABLE)
.header("Content-Range", format!("bytes */{}", data.len()))
.body(Full::new(HyperBytes::new()))
.map_err(std::io::Error::other);
}
Response::builder()
.status(StatusCode::OK)
.header("Content-Length", data.len())
.body(Full::new(HyperBytes::copy_from_slice(&data)))
.map_err(std::io::Error::other)
}
fn parse_range(header: &str, total: usize) -> Option<(usize, usize)> {
let header = header.strip_prefix("bytes=")?;
if total == 0 {
return None;
}
if let Some(suffix) = header.strip_prefix('-') {
let n: usize = suffix.parse().ok()?;
if n == 0 {
return None;
}
Some((total.saturating_sub(n), total))
} else if let Some((start_s, end_s)) = header.split_once('-') {
let start: usize = start_s.parse().ok()?;
if start >= total {
return None;
}
if end_s.is_empty() {
Some((start, total))
} else {
let end: usize = end_s.parse().ok()?;
if end < start {
return None;
}
Some((start, end.min(total - 1) + 1))
}
} else {
None
}
}
fn make_message(shape: Vec<u64>, fill: u8) -> Vec<u8> {
let strides = if shape.is_empty() {
vec![]
} else {
let mut s = vec![1u64; shape.len()];
for i in (0..shape.len() - 1).rev() {
s[i] = s[i + 1] * shape[i + 1];
}
s
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: shape.len() as u64,
shape: shape.clone(),
strides,
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let num_bytes = shape.iter().product::<u64>() as usize * 4;
let data = vec![fill; num_bytes];
let meta = GlobalMetadata {
extra: BTreeMap::new(),
..Default::default()
};
encode::encode(&meta, &[(&desc, &data)], &EncodeOptions::default())
.expect("encode test message")
}
fn concat_messages(parts: Vec<Vec<u8>>) -> Vec<u8> {
let total = parts.iter().map(|p| p.len()).sum();
let mut out = Vec::with_capacity(total);
for p in parts {
out.extend_from_slice(&p);
}
out
}
fn empty_storage() -> BTreeMap<String, String> {
BTreeMap::new()
}
fn forward_layouts(url: &str) -> Vec<(u64, u64)> {
let backend =
RemoteBackend::open_with_scan_opts(url, &empty_storage(), RemoteScanOptions::default())
.expect("open forward-only");
backend.message_count().expect("count");
let state = backend.state.lock().expect("lock");
state.layouts.iter().map(|l| (l.offset, l.length)).collect()
}
fn bidir_layouts(url: &str) -> Vec<(u64, u64)> {
let backend = RemoteBackend::open_with_scan_opts(
url,
&empty_storage(),
RemoteScanOptions {
bidirectional: true,
},
)
.expect("open bidirectional");
backend.message_count().expect("count");
let state = backend.state.lock().expect("lock");
state.layouts.iter().map(|l| (l.offset, l.length)).collect()
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn bidir_one_message_no_duplicate() {
let buf = make_message(vec![4], 42);
let server = MockObjectStore::start(buf).await.expect("start");
let layouts = bidir_layouts(&server.url());
assert_eq!(
layouts.len(),
1,
"1-message file must produce exactly one layout, got {layouts:?}",
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn bidir_two_messages_match_forward_only() {
let buf = concat_messages(vec![make_message(vec![4], 10), make_message(vec![8], 20)]);
let server = MockObjectStore::start(buf).await.expect("start");
let fwd = forward_layouts(&server.url());
let bidir = bidir_layouts(&server.url());
assert_eq!(
fwd, bidir,
"2-message file: bidirectional must match forward-only"
);
assert_eq!(bidir.len(), 2);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn bidir_three_messages_odd_count_meet() {
let buf = concat_messages(vec![
make_message(vec![4], 10),
make_message(vec![8], 20),
make_message(vec![16], 30),
]);
let server = MockObjectStore::start(buf).await.expect("start");
let fwd = forward_layouts(&server.url());
let bidir = bidir_layouts(&server.url());
assert_eq!(
fwd, bidir,
"3-message file: bidirectional must match forward-only"
);
assert_eq!(bidir.len(), 3);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn bidir_ten_messages_match_forward_only() {
let parts: Vec<Vec<u8>> = (0..10)
.map(|i| make_message(vec![4 + i as u64], i as u8))
.collect();
let buf = concat_messages(parts);
let server = MockObjectStore::start(buf).await.expect("start");
let fwd = forward_layouts(&server.url());
let bidir = bidir_layouts(&server.url());
assert_eq!(
fwd, bidir,
"10-message file: bidirectional must match forward-only"
);
assert_eq!(bidir.len(), 10);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn bidir_streaming_postamble_yields_backward() {
let mut buf = make_message(vec![4], 99);
let pa_start = buf.len() - POSTAMBLE_SIZE;
let total_off = pa_start + 8;
for byte in &mut buf[total_off..total_off + 8] {
*byte = 0;
}
let preamble_total_off = 16;
for byte in &mut buf[preamble_total_off..preamble_total_off + 8] {
*byte = 0;
}
let server = MockObjectStore::start(buf).await.expect("start");
let bidir = bidir_layouts(&server.url());
assert_eq!(
bidir.len(),
1,
"streaming-tail must still produce a single forward-discovered layout, got {bidir:?}",
);
assert!(server.request_count() > 0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn bidir_corrupt_end_magic_yields_backward() {
let mut buf = concat_messages(vec![make_message(vec![4], 10), make_message(vec![8], 20)]);
let len = buf.len();
buf[len - 8..].copy_from_slice(b"BADMAGIC");
let server = MockObjectStore::start(buf).await.expect("start");
let backend = RemoteBackend::open_with_scan_opts(
&server.url(),
&empty_storage(),
RemoteScanOptions {
bidirectional: true,
},
)
.expect("open");
let count = backend.message_count().expect("count");
assert_eq!(
count, 2,
"corrupt END_MAGIC: forward path still finds both messages",
);
assert!(server.range_request_count() > 0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn bidir_concurrent_readers_consistent() {
let parts: Vec<Vec<u8>> = (0..10)
.map(|i| make_message(vec![4 + i as u64], i as u8))
.collect();
let buf = concat_messages(parts);
let server = MockObjectStore::start(buf).await.expect("start");
let backend = Arc::new(
RemoteBackend::open_with_scan_opts(
&server.url(),
&empty_storage(),
RemoteScanOptions {
bidirectional: true,
},
)
.expect("open"),
);
let mut handles = Vec::new();
for idx in 0..10 {
let backend = backend.clone();
handles.push(tokio::spawn(async move {
backend.read_message_async(idx).await
}));
}
for h in handles {
h.await.expect("join").expect("read_message_async");
}
let final_count = backend.message_count_async().await.expect("count");
assert_eq!(
final_count, 10,
"concurrent readers must converge to 10 layouts"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn bidir_message_count_race_with_partial_target() {
let parts: Vec<Vec<u8>> = (0..20)
.map(|i| make_message(vec![4 + i as u64], i as u8))
.collect();
let buf = concat_messages(parts);
let server = MockObjectStore::start(buf).await.expect("start");
let backend = Arc::new(
RemoteBackend::open_with_scan_opts(
&server.url(),
&empty_storage(),
RemoteScanOptions {
bidirectional: true,
},
)
.expect("open"),
);
let mut handles = Vec::new();
for _ in 0..8 {
let backend = backend.clone();
handles.push(tokio::spawn(
async move { backend.read_metadata_async(5).await },
));
}
let count_handle = {
let backend = backend.clone();
tokio::spawn(async move { backend.message_count_async().await })
};
for h in handles {
h.await.expect("join").expect("read_metadata_async");
}
let count = count_handle.await.expect("join").expect("count");
assert_eq!(
count, 20,
"message_count_async must return the FULL count even when racing target-driven readers; \
a partial result would mean the pipelined scanner returned Ok(true) on snapshot mismatch"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn pipelined_one_message_completes_without_fallback() {
let buf = make_message(vec![4], 42);
let server = MockObjectStore::start(buf).await.expect("start");
let backend = RemoteBackend::open_with_scan_opts(
&server.url(),
&empty_storage(),
RemoteScanOptions {
bidirectional: true,
},
)
.expect("open bidirectional");
let pipelined_ok = backend.scan_pipelined_async(None).await.expect("pipelined");
assert!(
pipelined_ok,
"1-message file: pipelined walker must complete without falling back \
(same-message-meet must not be misread as backward-overlaps-forward)",
);
let state = backend.lock_state().expect("lock");
assert_eq!(state.layouts.len(), 1);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn pipelined_three_messages_odd_count_meet_completes_without_fallback() {
let buf = concat_messages(vec![
make_message(vec![4], 10),
make_message(vec![8], 20),
make_message(vec![16], 30),
]);
let server = MockObjectStore::start(buf).await.expect("start");
let backend = RemoteBackend::open_with_scan_opts(
&server.url(),
&empty_storage(),
RemoteScanOptions {
bidirectional: true,
},
)
.expect("open bidirectional");
let pipelined_ok = backend.scan_pipelined_async(None).await.expect("pipelined");
assert!(
pipelined_ok,
"3-message file: pipelined walker must complete without falling back \
on the odd-count meet",
);
let state = backend.lock_state().expect("lock");
assert_eq!(state.layouts.len(), 3);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn pipelined_overlapping_backward_candidate_bails_to_fallback() {
let msg_a = make_message(vec![10], 7);
let msg_a_len = msg_a.len();
let msg_b = make_message(vec![10], 8);
let msg_b_len = msg_b.len();
let mut buf = concat_messages(vec![msg_a, msg_b.clone()]);
let file_len = buf.len() as u64;
let fake_offset = (msg_a_len as u64) + (msg_b_len as u64) / 2;
let claimed_total = file_len - fake_offset;
let fake = fake_offset as usize;
buf[fake..fake + PREAMBLE_SIZE].copy_from_slice(&msg_b[..PREAMBLE_SIZE]);
buf[fake + 16..fake + 16 + 8].copy_from_slice(&claimed_total.to_be_bytes());
let pa_total_off = buf.len() - POSTAMBLE_SIZE + 8;
buf[pa_total_off..pa_total_off + 8].copy_from_slice(&claimed_total.to_be_bytes());
let server = MockObjectStore::start(buf).await.expect("start");
let backend = RemoteBackend::open_with_scan_opts(
&server.url(),
&empty_storage(),
RemoteScanOptions {
bidirectional: true,
},
)
.expect("open bidirectional");
let pipelined_ok = backend
.scan_pipelined_async(Some(2))
.await
.expect("pipelined");
assert!(
!pipelined_ok,
"pipelined walker MUST bail (Ok(false)) on a backward candidate that \
overlaps the just-committed forward message; otherwise the post-loop \
drain on a target-driven walk would commit overlapping layouts",
);
let count = backend.message_count_async().await.expect("count");
assert_eq!(
count, 2,
"after fallback only msg_a and msg_b are committed"
);
let state = backend.lock_state().expect("lock");
assert!(
!state.bwd_active,
"backward must be disabled after the per-round walker detects the overlap",
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn bidir_postamble_length_below_minimum_yields_backward() {
let mut buf = make_message(vec![4], 7);
let pa_start = buf.len() - POSTAMBLE_SIZE;
let total_off = pa_start + 8;
buf[total_off..total_off + 8].copy_from_slice(&1u64.to_be_bytes());
let server = MockObjectStore::start(buf).await.expect("start");
let backend = RemoteBackend::open_with_scan_opts(
&server.url(),
&empty_storage(),
RemoteScanOptions {
bidirectional: true,
},
)
.expect("open");
let count = backend.message_count().expect("count");
assert_eq!(
count, 1,
"below-min postamble: forward still finds the message"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn bidir_preamble_postamble_length_mismatch_yields_backward() {
let buf_a = make_message(vec![4], 11);
let buf_b = make_message(vec![8], 22);
let mut buf = concat_messages(vec![buf_a.clone(), buf_b.clone()]);
let actual_len_b = buf_b.len() as u64;
let fake_len_b = actual_len_b - 8;
let pa_start = buf.len() - POSTAMBLE_SIZE;
let total_off = pa_start + 8;
buf[total_off..total_off + 8].copy_from_slice(&fake_len_b.to_be_bytes());
let server = MockObjectStore::start(buf).await.expect("start");
let backend = RemoteBackend::open_with_scan_opts(
&server.url(),
&empty_storage(),
RemoteScanOptions {
bidirectional: true,
},
)
.expect("open");
let count = backend.message_count().expect("count");
assert_eq!(
count, 2,
"length-mismatch backward yield: forward still finds both messages",
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn bidir_streaming_in_middle_no_recovery() {
let normal_a = make_message(vec![4], 1);
let mut streaming = make_message(vec![8], 2);
let preamble_total_off = 16;
for byte in &mut streaming[preamble_total_off..preamble_total_off + 8] {
*byte = 0;
}
let pa_start = streaming.len() - POSTAMBLE_SIZE;
let total_off = pa_start + 8;
for byte in &mut streaming[total_off..total_off + 8] {
*byte = 0;
}
let normal_b = make_message(vec![16], 3);
let buf = concat_messages(vec![normal_a, streaming, normal_b]);
let server = MockObjectStore::start(buf).await.expect("start");
let fwd_backend = RemoteBackend::open_with_scan_opts(
&server.url(),
&empty_storage(),
RemoteScanOptions::default(),
)
.expect("open forward-only");
let fwd_count = fwd_backend.message_count().expect("count");
let bidir_backend = RemoteBackend::open_with_scan_opts(
&server.url(),
&empty_storage(),
RemoteScanOptions {
bidirectional: true,
},
)
.expect("open bidirectional");
let bidir_count = bidir_backend.message_count().expect("count");
assert_eq!(
bidir_count, fwd_count,
"streaming-in-middle: bidirectional must match forward-only ({fwd_count} layouts)",
);
}
fn make_streaming_message(shape: Vec<u64>, fill: u8) -> Vec<u8> {
use crate::streaming::StreamingEncoder;
let strides = if shape.is_empty() {
vec![]
} else {
let mut s = vec![1u64; shape.len()];
for i in (0..shape.len() - 1).rev() {
s[i] = s[i + 1] * shape[i + 1];
}
s
};
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: shape.len() as u64,
shape: shape.clone(),
strides,
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let num_bytes = shape.iter().product::<u64>() as usize * 4;
let data = vec![fill; num_bytes];
let meta = GlobalMetadata {
extra: BTreeMap::new(),
..Default::default()
};
let cursor = std::io::Cursor::new(Vec::<u8>::new());
let mut enc = StreamingEncoder::new(cursor, &meta, &EncodeOptions::default())
.expect("streaming encoder");
enc.write_object(&desc, &data).expect("write object");
let cursor = enc.finish_with_backfill().expect("finish_with_backfill");
cursor.into_inner()
}
fn metadata_index_populated(backend: &RemoteBackend, msg_idx: usize) -> bool {
let state = backend.state.lock().expect("lock");
let layout = &state.layouts[msg_idx];
layout.global_metadata.is_some() && layout.index.is_some()
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn bidir_eager_footer_populates_metadata_index_on_footer_indexed_file() {
let parts: Vec<Vec<u8>> = (0..6)
.map(|i| make_streaming_message(vec![4 + i as u64], i as u8))
.collect();
let buf = concat_messages(parts);
let server = MockObjectStore::start(buf).await.expect("start");
let backend = RemoteBackend::open_with_scan_opts(
&server.url(),
&empty_storage(),
RemoteScanOptions {
bidirectional: true,
},
)
.expect("open bidirectional");
let count = backend.message_count().expect("count");
assert_eq!(count, 6);
let any_pre_populated = (0..count).any(|i| metadata_index_populated(&backend, i));
assert!(
any_pre_populated,
"footer-indexed file: at least one layout must have eager-populated metadata + index",
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn bidir_header_indexed_skips_eager_footer_apply() {
let buf = concat_messages(vec![
make_message(vec![4], 1),
make_message(vec![8], 2),
make_message(vec![16], 3),
make_message(vec![32], 4),
]);
let server = MockObjectStore::start(buf).await.expect("start");
let backend = RemoteBackend::open_with_scan_opts(
&server.url(),
&empty_storage(),
RemoteScanOptions {
bidirectional: true,
},
)
.expect("open bidirectional");
let count = backend.message_count().expect("count");
assert_eq!(count, 4);
let state = backend.state.lock().expect("lock");
for (i, layout) in state.layouts.iter().enumerate() {
assert!(
layout.global_metadata.is_none(),
"layout[{i}]: header-indexed must remain lazy, but metadata was populated",
);
assert!(
layout.index.is_none(),
"layout[{i}]: header-indexed must remain lazy, but index was populated",
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn bidir_corrupt_footer_falls_back_to_lazy_without_poisoning_layout() {
let mut parts: Vec<Vec<u8>> = (0..3)
.map(|i| make_streaming_message(vec![4 + i as u64], i as u8))
.collect();
let target = &mut parts[1];
let pa_start = target.len() - POSTAMBLE_SIZE;
let first_footer_offset_bytes: [u8; 8] = target[pa_start..pa_start + 8]
.try_into()
.expect("postamble first_footer_offset slot");
let first_footer_offset = u64::from_be_bytes(first_footer_offset_bytes) as usize;
for byte in &mut target[first_footer_offset..first_footer_offset + 8] {
*byte ^= 0xFF;
}
let buf = concat_messages(parts);
let server = MockObjectStore::start(buf).await.expect("start");
let backend = RemoteBackend::open_with_scan_opts(
&server.url(),
&empty_storage(),
RemoteScanOptions {
bidirectional: true,
},
)
.expect("open bidirectional");
let count = backend.message_count().expect("count");
assert_eq!(
count, 3,
"corrupt footer must not abort layout commit: all 3 messages still discovered",
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remote_sync_read_api_roundtrip() {
let buf = make_message(vec![4], 0x42);
let server = MockObjectStore::start(buf).await.expect("start");
let backend = RemoteBackend::open_with_scan_opts(
&server.url(),
&empty_storage(),
RemoteScanOptions::default(),
)
.expect("open");
assert_eq!(backend.message_count().expect("count"), 1);
assert_eq!(backend.message_layouts().expect("layouts").len(), 1);
assert!(!backend.read_message(0).expect("read_message").is_empty());
let _meta = backend.read_metadata(0).expect("read_metadata");
let (_m, descs) = backend.read_descriptors(0).expect("read_descriptors");
assert_eq!(descs.len(), 1);
assert_eq!(descs[0].dtype, Dtype::Float32);
let opts = crate::DecodeOptions::default();
let (_m, desc, decoded) = backend.read_object(0, 0, &opts).expect("read_object");
assert_eq!(desc.shape, vec![4]);
assert_eq!(decoded.len(), 4 * 4);
let (_d, parts) = backend
.read_range(0, 0, &[(0, 4)], &opts)
.expect("read_range");
assert_eq!(parts.len(), 1);
assert_eq!(parts[0].len(), 4 * 4);
let obj_batch = backend
.read_object_batch(&[0], 0, &opts)
.expect("read_object_batch");
assert_eq!(obj_batch.len(), 1);
assert_eq!(obj_batch[0].2.len(), 4 * 4);
let range_batch = backend
.read_range_batch(&[0], 0, &[(0, 4)], &opts)
.expect("read_range_batch");
assert_eq!(range_batch.len(), 1);
assert_eq!(range_batch[0].1[0].len(), 4 * 4);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remote_async_read_api_roundtrip() {
let buf = make_message(vec![4], 0x07);
let server = MockObjectStore::start(buf).await.expect("start");
let backend = RemoteBackend::open_with_scan_opts(
&server.url(),
&empty_storage(),
RemoteScanOptions::default(),
)
.expect("open");
assert_eq!(backend.message_count_async().await.expect("count"), 1);
assert_eq!(
backend
.message_layouts_async()
.await
.expect("layouts")
.len(),
1
);
let _meta = backend
.read_metadata_async(0)
.await
.expect("read_metadata_async");
let (_m, descs) = backend
.read_descriptors_async(0)
.await
.expect("read_descriptors_async");
assert_eq!(descs.len(), 1);
let opts = crate::DecodeOptions::default();
let (_m, _d, decoded) = backend
.read_object_async(0, 0, &opts)
.await
.expect("read_object_async");
assert_eq!(decoded.len(), 4 * 4);
let (_d, parts) = backend
.read_range_async(0, 0, &[(0, 4)], &opts)
.await
.expect("read_range_async");
assert_eq!(parts[0].len(), 4 * 4);
let obj_batch = backend
.read_object_batch_async(&[0], 0, &opts)
.await
.expect("read_object_batch_async");
assert_eq!(obj_batch.len(), 1);
assert_eq!(obj_batch[0].2.len(), 4 * 4);
let range_batch = backend
.read_range_batch_async(&[0], 0, &[(0, 4)], &opts)
.await
.expect("read_range_batch_async");
assert_eq!(range_batch.len(), 1);
assert_eq!(range_batch[0].1[0].len(), 4 * 4);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remote_open_too_small_file_errors() {
let server = MockObjectStore::start(vec![0u8; 8]).await.expect("start");
let result = RemoteBackend::open_with_scan_opts(
&server.url(),
&empty_storage(),
RemoteScanOptions::default(),
);
assert!(result.is_err(), "tiny remote file must be rejected");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remote_open_async_roundtrip() {
let buf = concat_messages(vec![make_message(vec![4], 1), make_message(vec![8], 2)]);
let server = MockObjectStore::start(buf).await.expect("start");
let backend = RemoteBackend::open_async_with_scan_opts(
&server.url(),
&empty_storage(),
RemoteScanOptions::default(),
)
.await
.expect("open_async");
assert_eq!(backend.message_count_async().await.expect("count"), 2);
let _ = backend.source_url();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remote_open_async_too_small_errors() {
let server = MockObjectStore::start(vec![0u8; 8]).await.expect("start");
let result = RemoteBackend::open_async_with_scan_opts(
&server.url(),
&empty_storage(),
RemoteScanOptions::default(),
)
.await;
assert!(result.is_err(), "tiny remote file must be rejected (async)");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn remote_open_async_no_valid_messages_errors() {
let server = MockObjectStore::start(vec![0u8; 128]).await.expect("start");
let result = RemoteBackend::open_async_with_scan_opts(
&server.url(),
&empty_storage(),
RemoteScanOptions::default(),
)
.await;
assert!(
result.is_err(),
"no valid messages must be rejected (async)"
);
}
}
#[cfg(test)]
mod inmem_tests {
use std::collections::BTreeMap;
use std::sync::{Arc, Mutex};
use object_store::memory::InMemory;
use object_store::path::Path as ObjectPath;
use super::*;
use crate::Dtype;
use crate::encode::{self, EncodeOptions};
use crate::types::{ByteOrder, DataObjectDescriptor, GlobalMetadata};
const TEST_PATH: &str = "file.tgm";
fn run<F: std::future::Future>(fut: F) -> F::Output {
shared_runtime().expect("shared runtime").block_on(fut)
}
fn make_message(shape: Vec<u64>, fill: u8) -> Vec<u8> {
let strides = strides_for(&shape);
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: shape.len() as u64,
shape: shape.clone(),
strides,
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let num_bytes = shape.iter().product::<u64>() as usize * 4;
let data = vec![fill; num_bytes];
let meta = GlobalMetadata {
extra: BTreeMap::new(),
..Default::default()
};
encode::encode(&meta, &[(&desc, &data)], &EncodeOptions::default())
.expect("encode test message")
}
fn make_streaming_message(shape: Vec<u64>, fill: u8) -> Vec<u8> {
use crate::streaming::StreamingEncoder;
let strides = strides_for(&shape);
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: shape.len() as u64,
shape: shape.clone(),
strides,
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let num_bytes = shape.iter().product::<u64>() as usize * 4;
let data = vec![fill; num_bytes];
let meta = GlobalMetadata {
extra: BTreeMap::new(),
..Default::default()
};
let cursor = std::io::Cursor::new(Vec::<u8>::new());
let mut enc = StreamingEncoder::new(cursor, &meta, &EncodeOptions::default())
.expect("streaming encoder");
enc.write_object(&desc, &data).expect("write object");
let cursor = enc.finish_with_backfill().expect("finish_with_backfill");
cursor.into_inner()
}
fn strides_for(shape: &[u64]) -> Vec<u64> {
if shape.is_empty() {
return vec![];
}
let mut s = vec![1u64; shape.len()];
for i in (0..shape.len() - 1).rev() {
s[i] = s[i + 1] * shape[i + 1];
}
s
}
fn dummy_layout(offset: u64, length: u64) -> CachedLayout {
CachedLayout {
offset,
length,
preamble: Preamble {
version: 3,
flags: MessageFlags::default(),
reserved: 0,
total_length: length,
},
index: None,
global_metadata: None,
}
}
fn concat(parts: Vec<Vec<u8>>) -> Vec<u8> {
let mut out = Vec::new();
for p in parts {
out.extend_from_slice(&p);
}
out
}
fn raw_backend(buf: Vec<u8>, scan_opts: RemoteScanOptions) -> RemoteBackend {
let store = Arc::new(InMemory::new());
let path = ObjectPath::from(TEST_PATH);
let file_size = buf.len() as u64;
{
let store = store.clone();
let path = path.clone();
run(async move {
store.put(&path, buf.into()).await.expect("put");
});
}
RemoteBackend {
source_url: "memory://file.tgm".to_string(),
store,
path,
file_size,
state: Mutex::new(RemoteState {
prev_scan_offset: file_size,
bwd_active: scan_opts.bidirectional,
..RemoteState::default()
}),
scan_opts,
}
}
fn open_backend(buf: Vec<u8>, scan_opts: RemoteScanOptions) -> Result<RemoteBackend> {
if (buf.len() as u64) < (PREAMBLE_SIZE + POSTAMBLE_SIZE) as u64 {
return Err(TensogramError::Remote("too small".to_string()));
}
let backend = raw_backend(buf, scan_opts);
{
let mut state = backend.state.lock().expect("lock");
backend.scan_next_locked(&mut state)?;
if state.layouts.is_empty() {
return Err(TensogramError::Remote("no valid messages".to_string()));
}
}
Ok(backend)
}
fn forward_opts() -> RemoteScanOptions {
RemoteScanOptions {
bidirectional: false,
}
}
fn bidir_opts() -> RemoteScanOptions {
RemoteScanOptions {
bidirectional: true,
}
}
fn layout_offsets(backend: &RemoteBackend) -> Vec<(u64, u64)> {
let state = backend.state.lock().expect("lock");
state.layouts.iter().map(|l| (l.offset, l.length)).collect()
}
#[test]
fn sync_open_too_small_buffer_rejected() {
let err = open_backend(vec![0u8; 8], forward_opts());
assert!(err.is_err(), "tiny buffer must be rejected");
}
#[test]
fn sync_scan_bad_magic_terminates_forward() {
let buf = vec![0u8; 128];
let backend = raw_backend(buf, forward_opts());
{
let mut state = backend.state.lock().expect("lock");
backend.scan_next_locked(&mut state).expect("scan");
assert!(state.layouts.is_empty(), "bad magic => no layouts");
assert!(state.fwd_terminated, "bad magic terminates forward");
}
}
#[test]
fn sync_scan_eof_when_remaining_below_min() {
let msg = make_message(vec![4], 1);
let mut buf = msg.clone();
buf.extend_from_slice(&[0u8; 10]); let backend = raw_backend(buf, forward_opts());
backend.message_count().expect("count");
let state = backend.state.lock().expect("lock");
assert_eq!(state.layouts.len(), 1, "tail too small => single message");
assert!(state.fwd_terminated);
}
#[test]
fn sync_scan_preamble_parse_error_terminates() {
let mut buf = make_message(vec![4], 1);
buf[8] = 0xFF;
let backend = raw_backend(buf, forward_opts());
{
let mut state = backend.state.lock().expect("lock");
backend.scan_next_locked(&mut state).expect("scan");
assert!(state.fwd_terminated, "bad preamble terminates forward");
assert!(state.layouts.is_empty());
}
}
#[test]
fn sync_scan_length_out_of_range_terminates() {
let mut buf = make_message(vec![4], 1);
let huge = (buf.len() as u64) + 10_000;
buf[16..24].copy_from_slice(&huge.to_be_bytes());
let backend = raw_backend(buf, forward_opts());
{
let mut state = backend.state.lock().expect("lock");
backend.scan_next_locked(&mut state).expect("scan");
assert!(state.fwd_terminated);
assert!(state.layouts.is_empty());
}
}
fn make_streaming_preamble(shape: Vec<u64>, fill: u8) -> Vec<u8> {
let mut buf = make_streaming_message(shape, fill);
buf[16..24].copy_from_slice(&0u64.to_be_bytes());
let pa_total = buf.len() - POSTAMBLE_SIZE + 8;
buf[pa_total..pa_total + 8].copy_from_slice(&0u64.to_be_bytes());
buf
}
#[test]
fn sync_scan_streaming_tail_recorded() {
let buf = make_streaming_preamble(vec![4], 7);
let backend = raw_backend(buf.clone(), forward_opts());
backend.message_count().expect("count");
let layouts = layout_offsets(&backend);
assert_eq!(layouts.len(), 1, "streaming tail => single layout");
assert_eq!(layouts[0].0, 0);
assert_eq!(layouts[0].1, buf.len() as u64, "tail spans to EOF");
}
#[test]
fn sync_scan_streaming_end_magic_mismatch_terminates() {
let mut buf = make_streaming_preamble(vec![4], 7);
let len = buf.len();
buf[len - 8..].copy_from_slice(b"BADMAGIC");
let backend = raw_backend(buf, forward_opts());
{
let mut state = backend.state.lock().expect("lock");
backend.scan_next_locked(&mut state).expect("scan");
assert!(state.fwd_terminated);
assert!(state.layouts.is_empty(), "mismatch => no layout");
}
}
#[test]
fn sync_eager_discover_bad_magic_terminates() {
let good = make_message(vec![4], 1);
let mut buf = good.clone();
buf.extend_from_slice(&[0u8; 64]);
let backend = raw_backend(buf, forward_opts());
let meta = backend.read_metadata(0).expect("metadata");
let _ = meta;
let count = backend.message_count().expect("count");
assert_eq!(count, 1, "junk after msg 0 => only one message");
}
#[test]
fn sync_eager_discover_streaming_preamble_branch() {
let buf = make_streaming_preamble(vec![4], 5);
let backend = raw_backend(buf.clone(), forward_opts());
let _meta = backend
.read_metadata(0)
.expect("footer frames still recoverable on streaming tail");
let state = backend.state.lock().expect("lock");
assert_eq!(state.layouts.len(), 1, "streaming tail recorded");
assert_eq!(
state.layouts[0].length,
buf.len() as u64,
"streaming tail spans to EOF",
);
}
#[test]
fn sync_eager_discover_streaming_recorded() {
let buf = make_streaming_message(vec![6], 3);
let backend = open_backend(buf, forward_opts()).expect("open");
let meta = backend.read_metadata(0).expect("metadata");
let _ = meta;
let (_m, descs) = backend.read_descriptors(0).expect("descriptors");
assert_eq!(descs.len(), 1);
let state = backend.state.lock().expect("lock");
assert!(state.layouts[0].global_metadata.is_some());
assert!(state.layouts[0].index.is_some());
}
#[test]
fn sync_eager_discover_footer_indexed_second_message() {
let buf = concat(vec![
make_streaming_message(vec![4], 1),
make_streaming_message(vec![8], 2),
]);
let backend = open_backend(buf, forward_opts()).expect("open");
let _meta = backend.read_metadata(1).expect("metadata for message 1");
let (_m, descs) = backend.read_descriptors(1).expect("descriptors");
assert_eq!(descs.len(), 1);
let state = backend.state.lock().expect("lock");
assert!(state.layouts[1].global_metadata.is_some());
assert!(state.layouts[1].index.is_some());
}
#[test]
fn sync_eager_discover_preamble_parse_error_at_second_message() {
let good = make_message(vec![4], 1);
let good_len = good.len();
let mut buf = concat(vec![good, make_message(vec![8], 2)]);
buf[good_len + 8] = 0xFF; let backend = open_backend(buf, forward_opts()).expect("open");
let err = backend.read_metadata(1);
assert!(err.is_err(), "corrupt 2nd preamble => message 1 missing");
assert_eq!(backend.message_count().expect("count"), 1);
}
#[test]
fn sync_eager_discover_length_out_of_range_at_second_message() {
let good = make_message(vec![4], 1);
let good_len = good.len();
let second = make_message(vec![8], 2);
let mut buf = concat(vec![good.clone(), second]);
let huge = (buf.len() as u64) + 50_000;
buf[good_len + 16..good_len + 24].copy_from_slice(&huge.to_be_bytes());
let backend = open_backend(buf, forward_opts()).expect("open");
let err = backend.read_metadata(1);
assert!(err.is_err(), "overrun 2nd length => message 1 missing");
assert_eq!(backend.message_count().expect("count"), 1);
}
#[test]
fn sync_debug_format_includes_message_count() {
let buf = make_message(vec![4], 1);
let backend = open_backend(buf, forward_opts()).expect("open");
backend.message_count().expect("count");
let dbg = format!("{backend:?}");
assert!(dbg.contains("RemoteBackend"), "debug includes type name");
assert!(dbg.contains("messages"), "debug includes message count");
}
#[test]
fn sync_read_object_and_range_batch_multi_message() {
let buf = concat(vec![make_message(vec![4], 10), make_message(vec![8], 20)]);
let backend = open_backend(buf, forward_opts()).expect("open");
assert_eq!(backend.message_count().expect("count"), 2);
let opts = crate::DecodeOptions::default();
let obj_batch = backend
.read_object_batch(&[0, 1], 0, &opts)
.expect("read_object_batch");
assert_eq!(obj_batch.len(), 2);
assert_eq!(obj_batch[0].2.len(), 4 * 4);
assert_eq!(obj_batch[1].2.len(), 8 * 4);
let range_batch = backend
.read_range_batch(&[0, 1], 0, &[(0, 4)], &opts)
.expect("read_range_batch");
assert_eq!(range_batch.len(), 2);
assert_eq!(range_batch[0].1[0].len(), 4 * 4);
}
#[test]
fn sync_read_object_out_of_range_obj_idx_errors() {
let buf = make_message(vec![4], 1);
let backend = open_backend(buf, forward_opts()).expect("open");
let opts = crate::DecodeOptions::default();
let err = backend.read_object(0, 5, &opts);
assert!(err.is_err(), "obj idx out of range must error");
}
#[test]
fn sync_read_message_index_out_of_range_errors() {
let buf = make_message(vec![4], 1);
let backend = open_backend(buf, forward_opts()).expect("open");
let err = backend.read_message(99);
assert!(err.is_err(), "msg idx out of range must error");
}
#[test]
fn sync_read_descriptors_large_frame_prefix_path() {
let n = 20_000u64; let buf = make_message(vec![n], 9);
let backend = open_backend(buf, forward_opts()).expect("open");
let (_m, descs) = backend
.read_descriptors(0)
.expect("read_descriptors large frame");
assert_eq!(descs.len(), 1);
assert_eq!(descs[0].shape, vec![n]);
}
fn one_desc(shape: Vec<u64>) -> DataObjectDescriptor {
DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: shape.len() as u64,
shape: shape.clone(),
strides: strides_for(&shape),
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
}
}
fn frame_backend(frame: Vec<u8>) -> RemoteBackend {
let store = Arc::new(InMemory::new());
let path = ObjectPath::from(TEST_PATH);
let file_size = frame.len() as u64;
{
let store = store.clone();
let path = path.clone();
run(async move {
store.put(&path, frame.into()).await.expect("put");
});
}
RemoteBackend {
source_url: "memory://frame.tgm".to_string(),
store,
path,
file_size,
state: Mutex::new(RemoteState {
prev_scan_offset: file_size,
..RemoteState::default()
}),
scan_opts: forward_opts(),
}
}
#[test]
fn read_descriptor_only_large_cbor_before_prefix_loop() {
let desc = one_desc(vec![20_000]);
let payload = vec![0u8; 80_000];
let frame = framing::encode_data_object_frame(&desc, &payload, true, false)
.expect("cbor-before frame");
let frame_len = frame.len() as u64;
let backend = frame_backend(frame);
let got = backend
.read_descriptor_only(0, frame_len, 0, frame_len)
.expect("descriptor via prefix loop");
assert_eq!(got.shape, vec![20_000]);
}
#[test]
fn read_descriptor_only_large_cbor_after() {
let desc = one_desc(vec![20_000]);
let payload = vec![0u8; 80_000];
let frame = framing::encode_data_object_frame(&desc, &payload, false, false)
.expect("cbor-after frame");
let frame_len = frame.len() as u64;
let backend = frame_backend(frame);
let got = backend
.read_descriptor_only(0, frame_len, 0, frame_len)
.expect("descriptor cbor-after");
assert_eq!(got.shape, vec![20_000]);
}
#[test]
fn read_descriptor_only_large_non_data_object_frame_errors() {
let desc = one_desc(vec![20_000]);
let payload = vec![0u8; 80_000];
let mut frame =
framing::encode_data_object_frame(&desc, &payload, false, false).expect("frame");
frame[2] = FrameType::HeaderMetadata as u8;
let frame_len = frame.len() as u64;
let backend = frame_backend(frame);
let err = backend.read_descriptor_only(0, frame_len, 0, frame_len);
assert!(err.is_err(), "non-data-object large frame must error");
}
#[test]
fn read_descriptor_only_large_cbor_offset_below_header_errors() {
let desc = one_desc(vec![20_000]);
let payload = vec![0u8; 80_000];
let mut frame =
framing::encode_data_object_frame(&desc, &payload, false, false).expect("frame");
let len = frame.len();
let cbor_off_pos = len - DATA_OBJECT_FOOTER_SIZE;
frame[cbor_off_pos..cbor_off_pos + 8].copy_from_slice(&0u64.to_be_bytes());
let frame_len = frame.len() as u64;
let backend = frame_backend(frame);
let err = backend.read_descriptor_only(0, frame_len, 0, frame_len);
assert!(err.is_err(), "cbor_offset below header size must error");
}
#[test]
fn read_descriptor_only_large_cbor_offset_past_footer_errors() {
let desc = one_desc(vec![20_000]);
let payload = vec![0u8; 80_000];
let mut frame =
framing::encode_data_object_frame(&desc, &payload, false, false).expect("frame");
let len = frame.len();
let cbor_off_pos = len - DATA_OBJECT_FOOTER_SIZE;
let footer_rel = (len - DATA_OBJECT_FOOTER_SIZE) as u64;
frame[cbor_off_pos..cbor_off_pos + 8].copy_from_slice(&footer_rel.to_be_bytes());
let frame_len = frame.len() as u64;
let backend = frame_backend(frame);
let err = backend.read_descriptor_only(0, frame_len, 0, frame_len);
assert!(err.is_err(), "cbor_offset at/past footer must error");
}
#[test]
fn read_descriptor_only_large_bad_endf_errors() {
let desc = one_desc(vec![20_000]);
let payload = vec![0u8; 80_000];
let mut frame =
framing::encode_data_object_frame(&desc, &payload, false, false).expect("frame");
let len = frame.len();
frame[len - 4..].copy_from_slice(b"XXXX");
let frame_len = frame.len() as u64;
let backend = frame_backend(frame);
let err = backend.read_descriptor_only(0, frame_len, 0, frame_len);
assert!(err.is_err(), "bad ENDF on large frame must error");
}
fn build_frame(frame_type: FrameType, payload: &[u8]) -> Vec<u8> {
let total = FRAME_HEADER_SIZE + payload.len() + FRAME_COMMON_FOOTER_SIZE;
let fh = FrameHeader {
frame_type,
version: 1,
flags: 0,
total_length: total as u64,
};
let mut out = Vec::with_capacity(total);
fh.write_to(&mut out);
out.extend_from_slice(payload);
out.extend_from_slice(&[0u8; 8]); out.extend_from_slice(FRAME_END);
out
}
#[test]
fn parse_footer_frames_into_frame_too_small_errors() {
let mut buf = build_frame(FrameType::FooterMetadata, &[1, 2, 3, 4]);
buf[8..16].copy_from_slice(&1u64.to_be_bytes());
let err = RemoteBackend::parse_footer_frames_into(&buf);
assert!(err.is_err(), "below-minimum frame total_length must error");
}
#[test]
fn parse_footer_frames_into_missing_endf_errors() {
let mut buf = build_frame(FrameType::FooterMetadata, &[1, 2, 3, 4]);
let len = buf.len();
buf[len - 4..].copy_from_slice(b"XXXX");
let err = RemoteBackend::parse_footer_frames_into(&buf);
assert!(err.is_err(), "missing ENDF trailer must error");
}
#[test]
fn parse_footer_frames_into_skips_non_footer_frames() {
let buf = build_frame(FrameType::HeaderMetadata, &[9, 9, 9, 9]);
let (metadata, index) = RemoteBackend::parse_footer_frames_into(&buf).expect("parse ok");
assert!(metadata.is_none() && index.is_none(), "non-footer skipped");
}
#[test]
fn parse_footer_frames_missing_metadata_errors() {
let mut state = RemoteState::default();
state.layouts.push(dummy_layout(0, 100));
let err = RemoteBackend::parse_footer_frames(&mut state, 0, &[]);
assert!(
err.is_err(),
"footer region without a metadata frame must error",
);
}
#[test]
fn parse_header_frames_missing_metadata_errors() {
let mut state = RemoteState::default();
state.layouts.push(dummy_layout(0, 100));
let buf = vec![0u8; PREAMBLE_SIZE + 32];
let err = RemoteBackend::parse_header_frames(&mut state, 0, &buf);
assert!(err.is_err(), "header region without metadata must error");
}
fn valid_metadata_cbor() -> Vec<u8> {
let meta = GlobalMetadata {
extra: BTreeMap::new(),
..Default::default()
};
crate::metadata::global_metadata_to_cbor(&meta).expect("meta cbor")
}
#[test]
fn parse_header_frames_missing_index_errors() {
let meta_frame = build_frame(FrameType::HeaderMetadata, &valid_metadata_cbor());
let mut buf = vec![0u8; PREAMBLE_SIZE];
buf.extend_from_slice(&meta_frame);
let mut state = RemoteState::default();
state.layouts.push(dummy_layout(0, buf.len() as u64));
let err = RemoteBackend::parse_header_frames(&mut state, 0, &buf);
assert!(err.is_err(), "header region without an index must error");
assert!(state.layouts[0].global_metadata.is_some());
}
#[test]
fn parse_footer_frames_missing_index_errors() {
let meta_frame = build_frame(FrameType::FooterMetadata, &valid_metadata_cbor());
let mut state = RemoteState::default();
state.layouts.push(dummy_layout(0, 100));
let err = RemoteBackend::parse_footer_frames(&mut state, 0, &meta_frame);
assert!(err.is_err(), "footer region without an index must error");
assert!(state.layouts[0].global_metadata.is_some());
}
#[test]
fn checked_frame_range_overflow_rejected() {
let err = RemoteBackend::checked_frame_range(0, 100, u64::MAX, 10);
assert!(err.is_err(), "frame offset overflow must error");
let err = RemoteBackend::checked_frame_range(0, 100, 50, 100);
assert!(err.is_err(), "frame beyond message boundary must error");
let ok = RemoteBackend::checked_frame_range(10, 100, 20, 30).expect("valid");
assert_eq!(ok, 30..60);
}
#[test]
fn validate_index_access_mismatched_lengths_errors() {
let index = IndexFrame {
offsets: vec![0, 100],
lengths: vec![50], };
let err = RemoteBackend::validate_index_access(&index, 0);
assert!(err.is_err(), "mismatched offsets/lengths must error");
}
#[test]
fn sync_read_descriptors_corrupt_index_errors() {
let buf = make_message(vec![4], 1);
let backend = open_backend(buf, forward_opts()).expect("open");
{
let mut state = backend.state.lock().expect("lock");
state.layouts[0].global_metadata = Some(GlobalMetadata::default());
state.layouts[0].index = Some(IndexFrame {
offsets: vec![0, 100],
lengths: vec![50],
});
}
let err = backend.read_descriptors(0);
assert!(err.is_err(), "corrupt index must error in read_descriptors");
}
#[test]
fn sync_bidir_matches_forward_only_multi_message() {
let parts: Vec<Vec<u8>> = (0..7)
.map(|i| make_message(vec![4 + i as u64], i as u8))
.collect();
let buf = concat(parts);
let fwd = open_backend(buf.clone(), forward_opts()).expect("open fwd");
fwd.message_count().expect("count");
let bidir = open_backend(buf, bidir_opts()).expect("open bidir");
bidir.message_count().expect("count");
assert_eq!(
layout_offsets(&fwd),
layout_offsets(&bidir),
"bidirectional must match forward-only layouts",
);
assert_eq!(layout_offsets(&bidir).len(), 7);
}
#[test]
fn sync_bidir_one_message_no_duplicate() {
let buf = make_message(vec![4], 42);
let backend = open_backend(buf, bidir_opts()).expect("open");
let count = backend.message_count().expect("count");
assert_eq!(count, 1, "1-message bidir must not duplicate");
}
#[test]
fn sync_bidir_corrupt_postamble_falls_back_to_forward() {
let mut buf = concat(vec![make_message(vec![4], 1), make_message(vec![8], 2)]);
let len = buf.len();
buf[len - 8..].copy_from_slice(b"BADMAGIC");
let backend = open_backend(buf, bidir_opts()).expect("open");
let count = backend.message_count().expect("count");
assert_eq!(count, 2, "backward corruption => forward still scans");
}
#[test]
fn sync_bidir_footer_indexed_eager_populates() {
let parts: Vec<Vec<u8>> = (0..6)
.map(|i| make_streaming_message(vec![4 + i as u64], i as u8))
.collect();
let buf = concat(parts);
let backend = open_backend(buf, bidir_opts()).expect("open");
let count = backend.message_count().expect("count");
assert_eq!(count, 6);
let state = backend.state.lock().expect("lock");
let any = (0..6).any(|i| {
state.layouts[i].global_metadata.is_some() && state.layouts[i].index.is_some()
});
assert!(any, "eager footer must populate at least one layout");
}
#[test]
fn sync_footer_discovery_first_footer_offset_too_small_errors() {
let mut buf = make_streaming_message(vec![6], 3);
let pa_start = buf.len() - POSTAMBLE_SIZE;
buf[pa_start..pa_start + 8].copy_from_slice(&1u64.to_be_bytes());
let backend = open_backend(buf, forward_opts()).expect("open");
let err = backend.read_metadata(0);
assert!(
err.is_err(),
"first_footer_offset below preamble end must error",
);
}
#[test]
fn sync_footer_discovery_first_footer_offset_past_postamble_errors() {
let mut buf = make_streaming_message(vec![6], 3);
let msg_len = buf.len() as u64;
let pa_start = buf.len() - POSTAMBLE_SIZE;
buf[pa_start..pa_start + 8].copy_from_slice(&msg_len.to_be_bytes());
let backend = open_backend(buf, forward_opts()).expect("open");
let err = backend.read_metadata(0);
assert!(
err.is_err(),
"first_footer_offset past postamble must error",
);
}
}
#[cfg(all(test, feature = "async"))]
mod inmem_async_tests {
use std::collections::BTreeMap;
use std::sync::{Arc, Mutex};
use object_store::memory::InMemory;
use object_store::path::Path as ObjectPath;
use super::*;
use crate::Dtype;
use crate::encode::{self, EncodeOptions};
use crate::types::{ByteOrder, DataObjectDescriptor, GlobalMetadata};
const TEST_PATH: &str = "file.tgm";
fn strides_for(shape: &[u64]) -> Vec<u64> {
if shape.is_empty() {
return vec![];
}
let mut s = vec![1u64; shape.len()];
for i in (0..shape.len() - 1).rev() {
s[i] = s[i + 1] * shape[i + 1];
}
s
}
fn make_message(shape: Vec<u64>, fill: u8) -> Vec<u8> {
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: shape.len() as u64,
shape: shape.clone(),
strides: strides_for(&shape),
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let num_bytes = shape.iter().product::<u64>() as usize * 4;
let data = vec![fill; num_bytes];
let meta = GlobalMetadata {
extra: BTreeMap::new(),
..Default::default()
};
encode::encode(&meta, &[(&desc, &data)], &EncodeOptions::default())
.expect("encode test message")
}
fn make_streaming_message(shape: Vec<u64>, fill: u8) -> Vec<u8> {
use crate::streaming::StreamingEncoder;
let desc = DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: shape.len() as u64,
shape: shape.clone(),
strides: strides_for(&shape),
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
};
let num_bytes = shape.iter().product::<u64>() as usize * 4;
let data = vec![fill; num_bytes];
let meta = GlobalMetadata {
extra: BTreeMap::new(),
..Default::default()
};
let cursor = std::io::Cursor::new(Vec::<u8>::new());
let mut enc = StreamingEncoder::new(cursor, &meta, &EncodeOptions::default())
.expect("streaming encoder");
enc.write_object(&desc, &data).expect("write object");
let cursor = enc.finish_with_backfill().expect("finish_with_backfill");
cursor.into_inner()
}
fn concat(parts: Vec<Vec<u8>>) -> Vec<u8> {
let mut out = Vec::new();
for p in parts {
out.extend_from_slice(&p);
}
out
}
async fn raw_backend(buf: Vec<u8>, scan_opts: RemoteScanOptions) -> RemoteBackend {
let store = Arc::new(InMemory::new());
let path = ObjectPath::from(TEST_PATH);
let file_size = buf.len() as u64;
store.put(&path, buf.into()).await.expect("put");
RemoteBackend {
source_url: "memory://file.tgm".to_string(),
store,
path,
file_size,
state: Mutex::new(RemoteState {
prev_scan_offset: file_size,
bwd_active: scan_opts.bidirectional,
..RemoteState::default()
}),
scan_opts,
}
}
async fn open_backend(buf: Vec<u8>, scan_opts: RemoteScanOptions) -> Result<RemoteBackend> {
if (buf.len() as u64) < (PREAMBLE_SIZE + POSTAMBLE_SIZE) as u64 {
return Err(TensogramError::Remote("too small".to_string()));
}
let backend = raw_backend(buf, scan_opts).await;
backend.scan_next_async().await?;
{
let state = backend.lock_state()?;
if state.layouts.is_empty() {
return Err(TensogramError::Remote("no valid messages".to_string()));
}
}
Ok(backend)
}
fn forward_opts() -> RemoteScanOptions {
RemoteScanOptions {
bidirectional: false,
}
}
fn bidir_opts() -> RemoteScanOptions {
RemoteScanOptions {
bidirectional: true,
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_scan_bad_magic_terminates() {
let backend = raw_backend(vec![0u8; 128], forward_opts()).await;
backend.scan_next_async().await.expect("scan");
let state = backend.lock_state().expect("lock");
assert!(state.fwd_terminated, "bad magic terminates forward");
assert!(state.layouts.is_empty());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_scan_preamble_parse_error_terminates() {
let mut buf = make_message(vec![4], 1);
buf[8] = 0xFF; let backend = raw_backend(buf, forward_opts()).await;
backend.scan_next_async().await.expect("scan");
let state = backend.lock_state().expect("lock");
assert!(state.fwd_terminated);
assert!(state.layouts.is_empty());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_scan_length_out_of_range_terminates() {
let mut buf = make_message(vec![4], 1);
let huge = (buf.len() as u64) + 10_000;
buf[16..24].copy_from_slice(&huge.to_be_bytes());
let backend = raw_backend(buf, forward_opts()).await;
backend.scan_next_async().await.expect("scan");
let state = backend.lock_state().expect("lock");
assert!(state.fwd_terminated);
assert!(state.layouts.is_empty());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_scan_streaming_tail_recorded() {
let mut buf = make_streaming_message(vec![4], 7);
buf[16..24].copy_from_slice(&0u64.to_be_bytes());
let pa_total = buf.len() - POSTAMBLE_SIZE + 8;
buf[pa_total..pa_total + 8].copy_from_slice(&0u64.to_be_bytes());
let file_len = buf.len() as u64;
let backend = raw_backend(buf, forward_opts()).await;
backend.message_count_async().await.expect("count");
let state = backend.lock_state().expect("lock");
assert_eq!(state.layouts.len(), 1);
assert_eq!(state.layouts[0].length, file_len, "tail spans to EOF");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_read_api_multi_message() {
let buf = concat(vec![make_message(vec![4], 10), make_message(vec![8], 20)]);
let backend = open_backend(buf, forward_opts()).await.expect("open");
assert_eq!(backend.message_count_async().await.expect("count"), 2);
assert_eq!(
backend
.message_layouts_async()
.await
.expect("layouts")
.len(),
2
);
assert!(
!backend
.read_message_async(1)
.await
.expect("read_message")
.is_empty()
);
let opts = crate::DecodeOptions::default();
let _meta = backend.read_metadata_async(0).await.expect("metadata");
let (_m, descs) = backend
.read_descriptors_async(1)
.await
.expect("descriptors");
assert_eq!(descs.len(), 1);
let (_m, _d, decoded) = backend
.read_object_async(0, 0, &opts)
.await
.expect("read_object_async");
assert_eq!(decoded.len(), 4 * 4);
let (_d, parts) = backend
.read_range_async(1, 0, &[(0, 8)], &opts)
.await
.expect("read_range_async");
assert_eq!(parts[0].len(), 8 * 4);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_batch_apis_multi_message() {
let buf = concat(vec![
make_message(vec![4], 10),
make_message(vec![8], 20),
make_message(vec![16], 30),
]);
let backend = open_backend(buf, forward_opts()).await.expect("open");
let opts = crate::DecodeOptions::default();
let obj_batch = backend
.read_object_batch_async(&[0, 1, 2], 0, &opts)
.await
.expect("read_object_batch_async");
assert_eq!(obj_batch.len(), 3);
assert_eq!(obj_batch[2].2.len(), 16 * 4);
let range_batch = backend
.read_range_batch_async(&[0, 2], 0, &[(0, 4)], &opts)
.await
.expect("read_range_batch_async");
assert_eq!(range_batch.len(), 2);
assert_eq!(range_batch[0].1[0].len(), 4 * 4);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_batch_empty_indices_is_noop() {
let buf = make_message(vec![4], 1);
let backend = open_backend(buf, forward_opts()).await.expect("open");
backend
.ensure_all_layouts_batch_async(&[])
.await
.expect("empty batch is a no-op");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_batch_out_of_range_index_errors() {
let buf = make_message(vec![4], 1);
let backend = open_backend(buf, forward_opts()).await.expect("open");
let err = backend.ensure_all_layouts_batch_async(&[0, 9]).await;
assert!(err.is_err(), "out-of-range batch index must error");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_read_descriptors_large_frame_prefix_path() {
let n = 20_000u64;
let buf = make_message(vec![n], 9);
let backend = open_backend(buf, forward_opts()).await.expect("open");
let (_m, descs) = backend
.read_descriptors_async(0)
.await
.expect("descriptors");
assert_eq!(descs.len(), 1);
assert_eq!(descs[0].shape, vec![n]);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_read_message_out_of_range_errors() {
let buf = make_message(vec![4], 1);
let backend = open_backend(buf, forward_opts()).await.expect("open");
let err = backend.read_message_async(42).await;
assert!(err.is_err(), "out-of-range message index must error");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_read_descriptors_corrupt_index_errors() {
let buf = make_message(vec![4], 1);
let backend = open_backend(buf, forward_opts()).await.expect("open");
{
let mut state = backend.lock_state().expect("lock");
state.layouts[0].global_metadata = Some(GlobalMetadata::default());
state.layouts[0].index = Some(IndexFrame {
offsets: vec![0, 100],
lengths: vec![50],
});
}
let err = backend.read_descriptors_async(0).await;
assert!(err.is_err(), "corrupt index must error (async)");
}
fn one_desc(shape: Vec<u64>) -> DataObjectDescriptor {
DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: shape.len() as u64,
shape: shape.clone(),
strides: strides_for(&shape),
dtype: Dtype::Float32,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
params: BTreeMap::new(),
masks: None,
}
}
async fn frame_backend(frame: Vec<u8>) -> RemoteBackend {
let store = Arc::new(InMemory::new());
let path = ObjectPath::from(TEST_PATH);
let file_size = frame.len() as u64;
store.put(&path, frame.into()).await.expect("put");
RemoteBackend {
source_url: "memory://frame.tgm".to_string(),
store,
path,
file_size,
state: Mutex::new(RemoteState {
prev_scan_offset: file_size,
..RemoteState::default()
}),
scan_opts: forward_opts(),
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_read_descriptor_only_large_cbor_before_prefix_loop() {
let desc = one_desc(vec![20_000]);
let payload = vec![0u8; 80_000];
let frame = framing::encode_data_object_frame(&desc, &payload, true, false)
.expect("cbor-before frame");
let frame_len = frame.len() as u64;
let backend = frame_backend(frame).await;
let got = backend
.read_descriptor_only_async(0, frame_len, 0, frame_len)
.await
.expect("descriptor via prefix loop");
assert_eq!(got.shape, vec![20_000]);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_read_descriptor_only_large_cbor_after() {
let desc = one_desc(vec![20_000]);
let payload = vec![0u8; 80_000];
let frame = framing::encode_data_object_frame(&desc, &payload, false, false)
.expect("cbor-after frame");
let frame_len = frame.len() as u64;
let backend = frame_backend(frame).await;
let got = backend
.read_descriptor_only_async(0, frame_len, 0, frame_len)
.await
.expect("descriptor cbor-after");
assert_eq!(got.shape, vec![20_000]);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_read_descriptor_only_large_non_data_object_errors() {
let desc = one_desc(vec![20_000]);
let payload = vec![0u8; 80_000];
let mut frame =
framing::encode_data_object_frame(&desc, &payload, false, false).expect("frame");
frame[2] = FrameType::HeaderMetadata as u8;
let frame_len = frame.len() as u64;
let backend = frame_backend(frame).await;
let err = backend
.read_descriptor_only_async(0, frame_len, 0, frame_len)
.await;
assert!(err.is_err(), "non-data-object large frame must error");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_footer_indexed_read_metadata() {
let buf = make_streaming_message(vec![6], 3);
let backend = open_backend(buf, forward_opts()).await.expect("open");
let _meta = backend.read_metadata_async(0).await.expect("metadata");
let (_m, descs) = backend
.read_descriptors_async(0)
.await
.expect("descriptors");
assert_eq!(descs.len(), 1);
let state = backend.lock_state().expect("lock");
assert!(state.layouts[0].global_metadata.is_some());
assert!(state.layouts[0].index.is_some());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_footer_first_footer_offset_too_small_errors() {
let mut buf = make_streaming_message(vec![6], 3);
let pa_start = buf.len() - POSTAMBLE_SIZE;
buf[pa_start..pa_start + 8].copy_from_slice(&1u64.to_be_bytes());
let backend = open_backend(buf, forward_opts()).await.expect("open");
let err = backend.read_metadata_async(0).await;
assert!(
err.is_err(),
"first_footer_offset below preamble must error"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_footer_first_footer_offset_past_postamble_errors() {
let mut buf = make_streaming_message(vec![6], 3);
let msg_len = buf.len() as u64;
let pa_start = buf.len() - POSTAMBLE_SIZE;
buf[pa_start..pa_start + 8].copy_from_slice(&msg_len.to_be_bytes());
let backend = open_backend(buf, forward_opts()).await.expect("open");
let err = backend.read_metadata_async(0).await;
assert!(
err.is_err(),
"first_footer_offset past postamble must error"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_eager_discover_bad_magic_terminates() {
let mut buf = make_message(vec![4], 1);
buf.extend_from_slice(&[0u8; 64]);
let backend = raw_backend(buf, forward_opts()).await;
let _meta = backend.read_metadata_async(0).await.expect("metadata");
let count = backend.message_count_async().await.expect("count");
assert_eq!(count, 1, "junk after msg 0 => only one message");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_eager_discover_preamble_parse_error_at_second_message() {
let good = make_message(vec![4], 1);
let good_len = good.len();
let mut buf = concat(vec![good, make_message(vec![8], 2)]);
buf[good_len + 8] = 0xFF;
let backend = open_backend(buf, forward_opts()).await.expect("open");
let err = backend.read_metadata_async(1).await;
assert!(err.is_err(), "corrupt 2nd preamble => message 1 missing");
assert_eq!(backend.message_count_async().await.expect("count"), 1);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_eager_discover_length_out_of_range_at_second_message() {
let good = make_message(vec![4], 1);
let good_len = good.len();
let mut buf = concat(vec![good, make_message(vec![8], 2)]);
let huge = (buf.len() as u64) + 50_000;
buf[good_len + 16..good_len + 24].copy_from_slice(&huge.to_be_bytes());
let backend = open_backend(buf, forward_opts()).await.expect("open");
let err = backend.read_metadata_async(1).await;
assert!(err.is_err(), "overrun 2nd length => message 1 missing");
assert_eq!(backend.message_count_async().await.expect("count"), 1);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_eager_discover_streaming_preamble_branch() {
let mut buf = make_streaming_message(vec![4], 5);
buf[16..24].copy_from_slice(&0u64.to_be_bytes());
let pa_total = buf.len() - POSTAMBLE_SIZE + 8;
buf[pa_total..pa_total + 8].copy_from_slice(&0u64.to_be_bytes());
let file_len = buf.len() as u64;
let backend = raw_backend(buf, forward_opts()).await;
let _meta = backend
.read_metadata_async(0)
.await
.expect("footer frames recoverable on streaming tail");
let state = backend.lock_state().expect("lock");
assert_eq!(state.layouts.len(), 1);
assert_eq!(state.layouts[0].length, file_len);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_eager_discover_streaming_end_magic_mismatch() {
let mut buf = make_streaming_message(vec![4], 5);
buf[16..24].copy_from_slice(&0u64.to_be_bytes());
let pa_total = buf.len() - POSTAMBLE_SIZE + 8;
buf[pa_total..pa_total + 8].copy_from_slice(&0u64.to_be_bytes());
let len = buf.len();
buf[len - 8..].copy_from_slice(b"BADMAGIC");
let backend = raw_backend(buf, forward_opts()).await;
let err = backend.read_metadata_async(0).await;
assert!(err.is_err(), "streaming end-magic mismatch => no layout");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_footer_indexed_eager_discover_second_message() {
let buf = concat(vec![
make_streaming_message(vec![4], 1),
make_streaming_message(vec![8], 2),
]);
let backend = open_backend(buf, forward_opts()).await.expect("open");
let _meta = backend
.read_metadata_async(1)
.await
.expect("footer metadata for message 1");
let (_m, descs) = backend
.read_descriptors_async(1)
.await
.expect("descriptors");
assert_eq!(descs.len(), 1);
let state = backend.lock_state().expect("lock");
assert!(state.layouts[1].global_metadata.is_some());
assert!(state.layouts[1].index.is_some());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_footer_indexed_batch() {
let buf = concat(vec![
make_streaming_message(vec![4], 1),
make_streaming_message(vec![8], 2),
]);
let backend = open_backend(buf, forward_opts()).await.expect("open");
let opts = crate::DecodeOptions::default();
let obj_batch = backend
.read_object_batch_async(&[0, 1], 0, &opts)
.await
.expect("footer batch decode");
assert_eq!(obj_batch.len(), 2);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_bidir_matches_forward_only() {
let parts: Vec<Vec<u8>> = (0..8)
.map(|i| make_message(vec![4 + i as u64], i as u8))
.collect();
let buf = concat(parts);
let fwd = open_backend(buf.clone(), forward_opts())
.await
.expect("open fwd");
let fwd_layouts = fwd.message_layouts_async().await.expect("fwd layouts");
let bidir = open_backend(buf, bidir_opts()).await.expect("open bidir");
let bidir_layouts = bidir.message_layouts_async().await.expect("bidir layouts");
assert_eq!(
fwd_layouts.len(),
bidir_layouts.len(),
"bidir must find the same number of messages",
);
for (a, b) in fwd_layouts.iter().zip(bidir_layouts.iter()) {
assert_eq!(a.offset, b.offset);
assert_eq!(a.length, b.length);
}
assert_eq!(bidir_layouts.len(), 8);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_pipelined_one_message_completes_without_fallback() {
let buf = make_message(vec![4], 42);
let backend = open_backend(buf, bidir_opts()).await.expect("open");
let ok = backend.scan_pipelined_async(None).await.expect("pipelined");
assert!(ok, "1-message file: pipeline completes without fallback");
let state = backend.lock_state().expect("lock");
assert_eq!(state.layouts.len(), 1);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_bidir_corrupt_end_magic_forward_still_scans() {
let mut buf = concat(vec![make_message(vec![4], 1), make_message(vec![8], 2)]);
let len = buf.len();
buf[len - 8..].copy_from_slice(b"BADMAGIC");
let backend = open_backend(buf, bidir_opts()).await.expect("open");
let count = backend.message_count_async().await.expect("count");
assert_eq!(count, 2, "backward corruption => forward still scans");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_bidir_footer_indexed_eager_populates() {
let parts: Vec<Vec<u8>> = (0..6)
.map(|i| make_streaming_message(vec![4 + i as u64], i as u8))
.collect();
let buf = concat(parts);
let backend = open_backend(buf, bidir_opts()).await.expect("open");
let count = backend.message_count_async().await.expect("count");
assert_eq!(count, 6);
}
}