1use af_ptbuilder::{ProgrammableTransactionBuilder, ptbuilder};
2use af_sui_types::{Argument, ObjectArg, ObjectId};
3use bytes::Bytes;
4use serde::{Deserialize, Serialize};
5
6use crate::price_info::PriceInfo;
7
8const ACCUMULATOR_MAGIC: [u8; 4] = [0x50, 0x4e, 0x41, 0x55];
9
10#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
12pub enum UpdatePayload {
13 Accumulator { vaa: Bytes, message: Bytes },
14 Normal(Vec<Bytes>),
15}
16
17impl UpdatePayload {
18 pub fn new(binary_update: Vec<Vec<u8>>) -> Result<Self, MixedVaasError> {
20 let mut bytes_vec: Vec<_> = binary_update.into_iter().map(Bytes::from).collect();
21
22 let accumulator_msg = bytes_vec
23 .iter()
24 .position(is_accumulator_msg)
25 .map(|index| bytes_vec.swap_remove(index));
26
27 if accumulator_msg.is_some() && !bytes_vec.is_empty() {
28 return Err(MixedVaasError);
29 }
30
31 let update = accumulator_msg.map_or(Self::Normal(bytes_vec), |bytes| Self::Accumulator {
32 vaa: accumulator_payload(&bytes),
33 message: bytes,
34 });
35 Ok(update)
36 }
37}
38
39#[derive(thiserror::Error, Debug)]
40#[error("Multiple accumulator messages or mixed accumulator and non-accumulator messages")]
41pub struct MixedVaasError;
42
43fn accumulator_payload(acc_message: &Bytes) -> Bytes {
45 let trailing_payload_size = acc_message[6] as usize;
48 let vaa_size_offset = 7 + trailing_payload_size + 1; let vaa_size = u16::from_be_bytes([
52 acc_message[vaa_size_offset],
53 acc_message[vaa_size_offset + 1],
54 ]) as usize;
55 let vaa_offset = vaa_size_offset + 2;
56 acc_message.slice(vaa_offset..(vaa_offset + vaa_size))
57}
58
59fn is_accumulator_msg(bytes: &Bytes) -> bool {
60 bytes[..4] == ACCUMULATOR_MAGIC
61}
62
63#[derive(Clone, Debug)]
65pub struct PtbArguments {
66 pub pyth_state: Argument,
68 pub wormhole_state: Argument,
70 pub price_info_objects: Vec<Argument>,
72 pub fee_coin: Argument,
74}
75
76#[extension_traits::extension(pub trait ProgrammableTransactionBuilderExt)]
77impl ProgrammableTransactionBuilder {
78 fn update_pyth_price_info_args(
85 &mut self,
86 pyth_state: ObjectArg,
87 wormhole_state: ObjectArg,
88 price_info_objects: Vec<ObjectArg>,
89 fee_coin: Argument,
90 ) -> Result<PtbArguments, af_ptbuilder::Error> {
91 ptbuilder!(self {
92 input obj pyth_state;
93 input obj wormhole_state;
94 });
95 let mut vars = PtbArguments {
96 pyth_state,
97 wormhole_state,
98 price_info_objects: vec![],
99 fee_coin,
100 };
101 for pio in price_info_objects {
102 ptbuilder!(self {
103 input obj pio;
104 });
105 vars.price_info_objects.push(pio);
106 }
107 Ok(vars)
108 }
109
110 fn update_pyth_price_info(
116 &mut self,
117 pyth_pkg: ObjectId,
118 wormhole_pkg: ObjectId,
119 arguments: PtbArguments,
120 update: UpdatePayload,
121 ) -> Result<(), af_ptbuilder::Error> {
122 let PtbArguments {
123 pyth_state,
124 wormhole_state,
125 price_info_objects,
126 fee_coin,
127 } = arguments;
128
129 ptbuilder!(self {
131 package pyth: pyth_pkg;
132 package wormhole: wormhole_pkg;
133
134 input obj clock: ObjectArg::CLOCK_IMM;
135 });
136
137 let mut price_updates = match update {
138 UpdatePayload::Accumulator { vaa, message } => {
139 ptbuilder!(self {
140 input pure vaa: vaa.as_ref();
141 input pure accumulator_msg: message.as_ref();
142
143 let verified_vaa = wormhole::vaa::parse_and_verify(wormhole_state, vaa, clock);
144 let updates = pyth::pyth::create_authenticated_price_infos_using_accumulator(
145 pyth_state,
146 accumulator_msg,
147 verified_vaa,
148 clock
149 );
150 });
151 updates
152 }
153 UpdatePayload::Normal(bytes) => {
154 let mut verified_vaas = Vec::new();
155 for vaa in bytes {
156 ptbuilder!(self {
157 input pure vaa: vaa.as_ref();
158 let verified = wormhole::vaa::parse_and_verify(wormhole_state, vaa, clock);
159 });
160 verified_vaas.push(verified);
161 }
162 ptbuilder!(self {
163 let verified_vaas = command! MakeMoveVec(None, verified_vaas);
164 let updates = pyth::pyth::create_price_infos_hot_potato(
165 pyth_state,
166 verified_vaas,
167 clock
168 );
169 });
170 updates
171 }
172 };
173
174 ptbuilder!(self {
175 let base_update_fee = pyth::state::get_base_update_fee(pyth_state);
176 });
177 let fee_coins =
178 self.split_coins_into_vec(fee_coin, vec![base_update_fee; price_info_objects.len()]);
179 for (price_info_object, fee) in price_info_objects.into_iter().zip(fee_coins) {
180 ptbuilder!(self {
181 let price_updates_ = pyth::pyth::update_single_price_feed(
182 pyth_state,
183 price_updates,
184 price_info_object,
185 fee,
186 clock,
187 );
188 });
189 price_updates = price_updates_;
190 }
191 ptbuilder!(self {
192 type T = PriceInfo::type_(pyth_pkg.into()).into();
193 pyth::hot_potato_vector::destroy<T>(price_updates);
194 });
195
196 Ok(())
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 const ACCUMULATOR_MAGIC_HEX: &str = "504e4155";
205
206 #[test]
207 fn magics_match() {
208 assert_eq!(hex::encode(ACCUMULATOR_MAGIC), ACCUMULATOR_MAGIC_HEX);
209 }
210}