1use std::net::{IpAddr, ToSocketAddrs};
21use std::sync::Arc;
22
23use arrow_array::RecordBatch;
24use arrow_schema::{Schema, SchemaRef};
25use sha2::{Digest, Sha256};
26
27use crate::errors::{Result, RpcError};
28use crate::metadata::{LOCATION_FETCH_MS_KEY, LOCATION_KEY, LOCATION_SHA256_KEY};
29use crate::wire::{bytes_to_hex, empty_batch, md_get, write_one_batch, Metadata, StreamReader};
30
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
33pub enum Compression {
34 None,
35 Zstd(i32),
37}
38
39#[derive(Clone, Debug)]
41pub struct UploadResult {
42 pub url: String,
44 pub sha256: String,
46}
47
48pub trait ExternalStorage: Send + Sync {
55 fn upload(&self, ipc_bytes: &[u8], compression: Compression) -> Result<UploadResult>;
56}
57
58#[derive(Clone, Debug)]
63pub struct UploadUrl {
64 pub upload_url: String,
65 pub download_url: String,
66 pub expires_at_micros: i64,
68}
69
70pub trait UploadUrlProvider: Send + Sync {
76 fn generate_upload_url(&self) -> Result<UploadUrl>;
77}
78
79pub type UrlValidator = Arc<dyn Fn(&str) -> Result<()> + Send + Sync>;
83
84pub trait Fetcher: Send + Sync {
96 fn fetch(&self, url: &str, compression: Compression, max_bytes: usize) -> Result<Vec<u8>>;
97}
98
99#[derive(Clone)]
101pub struct ExternalLocationConfig {
102 pub threshold_bytes: usize,
105 pub compression: Compression,
108 pub storage: Arc<dyn ExternalStorage>,
110 pub fetcher: Arc<dyn Fetcher>,
113 pub url_validator: UrlValidator,
118 pub max_decompressed_bytes: usize,
125}
126
127impl std::fmt::Debug for ExternalLocationConfig {
128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 f.debug_struct("ExternalLocationConfig")
130 .field("threshold_bytes", &self.threshold_bytes)
131 .field("compression", &self.compression)
132 .finish_non_exhaustive()
133 }
134}
135
136pub fn https_only_validator() -> UrlValidator {
144 Arc::new(|url: &str| {
145 if url.starts_with("https://") {
146 Ok(())
147 } else {
148 Err(RpcError::value_error(format!(
149 "external location URL must be https:// ({url})"
150 )))
151 }
152 })
153}
154
155pub fn safe_https_validator() -> UrlValidator {
170 Arc::new(|raw: &str| {
171 let url = url::Url::parse(raw)
172 .map_err(|e| RpcError::value_error(format!("invalid external location URL: {e}")))?;
173 if url.scheme() != "https" {
174 return Err(RpcError::value_error(
175 "external location URL must be https://",
176 ));
177 }
178 let host = url
179 .host()
180 .ok_or_else(|| RpcError::value_error("external location URL has no host"))?;
181 match host {
182 url::Host::Ipv4(ip) => reject_unsafe_ip(IpAddr::V4(ip)),
183 url::Host::Ipv6(ip) => reject_unsafe_ip(IpAddr::V6(ip)),
184 url::Host::Domain(name) => {
185 let lname = name.to_ascii_lowercase();
186 if lname == "localhost" || lname.ends_with(".localhost") {
187 return Err(RpcError::value_error(
188 "external location host is not publicly routable",
189 ));
190 }
191 let port = url.port_or_known_default().unwrap_or(443);
192 let addrs = (name, port).to_socket_addrs().map_err(|e| {
193 RpcError::value_error(format!("external location host does not resolve: {e}"))
194 })?;
195 let mut saw_any = false;
196 for sa in addrs {
197 saw_any = true;
198 reject_unsafe_ip(sa.ip())?;
199 }
200 if !saw_any {
201 return Err(RpcError::value_error(
202 "external location host does not resolve",
203 ));
204 }
205 Ok(())
206 }
207 }
208 })
209}
210
211fn reject_unsafe_ip(ip: IpAddr) -> Result<()> {
215 let unsafe_addr = match ip {
216 IpAddr::V4(v4) => {
217 let o = v4.octets();
218 v4.is_loopback()
219 || v4.is_private()
220 || v4.is_link_local()
221 || v4.is_unspecified()
222 || v4.is_broadcast()
223 || v4.is_multicast()
224 || v4.is_documentation()
225 || (o[0] == 100 && (o[1] & 0xc0) == 0x40)
227 }
228 IpAddr::V6(v6) => {
229 let seg0 = v6.segments()[0];
230 v6.is_loopback()
231 || v6.is_unspecified()
232 || v6.is_multicast()
233 || (seg0 & 0xfe00) == 0xfc00
235 || (seg0 & 0xffc0) == 0xfe80
237 || v6
239 .to_ipv4_mapped()
240 .map(|m| reject_unsafe_ip(IpAddr::V4(m)).is_err())
241 .unwrap_or(false)
242 }
243 };
244 if unsafe_addr {
245 return Err(RpcError::value_error(
246 "external location host resolves to a non-routable / internal address",
247 ));
248 }
249 Ok(())
250}
251
252pub fn any_url_validator() -> UrlValidator {
254 Arc::new(|_: &str| Ok(()))
255}
256
257impl ExternalLocationConfig {
258 pub fn new(storage: Arc<dyn ExternalStorage>, fetcher: Arc<dyn Fetcher>) -> Self {
259 Self {
260 threshold_bytes: 1024 * 1024,
261 compression: Compression::None,
262 storage,
263 fetcher,
264 url_validator: safe_https_validator(),
265 max_decompressed_bytes: 1024 * 1024 * 1024,
266 }
267 }
268
269 pub fn with_threshold_bytes(mut self, n: usize) -> Self {
270 self.threshold_bytes = n;
271 self
272 }
273
274 pub fn with_compression(mut self, c: Compression) -> Self {
275 self.compression = c;
276 self
277 }
278
279 pub fn with_url_validator(mut self, v: UrlValidator) -> Self {
280 self.url_validator = v;
281 self
282 }
283
284 pub fn with_max_decompressed_bytes(mut self, n: usize) -> Self {
287 self.max_decompressed_bytes = n;
288 self
289 }
290}
291
292pub fn serialize_batch_to_ipc(batch: &RecordBatch) -> Result<Vec<u8>> {
298 write_one_batch(batch, None)
302}
303
304pub fn fetch_external_ipc_bytes(
314 metadata: &Metadata,
315 cfg: &ExternalLocationConfig,
316) -> Result<Option<Vec<u8>>> {
317 let Some(url) = md_get(metadata, LOCATION_KEY) else {
318 return Ok(None);
319 };
320 (cfg.url_validator)(url)?;
321 let compressed = cfg
322 .fetcher
323 .fetch(url, cfg.compression, cfg.max_decompressed_bytes)?;
324 let ipc_bytes = decompress(&compressed, cfg.compression, cfg.max_decompressed_bytes)?;
325 if let Some(expected) = md_get(metadata, LOCATION_SHA256_KEY) {
326 let actual = sha256_hex(&ipc_bytes);
327 if expected != actual.as_str() {
328 return Err(RpcError::runtime_error(format!(
329 "external location SHA-256 mismatch (expected {expected}, got {actual})"
330 )));
331 }
332 }
333 Ok(Some(ipc_bytes))
334}
335
336pub fn deserialize_single_batch(ipc_bytes: &[u8]) -> Result<RecordBatch> {
337 Ok(deserialize_single_batch_with_metadata(ipc_bytes)?.0)
338}
339
340pub fn deserialize_single_batch_with_metadata(ipc_bytes: &[u8]) -> Result<(RecordBatch, Metadata)> {
344 let mut r = StreamReader::new(ipc_bytes)?;
345 r.read_next()?
346 .ok_or_else(|| RpcError::runtime_error("external batch stream is empty"))
347}
348
349fn sha256_hex(bytes: &[u8]) -> String {
350 bytes_to_hex(&Sha256::digest(bytes))
351}
352
353fn compress(ipc_bytes: &[u8], compression: Compression) -> Result<Vec<u8>> {
354 match compression {
355 Compression::None => Ok(ipc_bytes.to_vec()),
356 Compression::Zstd(level) => {
357 let mut enc = zstd::bulk::Compressor::new(level)
362 .map_err(|e| RpcError::runtime_error(format!("zstd encoder: {e}")))?;
363 enc.set_parameter(zstd::stream::raw::CParameter::ContentSizeFlag(true))
364 .map_err(|e| RpcError::runtime_error(format!("zstd contentsize: {e}")))?;
365 enc.compress(ipc_bytes)
366 .map_err(|e| RpcError::runtime_error(format!("zstd encode: {e}")))
367 }
368 }
369}
370
371fn decompress(bytes: &[u8], compression: Compression, max_size: usize) -> Result<Vec<u8>> {
372 match compression {
373 Compression::None => {
374 if bytes.len() > max_size {
375 return Err(RpcError::runtime_error(format!(
376 "external payload {} bytes exceeds max_decompressed_bytes={max_size}",
377 bytes.len()
378 )));
379 }
380 Ok(bytes.to_vec())
381 }
382 Compression::Zstd(_) => {
387 use std::io::Read;
388 let mut decoder = zstd::Decoder::new(bytes)
389 .map_err(|e| RpcError::runtime_error(format!("zstd decode: {e}")))?;
390 let mut out = Vec::new();
391 let mut buf = [0u8; 64 * 1024];
392 loop {
393 let n = decoder
394 .read(&mut buf)
395 .map_err(|e| RpcError::runtime_error(format!("zstd decode: {e}")))?;
396 if n == 0 {
397 break;
398 }
399 if out.len() + n > max_size {
400 return Err(RpcError::runtime_error(format!(
401 "zstd decode: output exceeds max_decompressed_bytes={max_size}"
402 )));
403 }
404 out.extend_from_slice(&buf[..n]);
405 }
406 Ok(out)
407 }
408 }
409}
410
411pub fn pointer_schema() -> SchemaRef {
419 Arc::new(Schema::empty())
420}
421
422pub fn maybe_externalize_batch(
429 batch: &RecordBatch,
430 inline_metadata: Option<&Metadata>,
431 cfg: &ExternalLocationConfig,
432) -> Result<Option<(RecordBatch, Metadata)>> {
433 if batch.num_rows() == 0 {
434 return Ok(None);
435 }
436 let ipc_bytes = serialize_batch_to_ipc(batch)?;
437 if ipc_bytes.len() < cfg.threshold_bytes {
438 return Ok(None);
439 }
440 let sha = sha256_hex(&ipc_bytes);
441 let payload = compress(&ipc_bytes, cfg.compression)?;
442 let upload = cfg.storage.upload(&payload, cfg.compression)?;
443 (cfg.url_validator)(&upload.url)?;
445
446 let mut md: Metadata = inline_metadata.cloned().unwrap_or_default();
448 md.remove(LOCATION_KEY);
449 md.remove(LOCATION_SHA256_KEY);
450 md.remove(LOCATION_FETCH_MS_KEY);
451 md.insert(LOCATION_KEY.to_string(), upload.url);
452 md.insert(LOCATION_SHA256_KEY.to_string(), sha);
453
454 let ptr = empty_batch(batch.schema().as_ref())?;
458 Ok(Some((ptr, md)))
459}
460
461pub fn resolve_external_location(
474 batch: &RecordBatch,
475 metadata: &Metadata,
476 cfg: &ExternalLocationConfig,
477) -> Result<(RecordBatch, Metadata)> {
478 let Some(url) = md_get(metadata, LOCATION_KEY) else {
479 return Ok((batch.clone(), metadata.clone()));
480 };
481 (cfg.url_validator)(url)?;
482
483 let start = std::time::Instant::now();
484 let compressed = cfg
489 .fetcher
490 .fetch(url, cfg.compression, cfg.max_decompressed_bytes)?;
491 let ipc_bytes = decompress(&compressed, cfg.compression, cfg.max_decompressed_bytes)?;
492
493 if let Some(expected) = md_get(metadata, LOCATION_SHA256_KEY) {
495 let actual = sha256_hex(&ipc_bytes);
496 if expected != actual.as_str() {
497 return Err(RpcError::runtime_error(format!(
498 "external location SHA-256 mismatch (expected {expected}, got {actual})"
499 )));
500 }
501 }
502 let (resolved, inner_md) = deserialize_single_batch_with_metadata(&ipc_bytes)?;
503 let fetch_ms = start.elapsed().as_secs_f64() * 1000.0;
504
505 let mut user_md: Metadata = metadata
512 .iter()
513 .filter(|(k, _)| {
514 *k != LOCATION_KEY && *k != LOCATION_SHA256_KEY && *k != LOCATION_FETCH_MS_KEY
515 })
516 .map(|(k, v)| (k.clone(), v.clone()))
517 .collect();
518 for (k, v) in inner_md {
519 if k != LOCATION_KEY && k != LOCATION_SHA256_KEY && k != LOCATION_FETCH_MS_KEY {
520 user_md.insert(k, v);
521 }
522 }
523 user_md.insert(
524 LOCATION_FETCH_MS_KEY.to_string(),
525 format!("{:.2}", fetch_ms),
526 );
527 Ok((resolved, user_md))
528}
529
530#[cfg(any(test, feature = "test-utils"))]
542pub struct InMemoryStorage {
543 map: std::sync::Mutex<std::collections::HashMap<String, Vec<u8>>>,
544 next_id: std::sync::atomic::AtomicU64,
545 base_url: String,
546}
547
548#[cfg(any(test, feature = "test-utils"))]
549impl InMemoryStorage {
550 pub fn new() -> Arc<Self> {
551 Arc::new(Self {
552 map: std::sync::Mutex::new(std::collections::HashMap::new()),
553 next_id: std::sync::atomic::AtomicU64::new(1),
554 base_url: "https://inmem.test/".to_string(),
555 })
556 }
557
558 pub fn len(&self) -> usize {
559 self.map.lock().unwrap().len()
560 }
561
562 pub fn is_empty(&self) -> bool {
563 self.len() == 0
564 }
565}
566
567#[cfg(any(test, feature = "test-utils"))]
568impl ExternalStorage for InMemoryStorage {
569 fn upload(&self, ipc_bytes: &[u8], _compression: Compression) -> Result<UploadResult> {
570 let id = self
571 .next_id
572 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
573 let url = format!("{}{:016x}", self.base_url, id);
574 let sha = sha256_hex(
575 ipc_bytes,
581 );
582 self.map
583 .lock()
584 .unwrap()
585 .insert(url.clone(), ipc_bytes.to_vec());
586 Ok(UploadResult { url, sha256: sha })
587 }
588}
589
590#[cfg(any(test, feature = "test-utils"))]
591impl Fetcher for InMemoryStorage {
592 fn fetch(&self, url: &str, _compression: Compression, max_bytes: usize) -> Result<Vec<u8>> {
593 let bytes = self
594 .map
595 .lock()
596 .unwrap()
597 .get(url)
598 .cloned()
599 .ok_or_else(|| RpcError::runtime_error(format!("inmem fetch miss: {url}")))?;
600 if bytes.len() > max_bytes {
601 return Err(RpcError::runtime_error(format!(
602 "inmem fetch payload {} bytes exceeds max_bytes={max_bytes}",
603 bytes.len()
604 )));
605 }
606 Ok(bytes)
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use std::sync::Arc as Ar;
613
614 use arrow_array::{Int64Array, RecordBatch};
615 use arrow_schema::{DataType, Field, Schema};
616
617 use super::*;
618
619 fn big_batch(rows: usize) -> RecordBatch {
620 let schema = Arc::new(Schema::new(vec![Field::new(
621 "value",
622 DataType::Int64,
623 false,
624 )]));
625 let col: Ar<dyn arrow_array::Array> =
626 Arc::new(Int64Array::from((0..rows as i64).collect::<Vec<_>>()));
627 RecordBatch::try_new(schema, vec![col]).unwrap()
628 }
629
630 fn cfg_with(storage: Arc<InMemoryStorage>, threshold: usize) -> ExternalLocationConfig {
631 let s: Arc<dyn ExternalStorage> = storage.clone();
632 let f: Arc<dyn Fetcher> = storage;
633 ExternalLocationConfig::new(s, f)
634 .with_threshold_bytes(threshold)
635 .with_url_validator(any_url_validator())
636 }
637
638 #[test]
639 fn small_batch_stays_inline() {
640 let storage = InMemoryStorage::new();
641 let cfg = cfg_with(storage.clone(), 1024 * 1024);
642 let batch = big_batch(10);
643 let out = maybe_externalize_batch(&batch, None, &cfg).unwrap();
644 assert!(out.is_none());
645 assert!(storage.is_empty());
646 }
647
648 #[test]
649 fn large_batch_externalizes_and_round_trips() {
650 let storage = InMemoryStorage::new();
651 let cfg = cfg_with(storage.clone(), 1024);
652 let batch = big_batch(50_000);
653
654 let (ptr, md) = maybe_externalize_batch(&batch, None, &cfg)
655 .unwrap()
656 .unwrap();
657 assert_eq!(ptr.num_rows(), 0);
658 assert_eq!(ptr.schema().fields().len(), batch.schema().fields().len());
661 assert!(md_get(&md, LOCATION_KEY).unwrap().starts_with("https://"));
662 assert_eq!(storage.len(), 1);
663
664 let (resolved, user_md) = resolve_external_location(&ptr, &md, &cfg).unwrap();
665 assert_eq!(resolved.num_rows(), batch.num_rows());
666 assert!(md_get(&user_md, LOCATION_KEY).is_none());
667 assert!(md_get(&user_md, LOCATION_FETCH_MS_KEY).is_some());
668 }
669
670 #[test]
671 fn zstd_compression_round_trip() {
672 let storage = InMemoryStorage::new();
673 let cfg = cfg_with(storage.clone(), 1024).with_compression(Compression::Zstd(3));
674 let batch = big_batch(20_000);
675 let (ptr, md) = maybe_externalize_batch(&batch, None, &cfg)
676 .unwrap()
677 .unwrap();
678 let (resolved, _) = resolve_external_location(&ptr, &md, &cfg).unwrap();
679 assert_eq!(resolved.num_rows(), batch.num_rows());
680 }
681
682 #[test]
683 fn https_only_validator_rejects_plaintext() {
684 let storage = InMemoryStorage::new();
685 let s: Arc<dyn ExternalStorage> = storage.clone();
686 let f: Arc<dyn Fetcher> = storage;
687 let cfg = ExternalLocationConfig::new(s, f).with_threshold_bytes(0);
688 let batch = big_batch(1);
690 let mut bogus_md = Metadata::new();
691 bogus_md.insert("vgi_rpc.location".into(), "http://not-secure/x".into());
692 let err = resolve_external_location(&batch, &bogus_md, &cfg).unwrap_err();
693 assert!(err.message.contains("https://"));
694 }
695
696 #[test]
697 fn safe_https_validator_blocks_ssrf_targets() {
698 let v = safe_https_validator();
699 assert!(v("http://example.com/x").is_err());
701 assert!(v("https://169.254.169.254/latest/meta-data/").is_err());
703 assert!(v("https://127.0.0.1/").is_err());
704 assert!(v("https://10.0.0.1/").is_err());
705 assert!(v("https://192.168.1.1/").is_err());
706 assert!(v("https://[::1]/").is_err());
707 assert!(v("https://0.0.0.0/").is_err());
708 assert!(v("https://localhost/x").is_err());
710 assert!(v("https://api.localhost/x").is_err());
711 assert!(v("https://1.1.1.1/x").is_ok());
713 }
714
715 #[test]
716 fn sha_mismatch_is_rejected() {
717 let storage = InMemoryStorage::new();
718 let cfg = cfg_with(storage.clone(), 1024);
719 let batch = big_batch(10_000);
720 let (ptr, mut md) = maybe_externalize_batch(&batch, None, &cfg)
721 .unwrap()
722 .unwrap();
723 if let Some(v) = md.get_mut(LOCATION_SHA256_KEY) {
725 *v = "deadbeef".into();
726 }
727 let err = resolve_external_location(&ptr, &md, &cfg).unwrap_err();
728 assert!(err.message.contains("SHA-256 mismatch"));
729 }
730}