use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use arrow_array::{Array, ArrayRef, BinaryArray, Int64Array, RecordBatch};
use arrow_schema::SchemaRef;
use vgi_rpc::{
Bytes, CallContext, ExchangeState, OutputCollector, Request, Result, RpcError, StreamResult,
VgiArrow,
};
use crate::aggregate::{AggregateBindParams, AggregateFunction, GROUP_COLUMN_NAME};
use crate::buffering::{BufferingParams, TableBufferingFunction};
use crate::catalog;
use crate::function::{BindParams, ProcessParams, ScalarFunction};
use crate::ipc;
use crate::protocol::dtos::*;
use crate::storage::{default_storage, FunctionStorage};
use crate::table_function::{TableFunction, TableProducer};
use crate::table_in_out::TableInOutFunction;
use crate::wire;
const PROJ_REPRO_APP: &str = "projection_repro";
const PROJ_REPRO_PREFIX: &str = "proj_repro";
pub struct Dispatcher {
pub catalog_name: String,
pub scalars: HashMap<String, Vec<Arc<dyn ScalarFunction>>>,
pub tables: HashMap<String, Vec<Arc<dyn TableFunction>>>,
pub tableinouts: HashMap<String, Vec<Arc<dyn TableInOutFunction>>>,
pub buffering: HashMap<String, Vec<Arc<dyn TableBufferingFunction>>>,
pub aggregates: HashMap<String, Vec<Arc<dyn AggregateFunction>>>,
pub store: Arc<dyn FunctionStorage>,
pub catalog: catalog::CatalogModel,
pub secondary: Vec<catalog::CatalogModel>,
secondary_functions: Vec<Vec<String>>,
pub secret_types: Vec<catalog::SecretTypeSpec>,
pub settings: Vec<catalog::SettingSpec>,
exec_counter: AtomicU64,
}
impl Dispatcher {
pub fn new(catalog_name: impl Into<String>) -> Self {
Dispatcher {
catalog_name: catalog_name.into(),
scalars: HashMap::new(),
tables: HashMap::new(),
tableinouts: HashMap::new(),
buffering: HashMap::new(),
aggregates: HashMap::new(),
store: default_storage(),
catalog: catalog::CatalogModel::default(),
secondary: Vec::new(),
secondary_functions: Vec::new(),
secret_types: Vec::new(),
settings: Vec::new(),
exec_counter: AtomicU64::new(1),
}
}
pub fn set_catalog(&mut self, model: catalog::CatalogModel) {
self.catalog = model;
}
pub fn register_secondary_catalog(
&mut self,
model: catalog::CatalogModel,
functions: Vec<String>,
) {
self.secondary.push(model);
self.secondary_functions.push(functions);
}
pub fn register_secret_type(&mut self, spec: catalog::SecretTypeSpec) {
self.secret_types.push(spec);
}
pub fn register_setting(&mut self, spec: catalog::SettingSpec) {
self.settings.push(spec);
}
pub fn register_aggregate(&mut self, f: Arc<dyn AggregateFunction>) {
self.aggregates
.entry(f.name().to_string())
.or_default()
.push(f);
}
fn resolve_aggregate(&self, name: &str) -> Result<Arc<dyn AggregateFunction>> {
self.aggregates
.get(name)
.and_then(|v| v.first())
.cloned()
.ok_or_else(|| RpcError::value_error(format!("Unknown function: '{name}'")))
}
pub fn register_scalar(&mut self, f: Arc<dyn ScalarFunction>) {
self.scalars
.entry(f.name().to_string())
.or_default()
.push(f);
}
pub fn register_table(&mut self, f: Arc<dyn TableFunction>) {
self.tables.entry(f.name().to_string()).or_default().push(f);
}
pub fn register_table_if_absent(&mut self, f: Arc<dyn TableFunction>) {
if !self.tables.contains_key(f.name()) {
self.tables.entry(f.name().to_string()).or_default().push(f);
}
}
pub fn register_table_in_out(&mut self, f: Arc<dyn TableInOutFunction>) {
self.tableinouts
.entry(f.name().to_string())
.or_default()
.push(f);
}
fn resolve_table_in_out(
&self,
name: &str,
args: &crate::arguments::Arguments,
input_schema: Option<&SchemaRef>,
) -> Result<Arc<dyn TableInOutFunction>> {
let cands = self
.tableinouts
.get(name)
.ok_or_else(|| RpcError::value_error(format!("Unknown function: '{name}'")))?;
let idx = crate::overload::resolve_overload(
cands.len(),
|i| cands[i].argument_specs(),
args,
input_schema,
)
.ok_or_else(|| RpcError::value_error(format!("No matching overload for '{name}'")))?;
Ok(cands[idx].clone())
}
pub fn register_buffering(&mut self, f: Arc<dyn TableBufferingFunction>) {
self.buffering
.entry(f.name().to_string())
.or_default()
.push(f);
}
fn resolve_buffering(&self, name: &str) -> Result<Arc<dyn TableBufferingFunction>> {
self.buffering
.get(name)
.and_then(|v| v.first())
.cloned()
.ok_or_else(|| RpcError::value_error(format!("Unknown function: '{name}'")))
}
fn resolve_table(
&self,
name: &str,
args: &crate::arguments::Arguments,
input_schema: Option<&SchemaRef>,
) -> Result<Arc<dyn TableFunction>> {
let cands = self
.tables
.get(name)
.ok_or_else(|| RpcError::value_error(format!("Unknown function: '{name}'")))?;
let idx = crate::overload::resolve_overload(
cands.len(),
|i| cands[i].argument_specs(),
args,
input_schema,
)
.ok_or_else(|| RpcError::value_error(format!("No matching overload for '{name}'")))?;
Ok(cands[idx].clone())
}
fn next_execution_id(&self) -> Vec<u8> {
let n = self.exec_counter.fetch_add(1, Ordering::Relaxed);
let t = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0);
let mut v = b"vgi-exec-".to_vec();
v.extend_from_slice(&std::process::id().to_le_bytes());
v.extend_from_slice(&t.to_le_bytes());
v.extend_from_slice(&n.to_le_bytes());
v
}
fn resolve_scalar(
&self,
name: &str,
args: &crate::arguments::Arguments,
input_schema: Option<&SchemaRef>,
) -> Result<Arc<dyn ScalarFunction>> {
let cands = self
.scalars
.get(name)
.ok_or_else(|| RpcError::value_error(format!("Unknown function: '{name}'")))?;
let idx = crate::overload::resolve_overload(
cands.len(),
|i| cands[i].argument_specs(),
args,
input_schema,
)
.ok_or_else(|| RpcError::value_error(format!("No matching overload for '{name}'")))?;
Ok(cands[idx].clone())
}
fn bind_params(&self, dto: &BindRequest, ctx: &CallContext) -> Result<BindParams> {
Ok(BindParams {
input_schema: opt_schema(&dto.input_schema)?,
arguments: crate::arguments::Arguments::parse(&dto.arguments.0)?,
settings: parse_settings(&dto.settings)?,
secrets: parse_secrets(&dto.secrets)?,
resolved_secrets_provided: dto.resolved_secrets_provided,
auth_principal: principal(ctx),
attach_opaque_data: dto.attach_opaque_data.clone().map(|b| b.into()),
transaction_opaque_data: dto.transaction_opaque_data.clone().map(|b| b.into()),
storage: Some(self.store.clone()),
})
}
pub fn handle_bind(&self, req: &Request, ctx: &CallContext) -> Result<Option<RecordBatch>> {
let dto: BindRequest = boxed(req)?;
let mut params = self.bind_params(&dto, ctx)?;
let ft = normalize_function_type(&dto.function_type.0).unwrap_or_default();
if self.buffering.contains_key(&dto.function_name) {
let f = self.resolve_buffering(&dto.function_name)?;
params.arguments.remap_positional(&f.argument_specs());
let bind = f.on_bind(¶ms)?;
let resp = BindResponse {
output_schema: Bytes::from(ipc::write_schema_ref(&bind.output_schema)?),
opaque_data: Bytes::from(bind.opaque_data),
lookup_secret_types: Vec::new(),
lookup_scopes: Vec::new(),
lookup_names: Vec::new(),
};
return Ok(Some(wire::to_result_batch(resp)?));
}
if self.tableinouts.contains_key(&dto.function_name) {
let f = self.resolve_table_in_out(
&dto.function_name,
¶ms.arguments,
params.input_schema.as_ref(),
)?;
params.arguments.remap_positional(&f.argument_specs());
let bind = f.on_bind(¶ms)?;
let resp = BindResponse {
output_schema: Bytes::from(ipc::write_schema_ref(&bind.output_schema)?),
opaque_data: Bytes::from(bind.opaque_data),
lookup_secret_types: Vec::new(),
lookup_scopes: Vec::new(),
lookup_names: Vec::new(),
};
return Ok(Some(wire::to_result_batch(resp)?));
}
if (ft == "table" || ft == "table_buffering")
|| (!self.scalars.contains_key(&dto.function_name)
&& self.tables.contains_key(&dto.function_name))
{
let f = self.resolve_table(
&dto.function_name,
¶ms.arguments,
params.input_schema.as_ref(),
)?;
params.arguments.remap_positional(&f.argument_specs());
if !params.resolved_secrets_provided {
let lookups = f.secret_lookups(¶ms);
if !lookups.is_empty() {
let resp = BindResponse {
output_schema: Bytes::from(Vec::new()),
opaque_data: Bytes::from(Vec::new()),
lookup_secret_types: lookups
.iter()
.map(|l| l.secret_type.clone())
.collect(),
lookup_scopes: lookups
.iter()
.map(|l| l.scope.clone().unwrap_or_default())
.collect(),
lookup_names: lookups
.iter()
.map(|l| l.name.clone().unwrap_or_default())
.collect(),
};
return Ok(Some(wire::to_result_batch(resp)?));
}
}
let bind = f.on_bind(¶ms)?;
let resp = BindResponse {
output_schema: Bytes::from(ipc::write_schema_ref(&bind.output_schema)?),
opaque_data: Bytes::from(bind.opaque_data),
lookup_secret_types: Vec::new(),
lookup_scopes: Vec::new(),
lookup_names: Vec::new(),
};
return Ok(Some(wire::to_result_batch(resp)?));
}
let f = self.resolve_scalar(
&dto.function_name,
¶ms.arguments,
params.input_schema.as_ref(),
)?;
params.arguments.remap_positional(&f.argument_specs());
if !params.resolved_secrets_provided {
let lookups = f.secret_lookups(¶ms);
if !lookups.is_empty() {
let resp = BindResponse {
output_schema: Bytes::from(Vec::new()),
opaque_data: Bytes::from(Vec::new()),
lookup_secret_types: lookups.iter().map(|l| l.secret_type.clone()).collect(),
lookup_scopes: lookups
.iter()
.map(|l| l.scope.clone().unwrap_or_default())
.collect(),
lookup_names: lookups
.iter()
.map(|l| l.name.clone().unwrap_or_default())
.collect(),
};
return Ok(Some(wire::to_result_batch(resp)?));
}
}
crate::function::validate_type_bounds(&f.argument_specs(), params.input_schema.as_ref())?;
let bind = f.on_bind(¶ms)?;
let resp = BindResponse {
output_schema: Bytes::from(ipc::write_schema_ref(&bind.output_schema)?),
opaque_data: Bytes::from(bind.opaque_data),
lookup_secret_types: Vec::new(),
lookup_scopes: Vec::new(),
lookup_names: Vec::new(),
};
Ok(Some(wire::to_result_batch(resp)?))
}
pub fn handle_init(&self, req: &Request, ctx: &CallContext) -> Result<StreamResult> {
let dto: InitRequest = boxed(req)?;
let bind_call: BindRequest = wire::from_batch(&ipc::read_batch(&dto.bind_call.0)?)?;
let mut bp = self.bind_params(&bind_call, ctx)?;
let output_schema = crate::table_function::project_schema(
&ipc::read_schema(&dto.output_schema.0)?,
&dto.projection_ids,
);
let input_schema = bp.input_schema.clone();
let execution_id = dto
.execution_id
.clone()
.map(|b| b.into())
.unwrap_or_else(|| self.next_execution_id());
let ft = normalize_function_type(&bind_call.function_type.0).unwrap_or_default();
let build_params =
|args: crate::arguments::Arguments, settings, secrets, auth| ProcessParams {
output_schema: output_schema.clone(),
input_schema: input_schema.clone(),
execution_id: execution_id.clone(),
init_opaque_data: dto
.bind_opaque_data
.clone()
.map(|b| b.into())
.unwrap_or_default(),
arguments: args,
settings,
secrets,
auth_principal: auth,
projection_ids: dto.projection_ids.clone(),
pushdown_filters: dto.pushdown_filters.clone().map(|b| b.0),
join_keys: dto
.join_keys
.clone()
.map(|v| v.into_iter().map(|b| b.0).collect())
.unwrap_or_default(),
storage: Some(self.store.clone()),
order_by_column: dto.order_by_column_name.clone(),
order_by_direction: dto.order_by_direction.clone().map(|d| d.0),
order_by_null_order: dto.order_by_null_order.clone().map(|d| d.0),
order_by_limit: dto.order_by_limit,
tablesample_percentage: dto.tablesample_percentage,
tablesample_seed: dto.tablesample_seed,
attach_opaque_data: bind_call.attach_opaque_data.clone().map(|b| b.into()),
at_unit: bind_call.at_unit.clone().filter(|s| !s.is_empty()),
at_value: bind_call.at_value.clone().filter(|s| !s.is_empty()),
};
if self.buffering.contains_key(&bind_call.function_name) {
let f = self.resolve_buffering(&bind_call.function_name)?;
bp.arguments.remap_positional(&f.argument_specs());
let phase = dto.phase.as_ref().map(|d| d.0.clone()).unwrap_or_default();
let header = wire::to_batch(GlobalInitResponse {
execution_id: Bytes::from(execution_id.clone()),
max_workers: 1,
opaque_data: None,
})?;
if phase == crate::protocol::enums::phase::TABLE_BUFFERING_FINALIZE {
let fsid = dto
.finalize_state_id
.clone()
.map(|b| b.0)
.unwrap_or_default();
let bparams = BufferingParams {
execution_id,
storage: self.store.clone(),
output_schema: output_schema.clone(),
arguments: bp.arguments,
settings: bp.settings,
attach_opaque_data: bind_call.attach_opaque_data.clone().map(|b| b.into()),
batch_index: None,
logs: std::sync::Arc::new(std::sync::Mutex::new(Vec::new())),
};
let auto_apply = f.metadata().auto_apply_filters;
let filters = if auto_apply {
dto.pushdown_filters
.as_ref()
.map(|b| crate::pushdown::PushdownFilters::parse(&b.0))
.transpose()?
} else {
None
};
let producer = f.finalize_producer(&bparams, fsid)?;
let state = TableProducerState {
inner: producer,
filters,
project_to: None,
resume_blob: None,
};
return Ok(
StreamResult::producer(output_schema, Box::new(state)).with_header(header)
);
}
self.store.kv_put(
&execution_id,
b"outsc",
&ipc::write_schema_ref(&output_schema)?,
);
if let Some(insc) = input_schema.as_ref() {
self.store
.kv_put(&execution_id, b"insc", &ipc::write_schema_ref(insc)?);
}
self.store.kv_put(
&execution_id,
b"bufflags",
&[bp.arguments.named_bool("logging").unwrap_or(false) as u8],
);
self.store
.kv_put(&execution_id, b"bufargs", &bind_call.arguments.0);
if let Some(a) = bind_call.attach_opaque_data.as_ref() {
self.store.kv_put(&execution_id, b"bufattach", &a.0);
}
let state = TableProducerState {
inner: Box::new(EmptyProducer),
filters: None,
project_to: None,
resume_blob: None,
};
return Ok(StreamResult::producer(output_schema, Box::new(state)).with_header(header));
}
if self.tableinouts.contains_key(&bind_call.function_name) {
let f = self.resolve_table_in_out(
&bind_call.function_name,
&bp.arguments,
input_schema.as_ref(),
)?;
bp.arguments.remap_positional(&f.argument_specs());
let auto_apply = f.metadata().auto_apply_filters;
let params = build_params(bp.arguments, bp.settings, bp.secrets, bp.auth_principal);
let phase = dto.phase.as_ref().map(|d| d.0.clone()).unwrap_or_default();
if phase == crate::protocol::enums::phase::FINALIZE {
let header = wire::to_batch(GlobalInitResponse {
execution_id: Bytes::from(execution_id.clone()),
max_workers: 1,
opaque_data: None,
})?;
let batches = f.finish(¶ms)?;
let state = TableProducerState {
inner: Box::new(VecProducer { batches, pos: 0 }),
filters: None,
project_to: None,
resume_blob: None,
};
return Ok(
StreamResult::producer(output_schema, Box::new(state)).with_header(header)
);
}
let filters = if auto_apply {
params
.pushdown_filters
.as_ref()
.map(|b| {
crate::pushdown::PushdownFilters::parse_with_join_keys(b, ¶ms.join_keys)
})
.transpose()?
} else {
None
};
let header = wire::to_batch(GlobalInitResponse {
execution_id: Bytes::from(execution_id.clone()),
max_workers: 1,
opaque_data: None,
})?;
let blob = self.exchange_blob(
"table_in_out",
bind_call.function_name.clone(),
&output_schema,
input_schema.as_ref(),
&bind_call,
&dto,
&execution_id,
auto_apply,
)?;
let in_schema = input_schema.unwrap_or_else(|| Arc::new(arrow_schema::Schema::empty()));
let state = TableInOutExchangeState {
func: f,
params,
filters,
blob,
};
return Ok(
StreamResult::exchange(output_schema, in_schema, Box::new(state))
.with_header(header),
);
}
if (ft == "table" || ft == "table_buffering")
|| (!self.scalars.contains_key(&bind_call.function_name)
&& self.tables.contains_key(&bind_call.function_name))
{
let f = self.resolve_table(
&bind_call.function_name,
&bp.arguments,
input_schema.as_ref(),
)?;
bp.arguments.remap_positional(&f.argument_specs());
let max_workers = f.max_workers(&bp);
let auto_apply = f.metadata().auto_apply_filters;
let params = build_params(bp.arguments, bp.settings, bp.secrets, bp.auth_principal);
if dto.execution_id.is_none() {
f.on_init(¶ms)?;
}
let filters = if auto_apply {
params
.pushdown_filters
.as_ref()
.map(|b| {
crate::pushdown::PushdownFilters::parse_with_join_keys(b, ¶ms.join_keys)
})
.transpose()?
} else {
None
};
let project_to = Some(output_schema.clone());
let producer = f.producer(¶ms)?;
let resume_blob = if max_workers > 1 {
Some(self.exchange_blob(
"table",
bind_call.function_name.clone(),
&output_schema,
None,
&bind_call,
&dto,
&execution_id,
auto_apply,
)?)
} else {
None
};
let header = wire::to_batch(GlobalInitResponse {
execution_id: Bytes::from(execution_id),
max_workers,
opaque_data: None,
})?;
let state = TableProducerState {
inner: producer,
filters,
project_to,
resume_blob,
};
return Ok(StreamResult::producer(output_schema, Box::new(state)).with_header(header));
}
let f = self.resolve_scalar(
&bind_call.function_name,
&bp.arguments,
input_schema.as_ref(),
)?;
bp.arguments.remap_positional(&f.argument_specs());
let params = build_params(bp.arguments, bp.settings, bp.secrets, bp.auth_principal);
let header = wire::to_batch(GlobalInitResponse {
execution_id: Bytes::from(execution_id.clone()),
max_workers: 1,
opaque_data: None,
})?;
let blob = self.exchange_blob(
"scalar",
bind_call.function_name.clone(),
&output_schema,
input_schema.as_ref(),
&bind_call,
&dto,
&execution_id,
false,
)?;
let state = ScalarExchangeState {
func: f,
params,
blob,
};
let in_schema = input_schema.unwrap_or_else(|| Arc::new(arrow_schema::Schema::empty()));
Ok(StreamResult::exchange(output_schema, in_schema, Box::new(state)).with_header(header))
}
#[allow(clippy::too_many_arguments)]
fn exchange_blob(
&self,
kind: &str,
function_name: String,
output_schema: &arrow_schema::SchemaRef,
input_schema: Option<&arrow_schema::SchemaRef>,
bind_call: &BindRequest,
dto: &InitRequest,
execution_id: &[u8],
auto_apply: bool,
) -> Result<Vec<u8>> {
let blob = ExchangeBlob {
kind: kind.to_string(),
function_name,
output_schema: ipc::write_schema_ref(output_schema)?,
input_schema: match input_schema {
Some(s) => ipc::write_schema_ref(s)?,
None => Vec::new(),
},
arguments: bind_call.arguments.0.clone(),
settings: bind_call.settings.clone().map(|b| b.0).unwrap_or_default(),
secrets: bind_call.secrets.clone().map(|b| b.0).unwrap_or_default(),
execution_id: execution_id.to_vec(),
init_opaque: dto
.bind_opaque_data
.clone()
.map(|b| b.0)
.unwrap_or_default(),
pushdown_filters: dto
.pushdown_filters
.clone()
.map(|b| b.0)
.unwrap_or_default(),
auto_apply,
inner_resume: Vec::new(),
at_unit: bind_call.at_unit.clone().unwrap_or_default(),
at_value: bind_call.at_value.clone().unwrap_or_default(),
};
vgi_rpc::stream_codec::bincode_encode(&blob)
}
pub fn decode_init_state(&self, bytes: &[u8]) -> Result<vgi_rpc::stream::StreamStateKind> {
let blob: ExchangeBlob = vgi_rpc::stream_codec::bincode_decode(bytes)?;
let output_schema = ipc::read_schema(&blob.output_schema)?;
let input_schema = if blob.input_schema.is_empty() {
None
} else {
Some(ipc::read_schema(&blob.input_schema)?)
};
let settings = if blob.settings.is_empty() {
crate::settings::Settings::default()
} else {
crate::settings::Settings::parse(&blob.settings)?
};
let secrets = if blob.secrets.is_empty() {
crate::secrets::Secrets::default()
} else {
crate::secrets::Secrets::parse(&blob.secrets)?
};
let pushdown = if blob.pushdown_filters.is_empty() {
None
} else {
Some(blob.pushdown_filters.clone())
};
let mut args = crate::arguments::Arguments::parse(&blob.arguments)?;
let make_params = |args: crate::arguments::Arguments| ProcessParams {
output_schema: output_schema.clone(),
input_schema: input_schema.clone(),
execution_id: blob.execution_id.clone(),
init_opaque_data: blob.init_opaque.clone(),
arguments: args,
settings: settings.clone(),
secrets: secrets.clone(),
auth_principal: None,
projection_ids: None,
pushdown_filters: pushdown.clone(),
join_keys: Vec::new(),
storage: Some(self.store.clone()),
order_by_column: None,
order_by_direction: None,
order_by_null_order: None,
order_by_limit: None,
tablesample_percentage: None,
tablesample_seed: None,
attach_opaque_data: None,
at_unit: Some(blob.at_unit.clone()).filter(|s| !s.is_empty()),
at_value: Some(blob.at_value.clone()).filter(|s| !s.is_empty()),
};
if blob.kind == "table" {
let f = self.resolve_table(&blob.function_name, &args, input_schema.as_ref())?;
args.remap_positional(&f.argument_specs());
let params = make_params(args);
let filters = if blob.auto_apply {
params
.pushdown_filters
.as_ref()
.map(|b| {
crate::pushdown::PushdownFilters::parse_with_join_keys(b, ¶ms.join_keys)
})
.transpose()?
} else {
None
};
let project_to = Some(output_schema.clone());
let mut producer = f.producer(¶ms)?;
producer.restore_resume(&blob.inner_resume);
return Ok(vgi_rpc::stream::StreamStateKind::Producer(Box::new(
TableProducerState {
inner: producer,
filters,
project_to,
resume_blob: Some(bytes.to_vec()),
},
)));
}
if blob.kind == "table_in_out" {
let f = self.resolve_table_in_out(&blob.function_name, &args, input_schema.as_ref())?;
args.remap_positional(&f.argument_specs());
let params = make_params(args);
let filters = if blob.auto_apply {
params
.pushdown_filters
.as_ref()
.map(|b| {
crate::pushdown::PushdownFilters::parse_with_join_keys(b, ¶ms.join_keys)
})
.transpose()?
} else {
None
};
Ok(vgi_rpc::stream::StreamStateKind::Exchange(Box::new(
TableInOutExchangeState {
func: f,
params,
filters,
blob: bytes.to_vec(),
},
)))
} else {
let f = self.resolve_scalar(&blob.function_name, &args, input_schema.as_ref())?;
args.remap_positional(&f.argument_specs());
let params = make_params(args);
Ok(vgi_rpc::stream::StreamStateKind::Exchange(Box::new(
ScalarExchangeState {
func: f,
params,
blob: bytes.to_vec(),
},
)))
}
}
fn attach_bytes(&self) -> Vec<u8> {
self.catalog_name.as_bytes().to_vec()
}
fn active_catalog<'a>(&'a self, req: &Request) -> &'a catalog::CatalogModel {
if let Some((name, _)) = read_binary_col(req, "attach_opaque_data")
.as_deref()
.and_then(decode_secondary_opaque)
{
if let Some(c) = self.secondary.iter().find(|c| c.name == name) {
return c;
}
}
&self.catalog
}
fn catalog_schema_names(cat: &catalog::CatalogModel) -> Vec<String> {
let mut names: Vec<String> = cat.schemas.iter().map(|s| s.name.clone()).collect();
if !names.iter().any(|n| n == catalog::MAIN_SCHEMA) {
names.insert(0, catalog::MAIN_SCHEMA.to_string());
}
names
}
pub fn handle_catalog_catalogs(&self, _req: &Request) -> Result<Option<RecordBatch>> {
let mut items = vec![Bytes::from(catalog::serialize_catalog_info(&self.catalog)?)];
for sec in &self.secondary {
items.push(Bytes::from(catalog::serialize_catalog_info(sec)?));
}
Ok(Some(wire::to_result_batch(ItemsResult { items })?))
}
pub fn handle_catalog_attach(&self, req: &Request) -> Result<Option<RecordBatch>> {
let dto: CatalogAttachRequest = boxed(req)?;
if let Some(sec) = self.secondary.iter().find(|c| c.name == dto.name) {
let scope = self.next_execution_id();
let result = CatalogAttachResult {
attach_opaque_data: Bytes::from(encode_secondary_opaque(&sec.name, &scope)),
supports_transactions: true,
supports_time_travel: sec.supports_time_travel,
catalog_version_frozen: false,
catalog_version: 1,
attach_opaque_data_required: true,
default_schema: catalog::MAIN_SCHEMA.to_string(),
settings: Vec::new(),
secret_types: Vec::new(),
comment: sec.comment.clone(),
tags: sec.tags.clone(),
supports_column_statistics: false,
resolved_data_version: sec.data_version_spec.clone(),
resolved_implementation_version: sec.implementation_version.clone(),
};
return Ok(Some(wire::to_result_batch(result)?));
}
let (resolved_data_version, resolved_implementation_version) =
self.resolve_versions(&dto)?;
let attach_opaque_data =
if let Some(default_bytes) = &self.catalog.attach_options_default_batch {
let default_batch = ipc::read_batch(default_bytes)?;
let options = dto
.options
.as_ref()
.map(|b| ipc::read_batch(&b.0))
.transpose()?;
let cols: Vec<arrow_array::ArrayRef> = default_batch
.schema()
.fields()
.iter()
.enumerate()
.map(|(i, f)| -> Result<arrow_array::ArrayRef> {
match options.as_ref().and_then(|o| o.column_by_name(f.name())) {
Some(c) => arrow_cast::cast(c, f.data_type())
.map_err(|e| RpcError::runtime_error(e.to_string())),
None => Ok(default_batch.column(i).clone()),
}
})
.collect::<Result<_>>()?;
let merged = RecordBatch::try_new(default_batch.schema(), cols)
.map_err(|e| RpcError::runtime_error(e.to_string()))?;
let id = self.attach_bytes();
let mut v: Vec<u8> = id
.iter()
.copied()
.chain(std::iter::repeat(0))
.take(16)
.collect();
v.push(0);
v.extend_from_slice(&ipc::write_batch(&merged)?);
v
} else if !self.catalog.version_schemas.is_empty() {
let mut v = resolved_data_version
.clone()
.unwrap_or_default()
.into_bytes();
v.push(0);
v.extend_from_slice(&self.attach_bytes());
v
} else if dto.name == PROJ_REPRO_APP {
PROJ_REPRO_APP.as_bytes().to_vec()
} else {
self.attach_bytes()
};
let result = CatalogAttachResult {
attach_opaque_data: Bytes::from(attach_opaque_data),
supports_transactions: true,
supports_time_travel: self.catalog.supports_time_travel,
catalog_version_frozen: false,
catalog_version: 1,
attach_opaque_data_required: true,
default_schema: catalog::MAIN_SCHEMA.to_string(),
settings: self
.settings
.iter()
.map(|s| Ok(Bytes::from(catalog::serialize_setting(s)?)))
.collect::<Result<Vec<_>>>()?,
secret_types: self
.secret_types
.iter()
.map(|s| Ok(Bytes::from(catalog::serialize_secret_type(s)?)))
.collect::<Result<Vec<_>>>()?,
comment: self.catalog.comment.clone(),
tags: self.catalog.tags.clone(),
supports_column_statistics: self
.catalog
.schemas
.iter()
.flat_map(|s| &s.tables)
.any(|t| !t.statistics.is_empty()),
resolved_data_version,
resolved_implementation_version,
};
Ok(Some(wire::to_result_batch(result)?))
}
fn resolve_versions(
&self,
dto: &CatalogAttachRequest,
) -> Result<(Option<String>, Option<String>)> {
let cat = &self.catalog;
let resolved_impl =
if cat.npm_version_resolution && !cat.supported_implementation_versions.is_empty() {
Some(catalog::resolve_version_npm(
dto.implementation_version.as_deref(),
&cat.supported_implementation_versions,
cat.implementation_version.as_deref().unwrap_or(""),
"implementation_version",
)?)
} else {
match (&dto.implementation_version, &cat.implementation_version) {
(Some(req), Some(have)) if req != have => {
return Err(RpcError::value_error(format!(
"Unsupported implementation_version {req:?}; this worker serves {have:?}"
)));
}
(_, have) => have.clone(),
}
};
let resolved_data = if cat.supported_data_versions.is_empty() {
None
} else if cat.npm_version_resolution {
Some(catalog::resolve_version_npm(
dto.data_version_spec.as_deref(),
&cat.supported_data_versions,
cat.default_data_version.as_deref().unwrap_or(""),
"data_version_spec",
)?)
} else if let Some(req) = &dto.data_version_spec {
if !cat.supported_data_versions.contains(req) {
return Err(RpcError::value_error(format!(
"Unsupported data_version_spec {req:?}; this worker serves one of {:?}",
cat.supported_data_versions
)));
}
Some(req.clone())
} else {
cat.default_data_version.clone()
};
Ok((resolved_data, resolved_impl))
}
fn req_version(&self, req: &Request) -> Option<String> {
if self.catalog.version_schemas.is_empty() {
return None;
}
let bytes = read_binary_col(req, "attach_opaque_data")?;
let sep = bytes.iter().position(|&b| b == 0)?;
if sep == 0 {
return None;
}
String::from_utf8(bytes[..sep].to_vec()).ok()
}
fn schema_for_req<'a>(&'a self, req: &Request, name: &str) -> Option<&'a catalog::CatSchema> {
let cat = self.active_catalog(req);
if std::ptr::eq(cat, &self.catalog) {
let v = self.req_version(req);
self.catalog
.schemas_for(v.as_deref())
.iter()
.find(|s| s.name == name)
} else {
cat.schemas.iter().find(|s| s.name == name)
}
}
pub fn handle_catalog_version(&self, _req: &Request) -> Result<Option<RecordBatch>> {
Ok(Some(wire::to_result_batch(CatalogVersionResult {
version: 1,
})?))
}
pub fn handle_transaction_begin(&self, _req: &Request) -> Result<Option<RecordBatch>> {
Ok(Some(wire::to_result_batch(
CatalogTransactionBeginResult {
transaction_opaque_data: Some(Bytes::from(self.next_execution_id())),
},
)?))
}
fn schema_info_for(&self, cat: &catalog::CatalogModel, name: &str) -> SchemaInfo {
let comment = cat.schema(name).and_then(|s| s.comment.as_deref()).or(
if name == catalog::MAIN_SCHEMA {
Some("Default schema containing all registered functions")
} else {
None
},
);
let is_primary = std::ptr::eq(cat, &self.catalog);
let attach = if is_primary {
self.attach_bytes()
} else {
cat.name.as_bytes().to_vec()
};
let mut si = catalog::schema_info(name, comment, &attach);
if !is_primary || !self.catalog.version_schemas.is_empty() {
return si;
}
let sch = cat.schema(name);
let len = |n: usize| n as i64;
let (sf, af, tf) = if name == catalog::MAIN_SCHEMA {
(
len(self.scalars.len()),
len(self.aggregates.len()),
len(self.tables.len() + self.tableinouts.len() + self.buffering.len()),
)
} else {
(0, 0, 0)
};
si.estimated_object_count = Some(vec![
("view".into(), len(sch.map(|s| s.views.len()).unwrap_or(0))),
(
"macro".into(),
len(sch.map(|s| s.macros.len()).unwrap_or(0)),
),
(
"table".into(),
len(sch.map(|s| s.tables.len()).unwrap_or(0)),
),
("scalar_function".into(), sf),
("aggregate_function".into(), af),
("table_function".into(), tf),
("index".into(), 0),
]);
si
}
pub fn handle_catalog_schemas(&self, req: &Request) -> Result<Option<RecordBatch>> {
let cat = self.active_catalog(req);
let infos: Vec<SchemaInfo> = Self::catalog_schema_names(cat)
.iter()
.map(|n| self.schema_info_for(cat, n))
.collect();
let items = catalog::serialize_items(infos)?;
Ok(Some(wire::to_result_batch(ItemsResult { items })?))
}
pub fn handle_schema_get(&self, req: &Request) -> Result<Option<RecordBatch>> {
let p: CatalogSchemaNameParams = wire::from_batch(&req.batch)?;
let cat = self.active_catalog(req);
let items = if Self::catalog_schema_names(cat).iter().any(|n| n == &p.name) {
catalog::serialize_items(vec![self.schema_info_for(cat, &p.name)])?
} else {
Vec::new()
};
Ok(Some(wire::to_result_batch(ItemsResult { items })?))
}
pub fn handle_contents_views(&self, req: &Request) -> Result<Option<RecordBatch>> {
let name = read_string_col(req, "name")?;
let infos: Vec<ViewInfo> = self
.schema_for_req(req, &name)
.map(|s| {
s.views
.iter()
.map(|v| catalog::view_info(&name, v))
.collect()
})
.unwrap_or_default();
Ok(Some(wire::to_result_batch(ItemsResult {
items: catalog::serialize_items(infos)?,
})?))
}
pub fn handle_contents_tables(&self, req: &Request) -> Result<Option<RecordBatch>> {
let name = read_string_col(req, "name")?;
let infos: Vec<TableInfo> = match self.schema_for_req(req, &name) {
Some(s) => s
.tables
.iter()
.map(|t| catalog::table_info(&name, t))
.collect::<Result<_>>()?,
None => Vec::new(),
};
Ok(Some(wire::to_result_batch(ItemsResult {
items: catalog::serialize_items(infos)?,
})?))
}
pub fn handle_table_get(&self, req: &Request) -> Result<Option<RecordBatch>> {
let schema_name = read_string_col(req, "schema_name")?;
let table_name = read_string_col(req, "name")?;
let at_unit = read_opt_string_col(req, "at_unit");
let at_value = read_opt_string_col(req, "at_value");
let infos: Vec<TableInfo> = self
.schema_for_req(req, &schema_name)
.and_then(|s| s.tables.iter().find(|t| t.name == table_name))
.map(|t| {
let tt = Self::at_version(t, at_unit.as_deref(), at_value.as_deref())?;
catalog::table_info(&schema_name, &tt)
})
.transpose()?
.into_iter()
.collect();
Ok(Some(wire::to_result_batch(ItemsResult {
items: catalog::serialize_items(infos)?,
})?))
}
pub fn handle_table_scan_function_get(&self, req: &Request) -> Result<Option<RecordBatch>> {
let schema_name = read_string_col(req, "schema_name")?;
let table_name = read_string_col(req, "name")?;
let at_unit = read_opt_string_col(req, "at_unit");
let at_value = read_opt_string_col(req, "at_value");
let t = self
.schema_for_req(req, &schema_name)
.and_then(|s| s.tables.iter().find(|t| t.name == table_name))
.ok_or_else(|| {
RpcError::value_error(format!("Unknown table: '{schema_name}.{table_name}'"))
})?;
let t = Self::at_version(t, at_unit.as_deref(), at_value.as_deref())?;
Ok(Some(wire::to_result_batch(catalog::scan_function_result(
&t,
)?)?))
}
fn at_version(
t: &catalog::CatTable,
at_unit: Option<&str>,
at_value: Option<&str>,
) -> Result<catalog::CatTable> {
match t.resolve_version(at_unit, at_value)? {
Some(v) => {
let mut tt = t.clone();
tt.columns = v.columns.clone();
tt.scan_function = v.scan_function.clone();
tt.scan_arguments = v.scan_arguments.clone();
if !t.is_current_version(v.version) {
tt.not_null.clear();
tt.primary_key.clear();
tt.unique.clear();
tt.check.clear();
tt.foreign_keys.clear();
}
Ok(tt)
}
None => Ok(t.clone()),
}
}
pub fn handle_table_function_cardinality(
&self,
req: &Request,
ctx: &CallContext,
) -> Result<Option<RecordBatch>> {
let dto: CardinalityRequest = boxed(req)?;
let bind_call: BindRequest = wire::from_batch(&ipc::read_batch(&dto.bind_call.0)?)?;
let bp = self.bind_params(&bind_call, ctx)?;
let card = self
.tables
.get(&bind_call.function_name)
.and_then(|v| v.first())
.and_then(|f| f.cardinality(&bp));
let resp = crate::protocol::dtos::CardinalityResponse {
estimate: Some(card.and_then(|c| c.estimate).unwrap_or(-1)),
max: Some(card.and_then(|c| c.max).unwrap_or(-1)),
};
Ok(Some(wire::to_result_batch(resp)?))
}
pub fn handle_table_function_dynamic_to_string(
&self,
req: &Request,
) -> Result<Option<RecordBatch>> {
use crate::protocol::dtos::{DynamicToStringRequest, DynamicToStringResponse};
let dto: DynamicToStringRequest = boxed(req)?;
let bind_call: BindRequest = wire::from_batch(&ipc::read_batch(&dto.bind_call.0)?)?;
let pairs = self
.tables
.get(&bind_call.function_name)
.and_then(|v| v.first())
.map(|f| f.dynamic_to_string(&dto.global_execution_id.0, self.store.as_ref()))
.unwrap_or_default();
let (keys, values): (Vec<String>, Vec<String>) = pairs.into_iter().unzip();
Ok(Some(wire::to_result_batch(DynamicToStringResponse {
keys,
values,
})?))
}
pub fn handle_table_function_statistics(
&self,
req: &Request,
ctx: &CallContext,
) -> Result<Option<RecordBatch>> {
let dto: CardinalityRequest = boxed(req)?;
let bind_call: BindRequest = wire::from_batch(&ipc::read_batch(&dto.bind_call.0)?)?;
let bp = self.bind_params(&bind_call, ctx)?;
let stats = self
.tables
.get(&bind_call.function_name)
.and_then(|v| v.first())
.and_then(|f| f.statistics(&bp))
.unwrap_or_default();
let bytes = crate::statistics::serialize_column_statistics(&stats)?;
Ok(Some(wire::result_batch_from_bytes(&bytes)?))
}
pub fn handle_table_column_statistics_get(&self, req: &Request) -> Result<Option<RecordBatch>> {
let schema_name = read_string_col(req, "schema_name")?;
let table_name = read_string_col(req, "name")?;
let stats = self
.catalog
.schema(&schema_name)
.and_then(|s| s.tables.iter().find(|t| t.name == table_name))
.map(|t| t.statistics.clone())
.unwrap_or_default();
let bytes = crate::statistics::serialize_column_statistics(&stats)?;
Ok(Some(wire::result_batch_from_bytes(&bytes)?))
}
pub fn handle_table_scan_branches_get(&self, req: &Request) -> Result<Option<RecordBatch>> {
use crate::protocol::dtos::{ScanBranch, ScanBranchesResult};
let schema_name = read_string_col(req, "schema_name")?;
let table_name = read_string_col(req, "name")?;
let at_unit = read_opt_string_col(req, "at_unit");
let at_value = read_opt_string_col(req, "at_value");
let base = self
.schema_for_req(req, &schema_name)
.and_then(|s| s.tables.iter().find(|t| t.name == table_name))
.ok_or_else(|| {
RpcError::value_error(format!("Unknown table: '{schema_name}.{table_name}'"))
})?;
let resolved = Self::at_version(base, at_unit.as_deref(), at_value.as_deref())?;
let t = &resolved;
let mk = |b: ScanBranch| -> Result<Bytes> {
Ok(Bytes::from(ipc::write_batch(&wire::to_batch(b)?)?))
};
let branches: Vec<Bytes> = match &t.branches {
Some(defs) => defs
.iter()
.map(|d| {
mk(ScanBranch {
function_name: d.function_name.clone(),
arguments: Bytes::from(d.scan_arguments.clone()),
branch_filter: d.branch_filter.clone(),
writable: d.writable,
})
})
.collect::<Result<_>>()?,
None => vec![mk(ScanBranch {
function_name: t.scan_function.clone(),
arguments: Bytes::from(t.scan_arguments.clone()),
branch_filter: None,
writable: false,
})?],
};
Ok(Some(wire::to_result_batch(ScanBranchesResult {
branches,
required_extensions: Vec::new(),
})?))
}
pub fn handle_contents_macros(&self, req: &Request) -> Result<Option<RecordBatch>> {
let name = read_string_col(req, "name")?;
let want = normalize_function_type(&read_string_col(req, "type").unwrap_or_default());
let infos: Vec<MacroInfo> = self
.schema_for_req(req, &name)
.map(|s| {
s.macros
.iter()
.filter(|m| match want.as_deref() {
Some("table") => m.table_macro,
Some("scalar") => !m.table_macro,
_ => true,
})
.map(|m| catalog::macro_info(&name, m))
.collect()
})
.unwrap_or_default();
Ok(Some(wire::to_result_batch(ItemsResult {
items: catalog::serialize_items(infos)?,
})?))
}
pub fn handle_contents_functions(&self, req: &Request) -> Result<Option<RecordBatch>> {
let schema_name = read_string_col(req, "name")?;
let type_filter = read_string_col(req, "type").unwrap_or_default();
let is_proj_repro = read_binary_col(req, "attach_opaque_data")
.map(|b| b == PROJ_REPRO_APP.as_bytes())
.unwrap_or(false);
let active = self.active_catalog(req);
let active_sec_fns: Option<&[String]> = self
.secondary
.iter()
.position(|c| std::ptr::eq(c, active))
.and_then(|i| self.secondary_functions.get(i))
.map(|v| v.as_slice());
let all_sec_fns: std::collections::HashSet<&str> = self
.secondary_functions
.iter()
.flatten()
.map(|s| s.as_str())
.collect();
let visible = |name: &str| {
if name.starts_with(PROJ_REPRO_PREFIX) != is_proj_repro {
return false;
}
match active_sec_fns {
Some(fns) => fns.iter().any(|f| f == name),
None => !all_sec_fns.contains(name),
}
};
let mut infos = Vec::new();
if schema_name == catalog::MAIN_SCHEMA {
let want = normalize_function_type(&type_filter);
if want.as_deref() == Some("scalar") || want.is_none() {
let mut names: Vec<&String> = self.scalars.keys().filter(|n| visible(n)).collect();
names.sort();
for name in names {
for f in &self.scalars[name] {
infos.push(catalog::scalar_function_info(f.as_ref())?);
}
}
}
if matches!(want.as_deref(), Some("table") | Some("table_buffering")) || want.is_none()
{
let mut names: Vec<&String> = self.tables.keys().filter(|n| visible(n)).collect();
names.sort();
for name in names {
for f in &self.tables[name] {
infos.push(catalog::table_function_info(f.as_ref())?);
}
}
let mut tio: Vec<&String> =
self.tableinouts.keys().filter(|n| visible(n)).collect();
tio.sort();
for name in tio {
for f in &self.tableinouts[name] {
infos.push(catalog::table_in_out_function_info(f.as_ref())?);
}
}
let mut buf: Vec<&String> = self.buffering.keys().filter(|n| visible(n)).collect();
buf.sort();
for name in buf {
for f in &self.buffering[name] {
infos.push(catalog::buffering_function_info(f.as_ref())?);
}
}
}
if matches!(want.as_deref(), Some("aggregate")) || want.is_none() {
let mut agg: Vec<&String> = self.aggregates.keys().filter(|n| visible(n)).collect();
agg.sort();
for name in agg {
for f in &self.aggregates[name] {
infos.push(catalog::aggregate_function_info(f.as_ref())?);
}
}
}
}
let items = catalog::serialize_items(infos)?;
Ok(Some(wire::to_result_batch(ItemsResult { items })?))
}
fn buffering_output_schema(
&self,
execution_id: &[u8],
f: &dyn TableBufferingFunction,
input_schema: Option<arrow_schema::SchemaRef>,
) -> Result<arrow_schema::SchemaRef> {
if let Some(s) = self
.store
.kv_get(execution_id, b"outsc")
.and_then(|b| ipc::read_schema(&b).ok())
{
return Ok(s);
}
let input_schema = input_schema.or_else(|| {
self.store
.kv_get(execution_id, b"insc")
.and_then(|b| ipc::read_schema(&b).ok())
});
let Some(input_schema) = input_schema else {
return Err(RpcError::runtime_error(
"table-buffering: bound output schema unavailable (sink-init state \
not found on this worker and no input schema to rebind from)"
.to_string(),
));
};
let bind = f.on_bind(&BindParams {
input_schema: Some(input_schema),
arguments: self.buffering_arguments(execution_id, f),
attach_opaque_data: self.store.kv_get(execution_id, b"bufattach"),
storage: Some(self.store.clone()),
..Default::default()
})?;
Ok(bind.output_schema)
}
fn buffering_arguments(
&self,
execution_id: &[u8],
f: &dyn TableBufferingFunction,
) -> crate::arguments::Arguments {
let mut args = self
.store
.kv_get(execution_id, b"bufargs")
.and_then(|b| crate::arguments::Arguments::parse(&b).ok())
.unwrap_or_default();
args.remap_positional(&f.argument_specs());
args
}
pub fn handle_buffering_process(
&self,
req: &Request,
ctx: &CallContext,
) -> Result<Option<RecordBatch>> {
let dto: TableBufferingProcessRequest = boxed(req)?;
let f = self.resolve_buffering(&dto.function_name)?;
let batch = ipc::read_batch(&dto.input_batch.0)?;
let output_schema =
self.buffering_output_schema(&dto.execution_id.0, f.as_ref(), Some(batch.schema()))?;
let logs = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let params = BufferingParams {
execution_id: dto.execution_id.0.clone(),
storage: self.store.clone(),
output_schema,
arguments: self.buffering_arguments(&dto.execution_id.0, f.as_ref()),
settings: crate::settings::Settings::default(),
attach_opaque_data: self.store.kv_get(&dto.execution_id.0, b"bufattach"),
batch_index: dto.batch_index,
logs: logs.clone(),
};
let state_id = f.process(¶ms, &batch)?;
Self::drain_buffering_logs(&logs, ctx);
Ok(Some(wire::to_result_batch(
TableBufferingProcessResponse {
state_id: Bytes::from(state_id),
},
)?))
}
fn drain_buffering_logs(
logs: &std::sync::Arc<std::sync::Mutex<Vec<String>>>,
ctx: &CallContext,
) {
if let Ok(mut g) = logs.lock() {
for msg in g.drain(..) {
ctx.client_log(vgi_rpc::LogLevel::Info, msg);
}
}
}
pub fn handle_buffering_combine(
&self,
req: &Request,
ctx: &CallContext,
) -> Result<Option<RecordBatch>> {
let dto: TableBufferingCombineRequest = boxed(req)?;
let f = self.resolve_buffering(&dto.function_name)?;
let output_schema = self.buffering_output_schema(&dto.execution_id.0, f.as_ref(), None)?;
let logs = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let params = BufferingParams {
execution_id: dto.execution_id.0.clone(),
storage: self.store.clone(),
output_schema,
arguments: self.buffering_arguments(&dto.execution_id.0, f.as_ref()),
settings: crate::settings::Settings::default(),
attach_opaque_data: self.store.kv_get(&dto.execution_id.0, b"bufattach"),
batch_index: None,
logs: logs.clone(),
};
let state_ids: Vec<Vec<u8>> = dto.state_ids.into_iter().map(|b| b.0).collect();
let finalize_ids = f.combine(¶ms, &state_ids)?;
Self::drain_buffering_logs(&logs, ctx);
Ok(Some(wire::to_result_batch(
TableBufferingCombineResponse {
finalize_state_ids: finalize_ids.into_iter().map(Bytes::from).collect(),
},
)?))
}
pub fn handle_buffering_destructor(&self, req: &Request) -> Result<Option<RecordBatch>> {
let dto: TableBufferingDestructorRequest = boxed(req)?;
self.store.clear(&dto.execution_id.0);
Ok(None)
}
fn agg_key(gid: i64) -> Vec<u8> {
gid.to_le_bytes().to_vec()
}
pub fn handle_aggregate_bind(
&self,
req: &Request,
ctx: &CallContext,
) -> Result<Option<RecordBatch>> {
let dto: AggregateBindRequest = boxed(req)?;
let mut args = crate::arguments::Arguments::parse(&dto.arguments.0)?;
let input_schema = opt_schema(&dto.input_schema)?;
let f = self.resolve_aggregate(&dto.function_name)?;
args.remap_positional(&f.argument_specs());
let _ = ctx;
let params = AggregateBindParams {
arguments: args,
input_schema,
settings: parse_settings(&dto.settings)?,
};
let bind = f.on_bind(¶ms)?;
let execution_id = self.next_execution_id();
self.store
.kv_put(&execution_id, b"aggargs", &dto.arguments.0);
Ok(Some(wire::to_result_batch(AggregateBindResponse {
output_schema: Bytes::from(ipc::write_schema_ref(&bind.output_schema)?),
execution_id: Bytes::from(execution_id),
})?))
}
pub fn handle_aggregate_update(&self, req: &Request) -> Result<Option<RecordBatch>> {
let dto: AggregateUpdateRequest = boxed(req)?;
let f = self.resolve_aggregate(&dto.function_name)?;
let batch = ipc::read_batch(&dto.input_batch.0)?;
let (gids, columns) = split_group_ids(&batch)?;
let mut states: HashMap<i64, Vec<u8>> = HashMap::new();
for i in 0..gids.len() {
let gid = gids.value(i);
if let std::collections::hash_map::Entry::Vacant(e) = states.entry(gid) {
if let Some(s) = self.store.kv_get(&dto.execution_id.0, &Self::agg_key(gid)) {
e.insert(s);
}
}
}
f.update(&mut states, &gids, &columns)?;
for (gid, state) in states {
self.store
.kv_put(&dto.execution_id.0, &Self::agg_key(gid), &state);
}
Ok(Some(wire::empty_result_batch()?))
}
pub fn handle_aggregate_combine(&self, req: &Request) -> Result<Option<RecordBatch>> {
let dto: AggregateCombineRequest = boxed(req)?;
let f = self.resolve_aggregate(&dto.function_name)?;
let batch = ipc::read_batch(&dto.merge_batch.0)?;
let src = batch
.column_by_name("source_group_id")
.or_else(|| Some(batch.column(0)))
.and_then(|c| c.as_any().downcast_ref::<Int64Array>())
.ok_or_else(|| RpcError::type_error("combine: source_group_id"))?
.clone();
let tgt = batch
.column_by_name("target_group_id")
.or_else(|| batch.columns().get(1))
.and_then(|c| c.as_any().downcast_ref::<Int64Array>())
.ok_or_else(|| RpcError::type_error("combine: target_group_id"))?
.clone();
for i in 0..src.len() {
let s = src.value(i);
let t = tgt.value(i);
let source = self.store.kv_get(&dto.execution_id.0, &Self::agg_key(s));
let target = self.store.kv_get(&dto.execution_id.0, &Self::agg_key(t));
let merged = match (target, source) {
(None, None) => continue,
(Some(t), None) => t,
(None, Some(s)) => s,
(Some(t), Some(s)) => f.combine(t, s)?,
};
self.store
.kv_put(&dto.execution_id.0, &Self::agg_key(t), &merged);
}
Ok(Some(wire::empty_result_batch()?))
}
pub fn handle_aggregate_finalize(&self, req: &Request) -> Result<Option<RecordBatch>> {
let dto: AggregateFinalizeRequest = boxed(req)?;
let f = self.resolve_aggregate(&dto.function_name)?;
let output_schema = ipc::read_schema(&dto.output_schema.0)?;
let gid_batch = ipc::read_batch(&dto.group_ids_batch.0)?;
let gids = gid_batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.ok_or_else(|| RpcError::type_error("finalize: group_ids not int64"))?
.clone();
let states: Vec<Option<Vec<u8>>> = (0..gids.len())
.map(|i| {
self.store
.kv_get(&dto.execution_id.0, &Self::agg_key(gids.value(i)))
})
.collect();
let mut agg_args = self
.store
.kv_get(&dto.execution_id.0, b"aggargs")
.and_then(|b| crate::arguments::Arguments::parse(&b).ok())
.unwrap_or_default();
agg_args.remap_positional(&f.argument_specs());
let result = f.finalize_with_args(&output_schema, &gids, &states, &agg_args)?;
Ok(Some(wire::to_result_batch(AggregateFinalizeResponse {
result_batch: Bytes::from(ipc::write_batch(&result)?),
})?))
}
pub fn handle_aggregate_destructor(&self, req: &Request) -> Result<Option<RecordBatch>> {
let dto: AggregateDestructorRequest = boxed(req)?;
self.store.clear(&dto.execution_id.0);
Ok(Some(wire::empty_result_batch()?))
}
fn win_key(partition_id: i64, suffix: &str) -> Vec<u8> {
format!("win_{partition_id}_{suffix}").into_bytes()
}
pub fn handle_aggregate_window_init(&self, req: &Request) -> Result<Option<RecordBatch>> {
let dto: AggregateWindowInitRequest = boxed(req)?;
self.store.kv_put(
&dto.execution_id.0,
&Self::win_key(dto.partition_id, "p"),
&dto.partition_batch.0,
);
self.store.kv_put(
&dto.execution_id.0,
&Self::win_key(dto.partition_id, "o"),
&dto.output_schema.0,
);
if let Some(m) = &dto.filter_mask {
self.store.kv_put(
&dto.execution_id.0,
&Self::win_key(dto.partition_id, "m"),
&m.0,
);
}
Ok(Some(wire::empty_result_batch()?))
}
fn load_window_partition(
&self,
exec: &[u8],
partition_id: i64,
) -> Result<(RecordBatch, SchemaRef, Option<Vec<bool>>)> {
let pb = self
.store
.kv_get(exec, &Self::win_key(partition_id, "p"))
.ok_or_else(|| {
RpcError::runtime_error(format!(
"aggregate_window: unknown partition_id={partition_id}"
))
})?;
let os = self
.store
.kv_get(exec, &Self::win_key(partition_id, "o"))
.ok_or_else(|| RpcError::runtime_error("aggregate_window: missing output schema"))?;
let partition = ipc::read_batch(&pb)?;
let output_schema = ipc::read_schema(&os)?;
let n = partition.num_rows();
let mask = self
.store
.kv_get(exec, &Self::win_key(partition_id, "m"))
.filter(|b| !b.is_empty())
.map(|bytes| {
(0..n)
.map(|i| {
bytes
.get(i / 8)
.map(|byte| byte & (1 << (i % 8)) != 0)
.unwrap_or(true)
})
.collect::<Vec<bool>>()
});
Ok((partition, output_schema, mask))
}
pub fn handle_aggregate_window(&self, req: &Request) -> Result<Option<RecordBatch>> {
let dto: AggregateWindowRequest = boxed(req)?;
let f = self.resolve_aggregate(&dto.function_name)?;
let (partition, output_schema, mask) =
self.load_window_partition(&dto.execution_id.0, dto.partition_id)?;
let frames: Vec<(i64, i64)> = dto
.frame_starts
.iter()
.zip(dto.frame_ends.iter())
.map(|(&s, &e)| (s, e))
.collect();
let col = f.window(&partition, &output_schema, &[frames], mask.as_deref())?;
let batch = RecordBatch::try_new(output_schema, vec![col])
.map_err(|e| RpcError::runtime_error(e.to_string()))?;
Ok(Some(wire::to_result_batch(AggregateWindowResponse {
result_batch: Bytes::from(ipc::write_batch(&batch)?),
})?))
}
pub fn handle_aggregate_window_batch(&self, req: &Request) -> Result<Option<RecordBatch>> {
let dto: AggregateWindowBatchRequest = boxed(req)?;
let f = self.resolve_aggregate(&dto.function_name)?;
let (partition, output_schema, mask) =
self.load_window_partition(&dto.execution_id.0, dto.partition_id)?;
let mut frames: Vec<Vec<(i64, i64)>> = Vec::with_capacity(dto.count as usize);
let mut off = 0usize;
for r in 0..dto.count as usize {
let n = dto.frames_per_row.get(r).copied().unwrap_or(0) as usize;
let mut subs = Vec::with_capacity(n);
for _ in 0..n {
let s = dto.frame_starts.get(off).copied().unwrap_or(0);
let e = dto.frame_ends.get(off).copied().unwrap_or(0);
subs.push((s, e));
off += 1;
}
frames.push(subs);
}
let col = f.window(&partition, &output_schema, &frames, mask.as_deref())?;
let batch = RecordBatch::try_new(output_schema, vec![col])
.map_err(|e| RpcError::runtime_error(e.to_string()))?;
Ok(Some(wire::to_result_batch(AggregateWindowResponse {
result_batch: Bytes::from(ipc::write_batch(&batch)?),
})?))
}
pub fn handle_aggregate_window_destructor(&self, req: &Request) -> Result<Option<RecordBatch>> {
let dto: AggregateWindowDestructorRequest = boxed(req)?;
for sfx in ["p", "o", "m"] {
self.store
.kv_del(&dto.execution_id.0, &Self::win_key(dto.partition_id, sfx));
}
Ok(Some(wire::empty_result_batch()?))
}
fn ser_state_map(m: &std::collections::HashMap<Vec<u8>, Vec<u8>>) -> Vec<u8> {
let mut out = Vec::new();
out.extend_from_slice(&(m.len() as u64).to_le_bytes());
for (k, v) in m {
out.extend_from_slice(&(k.len() as u64).to_le_bytes());
out.extend_from_slice(k);
out.extend_from_slice(&(v.len() as u64).to_le_bytes());
out.extend_from_slice(v);
}
out
}
fn de_state_map(b: &[u8]) -> std::collections::HashMap<Vec<u8>, Vec<u8>> {
let mut m = std::collections::HashMap::new();
let rd = |b: &[u8], off: &mut usize, n: usize| -> Option<Vec<u8>> {
let end = off.checked_add(n)?;
let s = b.get(*off..end)?.to_vec();
*off = end;
Some(s)
};
let rd_len = |b: &[u8], off: &mut usize| -> Option<usize> {
let raw = rd(b, off, 8)?;
let arr: [u8; 8] = raw.try_into().ok()?;
Some(u64::from_le_bytes(arr) as usize)
};
let mut off = 0usize;
let Some(count) = rd_len(b, &mut off) else {
return m;
};
for _ in 0..count {
let Some(kl) = rd_len(b, &mut off) else { break };
let Some(k) = rd(b, &mut off, kl) else { break };
let Some(vl) = rd_len(b, &mut off) else { break };
let Some(v) = rd(b, &mut off, vl) else { break };
m.insert(k, v);
}
m
}
pub fn handle_aggregate_streaming_open(&self, req: &Request) -> Result<Option<RecordBatch>> {
let dto: AggregateStreamingOpenRequest = boxed(req)?;
self.resolve_aggregate(&dto.function_name)?;
let execution_id = self.next_execution_id();
self.store.kv_put(
&execution_id,
b"strm_pkc",
&dto.partition_key_count.to_le_bytes(),
);
self.store.kv_put(
&execution_id,
b"strm_okc",
&dto.order_key_count.to_le_bytes(),
);
self.store
.kv_put(&execution_id, b"strm_sos", &dto.output_schema.0);
Ok(Some(wire::to_result_batch(
AggregateStreamingOpenResponse {
execution_id: Bytes::from(execution_id),
},
)?))
}
pub fn handle_aggregate_streaming_chunk(&self, req: &Request) -> Result<Option<RecordBatch>> {
let dto: AggregateStreamingChunkRequest = boxed(req)?;
let f = self.resolve_aggregate(&dto.function_name)?;
let chunk = ipc::read_batch(&dto.input_batch.0)?;
let pkc = self
.store
.kv_get(&dto.execution_id.0, b"strm_pkc")
.and_then(|b| read_le_i64(&b))
.unwrap_or(0) as usize;
let okc = self
.store
.kv_get(&dto.execution_id.0, b"strm_okc")
.and_then(|b| read_le_i64(&b))
.unwrap_or(0) as usize;
let output_schema = self
.store
.kv_get(&dto.execution_id.0, b"strm_sos")
.and_then(|b| ipc::read_schema(&b).ok());
let mut states = self
.store
.kv_get(&dto.execution_id.0, b"strm_state")
.map(|b| Self::de_state_map(&b))
.unwrap_or_default();
let col = f.streaming_chunk(&chunk, pkc, okc, &mut states)?;
self.store.kv_put(
&dto.execution_id.0,
b"strm_state",
&Self::ser_state_map(&states),
);
let schema = output_schema.unwrap_or_else(|| {
Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new(
"result",
col.data_type().clone(),
true,
)]))
});
let batch = RecordBatch::try_new(schema, vec![col])
.map_err(|e| RpcError::runtime_error(e.to_string()))?;
Ok(Some(wire::to_result_batch(
AggregateStreamingChunkResponse {
result_batch: Bytes::from(ipc::write_batch(&batch)?),
},
)?))
}
pub fn handle_aggregate_streaming_close(&self, req: &Request) -> Result<Option<RecordBatch>> {
let dto: AggregateStreamingCloseRequest = boxed(req)?;
for k in [
b"strm_pkc".as_slice(),
b"strm_okc",
b"strm_sos",
b"strm_state",
] {
self.store.kv_del(&dto.execution_id.0, k);
}
Ok(Some(wire::empty_result_batch()?))
}
pub fn handle_empty_items(&self, _req: &Request) -> Result<Option<RecordBatch>> {
Ok(Some(wire::to_result_batch(ItemsResult {
items: Vec::new(),
})?))
}
pub fn handle_void(&self, _req: &Request) -> Result<Option<RecordBatch>> {
Ok(None)
}
pub fn handle_read_only(&self, _req: &Request) -> Result<Option<RecordBatch>> {
Err(RpcError::runtime_error("catalog is read-only"))
}
}
#[derive(serde::Serialize, serde::Deserialize, Clone)]
pub struct ExchangeBlob {
pub kind: String, pub function_name: String,
pub output_schema: Vec<u8>,
pub input_schema: Vec<u8>, pub arguments: Vec<u8>,
pub settings: Vec<u8>,
pub secrets: Vec<u8>,
pub execution_id: Vec<u8>,
pub init_opaque: Vec<u8>,
pub pushdown_filters: Vec<u8>, pub auto_apply: bool,
pub inner_resume: Vec<u8>,
pub at_unit: String,
pub at_value: String,
}
struct ScalarExchangeState {
func: Arc<dyn ScalarFunction>,
params: ProcessParams,
blob: Vec<u8>,
}
impl ExchangeState for ScalarExchangeState {
fn exchange(
&mut self,
input: &RecordBatch,
out: &mut OutputCollector,
ctx: &CallContext,
) -> Result<()> {
self.params.auth_principal = principal(ctx);
let result = self.func.process(&self.params, input)?;
out.emit(result)
}
fn encode_state(&self) -> Result<Vec<u8>> {
Ok(self.blob.clone())
}
}
struct EmptyProducer;
impl TableProducer for EmptyProducer {
fn next_batch(&mut self, _out: &mut OutputCollector) -> Result<Option<RecordBatch>> {
Ok(None)
}
}
struct VecProducer {
batches: Vec<RecordBatch>,
pos: usize,
}
impl TableProducer for VecProducer {
fn next_batch(&mut self, _out: &mut OutputCollector) -> Result<Option<RecordBatch>> {
let b = self.batches.get(self.pos).cloned();
if b.is_some() {
self.pos += 1;
}
Ok(b)
}
}
struct TableInOutExchangeState {
func: Arc<dyn TableInOutFunction>,
params: ProcessParams,
filters: Option<crate::pushdown::PushdownFilters>,
blob: Vec<u8>,
}
impl ExchangeState for TableInOutExchangeState {
fn exchange(
&mut self,
input: &RecordBatch,
out: &mut OutputCollector,
ctx: &CallContext,
) -> Result<()> {
self.params.auth_principal = principal(ctx);
for batch in self.func.process(&self.params, input)? {
let batch = match &self.filters {
Some(f) => f.apply(&batch)?,
None => batch,
};
out.emit(batch)?;
}
Ok(())
}
fn encode_state(&self) -> Result<Vec<u8>> {
Ok(self.blob.clone())
}
}
struct TableProducerState {
inner: Box<dyn TableProducer>,
filters: Option<crate::pushdown::PushdownFilters>,
project_to: Option<arrow_schema::SchemaRef>,
resume_blob: Option<Vec<u8>>,
}
const HTTP_WORKQUEUE_BATCH_LIMIT: usize = 4;
impl vgi_rpc::ProducerState for TableProducerState {
fn produce(&mut self, out: &mut OutputCollector, ctx: &CallContext) -> Result<()> {
let dynamic = ctx
.tick_metadata("vgi_pushdown_filters")
.and_then(|enc| crate::pushdown::PushdownFilters::parse_b64(&enc, &[]));
self.inner.on_dynamic_filters(dynamic.as_ref());
match self.inner.next_batch(out)? {
None => {
out.finish();
Ok(())
}
Some(batch) => {
let meta = self.inner.last_metadata();
let active = dynamic.as_ref().or(self.filters.as_ref());
let batch = match active {
Some(f) => f.apply(&batch)?,
None => batch,
};
let batch = match &self.project_to {
Some(ps) => crate::table_in_out::project_batch(&batch, ps)?,
None => batch,
};
match meta {
Some(m) => out.emit_with_metadata(batch, m),
None => out.emit(batch),
}
}
}
}
fn batch_limit(&self) -> Option<usize> {
self.resume_blob
.as_ref()
.map(|_| HTTP_WORKQUEUE_BATCH_LIMIT)
}
fn encode_state(&self) -> Result<Vec<u8>> {
match &self.resume_blob {
None => Ok(Vec::new()),
Some(bytes) => {
let mut blob: ExchangeBlob = vgi_rpc::stream_codec::bincode_decode(bytes)?;
blob.inner_resume = self.inner.encode_resume();
vgi_rpc::stream_codec::bincode_encode(&blob)
}
}
}
}
fn read_le_i64(b: &[u8]) -> Option<i64> {
let arr: [u8; 8] = b.get(..8)?.try_into().ok()?;
Some(i64::from_le_bytes(arr))
}
fn boxed<T: VgiArrow>(req: &Request) -> Result<T> {
let col = req
.column("request")
.ok_or_else(|| RpcError::type_error("request missing 'request' column"))?;
let ba = col
.as_any()
.downcast_ref::<BinaryArray>()
.ok_or_else(|| RpcError::type_error("'request' column is not binary"))?;
if ba.is_empty() || ba.is_null(0) {
return Err(RpcError::type_error("'request' column is empty"));
}
let batch = ipc::read_batch(ba.value(0))?;
if std::env::var("VGI_WIRE_DEBUG").is_ok() {
eprintln!(
"[vgi-wire] {} inner schema: {:?}",
req.method,
batch
.schema()
.fields()
.iter()
.map(|f| format!("{}:{}", f.name(), f.data_type()))
.collect::<Vec<_>>()
);
}
wire::from_batch::<T>(&batch)
}
fn split_group_ids(batch: &RecordBatch) -> Result<(Int64Array, Vec<ArrayRef>)> {
let (gidx, _) = batch
.schema()
.column_with_name(GROUP_COLUMN_NAME)
.ok_or_else(|| RpcError::type_error("update batch missing group-id column"))?;
let gids = batch
.column(gidx)
.as_any()
.downcast_ref::<Int64Array>()
.ok_or_else(|| RpcError::type_error("group-id column not int64"))?
.clone();
let columns: Vec<ArrayRef> = (0..batch.num_columns())
.filter(|&i| i != gidx)
.map(|i| batch.column(i).clone())
.collect();
Ok((gids, columns))
}
const SEC_MARKER: &[u8] = b"\x00sec\x00";
fn encode_secondary_opaque(name: &str, scope: &[u8]) -> Vec<u8> {
let mut v = SEC_MARKER.to_vec();
v.extend_from_slice(name.as_bytes());
v.push(0);
v.extend_from_slice(scope);
v
}
fn decode_secondary_opaque(bytes: &[u8]) -> Option<(String, Vec<u8>)> {
let rest = bytes.strip_prefix(SEC_MARKER)?;
let sep = rest.iter().position(|&b| b == 0)?;
let name = String::from_utf8(rest[..sep].to_vec()).ok()?;
Some((name, rest[sep + 1..].to_vec()))
}
fn read_string_col(req: &Request, name: &str) -> Result<String> {
let col = req
.column(name)
.ok_or_else(|| RpcError::type_error(format!("request missing '{name}' column")))?;
<String as VgiArrow>::read(col, 0)
}
fn read_opt_string_col(req: &Request, name: &str) -> Option<String> {
let col = req.column(name)?;
if col.is_null(0) {
return None;
}
<String as VgiArrow>::read(col, 0).ok()
}
fn read_binary_col(req: &Request, name: &str) -> Option<Vec<u8>> {
let col = req.column(name)?;
col.as_any()
.downcast_ref::<arrow_array::BinaryArray>()
.filter(|a| a.len() > 0 && a.is_valid(0))
.map(|a| a.value(0).to_vec())
}
fn parse_settings(field: &Option<Bytes>) -> Result<crate::settings::Settings> {
match field {
Some(b) => crate::settings::Settings::parse(&b.0),
None => Ok(crate::settings::Settings::default()),
}
}
fn parse_secrets(field: &Option<Bytes>) -> Result<crate::secrets::Secrets> {
match field {
Some(b) => crate::secrets::Secrets::parse(&b.0),
None => Ok(crate::secrets::Secrets::default()),
}
}
fn principal(ctx: &CallContext) -> Option<String> {
if ctx.auth.authenticated || !ctx.auth.principal.is_empty() {
Some(ctx.auth.principal.clone())
} else {
None
}
}
fn opt_schema(field: &Option<Bytes>) -> Result<Option<SchemaRef>> {
match field {
Some(b) if !b.0.is_empty() => Ok(Some(ipc::read_schema(&b.0)?)),
_ => Ok(None),
}
}
fn normalize_function_type(t: &str) -> Option<String> {
if t.is_empty() {
return None;
}
let lower = t.to_lowercase();
let short = lower.strip_suffix("_function").unwrap_or(&lower);
Some(short.to_string())
}
#[cfg(test)]
mod buffering_schema_tests {
use super::*;
use arrow_schema::{DataType, Field, Schema};
struct FixedOutput;
impl crate::buffering::TableBufferingFunction for FixedOutput {
fn name(&self) -> &str {
"fixed_output"
}
fn metadata(&self) -> crate::function::FunctionMetadata {
Default::default()
}
fn argument_specs(&self) -> Vec<crate::function::ArgSpec> {
vec![]
}
fn on_bind(&self, _p: &BindParams) -> Result<crate::function::BindResponse> {
Ok(crate::function::BindResponse {
output_schema: Arc::new(Schema::new(vec![Field::new(
"s",
DataType::Float64,
true,
)])),
opaque_data: Vec::new(),
})
}
fn process(
&self,
_p: &crate::buffering::BufferingParams,
_b: &arrow_array::RecordBatch,
) -> Result<Vec<u8>> {
unimplemented!()
}
fn combine(
&self,
_p: &crate::buffering::BufferingParams,
_s: &[Vec<u8>],
) -> Result<Vec<Vec<u8>>> {
unimplemented!()
}
fn finalize_producer(
&self,
_p: &crate::buffering::BufferingParams,
_f: Vec<u8>,
) -> Result<Box<dyn crate::table_function::TableProducer>> {
unimplemented!()
}
}
#[test]
fn output_schema_recomputed_on_store_miss() {
let d = Dispatcher::new("test");
let exec = format!("test-recompute-{}", std::process::id()).into_bytes();
d.store.clear(&exec); let decimal_input = Arc::new(Schema::new(vec![Field::new(
"a",
DataType::Decimal128(10, 2),
true,
)]));
let out = d
.buffering_output_schema(&exec, &FixedOutput, Some(decimal_input))
.expect("recompute via on_bind");
assert_eq!(out.fields().len(), 1);
assert_eq!(out.field(0).data_type(), &DataType::Float64);
}
#[test]
fn output_schema_errors_without_any_input() {
let d = Dispatcher::new("test");
let exec = format!("test-error-{}", std::process::id()).into_bytes();
d.store.clear(&exec);
assert!(d
.buffering_output_schema(&exec, &FixedOutput, None)
.is_err());
}
}
#[cfg(test)]
mod malformed_input_tests {
use super::*;
#[test]
fn de_state_map_roundtrips() {
let mut m = std::collections::HashMap::new();
m.insert(b"k1".to_vec(), b"value-one".to_vec());
m.insert(b"".to_vec(), b"".to_vec());
m.insert(vec![0xff, 0x00, 0xfe], vec![1, 2, 3, 4]);
let enc = Dispatcher::ser_state_map(&m);
assert_eq!(Dispatcher::de_state_map(&enc), m);
}
#[test]
fn de_state_map_tolerates_truncation_at_every_offset() {
let mut m = std::collections::HashMap::new();
m.insert(b"alpha".to_vec(), b"beta".to_vec());
m.insert(b"gamma".to_vec(), b"delta".to_vec());
let enc = Dispatcher::ser_state_map(&m);
for n in 0..=enc.len() {
let got = Dispatcher::de_state_map(&enc[..n]);
for (k, v) in &got {
assert_eq!(
m.get(k),
Some(v),
"decoded a key/value that was never encoded"
);
}
}
}
#[test]
fn de_state_map_rejects_garbage_lengths() {
let mut bad = Vec::new();
bad.extend_from_slice(&1u64.to_le_bytes());
bad.extend_from_slice(&u64::MAX.to_le_bytes());
assert!(Dispatcher::de_state_map(&bad).is_empty());
for len in 0..20usize {
let buf: Vec<u8> = (0..len).map(|i| (i as u8).wrapping_mul(37)).collect();
let _ = Dispatcher::de_state_map(&buf);
}
}
#[test]
fn read_le_i64_is_bounds_safe() {
assert_eq!(read_le_i64(&7i64.to_le_bytes()), Some(7));
assert_eq!(read_le_i64(&(-1i64).to_le_bytes()), Some(-1));
let mut long = 42i64.to_le_bytes().to_vec();
long.extend_from_slice(b"trailing");
assert_eq!(read_le_i64(&long), Some(42));
for n in 0..8usize {
assert_eq!(read_le_i64(&vec![0u8; n]), None);
}
assert_eq!(read_le_i64(&[]), None);
}
}