Skip to main content

pump_rust_client/
account_wrapper.rs

1//! Generic newtype that lets `Account<'info, T>` tolerate buffers shorter
2//! than `T`'s current Borsh layout by zero-padding to the expected
3//! Anchor-serialized size before deserializing.
4
5use std::io::{Read, Write};
6use std::ops::{Deref, DerefMut};
7
8use anchor_lang::prelude::Pubkey;
9use anchor_lang::{
10    AccountDeserialize, AccountSerialize, AnchorDeserialize, AnchorSerialize, Discriminator,
11    Owner, Result,
12};
13
14#[derive(Clone, Debug, Default, PartialEq, Eq)]
15#[repr(transparent)]
16pub struct AccountWrapper<T>(pub T);
17
18impl<T> AccountWrapper<T> {
19    pub fn new(inner: T) -> Self {
20        Self(inner)
21    }
22
23    pub fn into_inner(self) -> T {
24        self.0
25    }
26}
27
28impl<T> Deref for AccountWrapper<T> {
29    type Target = T;
30    fn deref(&self) -> &T {
31        &self.0
32    }
33}
34
35impl<T> DerefMut for AccountWrapper<T> {
36    fn deref_mut(&mut self) -> &mut T {
37        &mut self.0
38    }
39}
40
41impl<T> From<T> for AccountWrapper<T> {
42    fn from(inner: T) -> Self {
43        Self(inner)
44    }
45}
46
47impl<T: Discriminator> Discriminator for AccountWrapper<T> {
48    const DISCRIMINATOR: &'static [u8] = T::DISCRIMINATOR;
49}
50
51impl<T: Owner> Owner for AccountWrapper<T> {
52    fn owner() -> Pubkey {
53        T::owner()
54    }
55}
56
57impl<T: AccountSerialize> AccountSerialize for AccountWrapper<T> {
58    fn try_serialize<W: Write>(&self, writer: &mut W) -> Result<()> {
59        self.0.try_serialize(writer)
60    }
61}
62
63impl<T: AnchorSerialize> AnchorSerialize for AccountWrapper<T> {
64    fn serialize<W: Write>(&self, writer: &mut W) -> std::io::Result<()> {
65        self.0.serialize(writer)
66    }
67}
68
69impl<T: AnchorDeserialize> AnchorDeserialize for AccountWrapper<T> {
70    fn deserialize_reader<R: Read>(reader: &mut R) -> std::io::Result<Self> {
71        T::deserialize_reader(reader).map(Self)
72    }
73}
74
75#[cfg(feature = "idl-build")]
76impl<T> anchor_lang::IdlBuild for AccountWrapper<T> {}
77
78impl<T> AccountDeserialize for AccountWrapper<T>
79where
80    T: AccountDeserialize + AccountSerialize + Default,
81{
82    fn try_deserialize(buf: &mut &[u8]) -> Result<Self> {
83        let mut canonical = Vec::new();
84        T::default().try_serialize(&mut canonical)?;
85        let expected = canonical.len();
86
87        if buf.len() >= expected {
88            return T::try_deserialize(buf).map(Self);
89        }
90        let mut padded = vec![0u8; expected];
91        padded[..buf.len()].copy_from_slice(buf);
92        let mut slice: &[u8] = &padded;
93        let inner = T::try_deserialize(&mut slice)?;
94        *buf = &buf[buf.len()..];
95        Ok(Self(inner))
96    }
97
98    fn try_deserialize_unchecked(buf: &mut &[u8]) -> Result<Self> {
99        let mut canonical = Vec::new();
100        T::default().try_serialize(&mut canonical)?;
101        let expected = canonical.len();
102
103        if buf.len() >= expected {
104            return T::try_deserialize_unchecked(buf).map(Self);
105        }
106        let mut padded = vec![0u8; expected];
107        padded[..buf.len()].copy_from_slice(buf);
108        let mut slice: &[u8] = &padded;
109        let inner = T::try_deserialize_unchecked(&mut slice)?;
110        *buf = &buf[buf.len()..];
111        Ok(Self(inner))
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use anchor_lang::AccountSerialize;
119    use solana_program::pubkey::Pubkey;
120
121    use crate::constants;
122    use crate::state::{BondingCurve, BondingCurveFromIdl};
123
124    fn fake_pubkey(seed: u8) -> Pubkey {
125        Pubkey::new_from_array([seed; 32])
126    }
127
128    #[test]
129    fn try_deserialize_pads_short_buffer() {
130        let original = BondingCurveFromIdl {
131            virtual_token_reserves: 1_000_000_000,
132            virtual_quote_reserves: 30_000_000_000,
133            real_token_reserves: 800_000_000,
134            real_quote_reserves: 0,
135            token_total_supply: 1_000_000_000_000,
136            complete: false,
137            creator: fake_pubkey(7),
138            is_mayhem_mode: true,
139            is_cashback_coin: false,
140            quote_mint: constants::NATIVE_MINT,
141        };
142        let mut full = Vec::new();
143        original.try_serialize(&mut full).expect("serialize");
144
145        // Strip the trailing bytes — simulates a pre-extension on-chain
146        // account. The wrapper should zero-pad and decode successfully,
147        // yielding default-zero values for fields that lived in the
148        // truncated tail.
149        let truncated_len = full.len() - 33; // drop quote_mint (32) + is_cashback_coin (1)
150        let short = &full[..truncated_len];
151
152        let decoded =
153            <BondingCurve as AccountDeserialize>::try_deserialize(&mut &short[..]).expect("decode");
154        assert_eq!(decoded.virtual_token_reserves, original.virtual_token_reserves);
155        assert_eq!(decoded.creator, original.creator);
156        assert_eq!(decoded.is_mayhem_mode, original.is_mayhem_mode);
157        // Stripped fields land at their `Default` value (zero / Pubkey::default()).
158        assert!(!decoded.is_cashback_coin);
159        assert_eq!(decoded.quote_mint, Pubkey::default());
160    }
161
162    #[test]
163    fn try_deserialize_full_buffer_is_passthrough() {
164        let original = BondingCurveFromIdl {
165            virtual_token_reserves: 42,
166            virtual_quote_reserves: 0,
167            real_token_reserves: 0,
168            real_quote_reserves: 0,
169            token_total_supply: 0,
170            complete: true,
171            creator: fake_pubkey(3),
172            is_mayhem_mode: false,
173            is_cashback_coin: true,
174            quote_mint: constants::NATIVE_MINT,
175        };
176        let mut full = Vec::new();
177        original.try_serialize(&mut full).expect("serialize");
178
179        let decoded = <BondingCurve as AccountDeserialize>::try_deserialize(&mut &full[..])
180            .expect("decode");
181        assert_eq!(decoded.virtual_token_reserves, 42);
182        assert_eq!(decoded.complete, original.complete);
183        assert_eq!(decoded.creator, original.creator);
184        assert_eq!(decoded.is_cashback_coin, original.is_cashback_coin);
185        assert_eq!(decoded.quote_mint, original.quote_mint);
186    }
187}