Skip to main content

vgi_rpc/
external.rs

1//! External-location batches.
2//!
3//! Large result / stream batches can be uploaded to an external object
4//! store (S3 / GCS / in-memory) and replaced on the wire with a
5//! zero-row "pointer" batch carrying only metadata:
6//!
7//!   `vgi_rpc.location`         → URL (HTTPS by default).
8//!   `vgi_rpc.location.sha256`  → lowercase hex SHA-256 of the raw IPC bytes.
9//!   `vgi_rpc.location.source`  → debug annotation (filled when resolved).
10//!
11//! The remote payload is an Arrow IPC stream containing one batch with the
12//! original data. When compression is set, that stream is zstd-encoded
13//! before upload — the hash is over the raw (uncompressed) IPC bytes so
14//! integrity checks remain stable across compression changes.
15//!
16//! The crate stays storage-agnostic: users register an [`ExternalStorage`]
17//! implementation and a [`Fetcher`] for resolution. The companion
18//! `vgi-rpc-s3` and `vgi-rpc-gcs` crates ship ready-made backends.
19
20use std::net::{IpAddr, ToSocketAddrs};
21use std::sync::Arc;
22
23use arrow_array::RecordBatch;
24use arrow_schema::{Schema, SchemaRef};
25use sha2::{Digest, Sha256};
26
27use crate::errors::{Result, RpcError};
28use crate::metadata::{LOCATION_FETCH_MS_KEY, LOCATION_KEY, LOCATION_SHA256_KEY};
29use crate::wire::{bytes_to_hex, empty_batch, md_get, write_one_batch, Metadata, StreamReader};
30
31/// Optional body compression for externalized payloads.
32#[derive(Clone, Copy, Debug, PartialEq, Eq)]
33pub enum Compression {
34    None,
35    /// zstd at the given level (1..=22).
36    Zstd(i32),
37}
38
39/// Result of uploading a payload to external storage.
40#[derive(Clone, Debug)]
41pub struct UploadResult {
42    /// Caller-fetchable URL for the payload (typically a pre-signed URL).
43    pub url: String,
44    /// SHA-256 of the **raw** IPC bytes, hex-lowercased.
45    pub sha256: String,
46}
47
48/// Pluggable storage backend. Implementations upload Arrow IPC bytes and
49/// return a fetchable URL.
50///
51/// Kept synchronous to preserve compatibility with the current pipe/unix
52/// dispatch loop; async-only backends should offload via a blocking
53/// thread pool.
54pub trait ExternalStorage: Send + Sync {
55    fn upload(&self, ipc_bytes: &[u8], compression: Compression) -> Result<UploadResult>;
56}
57
58/// Pre-signed URL pair for client-side data upload.
59///
60/// `expires_at` is a Unix-epoch microseconds timestamp (UTC). Mirrors
61/// `vgi_rpc.external.UploadUrl`.
62#[derive(Clone, Debug)]
63pub struct UploadUrl {
64    pub upload_url: String,
65    pub download_url: String,
66    /// Expiration time as microseconds since the Unix epoch (UTC).
67    pub expires_at_micros: i64,
68}
69
70/// Generates pre-signed upload URL pairs for client-vended uploads.
71///
72/// Mirror of Python `vgi_rpc.external.UploadUrlProvider`. Implementations
73/// must be thread-safe — `generate_upload_url()` may be called concurrently
74/// from different request handlers.
75pub trait UploadUrlProvider: Send + Sync {
76    fn generate_upload_url(&self) -> Result<UploadUrl>;
77}
78
79/// Callback verifying an external location URL. Called on both the
80/// upload path (post-signing) and the fetch path (pre-download). Return
81/// `Err` to reject the URL with a typed [`RpcError`].
82pub type UrlValidator = Arc<dyn Fn(&str) -> Result<()> + Send + Sync>;
83
84/// Pluggable fetcher used to resolve pointer batches back into data.
85///
86/// Takes a URL (and the declared compression) and returns the still-encoded
87/// payload bytes. `vgi-rpc` ships an HTTPS fetcher; tests plug in an
88/// in-memory implementation.
89///
90/// `max_bytes` is a hard ceiling on the number of bytes the fetcher may
91/// read from the remote — implementations **must** abort once it is
92/// exceeded rather than buffering an unbounded response into memory. A
93/// hostile or compromised storage URL would otherwise OOM the process
94/// before decompression's own cap is ever reached.
95pub trait Fetcher: Send + Sync {
96    fn fetch(&self, url: &str, compression: Compression, max_bytes: usize) -> Result<Vec<u8>>;
97}
98
99/// Externalization configuration.
100#[derive(Clone)]
101pub struct ExternalLocationConfig {
102    /// Payload size in bytes at which a batch is eligible for
103    /// externalization. Smaller batches stay inline. Default 1 MiB.
104    pub threshold_bytes: usize,
105    /// Compression applied to the IPC bytes before upload. Default
106    /// `Compression::None`.
107    pub compression: Compression,
108    /// Upload backend. Required when externalization is enabled.
109    pub storage: Arc<dyn ExternalStorage>,
110    /// Resolver used on the read side. Required when resolving inbound
111    /// pointer batches.
112    pub fetcher: Arc<dyn Fetcher>,
113    /// URL validator run on both the upload (post-signing) and fetch
114    /// (pre-download) paths. Defaults to [`safe_https_validator`], which
115    /// rejects non-`https` URLs and internal/non-routable hosts. Reject
116    /// a URL by returning `Err`.
117    pub url_validator: UrlValidator,
118    /// Hard ceiling on the post-decompression size of a fetched
119    /// payload. Zstd frames carry their decompressed size in the
120    /// header and `zstd::decode_all` would otherwise trust it
121    /// eagerly — a small malicious payload claiming gigabytes of
122    /// output would OOM the client. Default 1 GiB. Set to `usize::MAX`
123    /// to disable.
124    pub max_decompressed_bytes: usize,
125}
126
127impl std::fmt::Debug for ExternalLocationConfig {
128    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129        f.debug_struct("ExternalLocationConfig")
130            .field("threshold_bytes", &self.threshold_bytes)
131            .field("compression", &self.compression)
132            .finish_non_exhaustive()
133    }
134}
135
136/// Helper: scheme-only HTTPS validator.
137///
138/// **Insufficient for untrusted input** — it accepts `https://localhost`,
139/// `https://169.254.169.254` (cloud metadata), and any internal-network
140/// host. Use it only when the set of external-location URLs is fully
141/// trusted (e.g. URLs your own backend minted). For anything that can
142/// carry a client-supplied location, use [`safe_https_validator`].
143pub fn https_only_validator() -> UrlValidator {
144    Arc::new(|url: &str| {
145        if url.starts_with("https://") {
146            Ok(())
147        } else {
148            Err(RpcError::value_error(format!(
149                "external location URL must be https:// ({url})"
150            )))
151        }
152    })
153}
154
155/// Helper: default-deny HTTPS validator with SSRF protection.
156///
157/// Requires the `https` scheme, then rejects the URL when its host is —
158/// or resolves to — a loopback, private, link-local, unique-local,
159/// carrier-grade-NAT, broadcast, documentation, or unspecified address.
160/// This is the default for [`ExternalLocationConfig::new`] because the
161/// unary HTTP path resolves a *client-supplied* `vgi_rpc.location`
162/// server-side; without this a client could pivot the server into
163/// fetching `https://169.254.169.254/...` or an internal service.
164///
165/// Note: a hostname is resolved here and again at fetch time, so a
166/// DNS-rebinding attacker could still slip through the gap. Pair this
167/// with a redirect-free, size-capped fetcher (the bundled `HttpFetcher`
168/// is both) and, for high-assurance deployments, an egress firewall.
169pub fn safe_https_validator() -> UrlValidator {
170    Arc::new(|raw: &str| {
171        let url = url::Url::parse(raw)
172            .map_err(|e| RpcError::value_error(format!("invalid external location URL: {e}")))?;
173        if url.scheme() != "https" {
174            return Err(RpcError::value_error(
175                "external location URL must be https://",
176            ));
177        }
178        let host = url
179            .host()
180            .ok_or_else(|| RpcError::value_error("external location URL has no host"))?;
181        match host {
182            url::Host::Ipv4(ip) => reject_unsafe_ip(IpAddr::V4(ip)),
183            url::Host::Ipv6(ip) => reject_unsafe_ip(IpAddr::V6(ip)),
184            url::Host::Domain(name) => {
185                let lname = name.to_ascii_lowercase();
186                if lname == "localhost" || lname.ends_with(".localhost") {
187                    return Err(RpcError::value_error(
188                        "external location host is not publicly routable",
189                    ));
190                }
191                let port = url.port_or_known_default().unwrap_or(443);
192                let addrs = (name, port).to_socket_addrs().map_err(|e| {
193                    RpcError::value_error(format!("external location host does not resolve: {e}"))
194                })?;
195                let mut saw_any = false;
196                for sa in addrs {
197                    saw_any = true;
198                    reject_unsafe_ip(sa.ip())?;
199                }
200                if !saw_any {
201                    return Err(RpcError::value_error(
202                        "external location host does not resolve",
203                    ));
204                }
205                Ok(())
206            }
207        }
208    })
209}
210
211/// Reject an IP address that is not safe for the server to dial: any
212/// loopback / private / link-local / unique-local / CGNAT / broadcast /
213/// documentation / unspecified / multicast address.
214fn reject_unsafe_ip(ip: IpAddr) -> Result<()> {
215    let unsafe_addr = match ip {
216        IpAddr::V4(v4) => {
217            let o = v4.octets();
218            v4.is_loopback()
219                || v4.is_private()
220                || v4.is_link_local()
221                || v4.is_unspecified()
222                || v4.is_broadcast()
223                || v4.is_multicast()
224                || v4.is_documentation()
225                // 100.64.0.0/10 — carrier-grade NAT.
226                || (o[0] == 100 && (o[1] & 0xc0) == 0x40)
227        }
228        IpAddr::V6(v6) => {
229            let seg0 = v6.segments()[0];
230            v6.is_loopback()
231                || v6.is_unspecified()
232                || v6.is_multicast()
233                // fc00::/7 — unique local.
234                || (seg0 & 0xfe00) == 0xfc00
235                // fe80::/10 — link local.
236                || (seg0 & 0xffc0) == 0xfe80
237                // IPv4-mapped (::ffff:0:0/96) — classify the embedded v4.
238                || v6
239                    .to_ipv4_mapped()
240                    .map(|m| reject_unsafe_ip(IpAddr::V4(m)).is_err())
241                    .unwrap_or(false)
242        }
243    };
244    if unsafe_addr {
245        return Err(RpcError::value_error(
246            "external location host resolves to a non-routable / internal address",
247        ));
248    }
249    Ok(())
250}
251
252/// Helper: accept any URL (useful for local tests + MinIO).
253pub fn any_url_validator() -> UrlValidator {
254    Arc::new(|_: &str| Ok(()))
255}
256
257impl ExternalLocationConfig {
258    pub fn new(storage: Arc<dyn ExternalStorage>, fetcher: Arc<dyn Fetcher>) -> Self {
259        Self {
260            threshold_bytes: 1024 * 1024,
261            compression: Compression::None,
262            storage,
263            fetcher,
264            url_validator: safe_https_validator(),
265            max_decompressed_bytes: 1024 * 1024 * 1024,
266        }
267    }
268
269    pub fn with_threshold_bytes(mut self, n: usize) -> Self {
270        self.threshold_bytes = n;
271        self
272    }
273
274    pub fn with_compression(mut self, c: Compression) -> Self {
275        self.compression = c;
276        self
277    }
278
279    pub fn with_url_validator(mut self, v: UrlValidator) -> Self {
280        self.url_validator = v;
281        self
282    }
283
284    /// Override the hard ceiling on post-decompression payload size.
285    /// Pass `usize::MAX` to disable. Default is 1 GiB.
286    pub fn with_max_decompressed_bytes(mut self, n: usize) -> Self {
287        self.max_decompressed_bytes = n;
288        self
289    }
290}
291
292// ---------------------------------------------------------------------------
293// Serialize a batch as an IPC stream with no custom metadata.
294// ---------------------------------------------------------------------------
295
296/// Serialize one record batch as a complete IPC stream (schema + batch + EOS).
297pub fn serialize_batch_to_ipc(batch: &RecordBatch) -> Result<Vec<u8>> {
298    // External payloads carry the raw data only; the pointer batch on
299    // the outside owns the metadata. Pass `None` to omit any
300    // `custom_metadata` field on the wire.
301    write_one_batch(batch, None)
302}
303
304/// Read back an IPC stream containing a single batch.
305/// Fetch, decompress, and integrity-check an external-location pointer's
306/// payload, returning the raw inner IPC stream bytes. The inner stream may
307/// contain **multiple** batches (e.g. a peer that externalizes a whole
308/// per-iteration output — logs followed by the data batch), so callers that
309/// need log/exception handling should process the returned bytes as a full
310/// response stream rather than assuming a single batch.
311///
312/// Returns `Ok(None)` when `metadata` carries no `vgi_rpc.location` pointer.
313pub fn fetch_external_ipc_bytes(
314    metadata: &Metadata,
315    cfg: &ExternalLocationConfig,
316) -> Result<Option<Vec<u8>>> {
317    let Some(url) = md_get(metadata, LOCATION_KEY) else {
318        return Ok(None);
319    };
320    (cfg.url_validator)(url)?;
321    let compressed = cfg
322        .fetcher
323        .fetch(url, cfg.compression, cfg.max_decompressed_bytes)?;
324    let ipc_bytes = decompress(&compressed, cfg.compression, cfg.max_decompressed_bytes)?;
325    if let Some(expected) = md_get(metadata, LOCATION_SHA256_KEY) {
326        let actual = sha256_hex(&ipc_bytes);
327        if expected != actual.as_str() {
328            return Err(RpcError::runtime_error(format!(
329                "external location SHA-256 mismatch (expected {expected}, got {actual})"
330            )));
331        }
332    }
333    Ok(Some(ipc_bytes))
334}
335
336pub fn deserialize_single_batch(ipc_bytes: &[u8]) -> Result<RecordBatch> {
337    Ok(deserialize_single_batch_with_metadata(ipc_bytes)?.0)
338}
339
340/// Like [`deserialize_single_batch`] but also returns the batch's per-message
341/// custom metadata — some peers carry keys (e.g. the stream-state token) on the
342/// externalized inner batch rather than the outer pointer.
343pub fn deserialize_single_batch_with_metadata(ipc_bytes: &[u8]) -> Result<(RecordBatch, Metadata)> {
344    let mut r = StreamReader::new(ipc_bytes)?;
345    r.read_next()?
346        .ok_or_else(|| RpcError::runtime_error("external batch stream is empty"))
347}
348
349fn sha256_hex(bytes: &[u8]) -> String {
350    bytes_to_hex(&Sha256::digest(bytes))
351}
352
353fn compress(ipc_bytes: &[u8], compression: Compression) -> Result<Vec<u8>> {
354    match compression {
355        Compression::None => Ok(ipc_bytes.to_vec()),
356        Compression::Zstd(level) => {
357            // Use the bulk API and explicitly include the decompressed
358            // size in the frame header. Python's `zstandard.ZstdDecompressor`
359            // requires `Content-Size` to be present when decompressing
360            // a single frame in one shot.
361            let mut enc = zstd::bulk::Compressor::new(level)
362                .map_err(|e| RpcError::runtime_error(format!("zstd encoder: {e}")))?;
363            enc.set_parameter(zstd::stream::raw::CParameter::ContentSizeFlag(true))
364                .map_err(|e| RpcError::runtime_error(format!("zstd contentsize: {e}")))?;
365            enc.compress(ipc_bytes)
366                .map_err(|e| RpcError::runtime_error(format!("zstd encode: {e}")))
367        }
368    }
369}
370
371fn decompress(bytes: &[u8], compression: Compression, max_size: usize) -> Result<Vec<u8>> {
372    match compression {
373        Compression::None => {
374            if bytes.len() > max_size {
375                return Err(RpcError::runtime_error(format!(
376                    "external payload {} bytes exceeds max_decompressed_bytes={max_size}",
377                    bytes.len()
378                )));
379            }
380            Ok(bytes.to_vec())
381        }
382        // Stream-decode and stop if we exceed the cap. Avoids trusting
383        // the zstd frame header's declared decompressed size (which
384        // `decode_all` would otherwise allocate eagerly), blocking a
385        // remote OOM via a tiny payload claiming gigabytes of output.
386        Compression::Zstd(_) => {
387            use std::io::Read;
388            let mut decoder = zstd::Decoder::new(bytes)
389                .map_err(|e| RpcError::runtime_error(format!("zstd decode: {e}")))?;
390            let mut out = Vec::new();
391            let mut buf = [0u8; 64 * 1024];
392            loop {
393                let n = decoder
394                    .read(&mut buf)
395                    .map_err(|e| RpcError::runtime_error(format!("zstd decode: {e}")))?;
396                if n == 0 {
397                    break;
398                }
399                if out.len() + n > max_size {
400                    return Err(RpcError::runtime_error(format!(
401                        "zstd decode: output exceeds max_decompressed_bytes={max_size}"
402                    )));
403                }
404                out.extend_from_slice(&buf[..n]);
405            }
406            Ok(out)
407        }
408    }
409}
410
411// ---------------------------------------------------------------------------
412// Server-side: externalize large batches
413// ---------------------------------------------------------------------------
414
415/// Pointer-batch schema — a zero-field, zero-row batch. The Python
416/// canonical externalizes into an empty-schema batch with location
417/// metadata; we match that so the on-wire bytes are identical.
418pub fn pointer_schema() -> SchemaRef {
419    Arc::new(Schema::empty())
420}
421
422/// Decide whether to externalize `batch`; return the pointer (zero-row)
423/// batch + pointer metadata when yes, else `None`. The original batch is
424/// left untouched so the caller can emit it inline.
425///
426/// `inline_metadata` — optional custom metadata the caller wants to
427/// attach alongside the location keys (merged; location keys win).
428pub fn maybe_externalize_batch(
429    batch: &RecordBatch,
430    inline_metadata: Option<&Metadata>,
431    cfg: &ExternalLocationConfig,
432) -> Result<Option<(RecordBatch, Metadata)>> {
433    if batch.num_rows() == 0 {
434        return Ok(None);
435    }
436    let ipc_bytes = serialize_batch_to_ipc(batch)?;
437    if ipc_bytes.len() < cfg.threshold_bytes {
438        return Ok(None);
439    }
440    let sha = sha256_hex(&ipc_bytes);
441    let payload = compress(&ipc_bytes, cfg.compression)?;
442    let upload = cfg.storage.upload(&payload, cfg.compression)?;
443    // Validator runs over the final URL.
444    (cfg.url_validator)(&upload.url)?;
445
446    // Build pointer metadata, merging the caller-supplied metadata first.
447    let mut md: Metadata = inline_metadata.cloned().unwrap_or_default();
448    md.remove(LOCATION_KEY);
449    md.remove(LOCATION_SHA256_KEY);
450    md.remove(LOCATION_FETCH_MS_KEY);
451    md.insert(LOCATION_KEY.to_string(), upload.url);
452    md.insert(LOCATION_SHA256_KEY.to_string(), sha);
453
454    // Pointer batch: zero-row but matching the input batch's schema,
455    // matching Python's `make_external_location_batch` shape so the
456    // client's IPC reader sees a consistent column count.
457    let ptr = empty_batch(batch.schema().as_ref())?;
458    Ok(Some((ptr, md)))
459}
460
461// ---------------------------------------------------------------------------
462// Client-side: resolve pointer batches
463// ---------------------------------------------------------------------------
464
465/// Resolve a pointer batch (zero-row batch with `vgi_rpc.location`
466/// metadata) back into the original record batch. Non-pointer batches
467/// are returned untouched.
468///
469/// Returns `(resolved_batch, user_metadata)` where the location keys
470/// have been stripped from the metadata visible to the caller. A
471/// `vgi_rpc.location.fetch_ms` claim is appended so callers / access
472/// logs can observe the fetch latency.
473pub fn resolve_external_location(
474    batch: &RecordBatch,
475    metadata: &Metadata,
476    cfg: &ExternalLocationConfig,
477) -> Result<(RecordBatch, Metadata)> {
478    let Some(url) = md_get(metadata, LOCATION_KEY) else {
479        return Ok((batch.clone(), metadata.clone()));
480    };
481    (cfg.url_validator)(url)?;
482
483    let start = std::time::Instant::now();
484    // Cap the fetched (still-encoded) payload at `max_decompressed_bytes`:
485    // any well-formed compressed body is smaller than its decompressed
486    // form, so this is a safe ceiling that also bounds the uncompressed
487    // case. `decompress` enforces the post-decompression cap on top.
488    let compressed = cfg
489        .fetcher
490        .fetch(url, cfg.compression, cfg.max_decompressed_bytes)?;
491    let ipc_bytes = decompress(&compressed, cfg.compression, cfg.max_decompressed_bytes)?;
492
493    // Integrity check.
494    if let Some(expected) = md_get(metadata, LOCATION_SHA256_KEY) {
495        let actual = sha256_hex(&ipc_bytes);
496        if expected != actual.as_str() {
497            return Err(RpcError::runtime_error(format!(
498                "external location SHA-256 mismatch (expected {expected}, got {actual})"
499            )));
500        }
501    }
502    let (resolved, inner_md) = deserialize_single_batch_with_metadata(&ipc_bytes)?;
503    let fetch_ms = start.elapsed().as_secs_f64() * 1000.0;
504
505    // Start from the outer pointer's non-location keys, then overlay the inner
506    // (externalized) batch's metadata. Implementations differ on where they
507    // carry per-batch keys like `vgi_rpc.stream_state#b64`: the Rust server
508    // stamps them on the outer pointer, the Python server on the inner payload
509    // batch. Merging both (inner wins) recovers the token either way and
510    // matches Python's resolver, which uses the inner batch's metadata.
511    let mut user_md: Metadata = metadata
512        .iter()
513        .filter(|(k, _)| {
514            *k != LOCATION_KEY && *k != LOCATION_SHA256_KEY && *k != LOCATION_FETCH_MS_KEY
515        })
516        .map(|(k, v)| (k.clone(), v.clone()))
517        .collect();
518    for (k, v) in inner_md {
519        if k != LOCATION_KEY && k != LOCATION_SHA256_KEY && k != LOCATION_FETCH_MS_KEY {
520            user_md.insert(k, v);
521        }
522    }
523    user_md.insert(
524        LOCATION_FETCH_MS_KEY.to_string(),
525        format!("{:.2}", fetch_ms),
526    );
527    Ok((resolved, user_md))
528}
529
530// ---------------------------------------------------------------------------
531// In-memory test backend
532// ---------------------------------------------------------------------------
533//
534// Gated behind the `test-utils` feature so it doesn't show up on
535// crates.io / docs.rs for normal users. Internal tests + the
536// `external_integration` integration test enable the feature
537// transitively via the workspace.
538
539/// In-memory storage backend + fetcher pair; used by tests and CI.
540/// Thread-safe, no I/O. **Not for production use.**
541#[cfg(any(test, feature = "test-utils"))]
542pub struct InMemoryStorage {
543    map: std::sync::Mutex<std::collections::HashMap<String, Vec<u8>>>,
544    next_id: std::sync::atomic::AtomicU64,
545    base_url: String,
546}
547
548#[cfg(any(test, feature = "test-utils"))]
549impl InMemoryStorage {
550    pub fn new() -> Arc<Self> {
551        Arc::new(Self {
552            map: std::sync::Mutex::new(std::collections::HashMap::new()),
553            next_id: std::sync::atomic::AtomicU64::new(1),
554            base_url: "https://inmem.test/".to_string(),
555        })
556    }
557
558    pub fn len(&self) -> usize {
559        self.map.lock().unwrap().len()
560    }
561
562    pub fn is_empty(&self) -> bool {
563        self.len() == 0
564    }
565}
566
567#[cfg(any(test, feature = "test-utils"))]
568impl ExternalStorage for InMemoryStorage {
569    fn upload(&self, ipc_bytes: &[u8], _compression: Compression) -> Result<UploadResult> {
570        let id = self
571            .next_id
572            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
573        let url = format!("{}{:016x}", self.base_url, id);
574        let sha = sha256_hex(
575            // Storage receives the already-compressed bytes; the sha recorded
576            // in the pointer metadata tracks the RAW ipc bytes so the
577            // caller of upload() is responsible for that value — we only
578            // store, not hash, here. Return a placeholder hash (caller
579            // replaces it from maybe_externalize_batch).
580            ipc_bytes,
581        );
582        self.map
583            .lock()
584            .unwrap()
585            .insert(url.clone(), ipc_bytes.to_vec());
586        Ok(UploadResult { url, sha256: sha })
587    }
588}
589
590#[cfg(any(test, feature = "test-utils"))]
591impl Fetcher for InMemoryStorage {
592    fn fetch(&self, url: &str, _compression: Compression, max_bytes: usize) -> Result<Vec<u8>> {
593        let bytes = self
594            .map
595            .lock()
596            .unwrap()
597            .get(url)
598            .cloned()
599            .ok_or_else(|| RpcError::runtime_error(format!("inmem fetch miss: {url}")))?;
600        if bytes.len() > max_bytes {
601            return Err(RpcError::runtime_error(format!(
602                "inmem fetch payload {} bytes exceeds max_bytes={max_bytes}",
603                bytes.len()
604            )));
605        }
606        Ok(bytes)
607    }
608}
609
610#[cfg(test)]
611mod tests {
612    use std::sync::Arc as Ar;
613
614    use arrow_array::{Int64Array, RecordBatch};
615    use arrow_schema::{DataType, Field, Schema};
616
617    use super::*;
618
619    fn big_batch(rows: usize) -> RecordBatch {
620        let schema = Arc::new(Schema::new(vec![Field::new(
621            "value",
622            DataType::Int64,
623            false,
624        )]));
625        let col: Ar<dyn arrow_array::Array> =
626            Arc::new(Int64Array::from((0..rows as i64).collect::<Vec<_>>()));
627        RecordBatch::try_new(schema, vec![col]).unwrap()
628    }
629
630    fn cfg_with(storage: Arc<InMemoryStorage>, threshold: usize) -> ExternalLocationConfig {
631        let s: Arc<dyn ExternalStorage> = storage.clone();
632        let f: Arc<dyn Fetcher> = storage;
633        ExternalLocationConfig::new(s, f)
634            .with_threshold_bytes(threshold)
635            .with_url_validator(any_url_validator())
636    }
637
638    #[test]
639    fn small_batch_stays_inline() {
640        let storage = InMemoryStorage::new();
641        let cfg = cfg_with(storage.clone(), 1024 * 1024);
642        let batch = big_batch(10);
643        let out = maybe_externalize_batch(&batch, None, &cfg).unwrap();
644        assert!(out.is_none());
645        assert!(storage.is_empty());
646    }
647
648    #[test]
649    fn large_batch_externalizes_and_round_trips() {
650        let storage = InMemoryStorage::new();
651        let cfg = cfg_with(storage.clone(), 1024);
652        let batch = big_batch(50_000);
653
654        let (ptr, md) = maybe_externalize_batch(&batch, None, &cfg)
655            .unwrap()
656            .unwrap();
657        assert_eq!(ptr.num_rows(), 0);
658        // Pointer batch carries the original schema (zero-row); cross-language
659        // clients expect the column count to match the result schema.
660        assert_eq!(ptr.schema().fields().len(), batch.schema().fields().len());
661        assert!(md_get(&md, LOCATION_KEY).unwrap().starts_with("https://"));
662        assert_eq!(storage.len(), 1);
663
664        let (resolved, user_md) = resolve_external_location(&ptr, &md, &cfg).unwrap();
665        assert_eq!(resolved.num_rows(), batch.num_rows());
666        assert!(md_get(&user_md, LOCATION_KEY).is_none());
667        assert!(md_get(&user_md, LOCATION_FETCH_MS_KEY).is_some());
668    }
669
670    #[test]
671    fn zstd_compression_round_trip() {
672        let storage = InMemoryStorage::new();
673        let cfg = cfg_with(storage.clone(), 1024).with_compression(Compression::Zstd(3));
674        let batch = big_batch(20_000);
675        let (ptr, md) = maybe_externalize_batch(&batch, None, &cfg)
676            .unwrap()
677            .unwrap();
678        let (resolved, _) = resolve_external_location(&ptr, &md, &cfg).unwrap();
679        assert_eq!(resolved.num_rows(), batch.num_rows());
680    }
681
682    #[test]
683    fn https_only_validator_rejects_plaintext() {
684        let storage = InMemoryStorage::new();
685        let s: Arc<dyn ExternalStorage> = storage.clone();
686        let f: Arc<dyn Fetcher> = storage;
687        let cfg = ExternalLocationConfig::new(s, f).with_threshold_bytes(0);
688        // In-memory URL is https, so build a forged metadata entry.
689        let batch = big_batch(1);
690        let mut bogus_md = Metadata::new();
691        bogus_md.insert("vgi_rpc.location".into(), "http://not-secure/x".into());
692        let err = resolve_external_location(&batch, &bogus_md, &cfg).unwrap_err();
693        assert!(err.message.contains("https://"));
694    }
695
696    #[test]
697    fn safe_https_validator_blocks_ssrf_targets() {
698        let v = safe_https_validator();
699        // Non-https.
700        assert!(v("http://example.com/x").is_err());
701        // IP-literal internal / non-routable targets.
702        assert!(v("https://169.254.169.254/latest/meta-data/").is_err());
703        assert!(v("https://127.0.0.1/").is_err());
704        assert!(v("https://10.0.0.1/").is_err());
705        assert!(v("https://192.168.1.1/").is_err());
706        assert!(v("https://[::1]/").is_err());
707        assert!(v("https://0.0.0.0/").is_err());
708        // Hostname forms of loopback.
709        assert!(v("https://localhost/x").is_err());
710        assert!(v("https://api.localhost/x").is_err());
711        // A public IP literal is allowed.
712        assert!(v("https://1.1.1.1/x").is_ok());
713    }
714
715    #[test]
716    fn sha_mismatch_is_rejected() {
717        let storage = InMemoryStorage::new();
718        let cfg = cfg_with(storage.clone(), 1024);
719        let batch = big_batch(10_000);
720        let (ptr, mut md) = maybe_externalize_batch(&batch, None, &cfg)
721            .unwrap()
722            .unwrap();
723        // Corrupt the recorded hash.
724        if let Some(v) = md.get_mut(LOCATION_SHA256_KEY) {
725            *v = "deadbeef".into();
726        }
727        let err = resolve_external_location(&ptr, &md, &cfg).unwrap_err();
728        assert!(err.message.contains("SHA-256 mismatch"));
729    }
730}