1use std::sync::Arc;
15
16use arrow_array::RecordBatch;
17use arrow_schema::{Schema, SchemaRef};
18use axum::{
19 body::Bytes,
20 extract::{Path, State},
21 http::{header, HeaderMap, HeaderValue, StatusCode},
22 response::{IntoResponse, Response},
23 routing::post,
24 Router,
25};
26use base64::Engine;
27use rand::RngCore;
28
29use crate::errors::{Result, RpcError};
30use crate::metadata::{CANCEL_KEY, REQUEST_ID_KEY, STATE_KEY};
31use crate::server::{
32 build_error_metadata, build_log_metadata, cast_batch, CallContext, MethodType, Request,
33 RpcServer,
34};
35use crate::stream::{empty_schema, Emitted, OutputCollector, StreamResult, StreamStateKind};
36use crate::wire::{bytes_to_hex, empty_batch, md_get, Metadata, StreamReader, StreamWriter};
37
38pub const ARROW_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
39
40const SESSION_HEADER: &str = "vgi-session";
44const SESSION_ACCEPT_HEADER: &str = "vgi-session-accept";
45const SESSION_CLOSE_HEADER: &str = "vgi-session-close";
46const ECHO_HEADER_PREFIX: &str = "vgi-echo-";
47const STICKY_ENABLED_HEADER: &str = "vgi-sticky-enabled";
48const STICKY_DEFAULT_TTL_HEADER: &str = "vgi-sticky-default-ttl";
49const STICKY_ECHO_HEADERS_HEADER: &str = "vgi-sticky-echo-headers";
50const SESSION_ENDPOINT: &str = "__session__";
52
53pub struct HttpState {
63 server: Arc<RpcServer>,
64 token_key: [u8; 32],
65 producer_batch_limit: usize,
66 token_ttl: std::time::Duration,
67 max_body_size: usize,
68 request_timeout: std::time::Duration,
72 authenticate: Option<crate::auth::Authenticate>,
73 #[allow(dead_code)]
74 oauth_metadata: Option<Arc<crate::auth::oauth::OAuthResourceMetadata>>,
75 oauth_metadata_json: Option<Vec<u8>>,
76 www_authenticate: Option<String>,
77 cors_origins: Option<String>,
78 cors_max_age: u32,
79 prefix: String,
80 response_compression_level: Option<i32>,
81 landing_page_enabled: bool,
82 describe_page_enabled: bool,
83 health_enabled: bool,
84 max_request_bytes: Option<usize>,
85 max_upload_bytes: Option<usize>,
86 max_response_bytes: Option<usize>,
91 max_externalized_response_bytes: Option<usize>,
95 upload_url_provider: Option<Arc<dyn crate::external::UploadUrlProvider>>,
96 sticky: Option<Arc<crate::sticky::StickyContext>>,
98}
99
100#[derive(Default)]
102pub struct HttpStateBuilder {
103 server: Option<Arc<RpcServer>>,
104 token_key: Option<[u8; 32]>,
105 producer_batch_limit: Option<usize>,
106 token_ttl: Option<std::time::Duration>,
107 max_body_size: Option<usize>,
108 request_timeout: Option<std::time::Duration>,
109 authenticate: Option<crate::auth::Authenticate>,
110 oauth_metadata: Option<Arc<crate::auth::oauth::OAuthResourceMetadata>>,
111 cors_origins: Option<String>,
112 cors_max_age: Option<u32>,
113 prefix: Option<String>,
114 response_compression_level: Option<i32>,
115 landing_page_enabled: Option<bool>,
116 describe_page_enabled: Option<bool>,
117 health_enabled: Option<bool>,
118 max_request_bytes: Option<usize>,
119 max_upload_bytes: Option<usize>,
120 max_response_bytes: Option<usize>,
121 max_externalized_response_bytes: Option<usize>,
122 upload_url_provider: Option<Arc<dyn crate::external::UploadUrlProvider>>,
123 enable_sticky: Option<bool>,
124 sticky_default_ttl: Option<std::time::Duration>,
125 sticky_echo_headers: Vec<(String, String)>,
126}
127
128impl HttpStateBuilder {
129 pub fn server(mut self, server: Arc<RpcServer>) -> Self {
130 self.server = Some(server);
131 self
132 }
133
134 pub fn token_key(mut self, key: &[u8]) -> Self {
140 assert!(
141 key.len() >= 32,
142 "token_key: signing key must be at least 32 bytes (got {}). \
143 Generate one with `openssl rand -hex 32`.",
144 key.len()
145 );
146 let mut k = [0u8; 32];
147 k.copy_from_slice(&key[..32]);
148 self.token_key = Some(k);
149 self
150 }
151
152 pub fn token_key_hex(self, hex: &str) -> Self {
156 let bytes = decode_hex_key(hex).expect("token_key_hex: invalid hex or wrong length");
157 self.token_key(&bytes)
158 }
159
160 pub fn token_key_base64(self, b64: &str) -> Self {
163 let bytes =
164 decode_base64_key(b64).expect("token_key_base64: invalid base64 or wrong length");
165 self.token_key(&bytes)
166 }
167
168 pub fn token_key_from_env(self, var: &str) -> Self {
173 let raw = std::env::var(var)
174 .unwrap_or_else(|_| panic!("token_key_from_env: env var {var} is unset or not UTF-8"));
175 let trimmed = raw.trim();
176 let bytes = decode_base64_key(trimmed)
177 .or_else(|_| decode_hex_key(trimmed))
178 .unwrap_or_else(|e| {
179 panic!("token_key_from_env: {var} is not valid base64 or hex ({e})")
180 });
181 self.token_key(&bytes)
182 }
183
184 pub fn producer_batch_limit(mut self, n: usize) -> Self {
187 self.producer_batch_limit = Some(n);
188 self
189 }
190
191 pub fn token_ttl(mut self, ttl: std::time::Duration) -> Self {
195 self.token_ttl = Some(ttl);
196 self
197 }
198
199 pub fn max_body_size(mut self, n: usize) -> Self {
204 self.max_body_size = Some(n);
205 self
206 }
207
208 pub fn request_timeout(mut self, d: std::time::Duration) -> Self {
210 self.request_timeout = Some(d);
211 self
212 }
213
214 pub fn authenticate(mut self, cb: crate::auth::Authenticate) -> Self {
217 self.authenticate = Some(cb);
218 self
219 }
220
221 pub fn oauth_resource_metadata(
225 mut self,
226 metadata: crate::auth::oauth::OAuthResourceMetadata,
227 ) -> Self {
228 self.oauth_metadata = Some(Arc::new(metadata));
229 self
230 }
231
232 pub fn cors_origins(mut self, origins: impl Into<String>) -> Self {
235 self.cors_origins = Some(origins.into());
236 self
237 }
238
239 pub fn cors_max_age(mut self, seconds: u32) -> Self {
241 self.cors_max_age = Some(seconds);
242 self
243 }
244
245 pub fn prefix(mut self, prefix: impl Into<String>) -> Self {
247 self.prefix = Some(prefix.into());
248 self
249 }
250
251 pub fn response_compression_level(mut self, level: i32) -> Self {
254 self.response_compression_level = Some(level);
255 self
256 }
257
258 pub fn enable_landing_page(mut self, enabled: bool) -> Self {
260 self.landing_page_enabled = Some(enabled);
261 self
262 }
263
264 pub fn enable_describe_page(mut self, enabled: bool) -> Self {
266 self.describe_page_enabled = Some(enabled);
267 self
268 }
269
270 pub fn enable_health(mut self, enabled: bool) -> Self {
272 self.health_enabled = Some(enabled);
273 self
274 }
275
276 pub fn max_request_bytes(mut self, n: usize) -> Self {
282 self.max_request_bytes = Some(n);
283 self
284 }
285
286 pub fn max_upload_bytes(mut self, n: usize) -> Self {
290 self.max_upload_bytes = Some(n);
291 self
292 }
293
294 pub fn max_response_bytes(mut self, n: usize) -> Self {
300 self.max_response_bytes = Some(n);
301 self
302 }
303
304 pub fn max_externalized_response_bytes(mut self, n: usize) -> Self {
308 self.max_externalized_response_bytes = Some(n);
309 self
310 }
311
312 pub fn upload_url_provider(
316 mut self,
317 provider: Arc<dyn crate::external::UploadUrlProvider>,
318 ) -> Self {
319 self.upload_url_provider = Some(provider);
320 self
321 }
322
323 pub fn enable_sticky(mut self, enabled: bool) -> Self {
328 self.enable_sticky = Some(enabled);
329 self
330 }
331
332 pub fn sticky_default_ttl(mut self, ttl: std::time::Duration) -> Self {
335 self.sticky_default_ttl = Some(ttl);
336 self
337 }
338
339 pub fn sticky_echo_headers(
345 mut self,
346 headers: impl IntoIterator<Item = (String, String)>,
347 ) -> Self {
348 self.sticky_echo_headers = headers.into_iter().collect();
349 self
350 }
351
352 pub fn build(self) -> Arc<HttpState> {
353 let server = self.server.expect("HttpStateBuilder::server is required");
354 assert!(
361 !(self.cors_origins.as_deref() == Some("*") && self.authenticate.is_some()),
362 "HttpStateBuilder: cors_origins(\"*\") cannot be combined with an \
363 authenticate callback — browsers reject credentialed requests \
364 against a wildcard origin. Configure a specific origin."
365 );
366 let token_key = self.token_key.unwrap_or_else(|| {
367 tracing::warn!(
368 target: "vgi_rpc.http",
369 "no token_key configured; using ephemeral per-process AEAD key — \
370 state tokens will not survive restart or load-balance across workers"
371 );
372 let mut k = [0u8; 32];
373 rand::thread_rng().fill_bytes(&mut k);
374 k
375 });
376 let oauth_metadata_json = self
377 .oauth_metadata
378 .as_ref()
379 .map(|m| m.to_json().into_bytes());
380 let www_authenticate = self.oauth_metadata.as_ref().map(|m| m.www_authenticate());
381 let sticky = if self.enable_sticky.unwrap_or(false) {
382 let ttl = self
383 .sticky_default_ttl
384 .unwrap_or_else(|| std::time::Duration::from_secs(300));
385 Some(crate::sticky::StickyContext::new(
386 token_key,
387 ttl,
388 self.sticky_echo_headers,
389 server.server_id.clone(),
390 ))
391 } else {
392 None
393 };
394 Arc::new(HttpState {
395 server,
396 token_key,
397 producer_batch_limit: self.producer_batch_limit.unwrap_or(1),
398 token_ttl: self
399 .token_ttl
400 .unwrap_or_else(|| std::time::Duration::from_secs(300)),
401 max_body_size: self.max_body_size.unwrap_or(64 * 1024 * 1024),
402 request_timeout: self
403 .request_timeout
404 .unwrap_or_else(|| std::time::Duration::from_secs(30)),
405 authenticate: self.authenticate,
406 oauth_metadata: self.oauth_metadata,
407 oauth_metadata_json,
408 www_authenticate,
409 cors_origins: self.cors_origins,
410 cors_max_age: self.cors_max_age.unwrap_or(7200),
411 prefix: self.prefix.unwrap_or_default(),
412 response_compression_level: self.response_compression_level,
413 landing_page_enabled: self.landing_page_enabled.unwrap_or(true),
414 describe_page_enabled: self.describe_page_enabled.unwrap_or(true),
415 health_enabled: self.health_enabled.unwrap_or(true),
416 max_request_bytes: self.max_request_bytes,
417 max_upload_bytes: self.max_upload_bytes,
418 max_response_bytes: self.max_response_bytes,
419 max_externalized_response_bytes: self.max_externalized_response_bytes,
420 upload_url_provider: self.upload_url_provider,
421 sticky,
422 })
423 }
424}
425
426impl HttpState {
427 pub fn new(server: Arc<RpcServer>) -> Arc<Self> {
430 Self::builder().server(server).build()
431 }
432
433 pub fn builder() -> HttpStateBuilder {
434 HttpStateBuilder::default()
435 }
436
437 pub fn sticky_drain_handle(&self) -> Option<crate::sticky::DrainHandle> {
440 self.sticky.as_ref().map(|c| c.drain_handle())
441 }
442
443 pub fn token_ttl(&self) -> std::time::Duration {
444 self.token_ttl
445 }
446
447 pub fn max_body_size(&self) -> usize {
448 self.max_body_size
449 }
450
451 pub(crate) fn pack_state_token(
458 &self,
459 auth: &crate::auth::AuthContext,
460 state_bytes: &[u8],
461 output_schema_bytes: &[u8],
462 input_schema_bytes: &[u8],
463 stream_id: &str,
464 ) -> String {
465 let aad = compute_aad(auth);
466 pack_state_token(
467 &self.token_key,
468 &aad,
469 state_bytes,
470 output_schema_bytes,
471 input_schema_bytes,
472 stream_id,
473 current_unix_secs(),
474 )
475 }
476
477 pub(crate) fn unpack_state_token(
480 &self,
481 auth: &crate::auth::AuthContext,
482 token: &str,
483 ) -> Result<UnpackedToken> {
484 let ttl = if self.token_ttl.is_zero() {
485 None
486 } else {
487 Some(self.token_ttl)
488 };
489 let aad = compute_aad(auth);
490 unpack_state_token(&self.token_key, &aad, token, ttl)
491 }
492}
493
494fn compute_aad(auth: &crate::auth::AuthContext) -> Vec<u8> {
499 let prefix = b"vgi_rpc.state.v4\x00";
500 if !auth.authenticated {
501 let mut out = Vec::with_capacity(prefix.len() + b"\x00anonymous".len());
502 out.extend_from_slice(prefix);
503 out.extend_from_slice(b"\x00anonymous");
504 return out;
505 }
506 let mut out =
507 Vec::with_capacity(prefix.len() + 1 + auth.domain.len() + 1 + auth.principal.len());
508 out.extend_from_slice(prefix);
509 out.push(0x01);
510 out.extend_from_slice(auth.domain.as_bytes());
511 out.push(0);
512 out.extend_from_slice(auth.principal.as_bytes());
513 out
514}
515
516pub(crate) const STATE_TOKEN_VERSION: u8 = 0x04;
518
519#[derive(Debug, Clone)]
521pub(crate) struct UnpackedToken {
522 pub state_bytes: Vec<u8>,
523 pub output_schema_bytes: Vec<u8>,
524 pub input_schema_bytes: Vec<u8>,
525 pub stream_id: String,
526 #[allow(dead_code)]
527 pub created_at: u64,
528}
529
530fn current_unix_secs() -> u64 {
532 std::time::SystemTime::now()
533 .duration_since(std::time::UNIX_EPOCH)
534 .map(|d| d.as_secs())
535 .unwrap_or(0)
536}
537
538pub(crate) fn pack_state_token(
565 token_key: &[u8; 32],
566 aad: &[u8],
567 state_bytes: &[u8],
568 output_schema_bytes: &[u8],
569 input_schema_bytes: &[u8],
570 stream_id: &str,
571 created_at: u64,
572) -> String {
573 let mut plaintext = Vec::with_capacity(
574 8 + 4
575 + state_bytes.len()
576 + 4
577 + output_schema_bytes.len()
578 + 4
579 + input_schema_bytes.len()
580 + 4
581 + stream_id.len(),
582 );
583 plaintext.extend_from_slice(&created_at.to_le_bytes());
584 plaintext.extend_from_slice(&(state_bytes.len() as u32).to_le_bytes());
585 plaintext.extend_from_slice(state_bytes);
586 plaintext.extend_from_slice(&(output_schema_bytes.len() as u32).to_le_bytes());
587 plaintext.extend_from_slice(output_schema_bytes);
588 plaintext.extend_from_slice(&(input_schema_bytes.len() as u32).to_le_bytes());
589 plaintext.extend_from_slice(input_schema_bytes);
590 plaintext.extend_from_slice(&(stream_id.len() as u32).to_le_bytes());
591 plaintext.extend_from_slice(stream_id.as_bytes());
592
593 crate::crypto::seal_base64(&plaintext, token_key, aad, STATE_TOKEN_VERSION)
594}
595
596pub(crate) fn unpack_state_token(
604 token_key: &[u8; 32],
605 aad: &[u8],
606 token: &str,
607 token_ttl: Option<std::time::Duration>,
608) -> Result<UnpackedToken> {
609 let raw = base64::engine::general_purpose::STANDARD
610 .decode(token.as_bytes())
611 .map_err(|_| RpcError::runtime_error("Malformed state token"))?;
612
613 let plaintext = crate::crypto::open_bytes(&raw, token_key, aad, STATE_TOKEN_VERSION)
614 .map_err(|_| RpcError::runtime_error("State token signature verification failed"))?;
615
616 if plaintext.len() < 8 {
617 return Err(RpcError::runtime_error("Malformed state token"));
618 }
619 let created_at = u64::from_le_bytes(plaintext[0..8].try_into().unwrap());
620
621 if let Some(ttl) = token_ttl {
622 let now = current_unix_secs();
623 if now > created_at && now - created_at > ttl.as_secs() {
624 return Err(RpcError::runtime_error("State token expired"));
625 }
626 const MAX_CLOCK_SKEW_SECS: u64 = 60;
631 if created_at > now && created_at - now > MAX_CLOCK_SKEW_SECS {
632 return Err(RpcError::runtime_error(
633 "State token timestamp is implausibly in the future",
634 ));
635 }
636 }
637
638 let mut pos = 8;
639 let state_bytes = read_segment(&plaintext, &mut pos)?;
640 let output_schema_bytes = read_segment(&plaintext, &mut pos)?;
641 let input_schema_bytes = read_segment(&plaintext, &mut pos)?;
642 let stream_id_bytes = read_segment(&plaintext, &mut pos)?;
643 if pos != plaintext.len() {
644 return Err(RpcError::runtime_error("Malformed state token"));
645 }
646 let stream_id = String::from_utf8(stream_id_bytes)
647 .map_err(|_| RpcError::runtime_error("Malformed state token"))?;
648
649 Ok(UnpackedToken {
650 state_bytes,
651 output_schema_bytes,
652 input_schema_bytes,
653 stream_id,
654 created_at,
655 })
656}
657
658fn read_segment(buf: &[u8], pos: &mut usize) -> Result<Vec<u8>> {
659 if *pos + 4 > buf.len() {
660 return Err(RpcError::runtime_error("Malformed state token"));
661 }
662 let len = u32::from_le_bytes(buf[*pos..*pos + 4].try_into().unwrap()) as usize;
663 *pos += 4;
664 if *pos + len > buf.len() {
665 return Err(RpcError::runtime_error("Malformed state token"));
666 }
667 let out = buf[*pos..*pos + len].to_vec();
668 *pos += len;
669 Ok(out)
670}
671
672fn write_schema_bytes(schema: &Schema) -> Result<Vec<u8>> {
676 let empty = empty_batch(schema)?;
677 crate::wire::write_one_batch(&empty, None)
678}
679
680fn read_schema_bytes(bytes: &[u8]) -> Result<SchemaRef> {
682 let r = StreamReader::new(bytes)?;
683 Ok(r.schema())
684}
685
686pub async fn shutdown_signal() {
702 #[cfg(unix)]
703 {
704 use tokio::signal::unix::{signal, SignalKind};
705 let mut term = match signal(SignalKind::terminate()) {
706 Ok(s) => s,
707 Err(_) => {
708 let _ = tokio::signal::ctrl_c().await;
709 return;
710 }
711 };
712 let mut intr = match signal(SignalKind::interrupt()) {
713 Ok(s) => s,
714 Err(_) => {
715 let _ = tokio::signal::ctrl_c().await;
716 return;
717 }
718 };
719 tokio::select! {
720 _ = term.recv() => {},
721 _ = intr.recv() => {},
722 }
723 }
724 #[cfg(not(unix))]
725 {
726 let _ = tokio::signal::ctrl_c().await;
727 }
728}
729
730pub async fn serve_with_shutdown(
734 state: Arc<HttpState>,
735 listener: tokio::net::TcpListener,
736) -> std::io::Result<()> {
737 let app = build_router(state);
738 axum::serve(listener, app)
739 .with_graceful_shutdown(shutdown_signal())
740 .await
741}
742
743const MAX_RESPONSE_BYTES_HARD_CAP: usize = 256 * 1024 * 1024;
755
756fn response_buffer_ceiling(state: &HttpState) -> usize {
762 match state.max_response_bytes {
763 Some(soft) => MAX_RESPONSE_BYTES_HARD_CAP.max(soft.saturating_mul(2)),
764 None => MAX_RESPONSE_BYTES_HARD_CAP,
765 }
766}
767
768pub fn build_router(state: Arc<HttpState>) -> Router {
769 let body_limit = state.max_body_size;
770 let request_timeout = state.request_timeout;
771 build_router_inner(state.clone())
772 .layer(axum::middleware::from_fn_with_state(
773 state,
774 postprocess_middleware,
775 ))
776 .layer(tower_http::limit::RequestBodyLimitLayer::new(body_limit))
779 .layer(tower_http::timeout::TimeoutLayer::with_status_code(
782 StatusCode::REQUEST_TIMEOUT,
783 request_timeout,
784 ))
785 .layer(tower_http::catch_panic::CatchPanicLayer::new())
790}
791
792async fn postprocess_middleware(
793 axum::extract::State(state): axum::extract::State<Arc<HttpState>>,
794 req: axum::http::Request<axum::body::Body>,
795 next: axum::middleware::Next,
796) -> Response {
797 use axum::body::to_bytes;
798 state.server.notify_transport(
802 crate::transport::TransportKind::Http,
803 crate::transport::TransportCapabilities::none(),
804 );
805 let req_headers = req.headers().clone();
806 let req_method = req.method().clone();
807 let req_path = req.uri().path().to_string();
808
809 if let Some(limit) = state.max_request_bytes {
814 let exempt = req_path.ends_with("/__upload_url__/init")
815 || req_path.contains("/__upload_url__/")
816 || req_path == "/health"
817 || req_path.ends_with("/health");
818 if !exempt {
819 if let Some(cl) = req
820 .headers()
821 .get(header::CONTENT_LENGTH)
822 .and_then(|v| v.to_str().ok())
823 .and_then(|s| s.parse::<usize>().ok())
824 {
825 if cl > limit {
826 let mut h = HeaderMap::new();
827 attach_capability_headers(&state, &mut h, &req_method);
828 attach_cors_headers(&state, &mut h, &req_headers, false);
829 return (
830 StatusCode::PAYLOAD_TOO_LARGE,
831 h,
832 format!(
833 "Request body of {cl} bytes exceeds advertised \
834 max_request_bytes={limit}. Use the upload-URL \
835 flow (__upload_url__/init) to externalize."
836 ),
837 )
838 .into_response();
839 }
840 }
841 }
842 }
843
844 let resp = next.run(req).await;
845 let (mut parts, body) = resp.into_parts();
846 let response_limit = response_buffer_ceiling(&state);
854 let bytes = match to_bytes(body, response_limit).await {
855 Ok(b) => b,
856 Err(_) => {
857 let mut h = HeaderMap::new();
858 attach_cors_headers(&state, &mut h, &req_headers, false);
859 return (
860 StatusCode::INTERNAL_SERVER_ERROR,
861 h,
862 "response body exceeded the configured size limit",
863 )
864 .into_response();
865 }
866 };
867 let is_arrow = parts
868 .headers
869 .get(header::CONTENT_TYPE)
870 .and_then(|v| v.to_str().ok())
871 == Some(ARROW_CONTENT_TYPE);
872
873 if is_arrow {
874 if let Some(level) = state.response_compression_level {
875 let accepts = req_headers
876 .get(header::ACCEPT_ENCODING)
877 .and_then(|v| v.to_str().ok())
878 .unwrap_or("");
879 if accepts.contains("zstd") {
880 if let Ok(compressed) = zstd::encode_all(std::io::Cursor::new(&bytes), level) {
881 parts
882 .headers
883 .insert(header::CONTENT_ENCODING, HeaderValue::from_static("zstd"));
884 attach_cors_headers(&state, &mut parts.headers, &req_headers, false);
885 let body_new = axum::body::Body::from(compressed);
886 return Response::from_parts(parts, body_new);
887 }
888 }
889 }
890 }
891 attach_cors_headers(&state, &mut parts.headers, &req_headers, false);
892 attach_capability_headers(&state, &mut parts.headers, &req_method);
893 Response::from_parts(parts, axum::body::Body::from(bytes))
894}
895
896fn attach_capability_headers(
902 state: &Arc<HttpState>,
903 out: &mut HeaderMap,
904 method: &axum::http::Method,
905) {
906 let mut any = false;
907 if let Some(n) = state.max_request_bytes {
908 if let Ok(v) = HeaderValue::from_str(&n.to_string()) {
909 out.insert("vgi-max-request-bytes", v);
910 any = true;
911 }
912 }
913 if let Some(n) = state.max_response_bytes {
914 if let Ok(v) = HeaderValue::from_str(&n.to_string()) {
915 out.insert("vgi-max-response-bytes", v);
916 any = true;
917 }
918 }
919 if let Some(n) = state.max_externalized_response_bytes {
920 if let Ok(v) = HeaderValue::from_str(&n.to_string()) {
921 out.insert("vgi-max-externalized-response-bytes", v);
922 any = true;
923 }
924 }
925 out.insert(
928 "vgi-externalization-enabled",
929 HeaderValue::from_static(if state.server.external_config().is_some() {
930 "true"
931 } else {
932 "false"
933 }),
934 );
935 if state.upload_url_provider.is_some() {
936 out.insert("vgi-upload-url-support", HeaderValue::from_static("true"));
937 any = true;
938 if let Some(n) = state.max_upload_bytes {
939 if let Ok(v) = HeaderValue::from_str(&n.to_string()) {
940 out.insert("vgi-max-upload-bytes", v);
941 }
942 }
943 }
944 if let Some(sticky) = state.sticky.as_ref() {
947 out.insert(STICKY_ENABLED_HEADER, HeaderValue::from_static("true"));
948 any = true;
949 if let Ok(v) = HeaderValue::from_str(&sticky.default_ttl.as_secs().to_string()) {
950 out.insert(STICKY_DEFAULT_TTL_HEADER, v);
951 }
952 if !sticky.echo_headers.is_empty() {
953 let names = sticky
954 .echo_headers
955 .iter()
956 .map(|(n, _)| n.as_str())
957 .collect::<Vec<_>>()
958 .join(",");
959 if let Ok(v) = HeaderValue::from_str(&names) {
960 out.insert(STICKY_ECHO_HEADERS_HEADER, v);
961 }
962 }
963 } else {
964 out.insert(STICKY_ENABLED_HEADER, HeaderValue::from_static("false"));
965 }
966 if any && method == axum::http::Method::OPTIONS {
967 out.insert(
968 header::CACHE_CONTROL,
969 HeaderValue::from_static("public, max-age=300"),
970 );
971 }
972}
973
974fn build_router_inner(state: Arc<HttpState>) -> Router {
975 let prefix = state.prefix.clone();
976 let api = Router::new()
977 .route("/:method", post(handle_unary).options(handle_preflight))
978 .route(
979 "/:method/init",
980 post(handle_stream_init).options(handle_preflight),
981 )
982 .route(
983 "/:method/exchange",
984 post(handle_stream_exchange).options(handle_preflight),
985 );
986
987 let api = if state.upload_url_provider.is_some() {
988 api.route(
989 "/__upload_url__/init",
990 post(handle_upload_url).options(handle_preflight),
991 )
992 } else {
993 api
994 };
995
996 let api = if state.sticky.is_some() {
997 api.route(
998 "/__session__",
999 axum::routing::delete(handle_delete_session).options(handle_preflight),
1000 )
1001 } else {
1002 api
1003 };
1004
1005 let mut app = if prefix.is_empty() {
1006 api
1007 } else {
1008 Router::new().nest(&prefix, api)
1009 };
1010
1011 app = app.route(
1012 &format!(
1013 "{}{}",
1014 prefix,
1015 crate::auth::oauth::OAuthResourceMetadata::well_known_path()
1016 ),
1017 axum::routing::get(handle_oauth_metadata),
1018 );
1019
1020 if state.health_enabled {
1021 app = app.route(
1027 "/health",
1028 axum::routing::get(handle_health).options(handle_preflight),
1029 );
1030 }
1031 if state.landing_page_enabled {
1032 let landing_path = if prefix.is_empty() {
1033 "/".to_string()
1034 } else {
1035 prefix.clone()
1036 };
1037 app = app.route(&landing_path, axum::routing::get(handle_landing));
1038 }
1039 if state.describe_page_enabled {
1040 app = app.route(
1041 &format!("{prefix}/describe"),
1042 axum::routing::get(handle_describe_page),
1043 );
1044 }
1045
1046 app.with_state(state)
1047}
1048
1049async fn handle_delete_session(
1053 State(state): State<Arc<HttpState>>,
1054 headers: HeaderMap,
1055) -> Response {
1056 let auth = match authenticate_request(&state, SESSION_ENDPOINT, &headers) {
1057 Ok(a) => a,
1058 Err(resp) => return resp,
1059 };
1060 let Some(ctx) = state.sticky.as_ref() else {
1061 return StatusCode::OK.into_response();
1062 };
1063 let session_header = headers.get(SESSION_HEADER).and_then(|v| v.to_str().ok());
1064 match crate::sticky::handle_delete(ctx, &auth, session_header) {
1065 crate::sticky::DeleteOutcome::Idempotent => StatusCode::OK.into_response(),
1066 crate::sticky::DeleteOutcome::Closed => {
1067 let mut h = HeaderMap::new();
1068 h.insert(SESSION_CLOSE_HEADER, HeaderValue::from_static("true"));
1069 (StatusCode::NO_CONTENT, h).into_response()
1070 }
1071 }
1072}
1073
1074async fn handle_preflight(State(state): State<Arc<HttpState>>, headers: HeaderMap) -> Response {
1075 let mut h = HeaderMap::new();
1076 attach_cors_headers(&state, &mut h, &headers, true);
1077 (StatusCode::NO_CONTENT, h).into_response()
1078}
1079
1080async fn handle_health(State(state): State<Arc<HttpState>>) -> Response {
1081 let body = serde_json::json!({
1082 "status": "ok",
1083 "server_id": state.server.server_id,
1084 "protocol": state.server.protocol_name(),
1085 })
1086 .to_string();
1087 let mut h = HeaderMap::new();
1088 h.insert(
1089 header::CONTENT_TYPE,
1090 HeaderValue::from_static("application/json"),
1091 );
1092 (StatusCode::OK, h, body).into_response()
1093}
1094
1095async fn handle_landing(State(state): State<Arc<HttpState>>) -> Response {
1096 let body = render_landing(&state);
1097 let mut h = HeaderMap::new();
1098 h.insert(
1099 header::CONTENT_TYPE,
1100 HeaderValue::from_static("text/html; charset=utf-8"),
1101 );
1102 (StatusCode::OK, h, body).into_response()
1103}
1104
1105async fn handle_describe_page(State(state): State<Arc<HttpState>>) -> Response {
1106 let body = render_describe_page(&state);
1107 let mut h = HeaderMap::new();
1108 h.insert(
1109 header::CONTENT_TYPE,
1110 HeaderValue::from_static("text/html; charset=utf-8"),
1111 );
1112 (StatusCode::OK, h, body).into_response()
1113}
1114
1115fn render_landing(state: &Arc<HttpState>) -> String {
1116 let name = if state.server.protocol_name().is_empty() {
1117 "vgi-rpc service"
1118 } else {
1119 state.server.protocol_name()
1120 };
1121 let server_id = &state.server.server_id;
1122 let describe_link = if state.describe_page_enabled {
1123 format!(
1124 r#"<p><a href="{0}/describe">API reference</a></p>"#,
1125 state.prefix
1126 )
1127 } else {
1128 String::new()
1129 };
1130 format!(
1131 "<!doctype html><html><head><meta charset=\"utf-8\"><title>{name}</title></head><body>\
1132 <h1>{name}</h1><p>server_id: <code>{server_id}</code></p>{describe_link}\
1133 </body></html>"
1134 )
1135}
1136
1137fn render_describe_page(state: &Arc<HttpState>) -> String {
1138 let mut body = String::from(
1139 "<!doctype html><html><head><meta charset=\"utf-8\"><title>API reference</title></head><body>",
1140 );
1141 body.push_str(&format!(
1142 "<h1>{}</h1><table><tr><th>method</th><th>type</th><th>doc</th></tr>",
1143 state.server.protocol_name()
1144 ));
1145 for name in state.server.sorted_method_names() {
1146 let m = &state.server.methods()[name];
1147 let kind = match m.method_type {
1148 crate::server::MethodType::Unary => "unary",
1149 _ => "stream",
1150 };
1151 let doc = m.doc.as_deref().unwrap_or("");
1152 body.push_str(&format!(
1153 "<tr><td><code>{name}</code></td><td>{kind}</td><td>{}</td></tr>",
1154 html_escape(doc)
1155 ));
1156 }
1157 body.push_str("</table></body></html>");
1158 body
1159}
1160
1161fn html_escape(s: &str) -> String {
1162 s.replace('&', "&")
1163 .replace('<', "<")
1164 .replace('>', ">")
1165}
1166
1167fn attach_cors_headers(
1168 state: &Arc<HttpState>,
1169 out: &mut HeaderMap,
1170 req_headers: &HeaderMap,
1171 is_preflight: bool,
1172) {
1173 let Some(origins) = state.cors_origins.as_deref() else {
1174 return;
1175 };
1176 if let Ok(v) = HeaderValue::from_str(origins) {
1177 out.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, v);
1178 }
1179 if state.authenticate.is_some() {
1186 out.insert(
1187 header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
1188 HeaderValue::from_static("true"),
1189 );
1190 }
1191 out.insert(
1192 header::ACCESS_CONTROL_ALLOW_METHODS,
1193 HeaderValue::from_static("POST, GET, OPTIONS"),
1194 );
1195 let requested = req_headers
1196 .get(header::ACCESS_CONTROL_REQUEST_HEADERS)
1197 .and_then(|v| v.to_str().ok())
1198 .unwrap_or("Content-Type, Authorization, Cookie, Accept-Encoding");
1199 if let Ok(v) = HeaderValue::from_str(requested) {
1200 out.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, v);
1201 }
1202 out.insert(
1203 header::ACCESS_CONTROL_EXPOSE_HEADERS,
1204 HeaderValue::from_static("Content-Encoding, WWW-Authenticate"),
1205 );
1206 if is_preflight {
1207 if let Ok(v) = HeaderValue::from_str(&state.cors_max_age.to_string()) {
1208 out.insert(header::ACCESS_CONTROL_MAX_AGE, v);
1209 }
1210 }
1211}
1212
1213async fn handle_oauth_metadata(State(state): State<Arc<HttpState>>) -> Response {
1214 match state.oauth_metadata_json.as_ref() {
1215 Some(body) => {
1216 let mut h = HeaderMap::new();
1217 h.insert(
1218 header::CONTENT_TYPE,
1219 HeaderValue::from_static("application/json"),
1220 );
1221 h.insert(
1222 header::CACHE_CONTROL,
1223 HeaderValue::from_static("public, max-age=60"),
1224 );
1225 (StatusCode::OK, h, body.clone()).into_response()
1226 }
1227 None => (StatusCode::NOT_FOUND, "").into_response(),
1228 }
1229}
1230
1231fn parse_cookies(raw: Option<&str>) -> std::collections::BTreeMap<String, String> {
1234 let mut out = std::collections::BTreeMap::new();
1235 let Some(raw) = raw else { return out };
1236 for part in raw.split(';') {
1237 let part = part.trim();
1238 if let Some((k, v)) = part.split_once('=') {
1239 let v = v.trim();
1240 let v = v
1241 .strip_prefix('"')
1242 .and_then(|s| s.strip_suffix('"'))
1243 .unwrap_or(v);
1244 out.insert(k.trim().to_string(), v.to_string());
1245 }
1246 }
1247 out
1248}
1249
1250fn headers_to_pairs(headers: &HeaderMap) -> Vec<(String, String)> {
1252 headers
1253 .iter()
1254 .filter_map(|(k, v)| {
1255 v.to_str()
1256 .ok()
1257 .map(|s| (k.as_str().to_string(), s.to_string()))
1258 })
1259 .collect()
1260}
1261
1262fn authenticate_request(
1265 state: &Arc<HttpState>,
1266 method: &str,
1267 headers: &HeaderMap,
1268) -> std::result::Result<crate::auth::AuthContext, Response> {
1269 let Some(cb) = state.authenticate.as_ref() else {
1270 return Ok(crate::auth::AuthContext::anonymous());
1271 };
1272 let pairs = headers_to_pairs(headers);
1273 let req = crate::auth::AuthRequest {
1274 method,
1275 headers: &pairs,
1276 peer_addr: None,
1277 };
1278 match (cb)(&req) {
1279 Ok(ctx) => Ok(ctx),
1280 Err(err) => {
1281 let status = match err.error_type.as_str() {
1282 "PermissionError" | "ValueError" => StatusCode::UNAUTHORIZED,
1283 _ => StatusCode::INTERNAL_SERVER_ERROR,
1284 };
1285 let mut h = HeaderMap::new();
1286 let body = if status == StatusCode::UNAUTHORIZED {
1293 if let Some(wa) = state.www_authenticate.as_deref() {
1294 if let Ok(hv) = HeaderValue::from_str(wa) {
1295 h.insert(header::WWW_AUTHENTICATE, hv);
1296 }
1297 }
1298 tracing::info!(
1299 target: "vgi_rpc.http",
1300 error = %err.message,
1301 "request authentication rejected"
1302 );
1303 "authentication failed"
1304 } else {
1305 tracing::error!(
1306 target: "vgi_rpc.http",
1307 error = %err.message,
1308 "authentication callback errored"
1309 );
1310 "internal error during authentication"
1311 };
1312 Err((status, h, body).into_response())
1313 }
1314 }
1315}
1316
1317fn arrow_response(status: StatusCode, body: Vec<u8>) -> Response {
1322 let mut headers = HeaderMap::new();
1323 headers.insert(
1324 header::CONTENT_TYPE,
1325 HeaderValue::from_static(ARROW_CONTENT_TYPE),
1326 );
1327 (status, headers, body).into_response()
1328}
1329
1330fn enforce_response_body_cap(
1335 state: &Arc<HttpState>,
1336 schema: &arrow_schema::Schema,
1337 body: Vec<u8>,
1338 method: &str,
1339 server_id: &str,
1340 request_id: &str,
1341) -> Response {
1342 if let Some(limit) = state.max_response_bytes {
1343 if body.len() > limit {
1344 let err = RpcError::runtime_error(format!(
1345 "HTTP body exceeds max_response_bytes ({} > {}) for method {:?}",
1346 body.len(),
1347 limit,
1348 method
1349 ));
1350 return cap_error_response(schema, &err, server_id, request_id);
1351 }
1352 }
1353 arrow_response(StatusCode::OK, body)
1354}
1355
1356fn cap_error_response(
1361 schema: &arrow_schema::Schema,
1362 err: &RpcError,
1363 server_id: &str,
1364 request_id: &str,
1365) -> Response {
1366 let mut buf = Vec::new();
1367 {
1368 let mut sw = StreamWriter::new(&mut buf, schema).unwrap();
1369 let md = build_error_metadata(err, server_id, request_id);
1370 let _ = sw.write(&empty_batch(schema).unwrap(), Some(&md));
1371 let _ = sw.finish();
1372 }
1373 let mut headers = HeaderMap::new();
1374 headers.insert(
1375 header::CONTENT_TYPE,
1376 HeaderValue::from_static(ARROW_CONTENT_TYPE),
1377 );
1378 headers.insert("x-vgi-rpc-error", HeaderValue::from_static("true"));
1379 (StatusCode::OK, headers, buf).into_response()
1380}
1381
1382fn plain_error(status: StatusCode, msg: String) -> Response {
1383 (status, msg).into_response()
1384}
1385
1386fn has_arrow_ct(headers: &HeaderMap) -> bool {
1387 headers
1388 .get(header::CONTENT_TYPE)
1389 .and_then(|v| v.to_str().ok())
1390 .map(|s| s == ARROW_CONTENT_TYPE)
1391 .unwrap_or(false)
1392}
1393
1394fn maybe_decompress(headers: &HeaderMap, body: &Bytes, max_size: usize) -> Result<Vec<u8>> {
1395 let enc = headers
1396 .get(header::CONTENT_ENCODING)
1397 .and_then(|v| v.to_str().ok())
1398 .unwrap_or("");
1399 if body.len() > max_size {
1400 return Err(RpcError::runtime_error(format!(
1401 "Request body exceeds max size ({} bytes > {})",
1402 body.len(),
1403 max_size
1404 )));
1405 }
1406 if enc.eq_ignore_ascii_case("zstd") {
1407 decode_zstd_bounded(body.as_ref(), max_size)
1408 } else {
1409 Ok(body.to_vec())
1410 }
1411}
1412
1413fn decode_zstd_bounded(input: &[u8], max_size: usize) -> Result<Vec<u8>> {
1417 use std::io::Read;
1418 let mut decoder = zstd::Decoder::new(input)
1419 .map_err(|e| RpcError::runtime_error(format!("zstd decode: {e}")))?;
1420 let mut out = Vec::with_capacity(input.len().min(max_size).min(64 * 1024));
1421 let mut buf = [0u8; 16 * 1024];
1422 loop {
1423 let n = decoder
1424 .read(&mut buf)
1425 .map_err(|e| RpcError::runtime_error(format!("zstd decode: {e}")))?;
1426 if n == 0 {
1427 break;
1428 }
1429 if out.len() + n > max_size {
1430 return Err(RpcError::runtime_error(format!(
1431 "Decompressed body exceeds max size ({}+ bytes > {})",
1432 out.len() + n,
1433 max_size
1434 )));
1435 }
1436 out.extend_from_slice(&buf[..n]);
1437 }
1438 Ok(out)
1439}
1440
1441fn parse_request_from_body(body: &[u8]) -> Result<Request> {
1442 let mut r = StreamReader::new(body)?;
1443 let (batch, metadata) = r
1444 .read_next()?
1445 .ok_or_else(|| RpcError::protocol_error("empty IPC stream"))?;
1446 r.drain()?;
1447 Request::from_read_batch(batch, metadata, false)
1448}
1449
1450fn error_stream_bytes(
1451 schema: &Schema,
1452 err: &RpcError,
1453 server_id: &str,
1454 request_id: &str,
1455) -> Vec<u8> {
1456 let mut buf = Vec::new();
1457 let mut w = StreamWriter::new(&mut buf, schema).unwrap();
1458 let md = build_error_metadata(err, server_id, request_id);
1459 let _ = w.write(&empty_batch(schema).unwrap(), Some(&md));
1460 let _ = w.finish();
1461 drop(w);
1462 buf
1463}
1464
1465fn arrow_error(
1469 state: &Arc<HttpState>,
1470 status: StatusCode,
1471 err: &RpcError,
1472 request_id: &str,
1473) -> Response {
1474 arrow_response(
1475 status,
1476 error_stream_bytes(&Schema::empty(), err, &state.server.server_id, request_id),
1477 )
1478}
1479
1480fn sticky_for_request(
1485 state: &Arc<HttpState>,
1486 auth: &crate::auth::AuthContext,
1487 headers: &HeaderMap,
1488) -> std::result::Result<Option<Arc<crate::sticky::StickySinkImpl>>, Response> {
1489 let Some(ctx) = state.sticky.as_ref() else {
1490 return Ok(None);
1491 };
1492 let accept = headers
1493 .get(SESSION_ACCEPT_HEADER)
1494 .and_then(|v| v.to_str().ok())
1495 .map(|s| s.trim().eq_ignore_ascii_case("true"))
1496 .unwrap_or(false);
1497 let session_header = headers.get(SESSION_HEADER).and_then(|v| v.to_str().ok());
1498 match crate::sticky::resolve(ctx, auth, accept, session_header) {
1499 crate::sticky::StickyResolution::Sink(s) => Ok(Some(s)),
1500 crate::sticky::StickyResolution::Lost(err) => Err(cap_error_response(
1501 &Schema::empty(),
1502 &err,
1503 &state.server.server_id,
1504 "",
1505 )),
1506 }
1507}
1508
1509fn stamp_session_headers(
1512 resp: &mut Response,
1513 state: &Arc<HttpState>,
1514 sink: &Arc<crate::sticky::StickySinkImpl>,
1515) {
1516 let headers = resp.headers_mut();
1517 if let Some(token) = sink.mint_token() {
1518 if let Ok(v) = HeaderValue::from_str(&token) {
1519 headers.insert(SESSION_HEADER, v);
1520 }
1521 if let Some(ctx) = state.sticky.as_ref() {
1522 for (name, value) in &ctx.echo_headers {
1523 let full = format!("{ECHO_HEADER_PREFIX}{name}");
1524 if let (Ok(n), Ok(v)) = (
1525 axum::http::HeaderName::from_bytes(full.as_bytes()),
1526 HeaderValue::from_str(value),
1527 ) {
1528 headers.insert(n, v);
1529 }
1530 }
1531 }
1532 }
1533 if sink.was_closed() {
1534 headers.insert(SESSION_CLOSE_HEADER, HeaderValue::from_static("true"));
1535 }
1536}
1537
1538fn decode_hex_key(s: &str) -> std::result::Result<Vec<u8>, String> {
1539 let s = s.trim();
1540 if s.len() % 2 != 0 {
1541 return Err("hex length must be even".into());
1542 }
1543 let mut out = Vec::with_capacity(s.len() / 2);
1544 let bytes = s.as_bytes();
1545 for pair in bytes.chunks_exact(2) {
1546 let hi = hex_nibble(pair[0])?;
1547 let lo = hex_nibble(pair[1])?;
1548 out.push((hi << 4) | lo);
1549 }
1550 if out.len() < 32 {
1551 return Err(format!(
1552 "signing key must be ≥ 32 bytes (got {} bytes)",
1553 out.len()
1554 ));
1555 }
1556 Ok(out)
1557}
1558
1559fn hex_nibble(c: u8) -> std::result::Result<u8, String> {
1560 match c {
1561 b'0'..=b'9' => Ok(c - b'0'),
1562 b'a'..=b'f' => Ok(c - b'a' + 10),
1563 b'A'..=b'F' => Ok(c - b'A' + 10),
1564 _ => Err(format!("invalid hex character: {:?}", c as char)),
1565 }
1566}
1567
1568fn decode_base64_key(s: &str) -> std::result::Result<Vec<u8>, String> {
1569 let s = s.trim().trim_end_matches('=');
1571 let mut padded = s.to_string();
1572 while padded.len() % 4 != 0 {
1573 padded.push('=');
1574 }
1575 let bytes = base64::engine::general_purpose::STANDARD
1576 .decode(padded.as_bytes())
1577 .map_err(|e| format!("base64 decode: {e}"))?;
1578 if bytes.len() < 32 {
1579 return Err(format!(
1580 "signing key must be ≥ 32 bytes (got {} bytes)",
1581 bytes.len()
1582 ));
1583 }
1584 Ok(bytes)
1585}
1586
1587fn new_session_id() -> String {
1588 let mut b = [0u8; 16];
1589 rand::thread_rng().fill_bytes(&mut b);
1590 bytes_to_hex(&b)
1591}
1592
1593const UPLOAD_URL_METHOD: &str = "__upload_url__";
1598const MAX_UPLOAD_URL_COUNT: i64 = 100;
1599
1600fn upload_url_response_schema() -> Schema {
1601 use arrow_schema::{DataType, Field, TimeUnit};
1602 Schema::new(vec![
1603 Field::new("upload_url", DataType::Utf8, false),
1604 Field::new("download_url", DataType::Utf8, false),
1605 Field::new(
1606 "expires_at",
1607 DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
1608 false,
1609 ),
1610 ])
1611}
1612
1613async fn handle_upload_url(
1614 State(state): State<Arc<HttpState>>,
1615 headers: HeaderMap,
1616 body: Bytes,
1617) -> Response {
1618 let auth = match authenticate_request(&state, UPLOAD_URL_METHOD, &headers) {
1619 Ok(a) => a,
1620 Err(resp) => return resp,
1621 };
1622 let _ = auth;
1623 if !has_arrow_ct(&headers) {
1624 return plain_error(
1625 StatusCode::UNSUPPORTED_MEDIA_TYPE,
1626 "need arrow content type".into(),
1627 );
1628 }
1629 let provider = match state.upload_url_provider.as_ref() {
1630 Some(p) => p.clone(),
1631 None => return plain_error(StatusCode::NOT_FOUND, "upload-url not enabled".into()),
1632 };
1633
1634 let body = match maybe_decompress(&headers, &body, state.max_body_size) {
1635 Ok(b) => b,
1636 Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
1637 };
1638 let req = match parse_request_from_body(&body) {
1639 Ok(r) => r,
1640 Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
1641 };
1642 if !req.method.is_empty() && req.method != UPLOAD_URL_METHOD {
1643 let err = RpcError::protocol_error(format!(
1644 "Method mismatch: expected '{UPLOAD_URL_METHOD}', got '{}'",
1645 req.method
1646 ));
1647 return arrow_error(&state, StatusCode::BAD_REQUEST, &err, &req.request_id);
1648 }
1649 let mut count: i64 = 1;
1651 if let Some(arr) = req.column("count") {
1652 use arrow_array::Array;
1653 if let Some(c) = arr.as_any().downcast_ref::<arrow_array::Int64Array>() {
1654 if !c.is_empty() && !Array::is_null(c, 0) {
1655 count = c.value(0);
1656 }
1657 }
1658 }
1659 count = count.clamp(1, MAX_UPLOAD_URL_COUNT);
1660
1661 let urls_res = tokio::task::block_in_place(|| {
1664 let mut out = Vec::with_capacity(count as usize);
1665 for _ in 0..count {
1666 out.push(provider.generate_upload_url()?);
1667 }
1668 Ok::<_, RpcError>(out)
1669 });
1670
1671 let schema = upload_url_response_schema();
1672 let mut body_buf = Vec::new();
1673 {
1674 let mut sw = match StreamWriter::new(&mut body_buf, &schema) {
1675 Ok(w) => w,
1676 Err(e) => {
1677 return arrow_error(
1678 &state,
1679 StatusCode::INTERNAL_SERVER_ERROR,
1680 &e,
1681 &req.request_id,
1682 )
1683 }
1684 };
1685 match urls_res {
1686 Ok(urls) => {
1687 use arrow_array::{StringArray, TimestampMicrosecondArray};
1688 let upload_arr = StringArray::from(
1689 urls.iter()
1690 .map(|u| u.upload_url.clone())
1691 .collect::<Vec<_>>(),
1692 );
1693 let download_arr = StringArray::from(
1694 urls.iter()
1695 .map(|u| u.download_url.clone())
1696 .collect::<Vec<_>>(),
1697 );
1698 let expires_arr = TimestampMicrosecondArray::from(
1699 urls.iter().map(|u| u.expires_at_micros).collect::<Vec<_>>(),
1700 )
1701 .with_timezone("UTC");
1702 let batch = match RecordBatch::try_new(
1703 Arc::new(schema.clone()),
1704 vec![
1705 Arc::new(upload_arr),
1706 Arc::new(download_arr),
1707 Arc::new(expires_arr),
1708 ],
1709 ) {
1710 Ok(b) => b,
1711 Err(e) => {
1712 let err = RpcError::runtime_error(format!("upload-url batch: {e}"));
1713 let md =
1714 build_error_metadata(&err, &state.server.server_id, &req.request_id);
1715 let _ = sw.write(&empty_batch(&schema).unwrap(), Some(&md));
1716 let _ = sw.finish();
1717 drop(sw);
1718 return arrow_response(StatusCode::OK, body_buf);
1719 }
1720 };
1721 let _ = sw.write(&batch, None);
1722 }
1723 Err(err) => {
1724 let md = build_error_metadata(&err, &state.server.server_id, &req.request_id);
1725 let _ = sw.write(&empty_batch(&schema).unwrap(), Some(&md));
1726 }
1727 }
1728 let _ = sw.finish();
1729 }
1730 arrow_response(StatusCode::OK, body_buf)
1731}
1732
1733async fn handle_unary(
1738 State(state): State<Arc<HttpState>>,
1739 Path(method): Path<String>,
1740 headers: HeaderMap,
1741 body: Bytes,
1742) -> Response {
1743 let auth = match authenticate_request(&state, &method, &headers) {
1747 Ok(a) => a,
1748 Err(resp) => return resp,
1749 };
1750 if !has_arrow_ct(&headers) {
1751 return plain_error(
1752 StatusCode::UNSUPPORTED_MEDIA_TYPE,
1753 "need arrow content type".into(),
1754 );
1755 }
1756 let sticky_sink = match sticky_for_request(&state, &auth, &headers) {
1759 Ok(s) => s,
1760 Err(resp) => return resp,
1761 };
1762 let cookies = parse_cookies(headers.get(header::COOKIE).and_then(|v| v.to_str().ok()));
1763 let server = state.server.clone();
1764
1765 let body = match maybe_decompress(&headers, &body, state.max_body_size) {
1766 Ok(b) => b,
1767 Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
1768 };
1769 let mut req = match parse_request_from_body(&body) {
1770 Ok(r) => r,
1771 Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
1772 };
1773
1774 if md_get(&req.metadata, crate::metadata::LOCATION_KEY).is_some() {
1779 if let Some(cfg) = server.external_config().as_ref() {
1780 let outer_md = req.metadata.clone();
1781 let outer_batch = req.batch.clone();
1782 let resolved = tokio::task::block_in_place(|| {
1783 crate::external::resolve_external_location(&outer_batch, &outer_md, cfg)
1784 });
1785 match resolved {
1786 Ok((inner_batch, _user_md)) => {
1787 req.batch = inner_batch;
1788 }
1789 Err(err) => {
1790 return arrow_error(&state, StatusCode::BAD_REQUEST, &err, &req.request_id);
1791 }
1792 }
1793 }
1794 }
1795
1796 if server.describe_enabled() && method == crate::introspect::DESCRIBE_METHOD_NAME {
1798 let (batch, md) = match crate::introspect::build_describe(
1799 server.protocol_name(),
1800 server.methods(),
1801 &server.server_id,
1802 server.protocol_version(),
1803 ) {
1804 Ok(x) => x,
1805 Err(err) => {
1806 return arrow_error(
1807 &state,
1808 StatusCode::INTERNAL_SERVER_ERROR,
1809 &err,
1810 &req.request_id,
1811 );
1812 }
1813 };
1814 let mut buf = Vec::new();
1815 let _ = crate::introspect::write_describe_response(&mut buf, &batch, &md);
1816 return arrow_response(StatusCode::OK, buf);
1817 }
1818
1819 let Some(info) = server
1820 .method(&method)
1821 .filter(|m| m.method_type == MethodType::Unary)
1822 else {
1823 let err = RpcError::attribute_error(format!("Unknown method: '{}'", method));
1824 return arrow_error(&state, StatusCode::NOT_FOUND, &err, &req.request_id);
1825 };
1826
1827 let mut ctx = CallContext::with_auth_cookies(&server, &req, auth.clone(), cookies);
1828 if let Some(s) = sticky_sink.clone() {
1829 ctx.set_sticky(s);
1830 }
1831 let dispatch_info = crate::hooks::DispatchInfo::from_request(&server, &req, "unary", &auth);
1832 let hook = server.dispatch_hook.clone();
1833 let hook_token = hook.as_ref().map(|h| h.on_dispatch_start(&dispatch_info));
1834
1835 let mut stats = crate::hooks::CallStatistics {
1836 input_batches: 1,
1837 input_rows: req.batch.num_rows() as u64,
1838 ..Default::default()
1839 };
1840
1841 let result =
1846 crate::server::call_guard(|| (info.unary.as_ref().unwrap())(&req, &ctx)).and_then(|r| r);
1847 let logs = ctx.drain_logs();
1848 let mut app_err: Option<RpcError> = None;
1849
1850 let mut buf = Vec::new();
1851 {
1852 let mut sw = StreamWriter::new(&mut buf, &info.result_schema).unwrap();
1853 for log in &logs {
1854 let md = build_log_metadata(log, &server.server_id, &req.request_id);
1855 let _ = sw.write(&empty_batch(&info.result_schema).unwrap(), Some(&md));
1856 }
1857 match result {
1858 Ok(batch_opt) => {
1859 let out_batch =
1860 batch_opt.unwrap_or_else(|| empty_batch(&info.result_schema).unwrap());
1861 stats.output_batches = 1;
1862 stats.output_rows = out_batch.num_rows() as u64;
1863 if let Some(cfg) = server.external_config().as_ref() {
1864 let externalized = tokio::task::block_in_place(|| {
1869 crate::external::maybe_externalize_batch(&out_batch, None, cfg)
1870 });
1871 match externalized {
1872 Ok(Some((ptr, md))) => {
1873 let _ = sw.write(&ptr, Some(&md));
1874 }
1875 Ok(None) => {
1876 let _ = sw.write(&out_batch, None);
1877 }
1878 Err(err) => {
1879 let md = build_error_metadata(&err, &server.server_id, &req.request_id);
1880 let _ = sw.write(&empty_batch(&info.result_schema).unwrap(), Some(&md));
1881 app_err = Some(err);
1882 }
1883 }
1884 } else {
1885 let _ = sw.write(&out_batch, None);
1886 }
1887 }
1888 Err(err) => {
1889 let md = build_error_metadata(&err, &server.server_id, &req.request_id);
1890 let _ = sw.write(&empty_batch(&info.result_schema).unwrap(), Some(&md));
1891 app_err = Some(err);
1892 }
1893 }
1894 let _ = sw.finish();
1895 }
1896
1897 if let Some(hook) = hook {
1898 hook.on_dispatch_end(
1899 hook_token.unwrap_or(0),
1900 &dispatch_info,
1901 app_err.as_ref(),
1902 &stats,
1903 );
1904 }
1905 if let Some(limit) = state.max_response_bytes {
1911 if buf.len() > limit {
1912 let err = RpcError::runtime_error(format!(
1913 "HTTP body exceeds max_response_bytes ({} > {}) for method {:?}",
1914 buf.len(),
1915 limit,
1916 method
1917 ));
1918 let mut resp = cap_error_response(
1919 &info.result_schema,
1920 &err,
1921 &server.server_id,
1922 &req.request_id,
1923 );
1924 if let Some(s) = sticky_sink.as_ref() {
1925 stamp_session_headers(&mut resp, &state, s);
1926 }
1927 return resp;
1928 }
1929 }
1930 let mut resp = arrow_response(StatusCode::OK, buf);
1931 if let Some(s) = sticky_sink.as_ref() {
1932 stamp_session_headers(&mut resp, &state, s);
1933 }
1934 resp
1935}
1936
1937async fn handle_stream_init(
1942 State(state): State<Arc<HttpState>>,
1943 Path(method): Path<String>,
1944 headers: HeaderMap,
1945 body: Bytes,
1946) -> Response {
1947 let auth = match authenticate_request(&state, &method, &headers) {
1948 Ok(a) => a,
1949 Err(resp) => return resp,
1950 };
1951 if !has_arrow_ct(&headers) {
1952 return plain_error(
1953 StatusCode::UNSUPPORTED_MEDIA_TYPE,
1954 "need arrow content type".into(),
1955 );
1956 }
1957 let sticky_sink = match sticky_for_request(&state, &auth, &headers) {
1958 Ok(s) => s,
1959 Err(resp) => return resp,
1960 };
1961 let auth_for_token = auth.clone();
1962 let cookies = parse_cookies(headers.get(header::COOKIE).and_then(|v| v.to_str().ok()));
1963 let server = state.server.clone();
1964 let body = match maybe_decompress(&headers, &body, state.max_body_size) {
1965 Ok(b) => b,
1966 Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
1967 };
1968 let req = match parse_request_from_body(&body) {
1969 Ok(r) => r,
1970 Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
1971 };
1972
1973 let Some(info) = server
1974 .method(&method)
1975 .filter(|m| m.method_type != MethodType::Unary)
1976 else {
1977 let err = RpcError::attribute_error(format!("Unknown stream method: '{}'", method));
1978 return arrow_error(&state, StatusCode::NOT_FOUND, &err, &req.request_id);
1979 };
1980
1981 let mut ctx = CallContext::with_auth_cookies(&server, &req, auth, cookies);
1982 if let Some(s) = sticky_sink.clone() {
1983 ctx.set_sticky(s);
1984 }
1985 let init_result =
1988 crate::server::call_guard(|| (info.stream.as_ref().unwrap())(&req, &ctx)).and_then(|r| r);
1989 let init_logs = ctx.drain_logs();
1990
1991 let sr = match init_result {
1992 Ok(s) => s,
1993 Err(err) => {
1994 return arrow_response(
1995 StatusCode::OK,
1996 error_stream_bytes(&empty_schema(), &err, &server.server_id, &req.request_id),
1997 );
1998 }
1999 };
2000
2001 let StreamResult {
2002 output_schema,
2003 input_schema,
2004 state: mut ss,
2005 header,
2006 header_metadata,
2007 } = sr;
2008
2009 let mut body_buf = Vec::new();
2010
2011 if let Some(header_batch) = header.as_ref() {
2013 let hdr_schema = header_batch.schema();
2014 let mut hw = StreamWriter::new(&mut body_buf, hdr_schema.as_ref()).unwrap();
2015 for log in &init_logs {
2016 let md = build_log_metadata(log, &server.server_id, &req.request_id);
2017 let _ = hw.write(&empty_batch(hdr_schema.as_ref()).unwrap(), Some(&md));
2018 }
2019 let _ = hw.write(header_batch, header_metadata.as_ref());
2020 let _ = hw.finish();
2021 }
2022
2023 let is_producer = matches!(ss, StreamStateKind::Producer(_));
2024 let stream_id = new_session_id();
2025
2026 let mut finished = false;
2027 let mut init_error: Option<RpcError> = None;
2028 {
2029 let mut sw = StreamWriter::new(&mut body_buf, output_schema.as_ref()).unwrap();
2030 if header.is_none() {
2031 for log in &init_logs {
2032 let md = build_log_metadata(log, &server.server_id, &req.request_id);
2033 let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2034 }
2035 }
2036 let _ = header_metadata;
2037 if is_producer {
2038 finished = run_producer(
2039 &mut sw,
2040 &mut ss,
2041 &output_schema,
2042 &server,
2043 &req,
2044 state.producer_batch_limit,
2045 sticky_sink.as_ref(),
2046 );
2047 }
2048 if !finished {
2049 match build_continuation_token(
2050 &state,
2051 &auth_for_token,
2052 &ss,
2053 &output_schema,
2054 input_schema.as_ref(),
2055 &stream_id,
2056 ) {
2057 Ok(token) => {
2058 let md = Metadata::from([(STATE_KEY.to_string(), token)]);
2059 let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2060 }
2061 Err(err) => {
2062 let md = build_error_metadata(&err, &server.server_id, &req.request_id);
2066 let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2067 init_error = Some(err);
2068 }
2069 }
2070 }
2071 let _ = sw.finish();
2072 }
2073 let _ = init_error; let mut resp = arrow_response(StatusCode::OK, body_buf);
2076 if let Some(s) = sticky_sink.as_ref() {
2077 stamp_session_headers(&mut resp, &state, s);
2078 }
2079 resp
2080}
2081
2082fn build_continuation_token(
2086 state: &Arc<HttpState>,
2087 auth: &crate::auth::AuthContext,
2088 ss: &StreamStateKind,
2089 output_schema: &SchemaRef,
2090 input_schema: Option<&SchemaRef>,
2091 stream_id: &str,
2092) -> Result<String> {
2093 let state_bytes = match ss {
2094 StreamStateKind::Producer(p) => p.encode_state()?,
2095 StreamStateKind::Exchange(e) => e.encode_state()?,
2096 };
2097 let out_schema_bytes = write_schema_bytes(output_schema.as_ref())?;
2098 let in_schema_bytes = match input_schema {
2099 Some(s) => write_schema_bytes(s.as_ref())?,
2100 None => Vec::new(),
2101 };
2102 Ok(state.pack_state_token(
2103 auth,
2104 &state_bytes,
2105 &out_schema_bytes,
2106 &in_schema_bytes,
2107 stream_id,
2108 ))
2109}
2110
2111fn run_producer<W: std::io::Write>(
2112 sw: &mut StreamWriter<W>,
2113 ss: &mut StreamStateKind,
2114 output_schema: &SchemaRef,
2115 server: &Arc<RpcServer>,
2116 req: &Request,
2117 limit: usize,
2118 sticky: Option<&Arc<crate::sticky::StickySinkImpl>>,
2119) -> bool {
2120 let mut ctx = CallContext::for_request(server, req);
2122 if let Some(s) = sticky {
2123 ctx.set_sticky(s.clone());
2124 }
2125 let producer = match ss {
2126 StreamStateKind::Producer(p) => p,
2127 StreamStateKind::Exchange(_) => unreachable!(),
2128 };
2129 let limit = producer.batch_limit().unwrap_or(limit);
2132 let mut batches_written = 0usize;
2133 while limit == 0 || batches_written < limit {
2134 let mut out = OutputCollector::new(output_schema.clone(), true);
2135 let result = producer.produce(&mut out, &ctx);
2136 for log in ctx.drain_logs() {
2137 let md = build_log_metadata(&log, &server.server_id, &req.request_id);
2138 let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2139 }
2140 if let Err(err) = result {
2141 let md = build_error_metadata(&err, &server.server_id, &req.request_id);
2142 let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2143 return true;
2144 }
2145 let finished = out.finished();
2146 let mut emitted_data = false;
2147 for item in out.items.drain(..) {
2148 match item {
2149 Emitted::Log(log) => {
2150 let md = build_log_metadata(&log, &server.server_id, &req.request_id);
2151 let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2152 }
2153 Emitted::Batch { batch, metadata } => {
2154 let _ = sw.write(&batch, metadata.as_ref());
2155 emitted_data = true;
2156 }
2157 }
2158 }
2159 if emitted_data {
2160 batches_written += 1;
2161 }
2162 if finished {
2163 return true;
2164 }
2165 if !emitted_data {
2166 return true;
2168 }
2169 }
2170 false
2171}
2172
2173async fn handle_stream_exchange(
2178 State(state): State<Arc<HttpState>>,
2179 Path(method): Path<String>,
2180 headers: HeaderMap,
2181 body: Bytes,
2182) -> Response {
2183 let auth = match authenticate_request(&state, &method, &headers) {
2184 Ok(a) => a,
2185 Err(resp) => return resp,
2186 };
2187 if !has_arrow_ct(&headers) {
2188 return plain_error(
2189 StatusCode::UNSUPPORTED_MEDIA_TYPE,
2190 "need arrow content type".into(),
2191 );
2192 }
2193 let sticky_sink = match sticky_for_request(&state, &auth, &headers) {
2194 Ok(s) => s,
2195 Err(resp) => return resp,
2196 };
2197
2198 let server = state.server.clone();
2199 let body = match maybe_decompress(&headers, &body, state.max_body_size) {
2200 Ok(b) => b,
2201 Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
2202 };
2203 let (batch, metadata) = match read_input_batch(&body) {
2205 Ok(x) => x,
2206 Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
2207 };
2208
2209 let Some(token) = md_get(&metadata, STATE_KEY).map(str::to_owned) else {
2210 let err = RpcError::runtime_error("Missing state token in exchange request");
2211 return arrow_error(&state, StatusCode::BAD_REQUEST, &err, "");
2212 };
2213 let cancelled = md_get(&metadata, CANCEL_KEY).is_some();
2214
2215 let unpacked = match state.unpack_state_token(&auth, &token) {
2216 Ok(u) => u,
2217 Err(err) => return arrow_error(&state, StatusCode::BAD_REQUEST, &err, ""),
2218 };
2219
2220 let output_schema = match read_schema_bytes(&unpacked.output_schema_bytes) {
2222 Ok(s) => s,
2223 Err(err) => return arrow_error(&state, StatusCode::BAD_REQUEST, &err, ""),
2224 };
2225 let input_schema: Option<SchemaRef> = if unpacked.input_schema_bytes.is_empty() {
2226 None
2227 } else {
2228 match read_schema_bytes(&unpacked.input_schema_bytes) {
2229 Ok(s) => Some(s),
2230 Err(err) => return arrow_error(&state, StatusCode::BAD_REQUEST, &err, ""),
2231 }
2232 };
2233
2234 let Some(info) = server
2236 .method(&method)
2237 .filter(|m| m.method_type != MethodType::Unary)
2238 else {
2239 let err = RpcError::attribute_error(format!("Unknown stream method: '{}'", method));
2240 return arrow_error(&state, StatusCode::NOT_FOUND, &err, "");
2241 };
2242 let Some(decoder) = info.state_decoder.as_ref() else {
2243 let err = RpcError::runtime_error(format!(
2244 "Stream method '{method}' is registered without a state decoder; \
2245 it cannot serve HTTP continuation requests"
2246 ));
2247 return arrow_error(&state, StatusCode::INTERNAL_SERVER_ERROR, &err, "");
2248 };
2249 let mut ss = match decoder(&unpacked.state_bytes) {
2250 Ok(s) => s,
2251 Err(err) => return arrow_error(&state, StatusCode::BAD_REQUEST, &err, ""),
2252 };
2253
2254 let req = Request {
2255 method: method.clone(),
2256 request_id: md_get(&metadata, REQUEST_ID_KEY).unwrap_or("").to_string(),
2257 batch: empty_batch(&Schema::empty()).unwrap(),
2258 metadata: metadata.clone(),
2259 };
2260 let mut ctx = CallContext::for_request(&server, &req);
2261 if let Some(s) = sticky_sink.clone() {
2262 ctx.set_sticky(s);
2263 }
2264
2265 let mut body_buf = Vec::new();
2266
2267 if cancelled {
2268 match &mut ss {
2269 StreamStateKind::Producer(p) => p.on_cancel(&ctx),
2270 StreamStateKind::Exchange(e) => e.on_cancel(&ctx),
2271 }
2272 {
2273 let mut sw = StreamWriter::new(&mut body_buf, output_schema.as_ref()).unwrap();
2274 let _ = sw.finish();
2275 }
2276 let mut resp = arrow_response(StatusCode::OK, body_buf);
2277 if let Some(s) = sticky_sink.as_ref() {
2278 stamp_session_headers(&mut resp, &state, s);
2279 }
2280 return resp;
2281 }
2282
2283 if matches!(ss, StreamStateKind::Producer(_)) {
2284 let finished;
2286 {
2287 let mut sw = StreamWriter::new(&mut body_buf, output_schema.as_ref()).unwrap();
2288 finished = run_producer(
2289 &mut sw,
2290 &mut ss,
2291 &output_schema,
2292 &server,
2293 &req,
2294 state.producer_batch_limit,
2295 sticky_sink.as_ref(),
2296 );
2297 if !finished {
2298 match build_continuation_token(
2299 &state,
2300 &auth,
2301 &ss,
2302 &output_schema,
2303 input_schema.as_ref(),
2304 &unpacked.stream_id,
2305 ) {
2306 Ok(new_token) => {
2307 let md = Metadata::from([(STATE_KEY.to_string(), new_token)]);
2308 let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2309 }
2310 Err(err) => {
2311 let md = build_error_metadata(&err, &server.server_id, &req.request_id);
2312 let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2313 }
2314 }
2315 }
2316 let _ = sw.finish();
2317 }
2318 let mut resp = arrow_response(StatusCode::OK, body_buf);
2319 if let Some(s) = sticky_sink.as_ref() {
2320 stamp_session_headers(&mut resp, &state, s);
2321 }
2322 return resp;
2323 }
2324
2325 let casted = match &input_schema {
2327 Some(exp) if batch.schema() != *exp => match cast_batch(&batch, exp) {
2328 Ok(b) => b,
2329 Err(e) => {
2330 let mut sw = StreamWriter::new(&mut body_buf, output_schema.as_ref()).unwrap();
2331 let md = build_error_metadata(&e, &server.server_id, &req.request_id);
2332 let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2333 let _ = sw.finish();
2334 drop(sw);
2335 return arrow_response(StatusCode::OK, body_buf);
2336 }
2337 },
2338 _ => batch,
2339 };
2340
2341 let mut out = OutputCollector::new(output_schema.clone(), false);
2342 let res = match &mut ss {
2343 StreamStateKind::Exchange(e) => e.exchange(&casted, &mut out, &ctx),
2344 _ => unreachable!(),
2345 };
2346
2347 {
2348 let mut sw = StreamWriter::new(&mut body_buf, output_schema.as_ref()).unwrap();
2349 for log in ctx.drain_logs() {
2350 let md = build_log_metadata(&log, &server.server_id, &req.request_id);
2351 let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2352 }
2353 if let Err(err) = res {
2354 let md = build_error_metadata(&err, &server.server_id, &req.request_id);
2355 let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2356 } else {
2357 let new_token = match build_continuation_token(
2358 &state,
2359 &auth,
2360 &ss,
2361 &output_schema,
2362 input_schema.as_ref(),
2363 &unpacked.stream_id,
2364 ) {
2365 Ok(t) => t,
2366 Err(err) => {
2367 let md = build_error_metadata(&err, &server.server_id, &req.request_id);
2368 let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2369 let _ = sw.finish();
2370 drop(sw);
2371 return arrow_response(StatusCode::OK, body_buf);
2372 }
2373 };
2374 let mut wrote_data = false;
2375 for item in out.items.drain(..) {
2376 match item {
2377 Emitted::Log(log) => {
2378 let md = build_log_metadata(&log, &server.server_id, &req.request_id);
2379 let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2380 }
2381 Emitted::Batch { batch, metadata } => {
2382 let mut md = metadata.unwrap_or_default();
2383 md.insert(STATE_KEY.to_string(), new_token.clone());
2384 let _ = sw.write(&batch, Some(&md));
2385 wrote_data = true;
2386 }
2387 }
2388 }
2389 if !wrote_data {
2390 let md = Metadata::from([(STATE_KEY.to_string(), new_token)]);
2391 let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2392 }
2393 }
2394 let _ = sw.finish();
2395 }
2396
2397 let mut resp = enforce_response_body_cap(
2398 &state,
2399 output_schema.as_ref(),
2400 body_buf,
2401 &method,
2402 &server.server_id,
2403 "",
2404 );
2405 if let Some(s) = sticky_sink.as_ref() {
2406 stamp_session_headers(&mut resp, &state, s);
2407 }
2408 resp
2409}
2410
2411fn read_input_batch(body: &[u8]) -> Result<(RecordBatch, Metadata)> {
2412 let mut r = StreamReader::new(body)?;
2413 let (batch, metadata) = r
2414 .read_next()?
2415 .ok_or_else(|| RpcError::runtime_error("no batch in exchange request"))?;
2416 r.drain()?;
2417 Ok((batch, metadata))
2418}
2419
2420#[cfg(test)]
2421mod tests {
2422 use super::*;
2423 use std::time::Duration;
2424
2425 fn state_with_key() -> Arc<HttpState> {
2426 use crate::server::RpcServer;
2427 let server = Arc::new(RpcServer::builder().server_id("test").build());
2428 HttpState::builder()
2429 .server(server)
2430 .token_key(&[7u8; 32])
2431 .token_ttl(Duration::from_millis(50))
2432 .max_body_size(1024)
2433 .build()
2434 }
2435
2436 fn sample_schema_bytes() -> Vec<u8> {
2437 use arrow_schema::{DataType, Field, Schema};
2438 write_schema_bytes(&Schema::new(vec![Field::new("x", DataType::Int64, false)])).unwrap()
2439 }
2440
2441 #[tokio::test]
2442 async fn pack_unpack_roundtrip() {
2443 let s = state_with_key();
2444 let auth = crate::auth::AuthContext::anonymous();
2445 let state_bytes = b"state-payload";
2446 let out_sch = sample_schema_bytes();
2447 let in_sch = sample_schema_bytes();
2448 let token = s.pack_state_token(&auth, state_bytes, &out_sch, &in_sch, "sid-123");
2449 let unpacked = s.unpack_state_token(&auth, &token).unwrap();
2450 assert_eq!(unpacked.state_bytes, state_bytes);
2451 assert_eq!(unpacked.output_schema_bytes, out_sch);
2452 assert_eq!(unpacked.input_schema_bytes, in_sch);
2453 assert_eq!(unpacked.stream_id, "sid-123");
2454 }
2455
2456 #[tokio::test]
2457 async fn unpack_rejects_tampered_ciphertext() {
2458 let s = state_with_key();
2459 let auth = crate::auth::AuthContext::anonymous();
2460 let token = s.pack_state_token(&auth, b"s", b"o", b"i", "sid");
2461 let mut bytes = base64::engine::general_purpose::STANDARD
2462 .decode(token.as_bytes())
2463 .unwrap();
2464 let cipher_idx = 1 + 24;
2467 bytes[cipher_idx] ^= 0x01;
2468 let tampered = base64::engine::general_purpose::STANDARD.encode(bytes);
2469 assert!(s.unpack_state_token(&auth, &tampered).is_err());
2470 }
2471
2472 #[tokio::test]
2473 async fn unpack_rejects_tampered_nonce() {
2474 let s = state_with_key();
2475 let auth = crate::auth::AuthContext::anonymous();
2476 let token = s.pack_state_token(&auth, b"s", b"o", b"i", "sid");
2477 let mut bytes = base64::engine::general_purpose::STANDARD
2478 .decode(token.as_bytes())
2479 .unwrap();
2480 bytes[1] ^= 0x01;
2482 let tampered = base64::engine::general_purpose::STANDARD.encode(bytes);
2483 assert!(s.unpack_state_token(&auth, &tampered).is_err());
2484 }
2485
2486 #[tokio::test]
2487 async fn unpack_rejects_unknown_version() {
2488 let s = state_with_key();
2489 let auth = crate::auth::AuthContext::anonymous();
2490 let token = s.pack_state_token(&auth, b"s", b"o", b"i", "sid");
2491 let mut bytes = base64::engine::general_purpose::STANDARD
2492 .decode(token.as_bytes())
2493 .unwrap();
2494 bytes[0] = 0x99;
2495 let tampered = base64::engine::general_purpose::STANDARD.encode(bytes);
2496 let err = s.unpack_state_token(&auth, &tampered).unwrap_err();
2497 assert!(err.message.contains("signature verification failed"));
2500 }
2501
2502 #[tokio::test]
2503 async fn unpack_rejects_malformed_base64() {
2504 let s = state_with_key();
2505 let auth = crate::auth::AuthContext::anonymous();
2506 let err = s.unpack_state_token(&auth, "not!base64!").unwrap_err();
2507 assert!(err.message.contains("Malformed"));
2508 }
2509
2510 #[tokio::test]
2511 async fn unpack_rejects_different_key() {
2512 use crate::server::RpcServer;
2513 let server = Arc::new(RpcServer::builder().server_id("t").build());
2514 let a = HttpState::builder()
2515 .server(server.clone())
2516 .token_key(&[1u8; 32])
2517 .build();
2518 let b = HttpState::builder()
2519 .server(server)
2520 .token_key(&[2u8; 32])
2521 .build();
2522 let auth = crate::auth::AuthContext::anonymous();
2523 let tok = a.pack_state_token(&auth, b"s", b"o", b"i", "sid");
2524 assert!(b.unpack_state_token(&auth, &tok).is_err());
2525 }
2526
2527 #[tokio::test]
2528 async fn unpack_rejects_expired_token() {
2529 let s = state_with_key(); let auth = crate::auth::AuthContext::anonymous();
2531 let aad = compute_aad(&auth);
2534 let stale = pack_state_token(&[7u8; 32], &aad, b"s", b"o", b"i", "sid", 0);
2535 let err = s.unpack_state_token(&auth, &stale).unwrap_err();
2536 assert!(err.message.contains("expired"), "got: {}", err.message);
2537 }
2538
2539 #[tokio::test]
2540 async fn unpack_rejects_different_principal() {
2541 let s = state_with_key();
2542 let alice = crate::auth::AuthContext::for_principal("bearer", "alice");
2543 let bob = crate::auth::AuthContext::for_principal("bearer", "bob");
2544 let tok = s.pack_state_token(&alice, b"s", b"o", b"i", "sid");
2545 assert!(s.unpack_state_token(&alice, &tok).is_ok());
2546 assert!(s.unpack_state_token(&bob, &tok).is_err());
2547 let anon = crate::auth::AuthContext::anonymous();
2548 assert!(s.unpack_state_token(&anon, &tok).is_err());
2549 }
2550
2551 #[tokio::test]
2552 async fn unpack_rejects_authenticated_replay_of_anonymous_token() {
2553 let s = state_with_key();
2554 let anon = crate::auth::AuthContext::anonymous();
2555 let alice = crate::auth::AuthContext::for_principal("bearer", "alice");
2556 let tok = s.pack_state_token(&anon, b"s", b"o", b"i", "sid");
2557 assert!(s.unpack_state_token(&alice, &tok).is_err());
2558 }
2559
2560 #[tokio::test]
2561 async fn unpack_rejects_cross_domain_replay() {
2562 let s = state_with_key();
2563 let bearer_alice = crate::auth::AuthContext::for_principal("bearer", "alice");
2564 let mtls_alice = crate::auth::AuthContext::for_principal("mtls", "alice");
2565 let tok = s.pack_state_token(&bearer_alice, b"s", b"o", b"i", "sid");
2566 assert!(s.unpack_state_token(&mtls_alice, &tok).is_err());
2567 }
2568
2569 #[tokio::test]
2570 async fn decompress_rejects_oversize() {
2571 let hdr = HeaderMap::new();
2572 let body = Bytes::from(vec![0u8; 1025]);
2573 let err = super::maybe_decompress(&hdr, &body, 1024).unwrap_err();
2574 assert!(err.message.contains("exceeds max size"));
2575 }
2576
2577 #[test]
2578 fn zstd_bounded_rejects_zip_bomb_without_full_alloc() {
2579 let huge = vec![0u8; 8 * 1024 * 1024];
2583 let compressed = zstd::encode_all(huge.as_slice(), 1).unwrap();
2584 assert!(compressed.len() < 100_000, "compressed should be tiny");
2585 let err = super::decode_zstd_bounded(&compressed, 64 * 1024).unwrap_err();
2586 assert!(
2587 err.message.contains("exceeds max size"),
2588 "expected oversize error, got: {}",
2589 err.message
2590 );
2591 }
2592
2593 #[test]
2594 fn zstd_bounded_passes_small_payload() {
2595 let small = b"hello-world".repeat(10);
2596 let compressed = zstd::encode_all(small.as_slice(), 1).unwrap();
2597 let out = super::decode_zstd_bounded(&compressed, 1024).unwrap();
2598 assert_eq!(out, small);
2599 }
2600
2601 #[test]
2602 fn decode_hex_key_roundtrip() {
2603 let key =
2604 decode_hex_key("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f")
2605 .unwrap();
2606 assert_eq!(key.len(), 32);
2607 assert_eq!(key[0], 0x00);
2608 assert_eq!(key[31], 0x1f);
2609 }
2610
2611 #[test]
2612 fn decode_hex_key_rejects_short() {
2613 assert!(decode_hex_key("deadbeef").is_err());
2614 }
2615
2616 #[test]
2617 fn decode_hex_key_rejects_bad_char() {
2618 assert!(decode_hex_key(&"zz".repeat(32)).is_err());
2619 }
2620
2621 #[test]
2622 fn decode_base64_key_accepts_padded() {
2623 let s = base64::engine::general_purpose::STANDARD.encode([7u8; 32]);
2624 let out = decode_base64_key(&s).unwrap();
2625 assert_eq!(out, vec![7u8; 32]);
2626 }
2627
2628 #[test]
2629 fn decode_base64_key_accepts_unpadded() {
2630 let s = base64::engine::general_purpose::STANDARD
2631 .encode([7u8; 32])
2632 .trim_end_matches('=')
2633 .to_string();
2634 let out = decode_base64_key(&s).unwrap();
2635 assert_eq!(out, vec![7u8; 32]);
2636 }
2637
2638 #[test]
2639 fn decode_base64_key_rejects_short() {
2640 let s = base64::engine::general_purpose::STANDARD.encode(b"short");
2641 assert!(decode_base64_key(&s).is_err());
2642 }
2643
2644 #[tokio::test]
2645 async fn token_key_hex_round_trips_through_token() {
2646 use crate::server::RpcServer;
2647 let server = Arc::new(RpcServer::builder().server_id("t").build());
2648 let a = HttpState::builder()
2649 .server(server.clone())
2650 .token_key_hex("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f")
2651 .build();
2652 let b = HttpState::builder()
2653 .server(server)
2654 .token_key_hex("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f")
2655 .build();
2656 let auth = crate::auth::AuthContext::anonymous();
2657 let tok = a.pack_state_token(&auth, b"s", b"o", b"i", "sid");
2658 assert_eq!(b.unpack_state_token(&auth, &tok).unwrap().stream_id, "sid");
2659 }
2660
2661 #[test]
2662 fn response_buffer_ceiling_never_below_hard_cap() {
2663 use crate::server::RpcServer;
2664 let mk = |soft: Option<usize>| {
2665 let server = Arc::new(RpcServer::builder().server_id("t").build());
2666 let mut b = HttpState::builder().server(server).token_key(&[9u8; 32]);
2667 if let Some(n) = soft {
2668 b = b.max_response_bytes(n);
2669 }
2670 b.build()
2671 };
2672 assert_eq!(
2674 response_buffer_ceiling(&mk(None)),
2675 MAX_RESPONSE_BYTES_HARD_CAP
2676 );
2677 assert_eq!(
2681 response_buffer_ceiling(&mk(Some(8))),
2682 MAX_RESPONSE_BYTES_HARD_CAP
2683 );
2684 let big = MAX_RESPONSE_BYTES_HARD_CAP;
2686 assert_eq!(response_buffer_ceiling(&mk(Some(big))), big * 2);
2687 }
2688}