Skip to main content

pump_rust_client/
account_wrapper.rs

1//! Generic newtype that lets `Account<'info, T>` tolerate buffers shorter
2//! than `T`'s current Borsh layout by zero-padding to the expected
3//! Anchor-serialized size before deserializing.
4
5use std::io::{Read, Write};
6use std::ops::{Deref, DerefMut};
7
8use anchor_lang::prelude::Pubkey;
9use anchor_lang::{
10    AccountDeserialize, AccountSerialize, AnchorDeserialize, AnchorSerialize, Discriminator, Owner,
11    Result,
12};
13
14#[derive(Clone, Debug, Default, PartialEq, Eq)]
15#[repr(transparent)]
16pub struct AccountWrapper<T>(pub T);
17
18impl<T> AccountWrapper<T> {
19    pub fn new(inner: T) -> Self {
20        Self(inner)
21    }
22
23    pub fn into_inner(self) -> T {
24        self.0
25    }
26}
27
28impl<T> Deref for AccountWrapper<T> {
29    type Target = T;
30    fn deref(&self) -> &T {
31        &self.0
32    }
33}
34
35impl<T> DerefMut for AccountWrapper<T> {
36    fn deref_mut(&mut self) -> &mut T {
37        &mut self.0
38    }
39}
40
41impl<T> From<T> for AccountWrapper<T> {
42    fn from(inner: T) -> Self {
43        Self(inner)
44    }
45}
46
47impl<T: Discriminator> Discriminator for AccountWrapper<T> {
48    const DISCRIMINATOR: &'static [u8] = T::DISCRIMINATOR;
49}
50
51impl<T: Owner> Owner for AccountWrapper<T> {
52    fn owner() -> Pubkey {
53        T::owner()
54    }
55}
56
57impl<T: AccountSerialize> AccountSerialize for AccountWrapper<T> {
58    fn try_serialize<W: Write>(&self, writer: &mut W) -> Result<()> {
59        self.0.try_serialize(writer)
60    }
61}
62
63impl<T: AnchorSerialize> AnchorSerialize for AccountWrapper<T> {
64    fn serialize<W: Write>(&self, writer: &mut W) -> std::io::Result<()> {
65        self.0.serialize(writer)
66    }
67}
68
69impl<T: AnchorDeserialize> AnchorDeserialize for AccountWrapper<T> {
70    fn deserialize_reader<R: Read>(reader: &mut R) -> std::io::Result<Self> {
71        T::deserialize_reader(reader).map(Self)
72    }
73}
74
75#[cfg(feature = "idl-build")]
76impl<T> anchor_lang::IdlBuild for AccountWrapper<T> {}
77
78impl<T> AccountDeserialize for AccountWrapper<T>
79where
80    T: AccountDeserialize + AccountSerialize + Default,
81{
82    fn try_deserialize(buf: &mut &[u8]) -> Result<Self> {
83        let mut canonical = Vec::new();
84        T::default().try_serialize(&mut canonical)?;
85        let expected = canonical.len();
86
87        if buf.len() >= expected {
88            return T::try_deserialize(buf).map(Self);
89        }
90        let mut padded = vec![0u8; expected];
91        padded[..buf.len()].copy_from_slice(buf);
92        let mut slice: &[u8] = &padded;
93        let inner = T::try_deserialize(&mut slice)?;
94        *buf = &buf[buf.len()..];
95        Ok(Self(inner))
96    }
97
98    fn try_deserialize_unchecked(buf: &mut &[u8]) -> Result<Self> {
99        let mut canonical = Vec::new();
100        T::default().try_serialize(&mut canonical)?;
101        let expected = canonical.len();
102
103        if buf.len() >= expected {
104            return T::try_deserialize_unchecked(buf).map(Self);
105        }
106        let mut padded = vec![0u8; expected];
107        padded[..buf.len()].copy_from_slice(buf);
108        let mut slice: &[u8] = &padded;
109        let inner = T::try_deserialize_unchecked(&mut slice)?;
110        *buf = &buf[buf.len()..];
111        Ok(Self(inner))
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use anchor_lang::AccountSerialize;
119    use solana_program::pubkey::Pubkey;
120
121    use crate::constants;
122    use crate::state::{BondingCurve, BondingCurveFromIdl};
123
124    fn fake_pubkey(seed: u8) -> Pubkey {
125        Pubkey::new_from_array([seed; 32])
126    }
127
128    #[test]
129    fn try_deserialize_pads_short_buffer() {
130        let original = BondingCurveFromIdl {
131            virtual_token_reserves: 1_000_000_000,
132            virtual_quote_reserves: 30_000_000_000,
133            real_token_reserves: 800_000_000,
134            real_quote_reserves: 0,
135            token_total_supply: 1_000_000_000_000,
136            complete: false,
137            creator: fake_pubkey(7),
138            is_mayhem_mode: true,
139            is_cashback_coin: false,
140            quote_mint: constants::NATIVE_MINT,
141        };
142        let mut full = Vec::new();
143        original.try_serialize(&mut full).expect("serialize");
144
145        // Strip the trailing bytes — simulates a pre-extension on-chain
146        // account. The wrapper should zero-pad and decode successfully,
147        // yielding default-zero values for fields that lived in the
148        // truncated tail.
149        let truncated_len = full.len() - 33; // drop quote_mint (32) + is_cashback_coin (1)
150        let short = &full[..truncated_len];
151
152        let decoded =
153            <BondingCurve as AccountDeserialize>::try_deserialize(&mut &short[..]).expect("decode");
154        assert_eq!(
155            decoded.virtual_token_reserves,
156            original.virtual_token_reserves
157        );
158        assert_eq!(decoded.creator, original.creator);
159        assert_eq!(decoded.is_mayhem_mode, original.is_mayhem_mode);
160        // Stripped fields land at their `Default` value (zero / Pubkey::default()).
161        assert!(!decoded.is_cashback_coin);
162        assert_eq!(decoded.quote_mint, Pubkey::default());
163    }
164
165    #[test]
166    fn try_deserialize_full_buffer_is_passthrough() {
167        let original = BondingCurveFromIdl {
168            virtual_token_reserves: 42,
169            virtual_quote_reserves: 0,
170            real_token_reserves: 0,
171            real_quote_reserves: 0,
172            token_total_supply: 0,
173            complete: true,
174            creator: fake_pubkey(3),
175            is_mayhem_mode: false,
176            is_cashback_coin: true,
177            quote_mint: constants::NATIVE_MINT,
178        };
179        let mut full = Vec::new();
180        original.try_serialize(&mut full).expect("serialize");
181
182        let decoded =
183            <BondingCurve as AccountDeserialize>::try_deserialize(&mut &full[..]).expect("decode");
184        assert_eq!(decoded.virtual_token_reserves, 42);
185        assert_eq!(decoded.complete, original.complete);
186        assert_eq!(decoded.creator, original.creator);
187        assert_eq!(decoded.is_cashback_coin, original.is_cashback_coin);
188        assert_eq!(decoded.quote_mint, original.quote_mint);
189    }
190}