use std::collections::HashMap;
use serde_json::Value;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct InheritanceConfig {
pub inheritance_column: String,
}
impl Default for InheritanceConfig {
fn default() -> Self {
Self {
inheritance_column: "type".to_owned(),
}
}
}
impl InheritanceConfig {
#[must_use]
pub fn new(inheritance_column: impl Into<String>) -> Self {
Self {
inheritance_column: inheritance_column.into(),
}
}
}
pub trait StiType {
fn sti_name() -> &'static str;
}
pub trait SingleTableInheritance {
fn inheritance_config() -> &'static InheritanceConfig {
static DEFAULT_CONFIG: std::sync::LazyLock<InheritanceConfig> =
std::sync::LazyLock::new(InheritanceConfig::default);
&DEFAULT_CONFIG
}
}
#[must_use]
pub fn becomes<T, R>(record: R) -> T
where
T: From<R>,
{
T::from(record)
}
#[must_use]
pub fn scope_for_type<T>(
config: &InheritanceConfig,
mut conditions: HashMap<String, Value>,
) -> HashMap<String, Value>
where
T: StiType,
{
conditions.insert(
config.inheritance_column.clone(),
Value::String(T::sti_name().to_owned()),
);
conditions
}
#[must_use]
pub fn matches_type<T>(config: &InheritanceConfig, attributes: &HashMap<String, Value>) -> bool
where
T: StiType,
{
attributes
.get(&config.inheritance_column)
.and_then(Value::as_str)
== Some(T::sti_name())
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::LazyLock;
use serde_json::json;
use super::{
InheritanceConfig, SingleTableInheritance, StiType, becomes, matches_type, scope_for_type,
};
#[derive(Debug, Clone, PartialEq, Eq)]
struct CompanyRecord {
id: i64,
name: String,
record_type: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct FirmRecord {
id: i64,
name: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct ClientRecord {
id: i64,
name: String,
}
impl From<CompanyRecord> for FirmRecord {
fn from(value: CompanyRecord) -> Self {
Self {
id: value.id,
name: value.name,
}
}
}
impl From<CompanyRecord> for ClientRecord {
fn from(value: CompanyRecord) -> Self {
Self {
id: value.id,
name: value.name,
}
}
}
impl StiType for FirmRecord {
fn sti_name() -> &'static str {
"Firm"
}
}
impl StiType for ClientRecord {
fn sti_name() -> &'static str {
"Client"
}
}
impl SingleTableInheritance for CompanyRecord {
fn inheritance_config() -> &'static InheritanceConfig {
static CONFIG: LazyLock<InheritanceConfig> =
LazyLock::new(|| InheritanceConfig::new("record_type"));
&CONFIG
}
}
#[test]
fn default_config_uses_type_column() {
assert_eq!(InheritanceConfig::default().inheritance_column, "type");
}
#[test]
fn custom_config_uses_custom_column() {
assert_eq!(
CompanyRecord::inheritance_config().inheritance_column,
"record_type"
);
}
#[test]
fn becomes_casts_between_subtypes() {
let company = CompanyRecord {
id: 1,
name: "Acme".to_owned(),
record_type: "Firm".to_owned(),
};
let firm: FirmRecord = becomes(company);
assert_eq!(firm.id, 1);
assert_eq!(firm.name, "Acme");
}
#[test]
fn scope_for_type_adds_discriminator() {
let scope =
scope_for_type::<FirmRecord>(CompanyRecord::inheritance_config(), HashMap::new());
assert_eq!(scope.get("record_type"), Some(&json!("Firm")));
}
#[test]
fn scope_for_type_preserves_existing_conditions() {
let scope = scope_for_type::<ClientRecord>(
CompanyRecord::inheritance_config(),
HashMap::from([("active".to_owned(), json!(true))]),
);
assert_eq!(scope.get("active"), Some(&json!(true)));
assert_eq!(scope.get("record_type"), Some(&json!("Client")));
}
#[test]
fn matches_type_checks_discriminator_column() {
let attrs = HashMap::from([("record_type".to_owned(), json!("Firm"))]);
assert!(matches_type::<FirmRecord>(
CompanyRecord::inheritance_config(),
&attrs
));
assert!(!matches_type::<ClientRecord>(
CompanyRecord::inheritance_config(),
&attrs
));
}
#[test]
fn matches_type_returns_false_when_discriminator_missing() {
let attrs = HashMap::new();
assert!(!matches_type::<FirmRecord>(
CompanyRecord::inheritance_config(),
&attrs
));
}
}