use std::net::{IpAddr, ToSocketAddrs};
use std::sync::Arc;
use arrow_array::RecordBatch;
use arrow_schema::{Schema, SchemaRef};
use sha2::{Digest, Sha256};
use crate::errors::{Result, RpcError};
use crate::metadata::{LOCATION_FETCH_MS_KEY, LOCATION_KEY, LOCATION_SHA256_KEY};
use crate::wire::{bytes_to_hex, empty_batch, md_get, write_one_batch, Metadata, StreamReader};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Compression {
None,
Zstd(i32),
}
#[derive(Clone, Debug)]
pub struct UploadResult {
pub url: String,
pub sha256: String,
}
pub trait ExternalStorage: Send + Sync {
fn upload(&self, ipc_bytes: &[u8], compression: Compression) -> Result<UploadResult>;
}
#[derive(Clone, Debug)]
pub struct UploadUrl {
pub upload_url: String,
pub download_url: String,
pub expires_at_micros: i64,
}
pub trait UploadUrlProvider: Send + Sync {
fn generate_upload_url(&self) -> Result<UploadUrl>;
}
pub type UrlValidator = Arc<dyn Fn(&str) -> Result<()> + Send + Sync>;
pub trait Fetcher: Send + Sync {
fn fetch(&self, url: &str, compression: Compression, max_bytes: usize) -> Result<Vec<u8>>;
}
#[derive(Clone)]
pub struct ExternalLocationConfig {
pub threshold_bytes: usize,
pub compression: Compression,
pub storage: Arc<dyn ExternalStorage>,
pub fetcher: Arc<dyn Fetcher>,
pub url_validator: UrlValidator,
pub max_decompressed_bytes: usize,
}
impl std::fmt::Debug for ExternalLocationConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExternalLocationConfig")
.field("threshold_bytes", &self.threshold_bytes)
.field("compression", &self.compression)
.finish_non_exhaustive()
}
}
pub fn https_only_validator() -> UrlValidator {
Arc::new(|url: &str| {
if url.starts_with("https://") {
Ok(())
} else {
Err(RpcError::value_error(format!(
"external location URL must be https:// ({url})"
)))
}
})
}
pub fn safe_https_validator() -> UrlValidator {
Arc::new(|raw: &str| {
let url = url::Url::parse(raw)
.map_err(|e| RpcError::value_error(format!("invalid external location URL: {e}")))?;
if url.scheme() != "https" {
return Err(RpcError::value_error(
"external location URL must be https://",
));
}
let host = url
.host()
.ok_or_else(|| RpcError::value_error("external location URL has no host"))?;
match host {
url::Host::Ipv4(ip) => reject_unsafe_ip(IpAddr::V4(ip)),
url::Host::Ipv6(ip) => reject_unsafe_ip(IpAddr::V6(ip)),
url::Host::Domain(name) => {
let lname = name.to_ascii_lowercase();
if lname == "localhost" || lname.ends_with(".localhost") {
return Err(RpcError::value_error(
"external location host is not publicly routable",
));
}
let port = url.port_or_known_default().unwrap_or(443);
let addrs = (name, port).to_socket_addrs().map_err(|e| {
RpcError::value_error(format!("external location host does not resolve: {e}"))
})?;
let mut saw_any = false;
for sa in addrs {
saw_any = true;
reject_unsafe_ip(sa.ip())?;
}
if !saw_any {
return Err(RpcError::value_error(
"external location host does not resolve",
));
}
Ok(())
}
}
})
}
fn reject_unsafe_ip(ip: IpAddr) -> Result<()> {
let unsafe_addr = match ip {
IpAddr::V4(v4) => {
let o = v4.octets();
v4.is_loopback()
|| v4.is_private()
|| v4.is_link_local()
|| v4.is_unspecified()
|| v4.is_broadcast()
|| v4.is_multicast()
|| v4.is_documentation()
|| (o[0] == 100 && (o[1] & 0xc0) == 0x40)
}
IpAddr::V6(v6) => {
let seg0 = v6.segments()[0];
v6.is_loopback()
|| v6.is_unspecified()
|| v6.is_multicast()
|| (seg0 & 0xfe00) == 0xfc00
|| (seg0 & 0xffc0) == 0xfe80
|| v6
.to_ipv4_mapped()
.map(|m| reject_unsafe_ip(IpAddr::V4(m)).is_err())
.unwrap_or(false)
}
};
if unsafe_addr {
return Err(RpcError::value_error(
"external location host resolves to a non-routable / internal address",
));
}
Ok(())
}
pub fn any_url_validator() -> UrlValidator {
Arc::new(|_: &str| Ok(()))
}
impl ExternalLocationConfig {
pub fn new(storage: Arc<dyn ExternalStorage>, fetcher: Arc<dyn Fetcher>) -> Self {
Self {
threshold_bytes: 1024 * 1024,
compression: Compression::None,
storage,
fetcher,
url_validator: safe_https_validator(),
max_decompressed_bytes: 1024 * 1024 * 1024,
}
}
pub fn with_threshold_bytes(mut self, n: usize) -> Self {
self.threshold_bytes = n;
self
}
pub fn with_compression(mut self, c: Compression) -> Self {
self.compression = c;
self
}
pub fn with_url_validator(mut self, v: UrlValidator) -> Self {
self.url_validator = v;
self
}
pub fn with_max_decompressed_bytes(mut self, n: usize) -> Self {
self.max_decompressed_bytes = n;
self
}
}
pub fn serialize_batch_to_ipc(batch: &RecordBatch) -> Result<Vec<u8>> {
write_one_batch(batch, None)
}
pub fn deserialize_single_batch(ipc_bytes: &[u8]) -> Result<RecordBatch> {
let mut r = StreamReader::new(ipc_bytes)?;
let (batch, _md) = r
.read_next()?
.ok_or_else(|| RpcError::runtime_error("external batch stream is empty"))?;
Ok(batch)
}
fn sha256_hex(bytes: &[u8]) -> String {
bytes_to_hex(&Sha256::digest(bytes))
}
fn compress(ipc_bytes: &[u8], compression: Compression) -> Result<Vec<u8>> {
match compression {
Compression::None => Ok(ipc_bytes.to_vec()),
Compression::Zstd(level) => {
let mut enc = zstd::bulk::Compressor::new(level)
.map_err(|e| RpcError::runtime_error(format!("zstd encoder: {e}")))?;
enc.set_parameter(zstd::stream::raw::CParameter::ContentSizeFlag(true))
.map_err(|e| RpcError::runtime_error(format!("zstd contentsize: {e}")))?;
enc.compress(ipc_bytes)
.map_err(|e| RpcError::runtime_error(format!("zstd encode: {e}")))
}
}
}
fn decompress(bytes: &[u8], compression: Compression, max_size: usize) -> Result<Vec<u8>> {
match compression {
Compression::None => {
if bytes.len() > max_size {
return Err(RpcError::runtime_error(format!(
"external payload {} bytes exceeds max_decompressed_bytes={max_size}",
bytes.len()
)));
}
Ok(bytes.to_vec())
}
Compression::Zstd(_) => {
use std::io::Read;
let mut decoder = zstd::Decoder::new(bytes)
.map_err(|e| RpcError::runtime_error(format!("zstd decode: {e}")))?;
let mut out = Vec::new();
let mut buf = [0u8; 64 * 1024];
loop {
let n = decoder
.read(&mut buf)
.map_err(|e| RpcError::runtime_error(format!("zstd decode: {e}")))?;
if n == 0 {
break;
}
if out.len() + n > max_size {
return Err(RpcError::runtime_error(format!(
"zstd decode: output exceeds max_decompressed_bytes={max_size}"
)));
}
out.extend_from_slice(&buf[..n]);
}
Ok(out)
}
}
}
pub fn pointer_schema() -> SchemaRef {
Arc::new(Schema::empty())
}
pub fn maybe_externalize_batch(
batch: &RecordBatch,
inline_metadata: Option<&Metadata>,
cfg: &ExternalLocationConfig,
) -> Result<Option<(RecordBatch, Metadata)>> {
if batch.num_rows() == 0 {
return Ok(None);
}
let ipc_bytes = serialize_batch_to_ipc(batch)?;
if ipc_bytes.len() < cfg.threshold_bytes {
return Ok(None);
}
let sha = sha256_hex(&ipc_bytes);
let payload = compress(&ipc_bytes, cfg.compression)?;
let upload = cfg.storage.upload(&payload, cfg.compression)?;
(cfg.url_validator)(&upload.url)?;
let mut md: Metadata = inline_metadata.cloned().unwrap_or_default();
md.remove(LOCATION_KEY);
md.remove(LOCATION_SHA256_KEY);
md.remove(LOCATION_FETCH_MS_KEY);
md.insert(LOCATION_KEY.to_string(), upload.url);
md.insert(LOCATION_SHA256_KEY.to_string(), sha);
let ptr = empty_batch(batch.schema().as_ref())?;
Ok(Some((ptr, md)))
}
pub fn resolve_external_location(
batch: &RecordBatch,
metadata: &Metadata,
cfg: &ExternalLocationConfig,
) -> Result<(RecordBatch, Metadata)> {
let Some(url) = md_get(metadata, LOCATION_KEY) else {
return Ok((batch.clone(), metadata.clone()));
};
(cfg.url_validator)(url)?;
let start = std::time::Instant::now();
let compressed = cfg
.fetcher
.fetch(url, cfg.compression, cfg.max_decompressed_bytes)?;
let ipc_bytes = decompress(&compressed, cfg.compression, cfg.max_decompressed_bytes)?;
if let Some(expected) = md_get(metadata, LOCATION_SHA256_KEY) {
let actual = sha256_hex(&ipc_bytes);
if expected != actual.as_str() {
return Err(RpcError::runtime_error(format!(
"external location SHA-256 mismatch (expected {expected}, got {actual})"
)));
}
}
let resolved = deserialize_single_batch(&ipc_bytes)?;
let fetch_ms = start.elapsed().as_secs_f64() * 1000.0;
let mut user_md: Metadata = metadata
.iter()
.filter(|(k, _)| {
*k != LOCATION_KEY && *k != LOCATION_SHA256_KEY && *k != LOCATION_FETCH_MS_KEY
})
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
user_md.insert(
LOCATION_FETCH_MS_KEY.to_string(),
format!("{:.2}", fetch_ms),
);
Ok((resolved, user_md))
}
#[cfg(any(test, feature = "test-utils"))]
pub struct InMemoryStorage {
map: std::sync::Mutex<std::collections::HashMap<String, Vec<u8>>>,
next_id: std::sync::atomic::AtomicU64,
base_url: String,
}
#[cfg(any(test, feature = "test-utils"))]
impl InMemoryStorage {
pub fn new() -> Arc<Self> {
Arc::new(Self {
map: std::sync::Mutex::new(std::collections::HashMap::new()),
next_id: std::sync::atomic::AtomicU64::new(1),
base_url: "https://inmem.test/".to_string(),
})
}
pub fn len(&self) -> usize {
self.map.lock().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(any(test, feature = "test-utils"))]
impl ExternalStorage for InMemoryStorage {
fn upload(&self, ipc_bytes: &[u8], _compression: Compression) -> Result<UploadResult> {
let id = self
.next_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let url = format!("{}{:016x}", self.base_url, id);
let sha = sha256_hex(
ipc_bytes,
);
self.map
.lock()
.unwrap()
.insert(url.clone(), ipc_bytes.to_vec());
Ok(UploadResult { url, sha256: sha })
}
}
#[cfg(any(test, feature = "test-utils"))]
impl Fetcher for InMemoryStorage {
fn fetch(&self, url: &str, _compression: Compression, max_bytes: usize) -> Result<Vec<u8>> {
let bytes = self
.map
.lock()
.unwrap()
.get(url)
.cloned()
.ok_or_else(|| RpcError::runtime_error(format!("inmem fetch miss: {url}")))?;
if bytes.len() > max_bytes {
return Err(RpcError::runtime_error(format!(
"inmem fetch payload {} bytes exceeds max_bytes={max_bytes}",
bytes.len()
)));
}
Ok(bytes)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc as Ar;
use arrow_array::{Int64Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use super::*;
fn big_batch(rows: usize) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new(
"value",
DataType::Int64,
false,
)]));
let col: Ar<dyn arrow_array::Array> =
Arc::new(Int64Array::from((0..rows as i64).collect::<Vec<_>>()));
RecordBatch::try_new(schema, vec![col]).unwrap()
}
fn cfg_with(storage: Arc<InMemoryStorage>, threshold: usize) -> ExternalLocationConfig {
let s: Arc<dyn ExternalStorage> = storage.clone();
let f: Arc<dyn Fetcher> = storage;
ExternalLocationConfig::new(s, f)
.with_threshold_bytes(threshold)
.with_url_validator(any_url_validator())
}
#[test]
fn small_batch_stays_inline() {
let storage = InMemoryStorage::new();
let cfg = cfg_with(storage.clone(), 1024 * 1024);
let batch = big_batch(10);
let out = maybe_externalize_batch(&batch, None, &cfg).unwrap();
assert!(out.is_none());
assert!(storage.is_empty());
}
#[test]
fn large_batch_externalizes_and_round_trips() {
let storage = InMemoryStorage::new();
let cfg = cfg_with(storage.clone(), 1024);
let batch = big_batch(50_000);
let (ptr, md) = maybe_externalize_batch(&batch, None, &cfg)
.unwrap()
.unwrap();
assert_eq!(ptr.num_rows(), 0);
assert_eq!(ptr.schema().fields().len(), batch.schema().fields().len());
assert!(md_get(&md, LOCATION_KEY).unwrap().starts_with("https://"));
assert_eq!(storage.len(), 1);
let (resolved, user_md) = resolve_external_location(&ptr, &md, &cfg).unwrap();
assert_eq!(resolved.num_rows(), batch.num_rows());
assert!(md_get(&user_md, LOCATION_KEY).is_none());
assert!(md_get(&user_md, LOCATION_FETCH_MS_KEY).is_some());
}
#[test]
fn zstd_compression_round_trip() {
let storage = InMemoryStorage::new();
let cfg = cfg_with(storage.clone(), 1024).with_compression(Compression::Zstd(3));
let batch = big_batch(20_000);
let (ptr, md) = maybe_externalize_batch(&batch, None, &cfg)
.unwrap()
.unwrap();
let (resolved, _) = resolve_external_location(&ptr, &md, &cfg).unwrap();
assert_eq!(resolved.num_rows(), batch.num_rows());
}
#[test]
fn https_only_validator_rejects_plaintext() {
let storage = InMemoryStorage::new();
let s: Arc<dyn ExternalStorage> = storage.clone();
let f: Arc<dyn Fetcher> = storage;
let cfg = ExternalLocationConfig::new(s, f).with_threshold_bytes(0);
let batch = big_batch(1);
let mut bogus_md = Metadata::new();
bogus_md.insert("vgi_rpc.location".into(), "http://not-secure/x".into());
let err = resolve_external_location(&batch, &bogus_md, &cfg).unwrap_err();
assert!(err.message.contains("https://"));
}
#[test]
fn safe_https_validator_blocks_ssrf_targets() {
let v = safe_https_validator();
assert!(v("http://example.com/x").is_err());
assert!(v("https://169.254.169.254/latest/meta-data/").is_err());
assert!(v("https://127.0.0.1/").is_err());
assert!(v("https://10.0.0.1/").is_err());
assert!(v("https://192.168.1.1/").is_err());
assert!(v("https://[::1]/").is_err());
assert!(v("https://0.0.0.0/").is_err());
assert!(v("https://localhost/x").is_err());
assert!(v("https://api.localhost/x").is_err());
assert!(v("https://1.1.1.1/x").is_ok());
}
#[test]
fn sha_mismatch_is_rejected() {
let storage = InMemoryStorage::new();
let cfg = cfg_with(storage.clone(), 1024);
let batch = big_batch(10_000);
let (ptr, mut md) = maybe_externalize_batch(&batch, None, &cfg)
.unwrap()
.unwrap();
if let Some(v) = md.get_mut(LOCATION_SHA256_KEY) {
*v = "deadbeef".into();
}
let err = resolve_external_location(&ptr, &md, &cfg).unwrap_err();
assert!(err.message.contains("SHA-256 mismatch"));
}
}