Skip to main content

vgi_rpc/
http.rs

1//! HTTP transport. Implements the vgi-rpc protocol over HTTP:
2//!   `POST /{method}`            unary
3//!   `POST /{method}/init`       stream init (producer or exchange)
4//!   `POST /{method}/exchange`   stream continuation
5//!
6//! Streaming is stateless on the wire: the full `StreamStateKind` is
7//! sealed into an XChaCha20-Poly1305 AEAD token (v4 wire format) carried
8//! in the `vgi_rpc.stream_state#b64` metadata key. Any worker with the
9//! same token key can resume any continuation request — no server-side
10//! session map, no reaper, no cross-worker affinity. The token contents
11//! are confidential as well as authenticated: only the server can read
12//! the serialized state.
13
14use std::sync::Arc;
15
16use arrow_array::RecordBatch;
17use arrow_schema::{Schema, SchemaRef};
18use axum::{
19    body::Bytes,
20    extract::{Path, State},
21    http::{header, HeaderMap, HeaderValue, StatusCode},
22    response::{IntoResponse, Response},
23    routing::post,
24    Router,
25};
26use base64::Engine;
27use rand::RngCore;
28
29use crate::errors::{Result, RpcError};
30use crate::metadata::{CANCEL_KEY, REQUEST_ID_KEY, STATE_KEY};
31use crate::server::{
32    build_error_metadata, build_log_metadata, cast_batch, CallContext, MethodType, Request,
33    RpcServer,
34};
35use crate::stream::{empty_schema, Emitted, OutputCollector, StreamResult, StreamStateKind};
36use crate::wire::{bytes_to_hex, empty_batch, md_get, Metadata, StreamReader, StreamWriter};
37
38pub const ARROW_CONTENT_TYPE: &str = "application/vnd.apache.arrow.stream";
39
40// Sticky-session header conventions (HTTP-only). Header names are
41// compared case-insensitively by axum's `HeaderMap`, so the lowercase
42// forms here match the canonical `VGI-Session` etc. on the wire.
43const SESSION_HEADER: &str = "vgi-session";
44const SESSION_ACCEPT_HEADER: &str = "vgi-session-accept";
45const SESSION_CLOSE_HEADER: &str = "vgi-session-close";
46const ECHO_HEADER_PREFIX: &str = "vgi-echo-";
47const STICKY_ENABLED_HEADER: &str = "vgi-sticky-enabled";
48const STICKY_DEFAULT_TTL_HEADER: &str = "vgi-sticky-default-ttl";
49const STICKY_ECHO_HEADERS_HEADER: &str = "vgi-sticky-echo-headers";
50/// Framework-managed sticky session teardown endpoint path segment.
51const SESSION_ENDPOINT: &str = "__session__";
52
53/// HTTP server state shared across all handlers.
54///
55/// Build via [`HttpState::builder`] (preferred) or [`HttpState::new`] for a
56/// default configuration.
57///
58/// Streaming is stateless: the full `StreamStateKind` travels in every
59/// HTTP continuation request inside an AEAD-sealed state token, so any
60/// worker behind a load balancer can resume any stream. No session map
61/// is held on the server.
62pub struct HttpState {
63    server: Arc<RpcServer>,
64    token_key: [u8; 32],
65    producer_batch_limit: usize,
66    token_ttl: std::time::Duration,
67    max_body_size: usize,
68    /// Wall-clock ceiling for a single HTTP request, enforced by a
69    /// `tower_http::timeout::TimeoutLayer`. A stalled handler or a
70    /// slow-loris client cannot pin a runtime worker indefinitely.
71    request_timeout: std::time::Duration,
72    authenticate: Option<crate::auth::Authenticate>,
73    #[allow(dead_code)]
74    oauth_metadata: Option<Arc<crate::auth::oauth::OAuthResourceMetadata>>,
75    oauth_metadata_json: Option<Vec<u8>>,
76    www_authenticate: Option<String>,
77    cors_origins: Option<String>,
78    cors_max_age: u32,
79    prefix: String,
80    response_compression_level: Option<i32>,
81    landing_page_enabled: bool,
82    describe_page_enabled: bool,
83    health_enabled: bool,
84    max_request_bytes: Option<usize>,
85    max_upload_bytes: Option<usize>,
86    /// Hard cap on the HTTP body size for unary and stream-exchange
87    /// responses (advertised via `VGI-Max-Response-Bytes`).  `None` =
88    /// unbounded.  Externalised payloads do not count toward this — they
89    /// leave only tiny pointer batches on the wire.
90    max_response_bytes: Option<usize>,
91    /// Hard cap on bytes uploaded to external storage during one HTTP
92    /// response (advertised via `VGI-Max-Externalized-Response-Bytes`).
93    /// Always hard — externalised uploads have no escape valve.
94    max_externalized_response_bytes: Option<usize>,
95    upload_url_provider: Option<Arc<dyn crate::external::UploadUrlProvider>>,
96    /// Sticky-session context, `Some` when the server is sticky-enabled.
97    sticky: Option<Arc<crate::sticky::StickyContext>>,
98}
99
100/// Fluent builder for [`HttpState`].
101#[derive(Default)]
102pub struct HttpStateBuilder {
103    server: Option<Arc<RpcServer>>,
104    token_key: Option<[u8; 32]>,
105    producer_batch_limit: Option<usize>,
106    token_ttl: Option<std::time::Duration>,
107    max_body_size: Option<usize>,
108    request_timeout: Option<std::time::Duration>,
109    authenticate: Option<crate::auth::Authenticate>,
110    oauth_metadata: Option<Arc<crate::auth::oauth::OAuthResourceMetadata>>,
111    cors_origins: Option<String>,
112    cors_max_age: Option<u32>,
113    prefix: Option<String>,
114    response_compression_level: Option<i32>,
115    landing_page_enabled: Option<bool>,
116    describe_page_enabled: Option<bool>,
117    health_enabled: Option<bool>,
118    max_request_bytes: Option<usize>,
119    max_upload_bytes: Option<usize>,
120    max_response_bytes: Option<usize>,
121    max_externalized_response_bytes: Option<usize>,
122    upload_url_provider: Option<Arc<dyn crate::external::UploadUrlProvider>>,
123    enable_sticky: Option<bool>,
124    sticky_default_ttl: Option<std::time::Duration>,
125    sticky_echo_headers: Vec<(String, String)>,
126}
127
128impl HttpStateBuilder {
129    pub fn server(mut self, server: Arc<RpcServer>) -> Self {
130        self.server = Some(server);
131        self
132    }
133
134    /// AEAD master key used to seal state tokens. **Must be ≥32 bytes**
135    /// — the XChaCha20-Poly1305 key size; the first 32 bytes are used. A
136    /// shorter slice is a configuration error and panics rather than
137    /// being silently zero-padded into a weak key. When not set, a
138    /// random 32-byte key is generated at `build()` time.
139    pub fn token_key(mut self, key: &[u8]) -> Self {
140        assert!(
141            key.len() >= 32,
142            "token_key: signing key must be at least 32 bytes (got {}). \
143             Generate one with `openssl rand -hex 32`.",
144            key.len()
145        );
146        let mut k = [0u8; 32];
147        k.copy_from_slice(&key[..32]);
148        self.token_key = Some(k);
149        self
150    }
151
152    /// Set the token key from a lowercase-hex string (64 hex chars →
153    /// 32 bytes). Panics on invalid input — intended for startup config,
154    /// not runtime callers.
155    pub fn token_key_hex(self, hex: &str) -> Self {
156        let bytes = decode_hex_key(hex).expect("token_key_hex: invalid hex or wrong length");
157        self.token_key(&bytes)
158    }
159
160    /// Set the token key from a base64-encoded string (standard alphabet,
161    /// padding optional). Panics on invalid input.
162    pub fn token_key_base64(self, b64: &str) -> Self {
163        let bytes =
164            decode_base64_key(b64).expect("token_key_base64: invalid base64 or wrong length");
165        self.token_key(&bytes)
166    }
167
168    /// Read the token key from environment variable `var`. Accepts either
169    /// base64 or lowercase-hex (auto-detected). Panics if the variable is
170    /// unset, empty, or decodes to fewer than 32 bytes. Use this for
171    /// production deployments where the key is supplied by a secret manager.
172    pub fn token_key_from_env(self, var: &str) -> Self {
173        let raw = std::env::var(var)
174            .unwrap_or_else(|_| panic!("token_key_from_env: env var {var} is unset or not UTF-8"));
175        let trimmed = raw.trim();
176        let bytes = decode_base64_key(trimmed)
177            .or_else(|_| decode_hex_key(trimmed))
178            .unwrap_or_else(|e| {
179                panic!("token_key_from_env: {var} is not valid base64 or hex ({e})")
180            });
181        self.token_key(&bytes)
182    }
183
184    /// Maximum data batches per producer HTTP response (0 = unbounded).
185    /// Default `1` to mirror the Python/Go servers.
186    pub fn producer_batch_limit(mut self, n: usize) -> Self {
187        self.producer_batch_limit = Some(n);
188        self
189    }
190
191    /// Maximum age of a state token. Continuation requests with a token
192    /// older than this are rejected. Default `5 minutes`. Set to
193    /// `Duration::ZERO` to disable TTL enforcement.
194    pub fn token_ttl(mut self, ttl: std::time::Duration) -> Self {
195        self.token_ttl = Some(ttl);
196        self
197    }
198
199    /// Maximum request body size (post-decompression) in bytes. Default
200    /// `64 * 1024 * 1024` (64 MiB). Enforced as a hard ceiling on the
201    /// raw request body by a `RequestBodyLimitLayer` — independent of the
202    /// `Content-Length` header, so a chunked upload cannot bypass it.
203    pub fn max_body_size(mut self, n: usize) -> Self {
204        self.max_body_size = Some(n);
205        self
206    }
207
208    /// Wall-clock timeout for a single HTTP request. Default 30 s.
209    pub fn request_timeout(mut self, d: std::time::Duration) -> Self {
210        self.request_timeout = Some(d);
211        self
212    }
213
214    /// Register an authenticate callback run on every request. Not set →
215    /// anonymous for all callers (mirrors the Python `make_wsgi_app` default).
216    pub fn authenticate(mut self, cb: crate::auth::Authenticate) -> Self {
217        self.authenticate = Some(cb);
218        self
219    }
220
221    /// Attach RFC 9728 Protected Resource Metadata. When set, the server
222    /// exposes `/.well-known/oauth-protected-resource` and includes a
223    /// `WWW-Authenticate` header on 401 responses.
224    pub fn oauth_resource_metadata(
225        mut self,
226        metadata: crate::auth::oauth::OAuthResourceMetadata,
227    ) -> Self {
228        self.oauth_metadata = Some(Arc::new(metadata));
229        self
230    }
231
232    /// Enable CORS with the given `Access-Control-Allow-Origin` value.
233    /// Pass `"*"` for a permissive server or a specific origin URL.
234    pub fn cors_origins(mut self, origins: impl Into<String>) -> Self {
235        self.cors_origins = Some(origins.into());
236        self
237    }
238
239    /// Override the preflight cache lifetime (seconds). Default `7200`.
240    pub fn cors_max_age(mut self, seconds: u32) -> Self {
241        self.cors_max_age = Some(seconds);
242        self
243    }
244
245    /// Mount the router under a URL prefix (e.g. `/v1`). Default empty.
246    pub fn prefix(mut self, prefix: impl Into<String>) -> Self {
247        self.prefix = Some(prefix.into());
248        self
249    }
250
251    /// Enable zstd response compression at the given level (1..=22) when
252    /// the client sends `Accept-Encoding: zstd`. Default off.
253    pub fn response_compression_level(mut self, level: i32) -> Self {
254        self.response_compression_level = Some(level);
255        self
256    }
257
258    /// Serve a friendly HTML landing page at `GET /`. Default on.
259    pub fn enable_landing_page(mut self, enabled: bool) -> Self {
260        self.landing_page_enabled = Some(enabled);
261        self
262    }
263
264    /// Serve an API reference HTML page at `GET /describe`. Default on.
265    pub fn enable_describe_page(mut self, enabled: bool) -> Self {
266        self.describe_page_enabled = Some(enabled);
267        self
268    }
269
270    /// Serve a liveness probe at `GET /health`. Default on.
271    pub fn enable_health(mut self, enabled: bool) -> Self {
272        self.health_enabled = Some(enabled);
273        self
274    }
275
276    /// Maximum inline-request body size advertised via the
277    /// `VGI-Max-Request-Bytes` capability header and enforced server-side
278    /// (413 Payload Too Large for non-exempt routes). When set together
279    /// with [`Self::upload_url_provider`], clients can externalize
280    /// oversize requests via `__upload_url__/init` + a pointer batch.
281    pub fn max_request_bytes(mut self, n: usize) -> Self {
282        self.max_request_bytes = Some(n);
283        self
284    }
285
286    /// Advertised upper bound on the size of any single client-vended
287    /// upload (header `VGI-Max-Upload-Bytes`). Advertisement only — no
288    /// server-side enforcement.
289    pub fn max_upload_bytes(mut self, n: usize) -> Self {
290        self.max_upload_bytes = Some(n);
291        self
292    }
293
294    /// HTTP body cap (header `VGI-Max-Response-Bytes`). Hard for unary
295    /// and stream-exchange — overshoot replaces the response with a
296    /// fresh EXCEPTION-only IPC stream surfaced via 200 +
297    /// `X-VGI-RPC-Error: true`. Externalised payloads do not count
298    /// toward this cap.
299    pub fn max_response_bytes(mut self, n: usize) -> Self {
300        self.max_response_bytes = Some(n);
301        self
302    }
303
304    /// Cap on bytes uploaded to external storage during one HTTP
305    /// response (header `VGI-Max-Externalized-Response-Bytes`).  Always
306    /// hard — externalised uploads have no escape valve.
307    pub fn max_externalized_response_bytes(mut self, n: usize) -> Self {
308        self.max_externalized_response_bytes = Some(n);
309        self
310    }
311
312    /// Install an [`UploadUrlProvider`](crate::external::UploadUrlProvider).
313    /// When set, the server exposes `POST /__upload_url__/init` and
314    /// advertises `VGI-Upload-URL-Support: true`.
315    pub fn upload_url_provider(
316        mut self,
317        provider: Arc<dyn crate::external::UploadUrlProvider>,
318    ) -> Self {
319        self.upload_url_provider = Some(provider);
320        self
321    }
322
323    /// Opt in to sticky sessions (HTTP-only). When enabled the server
324    /// advertises `VGI-Sticky-Enabled: true`, honours the `VGI-Session` /
325    /// `VGI-Session-Accept` headers, and exposes `DELETE {prefix}/__session__`.
326    /// Off by default — the non-sticky wire path is unchanged.
327    pub fn enable_sticky(mut self, enabled: bool) -> Self {
328        self.enable_sticky = Some(enabled);
329        self
330    }
331
332    /// Default session TTL when a method calls `ctx.open_session` without
333    /// an explicit TTL. Default 300 s. Advertised via `VGI-Sticky-Default-TTL`.
334    pub fn sticky_default_ttl(mut self, ttl: std::time::Duration) -> Self {
335        self.sticky_default_ttl = Some(ttl);
336        self
337    }
338
339    /// Headers the server tells the client to echo back on every
340    /// subsequent request in a session (emitted as `VGI-Echo-<name>` on
341    /// the session-opening response; advertised by name via
342    /// `VGI-Sticky-Echo-Headers`). Used for client-driven routing
343    /// (e.g. `fly-force-instance-id` on Fly.io).
344    pub fn sticky_echo_headers(
345        mut self,
346        headers: impl IntoIterator<Item = (String, String)>,
347    ) -> Self {
348        self.sticky_echo_headers = headers.into_iter().collect();
349        self
350    }
351
352    pub fn build(self) -> Arc<HttpState> {
353        let server = self.server.expect("HttpStateBuilder::server is required");
354        // A wildcard CORS origin combined with a credentialed auth
355        // callback is unsafe and a browser would refuse it anyway: an
356        // `Access-Control-Allow-Origin: *` response cannot carry
357        // `Allow-Credentials: true`. Fail fast at config time instead of
358        // shipping a server whose authenticated cross-origin requests
359        // silently break.
360        assert!(
361            !(self.cors_origins.as_deref() == Some("*") && self.authenticate.is_some()),
362            "HttpStateBuilder: cors_origins(\"*\") cannot be combined with an \
363             authenticate callback — browsers reject credentialed requests \
364             against a wildcard origin. Configure a specific origin."
365        );
366        let token_key = self.token_key.unwrap_or_else(|| {
367            tracing::warn!(
368                target: "vgi_rpc.http",
369                "no token_key configured; using ephemeral per-process AEAD key — \
370                 state tokens will not survive restart or load-balance across workers"
371            );
372            let mut k = [0u8; 32];
373            rand::thread_rng().fill_bytes(&mut k);
374            k
375        });
376        let oauth_metadata_json = self
377            .oauth_metadata
378            .as_ref()
379            .map(|m| m.to_json().into_bytes());
380        let www_authenticate = self.oauth_metadata.as_ref().map(|m| m.www_authenticate());
381        let sticky = if self.enable_sticky.unwrap_or(false) {
382            let ttl = self
383                .sticky_default_ttl
384                .unwrap_or_else(|| std::time::Duration::from_secs(300));
385            Some(crate::sticky::StickyContext::new(
386                token_key,
387                ttl,
388                self.sticky_echo_headers,
389                server.server_id.clone(),
390            ))
391        } else {
392            None
393        };
394        Arc::new(HttpState {
395            server,
396            token_key,
397            producer_batch_limit: self.producer_batch_limit.unwrap_or(1),
398            token_ttl: self
399                .token_ttl
400                .unwrap_or_else(|| std::time::Duration::from_secs(300)),
401            max_body_size: self.max_body_size.unwrap_or(64 * 1024 * 1024),
402            request_timeout: self
403                .request_timeout
404                .unwrap_or_else(|| std::time::Duration::from_secs(30)),
405            authenticate: self.authenticate,
406            oauth_metadata: self.oauth_metadata,
407            oauth_metadata_json,
408            www_authenticate,
409            cors_origins: self.cors_origins,
410            cors_max_age: self.cors_max_age.unwrap_or(7200),
411            prefix: self.prefix.unwrap_or_default(),
412            response_compression_level: self.response_compression_level,
413            landing_page_enabled: self.landing_page_enabled.unwrap_or(true),
414            describe_page_enabled: self.describe_page_enabled.unwrap_or(true),
415            health_enabled: self.health_enabled.unwrap_or(true),
416            max_request_bytes: self.max_request_bytes,
417            max_upload_bytes: self.max_upload_bytes,
418            max_response_bytes: self.max_response_bytes,
419            max_externalized_response_bytes: self.max_externalized_response_bytes,
420            upload_url_provider: self.upload_url_provider,
421            sticky,
422        })
423    }
424}
425
426impl HttpState {
427    /// Create an `HttpState` with default configuration. See [`HttpState::builder`]
428    /// for the full set of knobs.
429    pub fn new(server: Arc<RpcServer>) -> Arc<Self> {
430        Self::builder().server(server).build()
431    }
432
433    pub fn builder() -> HttpStateBuilder {
434        HttpStateBuilder::default()
435    }
436
437    /// Operator handle for graceful sticky-session drain, or `None` when
438    /// the server is not sticky-enabled. Wire it into a SIGTERM handler.
439    pub fn sticky_drain_handle(&self) -> Option<crate::sticky::DrainHandle> {
440        self.sticky.as_ref().map(|c| c.drain_handle())
441    }
442
443    pub fn token_ttl(&self) -> std::time::Duration {
444        self.token_ttl
445    }
446
447    pub fn max_body_size(&self) -> usize {
448        self.max_body_size
449    }
450
451    /// Seal a v4 state token bound to the supplied auth identity.
452    ///
453    /// `(domain, principal)` are carried as AEAD associated data, so a
454    /// token issued under one identity fails decryption when presented
455    /// by another — same anti-replay guarantee as the prior HMAC subkey
456    /// derivation, expressed via AAD instead of key derivation.
457    pub(crate) fn pack_state_token(
458        &self,
459        auth: &crate::auth::AuthContext,
460        state_bytes: &[u8],
461        output_schema_bytes: &[u8],
462        input_schema_bytes: &[u8],
463        stream_id: &str,
464    ) -> String {
465        let aad = compute_aad(auth);
466        pack_state_token(
467            &self.token_key,
468            &aad,
469            state_bytes,
470            output_schema_bytes,
471            input_schema_bytes,
472            stream_id,
473            current_unix_secs(),
474        )
475    }
476
477    /// Open a v4 state token, decrypting under the current caller's
478    /// identity-derived AAD and enforcing TTL after authenticity.
479    pub(crate) fn unpack_state_token(
480        &self,
481        auth: &crate::auth::AuthContext,
482        token: &str,
483    ) -> Result<UnpackedToken> {
484        let ttl = if self.token_ttl.is_zero() {
485            None
486        } else {
487            Some(self.token_ttl)
488        };
489        let aad = compute_aad(auth);
490        unpack_state_token(&self.token_key, &aad, token, ttl)
491    }
492}
493
494/// Build the AEAD associated data that binds a state token to the
495/// authenticated identity of its issuer. Anonymous and authenticated
496/// callers produce distinct AAD strings so a token minted in one
497/// context cannot be opened in another. Mirrors Python's `_compute_aad`.
498fn compute_aad(auth: &crate::auth::AuthContext) -> Vec<u8> {
499    let prefix = b"vgi_rpc.state.v4\x00";
500    if !auth.authenticated {
501        let mut out = Vec::with_capacity(prefix.len() + b"\x00anonymous".len());
502        out.extend_from_slice(prefix);
503        out.extend_from_slice(b"\x00anonymous");
504        return out;
505    }
506    let mut out =
507        Vec::with_capacity(prefix.len() + 1 + auth.domain.len() + 1 + auth.principal.len());
508    out.extend_from_slice(prefix);
509    out.push(0x01);
510    out.extend_from_slice(auth.domain.as_bytes());
511    out.push(0);
512    out.extend_from_slice(auth.principal.as_bytes());
513    out
514}
515
516/// Token version supported by this crate.
517pub(crate) const STATE_TOKEN_VERSION: u8 = 0x04;
518
519/// Decomposed contents of a v4 state token after AEAD authentication.
520#[derive(Debug, Clone)]
521pub(crate) struct UnpackedToken {
522    pub state_bytes: Vec<u8>,
523    pub output_schema_bytes: Vec<u8>,
524    pub input_schema_bytes: Vec<u8>,
525    pub stream_id: String,
526    #[allow(dead_code)]
527    pub created_at: u64,
528}
529
530/// Current time as seconds since the UNIX epoch.
531fn current_unix_secs() -> u64 {
532    std::time::SystemTime::now()
533        .duration_since(std::time::UNIX_EPOCH)
534        .map(|d| d.as_secs())
535        .unwrap_or(0)
536}
537
538/// Seal a state token (v4 wire format).
539///
540/// On-wire layout (base64-encoded):
541///
542/// ```text
543/// [1]    version = 0x04
544/// [24]   XChaCha20-Poly1305 nonce (random)
545/// [..]   ciphertext = XChaCha20-Poly1305-Seal(plaintext, aad, nonce, key)
546///        plaintext (little-endian):
547///          [8]  created_at (u64 seconds since epoch)
548///          [4]  len(state_bytes)           [N] state_bytes
549///          [4]  len(output_schema_bytes)   [M] output_schema_bytes
550///          [4]  len(input_schema_bytes)    [K] input_schema_bytes
551///          [4]  len(stream_id_bytes)       [L] stream_id_bytes (UTF-8)
552///        [16]   Poly1305 tag (appended by AEAD construction)
553/// ```
554///
555/// `created_at` lives inside the ciphertext so TTL enforcement runs
556/// after authenticity is established. The version byte is not part of
557/// the AAD — it acts as a format selector; a tampered version byte still
558/// fails decryption because [`crypto::open_bytes`] rejects it before
559/// touching the cipher.
560///
561/// The AEAD envelope (version byte + nonce + ciphertext+tag) is owned by
562/// [`crypto`]; only the *plaintext* framing inside the ciphertext is this
563/// function's concern.
564pub(crate) fn pack_state_token(
565    token_key: &[u8; 32],
566    aad: &[u8],
567    state_bytes: &[u8],
568    output_schema_bytes: &[u8],
569    input_schema_bytes: &[u8],
570    stream_id: &str,
571    created_at: u64,
572) -> String {
573    let mut plaintext = Vec::with_capacity(
574        8 + 4
575            + state_bytes.len()
576            + 4
577            + output_schema_bytes.len()
578            + 4
579            + input_schema_bytes.len()
580            + 4
581            + stream_id.len(),
582    );
583    plaintext.extend_from_slice(&created_at.to_le_bytes());
584    plaintext.extend_from_slice(&(state_bytes.len() as u32).to_le_bytes());
585    plaintext.extend_from_slice(state_bytes);
586    plaintext.extend_from_slice(&(output_schema_bytes.len() as u32).to_le_bytes());
587    plaintext.extend_from_slice(output_schema_bytes);
588    plaintext.extend_from_slice(&(input_schema_bytes.len() as u32).to_le_bytes());
589    plaintext.extend_from_slice(input_schema_bytes);
590    plaintext.extend_from_slice(&(stream_id.len() as u32).to_le_bytes());
591    plaintext.extend_from_slice(stream_id.as_bytes());
592
593    crate::crypto::seal_base64(&plaintext, token_key, aad, STATE_TOKEN_VERSION)
594}
595
596/// Open and verify a v4 state token. [`crypto::open_bytes`] authenticates
597/// the payload; every malformed, wrong-version, tampered, wrong-key, or
598/// AAD-mismatched (e.g. cross-principal replay) token surfaces as the same
599/// uniform signature-verification error so callers cannot distinguish
600/// failure modes via timing or message content. Only a base64 decode
601/// failure — observable before any crypto work — stays a distinct
602/// "Malformed state token".
603pub(crate) fn unpack_state_token(
604    token_key: &[u8; 32],
605    aad: &[u8],
606    token: &str,
607    token_ttl: Option<std::time::Duration>,
608) -> Result<UnpackedToken> {
609    let raw = base64::engine::general_purpose::STANDARD
610        .decode(token.as_bytes())
611        .map_err(|_| RpcError::runtime_error("Malformed state token"))?;
612
613    let plaintext = crate::crypto::open_bytes(&raw, token_key, aad, STATE_TOKEN_VERSION)
614        .map_err(|_| RpcError::runtime_error("State token signature verification failed"))?;
615
616    if plaintext.len() < 8 {
617        return Err(RpcError::runtime_error("Malformed state token"));
618    }
619    let created_at = u64::from_le_bytes(plaintext[0..8].try_into().unwrap());
620
621    if let Some(ttl) = token_ttl {
622        let now = current_unix_secs();
623        if now > created_at && now - created_at > ttl.as_secs() {
624            return Err(RpcError::runtime_error("State token expired"));
625        }
626        // A `created_at` in the future is clock skew between workers (or
627        // a tampered host clock). Without this guard the expiry check
628        // above is simply skipped — a token minted on a fast-clocked
629        // worker would dodge the TTL on every normal-clocked peer.
630        const MAX_CLOCK_SKEW_SECS: u64 = 60;
631        if created_at > now && created_at - now > MAX_CLOCK_SKEW_SECS {
632            return Err(RpcError::runtime_error(
633                "State token timestamp is implausibly in the future",
634            ));
635        }
636    }
637
638    let mut pos = 8;
639    let state_bytes = read_segment(&plaintext, &mut pos)?;
640    let output_schema_bytes = read_segment(&plaintext, &mut pos)?;
641    let input_schema_bytes = read_segment(&plaintext, &mut pos)?;
642    let stream_id_bytes = read_segment(&plaintext, &mut pos)?;
643    if pos != plaintext.len() {
644        return Err(RpcError::runtime_error("Malformed state token"));
645    }
646    let stream_id = String::from_utf8(stream_id_bytes)
647        .map_err(|_| RpcError::runtime_error("Malformed state token"))?;
648
649    Ok(UnpackedToken {
650        state_bytes,
651        output_schema_bytes,
652        input_schema_bytes,
653        stream_id,
654        created_at,
655    })
656}
657
658fn read_segment(buf: &[u8], pos: &mut usize) -> Result<Vec<u8>> {
659    if *pos + 4 > buf.len() {
660        return Err(RpcError::runtime_error("Malformed state token"));
661    }
662    let len = u32::from_le_bytes(buf[*pos..*pos + 4].try_into().unwrap()) as usize;
663    *pos += 4;
664    if *pos + len > buf.len() {
665        return Err(RpcError::runtime_error("Malformed state token"));
666    }
667    let out = buf[*pos..*pos + len].to_vec();
668    *pos += len;
669    Ok(out)
670}
671
672/// Serialize an Arrow schema into transportable bytes — wraps it in a
673/// zero-row IPC stream since the stock writer doesn't expose a raw
674/// `Schema.serialize()` path. Round-trip via [`read_schema_bytes`].
675fn write_schema_bytes(schema: &Schema) -> Result<Vec<u8>> {
676    let empty = empty_batch(schema)?;
677    crate::wire::write_one_batch(&empty, None)
678}
679
680/// Inverse of [`write_schema_bytes`].
681fn read_schema_bytes(bytes: &[u8]) -> Result<SchemaRef> {
682    let r = StreamReader::new(bytes)?;
683    Ok(r.schema())
684}
685
686/// A future that resolves when the process receives SIGTERM or SIGINT
687/// (or a Ctrl-C event on non-Unix). Pass to [`axum::serve`]'s
688/// `with_graceful_shutdown` to stop accepting new connections, drain
689/// in-flight requests, and exit cleanly.
690///
691/// ```no_run
692/// # async fn run(state: std::sync::Arc<vgi_rpc::http::HttpState>) {
693/// let app = vgi_rpc::http::build_router(state);
694/// let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await.unwrap();
695/// axum::serve(listener, app)
696///     .with_graceful_shutdown(vgi_rpc::http::shutdown_signal())
697///     .await
698///     .unwrap();
699/// # }
700/// ```
701pub async fn shutdown_signal() {
702    #[cfg(unix)]
703    {
704        use tokio::signal::unix::{signal, SignalKind};
705        let mut term = match signal(SignalKind::terminate()) {
706            Ok(s) => s,
707            Err(_) => {
708                let _ = tokio::signal::ctrl_c().await;
709                return;
710            }
711        };
712        let mut intr = match signal(SignalKind::interrupt()) {
713            Ok(s) => s,
714            Err(_) => {
715                let _ = tokio::signal::ctrl_c().await;
716                return;
717            }
718        };
719        tokio::select! {
720            _ = term.recv() => {},
721            _ = intr.recv() => {},
722        }
723    }
724    #[cfg(not(unix))]
725    {
726        let _ = tokio::signal::ctrl_c().await;
727    }
728}
729
730/// Serve `state` on `listener`, terminating cleanly on SIGTERM/SIGINT.
731/// Convenience wrapper around [`build_router`] +
732/// [`axum::serve`] + [`shutdown_signal`].
733pub async fn serve_with_shutdown(
734    state: Arc<HttpState>,
735    listener: tokio::net::TcpListener,
736) -> std::io::Result<()> {
737    let app = build_router(state);
738    axum::serve(listener, app)
739        .with_graceful_shutdown(shutdown_signal())
740        .await
741}
742
743/// Absolute ceiling on a buffered HTTP response body in the
744/// post-processing middleware. Mirrors `wire::MAX_IPC_MESSAGE_BYTES` —
745/// large enough for any reasonable Arrow batch, small enough that the
746/// middleware can never be driven to exhaust the heap.
747///
748/// This is **distinct** from the operator's `max_response_bytes`, which
749/// is a *soft* producer-side cap: a producer is allowed to overshoot it
750/// by one batch and then mint a continuation token, so the response
751/// body on the wire can legitimately exceed `max_response_bytes`. The
752/// middleware therefore caps at `max(this, 2 × max_response_bytes)` —
753/// see [`response_buffer_ceiling`].
754const MAX_RESPONSE_BYTES_HARD_CAP: usize = 256 * 1024 * 1024;
755
756/// Hard ceiling the post-processing middleware buffers a response under.
757/// Always at least [`MAX_RESPONSE_BYTES_HARD_CAP`]; when a (soft)
758/// `max_response_bytes` is configured, it leaves headroom for the
759/// one-batch producer overshoot that the continuation-token design
760/// permits.
761fn response_buffer_ceiling(state: &HttpState) -> usize {
762    match state.max_response_bytes {
763        Some(soft) => MAX_RESPONSE_BYTES_HARD_CAP.max(soft.saturating_mul(2)),
764        None => MAX_RESPONSE_BYTES_HARD_CAP,
765    }
766}
767
768pub fn build_router(state: Arc<HttpState>) -> Router {
769    let body_limit = state.max_body_size;
770    let request_timeout = state.request_timeout;
771    build_router_inner(state.clone())
772        .layer(axum::middleware::from_fn_with_state(
773            state,
774            postprocess_middleware,
775        ))
776        // Hard ceiling on the raw request body, enforced regardless of
777        // the `Content-Length` header (chunked uploads included).
778        .layer(tower_http::limit::RequestBodyLimitLayer::new(body_limit))
779        // Wall-clock ceiling per request so a stalled handler or a
780        // slow-loris client can't pin a runtime worker forever.
781        .layer(tower_http::timeout::TimeoutLayer::with_status_code(
782            StatusCode::REQUEST_TIMEOUT,
783            request_timeout,
784        ))
785        // Convert a panic in handler code into a 500 instead of a
786        // dropped connection. The `unwrap`s on the HTTP hot path are on
787        // infallible `Vec` writers and server-controlled schemas, but
788        // this is the defence-in-depth net for any that slip through.
789        .layer(tower_http::catch_panic::CatchPanicLayer::new())
790}
791
792async fn postprocess_middleware(
793    axum::extract::State(state): axum::extract::State<Arc<HttpState>>,
794    req: axum::http::Request<axum::body::Body>,
795    next: axum::middleware::Next,
796) -> Response {
797    use axum::body::to_bytes;
798    // Bind the server to the HTTP transport on first request. Idempotent
799    // for the (kind, caps) pair so calling it per-request is cheap;
800    // fork-safe for pre-fork deployments because each child fires once.
801    state.server.notify_transport(
802        crate::transport::TransportKind::Http,
803        crate::transport::TransportCapabilities::none(),
804    );
805    let req_headers = req.headers().clone();
806    let req_method = req.method().clone();
807    let req_path = req.uri().path().to_string();
808
809    // Enforce server-advertised max_request_bytes before invoking the
810    // handler. Exempt the upload-URL flow itself and `/health` so
811    // clients can still discover capabilities and request URLs even if
812    // their next request would exceed the limit.
813    if let Some(limit) = state.max_request_bytes {
814        let exempt = req_path.ends_with("/__upload_url__/init")
815            || req_path.contains("/__upload_url__/")
816            || req_path == "/health"
817            || req_path.ends_with("/health");
818        if !exempt {
819            if let Some(cl) = req
820                .headers()
821                .get(header::CONTENT_LENGTH)
822                .and_then(|v| v.to_str().ok())
823                .and_then(|s| s.parse::<usize>().ok())
824            {
825                if cl > limit {
826                    let mut h = HeaderMap::new();
827                    attach_capability_headers(&state, &mut h, &req_method);
828                    attach_cors_headers(&state, &mut h, &req_headers, false);
829                    return (
830                        StatusCode::PAYLOAD_TOO_LARGE,
831                        h,
832                        format!(
833                            "Request body of {cl} bytes exceeds advertised \
834                             max_request_bytes={limit}. Use the upload-URL \
835                             flow (__upload_url__/init) to externalize."
836                        ),
837                    )
838                        .into_response();
839                }
840            }
841        }
842    }
843
844    let resp = next.run(req).await;
845    let (mut parts, body) = resp.into_parts();
846    // Buffer the response under a hard ceiling. `usize::MAX` here let a
847    // large handler response exhaust the heap; on overflow fail loud
848    // with a 500 rather than `unwrap_or_default()` silently shipping an
849    // empty 200. The ceiling is *not* `max_response_bytes` (a soft
850    // producer-side cap the wire may legitimately overshoot) — see
851    // `response_buffer_ceiling`. Externalised payloads leave only tiny
852    // pointer batches on the wire, so they never approach this bound.
853    let response_limit = response_buffer_ceiling(&state);
854    let bytes = match to_bytes(body, response_limit).await {
855        Ok(b) => b,
856        Err(_) => {
857            let mut h = HeaderMap::new();
858            attach_cors_headers(&state, &mut h, &req_headers, false);
859            return (
860                StatusCode::INTERNAL_SERVER_ERROR,
861                h,
862                "response body exceeded the configured size limit",
863            )
864                .into_response();
865        }
866    };
867    let is_arrow = parts
868        .headers
869        .get(header::CONTENT_TYPE)
870        .and_then(|v| v.to_str().ok())
871        == Some(ARROW_CONTENT_TYPE);
872
873    if is_arrow {
874        if let Some(level) = state.response_compression_level {
875            let accepts = req_headers
876                .get(header::ACCEPT_ENCODING)
877                .and_then(|v| v.to_str().ok())
878                .unwrap_or("");
879            if accepts.contains("zstd") {
880                if let Ok(compressed) = zstd::encode_all(std::io::Cursor::new(&bytes), level) {
881                    parts
882                        .headers
883                        .insert(header::CONTENT_ENCODING, HeaderValue::from_static("zstd"));
884                    attach_cors_headers(&state, &mut parts.headers, &req_headers, false);
885                    let body_new = axum::body::Body::from(compressed);
886                    return Response::from_parts(parts, body_new);
887                }
888            }
889        }
890    }
891    attach_cors_headers(&state, &mut parts.headers, &req_headers, false);
892    attach_capability_headers(&state, &mut parts.headers, &req_method);
893    Response::from_parts(parts, axum::body::Body::from(bytes))
894}
895
896/// Attach `VGI-Max-Request-Bytes`, `VGI-Upload-URL-Support`,
897/// `VGI-Max-Upload-Bytes` capability headers when configured. On
898/// `OPTIONS` responses also stamp `Cache-Control: public, max-age=300`
899/// so clients cache discovery results, mirroring the Python
900/// `_CapabilitiesMiddleware`.
901fn attach_capability_headers(
902    state: &Arc<HttpState>,
903    out: &mut HeaderMap,
904    method: &axum::http::Method,
905) {
906    let mut any = false;
907    if let Some(n) = state.max_request_bytes {
908        if let Ok(v) = HeaderValue::from_str(&n.to_string()) {
909            out.insert("vgi-max-request-bytes", v);
910            any = true;
911        }
912    }
913    if let Some(n) = state.max_response_bytes {
914        if let Ok(v) = HeaderValue::from_str(&n.to_string()) {
915            out.insert("vgi-max-response-bytes", v);
916            any = true;
917        }
918    }
919    if let Some(n) = state.max_externalized_response_bytes {
920        if let Ok(v) = HeaderValue::from_str(&n.to_string()) {
921            out.insert("vgi-max-externalized-response-bytes", v);
922            any = true;
923        }
924    }
925    // Always present so capability-aware clients can decide whether to
926    // expect externalised payloads.
927    out.insert(
928        "vgi-externalization-enabled",
929        HeaderValue::from_static(if state.server.external_config().is_some() {
930            "true"
931        } else {
932            "false"
933        }),
934    );
935    if state.upload_url_provider.is_some() {
936        out.insert("vgi-upload-url-support", HeaderValue::from_static("true"));
937        any = true;
938        if let Some(n) = state.max_upload_bytes {
939            if let Ok(v) = HeaderValue::from_str(&n.to_string()) {
940                out.insert("vgi-max-upload-bytes", v);
941            }
942        }
943    }
944    // Sticky-session capabilities. Always emit the enabled flag (negative
945    // form when off) so capability discovery is unambiguous.
946    if let Some(sticky) = state.sticky.as_ref() {
947        out.insert(STICKY_ENABLED_HEADER, HeaderValue::from_static("true"));
948        any = true;
949        if let Ok(v) = HeaderValue::from_str(&sticky.default_ttl.as_secs().to_string()) {
950            out.insert(STICKY_DEFAULT_TTL_HEADER, v);
951        }
952        if !sticky.echo_headers.is_empty() {
953            let names = sticky
954                .echo_headers
955                .iter()
956                .map(|(n, _)| n.as_str())
957                .collect::<Vec<_>>()
958                .join(",");
959            if let Ok(v) = HeaderValue::from_str(&names) {
960                out.insert(STICKY_ECHO_HEADERS_HEADER, v);
961            }
962        }
963    } else {
964        out.insert(STICKY_ENABLED_HEADER, HeaderValue::from_static("false"));
965    }
966    if any && method == axum::http::Method::OPTIONS {
967        out.insert(
968            header::CACHE_CONTROL,
969            HeaderValue::from_static("public, max-age=300"),
970        );
971    }
972}
973
974fn build_router_inner(state: Arc<HttpState>) -> Router {
975    let prefix = state.prefix.clone();
976    let api = Router::new()
977        .route("/:method", post(handle_unary).options(handle_preflight))
978        .route(
979            "/:method/init",
980            post(handle_stream_init).options(handle_preflight),
981        )
982        .route(
983            "/:method/exchange",
984            post(handle_stream_exchange).options(handle_preflight),
985        );
986
987    let api = if state.upload_url_provider.is_some() {
988        api.route(
989            "/__upload_url__/init",
990            post(handle_upload_url).options(handle_preflight),
991        )
992    } else {
993        api
994    };
995
996    let api = if state.sticky.is_some() {
997        api.route(
998            "/__session__",
999            axum::routing::delete(handle_delete_session).options(handle_preflight),
1000        )
1001    } else {
1002        api
1003    };
1004
1005    let mut app = if prefix.is_empty() {
1006        api
1007    } else {
1008        Router::new().nest(&prefix, api)
1009    };
1010
1011    app = app.route(
1012        &format!(
1013            "{}{}",
1014            prefix,
1015            crate::auth::oauth::OAuthResourceMetadata::well_known_path()
1016        ),
1017        axum::routing::get(handle_oauth_metadata),
1018    );
1019
1020    if state.health_enabled {
1021        // Always mount `/health` at the absolute root, regardless of
1022        // the API prefix. Liveness probes / load-balancer health
1023        // checks should never have to know which URL prefix the API
1024        // is under, and the conformance suite verifies it bypasses
1025        // auth even when every RPC endpoint requires it.
1026        app = app.route(
1027            "/health",
1028            axum::routing::get(handle_health).options(handle_preflight),
1029        );
1030    }
1031    if state.landing_page_enabled {
1032        let landing_path = if prefix.is_empty() {
1033            "/".to_string()
1034        } else {
1035            prefix.clone()
1036        };
1037        app = app.route(&landing_path, axum::routing::get(handle_landing));
1038    }
1039    if state.describe_page_enabled {
1040        app = app.route(
1041            &format!("{prefix}/describe"),
1042            axum::routing::get(handle_describe_page),
1043        );
1044    }
1045
1046    app.with_state(state)
1047}
1048
1049/// `DELETE {prefix}/__session__` — idempotent best-effort session teardown.
1050/// Token absent / stale / forged / wrong-principal ⇒ 200 (no info leak);
1051/// a live session ⇒ close it, emit `VGI-Session-Close: true`, 204.
1052async fn handle_delete_session(
1053    State(state): State<Arc<HttpState>>,
1054    headers: HeaderMap,
1055) -> Response {
1056    let auth = match authenticate_request(&state, SESSION_ENDPOINT, &headers) {
1057        Ok(a) => a,
1058        Err(resp) => return resp,
1059    };
1060    let Some(ctx) = state.sticky.as_ref() else {
1061        return StatusCode::OK.into_response();
1062    };
1063    let session_header = headers.get(SESSION_HEADER).and_then(|v| v.to_str().ok());
1064    match crate::sticky::handle_delete(ctx, &auth, session_header) {
1065        crate::sticky::DeleteOutcome::Idempotent => StatusCode::OK.into_response(),
1066        crate::sticky::DeleteOutcome::Closed => {
1067            let mut h = HeaderMap::new();
1068            h.insert(SESSION_CLOSE_HEADER, HeaderValue::from_static("true"));
1069            (StatusCode::NO_CONTENT, h).into_response()
1070        }
1071    }
1072}
1073
1074async fn handle_preflight(State(state): State<Arc<HttpState>>, headers: HeaderMap) -> Response {
1075    let mut h = HeaderMap::new();
1076    attach_cors_headers(&state, &mut h, &headers, true);
1077    (StatusCode::NO_CONTENT, h).into_response()
1078}
1079
1080async fn handle_health(State(state): State<Arc<HttpState>>) -> Response {
1081    let body = serde_json::json!({
1082        "status": "ok",
1083        "server_id": state.server.server_id,
1084        "protocol": state.server.protocol_name(),
1085    })
1086    .to_string();
1087    let mut h = HeaderMap::new();
1088    h.insert(
1089        header::CONTENT_TYPE,
1090        HeaderValue::from_static("application/json"),
1091    );
1092    (StatusCode::OK, h, body).into_response()
1093}
1094
1095async fn handle_landing(State(state): State<Arc<HttpState>>) -> Response {
1096    let body = render_landing(&state);
1097    let mut h = HeaderMap::new();
1098    h.insert(
1099        header::CONTENT_TYPE,
1100        HeaderValue::from_static("text/html; charset=utf-8"),
1101    );
1102    (StatusCode::OK, h, body).into_response()
1103}
1104
1105async fn handle_describe_page(State(state): State<Arc<HttpState>>) -> Response {
1106    let body = render_describe_page(&state);
1107    let mut h = HeaderMap::new();
1108    h.insert(
1109        header::CONTENT_TYPE,
1110        HeaderValue::from_static("text/html; charset=utf-8"),
1111    );
1112    (StatusCode::OK, h, body).into_response()
1113}
1114
1115fn render_landing(state: &Arc<HttpState>) -> String {
1116    let name = if state.server.protocol_name().is_empty() {
1117        "vgi-rpc service"
1118    } else {
1119        state.server.protocol_name()
1120    };
1121    let server_id = &state.server.server_id;
1122    let describe_link = if state.describe_page_enabled {
1123        format!(
1124            r#"<p><a href="{0}/describe">API reference</a></p>"#,
1125            state.prefix
1126        )
1127    } else {
1128        String::new()
1129    };
1130    format!(
1131        "<!doctype html><html><head><meta charset=\"utf-8\"><title>{name}</title></head><body>\
1132         <h1>{name}</h1><p>server_id: <code>{server_id}</code></p>{describe_link}\
1133         </body></html>"
1134    )
1135}
1136
1137fn render_describe_page(state: &Arc<HttpState>) -> String {
1138    let mut body = String::from(
1139        "<!doctype html><html><head><meta charset=\"utf-8\"><title>API reference</title></head><body>",
1140    );
1141    body.push_str(&format!(
1142        "<h1>{}</h1><table><tr><th>method</th><th>type</th><th>doc</th></tr>",
1143        state.server.protocol_name()
1144    ));
1145    for name in state.server.sorted_method_names() {
1146        let m = &state.server.methods()[name];
1147        let kind = match m.method_type {
1148            crate::server::MethodType::Unary => "unary",
1149            _ => "stream",
1150        };
1151        let doc = m.doc.as_deref().unwrap_or("");
1152        body.push_str(&format!(
1153            "<tr><td><code>{name}</code></td><td>{kind}</td><td>{}</td></tr>",
1154            html_escape(doc)
1155        ));
1156    }
1157    body.push_str("</table></body></html>");
1158    body
1159}
1160
1161fn html_escape(s: &str) -> String {
1162    s.replace('&', "&amp;")
1163        .replace('<', "&lt;")
1164        .replace('>', "&gt;")
1165}
1166
1167fn attach_cors_headers(
1168    state: &Arc<HttpState>,
1169    out: &mut HeaderMap,
1170    req_headers: &HeaderMap,
1171    is_preflight: bool,
1172) {
1173    let Some(origins) = state.cors_origins.as_deref() else {
1174        return;
1175    };
1176    if let Ok(v) = HeaderValue::from_str(origins) {
1177        out.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, v);
1178    }
1179    // When an authenticate callback is configured, requests carry
1180    // credentials (cookie / bearer). A browser only honours those
1181    // cross-origin when `Allow-Credentials: true` is present *and* the
1182    // origin is specific — `HttpStateBuilder::build` rejects the
1183    // `"*"` + auth combination, so the configured origin here is
1184    // always concrete.
1185    if state.authenticate.is_some() {
1186        out.insert(
1187            header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
1188            HeaderValue::from_static("true"),
1189        );
1190    }
1191    out.insert(
1192        header::ACCESS_CONTROL_ALLOW_METHODS,
1193        HeaderValue::from_static("POST, GET, OPTIONS"),
1194    );
1195    let requested = req_headers
1196        .get(header::ACCESS_CONTROL_REQUEST_HEADERS)
1197        .and_then(|v| v.to_str().ok())
1198        .unwrap_or("Content-Type, Authorization, Cookie, Accept-Encoding");
1199    if let Ok(v) = HeaderValue::from_str(requested) {
1200        out.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, v);
1201    }
1202    out.insert(
1203        header::ACCESS_CONTROL_EXPOSE_HEADERS,
1204        HeaderValue::from_static("Content-Encoding, WWW-Authenticate"),
1205    );
1206    if is_preflight {
1207        if let Ok(v) = HeaderValue::from_str(&state.cors_max_age.to_string()) {
1208            out.insert(header::ACCESS_CONTROL_MAX_AGE, v);
1209        }
1210    }
1211}
1212
1213async fn handle_oauth_metadata(State(state): State<Arc<HttpState>>) -> Response {
1214    match state.oauth_metadata_json.as_ref() {
1215        Some(body) => {
1216            let mut h = HeaderMap::new();
1217            h.insert(
1218                header::CONTENT_TYPE,
1219                HeaderValue::from_static("application/json"),
1220            );
1221            h.insert(
1222                header::CACHE_CONTROL,
1223                HeaderValue::from_static("public, max-age=60"),
1224            );
1225            (StatusCode::OK, h, body.clone()).into_response()
1226        }
1227        None => (StatusCode::NOT_FOUND, "").into_response(),
1228    }
1229}
1230
1231/// Parse a `Cookie:` header into a name→value map. Surrounding
1232/// double-quotes on a value (RFC 6265 quoted form) are stripped.
1233fn parse_cookies(raw: Option<&str>) -> std::collections::BTreeMap<String, String> {
1234    let mut out = std::collections::BTreeMap::new();
1235    let Some(raw) = raw else { return out };
1236    for part in raw.split(';') {
1237        let part = part.trim();
1238        if let Some((k, v)) = part.split_once('=') {
1239            let v = v.trim();
1240            let v = v
1241                .strip_prefix('"')
1242                .and_then(|s| s.strip_suffix('"'))
1243                .unwrap_or(v);
1244            out.insert(k.trim().to_string(), v.to_string());
1245        }
1246    }
1247    out
1248}
1249
1250/// Copy the request headers into a `Vec<(String, String)>` for AuthRequest.
1251fn headers_to_pairs(headers: &HeaderMap) -> Vec<(String, String)> {
1252    headers
1253        .iter()
1254        .filter_map(|(k, v)| {
1255            v.to_str()
1256                .ok()
1257                .map(|s| (k.as_str().to_string(), s.to_string()))
1258        })
1259        .collect()
1260}
1261
1262/// Run the authenticate callback (if any); on error, build a 401 response
1263/// with WWW-Authenticate attached.
1264fn authenticate_request(
1265    state: &Arc<HttpState>,
1266    method: &str,
1267    headers: &HeaderMap,
1268) -> std::result::Result<crate::auth::AuthContext, Response> {
1269    let Some(cb) = state.authenticate.as_ref() else {
1270        return Ok(crate::auth::AuthContext::anonymous());
1271    };
1272    let pairs = headers_to_pairs(headers);
1273    let req = crate::auth::AuthRequest {
1274        method,
1275        headers: &pairs,
1276        peer_addr: None,
1277    };
1278    match (cb)(&req) {
1279        Ok(ctx) => Ok(ctx),
1280        Err(err) => {
1281            let status = match err.error_type.as_str() {
1282                "PermissionError" | "ValueError" => StatusCode::UNAUTHORIZED,
1283                _ => StatusCode::INTERNAL_SERVER_ERROR,
1284            };
1285            let mut h = HeaderMap::new();
1286            // The response body must not echo internal detail. A 401 is
1287            // part of the auth contract, but the verifier's message can
1288            // carry an attacker-supplied `kid` or a raw library error —
1289            // keep that in the logs, return a generic body. A 500 from
1290            // the auth callback (e.g. a JWKS fetch failure) is purely
1291            // internal and gets the same treatment.
1292            let body = if status == StatusCode::UNAUTHORIZED {
1293                if let Some(wa) = state.www_authenticate.as_deref() {
1294                    if let Ok(hv) = HeaderValue::from_str(wa) {
1295                        h.insert(header::WWW_AUTHENTICATE, hv);
1296                    }
1297                }
1298                tracing::info!(
1299                    target: "vgi_rpc.http",
1300                    error = %err.message,
1301                    "request authentication rejected"
1302                );
1303                "authentication failed"
1304            } else {
1305                tracing::error!(
1306                    target: "vgi_rpc.http",
1307                    error = %err.message,
1308                    "authentication callback errored"
1309                );
1310                "internal error during authentication"
1311            };
1312            Err((status, h, body).into_response())
1313        }
1314    }
1315}
1316
1317// ---------------------------------------------------------------------------
1318// Helpers
1319// ---------------------------------------------------------------------------
1320
1321fn arrow_response(status: StatusCode, body: Vec<u8>) -> Response {
1322    let mut headers = HeaderMap::new();
1323    headers.insert(
1324        header::CONTENT_TYPE,
1325        HeaderValue::from_static(ARROW_CONTENT_TYPE),
1326    );
1327    (status, headers, body).into_response()
1328}
1329
1330/// Hard wire-cap enforcement helper for stream-exchange responses.
1331/// Returns the original 200 response when within budget; otherwise
1332/// rebuilds the response as an EXCEPTION-only IPC stream surfaced via
1333/// 200 + `X-VGI-RPC-Error: true`.
1334fn enforce_response_body_cap(
1335    state: &Arc<HttpState>,
1336    schema: &arrow_schema::Schema,
1337    body: Vec<u8>,
1338    method: &str,
1339    server_id: &str,
1340    request_id: &str,
1341) -> Response {
1342    if let Some(limit) = state.max_response_bytes {
1343        if body.len() > limit {
1344            let err = RpcError::runtime_error(format!(
1345                "HTTP body exceeds max_response_bytes ({} > {}) for method {:?}",
1346                body.len(),
1347                limit,
1348                method
1349            ));
1350            return cap_error_response(schema, &err, server_id, request_id);
1351        }
1352    }
1353    arrow_response(StatusCode::OK, body)
1354}
1355
1356/// Build a fresh IPC stream containing only an EXCEPTION batch and emit
1357/// it as 200 + `X-VGI-RPC-Error: true` so RPC clients see the message
1358/// as `RpcError`, not a transport failure. Used by the response-cap
1359/// strict-fail path.
1360fn cap_error_response(
1361    schema: &arrow_schema::Schema,
1362    err: &RpcError,
1363    server_id: &str,
1364    request_id: &str,
1365) -> Response {
1366    let mut buf = Vec::new();
1367    {
1368        let mut sw = StreamWriter::new(&mut buf, schema).unwrap();
1369        let md = build_error_metadata(err, server_id, request_id);
1370        let _ = sw.write(&empty_batch(schema).unwrap(), Some(&md));
1371        let _ = sw.finish();
1372    }
1373    let mut headers = HeaderMap::new();
1374    headers.insert(
1375        header::CONTENT_TYPE,
1376        HeaderValue::from_static(ARROW_CONTENT_TYPE),
1377    );
1378    headers.insert("x-vgi-rpc-error", HeaderValue::from_static("true"));
1379    (StatusCode::OK, headers, buf).into_response()
1380}
1381
1382fn plain_error(status: StatusCode, msg: String) -> Response {
1383    (status, msg).into_response()
1384}
1385
1386fn has_arrow_ct(headers: &HeaderMap) -> bool {
1387    headers
1388        .get(header::CONTENT_TYPE)
1389        .and_then(|v| v.to_str().ok())
1390        .map(|s| s == ARROW_CONTENT_TYPE)
1391        .unwrap_or(false)
1392}
1393
1394fn maybe_decompress(headers: &HeaderMap, body: &Bytes, max_size: usize) -> Result<Vec<u8>> {
1395    let enc = headers
1396        .get(header::CONTENT_ENCODING)
1397        .and_then(|v| v.to_str().ok())
1398        .unwrap_or("");
1399    if body.len() > max_size {
1400        return Err(RpcError::runtime_error(format!(
1401            "Request body exceeds max size ({} bytes > {})",
1402            body.len(),
1403            max_size
1404        )));
1405    }
1406    if enc.eq_ignore_ascii_case("zstd") {
1407        decode_zstd_bounded(body.as_ref(), max_size)
1408    } else {
1409        Ok(body.to_vec())
1410    }
1411}
1412
1413/// Stream-decode a zstd payload, aborting once the decompressed length
1414/// exceeds `max_size`. Defends against zip-bomb-style oversized payloads
1415/// without first allocating the full decompressed result.
1416fn decode_zstd_bounded(input: &[u8], max_size: usize) -> Result<Vec<u8>> {
1417    use std::io::Read;
1418    let mut decoder = zstd::Decoder::new(input)
1419        .map_err(|e| RpcError::runtime_error(format!("zstd decode: {e}")))?;
1420    let mut out = Vec::with_capacity(input.len().min(max_size).min(64 * 1024));
1421    let mut buf = [0u8; 16 * 1024];
1422    loop {
1423        let n = decoder
1424            .read(&mut buf)
1425            .map_err(|e| RpcError::runtime_error(format!("zstd decode: {e}")))?;
1426        if n == 0 {
1427            break;
1428        }
1429        if out.len() + n > max_size {
1430            return Err(RpcError::runtime_error(format!(
1431                "Decompressed body exceeds max size ({}+ bytes > {})",
1432                out.len() + n,
1433                max_size
1434            )));
1435        }
1436        out.extend_from_slice(&buf[..n]);
1437    }
1438    Ok(out)
1439}
1440
1441fn parse_request_from_body(body: &[u8]) -> Result<Request> {
1442    let mut r = StreamReader::new(body)?;
1443    let (batch, metadata) = r
1444        .read_next()?
1445        .ok_or_else(|| RpcError::protocol_error("empty IPC stream"))?;
1446    r.drain()?;
1447    Request::from_read_batch(batch, metadata, false)
1448}
1449
1450fn error_stream_bytes(
1451    schema: &Schema,
1452    err: &RpcError,
1453    server_id: &str,
1454    request_id: &str,
1455) -> Vec<u8> {
1456    let mut buf = Vec::new();
1457    let mut w = StreamWriter::new(&mut buf, schema).unwrap();
1458    let md = build_error_metadata(err, server_id, request_id);
1459    let _ = w.write(&empty_batch(schema).unwrap(), Some(&md));
1460    let _ = w.finish();
1461    drop(w);
1462    buf
1463}
1464
1465/// Build a complete arrow-typed error response. Centralizes the
1466/// `arrow_response(status, error_stream_bytes(Schema::empty(), ...))`
1467/// pattern used by every error-returning branch of the HTTP handlers.
1468fn arrow_error(
1469    state: &Arc<HttpState>,
1470    status: StatusCode,
1471    err: &RpcError,
1472    request_id: &str,
1473) -> Response {
1474    arrow_response(
1475        status,
1476        error_stream_bytes(&Schema::empty(), err, &state.server.server_id, request_id),
1477    )
1478}
1479
1480/// Resolve sticky headers on an incoming request. `Ok(Some(sink))` to
1481/// install on the [`CallContext`]; `Ok(None)` when sticky is disabled;
1482/// `Err(resp)` to short-circuit with a `SessionLostError` response when a
1483/// presented token failed to resolve.
1484fn sticky_for_request(
1485    state: &Arc<HttpState>,
1486    auth: &crate::auth::AuthContext,
1487    headers: &HeaderMap,
1488) -> std::result::Result<Option<Arc<crate::sticky::StickySinkImpl>>, Response> {
1489    let Some(ctx) = state.sticky.as_ref() else {
1490        return Ok(None);
1491    };
1492    let accept = headers
1493        .get(SESSION_ACCEPT_HEADER)
1494        .and_then(|v| v.to_str().ok())
1495        .map(|s| s.trim().eq_ignore_ascii_case("true"))
1496        .unwrap_or(false);
1497    let session_header = headers.get(SESSION_HEADER).and_then(|v| v.to_str().ok());
1498    match crate::sticky::resolve(ctx, auth, accept, session_header) {
1499        crate::sticky::StickyResolution::Sink(s) => Ok(Some(s)),
1500        crate::sticky::StickyResolution::Lost(err) => Err(cap_error_response(
1501            &Schema::empty(),
1502            &err,
1503            &state.server.server_id,
1504            "",
1505        )),
1506    }
1507}
1508
1509/// Stamp `VGI-Session` (+ echo headers) and `VGI-Session-Close` onto a
1510/// response according to the per-request sink's mint/close signals.
1511fn stamp_session_headers(
1512    resp: &mut Response,
1513    state: &Arc<HttpState>,
1514    sink: &Arc<crate::sticky::StickySinkImpl>,
1515) {
1516    let headers = resp.headers_mut();
1517    if let Some(token) = sink.mint_token() {
1518        if let Ok(v) = HeaderValue::from_str(&token) {
1519            headers.insert(SESSION_HEADER, v);
1520        }
1521        if let Some(ctx) = state.sticky.as_ref() {
1522            for (name, value) in &ctx.echo_headers {
1523                let full = format!("{ECHO_HEADER_PREFIX}{name}");
1524                if let (Ok(n), Ok(v)) = (
1525                    axum::http::HeaderName::from_bytes(full.as_bytes()),
1526                    HeaderValue::from_str(value),
1527                ) {
1528                    headers.insert(n, v);
1529                }
1530            }
1531        }
1532    }
1533    if sink.was_closed() {
1534        headers.insert(SESSION_CLOSE_HEADER, HeaderValue::from_static("true"));
1535    }
1536}
1537
1538fn decode_hex_key(s: &str) -> std::result::Result<Vec<u8>, String> {
1539    let s = s.trim();
1540    if s.len() % 2 != 0 {
1541        return Err("hex length must be even".into());
1542    }
1543    let mut out = Vec::with_capacity(s.len() / 2);
1544    let bytes = s.as_bytes();
1545    for pair in bytes.chunks_exact(2) {
1546        let hi = hex_nibble(pair[0])?;
1547        let lo = hex_nibble(pair[1])?;
1548        out.push((hi << 4) | lo);
1549    }
1550    if out.len() < 32 {
1551        return Err(format!(
1552            "signing key must be ≥ 32 bytes (got {} bytes)",
1553            out.len()
1554        ));
1555    }
1556    Ok(out)
1557}
1558
1559fn hex_nibble(c: u8) -> std::result::Result<u8, String> {
1560    match c {
1561        b'0'..=b'9' => Ok(c - b'0'),
1562        b'a'..=b'f' => Ok(c - b'a' + 10),
1563        b'A'..=b'F' => Ok(c - b'A' + 10),
1564        _ => Err(format!("invalid hex character: {:?}", c as char)),
1565    }
1566}
1567
1568fn decode_base64_key(s: &str) -> std::result::Result<Vec<u8>, String> {
1569    // Accept both padded and unpadded standard base64.
1570    let s = s.trim().trim_end_matches('=');
1571    let mut padded = s.to_string();
1572    while padded.len() % 4 != 0 {
1573        padded.push('=');
1574    }
1575    let bytes = base64::engine::general_purpose::STANDARD
1576        .decode(padded.as_bytes())
1577        .map_err(|e| format!("base64 decode: {e}"))?;
1578    if bytes.len() < 32 {
1579        return Err(format!(
1580            "signing key must be ≥ 32 bytes (got {} bytes)",
1581            bytes.len()
1582        ));
1583    }
1584    Ok(bytes)
1585}
1586
1587fn new_session_id() -> String {
1588    let mut b = [0u8; 16];
1589    rand::thread_rng().fill_bytes(&mut b);
1590    bytes_to_hex(&b)
1591}
1592
1593// ---------------------------------------------------------------------------
1594// Upload-URL endpoint
1595// ---------------------------------------------------------------------------
1596
1597const UPLOAD_URL_METHOD: &str = "__upload_url__";
1598const MAX_UPLOAD_URL_COUNT: i64 = 100;
1599
1600fn upload_url_response_schema() -> Schema {
1601    use arrow_schema::{DataType, Field, TimeUnit};
1602    Schema::new(vec![
1603        Field::new("upload_url", DataType::Utf8, false),
1604        Field::new("download_url", DataType::Utf8, false),
1605        Field::new(
1606            "expires_at",
1607            DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
1608            false,
1609        ),
1610    ])
1611}
1612
1613async fn handle_upload_url(
1614    State(state): State<Arc<HttpState>>,
1615    headers: HeaderMap,
1616    body: Bytes,
1617) -> Response {
1618    let auth = match authenticate_request(&state, UPLOAD_URL_METHOD, &headers) {
1619        Ok(a) => a,
1620        Err(resp) => return resp,
1621    };
1622    let _ = auth;
1623    if !has_arrow_ct(&headers) {
1624        return plain_error(
1625            StatusCode::UNSUPPORTED_MEDIA_TYPE,
1626            "need arrow content type".into(),
1627        );
1628    }
1629    let provider = match state.upload_url_provider.as_ref() {
1630        Some(p) => p.clone(),
1631        None => return plain_error(StatusCode::NOT_FOUND, "upload-url not enabled".into()),
1632    };
1633
1634    let body = match maybe_decompress(&headers, &body, state.max_body_size) {
1635        Ok(b) => b,
1636        Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
1637    };
1638    let req = match parse_request_from_body(&body) {
1639        Ok(r) => r,
1640        Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
1641    };
1642    if !req.method.is_empty() && req.method != UPLOAD_URL_METHOD {
1643        let err = RpcError::protocol_error(format!(
1644            "Method mismatch: expected '{UPLOAD_URL_METHOD}', got '{}'",
1645            req.method
1646        ));
1647        return arrow_error(&state, StatusCode::BAD_REQUEST, &err, &req.request_id);
1648    }
1649    // Pull `count` from the int64 column (default 1, clamped to [1, MAX]).
1650    let mut count: i64 = 1;
1651    if let Some(arr) = req.column("count") {
1652        use arrow_array::Array;
1653        if let Some(c) = arr.as_any().downcast_ref::<arrow_array::Int64Array>() {
1654            if !c.is_empty() && !Array::is_null(c, 0) {
1655                count = c.value(0);
1656            }
1657        }
1658    }
1659    count = count.clamp(1, MAX_UPLOAD_URL_COUNT);
1660
1661    // Generate URLs (provider may block on HTTP — same caveat as
1662    // ExternalStorage::upload, hence block_in_place).
1663    let urls_res = tokio::task::block_in_place(|| {
1664        let mut out = Vec::with_capacity(count as usize);
1665        for _ in 0..count {
1666            out.push(provider.generate_upload_url()?);
1667        }
1668        Ok::<_, RpcError>(out)
1669    });
1670
1671    let schema = upload_url_response_schema();
1672    let mut body_buf = Vec::new();
1673    {
1674        let mut sw = match StreamWriter::new(&mut body_buf, &schema) {
1675            Ok(w) => w,
1676            Err(e) => {
1677                return arrow_error(
1678                    &state,
1679                    StatusCode::INTERNAL_SERVER_ERROR,
1680                    &e,
1681                    &req.request_id,
1682                )
1683            }
1684        };
1685        match urls_res {
1686            Ok(urls) => {
1687                use arrow_array::{StringArray, TimestampMicrosecondArray};
1688                let upload_arr = StringArray::from(
1689                    urls.iter()
1690                        .map(|u| u.upload_url.clone())
1691                        .collect::<Vec<_>>(),
1692                );
1693                let download_arr = StringArray::from(
1694                    urls.iter()
1695                        .map(|u| u.download_url.clone())
1696                        .collect::<Vec<_>>(),
1697                );
1698                let expires_arr = TimestampMicrosecondArray::from(
1699                    urls.iter().map(|u| u.expires_at_micros).collect::<Vec<_>>(),
1700                )
1701                .with_timezone("UTC");
1702                let batch = match RecordBatch::try_new(
1703                    Arc::new(schema.clone()),
1704                    vec![
1705                        Arc::new(upload_arr),
1706                        Arc::new(download_arr),
1707                        Arc::new(expires_arr),
1708                    ],
1709                ) {
1710                    Ok(b) => b,
1711                    Err(e) => {
1712                        let err = RpcError::runtime_error(format!("upload-url batch: {e}"));
1713                        let md =
1714                            build_error_metadata(&err, &state.server.server_id, &req.request_id);
1715                        let _ = sw.write(&empty_batch(&schema).unwrap(), Some(&md));
1716                        let _ = sw.finish();
1717                        drop(sw);
1718                        return arrow_response(StatusCode::OK, body_buf);
1719                    }
1720                };
1721                let _ = sw.write(&batch, None);
1722            }
1723            Err(err) => {
1724                let md = build_error_metadata(&err, &state.server.server_id, &req.request_id);
1725                let _ = sw.write(&empty_batch(&schema).unwrap(), Some(&md));
1726            }
1727        }
1728        let _ = sw.finish();
1729    }
1730    arrow_response(StatusCode::OK, body_buf)
1731}
1732
1733// ---------------------------------------------------------------------------
1734// Unary
1735// ---------------------------------------------------------------------------
1736
1737async fn handle_unary(
1738    State(state): State<Arc<HttpState>>,
1739    Path(method): Path<String>,
1740    headers: HeaderMap,
1741    body: Bytes,
1742) -> Response {
1743    // Authenticate before any other rejection: an unauthenticated
1744    // caller should always see 401, regardless of whether they sent
1745    // the right content type or anything else.
1746    let auth = match authenticate_request(&state, &method, &headers) {
1747        Ok(a) => a,
1748        Err(resp) => return resp,
1749    };
1750    if !has_arrow_ct(&headers) {
1751        return plain_error(
1752            StatusCode::UNSUPPORTED_MEDIA_TYPE,
1753            "need arrow content type".into(),
1754        );
1755    }
1756    // Resolve sticky-session headers before dispatch. A presented token
1757    // that fails to resolve short-circuits with a SessionLostError stream.
1758    let sticky_sink = match sticky_for_request(&state, &auth, &headers) {
1759        Ok(s) => s,
1760        Err(resp) => return resp,
1761    };
1762    let cookies = parse_cookies(headers.get(header::COOKIE).and_then(|v| v.to_str().ok()));
1763    let server = state.server.clone();
1764
1765    let body = match maybe_decompress(&headers, &body, state.max_body_size) {
1766        Ok(b) => b,
1767        Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
1768    };
1769    let mut req = match parse_request_from_body(&body) {
1770        Ok(r) => r,
1771        Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
1772    };
1773
1774    // If the request batch is an external-location pointer (zero rows +
1775    // `vgi_rpc.location` metadata), fetch the referenced bytes and use
1776    // the inner batch's columns for parameter extraction. Dispatch
1777    // metadata (method, request_id) is taken from the outer batch.
1778    if md_get(&req.metadata, crate::metadata::LOCATION_KEY).is_some() {
1779        if let Some(cfg) = server.external_config().as_ref() {
1780            let outer_md = req.metadata.clone();
1781            let outer_batch = req.batch.clone();
1782            let resolved = tokio::task::block_in_place(|| {
1783                crate::external::resolve_external_location(&outer_batch, &outer_md, cfg)
1784            });
1785            match resolved {
1786                Ok((inner_batch, _user_md)) => {
1787                    req.batch = inner_batch;
1788                }
1789                Err(err) => {
1790                    return arrow_error(&state, StatusCode::BAD_REQUEST, &err, &req.request_id);
1791                }
1792            }
1793        }
1794    }
1795
1796    // __describe__ introspection — served as a unary call.
1797    if server.describe_enabled() && method == crate::introspect::DESCRIBE_METHOD_NAME {
1798        let (batch, md) = match crate::introspect::build_describe(
1799            server.protocol_name(),
1800            server.methods(),
1801            &server.server_id,
1802            server.protocol_version(),
1803        ) {
1804            Ok(x) => x,
1805            Err(err) => {
1806                return arrow_error(
1807                    &state,
1808                    StatusCode::INTERNAL_SERVER_ERROR,
1809                    &err,
1810                    &req.request_id,
1811                );
1812            }
1813        };
1814        let mut buf = Vec::new();
1815        let _ = crate::introspect::write_describe_response(&mut buf, &batch, &md);
1816        return arrow_response(StatusCode::OK, buf);
1817    }
1818
1819    let Some(info) = server
1820        .method(&method)
1821        .filter(|m| m.method_type == MethodType::Unary)
1822    else {
1823        let err = RpcError::attribute_error(format!("Unknown method: '{}'", method));
1824        return arrow_error(&state, StatusCode::NOT_FOUND, &err, &req.request_id);
1825    };
1826
1827    let mut ctx = CallContext::with_auth_cookies(&server, &req, auth.clone(), cookies);
1828    if let Some(s) = sticky_sink.clone() {
1829        ctx.set_sticky(s);
1830    }
1831    let dispatch_info = crate::hooks::DispatchInfo::from_request(&server, &req, "unary", &auth);
1832    let hook = server.dispatch_hook.clone();
1833    let hook_token = hook.as_ref().map(|h| h.on_dispatch_start(&dispatch_info));
1834
1835    let mut stats = crate::hooks::CallStatistics {
1836        input_batches: 1,
1837        input_rows: req.batch.num_rows() as u64,
1838        ..Default::default()
1839    };
1840
1841    // Isolate handler panics: convert them to an `RpcError` that flows into the
1842    // structured Arrow error envelope below, matching the stdio/unix serve loop.
1843    // Without this a panic would bottom out at the `CatchPanicLayer` as a bare
1844    // 500 the DuckDB client can't parse as a VGI error.
1845    let result =
1846        crate::server::call_guard(|| (info.unary.as_ref().unwrap())(&req, &ctx)).and_then(|r| r);
1847    let logs = ctx.drain_logs();
1848    let mut app_err: Option<RpcError> = None;
1849
1850    let mut buf = Vec::new();
1851    {
1852        let mut sw = StreamWriter::new(&mut buf, &info.result_schema).unwrap();
1853        for log in &logs {
1854            let md = build_log_metadata(log, &server.server_id, &req.request_id);
1855            let _ = sw.write(&empty_batch(&info.result_schema).unwrap(), Some(&md));
1856        }
1857        match result {
1858            Ok(batch_opt) => {
1859                let out_batch =
1860                    batch_opt.unwrap_or_else(|| empty_batch(&info.result_schema).unwrap());
1861                stats.output_batches = 1;
1862                stats.output_rows = out_batch.num_rows() as u64;
1863                if let Some(cfg) = server.external_config().as_ref() {
1864                    // `maybe_externalize_batch` may invoke a blocking
1865                    // upload (e.g. reqwest::blocking). Allow blocking
1866                    // in this async handler so the inner client's
1867                    // tokio runtime can drop without panicking.
1868                    let externalized = tokio::task::block_in_place(|| {
1869                        crate::external::maybe_externalize_batch(&out_batch, None, cfg)
1870                    });
1871                    match externalized {
1872                        Ok(Some((ptr, md))) => {
1873                            let _ = sw.write(&ptr, Some(&md));
1874                        }
1875                        Ok(None) => {
1876                            let _ = sw.write(&out_batch, None);
1877                        }
1878                        Err(err) => {
1879                            let md = build_error_metadata(&err, &server.server_id, &req.request_id);
1880                            let _ = sw.write(&empty_batch(&info.result_schema).unwrap(), Some(&md));
1881                            app_err = Some(err);
1882                        }
1883                    }
1884                } else {
1885                    let _ = sw.write(&out_batch, None);
1886                }
1887            }
1888            Err(err) => {
1889                let md = build_error_metadata(&err, &server.server_id, &req.request_id);
1890                let _ = sw.write(&empty_batch(&info.result_schema).unwrap(), Some(&md));
1891                app_err = Some(err);
1892            }
1893        }
1894        let _ = sw.finish();
1895    }
1896
1897    if let Some(hook) = hook {
1898        hook.on_dispatch_end(
1899            hook_token.unwrap_or(0),
1900            &dispatch_info,
1901            app_err.as_ref(),
1902            &stats,
1903        );
1904    }
1905    // Operator-facing wire body cap.  Hard for unary — overshoot
1906    // replaces the response with an EXCEPTION-only IPC stream surfaced
1907    // via 200 + `X-VGI-RPC-Error: true`.  Mirrors Python's strict-fail
1908    // contract; the literal `max_response_bytes` token in the message
1909    // is what the cross-language conformance suite asserts on.
1910    if let Some(limit) = state.max_response_bytes {
1911        if buf.len() > limit {
1912            let err = RpcError::runtime_error(format!(
1913                "HTTP body exceeds max_response_bytes ({} > {}) for method {:?}",
1914                buf.len(),
1915                limit,
1916                method
1917            ));
1918            let mut resp = cap_error_response(
1919                &info.result_schema,
1920                &err,
1921                &server.server_id,
1922                &req.request_id,
1923            );
1924            if let Some(s) = sticky_sink.as_ref() {
1925                stamp_session_headers(&mut resp, &state, s);
1926            }
1927            return resp;
1928        }
1929    }
1930    let mut resp = arrow_response(StatusCode::OK, buf);
1931    if let Some(s) = sticky_sink.as_ref() {
1932        stamp_session_headers(&mut resp, &state, s);
1933    }
1934    resp
1935}
1936
1937// ---------------------------------------------------------------------------
1938// Stream init
1939// ---------------------------------------------------------------------------
1940
1941async fn handle_stream_init(
1942    State(state): State<Arc<HttpState>>,
1943    Path(method): Path<String>,
1944    headers: HeaderMap,
1945    body: Bytes,
1946) -> Response {
1947    let auth = match authenticate_request(&state, &method, &headers) {
1948        Ok(a) => a,
1949        Err(resp) => return resp,
1950    };
1951    if !has_arrow_ct(&headers) {
1952        return plain_error(
1953            StatusCode::UNSUPPORTED_MEDIA_TYPE,
1954            "need arrow content type".into(),
1955        );
1956    }
1957    let sticky_sink = match sticky_for_request(&state, &auth, &headers) {
1958        Ok(s) => s,
1959        Err(resp) => return resp,
1960    };
1961    let auth_for_token = auth.clone();
1962    let cookies = parse_cookies(headers.get(header::COOKIE).and_then(|v| v.to_str().ok()));
1963    let server = state.server.clone();
1964    let body = match maybe_decompress(&headers, &body, state.max_body_size) {
1965        Ok(b) => b,
1966        Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
1967    };
1968    let req = match parse_request_from_body(&body) {
1969        Ok(r) => r,
1970        Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
1971    };
1972
1973    let Some(info) = server
1974        .method(&method)
1975        .filter(|m| m.method_type != MethodType::Unary)
1976    else {
1977        let err = RpcError::attribute_error(format!("Unknown stream method: '{}'", method));
1978        return arrow_error(&state, StatusCode::NOT_FOUND, &err, &req.request_id);
1979    };
1980
1981    let mut ctx = CallContext::with_auth_cookies(&server, &req, auth, cookies);
1982    if let Some(s) = sticky_sink.clone() {
1983        ctx.set_sticky(s);
1984    }
1985    // Isolate handler panics into the structured stream error envelope below
1986    // (see the unary path for rationale).
1987    let init_result =
1988        crate::server::call_guard(|| (info.stream.as_ref().unwrap())(&req, &ctx)).and_then(|r| r);
1989    let init_logs = ctx.drain_logs();
1990
1991    let sr = match init_result {
1992        Ok(s) => s,
1993        Err(err) => {
1994            return arrow_response(
1995                StatusCode::OK,
1996                error_stream_bytes(&empty_schema(), &err, &server.server_id, &req.request_id),
1997            );
1998        }
1999    };
2000
2001    let StreamResult {
2002        output_schema,
2003        input_schema,
2004        state: mut ss,
2005        header,
2006        header_metadata,
2007    } = sr;
2008
2009    let mut body_buf = Vec::new();
2010
2011    // Write header stream (if any) into body_buf.
2012    if let Some(header_batch) = header.as_ref() {
2013        let hdr_schema = header_batch.schema();
2014        let mut hw = StreamWriter::new(&mut body_buf, hdr_schema.as_ref()).unwrap();
2015        for log in &init_logs {
2016            let md = build_log_metadata(log, &server.server_id, &req.request_id);
2017            let _ = hw.write(&empty_batch(hdr_schema.as_ref()).unwrap(), Some(&md));
2018        }
2019        let _ = hw.write(header_batch, header_metadata.as_ref());
2020        let _ = hw.finish();
2021    }
2022
2023    let is_producer = matches!(ss, StreamStateKind::Producer(_));
2024    let stream_id = new_session_id();
2025
2026    let mut finished = false;
2027    let mut init_error: Option<RpcError> = None;
2028    {
2029        let mut sw = StreamWriter::new(&mut body_buf, output_schema.as_ref()).unwrap();
2030        if header.is_none() {
2031            for log in &init_logs {
2032                let md = build_log_metadata(log, &server.server_id, &req.request_id);
2033                let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2034            }
2035        }
2036        let _ = header_metadata;
2037        if is_producer {
2038            finished = run_producer(
2039                &mut sw,
2040                &mut ss,
2041                &output_schema,
2042                &server,
2043                &req,
2044                state.producer_batch_limit,
2045                sticky_sink.as_ref(),
2046            );
2047        }
2048        if !finished {
2049            match build_continuation_token(
2050                &state,
2051                &auth_for_token,
2052                &ss,
2053                &output_schema,
2054                input_schema.as_ref(),
2055                &stream_id,
2056            ) {
2057                Ok(token) => {
2058                    let md = Metadata::from([(STATE_KEY.to_string(), token)]);
2059                    let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2060                }
2061                Err(err) => {
2062                    // Handler doesn't implement encode_state — emit as an
2063                    // error envelope so the client sees a useful message
2064                    // instead of a hung stream.
2065                    let md = build_error_metadata(&err, &server.server_id, &req.request_id);
2066                    let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2067                    init_error = Some(err);
2068                }
2069            }
2070        }
2071        let _ = sw.finish();
2072    }
2073    let _ = init_error; // surfaced via error envelope already
2074
2075    let mut resp = arrow_response(StatusCode::OK, body_buf);
2076    if let Some(s) = sticky_sink.as_ref() {
2077        stamp_session_headers(&mut resp, &state, s);
2078    }
2079    resp
2080}
2081
2082/// Encode a `StreamStateKind` into a signed state token. The token is
2083/// bound to `auth` so a different identity replaying it will fail HMAC
2084/// verification on the next continuation request.
2085fn build_continuation_token(
2086    state: &Arc<HttpState>,
2087    auth: &crate::auth::AuthContext,
2088    ss: &StreamStateKind,
2089    output_schema: &SchemaRef,
2090    input_schema: Option<&SchemaRef>,
2091    stream_id: &str,
2092) -> Result<String> {
2093    let state_bytes = match ss {
2094        StreamStateKind::Producer(p) => p.encode_state()?,
2095        StreamStateKind::Exchange(e) => e.encode_state()?,
2096    };
2097    let out_schema_bytes = write_schema_bytes(output_schema.as_ref())?;
2098    let in_schema_bytes = match input_schema {
2099        Some(s) => write_schema_bytes(s.as_ref())?,
2100        None => Vec::new(),
2101    };
2102    Ok(state.pack_state_token(
2103        auth,
2104        &state_bytes,
2105        &out_schema_bytes,
2106        &in_schema_bytes,
2107        stream_id,
2108    ))
2109}
2110
2111fn run_producer<W: std::io::Write>(
2112    sw: &mut StreamWriter<W>,
2113    ss: &mut StreamStateKind,
2114    output_schema: &SchemaRef,
2115    server: &Arc<RpcServer>,
2116    req: &Request,
2117    limit: usize,
2118    sticky: Option<&Arc<crate::sticky::StickySinkImpl>>,
2119) -> bool {
2120    // Continuation producers run without auth context (session-bound).
2121    let mut ctx = CallContext::for_request(server, req);
2122    if let Some(s) = sticky {
2123        ctx.set_sticky(s.clone());
2124    }
2125    let producer = match ss {
2126        StreamStateKind::Producer(p) => p,
2127        StreamStateKind::Exchange(_) => unreachable!(),
2128    };
2129    // A resumable producer may cap its own per-response batch count (so it
2130    // yields a continuation instead of draining the whole shared work queue).
2131    let limit = producer.batch_limit().unwrap_or(limit);
2132    let mut batches_written = 0usize;
2133    while limit == 0 || batches_written < limit {
2134        let mut out = OutputCollector::new(output_schema.clone(), true);
2135        let result = producer.produce(&mut out, &ctx);
2136        for log in ctx.drain_logs() {
2137            let md = build_log_metadata(&log, &server.server_id, &req.request_id);
2138            let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2139        }
2140        if let Err(err) = result {
2141            let md = build_error_metadata(&err, &server.server_id, &req.request_id);
2142            let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2143            return true;
2144        }
2145        let finished = out.finished();
2146        let mut emitted_data = false;
2147        for item in out.items.drain(..) {
2148            match item {
2149                Emitted::Log(log) => {
2150                    let md = build_log_metadata(&log, &server.server_id, &req.request_id);
2151                    let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2152                }
2153                Emitted::Batch { batch, metadata } => {
2154                    let _ = sw.write(&batch, metadata.as_ref());
2155                    emitted_data = true;
2156                }
2157            }
2158        }
2159        if emitted_data {
2160            batches_written += 1;
2161        }
2162        if finished {
2163            return true;
2164        }
2165        if !emitted_data {
2166            // Guard against degenerate producers that neither emit nor finish.
2167            return true;
2168        }
2169    }
2170    false
2171}
2172
2173// ---------------------------------------------------------------------------
2174// Stream exchange / producer continuation / cancel
2175// ---------------------------------------------------------------------------
2176
2177async fn handle_stream_exchange(
2178    State(state): State<Arc<HttpState>>,
2179    Path(method): Path<String>,
2180    headers: HeaderMap,
2181    body: Bytes,
2182) -> Response {
2183    let auth = match authenticate_request(&state, &method, &headers) {
2184        Ok(a) => a,
2185        Err(resp) => return resp,
2186    };
2187    if !has_arrow_ct(&headers) {
2188        return plain_error(
2189            StatusCode::UNSUPPORTED_MEDIA_TYPE,
2190            "need arrow content type".into(),
2191        );
2192    }
2193    let sticky_sink = match sticky_for_request(&state, &auth, &headers) {
2194        Ok(s) => s,
2195        Err(resp) => return resp,
2196    };
2197
2198    let server = state.server.clone();
2199    let body = match maybe_decompress(&headers, &body, state.max_body_size) {
2200        Ok(b) => b,
2201        Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
2202    };
2203    // Parse input batch (may be empty-schema for cancel / producer continuation).
2204    let (batch, metadata) = match read_input_batch(&body) {
2205        Ok(x) => x,
2206        Err(e) => return arrow_error(&state, StatusCode::BAD_REQUEST, &e, ""),
2207    };
2208
2209    let Some(token) = md_get(&metadata, STATE_KEY).map(str::to_owned) else {
2210        let err = RpcError::runtime_error("Missing state token in exchange request");
2211        return arrow_error(&state, StatusCode::BAD_REQUEST, &err, "");
2212    };
2213    let cancelled = md_get(&metadata, CANCEL_KEY).is_some();
2214
2215    let unpacked = match state.unpack_state_token(&auth, &token) {
2216        Ok(u) => u,
2217        Err(err) => return arrow_error(&state, StatusCode::BAD_REQUEST, &err, ""),
2218    };
2219
2220    // Reconstruct schemas from token-carried IPC bytes.
2221    let output_schema = match read_schema_bytes(&unpacked.output_schema_bytes) {
2222        Ok(s) => s,
2223        Err(err) => return arrow_error(&state, StatusCode::BAD_REQUEST, &err, ""),
2224    };
2225    let input_schema: Option<SchemaRef> = if unpacked.input_schema_bytes.is_empty() {
2226        None
2227    } else {
2228        match read_schema_bytes(&unpacked.input_schema_bytes) {
2229            Ok(s) => Some(s),
2230            Err(err) => return arrow_error(&state, StatusCode::BAD_REQUEST, &err, ""),
2231        }
2232    };
2233
2234    // Resolve the method's state decoder from URL path.
2235    let Some(info) = server
2236        .method(&method)
2237        .filter(|m| m.method_type != MethodType::Unary)
2238    else {
2239        let err = RpcError::attribute_error(format!("Unknown stream method: '{}'", method));
2240        return arrow_error(&state, StatusCode::NOT_FOUND, &err, "");
2241    };
2242    let Some(decoder) = info.state_decoder.as_ref() else {
2243        let err = RpcError::runtime_error(format!(
2244            "Stream method '{method}' is registered without a state decoder; \
2245             it cannot serve HTTP continuation requests"
2246        ));
2247        return arrow_error(&state, StatusCode::INTERNAL_SERVER_ERROR, &err, "");
2248    };
2249    let mut ss = match decoder(&unpacked.state_bytes) {
2250        Ok(s) => s,
2251        Err(err) => return arrow_error(&state, StatusCode::BAD_REQUEST, &err, ""),
2252    };
2253
2254    let req = Request {
2255        method: method.clone(),
2256        request_id: md_get(&metadata, REQUEST_ID_KEY).unwrap_or("").to_string(),
2257        batch: empty_batch(&Schema::empty()).unwrap(),
2258        metadata: metadata.clone(),
2259    };
2260    let mut ctx = CallContext::for_request(&server, &req);
2261    if let Some(s) = sticky_sink.clone() {
2262        ctx.set_sticky(s);
2263    }
2264
2265    let mut body_buf = Vec::new();
2266
2267    if cancelled {
2268        match &mut ss {
2269            StreamStateKind::Producer(p) => p.on_cancel(&ctx),
2270            StreamStateKind::Exchange(e) => e.on_cancel(&ctx),
2271        }
2272        {
2273            let mut sw = StreamWriter::new(&mut body_buf, output_schema.as_ref()).unwrap();
2274            let _ = sw.finish();
2275        }
2276        let mut resp = arrow_response(StatusCode::OK, body_buf);
2277        if let Some(s) = sticky_sink.as_ref() {
2278            stamp_session_headers(&mut resp, &state, s);
2279        }
2280        return resp;
2281    }
2282
2283    if matches!(ss, StreamStateKind::Producer(_)) {
2284        // Producer continuation.
2285        let finished;
2286        {
2287            let mut sw = StreamWriter::new(&mut body_buf, output_schema.as_ref()).unwrap();
2288            finished = run_producer(
2289                &mut sw,
2290                &mut ss,
2291                &output_schema,
2292                &server,
2293                &req,
2294                state.producer_batch_limit,
2295                sticky_sink.as_ref(),
2296            );
2297            if !finished {
2298                match build_continuation_token(
2299                    &state,
2300                    &auth,
2301                    &ss,
2302                    &output_schema,
2303                    input_schema.as_ref(),
2304                    &unpacked.stream_id,
2305                ) {
2306                    Ok(new_token) => {
2307                        let md = Metadata::from([(STATE_KEY.to_string(), new_token)]);
2308                        let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2309                    }
2310                    Err(err) => {
2311                        let md = build_error_metadata(&err, &server.server_id, &req.request_id);
2312                        let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2313                    }
2314                }
2315            }
2316            let _ = sw.finish();
2317        }
2318        let mut resp = arrow_response(StatusCode::OK, body_buf);
2319        if let Some(s) = sticky_sink.as_ref() {
2320            stamp_session_headers(&mut resp, &state, s);
2321        }
2322        return resp;
2323    }
2324
2325    // Exchange continuation.
2326    let casted = match &input_schema {
2327        Some(exp) if batch.schema() != *exp => match cast_batch(&batch, exp) {
2328            Ok(b) => b,
2329            Err(e) => {
2330                let mut sw = StreamWriter::new(&mut body_buf, output_schema.as_ref()).unwrap();
2331                let md = build_error_metadata(&e, &server.server_id, &req.request_id);
2332                let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2333                let _ = sw.finish();
2334                drop(sw);
2335                return arrow_response(StatusCode::OK, body_buf);
2336            }
2337        },
2338        _ => batch,
2339    };
2340
2341    let mut out = OutputCollector::new(output_schema.clone(), false);
2342    let res = match &mut ss {
2343        StreamStateKind::Exchange(e) => e.exchange(&casted, &mut out, &ctx),
2344        _ => unreachable!(),
2345    };
2346
2347    {
2348        let mut sw = StreamWriter::new(&mut body_buf, output_schema.as_ref()).unwrap();
2349        for log in ctx.drain_logs() {
2350            let md = build_log_metadata(&log, &server.server_id, &req.request_id);
2351            let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2352        }
2353        if let Err(err) = res {
2354            let md = build_error_metadata(&err, &server.server_id, &req.request_id);
2355            let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2356        } else {
2357            let new_token = match build_continuation_token(
2358                &state,
2359                &auth,
2360                &ss,
2361                &output_schema,
2362                input_schema.as_ref(),
2363                &unpacked.stream_id,
2364            ) {
2365                Ok(t) => t,
2366                Err(err) => {
2367                    let md = build_error_metadata(&err, &server.server_id, &req.request_id);
2368                    let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2369                    let _ = sw.finish();
2370                    drop(sw);
2371                    return arrow_response(StatusCode::OK, body_buf);
2372                }
2373            };
2374            let mut wrote_data = false;
2375            for item in out.items.drain(..) {
2376                match item {
2377                    Emitted::Log(log) => {
2378                        let md = build_log_metadata(&log, &server.server_id, &req.request_id);
2379                        let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2380                    }
2381                    Emitted::Batch { batch, metadata } => {
2382                        let mut md = metadata.unwrap_or_default();
2383                        md.insert(STATE_KEY.to_string(), new_token.clone());
2384                        let _ = sw.write(&batch, Some(&md));
2385                        wrote_data = true;
2386                    }
2387                }
2388            }
2389            if !wrote_data {
2390                let md = Metadata::from([(STATE_KEY.to_string(), new_token)]);
2391                let _ = sw.write(&empty_batch(output_schema.as_ref()).unwrap(), Some(&md));
2392            }
2393        }
2394        let _ = sw.finish();
2395    }
2396
2397    let mut resp = enforce_response_body_cap(
2398        &state,
2399        output_schema.as_ref(),
2400        body_buf,
2401        &method,
2402        &server.server_id,
2403        "",
2404    );
2405    if let Some(s) = sticky_sink.as_ref() {
2406        stamp_session_headers(&mut resp, &state, s);
2407    }
2408    resp
2409}
2410
2411fn read_input_batch(body: &[u8]) -> Result<(RecordBatch, Metadata)> {
2412    let mut r = StreamReader::new(body)?;
2413    let (batch, metadata) = r
2414        .read_next()?
2415        .ok_or_else(|| RpcError::runtime_error("no batch in exchange request"))?;
2416    r.drain()?;
2417    Ok((batch, metadata))
2418}
2419
2420#[cfg(test)]
2421mod tests {
2422    use super::*;
2423    use std::time::Duration;
2424
2425    fn state_with_key() -> Arc<HttpState> {
2426        use crate::server::RpcServer;
2427        let server = Arc::new(RpcServer::builder().server_id("test").build());
2428        HttpState::builder()
2429            .server(server)
2430            .token_key(&[7u8; 32])
2431            .token_ttl(Duration::from_millis(50))
2432            .max_body_size(1024)
2433            .build()
2434    }
2435
2436    fn sample_schema_bytes() -> Vec<u8> {
2437        use arrow_schema::{DataType, Field, Schema};
2438        write_schema_bytes(&Schema::new(vec![Field::new("x", DataType::Int64, false)])).unwrap()
2439    }
2440
2441    #[tokio::test]
2442    async fn pack_unpack_roundtrip() {
2443        let s = state_with_key();
2444        let auth = crate::auth::AuthContext::anonymous();
2445        let state_bytes = b"state-payload";
2446        let out_sch = sample_schema_bytes();
2447        let in_sch = sample_schema_bytes();
2448        let token = s.pack_state_token(&auth, state_bytes, &out_sch, &in_sch, "sid-123");
2449        let unpacked = s.unpack_state_token(&auth, &token).unwrap();
2450        assert_eq!(unpacked.state_bytes, state_bytes);
2451        assert_eq!(unpacked.output_schema_bytes, out_sch);
2452        assert_eq!(unpacked.input_schema_bytes, in_sch);
2453        assert_eq!(unpacked.stream_id, "sid-123");
2454    }
2455
2456    #[tokio::test]
2457    async fn unpack_rejects_tampered_ciphertext() {
2458        let s = state_with_key();
2459        let auth = crate::auth::AuthContext::anonymous();
2460        let token = s.pack_state_token(&auth, b"s", b"o", b"i", "sid");
2461        let mut bytes = base64::engine::general_purpose::STANDARD
2462            .decode(token.as_bytes())
2463            .unwrap();
2464        // Flip a bit inside the ciphertext (past the 1-byte version + 24-byte
2465        // nonce header owned by `crypto`).
2466        let cipher_idx = 1 + 24;
2467        bytes[cipher_idx] ^= 0x01;
2468        let tampered = base64::engine::general_purpose::STANDARD.encode(bytes);
2469        assert!(s.unpack_state_token(&auth, &tampered).is_err());
2470    }
2471
2472    #[tokio::test]
2473    async fn unpack_rejects_tampered_nonce() {
2474        let s = state_with_key();
2475        let auth = crate::auth::AuthContext::anonymous();
2476        let token = s.pack_state_token(&auth, b"s", b"o", b"i", "sid");
2477        let mut bytes = base64::engine::general_purpose::STANDARD
2478            .decode(token.as_bytes())
2479            .unwrap();
2480        // Flip the first nonce byte; AEAD decryption must reject.
2481        bytes[1] ^= 0x01;
2482        let tampered = base64::engine::general_purpose::STANDARD.encode(bytes);
2483        assert!(s.unpack_state_token(&auth, &tampered).is_err());
2484    }
2485
2486    #[tokio::test]
2487    async fn unpack_rejects_unknown_version() {
2488        let s = state_with_key();
2489        let auth = crate::auth::AuthContext::anonymous();
2490        let token = s.pack_state_token(&auth, b"s", b"o", b"i", "sid");
2491        let mut bytes = base64::engine::general_purpose::STANDARD
2492            .decode(token.as_bytes())
2493            .unwrap();
2494        bytes[0] = 0x99;
2495        let tampered = base64::engine::general_purpose::STANDARD.encode(bytes);
2496        let err = s.unpack_state_token(&auth, &tampered).unwrap_err();
2497        // Wrong-version tokens map to the same uniform error as every other
2498        // bad-token mode — callers cannot distinguish failure modes.
2499        assert!(err.message.contains("signature verification failed"));
2500    }
2501
2502    #[tokio::test]
2503    async fn unpack_rejects_malformed_base64() {
2504        let s = state_with_key();
2505        let auth = crate::auth::AuthContext::anonymous();
2506        let err = s.unpack_state_token(&auth, "not!base64!").unwrap_err();
2507        assert!(err.message.contains("Malformed"));
2508    }
2509
2510    #[tokio::test]
2511    async fn unpack_rejects_different_key() {
2512        use crate::server::RpcServer;
2513        let server = Arc::new(RpcServer::builder().server_id("t").build());
2514        let a = HttpState::builder()
2515            .server(server.clone())
2516            .token_key(&[1u8; 32])
2517            .build();
2518        let b = HttpState::builder()
2519            .server(server)
2520            .token_key(&[2u8; 32])
2521            .build();
2522        let auth = crate::auth::AuthContext::anonymous();
2523        let tok = a.pack_state_token(&auth, b"s", b"o", b"i", "sid");
2524        assert!(b.unpack_state_token(&auth, &tok).is_err());
2525    }
2526
2527    #[tokio::test]
2528    async fn unpack_rejects_expired_token() {
2529        let s = state_with_key(); // ttl = 50ms
2530        let auth = crate::auth::AuthContext::anonymous();
2531        // Pack a token whose created_at is far in the past, using the
2532        // same AAD the server will reconstruct for an anonymous caller.
2533        let aad = compute_aad(&auth);
2534        let stale = pack_state_token(&[7u8; 32], &aad, b"s", b"o", b"i", "sid", 0);
2535        let err = s.unpack_state_token(&auth, &stale).unwrap_err();
2536        assert!(err.message.contains("expired"), "got: {}", err.message);
2537    }
2538
2539    #[tokio::test]
2540    async fn unpack_rejects_different_principal() {
2541        let s = state_with_key();
2542        let alice = crate::auth::AuthContext::for_principal("bearer", "alice");
2543        let bob = crate::auth::AuthContext::for_principal("bearer", "bob");
2544        let tok = s.pack_state_token(&alice, b"s", b"o", b"i", "sid");
2545        assert!(s.unpack_state_token(&alice, &tok).is_ok());
2546        assert!(s.unpack_state_token(&bob, &tok).is_err());
2547        let anon = crate::auth::AuthContext::anonymous();
2548        assert!(s.unpack_state_token(&anon, &tok).is_err());
2549    }
2550
2551    #[tokio::test]
2552    async fn unpack_rejects_authenticated_replay_of_anonymous_token() {
2553        let s = state_with_key();
2554        let anon = crate::auth::AuthContext::anonymous();
2555        let alice = crate::auth::AuthContext::for_principal("bearer", "alice");
2556        let tok = s.pack_state_token(&anon, b"s", b"o", b"i", "sid");
2557        assert!(s.unpack_state_token(&alice, &tok).is_err());
2558    }
2559
2560    #[tokio::test]
2561    async fn unpack_rejects_cross_domain_replay() {
2562        let s = state_with_key();
2563        let bearer_alice = crate::auth::AuthContext::for_principal("bearer", "alice");
2564        let mtls_alice = crate::auth::AuthContext::for_principal("mtls", "alice");
2565        let tok = s.pack_state_token(&bearer_alice, b"s", b"o", b"i", "sid");
2566        assert!(s.unpack_state_token(&mtls_alice, &tok).is_err());
2567    }
2568
2569    #[tokio::test]
2570    async fn decompress_rejects_oversize() {
2571        let hdr = HeaderMap::new();
2572        let body = Bytes::from(vec![0u8; 1025]);
2573        let err = super::maybe_decompress(&hdr, &body, 1024).unwrap_err();
2574        assert!(err.message.contains("exceeds max size"));
2575    }
2576
2577    #[test]
2578    fn zstd_bounded_rejects_zip_bomb_without_full_alloc() {
2579        // 8 MiB of zeroes compresses to a tiny payload — small enough to
2580        // pass the encoded-size check but it would blow past the limit
2581        // when fully decompressed.
2582        let huge = vec![0u8; 8 * 1024 * 1024];
2583        let compressed = zstd::encode_all(huge.as_slice(), 1).unwrap();
2584        assert!(compressed.len() < 100_000, "compressed should be tiny");
2585        let err = super::decode_zstd_bounded(&compressed, 64 * 1024).unwrap_err();
2586        assert!(
2587            err.message.contains("exceeds max size"),
2588            "expected oversize error, got: {}",
2589            err.message
2590        );
2591    }
2592
2593    #[test]
2594    fn zstd_bounded_passes_small_payload() {
2595        let small = b"hello-world".repeat(10);
2596        let compressed = zstd::encode_all(small.as_slice(), 1).unwrap();
2597        let out = super::decode_zstd_bounded(&compressed, 1024).unwrap();
2598        assert_eq!(out, small);
2599    }
2600
2601    #[test]
2602    fn decode_hex_key_roundtrip() {
2603        let key =
2604            decode_hex_key("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f")
2605                .unwrap();
2606        assert_eq!(key.len(), 32);
2607        assert_eq!(key[0], 0x00);
2608        assert_eq!(key[31], 0x1f);
2609    }
2610
2611    #[test]
2612    fn decode_hex_key_rejects_short() {
2613        assert!(decode_hex_key("deadbeef").is_err());
2614    }
2615
2616    #[test]
2617    fn decode_hex_key_rejects_bad_char() {
2618        assert!(decode_hex_key(&"zz".repeat(32)).is_err());
2619    }
2620
2621    #[test]
2622    fn decode_base64_key_accepts_padded() {
2623        let s = base64::engine::general_purpose::STANDARD.encode([7u8; 32]);
2624        let out = decode_base64_key(&s).unwrap();
2625        assert_eq!(out, vec![7u8; 32]);
2626    }
2627
2628    #[test]
2629    fn decode_base64_key_accepts_unpadded() {
2630        let s = base64::engine::general_purpose::STANDARD
2631            .encode([7u8; 32])
2632            .trim_end_matches('=')
2633            .to_string();
2634        let out = decode_base64_key(&s).unwrap();
2635        assert_eq!(out, vec![7u8; 32]);
2636    }
2637
2638    #[test]
2639    fn decode_base64_key_rejects_short() {
2640        let s = base64::engine::general_purpose::STANDARD.encode(b"short");
2641        assert!(decode_base64_key(&s).is_err());
2642    }
2643
2644    #[tokio::test]
2645    async fn token_key_hex_round_trips_through_token() {
2646        use crate::server::RpcServer;
2647        let server = Arc::new(RpcServer::builder().server_id("t").build());
2648        let a = HttpState::builder()
2649            .server(server.clone())
2650            .token_key_hex("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f")
2651            .build();
2652        let b = HttpState::builder()
2653            .server(server)
2654            .token_key_hex("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f")
2655            .build();
2656        let auth = crate::auth::AuthContext::anonymous();
2657        let tok = a.pack_state_token(&auth, b"s", b"o", b"i", "sid");
2658        assert_eq!(b.unpack_state_token(&auth, &tok).unwrap().stream_id, "sid");
2659    }
2660
2661    #[test]
2662    fn response_buffer_ceiling_never_below_hard_cap() {
2663        use crate::server::RpcServer;
2664        let mk = |soft: Option<usize>| {
2665            let server = Arc::new(RpcServer::builder().server_id("t").build());
2666            let mut b = HttpState::builder().server(server).token_key(&[9u8; 32]);
2667            if let Some(n) = soft {
2668                b = b.max_response_bytes(n);
2669            }
2670            b.build()
2671        };
2672        // No soft cap → exactly the hard cap.
2673        assert_eq!(
2674            response_buffer_ceiling(&mk(None)),
2675            MAX_RESPONSE_BYTES_HARD_CAP
2676        );
2677        // A small soft cap must NOT shrink the middleware ceiling — the
2678        // soft cap is a producer knob the wire may legitimately
2679        // overshoot.
2680        assert_eq!(
2681            response_buffer_ceiling(&mk(Some(8))),
2682            MAX_RESPONSE_BYTES_HARD_CAP
2683        );
2684        // A soft cap larger than the hard cap leaves overshoot headroom.
2685        let big = MAX_RESPONSE_BYTES_HARD_CAP;
2686        assert_eq!(response_buffer_ceiling(&mk(Some(big))), big * 2);
2687    }
2688}