use quote::quote;
use prax_schema::ast::{FieldType, Model, ScalarType};
use crate::plugins::{Plugin, PluginContext, PluginOutput};
pub struct SerdePlugin;
impl Plugin for SerdePlugin {
fn name(&self) -> &'static str {
"serde"
}
fn env_var(&self) -> &'static str {
"PRAX_PLUGIN_SERDE"
}
fn description(&self) -> &'static str {
"Adds custom serialization helpers for dates, decimals, and UUIDs"
}
fn on_start(&self, _ctx: &PluginContext) -> PluginOutput {
PluginOutput::with_tokens(quote! {
pub mod _serde_helpers {
pub mod iso_date {
use serde::{self, Deserialize, Deserializer, Serializer};
const FORMAT: &str = "%Y-%m-%d";
pub fn serialize<S>(
date: &chrono::NaiveDate,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let s = date.format(FORMAT).to_string();
serializer.serialize_str(&s)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<chrono::NaiveDate, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
chrono::NaiveDate::parse_from_str(&s, FORMAT)
.map_err(serde::de::Error::custom)
}
}
pub mod iso_datetime {
use serde::{self, Deserialize, Deserializer, Serializer};
pub fn serialize<S>(
dt: &chrono::DateTime<chrono::Utc>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&dt.to_rfc3339())
}
pub fn deserialize<'de, D>(
deserializer: D,
) -> Result<chrono::DateTime<chrono::Utc>, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
chrono::DateTime::parse_from_rfc3339(&s)
.map(|dt| dt.with_timezone(&chrono::Utc))
.map_err(serde::de::Error::custom)
}
}
pub mod uuid_string {
use serde::{self, Deserialize, Deserializer, Serializer};
pub fn serialize<S>(uuid: &uuid::Uuid, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&uuid.to_string())
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<uuid::Uuid, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.parse().map_err(serde::de::Error::custom)
}
}
pub mod decimal_string {
use serde::{self, Deserialize, Deserializer, Serializer};
pub fn serialize<S>(
decimal: &rust_decimal::Decimal,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&decimal.to_string())
}
pub fn deserialize<'de, D>(
deserializer: D,
) -> Result<rust_decimal::Decimal, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.parse().map_err(serde::de::Error::custom)
}
}
pub mod base64_bytes {
use serde::{self, Deserialize, Deserializer, Serializer};
use base64::{Engine as _, engine::general_purpose::STANDARD};
pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&STANDARD.encode(bytes))
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
STANDARD.decode(&s).map_err(serde::de::Error::custom)
}
}
}
})
}
fn on_model(&self, _ctx: &PluginContext, model: &Model) -> PluginOutput {
let has_datetime = model
.fields
.values()
.any(|f| matches!(f.field_type, FieldType::Scalar(ScalarType::DateTime)));
let has_date = model
.fields
.values()
.any(|f| matches!(f.field_type, FieldType::Scalar(ScalarType::Date)));
let has_uuid = model
.fields
.values()
.any(|f| matches!(f.field_type, FieldType::Scalar(ScalarType::Uuid)));
let has_decimal = model
.fields
.values()
.any(|f| matches!(f.field_type, FieldType::Scalar(ScalarType::Decimal)));
let has_bytes = model
.fields
.values()
.any(|f| matches!(f.field_type, FieldType::Scalar(ScalarType::Bytes)));
let mut hints = Vec::new();
if has_datetime {
hints.push("DateTime fields use RFC 3339 format");
}
if has_date {
hints.push("Date fields use ISO 8601 format (YYYY-MM-DD)");
}
if has_uuid {
hints.push("UUID fields are serialized as strings");
}
if has_decimal {
hints.push("Decimal fields are serialized as strings");
}
if has_bytes {
hints.push("Bytes fields are serialized as base64");
}
if hints.is_empty() {
return PluginOutput::new();
}
let hints_str = hints.join(", ");
PluginOutput::with_tokens(quote! {
pub mod _serde_info {
pub const HINTS: &str = #hints_str;
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use prax_schema::Schema;
use prax_schema::ast::{Field, Ident, Span, TypeModifier};
fn make_span() -> Span {
Span::new(0, 0)
}
fn make_ident(name: &str) -> Ident {
Ident::new(name, make_span())
}
#[test]
fn test_serde_plugin_start() {
let schema = Schema::new();
let config = crate::plugins::PluginConfig::new();
let ctx = PluginContext::new(&schema, &config);
let plugin = SerdePlugin;
let output = plugin.on_start(&ctx);
let code = output.tokens.to_string();
assert!(code.contains("_serde_helpers"));
assert!(code.contains("iso_date"));
assert!(code.contains("iso_datetime"));
assert!(code.contains("uuid_string"));
}
#[test]
fn test_serde_plugin_model_with_datetime() {
let schema = Schema::new();
let config = crate::plugins::PluginConfig::new();
let ctx = PluginContext::new(&schema, &config);
let mut model = Model::new(make_ident("Event"), make_span());
model.add_field(Field::new(
make_ident("created_at"),
FieldType::Scalar(ScalarType::DateTime),
TypeModifier::Required,
vec![],
make_span(),
));
let plugin = SerdePlugin;
let output = plugin.on_model(&ctx, &model);
let code = output.tokens.to_string();
assert!(code.contains("_serde_info"));
assert!(code.contains("HINTS"));
}
}