1use af_ptbuilder::{ProgrammableTransactionBuilder, ptbuilder};
2use af_sui_types::{Argument, ObjectArg, ObjectId};
3use bytes::Bytes;
4use serde::{Deserialize, Serialize};
5
6use crate::price_info::PriceInfo;
7
8const ACCUMULATOR_MAGIC: [u8; 4] = [0x50, 0x4e, 0x41, 0x55];
9
10#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
12pub enum UpdatePayload {
13    Accumulator { vaa: Bytes, message: Bytes },
14    Normal(Vec<Bytes>),
15}
16
17impl UpdatePayload {
18    pub fn new(binary_update: Vec<Vec<u8>>) -> Result<Self, MixedVaasError> {
20        let mut bytes_vec: Vec<_> = binary_update.into_iter().map(Bytes::from).collect();
21
22        let accumulator_msg = bytes_vec
23            .iter()
24            .position(is_accumulator_msg)
25            .map(|index| bytes_vec.swap_remove(index));
26
27        if accumulator_msg.is_some() && !bytes_vec.is_empty() {
28            return Err(MixedVaasError);
29        }
30
31        let update = accumulator_msg.map_or(Self::Normal(bytes_vec), |bytes| Self::Accumulator {
32            vaa: accumulator_payload(&bytes),
33            message: bytes,
34        });
35        Ok(update)
36    }
37}
38
39#[derive(thiserror::Error, Debug)]
40#[error("Multiple accumulator messages or mixed accumulator and non-accumulator messages")]
41pub struct MixedVaasError;
42
43fn accumulator_payload(acc_message: &Bytes) -> Bytes {
45    let trailing_payload_size = acc_message[6] as usize;
48    let vaa_size_offset = 7 + trailing_payload_size + 1; let vaa_size = u16::from_be_bytes([
52        acc_message[vaa_size_offset],
53        acc_message[vaa_size_offset + 1],
54    ]) as usize;
55    let vaa_offset = vaa_size_offset + 2;
56    acc_message.slice(vaa_offset..(vaa_offset + vaa_size))
57}
58
59fn is_accumulator_msg(bytes: &Bytes) -> bool {
60    bytes[..4] == ACCUMULATOR_MAGIC
61}
62
63#[derive(Clone, Debug)]
65pub struct PtbArguments {
66    pub pyth_state: Argument,
68    pub wormhole_state: Argument,
70    pub price_info_objects: Vec<Argument>,
72    pub fee_coin: Argument,
74}
75
76#[extension_traits::extension(pub trait ProgrammableTransactionBuilderExt)]
77impl ProgrammableTransactionBuilder {
78    fn update_pyth_price_info_args(
85        &mut self,
86        pyth_state: ObjectArg,
87        wormhole_state: ObjectArg,
88        price_info_objects: Vec<ObjectArg>,
89        fee_coin: Argument,
90    ) -> Result<PtbArguments, af_ptbuilder::Error> {
91        ptbuilder!(self {
92            input obj pyth_state;
93            input obj wormhole_state;
94        });
95        let mut vars = PtbArguments {
96            pyth_state,
97            wormhole_state,
98            price_info_objects: vec![],
99            fee_coin,
100        };
101        for pio in price_info_objects {
102            ptbuilder!(self {
103                input obj pio;
104            });
105            vars.price_info_objects.push(pio);
106        }
107        Ok(vars)
108    }
109
110    fn update_pyth_price_info(
116        &mut self,
117        pyth_pkg: ObjectId,
118        wormhole_pkg: ObjectId,
119        arguments: PtbArguments,
120        update: UpdatePayload,
121    ) -> Result<(), af_ptbuilder::Error> {
122        let PtbArguments {
123            pyth_state,
124            wormhole_state,
125            price_info_objects,
126            fee_coin,
127        } = arguments;
128
129        ptbuilder!(self {
131            package pyth: pyth_pkg;
132            package wormhole: wormhole_pkg;
133
134            input obj clock: ObjectArg::CLOCK_IMM;
135        });
136
137        let mut price_updates = match update {
138            UpdatePayload::Accumulator { vaa, message } => {
139                ptbuilder!(self {
140                    input pure vaa: vaa.as_ref();
141                    input pure accumulator_msg: message.as_ref();
142
143                    let verified_vaa = wormhole::vaa::parse_and_verify(wormhole_state, vaa, clock);
144                    let updates = pyth::pyth::create_authenticated_price_infos_using_accumulator(
145                        pyth_state,
146                        accumulator_msg,
147                        verified_vaa,
148                        clock
149                    );
150                });
151                updates
152            }
153            UpdatePayload::Normal(bytes) => {
154                let mut verified_vaas = Vec::new();
155                for vaa in bytes {
156                    ptbuilder!(self {
157                        input pure vaa: vaa.as_ref();
158                        let verified = wormhole::vaa::parse_and_verify(wormhole_state, vaa, clock);
159                    });
160                    verified_vaas.push(verified);
161                }
162                ptbuilder!(self {
163                    let verified_vaas = command! MakeMoveVec(None, verified_vaas);
164                    let updates = pyth::pyth::create_price_infos_hot_potato(
165                        pyth_state,
166                        verified_vaas,
167                        clock
168                    );
169                });
170                updates
171            }
172        };
173
174        ptbuilder!(self {
175            let base_update_fee = pyth::state::get_base_update_fee(pyth_state);
176        });
177        let fee_coins =
178            self.split_coins_into_vec(fee_coin, vec![base_update_fee; price_info_objects.len()]);
179        for (price_info_object, fee) in price_info_objects.into_iter().zip(fee_coins) {
180            ptbuilder!(self {
181                let price_updates_ = pyth::pyth::update_single_price_feed(
182                    pyth_state,
183                    price_updates,
184                    price_info_object,
185                    fee,
186                    clock,
187                );
188            });
189            price_updates = price_updates_;
190        }
191        ptbuilder!(self {
192            type T = PriceInfo::type_(pyth_pkg.into()).into();
193            pyth::hot_potato_vector::destroy<T>(price_updates);
194        });
195
196        Ok(())
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    const ACCUMULATOR_MAGIC_HEX: &str = "504e4155";
205
206    #[test]
207    fn magics_match() {
208        assert_eq!(hex::encode(ACCUMULATOR_MAGIC), ACCUMULATOR_MAGIC_HEX);
209    }
210}