use std::collections::{HashMap, HashSet};
use std::marker::PhantomData;
use serde::Serialize;
use serde::de::DeserializeOwned;
use crate::config::EncodingConfig;
use crate::types::symbol::ObjectId;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum MigrationMode {
TraditionalOnly,
#[default]
PreferTraditional,
Adaptive,
PreferSymbolNative,
SymbolNativeOnly,
}
impl MigrationMode {
#[must_use]
pub fn should_use_raptorq(&self, hint: Option<bool>, data_size: usize) -> bool {
match (self, hint) {
(_, Some(true)) | (Self::SymbolNativeOnly | Self::PreferSymbolNative, None) => true,
(_, Some(false)) | (Self::TraditionalOnly | Self::PreferTraditional, None) => false,
(Self::Adaptive, None) => data_size > 1024,
}
}
}
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
pub enum MigrationFeature {
JoinEncoding,
RaceEncoding,
DistributedRegions,
SymbolCancellation,
SymbolTracing,
EpochBarriers,
}
impl MigrationFeature {
pub fn all() -> impl Iterator<Item = Self> {
[
Self::JoinEncoding,
Self::RaceEncoding,
Self::DistributedRegions,
Self::SymbolCancellation,
Self::SymbolTracing,
Self::EpochBarriers,
]
.into_iter()
}
}
#[derive(Debug, Clone, Default)]
pub struct MigrationConfig {
features: HashSet<MigrationFeature>,
mode: MigrationMode,
overrides: HashMap<String, MigrationMode>,
}
impl MigrationConfig {
#[must_use]
pub fn is_enabled(&self, feature: MigrationFeature) -> bool {
self.features.contains(&feature)
}
#[must_use]
pub fn mode(&self) -> MigrationMode {
self.mode
}
#[must_use]
pub fn enabled_features(&self) -> &HashSet<MigrationFeature> {
&self.features
}
#[must_use]
pub fn mode_for(&self, operation: &str) -> MigrationMode {
self.overrides.get(operation).copied().unwrap_or(self.mode)
}
}
#[derive(Debug, Default)]
pub struct MigrationBuilder {
features: HashSet<MigrationFeature>,
mode: MigrationMode,
overrides: HashMap<String, MigrationMode>,
}
impl MigrationBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn enable(mut self, feature: MigrationFeature) -> Self {
self.features.insert(feature);
self
}
#[must_use]
pub fn disable(mut self, feature: MigrationFeature) -> Self {
self.features.remove(&feature);
self
}
#[must_use]
pub fn with_mode(mut self, mode: MigrationMode) -> Self {
self.mode = mode;
self
}
#[must_use]
pub fn override_operation(mut self, operation: impl Into<String>, mode: MigrationMode) -> Self {
self.overrides.insert(operation.into(), mode);
self
}
#[must_use]
pub fn full_raptorq(mut self) -> Self {
self.features = MigrationFeature::all().collect();
self.mode = MigrationMode::SymbolNativeOnly;
self
}
#[must_use]
pub fn build(self) -> MigrationConfig {
MigrationConfig {
features: self.features,
mode: self.mode,
overrides: self.overrides,
}
}
}
#[must_use]
pub fn configure_migration() -> MigrationBuilder {
MigrationBuilder::new()
}
#[derive(Debug)]
pub enum DualValueError {
SerializationFailed(String),
DeserializationFailed(String),
}
impl std::fmt::Display for DualValueError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SerializationFailed(msg) => write!(f, "serialization failed: {msg}"),
Self::DeserializationFailed(msg) => write!(f, "deserialization failed: {msg}"),
}
}
}
impl std::error::Error for DualValueError {}
pub enum DualValue<T> {
Traditional(T),
SymbolNative {
serialized: Vec<u8>,
object_id: ObjectId,
_phantom: PhantomData<T>,
},
}
impl<T> DualValue<T> {
#[must_use]
pub fn uses_raptorq(&self) -> bool {
matches!(self, Self::SymbolNative { .. })
}
#[must_use]
pub fn is_traditional(&self) -> bool {
matches!(self, Self::Traditional(_))
}
}
impl<T: Clone + Serialize + DeserializeOwned> DualValue<T> {
pub fn get(&self) -> Result<T, DualValueError> {
match self {
Self::Traditional(v) => Ok(v.clone()),
Self::SymbolNative { serialized, .. } => serde_json::from_slice(serialized)
.map_err(|e| DualValueError::DeserializationFailed(e.to_string())),
}
}
pub fn ensure_symbols(&mut self, _config: &EncodingConfig) -> Result<(), DualValueError> {
if let Self::Traditional(v) = self {
let serialized = serde_json::to_vec(v)
.map_err(|e| DualValueError::SerializationFailed(e.to_string()))?;
let object_id = ObjectId::new_for_test(0);
*self = Self::SymbolNative {
serialized,
object_id,
_phantom: PhantomData,
};
}
Ok(())
}
}
impl<T: std::fmt::Debug> std::fmt::Debug for DualValue<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Traditional(v) => f.debug_tuple("Traditional").field(v).finish(),
Self::SymbolNative {
serialized,
object_id,
..
} => f
.debug_struct("SymbolNative")
.field("bytes", &serialized.len())
.field("object_id", object_id)
.finish(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dual_value_traditional() {
let value = DualValue::Traditional(42i32);
assert_eq!(value.get().unwrap(), 42);
assert!(!value.uses_raptorq());
}
#[test]
fn test_dual_value_conversion() {
let mut value = DualValue::Traditional("hello".to_string());
let config = EncodingConfig::default();
value.ensure_symbols(&config).unwrap();
assert!(matches!(value, DualValue::SymbolNative { .. }));
assert_eq!(value.get().unwrap(), "hello".to_string());
}
#[test]
fn test_migration_mode_decisions() {
assert!(!MigrationMode::TraditionalOnly.should_use_raptorq(None, 10000));
assert!(MigrationMode::SymbolNativeOnly.should_use_raptorq(None, 10));
assert!(MigrationMode::TraditionalOnly.should_use_raptorq(Some(true), 10));
assert!(!MigrationMode::SymbolNativeOnly.should_use_raptorq(Some(false), 10));
assert!(!MigrationMode::Adaptive.should_use_raptorq(None, 100));
assert!(MigrationMode::Adaptive.should_use_raptorq(None, 10000));
}
#[test]
fn test_migration_builder() {
let config = configure_migration()
.enable(MigrationFeature::JoinEncoding)
.enable(MigrationFeature::RaceEncoding)
.build();
assert!(config.is_enabled(MigrationFeature::JoinEncoding));
assert!(config.is_enabled(MigrationFeature::RaceEncoding));
assert!(!config.is_enabled(MigrationFeature::DistributedRegions));
}
#[test]
fn test_full_raptorq_mode() {
let config = configure_migration().full_raptorq().build();
for feature in MigrationFeature::all() {
assert!(config.is_enabled(feature));
}
}
#[test]
fn test_migration_mode_default() {
let mode = MigrationMode::default();
assert_eq!(mode, MigrationMode::PreferTraditional);
}
#[test]
fn test_migration_builder_disable() {
let config = configure_migration()
.full_raptorq()
.disable(MigrationFeature::SymbolTracing)
.build();
assert!(config.is_enabled(MigrationFeature::JoinEncoding));
assert!(!config.is_enabled(MigrationFeature::SymbolTracing));
}
#[test]
fn test_per_operation_override() {
let config = configure_migration()
.with_mode(MigrationMode::PreferTraditional)
.override_operation("heavy_join", MigrationMode::PreferSymbolNative)
.build();
assert_eq!(config.mode(), MigrationMode::PreferTraditional);
assert_eq!(
config.mode_for("heavy_join"),
MigrationMode::PreferSymbolNative
);
assert_eq!(
config.mode_for("other_op"),
MigrationMode::PreferTraditional
);
}
#[test]
fn dual_value_error_display_serialization() {
let err = DualValueError::SerializationFailed("bad input".into());
assert_eq!(err.to_string(), "serialization failed: bad input");
}
#[test]
fn dual_value_error_display_deserialization() {
let err = DualValueError::DeserializationFailed("unexpected EOF".into());
assert_eq!(err.to_string(), "deserialization failed: unexpected EOF");
}
#[test]
fn dual_value_error_source_is_none() {
use std::error::Error;
let err = DualValueError::SerializationFailed("x".into());
assert!(err.source().is_none());
}
#[test]
fn dual_value_is_traditional() {
let val = DualValue::Traditional(100u32);
assert!(val.is_traditional());
assert!(!val.uses_raptorq());
}
#[test]
fn dual_value_uses_raptorq_after_ensure_symbols() {
let mut val = DualValue::Traditional(42u32);
let config = EncodingConfig::default();
val.ensure_symbols(&config).unwrap();
assert!(val.uses_raptorq());
assert!(!val.is_traditional());
}
#[test]
fn dual_value_ensure_symbols_idempotent() {
let mut val = DualValue::Traditional(42u32);
let config = EncodingConfig::default();
val.ensure_symbols(&config).unwrap();
assert!(val.uses_raptorq());
val.ensure_symbols(&config).unwrap();
assert!(val.uses_raptorq());
assert_eq!(val.get().unwrap(), 42u32);
}
#[test]
fn dual_value_get_deserialization_failure() {
let bad = DualValue::<u32>::SymbolNative {
serialized: b"not valid json".to_vec(),
object_id: ObjectId::new_for_test(0),
_phantom: PhantomData,
};
let err = bad.get().unwrap_err();
assert!(matches!(err, DualValueError::DeserializationFailed(_)));
}
#[test]
fn dual_value_debug_traditional() {
let val = DualValue::Traditional(99i32);
let dbg = format!("{val:?}");
assert!(dbg.contains("Traditional"), "{dbg}");
assert!(dbg.contains("99"), "{dbg}");
}
#[test]
fn dual_value_debug_symbol_native() {
let mut val = DualValue::Traditional("hello".to_string());
let config = EncodingConfig::default();
val.ensure_symbols(&config).unwrap();
let dbg = format!("{val:?}");
assert!(dbg.contains("SymbolNative"), "{dbg}");
assert!(dbg.contains("bytes"), "{dbg}");
}
#[test]
fn dual_value_ensure_symbols_serialization_failure() {
#[derive(Clone)]
struct FailingSerialize;
impl Serialize for FailingSerialize {
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
Err(serde::ser::Error::custom("forced serialization failure"))
}
}
impl<'de> serde::Deserialize<'de> for FailingSerialize {
fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(Self)
}
}
let mut val = DualValue::Traditional(FailingSerialize);
let config = EncodingConfig::default();
let err = val.ensure_symbols(&config).unwrap_err();
assert!(matches!(err, DualValueError::SerializationFailed(_)));
assert!(
val.is_traditional(),
"failed conversion should preserve original"
);
}
#[test]
fn migration_config_enabled_features_returns_set() {
let config = configure_migration()
.enable(MigrationFeature::EpochBarriers)
.enable(MigrationFeature::SymbolCancellation)
.build();
let features = config.enabled_features();
assert_eq!(features.len(), 2);
assert!(features.contains(&MigrationFeature::EpochBarriers));
assert!(features.contains(&MigrationFeature::SymbolCancellation));
}
#[test]
fn migration_config_default_has_no_features() {
let config = MigrationConfig::default();
assert!(config.enabled_features().is_empty());
assert_eq!(config.mode(), MigrationMode::PreferTraditional);
}
#[test]
fn adaptive_mode_boundary_at_1024() {
assert!(!MigrationMode::Adaptive.should_use_raptorq(None, 1024));
assert!(MigrationMode::Adaptive.should_use_raptorq(None, 1025));
}
#[test]
fn prefer_symbol_native_without_hint() {
assert!(MigrationMode::PreferSymbolNative.should_use_raptorq(None, 0));
assert!(MigrationMode::PreferSymbolNative.should_use_raptorq(None, 9999));
}
#[test]
fn migration_feature_all_has_six_items() {
assert_eq!(MigrationFeature::all().count(), 6);
}
#[test]
fn migration_feature_all_roundtrip_via_full_raptorq() {
let config = configure_migration().full_raptorq().build();
assert_eq!(config.mode(), MigrationMode::SymbolNativeOnly);
for feature in MigrationFeature::all() {
assert!(
config.is_enabled(feature),
"full_raptorq should enable {feature:?}"
);
}
}
#[test]
fn migration_builder_with_mode() {
let config = MigrationBuilder::new()
.with_mode(MigrationMode::Adaptive)
.build();
assert_eq!(config.mode(), MigrationMode::Adaptive);
}
#[test]
fn migration_builder_multiple_overrides() {
let config = configure_migration()
.override_operation("op_a", MigrationMode::SymbolNativeOnly)
.override_operation("op_b", MigrationMode::TraditionalOnly)
.build();
assert_eq!(config.mode_for("op_a"), MigrationMode::SymbolNativeOnly);
assert_eq!(config.mode_for("op_b"), MigrationMode::TraditionalOnly);
assert_eq!(config.mode_for("op_c"), MigrationMode::PreferTraditional);
}
#[test]
fn migration_mode_debug_clone_copy_default_eq() {
let m = MigrationMode::Adaptive;
let dbg = format!("{m:?}");
assert!(dbg.contains("Adaptive"), "{dbg}");
let copied: MigrationMode = m;
let cloned = m;
assert_eq!(copied, cloned);
assert_eq!(MigrationMode::default(), MigrationMode::PreferTraditional);
assert_ne!(m, MigrationMode::TraditionalOnly);
}
#[test]
fn migration_feature_debug_clone_copy_hash_eq() {
use std::collections::HashSet;
let f = MigrationFeature::JoinEncoding;
let dbg = format!("{f:?}");
assert!(dbg.contains("JoinEncoding"), "{dbg}");
let copied: MigrationFeature = f;
let cloned = f;
assert_eq!(copied, cloned);
let mut set = HashSet::new();
set.insert(MigrationFeature::JoinEncoding);
set.insert(MigrationFeature::RaceEncoding);
set.insert(MigrationFeature::DistributedRegions);
assert_eq!(set.len(), 3);
}
#[test]
fn migration_config_debug_clone_default() {
let c = MigrationConfig::default();
let dbg = format!("{c:?}");
assert!(dbg.contains("MigrationConfig"), "{dbg}");
let cloned = c;
assert_eq!(format!("{cloned:?}"), dbg);
}
}