use std::sync::Arc;
use arrow_array::RecordBatch;
use arrow_schema::{Schema, SchemaRef};
use axum::{
body::Bytes,
extract::{Path, State},
http::{header, HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response},
routing::post,
Router,
};
use base64::Engine;
use hmac::{Hmac, Mac};
use rand::RngCore;
use sha2::Sha256;
use crate::errors::{Result, RpcError};
use crate::metadata::{CANCEL_KEY, REQUEST_ID_KEY, STATE_KEY};
use crate::server::{
build_error_metadata, build_log_metadata, cast_batch, CallContext, MethodType, Request,
RpcServer,
};
use crate::stream::{empty_schema, Emitted, OutputCollector, StreamResult, StreamStateKind};
use crate::wire::{bytes_to_hex, empty_batch, md_get, Metadata, StreamReader, StreamWriter};
pub const ARROW_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
type HmacSha256 = Hmac<Sha256>;
pub struct HttpState {
server: Arc<RpcServer>,
signing_key: [u8; 32],
producer_batch_limit: usize,
token_ttl: std::time::Duration,
max_body_size: usize,
authenticate: Option<crate::auth::Authenticate>,
#[allow(dead_code)]
oauth_metadata: Option<Arc<crate::auth::oauth::OAuthResourceMetadata>>,
oauth_metadata_json: Option<Vec<u8>>,
www_authenticate: Option<String>,
cors_origins: Option<String>,
cors_max_age: u32,
prefix: String,
response_compression_level: Option<i32>,
landing_page_enabled: bool,
describe_page_enabled: bool,
health_enabled: bool,
max_request_bytes: Option<usize>,
max_upload_bytes: Option<usize>,
max_response_bytes: Option<usize>,
max_externalized_response_bytes: Option<usize>,
upload_url_provider: Option<Arc<dyn crate::external::UploadUrlProvider>>,
}
#[derive(Default)]
pub struct HttpStateBuilder {
server: Option<Arc<RpcServer>>,
signing_key: Option<[u8; 32]>,
producer_batch_limit: Option<usize>,
token_ttl: Option<std::time::Duration>,
max_body_size: Option<usize>,
authenticate: Option<crate::auth::Authenticate>,
oauth_metadata: Option<Arc<crate::auth::oauth::OAuthResourceMetadata>>,
cors_origins: Option<String>,
cors_max_age: Option<u32>,
prefix: Option<String>,
response_compression_level: Option<i32>,
landing_page_enabled: Option<bool>,
describe_page_enabled: Option<bool>,
health_enabled: Option<bool>,
max_request_bytes: Option<usize>,
max_upload_bytes: Option<usize>,
max_response_bytes: Option<usize>,
max_externalized_response_bytes: Option<usize>,
upload_url_provider: Option<Arc<dyn crate::external::UploadUrlProvider>>,
}
impl HttpStateBuilder {
pub fn server(mut self, server: Arc<RpcServer>) -> Self {
self.server = Some(server);
self
}
pub fn signing_key(mut self, key: &[u8]) -> Self {
let mut k = [0u8; 32];
let n = key.len().min(32);
k[..n].copy_from_slice(&key[..n]);
self.signing_key = Some(k);
self
}
pub fn signing_key_hex(self, hex: &str) -> Self {
let bytes = decode_hex_key(hex).expect("signing_key_hex: invalid hex or wrong length");
self.signing_key(&bytes)
}
pub fn signing_key_base64(self, b64: &str) -> Self {
let bytes =
decode_base64_key(b64).expect("signing_key_base64: invalid base64 or wrong length");
self.signing_key(&bytes)
}
pub fn signing_key_from_env(self, var: &str) -> Self {
let raw = std::env::var(var).unwrap_or_else(|_| {
panic!("signing_key_from_env: env var {var} is unset or not UTF-8")
});
let trimmed = raw.trim();
let bytes = decode_base64_key(trimmed)
.or_else(|_| decode_hex_key(trimmed))
.unwrap_or_else(|e| {
panic!("signing_key_from_env: {var} is not valid base64 or hex ({e})")
});
self.signing_key(&bytes)
}
pub fn producer_batch_limit(mut self, n: usize) -> Self {
self.producer_batch_limit = Some(n);
self
}
pub fn token_ttl(mut self, ttl: std::time::Duration) -> Self {
self.token_ttl = Some(ttl);
self
}
pub fn max_body_size(mut self, n: usize) -> Self {
self.max_body_size = Some(n);
self
}
pub fn authenticate(mut self, cb: crate::auth::Authenticate) -> Self {
self.authenticate = Some(cb);
self
}
pub fn oauth_resource_metadata(
mut self,
metadata: crate::auth::oauth::OAuthResourceMetadata,
) -> Self {
self.oauth_metadata = Some(Arc::new(metadata));
self
}
pub fn cors_origins(mut self, origins: impl Into<String>) -> Self {
self.cors_origins = Some(origins.into());
self
}
pub fn cors_max_age(mut self, seconds: u32) -> Self {
self.cors_max_age = Some(seconds);
self
}
pub fn prefix(mut self, prefix: impl Into<String>) -> Self {
self.prefix = Some(prefix.into());
self
}
pub fn response_compression_level(mut self, level: i32) -> Self {
self.response_compression_level = Some(level);
self
}
pub fn enable_landing_page(mut self, enabled: bool) -> Self {
self.landing_page_enabled = Some(enabled);
self
}
pub fn enable_describe_page(mut self, enabled: bool) -> Self {
self.describe_page_enabled = Some(enabled);
self
}
pub fn enable_health(mut self, enabled: bool) -> Self {
self.health_enabled = Some(enabled);
self
}
pub fn max_request_bytes(mut self, n: usize) -> Self {
self.max_request_bytes = Some(n);
self
}
pub fn max_upload_bytes(mut self, n: usize) -> Self {
self.max_upload_bytes = Some(n);
self
}
pub fn max_response_bytes(mut self, n: usize) -> Self {
self.max_response_bytes = Some(n);
self
}
pub fn max_externalized_response_bytes(mut self, n: usize) -> Self {
self.max_externalized_response_bytes = Some(n);
self
}
pub fn upload_url_provider(
mut self,
provider: Arc<dyn crate::external::UploadUrlProvider>,
) -> Self {
self.upload_url_provider = Some(provider);
self
}
pub fn build(self) -> Arc<HttpState> {
let server = self.server.expect("HttpStateBuilder::server is required");
let signing_key = self.signing_key.unwrap_or_else(|| {
tracing::warn!(
target: "vgi_rpc.http",
"no signing_key configured; using ephemeral per-process key — \
state tokens will not survive restart or load-balance across workers"
);
let mut k = [0u8; 32];
rand::thread_rng().fill_bytes(&mut k);
k
});
let oauth_metadata_json = self
.oauth_metadata
.as_ref()
.map(|m| m.to_json().into_bytes());
let www_authenticate = self.oauth_metadata.as_ref().map(|m| m.www_authenticate());
Arc::new(HttpState {
server,
signing_key,
producer_batch_limit: self.producer_batch_limit.unwrap_or(1),
token_ttl: self
.token_ttl
.unwrap_or_else(|| std::time::Duration::from_secs(300)),
max_body_size: self.max_body_size.unwrap_or(64 * 1024 * 1024),
authenticate: self.authenticate,
oauth_metadata: self.oauth_metadata,
oauth_metadata_json,
www_authenticate,
cors_origins: self.cors_origins,
cors_max_age: self.cors_max_age.unwrap_or(7200),
prefix: self.prefix.unwrap_or_default(),
response_compression_level: self.response_compression_level,
landing_page_enabled: self.landing_page_enabled.unwrap_or(true),
describe_page_enabled: self.describe_page_enabled.unwrap_or(true),
health_enabled: self.health_enabled.unwrap_or(true),
max_request_bytes: self.max_request_bytes,
max_upload_bytes: self.max_upload_bytes,
max_response_bytes: self.max_response_bytes,
max_externalized_response_bytes: self.max_externalized_response_bytes,
upload_url_provider: self.upload_url_provider,
})
}
}
impl HttpState {
pub fn new(server: Arc<RpcServer>) -> Arc<Self> {
Self::builder().server(server).build()
}
pub fn builder() -> HttpStateBuilder {
HttpStateBuilder::default()
}
pub fn token_ttl(&self) -> std::time::Duration {
self.token_ttl
}
pub fn max_body_size(&self) -> usize {
self.max_body_size
}
pub(crate) fn pack_state_token(
&self,
auth: &crate::auth::AuthContext,
state_bytes: &[u8],
output_schema_bytes: &[u8],
input_schema_bytes: &[u8],
stream_id: &str,
) -> String {
let binding = principal_binding(auth);
pack_state_token(
&self.signing_key,
&binding,
state_bytes,
output_schema_bytes,
input_schema_bytes,
stream_id,
current_unix_secs(),
)
}
pub(crate) fn unpack_state_token(
&self,
auth: &crate::auth::AuthContext,
token: &str,
) -> Result<UnpackedToken> {
let ttl = if self.token_ttl.is_zero() {
None
} else {
Some(self.token_ttl)
};
let binding = principal_binding(auth);
unpack_state_token(&self.signing_key, &binding, token, ttl)
}
}
fn principal_binding(auth: &crate::auth::AuthContext) -> Vec<u8> {
if !auth.authenticated {
return Vec::new();
}
let mut out = Vec::with_capacity(auth.domain.len() + 1 + auth.principal.len());
out.extend_from_slice(auth.domain.as_bytes());
out.push(0);
out.extend_from_slice(auth.principal.as_bytes());
out
}
const PRINCIPAL_BINDING_LABEL: &[u8] = b"vgi_rpc.state_token.principal_binding.v1";
fn derive_signing_key(base: &[u8; 32], binding: &[u8]) -> [u8; 32] {
if binding.is_empty() {
return *base;
}
let mut mac = HmacSha256::new_from_slice(base).expect("hmac key");
mac.update(PRINCIPAL_BINDING_LABEL);
mac.update(&[0u8]);
mac.update(binding);
let tag = mac.finalize().into_bytes();
let mut out = [0u8; 32];
out.copy_from_slice(&tag);
out
}
pub(crate) const STATE_TOKEN_VERSION: u8 = 0x03;
const STATE_TOKEN_MIN_LEN: usize = 1 + 8 + 4 + 4 + 4 + 4 + 32;
#[derive(Debug, Clone)]
pub(crate) struct UnpackedToken {
pub state_bytes: Vec<u8>,
pub output_schema_bytes: Vec<u8>,
pub input_schema_bytes: Vec<u8>,
pub stream_id: String,
#[allow(dead_code)]
pub created_at: u64,
}
fn current_unix_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
pub(crate) fn pack_state_token(
signing_key: &[u8; 32],
principal_binding: &[u8],
state_bytes: &[u8],
output_schema_bytes: &[u8],
input_schema_bytes: &[u8],
stream_id: &str,
created_at: u64,
) -> String {
let key = derive_signing_key(signing_key, principal_binding);
let mut payload = Vec::with_capacity(
1 + 8
+ 4
+ state_bytes.len()
+ 4
+ output_schema_bytes.len()
+ 4
+ input_schema_bytes.len()
+ 4
+ stream_id.len(),
);
payload.push(STATE_TOKEN_VERSION);
payload.extend_from_slice(&created_at.to_le_bytes());
payload.extend_from_slice(&(state_bytes.len() as u32).to_le_bytes());
payload.extend_from_slice(state_bytes);
payload.extend_from_slice(&(output_schema_bytes.len() as u32).to_le_bytes());
payload.extend_from_slice(output_schema_bytes);
payload.extend_from_slice(&(input_schema_bytes.len() as u32).to_le_bytes());
payload.extend_from_slice(input_schema_bytes);
payload.extend_from_slice(&(stream_id.len() as u32).to_le_bytes());
payload.extend_from_slice(stream_id.as_bytes());
let mut mac = HmacSha256::new_from_slice(&key).expect("hmac key");
mac.update(&payload);
let sig = mac.finalize().into_bytes();
payload.extend_from_slice(&sig);
base64::engine::general_purpose::STANDARD.encode(payload)
}
pub(crate) fn unpack_state_token(
signing_key: &[u8; 32],
principal_binding: &[u8],
token: &str,
token_ttl: Option<std::time::Duration>,
) -> Result<UnpackedToken> {
let raw = base64::engine::general_purpose::STANDARD
.decode(token.as_bytes())
.map_err(|_| RpcError::runtime_error("Malformed state token"))?;
if raw.len() < STATE_TOKEN_MIN_LEN {
return Err(RpcError::runtime_error("Malformed state token"));
}
let payload_end = raw.len() - 32;
let (payload, received_mac) = raw.split_at(payload_end);
let key = derive_signing_key(signing_key, principal_binding);
let mut mac = HmacSha256::new_from_slice(&key).expect("hmac key");
mac.update(payload);
mac.verify_slice(received_mac)
.map_err(|_| RpcError::runtime_error("State token signature verification failed"))?;
let version = payload[0];
if version != STATE_TOKEN_VERSION {
return Err(RpcError::runtime_error(format!(
"Unsupported state token version {version} (expected {STATE_TOKEN_VERSION})"
)));
}
let created_at = u64::from_le_bytes(payload[1..9].try_into().unwrap());
if let Some(ttl) = token_ttl {
let now = current_unix_secs();
if now > created_at && now - created_at > ttl.as_secs() {
return Err(RpcError::runtime_error("State token expired"));
}
}
let mut pos = 9;
let state_bytes = read_segment(payload, &mut pos)?;
let output_schema_bytes = read_segment(payload, &mut pos)?;
let input_schema_bytes = read_segment(payload, &mut pos)?;
let stream_id_bytes = read_segment(payload, &mut pos)?;
if pos != payload.len() {
return Err(RpcError::runtime_error("Malformed state token"));
}
let stream_id = String::from_utf8(stream_id_bytes)
.map_err(|_| RpcError::runtime_error("Malformed state token"))?;
Ok(UnpackedToken {
state_bytes,
output_schema_bytes,
input_schema_bytes,
stream_id,
created_at,
})
}
fn read_segment(buf: &[u8], pos: &mut usize) -> Result<Vec<u8>> {
if *pos + 4 > buf.len() {
return Err(RpcError::runtime_error("Malformed state token"));
}
let len = u32::from_le_bytes(buf[*pos..*pos + 4].try_into().unwrap()) as usize;
*pos += 4;
if *pos + len > buf.len() {
return Err(RpcError::runtime_error("Malformed state token"));
}
let out = buf[*pos..*pos + len].to_vec();
*pos += len;
Ok(out)
}
fn write_schema_bytes(schema: &Schema) -> Result<Vec<u8>> {
let empty = empty_batch(schema)?;
crate::wire::write_one_batch(&empty, None)
}
fn read_schema_bytes(bytes: &[u8]) -> Result<SchemaRef> {
let r = StreamReader::new(bytes)?;
Ok(r.schema())
}
pub async fn shutdown_signal() {
#[cfg(unix)]
{
use tokio::signal::unix::{signal, SignalKind};
let mut term = match signal(SignalKind::terminate()) {
Ok(s) => s,
Err(_) => {
let _ = tokio::signal::ctrl_c().await;
return;
}
};
let mut intr = match signal(SignalKind::interrupt()) {
Ok(s) => s,
Err(_) => {
let _ = tokio::signal::ctrl_c().await;
return;
}
};
tokio::select! {
_ = term.recv() => {},
_ = intr.recv() => {},
}
}
#[cfg(not(unix))]
{
let _ = tokio::signal::ctrl_c().await;
}
}
pub async fn serve_with_shutdown(
state: Arc<HttpState>,
listener: tokio::net::TcpListener,
) -> std::io::Result<()> {
let app = build_router(state);
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await
}
pub fn build_router(state: Arc<HttpState>) -> Router {
build_router_inner(state.clone()).layer(axum::middleware::from_fn_with_state(
state,
postprocess_middleware,
))
}
async fn postprocess_middleware(
axum::extract::State(state): axum::extract::State<Arc<HttpState>>,
req: axum::http::Request<axum::body::Body>,
next: axum::middleware::Next,
) -> Response {
use axum::body::to_bytes;
state.server.notify_transport(
crate::transport::TransportKind::Http,
crate::transport::TransportCapabilities::none(),
);
let req_headers = req.headers().clone();
let req_method = req.method().clone();
let req_path = req.uri().path().to_string();
if let Some(limit) = state.max_request_bytes {
let exempt = req_path.ends_with("/__upload_url__/init")
|| req_path.contains("/__upload_url__/")
|| req_path == "/health"
|| req_path.ends_with("/health");
if !exempt {
if let Some(cl) = req
.headers()
.get(header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<usize>().ok())
{
if cl > limit {
let mut h = HeaderMap::new();
attach_capability_headers(&state, &mut h, &req_method);
attach_cors_headers(&state, &mut h, &req_headers, false);
return (
StatusCode::PAYLOAD_TOO_LARGE,
h,
format!(
"Request body of {cl} bytes exceeds advertised \
max_request_bytes={limit}. Use the upload-URL \
flow (__upload_url__/init) to externalize."
),
)
.into_response();
}
}
}
}
let resp = next.run(req).await;
let (mut parts, body) = resp.into_parts();
let bytes = to_bytes(body, usize::MAX).await.unwrap_or_default();
let is_arrow = parts
.headers
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
== Some(ARROW_CONTENT_TYPE);
if is_arrow {
if let Some(level) = state.response_compression_level {
let accepts = req_headers
.get(header::ACCEPT_ENCODING)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if accepts.contains("zstd") {
if let Ok(compressed) = zstd::encode_all(std::io::Cursor::new(&bytes), level) {
parts
.headers
.insert(header::CONTENT_ENCODING, HeaderValue::from_static("zstd"));
attach_cors_headers(&state, &mut parts.headers, &req_headers, false);
let body_new = axum::body::Body::from(compressed);
return Response::from_parts(parts, body_new);
}
}
}
}
attach_cors_headers(&state, &mut parts.headers, &req_headers, false);
attach_capability_headers(&state, &mut parts.headers, &req_method);
Response::from_parts(parts, axum::body::Body::from(bytes))
}
fn attach_capability_headers(
state: &Arc<HttpState>,
out: &mut HeaderMap,
method: &axum::http::Method,
) {
let mut any = false;
if let Some(n) = state.max_request_bytes {
if let Ok(v) = HeaderValue::from_str(&n.to_string()) {
out.insert("vgi-max-request-bytes", v);
any = true;
}
}
if let Some(n) = state.max_response_bytes {
if let Ok(v) = HeaderValue::from_str(&n.to_string()) {
out.insert("vgi-max-response-bytes", v);
any = true;
}
}
if let Some(n) = state.max_externalized_response_bytes {
if let Ok(v) = HeaderValue::from_str(&n.to_string()) {
out.insert("vgi-max-externalized-response-bytes", v);
any = true;
}
}
out.insert(
"vgi-externalization-enabled",
HeaderValue::from_static(if state.server.external_config().is_some() {
"true"
} else {
"false"
}),
);
if state.upload_url_provider.is_some() {
out.insert("vgi-upload-url-support", HeaderValue::from_static("true"));
any = true;
if let Some(n) = state.max_upload_bytes {
if let Ok(v) = HeaderValue::from_str(&n.to_string()) {
out.insert("vgi-max-upload-bytes", v);
}
}
}
if any && method == axum::http::Method::OPTIONS {
out.insert(
header::CACHE_CONTROL,
HeaderValue::from_static("public, max-age=300"),
);
}
}
fn build_router_inner(state: Arc<HttpState>) -> Router {
let prefix = state.prefix.clone();
let api = Router::new()
.route("/:method", post(handle_unary).options(handle_preflight))
.route(
"/:method/init",
post(handle_stream_init).options(handle_preflight),
)
.route(
"/:method/exchange",
post(handle_stream_exchange).options(handle_preflight),
);
let api = if state.upload_url_provider.is_some() {
api.route(
"/__upload_url__/init",
post(handle_upload_url).options(handle_preflight),
)
} else {
api
};
let mut app = if prefix.is_empty() {
api
} else {
Router::new().nest(&prefix, api)
};
app = app.route(
&format!(
"{}{}",
prefix,
crate::auth::oauth::OAuthResourceMetadata::well_known_path()
),
axum::routing::get(handle_oauth_metadata),
);
if state.health_enabled {
app = app.route(
"/health",
axum::routing::get(handle_health).options(handle_preflight),
);
}
if state.landing_page_enabled {
let landing_path = if prefix.is_empty() {
"/".to_string()
} else {
prefix.clone()
};
app = app.route(&landing_path, axum::routing::get(handle_landing));
}
if state.describe_page_enabled {
app = app.route(
&format!("{prefix}/describe"),
axum::routing::get(handle_describe_page),
);
}
app.with_state(state)
}
async fn handle_preflight(State(state): State<Arc<HttpState>>, headers: HeaderMap) -> Response {
let mut h = HeaderMap::new();
attach_cors_headers(&state, &mut h, &headers, true);
(StatusCode::NO_CONTENT, h).into_response()
}
async fn handle_health(State(state): State<Arc<HttpState>>) -> Response {
let body = serde_json::json!({
"status": "ok",
"server_id": state.server.server_id,
"protocol": state.server.protocol_name(),
})
.to_string();
let mut h = HeaderMap::new();
h.insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
(StatusCode::OK, h, body).into_response()
}
async fn handle_landing(State(state): State<Arc<HttpState>>) -> Response {
let body = render_landing(&state);
let mut h = HeaderMap::new();
h.insert(
header::CONTENT_TYPE,
HeaderValue::from_static("text/html; charset=utf-8"),
);
(StatusCode::OK, h, body).into_response()
}
async fn handle_describe_page(State(state): State<Arc<HttpState>>) -> Response {
let body = render_describe_page(&state);
let mut h = HeaderMap::new();
h.insert(
header::CONTENT_TYPE,
HeaderValue::from_static("text/html; charset=utf-8"),
);
(StatusCode::OK, h, body).into_response()
}
fn render_landing(state: &Arc<HttpState>) -> String {
let name = if state.server.protocol_name().is_empty() {
"vgi-rpc service"
} else {
state.server.protocol_name()
};
let server_id = &state.server.server_id;
let describe_link = if state.describe_page_enabled {
format!(
r#"<p><a href="{0}/describe">API reference</a></p>"#,
state.prefix
)
} else {
String::new()
};
format!(
"<!doctype html><html><head><meta charset=\"utf-8\"><title>{name}</title></head><body>\
<h1>{name}</h1><p>server_id: <code>{server_id}</code></p>{describe_link}\
</body></html>"
)
}
fn render_describe_page(state: &Arc<HttpState>) -> String {
let mut body = String::from(
"<!doctype html><html><head><meta charset=\"utf-8\"><title>API reference</title></head><body>",
);
body.push_str(&format!(
"<h1>{}</h1><table><tr><th>method</th><th>type</th><th>doc</th></tr>",
state.server.protocol_name()
));
for name in state.server.sorted_method_names() {
let m = &state.server.methods()[name];
let kind = match m.method_type {
crate::server::MethodType::Unary => "unary",
_ => "stream",
};
let doc = m.doc.as_deref().unwrap_or("");
body.push_str(&format!(
"<tr><td><code>{name}</code></td><td>{kind}</td><td>{}</td></tr>",
html_escape(doc)
));
}
body.push_str("</table></body></html>");
body
}
fn html_escape(s: &str) -> String {
s.replace('&', "&")
.replace('<', "<")
.replace('>', ">")
}
fn attach_cors_headers(
state: &Arc<HttpState>,
out: &mut HeaderMap,
req_headers: &HeaderMap,
is_preflight: bool,
) {
let Some(origins) = state.cors_origins.as_deref() else {
return;
};
if let Ok(v) = HeaderValue::from_str(origins) {
out.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, v);
}
out.insert(
header::ACCESS_CONTROL_ALLOW_METHODS,
HeaderValue::from_static("POST, GET, OPTIONS"),
);
let requested = req_headers
.get(header::ACCESS_CONTROL_REQUEST_HEADERS)
.and_then(|v| v.to_str().ok())
.unwrap_or("Content-Type, Authorization, Cookie, Accept-Encoding");
if let Ok(v) = HeaderValue::from_str(requested) {
out.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, v);
}
out.insert(
header::ACCESS_CONTROL_EXPOSE_HEADERS,
HeaderValue::from_static("Content-Encoding, WWW-Authenticate"),
);
if is_preflight {
if let Ok(v) = HeaderValue::from_str(&state.cors_max_age.to_string()) {
out.insert(header::ACCESS_CONTROL_MAX_AGE, v);
}
}
}
async fn handle_oauth_metadata(State(state): State<Arc<HttpState>>) -> Response {
match state.oauth_metadata_json.as_ref() {
Some(body) => {
let mut h = HeaderMap::new();
h.insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
h.insert(
header::CACHE_CONTROL,
HeaderValue::from_static("public, max-age=60"),
);
(StatusCode::OK, h, body.clone()).into_response()
}
None => (StatusCode::NOT_FOUND, "").into_response(),
}
}
fn parse_cookies(raw: Option<&str>) -> std::collections::BTreeMap<String, String> {
let mut out = std::collections::BTreeMap::new();
let Some(raw) = raw else { return out };
for part in raw.split(';') {
let part = part.trim();
if let Some((k, v)) = part.split_once('=') {
out.insert(k.trim().to_string(), v.trim().to_string());
}
}
out
}
fn headers_to_pairs(headers: &HeaderMap) -> Vec<(String, String)> {
headers
.iter()
.filter_map(|(k, v)| {
v.to_str()
.ok()
.map(|s| (k.as_str().to_string(), s.to_string()))
})
.collect()
}
fn authenticate_request(
state: &Arc<HttpState>,
method: &str,
headers: &HeaderMap,
) -> std::result::Result<crate::auth::AuthContext, Response> {
let Some(cb) = state.authenticate.as_ref() else {
return Ok(crate::auth::AuthContext::anonymous());
};
let pairs = headers_to_pairs(headers);
let req = crate::auth::AuthRequest {
method,
headers: &pairs,
peer_addr: None,
};
match (cb)(&req) {
Ok(ctx) => Ok(ctx),
Err(err) => {
let status = match err.error_type.as_str() {
"PermissionError" | "ValueError" => StatusCode::UNAUTHORIZED,
_ => StatusCode::INTERNAL_SERVER_ERROR,
};
let mut h = HeaderMap::new();
if status == StatusCode::UNAUTHORIZED {
if let Some(wa) = state.www_authenticate.as_deref() {
if let Ok(hv) = HeaderValue::from_str(wa) {
h.insert(header::WWW_AUTHENTICATE, hv);
}
}
}
Err((status, h, err.message.clone()).into_response())
}
}
}
fn arrow_response(status: StatusCode, body: Vec<u8>) -> Response {
let mut headers = HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
HeaderValue::from_static(ARROW_CONTENT_TYPE),
);
(status, headers, body).into_response()
}
fn enforce_response_body_cap(
state: &Arc<HttpState>,
schema: &arrow_schema::Schema,
body: Vec<u8>,
method: &str,
server_id: &str,
request_id: &str,
) -> Response {
if let Some(limit) = state.max_response_bytes {
if body.len() > limit {
let err = RpcError::runtime_error(format!(
"HTTP body exceeds max_response_bytes ({} > {}) for method {:?}",
body.len(),
limit,
method
));
return cap_error_response(schema, &err, server_id, request_id);
}
}
arrow_response(StatusCode::OK, body)
}
fn cap_error_response(
schema: &arrow_schema::Schema,
err: &RpcError,
server_id: &str,
request_id: &str,
) -> Response {
let mut buf = Vec::new();
{
let mut sw = StreamWriter::new(&mut buf, schema).unwrap();
let md = build_error_metadata(err, server_id, request_id);
let _ = sw.write(&empty_batch(schema).unwrap(), Some(&md));
let _ = sw.finish();
}
let mut headers = HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
HeaderValue::from_static(ARROW_CONTENT_TYPE),
);
headers.insert("x-vgi-rpc-error", HeaderValue::from_static("true"));
(StatusCode::OK, headers, buf).into_response()
}
fn plain_error(status: StatusCode, msg: String) -> Response {
(status, msg).into_response()
}
fn has_arrow_ct(headers: &HeaderMap) -> bool {
headers
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|s| s == ARROW_CONTENT_TYPE)
.unwrap_or(false)
}
fn maybe_decompress(headers: &HeaderMap, body: &Bytes, max_size: usize) -> Result<Vec<u8>> {
let enc = headers
.get(header::CONTENT_ENCODING)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if body.len() > max_size {
return Err(RpcError::runtime_error(format!(
"Request body exceeds max size ({} bytes > {})",
body.len(),
max_size
)));
}
if enc.eq_ignore_ascii_case("zstd") {
decode_zstd_bounded(body.as_ref(), max_size)
} else {
Ok(body.to_vec())
}
}
fn decode_zstd_bounded(input: &[u8], max_size: usize) -> Result<Vec<u8>> {
use std::io::Read;
let mut decoder = zstd::Decoder::new(input)
.map_err(|e| RpcError::runtime_error(format!("zstd decode: {e}")))?;
let mut out = Vec::with_capacity(input.len().min(max_size).min(64 * 1024));
let mut buf = [0u8; 16 * 1024];
loop {
let n = decoder
.read(&mut buf)
.map_err(|e| RpcError::runtime_error(format!("zstd decode: {e}")))?;
if n == 0 {
break;
}
if out.len() + n > max_size {
return Err(RpcError::runtime_error(format!(
"Decompressed body exceeds max size ({}+ bytes > {})",
out.len() + n,
max_size
)));
}
out.extend_from_slice(&buf[..n]);
}
Ok(out)
}
fn parse_request_from_body(body: &[u8]) -> Result<Request> {
let mut r = StreamReader::new(body)?;
let (batch, metadata) = r
.read_next()?
.ok_or_else(|| RpcError::protocol_error("empty IPC stream"))?;
r.drain()?;
Request::from_read_batch(batch, metadata, false)
}
fn error_stream_bytes(
schema: &Schema,
err: &RpcError,
server_id: &str,
request_id: &str,
) -> Vec<u8> {
let mut buf = Vec::new();
let mut w = StreamWriter::new(&mut buf, schema).unwrap();
let md = build_error_metadata(err, server_id, request_id);
let _ = w.write(&empty_batch(schema).unwrap(), Some(&md));
let _ = w.finish();
drop(w);
buf
}
fn arrow_error(
state: &Arc<HttpState>,
status: StatusCode,
err: &RpcError,
request_id: &str,
) -> Response {
arrow_response(
status,
error_stream_bytes(&Schema::empty(), err, &state.server.server_id, request_id),
)
}
fn decode_hex_key(s: &str) -> std::result::Result<Vec<u8>, String> {
let s = s.trim();
if s.len() % 2 != 0 {
return Err("hex length must be even".into());
}
let mut out = Vec::with_capacity(s.len() / 2);
let bytes = s.as_bytes();
for pair in bytes.chunks_exact(2) {
let hi = hex_nibble(pair[0])?;
let lo = hex_nibble(pair[1])?;
out.push((hi << 4) | lo);
}
if out.len() < 32 {
return Err(format!(
"signing key must be ≥ 32 bytes (got {} bytes)",
out.len()
));
}
Ok(out)
}
fn hex_nibble(c: u8) -> std::result::Result<u8, String> {
match c {
b'0'..=b'9' => Ok(c - b'0'),
b'a'..=b'f' => Ok(c - b'a' + 10),
b'A'..=b'F' => Ok(c - b'A' + 10),
_ => Err(format!("invalid hex character: {:?}", c as char)),
}
}
fn decode_base64_key(s: &str) -> std::result::Result<Vec<u8>, String> {
let s = s.trim().trim_end_matches('=');
let mut padded = s.to_string();
while padded.len() % 4 != 0 {
padded.push('=');
}
let bytes = base64::engine::general_purpose::STANDARD
.decode(padded.as_bytes())
.map_err(|e| format!("base64 decode: {e}"))?;
if bytes.len() < 32 {
return Err(format!(
"signing key must be ≥ 32 bytes (got {} bytes)",
bytes.len()
));
}
Ok(bytes)
}
fn new_session_id() -> String {
let mut b = [0u8; 16];
rand::thread_rng().fill_bytes(&mut b);
bytes_to_hex(&b)
}
const UPLOAD_URL_METHOD: &str = "__upload_url__";
const MAX_UPLOAD_URL_COUNT: i64 = 100;
fn upload_url_response_schema() -> Schema {
use arrow_schema::{DataType, Field, TimeUnit};
Schema::new(vec![
Field::new("upload_url", DataType::Utf8, false),
Field::new("download_url", DataType::Utf8, false),
Field::new(
"expires_at",
DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
false,
),
])
}
async fn handle_upload_url(
State(state): State<Arc<HttpState>>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let auth = match authenticate_request(&state, UPLOAD_URL_METHOD, &headers) {
Ok(a) => a,
Err(resp) => return resp,
};
let _ = auth;
if !has_arrow_ct(&headers) {
return plain_error(
StatusCode::UNSUPPORTED_MEDIA_TYPE,
"need arrow content type".into(),
);
}
let provider = match state.upload_url_provider.as_ref() {
Some(p) => p.clone(),
None => return plain_error(StatusCode::NOT_FOUND, "upload-url not enabled".into()),
};
let body = match maybe_decompress(&headers, &body, state.max_body_size) {
Ok(b) => b,
Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
};
let req = match parse_request_from_body(&body) {
Ok(r) => r,
Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
};
if !req.method.is_empty() && req.method != UPLOAD_URL_METHOD {
let err = RpcError::protocol_error(format!(
"Method mismatch: expected '{UPLOAD_URL_METHOD}', got '{}'",
req.method
));
return arrow_error(&state, StatusCode::BAD_REQUEST, &err, &req.request_id);
}
let mut count: i64 = 1;
if let Some(arr) = req.column("count") {
use arrow_array::Array;
if let Some(c) = arr.as_any().downcast_ref::<arrow_array::Int64Array>() {
if !c.is_empty() && !Array::is_null(c, 0) {
count = c.value(0);
}
}
}
count = count.clamp(1, MAX_UPLOAD_URL_COUNT);
let urls_res = tokio::task::block_in_place(|| {
let mut out = Vec::with_capacity(count as usize);
for _ in 0..count {
out.push(provider.generate_upload_url()?);
}
Ok::<_, RpcError>(out)
});
let schema = upload_url_response_schema();
let mut body_buf = Vec::new();
{
let mut sw = match StreamWriter::new(&mut body_buf, &schema) {
Ok(w) => w,
Err(e) => {
return arrow_error(
&state,
StatusCode::INTERNAL_SERVER_ERROR,
&e,
&req.request_id,
)
}
};
match urls_res {
Ok(urls) => {
use arrow_array::{StringArray, TimestampMicrosecondArray};
let upload_arr = StringArray::from(
urls.iter()
.map(|u| u.upload_url.clone())
.collect::<Vec<_>>(),
);
let download_arr = StringArray::from(
urls.iter()
.map(|u| u.download_url.clone())
.collect::<Vec<_>>(),
);
let expires_arr = TimestampMicrosecondArray::from(
urls.iter().map(|u| u.expires_at_micros).collect::<Vec<_>>(),
)
.with_timezone("UTC");
let batch = match RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(upload_arr),
Arc::new(download_arr),
Arc::new(expires_arr),
],
) {
Ok(b) => b,
Err(e) => {
let err = RpcError::runtime_error(format!("upload-url batch: {e}"));
let md =
build_error_metadata(&err, &state.server.server_id, &req.request_id);
let _ = sw.write(&empty_batch(&schema).unwrap(), Some(&md));
let _ = sw.finish();
drop(sw);
return arrow_response(StatusCode::OK, body_buf);
}
};
let _ = sw.write(&batch, None);
}
Err(err) => {
let md = build_error_metadata(&err, &state.server.server_id, &req.request_id);
let _ = sw.write(&empty_batch(&schema).unwrap(), Some(&md));
}
}
let _ = sw.finish();
}
arrow_response(StatusCode::OK, body_buf)
}
async fn handle_unary(
State(state): State<Arc<HttpState>>,
Path(method): Path<String>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let auth = match authenticate_request(&state, &method, &headers) {
Ok(a) => a,
Err(resp) => return resp,
};
if !has_arrow_ct(&headers) {
return plain_error(
StatusCode::UNSUPPORTED_MEDIA_TYPE,
"need arrow content type".into(),
);
}
let cookies = parse_cookies(headers.get(header::COOKIE).and_then(|v| v.to_str().ok()));
let server = state.server.clone();
let body = match maybe_decompress(&headers, &body, state.max_body_size) {
Ok(b) => b,
Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
};
let mut req = match parse_request_from_body(&body) {
Ok(r) => r,
Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
};
if md_get(&req.metadata, crate::metadata::LOCATION_KEY).is_some() {
if let Some(cfg) = server.external_config().as_ref() {
let outer_md = req.metadata.clone();
let outer_batch = req.batch.clone();
let resolved = tokio::task::block_in_place(|| {
crate::external::resolve_external_location(&outer_batch, &outer_md, cfg)
});
match resolved {
Ok((inner_batch, _user_md)) => {
req.batch = inner_batch;
}
Err(err) => {
return arrow_error(&state, StatusCode::BAD_REQUEST, &err, &req.request_id);
}
}
}
}
if server.describe_enabled() && method == crate::introspect::DESCRIBE_METHOD_NAME {
let (batch, md) = match crate::introspect::build_describe(
server.protocol_name(),
server.methods(),
&server.server_id,
) {
Ok(x) => x,
Err(err) => {
return arrow_error(
&state,
StatusCode::INTERNAL_SERVER_ERROR,
&err,
&req.request_id,
);
}
};
let mut buf = Vec::new();
let _ = crate::introspect::write_describe_response(&mut buf, &batch, &md);
return arrow_response(StatusCode::OK, buf);
}
let Some(info) = server
.method(&method)
.filter(|m| m.method_type == MethodType::Unary)
else {
let err = RpcError::attribute_error(format!("Unknown method: '{}'", method));
return arrow_error(&state, StatusCode::NOT_FOUND, &err, &req.request_id);
};
let ctx = CallContext::with_auth_cookies(&server, &req, auth.clone(), cookies);
let dispatch_info = crate::hooks::DispatchInfo::from_request(&server, &req, "unary", &auth);
let hook = server.dispatch_hook.clone();
let hook_token = hook.as_ref().map(|h| h.on_dispatch_start(&dispatch_info));
let mut stats = crate::hooks::CallStatistics {
input_batches: 1,
input_rows: req.batch.num_rows() as u64,
..Default::default()
};
let result = (info.unary.as_ref().unwrap())(&req, &ctx);
let logs = ctx.drain_logs();
let mut app_err: Option<RpcError> = None;
let mut buf = Vec::new();
{
let mut sw = StreamWriter::new(&mut buf, &info.result_schema).unwrap();
for log in &logs {
let md = build_log_metadata(log, &server.server_id, &req.request_id);
let _ = sw.write(&empty_batch(&info.result_schema).unwrap(), Some(&md));
}
match result {
Ok(batch_opt) => {
let out_batch =
batch_opt.unwrap_or_else(|| empty_batch(&info.result_schema).unwrap());
stats.output_batches = 1;
stats.output_rows = out_batch.num_rows() as u64;
if let Some(cfg) = server.external_config().as_ref() {
let externalized = tokio::task::block_in_place(|| {
crate::external::maybe_externalize_batch(&out_batch, None, cfg)
});
match externalized {
Ok(Some((ptr, md))) => {
let _ = sw.write(&ptr, Some(&md));
}
Ok(None) => {
let _ = sw.write(&out_batch, None);
}
Err(err) => {
let md = build_error_metadata(&err, &server.server_id, &req.request_id);
let _ = sw.write(&empty_batch(&info.result_schema).unwrap(), Some(&md));
app_err = Some(err);
}
}
} else {
let _ = sw.write(&out_batch, None);
}
}
Err(err) => {
let md = build_error_metadata(&err, &server.server_id, &req.request_id);
let _ = sw.write(&empty_batch(&info.result_schema).unwrap(), Some(&md));
app_err = Some(err);
}
}
let _ = sw.finish();
}
if let Some(hook) = hook {
hook.on_dispatch_end(
hook_token.unwrap_or(0),
&dispatch_info,
app_err.as_ref(),
&stats,
);
}
if let Some(limit) = state.max_response_bytes {
if buf.len() > limit {
let err = RpcError::runtime_error(format!(
"HTTP body exceeds max_response_bytes ({} > {}) for method {:?}",
buf.len(),
limit,
method
));
return cap_error_response(
&info.result_schema,
&err,
&server.server_id,
&req.request_id,
);
}
}
arrow_response(StatusCode::OK, buf)
}
async fn handle_stream_init(
State(state): State<Arc<HttpState>>,
Path(method): Path<String>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let auth = match authenticate_request(&state, &method, &headers) {
Ok(a) => a,
Err(resp) => return resp,
};
if !has_arrow_ct(&headers) {
return plain_error(
StatusCode::UNSUPPORTED_MEDIA_TYPE,
"need arrow content type".into(),
);
}
let auth_for_token = auth.clone();
let cookies = parse_cookies(headers.get(header::COOKIE).and_then(|v| v.to_str().ok()));
let server = state.server.clone();
let body = match maybe_decompress(&headers, &body, state.max_body_size) {
Ok(b) => b,
Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
};
let req = match parse_request_from_body(&body) {
Ok(r) => r,
Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
};
let Some(info) = server
.method(&method)
.filter(|m| m.method_type != MethodType::Unary)
else {
let err = RpcError::attribute_error(format!("Unknown stream method: '{}'", method));
return arrow_error(&state, StatusCode::NOT_FOUND, &err, &req.request_id);
};
let ctx = CallContext::with_auth_cookies(&server, &req, auth, cookies);
let init_result = (info.stream.as_ref().unwrap())(&req, &ctx);
let init_logs = ctx.drain_logs();
let sr = match init_result {
Ok(s) => s,
Err(err) => {
return arrow_response(
StatusCode::OK,
error_stream_bytes(&empty_schema(), &err, &server.server_id, &req.request_id),
);
}
};
let StreamResult {
output_schema,
input_schema,
state: mut ss,
header,
header_metadata,
} = sr;
let mut body_buf = Vec::new();
if let Some(header_batch) = header.as_ref() {
let hdr_schema = header_batch.schema();
let mut hw = StreamWriter::new(&mut body_buf, hdr_schema.as_ref()).unwrap();
for log in &init_logs {
let md = build_log_metadata(log, &server.server_id, &req.request_id);
let _ = hw.write(&empty_batch(hdr_schema.as_ref()).unwrap(), Some(&md));
}
let _ = hw.write(header_batch, header_metadata.as_ref());
let _ = hw.finish();
}
let is_producer = matches!(ss, StreamStateKind::Producer(_));
let stream_id = new_session_id();
let mut finished = false;
let mut init_error: Option<RpcError> = None;
{
let mut sw = StreamWriter::new(&mut body_buf, output_schema.as_ref()).unwrap();
if header.is_none() {
for log in &init_logs {
let md = build_log_metadata(log, &server.server_id, &req.request_id);
let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
}
}
let _ = header_metadata;
if is_producer {
finished = run_producer(
&mut sw,
&mut ss,
&output_schema,
&server,
&req,
state.producer_batch_limit,
);
}
if !finished {
match build_continuation_token(
&state,
&auth_for_token,
&ss,
&output_schema,
input_schema.as_ref(),
&stream_id,
) {
Ok(token) => {
let md = Metadata::from([(STATE_KEY.to_string(), token)]);
let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
}
Err(err) => {
let md = build_error_metadata(&err, &server.server_id, &req.request_id);
let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
init_error = Some(err);
}
}
}
let _ = sw.finish();
}
let _ = init_error;
arrow_response(StatusCode::OK, body_buf)
}
fn build_continuation_token(
state: &Arc<HttpState>,
auth: &crate::auth::AuthContext,
ss: &StreamStateKind,
output_schema: &SchemaRef,
input_schema: Option<&SchemaRef>,
stream_id: &str,
) -> Result<String> {
let state_bytes = match ss {
StreamStateKind::Producer(p) => p.encode_state()?,
StreamStateKind::Exchange(e) => e.encode_state()?,
};
let out_schema_bytes = write_schema_bytes(output_schema.as_ref())?;
let in_schema_bytes = match input_schema {
Some(s) => write_schema_bytes(s.as_ref())?,
None => Vec::new(),
};
Ok(state.pack_state_token(
auth,
&state_bytes,
&out_schema_bytes,
&in_schema_bytes,
stream_id,
))
}
fn run_producer<W: std::io::Write>(
sw: &mut StreamWriter<W>,
ss: &mut StreamStateKind,
output_schema: &SchemaRef,
server: &Arc<RpcServer>,
req: &Request,
limit: usize,
) -> bool {
let ctx = CallContext::for_request(server, req);
let producer = match ss {
StreamStateKind::Producer(p) => p,
StreamStateKind::Exchange(_) => unreachable!(),
};
let mut batches_written = 0usize;
while limit == 0 || batches_written < limit {
let mut out = OutputCollector::new(output_schema.clone(), true);
let result = producer.produce(&mut out, &ctx);
for log in ctx.drain_logs() {
let md = build_log_metadata(&log, &server.server_id, &req.request_id);
let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
}
if let Err(err) = result {
let md = build_error_metadata(&err, &server.server_id, &req.request_id);
let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
return true;
}
let finished = out.finished();
let mut emitted_data = false;
for item in out.items.drain(..) {
match item {
Emitted::Log(log) => {
let md = build_log_metadata(&log, &server.server_id, &req.request_id);
let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
}
Emitted::Batch { batch, metadata } => {
let _ = sw.write(&batch, metadata.as_ref());
emitted_data = true;
}
}
}
if emitted_data {
batches_written += 1;
}
if finished {
return true;
}
if !emitted_data {
return true;
}
}
false
}
async fn handle_stream_exchange(
State(state): State<Arc<HttpState>>,
Path(method): Path<String>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let auth = match authenticate_request(&state, &method, &headers) {
Ok(a) => a,
Err(resp) => return resp,
};
if !has_arrow_ct(&headers) {
return plain_error(
StatusCode::UNSUPPORTED_MEDIA_TYPE,
"need arrow content type".into(),
);
}
let server = state.server.clone();
let body = match maybe_decompress(&headers, &body, state.max_body_size) {
Ok(b) => b,
Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
};
let (batch, metadata) = match read_input_batch(&body) {
Ok(x) => x,
Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
};
let Some(token) = md_get(&metadata, STATE_KEY).map(str::to_owned) else {
let err = RpcError::runtime_error("Missing state token in exchange request");
return arrow_error(&state, StatusCode::BAD_REQUEST, &err, "");
};
let cancelled = md_get(&metadata, CANCEL_KEY).is_some();
let unpacked = match state.unpack_state_token(&auth, &token) {
Ok(u) => u,
Err(err) => return arrow_error(&state, StatusCode::BAD_REQUEST, &err, ""),
};
let output_schema = match read_schema_bytes(&unpacked.output_schema_bytes) {
Ok(s) => s,
Err(err) => return arrow_error(&state, StatusCode::BAD_REQUEST, &err, ""),
};
let input_schema: Option<SchemaRef> = if unpacked.input_schema_bytes.is_empty() {
None
} else {
match read_schema_bytes(&unpacked.input_schema_bytes) {
Ok(s) => Some(s),
Err(err) => return arrow_error(&state, StatusCode::BAD_REQUEST, &err, ""),
}
};
let Some(info) = server
.method(&method)
.filter(|m| m.method_type != MethodType::Unary)
else {
let err = RpcError::attribute_error(format!("Unknown stream method: '{}'", method));
return arrow_error(&state, StatusCode::NOT_FOUND, &err, "");
};
let Some(decoder) = info.state_decoder.as_ref() else {
let err = RpcError::runtime_error(format!(
"Stream method '{method}' is registered without a state decoder; \
it cannot serve HTTP continuation requests"
));
return arrow_error(&state, StatusCode::INTERNAL_SERVER_ERROR, &err, "");
};
let mut ss = match decoder(&unpacked.state_bytes) {
Ok(s) => s,
Err(err) => return arrow_error(&state, StatusCode::BAD_REQUEST, &err, ""),
};
let req = Request {
method: method.clone(),
request_id: md_get(&metadata, REQUEST_ID_KEY).unwrap_or("").to_string(),
batch: empty_batch(&Schema::empty()).unwrap(),
metadata: metadata.clone(),
};
let ctx = CallContext::for_request(&server, &req);
let mut body_buf = Vec::new();
if cancelled {
match &mut ss {
StreamStateKind::Producer(p) => p.on_cancel(&ctx),
StreamStateKind::Exchange(e) => e.on_cancel(&ctx),
}
{
let mut sw = StreamWriter::new(&mut body_buf, output_schema.as_ref()).unwrap();
let _ = sw.finish();
}
return arrow_response(StatusCode::OK, body_buf);
}
if matches!(ss, StreamStateKind::Producer(_)) {
let finished;
{
let mut sw = StreamWriter::new(&mut body_buf, output_schema.as_ref()).unwrap();
finished = run_producer(
&mut sw,
&mut ss,
&output_schema,
&server,
&req,
state.producer_batch_limit,
);
if !finished {
match build_continuation_token(
&state,
&auth,
&ss,
&output_schema,
input_schema.as_ref(),
&unpacked.stream_id,
) {
Ok(new_token) => {
let md = Metadata::from([(STATE_KEY.to_string(), new_token)]);
let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
}
Err(err) => {
let md = build_error_metadata(&err, &server.server_id, &req.request_id);
let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
}
}
}
let _ = sw.finish();
}
return arrow_response(StatusCode::OK, body_buf);
}
let casted = match &input_schema {
Some(exp) if batch.schema() != *exp => match cast_batch(&batch, exp) {
Ok(b) => b,
Err(e) => {
let mut sw = StreamWriter::new(&mut body_buf, output_schema.as_ref()).unwrap();
let md = build_error_metadata(&e, &server.server_id, &req.request_id);
let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
let _ = sw.finish();
drop(sw);
return arrow_response(StatusCode::OK, body_buf);
}
},
_ => batch,
};
let mut out = OutputCollector::new(output_schema.clone(), false);
let res = match &mut ss {
StreamStateKind::Exchange(e) => e.exchange(&casted, &mut out, &ctx),
_ => unreachable!(),
};
{
let mut sw = StreamWriter::new(&mut body_buf, output_schema.as_ref()).unwrap();
for log in ctx.drain_logs() {
let md = build_log_metadata(&log, &server.server_id, &req.request_id);
let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
}
if let Err(err) = res {
let md = build_error_metadata(&err, &server.server_id, &req.request_id);
let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
} else {
let new_token = match build_continuation_token(
&state,
&auth,
&ss,
&output_schema,
input_schema.as_ref(),
&unpacked.stream_id,
) {
Ok(t) => t,
Err(err) => {
let md = build_error_metadata(&err, &server.server_id, &req.request_id);
let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
let _ = sw.finish();
drop(sw);
return arrow_response(StatusCode::OK, body_buf);
}
};
let mut wrote_data = false;
for item in out.items.drain(..) {
match item {
Emitted::Log(log) => {
let md = build_log_metadata(&log, &server.server_id, &req.request_id);
let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
}
Emitted::Batch { batch, metadata } => {
let mut md = metadata.unwrap_or_default();
md.insert(STATE_KEY.to_string(), new_token.clone());
let _ = sw.write(&batch, Some(&md));
wrote_data = true;
}
}
}
if !wrote_data {
let md = Metadata::from([(STATE_KEY.to_string(), new_token)]);
let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
}
}
let _ = sw.finish();
}
enforce_response_body_cap(
&state,
output_schema.as_ref(),
body_buf,
&method,
&server.server_id,
"",
)
}
fn read_input_batch(body: &[u8]) -> Result<(RecordBatch, Metadata)> {
let mut r = StreamReader::new(body)?;
let (batch, metadata) = r
.read_next()?
.ok_or_else(|| RpcError::runtime_error("no batch in exchange request"))?;
r.drain()?;
Ok((batch, metadata))
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
fn state_with_key() -> Arc<HttpState> {
use crate::server::RpcServer;
let server = Arc::new(RpcServer::builder().server_id("test").build());
HttpState::builder()
.server(server)
.signing_key(&[7u8; 32])
.token_ttl(Duration::from_millis(50))
.max_body_size(1024)
.build()
}
fn sample_schema_bytes() -> Vec<u8> {
use arrow_schema::{DataType, Field, Schema};
write_schema_bytes(&Schema::new(vec![Field::new("x", DataType::Int64, false)])).unwrap()
}
#[tokio::test]
async fn pack_unpack_roundtrip() {
let s = state_with_key();
let auth = crate::auth::AuthContext::anonymous();
let state_bytes = b"state-payload";
let out_sch = sample_schema_bytes();
let in_sch = sample_schema_bytes();
let token = s.pack_state_token(&auth, state_bytes, &out_sch, &in_sch, "sid-123");
let unpacked = s.unpack_state_token(&auth, &token).unwrap();
assert_eq!(unpacked.state_bytes, state_bytes);
assert_eq!(unpacked.output_schema_bytes, out_sch);
assert_eq!(unpacked.input_schema_bytes, in_sch);
assert_eq!(unpacked.stream_id, "sid-123");
}
#[tokio::test]
async fn unpack_rejects_tampered_hmac() {
let s = state_with_key();
let auth = crate::auth::AuthContext::anonymous();
let token = s.pack_state_token(&auth, b"s", b"o", b"i", "sid");
let mut bytes = base64::engine::general_purpose::STANDARD
.decode(token.as_bytes())
.unwrap();
let last = bytes.len() - 1;
bytes[last] ^= 0x01;
let tampered = base64::engine::general_purpose::STANDARD.encode(bytes);
assert!(s.unpack_state_token(&auth, &tampered).is_err());
}
#[tokio::test]
async fn unpack_rejects_different_key() {
use crate::server::RpcServer;
let server = Arc::new(RpcServer::builder().server_id("t").build());
let a = HttpState::builder()
.server(server.clone())
.signing_key(&[1u8; 32])
.build();
let b = HttpState::builder()
.server(server)
.signing_key(&[2u8; 32])
.build();
let auth = crate::auth::AuthContext::anonymous();
let tok = a.pack_state_token(&auth, b"s", b"o", b"i", "sid");
assert!(b.unpack_state_token(&auth, &tok).is_err());
}
#[tokio::test]
async fn unpack_rejects_expired_token() {
let s = state_with_key(); let stale = pack_state_token(&[7u8; 32], &[], b"s", b"o", b"i", "sid", 0);
let auth = crate::auth::AuthContext::anonymous();
let err = s.unpack_state_token(&auth, &stale).unwrap_err();
assert!(err.message.contains("expired"), "got: {}", err.message);
}
#[tokio::test]
async fn unpack_rejects_different_principal() {
let s = state_with_key();
let alice = crate::auth::AuthContext::for_principal("bearer", "alice");
let bob = crate::auth::AuthContext::for_principal("bearer", "bob");
let tok = s.pack_state_token(&alice, b"s", b"o", b"i", "sid");
assert!(s.unpack_state_token(&alice, &tok).is_ok());
assert!(s.unpack_state_token(&bob, &tok).is_err());
let anon = crate::auth::AuthContext::anonymous();
assert!(s.unpack_state_token(&anon, &tok).is_err());
}
#[tokio::test]
async fn unpack_rejects_authenticated_replay_of_anonymous_token() {
let s = state_with_key();
let anon = crate::auth::AuthContext::anonymous();
let alice = crate::auth::AuthContext::for_principal("bearer", "alice");
let tok = s.pack_state_token(&anon, b"s", b"o", b"i", "sid");
assert!(s.unpack_state_token(&alice, &tok).is_err());
}
#[tokio::test]
async fn unpack_rejects_cross_domain_replay() {
let s = state_with_key();
let bearer_alice = crate::auth::AuthContext::for_principal("bearer", "alice");
let mtls_alice = crate::auth::AuthContext::for_principal("mtls", "alice");
let tok = s.pack_state_token(&bearer_alice, b"s", b"o", b"i", "sid");
assert!(s.unpack_state_token(&mtls_alice, &tok).is_err());
}
#[tokio::test]
async fn decompress_rejects_oversize() {
let hdr = HeaderMap::new();
let body = Bytes::from(vec![0u8; 1025]);
let err = super::maybe_decompress(&hdr, &body, 1024).unwrap_err();
assert!(err.message.contains("exceeds max size"));
}
#[test]
fn zstd_bounded_rejects_zip_bomb_without_full_alloc() {
let huge = vec![0u8; 8 * 1024 * 1024];
let compressed = zstd::encode_all(huge.as_slice(), 1).unwrap();
assert!(compressed.len() < 100_000, "compressed should be tiny");
let err = super::decode_zstd_bounded(&compressed, 64 * 1024).unwrap_err();
assert!(
err.message.contains("exceeds max size"),
"expected oversize error, got: {}",
err.message
);
}
#[test]
fn zstd_bounded_passes_small_payload() {
let small = b"hello-world".repeat(10);
let compressed = zstd::encode_all(small.as_slice(), 1).unwrap();
let out = super::decode_zstd_bounded(&compressed, 1024).unwrap();
assert_eq!(out, small);
}
#[test]
fn decode_hex_key_roundtrip() {
let key =
decode_hex_key("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f")
.unwrap();
assert_eq!(key.len(), 32);
assert_eq!(key[0], 0x00);
assert_eq!(key[31], 0x1f);
}
#[test]
fn decode_hex_key_rejects_short() {
assert!(decode_hex_key("deadbeef").is_err());
}
#[test]
fn decode_hex_key_rejects_bad_char() {
assert!(decode_hex_key(&"zz".repeat(32)).is_err());
}
#[test]
fn decode_base64_key_accepts_padded() {
let s = base64::engine::general_purpose::STANDARD.encode([7u8; 32]);
let out = decode_base64_key(&s).unwrap();
assert_eq!(out, vec![7u8; 32]);
}
#[test]
fn decode_base64_key_accepts_unpadded() {
let s = base64::engine::general_purpose::STANDARD
.encode([7u8; 32])
.trim_end_matches('=')
.to_string();
let out = decode_base64_key(&s).unwrap();
assert_eq!(out, vec![7u8; 32]);
}
#[test]
fn decode_base64_key_rejects_short() {
let s = base64::engine::general_purpose::STANDARD.encode(b"short");
assert!(decode_base64_key(&s).is_err());
}
#[tokio::test]
async fn signing_key_hex_round_trips_through_token() {
use crate::server::RpcServer;
let server = Arc::new(RpcServer::builder().server_id("t").build());
let a = HttpState::builder()
.server(server.clone())
.signing_key_hex("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f")
.build();
let b = HttpState::builder()
.server(server)
.signing_key_hex("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f")
.build();
let auth = crate::auth::AuthContext::anonymous();
let tok = a.pack_state_token(&auth, b"s", b"o", b"i", "sid");
assert_eq!(b.unpack_state_token(&auth, &tok).unwrap().stream_id, "sid");
}
}