Skip to main content

pump_rust_client/
account_wrapper.rs

1//! Generic newtype that lets `Account<'info, T>` tolerate buffers shorter
2//! than `T`'s current Borsh layout by zero-padding to the expected
3//! Anchor-serialized size before deserializing.
4
5use std::io::{Read, Write};
6use std::ops::{Deref, DerefMut};
7
8use anchor_lang::prelude::Pubkey;
9use anchor_lang::{
10    AccountDeserialize, AccountSerialize, AnchorDeserialize, AnchorSerialize, Discriminator,
11    Owner, Result,
12};
13
14#[derive(Clone, Debug, Default, PartialEq, Eq)]
15#[repr(transparent)]
16pub struct AccountWrapper<T>(pub T);
17
18impl<T> AccountWrapper<T> {
19    pub fn new(inner: T) -> Self {
20        Self(inner)
21    }
22
23    pub fn into_inner(self) -> T {
24        self.0
25    }
26}
27
28impl<T> Deref for AccountWrapper<T> {
29    type Target = T;
30    fn deref(&self) -> &T {
31        &self.0
32    }
33}
34
35impl<T> DerefMut for AccountWrapper<T> {
36    fn deref_mut(&mut self) -> &mut T {
37        &mut self.0
38    }
39}
40
41impl<T> From<T> for AccountWrapper<T> {
42    fn from(inner: T) -> Self {
43        Self(inner)
44    }
45}
46
47impl<T: Discriminator> Discriminator for AccountWrapper<T> {
48    const DISCRIMINATOR: &'static [u8] = T::DISCRIMINATOR;
49}
50
51impl<T: Owner> Owner for AccountWrapper<T> {
52    fn owner() -> Pubkey {
53        T::owner()
54    }
55}
56
57impl<T: AccountSerialize> AccountSerialize for AccountWrapper<T> {
58    fn try_serialize<W: Write>(&self, writer: &mut W) -> Result<()> {
59        self.0.try_serialize(writer)
60    }
61}
62
63impl<T: AnchorSerialize> AnchorSerialize for AccountWrapper<T> {
64    fn serialize<W: Write>(&self, writer: &mut W) -> std::io::Result<()> {
65        self.0.serialize(writer)
66    }
67}
68
69impl<T: AnchorDeserialize> AnchorDeserialize for AccountWrapper<T> {
70    fn deserialize_reader<R: Read>(reader: &mut R) -> std::io::Result<Self> {
71        T::deserialize_reader(reader).map(Self)
72    }
73}
74
75impl<T> AccountDeserialize for AccountWrapper<T>
76where
77    T: AccountDeserialize + AccountSerialize + Default,
78{
79    fn try_deserialize(buf: &mut &[u8]) -> Result<Self> {
80        let mut canonical = Vec::new();
81        T::default().try_serialize(&mut canonical)?;
82        let expected = canonical.len();
83
84        if buf.len() >= expected {
85            return T::try_deserialize(buf).map(Self);
86        }
87        let mut padded = vec![0u8; expected];
88        padded[..buf.len()].copy_from_slice(buf);
89        let mut slice: &[u8] = &padded;
90        let inner = T::try_deserialize(&mut slice)?;
91        *buf = &buf[buf.len()..];
92        Ok(Self(inner))
93    }
94
95    fn try_deserialize_unchecked(buf: &mut &[u8]) -> Result<Self> {
96        let mut canonical = Vec::new();
97        T::default().try_serialize(&mut canonical)?;
98        let expected = canonical.len();
99
100        if buf.len() >= expected {
101            return T::try_deserialize_unchecked(buf).map(Self);
102        }
103        let mut padded = vec![0u8; expected];
104        padded[..buf.len()].copy_from_slice(buf);
105        let mut slice: &[u8] = &padded;
106        let inner = T::try_deserialize_unchecked(&mut slice)?;
107        *buf = &buf[buf.len()..];
108        Ok(Self(inner))
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use anchor_lang::AccountSerialize;
116    use solana_program::pubkey::Pubkey;
117
118    use crate::constants;
119    use crate::state::{BondingCurve, BondingCurveFromIdl};
120
121    fn fake_pubkey(seed: u8) -> Pubkey {
122        Pubkey::new_from_array([seed; 32])
123    }
124
125    #[test]
126    fn try_deserialize_pads_short_buffer() {
127        let original = BondingCurveFromIdl {
128            virtual_token_reserves: 1_000_000_000,
129            virtual_quote_reserves: 30_000_000_000,
130            real_token_reserves: 800_000_000,
131            real_quote_reserves: 0,
132            token_total_supply: 1_000_000_000_000,
133            complete: false,
134            creator: fake_pubkey(7),
135            is_mayhem_mode: true,
136            is_cashback_coin: false,
137            quote_mint: constants::NATIVE_MINT,
138        };
139        let mut full = Vec::new();
140        original.try_serialize(&mut full).expect("serialize");
141
142        // Strip the trailing bytes — simulates a pre-extension on-chain
143        // account. The wrapper should zero-pad and decode successfully,
144        // yielding default-zero values for fields that lived in the
145        // truncated tail.
146        let truncated_len = full.len() - 33; // drop quote_mint (32) + is_cashback_coin (1)
147        let short = &full[..truncated_len];
148
149        let decoded =
150            <BondingCurve as AccountDeserialize>::try_deserialize(&mut &short[..]).expect("decode");
151        assert_eq!(decoded.virtual_token_reserves, original.virtual_token_reserves);
152        assert_eq!(decoded.creator, original.creator);
153        assert_eq!(decoded.is_mayhem_mode, original.is_mayhem_mode);
154        // Stripped fields land at their `Default` value (zero / Pubkey::default()).
155        assert!(!decoded.is_cashback_coin);
156        assert_eq!(decoded.quote_mint, Pubkey::default());
157    }
158
159    #[test]
160    fn try_deserialize_full_buffer_is_passthrough() {
161        let original = BondingCurveFromIdl {
162            virtual_token_reserves: 42,
163            virtual_quote_reserves: 0,
164            real_token_reserves: 0,
165            real_quote_reserves: 0,
166            token_total_supply: 0,
167            complete: true,
168            creator: fake_pubkey(3),
169            is_mayhem_mode: false,
170            is_cashback_coin: true,
171            quote_mint: constants::NATIVE_MINT,
172        };
173        let mut full = Vec::new();
174        original.try_serialize(&mut full).expect("serialize");
175
176        let decoded = <BondingCurve as AccountDeserialize>::try_deserialize(&mut &full[..])
177            .expect("decode");
178        assert_eq!(decoded.virtual_token_reserves, 42);
179        assert_eq!(decoded.complete, original.complete);
180        assert_eq!(decoded.creator, original.creator);
181        assert_eq!(decoded.is_cashback_coin, original.is_cashback_coin);
182        assert_eq!(decoded.quote_mint, original.quote_mint);
183    }
184}