use std::sync::Arc;
use crate::errors::{FnError, ReloadError};
use crate::registry::{PluginRecordSnapshot, PluginRegistry};
use crate::traits::cdc::{CdcLsn, CdcOutputProvider, CdcStartContext, CdcStream};
use crate::traits::crdt::CrdtKindProvider;
use crate::traits::index::{IndexHandle, IndexKindProvider};
use crate::traits::types::LogicalTypeProvider;
#[derive(Default)]
pub struct ReloadKindHandlers {
pub index_handles: Vec<IndexHandoff>,
pub cdc_streams: Vec<CdcHandoff>,
}
impl std::fmt::Debug for ReloadKindHandlers {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReloadKindHandlers")
.field("index_handles", &self.index_handles.len())
.field("cdc_streams", &self.cdc_streams.len())
.finish()
}
}
pub struct IndexHandoff {
pub name: String,
pub old: Box<dyn IndexHandle>,
pub new: Arc<dyn IndexKindProvider>,
}
impl std::fmt::Debug for IndexHandoff {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IndexHandoff")
.field("name", &self.name)
.finish_non_exhaustive()
}
}
pub struct CdcHandoff {
pub name: String,
pub old: Box<dyn CdcStream>,
pub new: Arc<dyn CdcOutputProvider>,
}
impl std::fmt::Debug for CdcHandoff {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CdcHandoff")
.field("name", &self.name)
.finish_non_exhaustive()
}
}
#[derive(Default)]
pub struct ReloadOutcome {
pub index_handles: Vec<(String, Box<dyn IndexHandle>)>,
pub cdc_streams: Vec<(String, Box<dyn CdcStream>)>,
}
impl std::fmt::Debug for ReloadOutcome {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReloadOutcome")
.field("index_handles", &self.index_handles.len())
.field("cdc_streams", &self.cdc_streams.len())
.finish()
}
}
#[derive(Debug)]
pub struct ReloadDispatcher<'a> {
old: &'a PluginRecordSnapshot,
new_registry: &'a PluginRegistry,
handlers: ReloadKindHandlers,
}
impl<'a> ReloadDispatcher<'a> {
#[must_use]
pub fn new(old: &'a PluginRecordSnapshot, new_registry: &'a PluginRegistry) -> Self {
Self {
old,
new_registry,
handlers: ReloadKindHandlers::default(),
}
}
#[must_use]
pub fn with_handlers(mut self, handlers: ReloadKindHandlers) -> Self {
self.handlers = handlers;
self
}
pub fn check_compat(&self, old_providers: &OldProviders) -> Result<(), ReloadError> {
for kind in &self.old.crdt_kinds {
let Some(old) = old_providers.crdt_kinds.get(kind) else {
continue;
};
let Some(new) = self.new_registry.crdt_kind(kind) else {
continue;
};
new.schema_compat_check(old.as_ref())
.map_err(|e: FnError| {
ReloadError::schema_incompat(format!("crdt:{}", kind.0), e.message)
})?;
}
for name in &old_providers.logical_type_names {
let Some(old) = old_providers.logical_types.get(name) else {
continue;
};
let Some(new) = self.new_registry.logical_type(name) else {
continue;
};
new.compat_check(old.as_ref()).map_err(|e: FnError| {
ReloadError::schema_incompat(format!("logical-type:{name}"), e.message)
})?;
}
Ok(())
}
pub fn dispatch(mut self) -> Result<ReloadOutcome, ReloadError> {
let mut outcome = ReloadOutcome::default();
for handoff in self.handlers.index_handles.drain(..) {
let bytes = handoff.old.persist().map_err(ReloadError::Persist)?;
drop(handoff.old);
let reopened = handoff.new.open(&bytes).map_err(ReloadError::Persist)?;
outcome.index_handles.push((handoff.name, reopened));
}
for mut handoff in self.handlers.cdc_streams.drain(..) {
let lsn: CdcLsn = handoff.old.checkpoint().map_err(ReloadError::Persist)?;
handoff.old.shutdown().map_err(ReloadError::Persist)?;
drop(handoff.old);
let resumed = handoff
.new
.start(CdcStartContext::new(Some(lsn)))
.map_err(ReloadError::Persist)?;
outcome.cdc_streams.push((handoff.name, resumed));
}
Ok(outcome)
}
}
#[derive(Default)]
pub struct OldProviders {
pub crdt_kinds:
std::collections::HashMap<crate::traits::crdt::CrdtKind, Arc<dyn CrdtKindProvider>>,
pub logical_type_names: Vec<smol_str::SmolStr>,
pub logical_types: std::collections::HashMap<smol_str::SmolStr, Arc<dyn LogicalTypeProvider>>,
}
impl std::fmt::Debug for OldProviders {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OldProviders")
.field("crdt_kinds", &self.crdt_kinds.len())
.field("logical_types", &self.logical_types.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::crdt::{CrdtKind, CrdtOp, CrdtState};
use datafusion::scalar::ScalarValue;
#[derive(Default)]
struct CountState {
v: i64,
}
impl CrdtState for CountState {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn apply(&mut self, op: &CrdtOp) -> Result<(), FnError> {
self.v += op.bytes.len() as i64;
Ok(())
}
fn merge(&mut self, other: &dyn CrdtState) -> Result<(), FnError> {
let other = other
.as_any()
.downcast_ref::<CountState>()
.ok_or_else(|| FnError::new(0x100, "merge: wrong state type"))?;
if other.v > self.v {
self.v = other.v;
}
Ok(())
}
fn value(&self) -> Result<ScalarValue, FnError> {
Ok(ScalarValue::Int64(Some(self.v)))
}
fn persist(&self) -> Result<Vec<u8>, FnError> {
Ok(self.v.to_le_bytes().to_vec())
}
}
struct CountProvider {
kind_str: &'static str,
}
impl CrdtKindProvider for CountProvider {
fn kind(&self) -> CrdtKind {
CrdtKind::new(self.kind_str)
}
fn empty(&self) -> Box<dyn CrdtState> {
Box::new(CountState::default())
}
fn from_persisted(&self, bytes: &[u8]) -> Result<Box<dyn CrdtState>, FnError> {
if bytes.len() != 8 {
return Err(FnError::new(
0x101,
format!("expected 8 bytes, got {}", bytes.len()),
));
}
let mut arr = [0u8; 8];
arr.copy_from_slice(bytes);
Ok(Box::new(CountState {
v: i64::from_le_bytes(arr),
}))
}
}
struct RejectingProvider;
impl CrdtKindProvider for RejectingProvider {
fn kind(&self) -> CrdtKind {
CrdtKind::new("count")
}
fn empty(&self) -> Box<dyn CrdtState> {
Box::new(CountState::default())
}
fn from_persisted(&self, _bytes: &[u8]) -> Result<Box<dyn CrdtState>, FnError> {
Err(FnError::new(0x102, "rejecting all persisted bytes"))
}
}
#[test]
fn schema_compat_accepts_round_trip() {
let old = CountProvider { kind_str: "count" };
let new = CountProvider { kind_str: "count" };
new.schema_compat_check(&old).expect("compatible");
}
#[test]
fn schema_compat_rejects_incompatible_round_trip() {
let old = CountProvider { kind_str: "count" };
let new = RejectingProvider;
let err = new.schema_compat_check(&old).unwrap_err();
assert!(err.message.contains("rejecting"));
}
#[test]
fn dispatcher_check_compat_passes_when_all_round_trip() {
let registry = PluginRegistry::new();
let snap = PluginRecordSnapshot {
crdt_kinds: vec![CrdtKind::new("count")],
..Default::default()
};
let mut olds = OldProviders::default();
olds.crdt_kinds.insert(
CrdtKind::new("count"),
Arc::new(CountProvider { kind_str: "count" }),
);
let d = ReloadDispatcher::new(&snap, ®istry);
d.check_compat(&olds).expect("absence is OK");
}
#[test]
fn dispatcher_dispatch_handles_index_handoff() {
struct DummyHandle {
bytes: Vec<u8>,
}
impl IndexHandle for DummyHandle {
fn probe(
&self,
_query: &datafusion::arrow::record_batch::RecordBatch,
_k: usize,
) -> Result<datafusion::arrow::record_batch::RecordBatch, FnError> {
Err(FnError::new(0, "unused"))
}
fn persist(&self) -> Result<Vec<u8>, FnError> {
Ok(self.bytes.clone())
}
fn schema(&self) -> arrow_schema::SchemaRef {
std::sync::Arc::new(arrow_schema::Schema::empty())
}
}
struct DummyProvider;
impl IndexKindProvider for DummyProvider {
fn kind(&self) -> crate::traits::index::IndexKind {
crate::traits::index::IndexKind::new("dummy")
}
fn build(
&self,
_source: &datafusion::arrow::record_batch::RecordBatch,
_options: &str,
) -> Result<Box<dyn crate::traits::index::IndexBuild>, FnError> {
Err(FnError::new(0, "unused"))
}
fn open(&self, persisted: &[u8]) -> Result<Box<dyn IndexHandle>, FnError> {
Ok(Box::new(DummyHandle {
bytes: persisted.to_vec(),
}))
}
}
let snap = PluginRecordSnapshot::default();
let registry = PluginRegistry::new();
let mut handlers = ReloadKindHandlers::default();
handlers.index_handles.push(IndexHandoff {
name: "i1".to_owned(),
old: Box::new(DummyHandle {
bytes: vec![1, 2, 3, 4],
}),
new: Arc::new(DummyProvider),
});
let outcome = ReloadDispatcher::new(&snap, ®istry)
.with_handlers(handlers)
.dispatch()
.expect("handoff");
assert_eq!(outcome.index_handles.len(), 1);
assert_eq!(outcome.index_handles[0].0, "i1");
assert_eq!(
outcome.index_handles[0].1.persist().unwrap(),
vec![1, 2, 3, 4]
);
}
}