use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{punctuated::Punctuated, DeriveInput, Field, Fields, Ident, ItemStruct, Result, Token};
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum Framework {
Anchor,
Pinocchio,
}
impl Framework {
fn on_chain_crate(&self) -> TokenStream {
match self {
Framework::Anchor => quote! { light_account },
Framework::Pinocchio => quote! { light_account_pinocchio },
}
}
fn serde_derives(&self) -> TokenStream {
match self {
Framework::Anchor => {
quote! { anchor_lang::AnchorSerialize, anchor_lang::AnchorDeserialize }
}
Framework::Pinocchio => quote! { borsh::BorshSerialize, borsh::BorshDeserialize },
}
}
}
use super::{
traits::{parse_compress_as_overrides, CompressAsFields},
validation::validate_compression_info_field,
};
use crate::{
discriminator,
hasher::derive_light_hasher_sha,
light_pdas::account::utils::{extract_fields_from_derive_input, is_copy_type, is_pubkey_type},
};
fn is_zero_copy(attrs: &[syn::Attribute]) -> bool {
attrs.iter().any(|attr| {
if !attr.path().is_ident("account") {
return false;
}
if let syn::Meta::List(meta_list) = &attr.meta {
return meta_list.tokens.to_string().contains("zero_copy");
}
false
})
}
pub fn derive_light_account(input: DeriveInput) -> Result<TokenStream> {
derive_light_account_internal(input, Framework::Anchor)
}
pub fn derive_light_pinocchio_account(input: DeriveInput) -> Result<TokenStream> {
derive_light_account_internal(input, Framework::Pinocchio)
}
fn parse_pinocchio_discriminator(attrs: &[syn::Attribute]) -> Result<Option<Vec<u8>>> {
for attr in attrs {
if !attr.path().is_ident("light_pinocchio") {
continue;
}
let meta_list = attr.meta.require_list()?;
let nested: Punctuated<syn::Meta, Token![,]> =
meta_list.parse_args_with(Punctuated::parse_terminated)?;
for meta in &nested {
if let syn::Meta::NameValue(nv) = meta {
if nv.path.is_ident("discriminator") {
if let syn::Expr::Array(arr) = &nv.value {
let bytes: Vec<u8> = arr
.elems
.iter()
.map(|e| {
if let syn::Expr::Lit(lit) = e {
if let syn::Lit::Int(i) = &lit.lit {
return i
.base10_parse::<u8>()
.map_err(|err| syn::Error::new_spanned(i, err));
}
}
if let syn::Expr::Cast(cast) = e {
if let syn::Expr::Lit(lit) = cast.expr.as_ref() {
if let syn::Lit::Int(i) = &lit.lit {
return i
.base10_parse::<u8>()
.map_err(|err| syn::Error::new_spanned(i, err));
}
}
}
Err(syn::Error::new_spanned(e, "expected integer literal"))
})
.collect::<Result<Vec<u8>>>()?;
if bytes.is_empty() {
return Err(syn::Error::new_spanned(
arr,
"discriminator must have at least one byte",
));
}
if bytes.len() > 8 {
return Err(syn::Error::new_spanned(
arr,
"discriminator must not exceed 8 bytes",
));
}
return Ok(Some(bytes));
}
return Err(syn::Error::new_spanned(
&nv.value,
"discriminator must be an array like [1u8]",
));
}
}
}
}
Ok(None)
}
fn derive_light_account_internal(input: DeriveInput, framework: Framework) -> Result<TokenStream> {
let item_struct = derive_input_to_item_struct(&input)?;
let hasher_impl = derive_light_hasher_sha(item_struct.clone())?;
let discriminator_impl = if let Some(disc_bytes) = parse_pinocchio_discriminator(&input.attrs)?
{
if framework != Framework::Pinocchio {
return Err(syn::Error::new_spanned(
&input.ident,
"#[light_pinocchio(discriminator = [...])] is only valid with \
#[derive(LightPinocchioAccount)], not with #[derive(LightAccount)]",
));
}
let mut padded = [0u8; 8];
let copy_len = disc_bytes.len().min(8);
padded[..copy_len].copy_from_slice(&disc_bytes[..copy_len]);
let discriminator_tokens: proc_macro2::TokenStream = format!("{padded:?}").parse().unwrap();
let slice_tokens: proc_macro2::TokenStream = format!("{disc_bytes:?}").parse().unwrap();
let struct_name = &input.ident;
let (impl_gen, type_gen, where_clause) = input.generics.split_for_impl();
quote! {
impl #impl_gen LightDiscriminator for #struct_name #type_gen #where_clause {
const LIGHT_DISCRIMINATOR: [u8; 8] = #discriminator_tokens;
const LIGHT_DISCRIMINATOR_SLICE: &'static [u8] = &#slice_tokens;
fn discriminator() -> [u8; 8] { Self::LIGHT_DISCRIMINATOR }
}
}
} else {
discriminator::anchor_discriminator(item_struct)?
};
let light_account_impl = generate_light_account_impl(&input, framework)?;
let anchor_serde_impls = if framework == Framework::Anchor && is_zero_copy(&input.attrs) {
generate_anchor_serde_for_zero_copy(&input)?
} else {
quote! {}
};
Ok(quote! {
#hasher_impl
#discriminator_impl
#light_account_impl
#anchor_serde_impls
})
}
fn derive_input_to_item_struct(input: &DeriveInput) -> Result<ItemStruct> {
let data = match &input.data {
syn::Data::Struct(data) => data,
_ => {
return Err(syn::Error::new_spanned(
input,
"LightAccount can only be derived for structs",
))
}
};
let fields = match &data.fields {
Fields::Named(fields) => Fields::Named(fields.clone()),
Fields::Unnamed(fields) => Fields::Unnamed(fields.clone()),
Fields::Unit => Fields::Unit,
};
Ok(ItemStruct {
attrs: input.attrs.clone(),
vis: input.vis.clone(),
struct_token: data.struct_token,
ident: input.ident.clone(),
generics: input.generics.clone(),
fields,
semi_token: data.semi_token,
})
}
fn generate_anchor_serde_for_zero_copy(input: &DeriveInput) -> Result<TokenStream> {
let struct_name = &input.ident;
let fields = extract_fields_from_derive_input(input)?;
let serialize_fields: Vec<_> = fields
.iter()
.filter_map(|f| {
let name = f.ident.as_ref()?;
Some(quote! {
anchor_lang::AnchorSerialize::serialize(&self.#name, writer)?;
})
})
.collect();
let deserialize_fields: Vec<_> = fields
.iter()
.filter_map(|f| {
let name = f.ident.as_ref()?;
Some(quote! {
#name: anchor_lang::AnchorDeserialize::deserialize_reader(reader)?
})
})
.collect();
Ok(quote! {
impl anchor_lang::AnchorSerialize for #struct_name {
fn serialize<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
#(#serialize_fields)*
Ok(())
}
}
impl anchor_lang::AnchorDeserialize for #struct_name {
fn deserialize_reader<R: std::io::Read>(reader: &mut R) -> std::io::Result<Self> {
Ok(Self {
#(#deserialize_fields,)*
})
}
}
})
}
fn generate_light_account_impl(input: &DeriveInput, framework: Framework) -> Result<TokenStream> {
let struct_name = &input.ident;
let packed_struct_name = format_ident!("Packed{}", struct_name);
let fields = extract_fields_from_derive_input(input)?;
let is_zero_copy = is_zero_copy(&input.attrs);
let _compression_info_first = validate_compression_info_field(fields, struct_name)?;
let compress_as_fields = parse_compress_as_overrides(&input.attrs)?;
let has_pubkey_fields = fields
.iter()
.filter(|f| {
f.ident
.as_ref()
.is_none_or(|name| name != "compression_info")
})
.any(|f| is_pubkey_type(&f.ty));
let packed_struct =
generate_packed_struct(&packed_struct_name, fields, has_pubkey_fields, framework)?;
let pack_body = generate_pack_body(&packed_struct_name, fields, has_pubkey_fields, framework)?;
let unpack_body = generate_unpack_body(struct_name, fields, has_pubkey_fields, framework)?;
let compress_as_assignments =
generate_compress_as_assignments(fields, &compress_as_fields, framework);
let compress_as_impl_body =
generate_compress_as_impl_body(fields, &compress_as_fields, framework);
let on_chain_crate = framework.on_chain_crate();
let (size_assertion, account_type_token, init_space_token) = match framework {
Framework::Pinocchio => {
(
quote! {
const _: () = {
assert!(
core::mem::size_of::<#struct_name>() <= 800,
"Compressed account size exceeds 800 byte limit"
);
};
},
quote! { #on_chain_crate::AccountType::PdaZeroCopy },
quote! { core::mem::size_of::<Self>() },
)
}
Framework::Anchor => {
if is_zero_copy {
(
quote! {
const _: () = {
assert!(
core::mem::size_of::<#struct_name>() <= 800,
"Compressed account size exceeds 800 byte limit"
);
};
},
quote! { #on_chain_crate::AccountType::PdaZeroCopy },
quote! { core::mem::size_of::<Self>() },
)
} else {
(
quote! {
const _: () = {
assert!(
<#struct_name as anchor_lang::Space>::INIT_SPACE <= 800,
"Compressed account size exceeds 800 byte limit"
);
};
},
quote! { #on_chain_crate::AccountType::Pda },
quote! { <Self as anchor_lang::Space>::INIT_SPACE },
)
}
}
};
let light_account_impl = quote! {
#packed_struct
#size_assertion
impl #on_chain_crate::LightAccount for #struct_name {
const ACCOUNT_TYPE: #on_chain_crate::AccountType = #account_type_token;
type Packed = #packed_struct_name;
const INIT_SPACE: usize = #init_space_token;
#[inline]
fn compression_info(&self) -> &#on_chain_crate::CompressionInfo {
&self.compression_info
}
#[inline]
fn compression_info_mut(&mut self) -> &mut #on_chain_crate::CompressionInfo {
&mut self.compression_info
}
fn set_decompressed(&mut self, config: &#on_chain_crate::LightConfig, current_slot: u64) {
self.compression_info = #on_chain_crate::CompressionInfo::new_from_config(config, current_slot);
#compress_as_assignments
}
#[cfg(not(target_os = "solana"))]
#[inline(never)]
fn pack<AM: #on_chain_crate::AccountMetaTrait>(
&self,
accounts: &mut #on_chain_crate::interface::instruction::PackedAccounts<AM>,
) -> std::result::Result<Self::Packed, #on_chain_crate::LightSdkTypesError> {
#pack_body
}
#[inline(never)]
fn unpack<A: #on_chain_crate::AccountInfoTrait>(
packed: &Self::Packed,
accounts: &#on_chain_crate::packed_accounts::ProgramPackedAccounts<A>,
) -> std::result::Result<Self, #on_chain_crate::LightSdkTypesError> {
#unpack_body
}
}
#[cfg(not(target_os = "solana"))]
impl<AM: #on_chain_crate::AccountMetaTrait> #on_chain_crate::Pack<AM> for #struct_name {
type Packed = #packed_struct_name;
fn pack(
&self,
remaining_accounts: &mut #on_chain_crate::interface::instruction::PackedAccounts<AM>,
) -> std::result::Result<Self::Packed, #on_chain_crate::LightSdkTypesError> {
<Self as #on_chain_crate::LightAccount>::pack(self, remaining_accounts)
}
}
impl<AI: #on_chain_crate::AccountInfoTrait> #on_chain_crate::Unpack<AI> for #packed_struct_name {
type Unpacked = #struct_name;
fn unpack(
&self,
remaining_accounts: &[AI],
) -> std::result::Result<Self::Unpacked, #on_chain_crate::LightSdkTypesError> {
let accounts = #on_chain_crate::packed_accounts::ProgramPackedAccounts {
accounts: remaining_accounts
};
<#struct_name as #on_chain_crate::LightAccount>::unpack(self, &accounts)
}
}
impl #on_chain_crate::HasCompressionInfo for #struct_name {
fn compression_info(&self) -> std::result::Result<&#on_chain_crate::CompressionInfo, #on_chain_crate::LightSdkTypesError> {
Ok(&self.compression_info)
}
fn compression_info_mut(&mut self) -> std::result::Result<&mut #on_chain_crate::CompressionInfo, #on_chain_crate::LightSdkTypesError> {
Ok(&mut self.compression_info)
}
fn compression_info_mut_opt(&mut self) -> &mut Option<#on_chain_crate::CompressionInfo> {
panic!("compression_info_mut_opt not supported for LightAccount types (use compression_info_mut instead)")
}
fn set_compression_info_none(&mut self) -> std::result::Result<(), #on_chain_crate::LightSdkTypesError> {
self.compression_info = #on_chain_crate::CompressionInfo::compressed();
Ok(())
}
}
impl #on_chain_crate::Size for #struct_name {
#[inline]
fn size(&self) -> std::result::Result<usize, #on_chain_crate::LightSdkTypesError> {
Ok(<Self as #on_chain_crate::LightAccount>::INIT_SPACE)
}
}
impl #on_chain_crate::CompressAs for #struct_name {
type Output = Self;
fn compress_as(&self) -> std::borrow::Cow<'_, Self::Output> {
#compress_as_impl_body
}
}
impl #on_chain_crate::CompressedInitSpace for #struct_name {
const COMPRESSED_INIT_SPACE: usize = <Self as #on_chain_crate::LightAccount>::INIT_SPACE;
}
};
Ok(light_account_impl)
}
fn generate_packed_struct(
packed_struct_name: &Ident,
fields: &Punctuated<Field, Token![,]>,
has_pubkey_fields: bool,
framework: Framework,
) -> Result<TokenStream> {
let serde_derives = framework.serde_derives();
if !has_pubkey_fields {
let non_compression_fields: Vec<_> = fields
.iter()
.filter(|f| {
f.ident
.as_ref()
.is_none_or(|name| name != "compression_info")
})
.collect();
if non_compression_fields.is_empty() {
return Ok(quote! {
#[derive(Debug, Clone, #serde_derives)]
pub struct #packed_struct_name;
});
}
let packed_fields = non_compression_fields.iter().filter_map(|field| {
let field_name = field.ident.as_ref()?;
let field_type = &field.ty;
Some(quote! { pub #field_name: #field_type })
});
return Ok(quote! {
#[derive(Debug, Clone, #serde_derives)]
pub struct #packed_struct_name {
#(#packed_fields,)*
}
});
}
let packed_fields = fields.iter().filter_map(|field| {
let field_name = field.ident.as_ref()?;
if field_name == "compression_info" {
return None;
}
let field_type = &field.ty;
let packed_type = if is_pubkey_type(field_type) {
quote! { u8 }
} else {
quote! { #field_type }
};
Some(quote! { pub #field_name: #packed_type })
});
Ok(quote! {
#[derive(Debug, Clone, #serde_derives)]
pub struct #packed_struct_name {
#(#packed_fields,)*
}
})
}
fn generate_pack_body(
packed_struct_name: &Ident,
fields: &Punctuated<Field, Token![,]>,
has_pubkey_fields: bool,
framework: Framework,
) -> Result<TokenStream> {
let pack_assignments: Vec<_> = fields
.iter()
.filter_map(|field| {
let field_name = field.ident.as_ref()?;
if field_name == "compression_info" {
return None;
}
let field_type = &field.ty;
Some(if is_pubkey_type(field_type) {
match framework {
Framework::Anchor => {
quote! { #field_name: accounts.insert_or_get_read_only(AM::pubkey_from_bytes(self.#field_name.to_bytes())) }
}
Framework::Pinocchio => {
quote! { #field_name: accounts.insert_or_get_read_only(AM::pubkey_from_bytes(self.#field_name)) }
}
}
} else if is_copy_type(field_type) {
quote! { #field_name: self.#field_name }
} else {
quote! { #field_name: self.#field_name.clone() }
})
})
.collect();
if !has_pubkey_fields && pack_assignments.is_empty() {
return Ok(quote! {
Ok(#packed_struct_name)
});
}
Ok(quote! {
Ok(#packed_struct_name {
#(#pack_assignments,)*
})
})
}
fn generate_unpack_body(
struct_name: &Ident,
fields: &Punctuated<Field, Token![,]>,
has_pubkey_fields: bool,
framework: Framework,
) -> Result<TokenStream> {
let struct_name_str = struct_name.to_string();
let on_chain_crate = framework.on_chain_crate();
let unpack_assignments: Vec<_> = fields
.iter()
.filter_map(|field| {
let field_name = field.ident.as_ref()?;
let field_type = &field.ty;
if field_name == "compression_info" {
return Some(quote! {
#field_name: #on_chain_crate::CompressionInfo::compressed()
});
}
Some(if is_pubkey_type(field_type) {
let error_msg = format!("{}: {}", struct_name_str, field_name);
let key_conversion = match framework {
Framework::Anchor => quote! { solana_pubkey::Pubkey::from(account.key()) },
Framework::Pinocchio => quote! { account.key() },
};
quote! {
#field_name: {
let account = accounts
.get_u8(packed.#field_name, #error_msg)
.map_err(|_| #on_chain_crate::LightSdkTypesError::InvalidInstructionData)?;
#key_conversion
}
}
} else if !has_pubkey_fields {
if is_copy_type(field_type) {
quote! { #field_name: packed.#field_name }
} else {
quote! { #field_name: packed.#field_name.clone() }
}
} else if is_copy_type(field_type) {
quote! { #field_name: packed.#field_name }
} else {
quote! { #field_name: packed.#field_name.clone() }
})
})
.collect();
Ok(quote! {
Ok(#struct_name {
#(#unpack_assignments,)*
})
})
}
fn generate_compress_as_assignments(
fields: &Punctuated<Field, Token![,]>,
compress_as_fields: &Option<CompressAsFields>,
_framework: Framework,
) -> TokenStream {
let Some(overrides) = compress_as_fields else {
return quote! {};
};
let assignments: Vec<_> = fields
.iter()
.filter_map(|field| {
let field_name = field.ident.as_ref()?;
if field_name == "compression_info" {
return None;
}
if field.attrs.iter().any(|attr| attr.path().is_ident("skip")) {
return None;
}
let override_field = overrides.fields.iter().find(|f| &f.name == field_name)?;
let value = &override_field.value;
Some(quote! {
self.#field_name = #value;
})
})
.collect();
quote! { #(#assignments)* }
}
fn generate_compress_as_impl_body(
fields: &Punctuated<Field, Token![,]>,
compress_as_fields: &Option<CompressAsFields>,
framework: Framework,
) -> TokenStream {
let on_chain_crate = framework.on_chain_crate();
let Some(overrides) = compress_as_fields else {
return quote! {
let mut result = self.clone();
result.compression_info = #on_chain_crate::CompressionInfo::compressed();
std::borrow::Cow::Owned(result)
};
};
let assignments: Vec<_> = fields
.iter()
.filter_map(|field| {
let field_name = field.ident.as_ref()?;
if field_name == "compression_info" {
return None;
}
if field.attrs.iter().any(|attr| attr.path().is_ident("skip")) {
return None;
}
let override_field = overrides.fields.iter().find(|f| &f.name == field_name)?;
let value = &override_field.value;
Some(quote! {
result.#field_name = #value;
})
})
.collect();
if assignments.is_empty() {
quote! {
let mut result = self.clone();
result.compression_info = #on_chain_crate::CompressionInfo::compressed();
std::borrow::Cow::Owned(result)
}
} else {
quote! {
let mut result = self.clone();
result.compression_info = #on_chain_crate::CompressionInfo::compressed();
#(#assignments)*
std::borrow::Cow::Owned(result)
}
}
}
#[cfg(test)]
mod tests {
use syn::parse_quote;
use super::*;
#[test]
fn test_light_pinocchio_custom_discriminator() {
let input: DeriveInput = parse_quote! {
#[light_pinocchio(discriminator = [1u8])]
pub struct OneByteRecord {
pub compression_info: CompressionInfo,
pub owner: [u8; 32],
}
};
let result = derive_light_pinocchio_account(input);
assert!(
result.is_ok(),
"LightPinocchioAccount with custom discriminator should succeed: {:?}",
result.err()
);
let output = result.unwrap().to_string();
assert!(
output.contains("LIGHT_DISCRIMINATOR"),
"Should have LIGHT_DISCRIMINATOR"
);
assert!(
output.contains("1 , 0 , 0 , 0 , 0 , 0 , 0 , 0")
|| output.contains("1, 0, 0, 0, 0, 0, 0, 0"),
"LIGHT_DISCRIMINATOR should be [1,0,0,0,0,0,0,0]"
);
assert!(
output.contains("LIGHT_DISCRIMINATOR_SLICE"),
"Should have LIGHT_DISCRIMINATOR_SLICE"
);
assert!(
output.contains("& [1u8]") || output.contains("& [1]"),
"LIGHT_DISCRIMINATOR_SLICE should be &[1] (1 byte), got: {output}"
);
}
#[test]
fn test_light_pinocchio_custom_discriminator_empty_rejected() {
let input: DeriveInput = parse_quote! {
#[light_pinocchio(discriminator = [])]
pub struct EmptyDisc {
pub compression_info: CompressionInfo,
pub owner: [u8; 32],
}
};
let result = derive_light_pinocchio_account(input);
assert!(
result.is_err(),
"Empty discriminator array should be rejected"
);
let err = result.unwrap_err().to_string();
assert!(
err.contains("at least one byte"),
"Error should mention 'at least one byte', got: {err}"
);
}
#[test]
fn test_light_pinocchio_custom_discriminator_too_long_rejected() {
let input: DeriveInput = parse_quote! {
#[light_pinocchio(discriminator = [1, 2, 3, 4, 5, 6, 7, 8, 9])]
pub struct TooLongDisc {
pub compression_info: CompressionInfo,
pub owner: [u8; 32],
}
};
let result = derive_light_pinocchio_account(input);
assert!(
result.is_err(),
"Discriminator longer than 8 bytes should be rejected"
);
let err = result.unwrap_err().to_string();
assert!(
err.contains("exceed 8 bytes"),
"Error should mention max length, got: {err}"
);
}
#[test]
fn test_light_pinocchio_discriminator_rejected_on_anchor() {
let input: DeriveInput = parse_quote! {
#[light_pinocchio(discriminator = [1u8])]
pub struct AnchorRecord {
pub compression_info: CompressionInfo,
pub owner: Pubkey,
}
};
let result = derive_light_account(input);
assert!(
result.is_err(),
"#[light_pinocchio(discriminator)] should be rejected with LightAccount (Anchor)"
);
}
#[test]
fn test_light_account_basic() {
let input: DeriveInput = parse_quote! {
pub struct UserRecord {
pub compression_info: CompressionInfo,
pub owner: Pubkey,
pub name: String,
pub score: u64,
}
};
let result = derive_light_account(input);
assert!(result.is_ok(), "LightAccount should succeed");
let output = result.unwrap().to_string();
assert!(output.contains("DataHasher"), "Should implement DataHasher");
assert!(
output.contains("ToByteArray"),
"Should implement ToByteArray"
);
assert!(
output.contains("LightDiscriminator"),
"Should implement LightDiscriminator"
);
assert!(
output.contains("LIGHT_DISCRIMINATOR"),
"Should have discriminator constant"
);
assert!(
output.contains("impl light_account :: LightAccount for UserRecord"),
"Should implement LightAccount trait"
);
assert!(
output.contains("PackedUserRecord"),
"Should generate Packed struct"
);
assert!(
output.contains("ACCOUNT_TYPE"),
"Should have ACCOUNT_TYPE constant"
);
assert!(
output.contains("INIT_SPACE"),
"Should have INIT_SPACE constant"
);
assert!(
output.contains("800"),
"Should have 800-byte size assertion"
);
assert!(
output.contains("compression_info"),
"Should have compression_info methods"
);
assert!(output.contains("fn pack"), "Should have pack method");
assert!(output.contains("fn unpack"), "Should have unpack method");
assert!(
output.contains("set_decompressed"),
"Should have set_decompressed method"
);
}
#[test]
fn test_light_account_with_compress_as() {
let input: DeriveInput = parse_quote! {
#[compress_as(start_time = 0, score = 0)]
pub struct GameSession {
pub compression_info: CompressionInfo,
pub session_id: u64,
pub player: Pubkey,
pub start_time: u64,
pub score: u64,
}
};
let result = derive_light_account(input);
assert!(
result.is_ok(),
"LightAccount with compress_as should succeed"
);
let output = result.unwrap().to_string();
assert!(
output.contains("LightAccount"),
"Should implement LightAccount"
);
}
#[test]
fn test_light_account_no_pubkey_fields() {
let input: DeriveInput = parse_quote! {
pub struct SimpleRecord {
pub compression_info: CompressionInfo,
pub id: u64,
pub value: u32,
}
};
let result = derive_light_account(input);
assert!(
result.is_ok(),
"LightAccount without Pubkey fields should succeed"
);
let output = result.unwrap().to_string();
assert!(output.contains("DataHasher"), "Should implement DataHasher");
assert!(
output.contains("LightDiscriminator"),
"Should implement LightDiscriminator"
);
assert!(
output.contains("LightAccount"),
"Should implement LightAccount"
);
}
#[test]
fn test_light_account_enum_fails() {
let input: DeriveInput = parse_quote! {
pub enum NotAStruct {
A,
B,
}
};
let result = derive_light_account(input);
assert!(result.is_err(), "LightAccount should fail for enums");
}
#[test]
fn test_light_account_missing_compression_info() {
let input: DeriveInput = parse_quote! {
pub struct MissingCompressionInfo {
pub id: u64,
pub value: u32,
}
};
let result = derive_light_account(input);
assert!(
result.is_err(),
"Should fail without compression_info field"
);
}
#[test]
fn test_light_account_compression_info_in_middle_fails() {
let input: DeriveInput = parse_quote! {
pub struct BadLayout {
pub id: u64,
pub compression_info: CompressionInfo,
pub value: u32,
}
};
let result = derive_light_account(input);
assert!(
result.is_err(),
"Should fail when compression_info is in middle"
);
}
#[test]
fn test_packed_struct_excludes_compression_info() {
let input: DeriveInput = parse_quote! {
pub struct UserRecord {
pub compression_info: CompressionInfo,
pub owner: Pubkey,
pub score: u64,
}
};
let result = derive_light_account(input);
assert!(result.is_ok());
let output = result.unwrap().to_string();
assert!(
output.contains("pub struct PackedUserRecord"),
"Should generate PackedUserRecord"
);
assert!(
output.contains("pub owner : u8"),
"Packed struct should have owner as u8"
);
}
}