drift_rs/
utils.rs

1//! SDK utility functions
2
3use base64::Engine;
4use bytemuck::{bytes_of, Pod, Zeroable};
5use serde_json::json;
6use solana_sdk::{
7    account::Account, address_lookup_table::AddressLookupTableAccount, bs58,
8    instruction::Instruction, pubkey::Pubkey, signature::Keypair,
9};
10
11use crate::types::{SdkError, SdkResult};
12
13// kudos @wphan
14/// Try to parse secret `key` string
15///
16/// Returns error if the key cannot be parsed
17pub fn read_keypair_str_multi_format(key: &str) -> SdkResult<Keypair> {
18    // strip out any white spaces and new line/carriage return characters
19    let key = key.replace([' ', '\n', '\r', '[', ']'], "");
20
21    // first try to decode as a byte array
22    if key.contains(',') {
23        // decode the numbers array into json string
24        let bytes: Result<Vec<u8>, _> = key.split(',').map(|x| x.parse::<u8>()).collect();
25        if let Ok(bytes) = bytes {
26            return Keypair::from_bytes(&bytes).map_err(|_| SdkError::InvalidSeed);
27        } else {
28            return Err(SdkError::InvalidSeed);
29        }
30    }
31
32    // try to decode as base58 string
33    if let Ok(bytes) = bs58::decode(key.as_bytes()).into_vec() {
34        return Keypair::from_bytes(&bytes).map_err(|_| SdkError::InvalidSeed);
35    }
36
37    // try to decode as base64 string
38    if let Ok(bytes) = base64::engine::general_purpose::STANDARD.decode(key.as_bytes()) {
39        return Keypair::from_bytes(&bytes).map_err(|_| SdkError::InvalidSeed);
40    }
41
42    Err(SdkError::InvalidSeed)
43}
44
45/// Try load a `Keypair` from a file path or given string, supports json format and base58 format.
46pub fn load_keypair_multi_format(path_or_key: &str) -> SdkResult<Keypair> {
47    if let Ok(data) = std::fs::read_to_string(path_or_key) {
48        read_keypair_str_multi_format(data.as_str())
49    } else {
50        read_keypair_str_multi_format(path_or_key)
51    }
52}
53
54const LOOKUP_TABLE_META_SIZE: usize = 56;
55
56/// modified from sdk.1.17.x
57/// https://docs.rs/solana-program/latest/src/solana_program/address_lookup_table/state.rs.html#192
58pub fn deserialize_alt(address: Pubkey, account: &Account) -> SdkResult<AddressLookupTableAccount> {
59    let raw_addresses_data: &[u8] = account.data.get(LOOKUP_TABLE_META_SIZE..).ok_or({
60        // Should be impossible because table accounts must
61        // always be LOOKUP_TABLE_META_SIZE in length
62        SdkError::InvalidAccount
63    })?;
64    let addresses = bytemuck::try_cast_slice(raw_addresses_data).map_err(|_| {
65        // Should be impossible because raw address data
66        // should be aligned and sized in multiples of 32 bytes
67        SdkError::InvalidAccount
68    })?;
69
70    Ok(AddressLookupTableAccount {
71        key: address,
72        addresses: addresses.to_vec(),
73    })
74}
75
76pub fn http_to_ws(url: &str) -> Result<String, &'static str> {
77    let base_url = if url.starts_with("http://") {
78        url.replacen("http://", "ws://", 1)
79    } else if url.starts_with("https://") {
80        url.replacen("https://", "wss://", 1)
81    } else {
82        return Err("Invalid URL scheme");
83    };
84
85    Ok(format!("{}/ws", base_url.trim_end_matches('/')))
86}
87
88/// Convert a url string into a Ws equivalent
89pub fn get_ws_url(url: &str) -> SdkResult<String> {
90    if url.starts_with("http://") || url.starts_with("https://") {
91        Ok(url.replacen("http", "ws", 1))
92    } else if url.starts_with("wss://") || url.starts_with("ws://") {
93        Ok(url.to_string())
94    } else {
95        #[cfg(test)]
96        {
97            if url.starts_with("MockSender") {
98                return Ok("ws://mock.sender.com".into());
99            }
100        }
101        Err(SdkError::InvalidUrl)
102    }
103}
104
105/// Convert a url string into an Http equivalent
106pub fn get_http_url(url: &str) -> SdkResult<String> {
107    if url.starts_with("http://") || url.starts_with("https://") {
108        Ok(url.to_string())
109    } else if url.starts_with("ws://") || url.starts_with("wss://") {
110        Ok(url.replacen("ws", "http", 1))
111    } else {
112        Err(SdkError::InvalidUrl)
113    }
114}
115
116pub fn dlob_subscribe_ws_json(market: &str) -> String {
117    json!({
118        "type": "subscribe",
119        "marketType": if market.ends_with("perp") {
120            "perp"
121        } else {
122            "spot"
123        },
124        "channel": "orderbook",
125        "market": market,
126    })
127    .to_string()
128}
129
130pub fn zero_account_to_bytes<T: bytemuck::Pod + anchor_lang::Discriminator>(account: T) -> Vec<u8> {
131    let mut account_data = vec![0; 8 + std::mem::size_of::<T>()];
132    account_data[0..8].copy_from_slice(T::DISCRIMINATOR);
133    account_data[8..].copy_from_slice(bytemuck::bytes_of(&account));
134    account_data
135}
136
137pub mod test_envs {
138    //! test env vars
139    use solana_sdk::signature::Keypair;
140
141    /// solana mainnet endpoint
142    pub fn mainnet_endpoint() -> String {
143        std::env::var("TEST_MAINNET_RPC_ENDPOINT").expect("TEST_MAINNET_RPC_ENDPOINT set")
144    }
145    /// solana devnet endpoint
146    pub fn devnet_endpoint() -> String {
147        std::env::var("TEST_DEVNET_RPC_ENDPOINT")
148            .unwrap_or_else(|_| "https://api.devnet.solana.com".to_string())
149    }
150    /// keypair for integration tests
151    pub fn test_keypair() -> Keypair {
152        let private_key = std::env::var("TEST_PRIVATE_KEY").expect("TEST_PRIVATE_KEY set");
153        Keypair::from_base58_string(private_key.as_str())
154    }
155}
156
157/// copy of `solana_sdk::ed25519_instruction::Ed25519SignatureOffsets`
158/// it is missing useful constructors and public fields
159#[derive(Default, Debug, Copy, Clone, Zeroable, Pod, Eq, PartialEq)]
160#[repr(C)]
161struct Ed25519SignatureOffsets {
162    pub signature_offset: u16, // offset to ed25519 signature of 64 bytes
163    pub signature_instruction_index: u16, // instruction index to find signature
164    pub public_key_offset: u16, // offset to public key of 32 bytes
165    pub public_key_instruction_index: u16, // instruction index to find public key
166    pub message_data_offset: u16, // offset to start of message data
167    pub message_data_size: u16, // size of message data
168    pub message_instruction_index: u16, // index of instruction data to get message data
169}
170
171/// Build a new ed25519 verify ix pointing to another ix for data
172///
173/// DEV: this constructor should exist in `solana_sdk::ed25519_instruction` but does not.
174pub fn new_ed25519_ix_ptr(message: &[u8], instruction_index: u16) -> Instruction {
175    let mut instruction_data = Vec::with_capacity(solana_sdk::ed25519_instruction::DATA_START);
176    let signature_offset = 12_usize; // after discriminator??
177    let public_key_offset =
178        signature_offset.saturating_add(solana_sdk::ed25519_instruction::SIGNATURE_SERIALIZED_SIZE);
179    let message_data_size_offset =
180        public_key_offset.saturating_add(solana_sdk::ed25519_instruction::PUBKEY_SERIALIZED_SIZE);
181    let message_data_size = u16::from_le_bytes([
182        message[message_data_size_offset - signature_offset],
183        message[message_data_size_offset - signature_offset + 1],
184    ]);
185    let message_data_offset = message_data_size_offset + 2;
186
187    // add padding byte so that offset structure is aligned
188    let num_signatures: u8 = 1;
189    instruction_data.extend_from_slice(&[num_signatures, 0]);
190    instruction_data.extend_from_slice(bytes_of(&Ed25519SignatureOffsets {
191        signature_offset: signature_offset as u16,
192        signature_instruction_index: instruction_index,
193        public_key_offset: public_key_offset as u16,
194        public_key_instruction_index: instruction_index,
195        message_data_offset: message_data_offset as u16,
196        message_data_size,
197        message_instruction_index: instruction_index,
198    }));
199
200    Instruction {
201        program_id: solana_sdk::ed25519_program::id(),
202        accounts: vec![],
203        data: instruction_data,
204    }
205}
206
207#[cfg(test)]
208pub mod test_utils {
209    //! test utilities
210
211    use anchor_lang::Discriminator;
212    use bytes::BytesMut;
213    // helpers from drift-program test_utils.
214    pub fn get_pyth_price(price: i64, expo: i32) -> pyth_test::Price {
215        let mut pyth_price = pyth_test::Price::default();
216        let price = price * 10_i64.pow(expo as u32);
217        pyth_price.agg.price = price;
218        pyth_price.twap = price;
219        pyth_price.expo = expo;
220        pyth_price
221    }
222
223    mod pyth_test {
224        //! helper structs for pyth oracle prices
225        use bytemuck::{Pod, Zeroable};
226        use serde::Serialize;
227
228        #[derive(Default, Copy, Clone, Serialize)]
229        #[repr(C)]
230        pub struct AccKey {
231            pub val: [u8; 32],
232        }
233
234        #[derive(Copy, Clone, Default, Serialize)]
235        #[repr(C)]
236        #[allow(dead_code)]
237        pub enum PriceStatus {
238            Unknown,
239            #[default]
240            Trading,
241            Halted,
242            Auction,
243        }
244
245        #[derive(Copy, Clone, Default, Serialize)]
246        #[repr(C)]
247        pub enum CorpAction {
248            #[default]
249            NoCorpAct,
250        }
251
252        #[derive(Default, Copy, Clone, Serialize)]
253        #[repr(C)]
254        pub struct PriceInfo {
255            pub price: i64,
256            pub conf: u64,
257            pub status: PriceStatus,
258            pub corp_act: CorpAction,
259            pub pub_slot: u64,
260        }
261        #[derive(Default, Copy, Clone, Serialize)]
262        #[repr(C)]
263        pub struct PriceComp {
264            publisher: AccKey,
265            agg: PriceInfo,
266            latest: PriceInfo,
267        }
268
269        #[derive(Copy, Clone, Default, Serialize)]
270        #[repr(C)]
271        #[allow(dead_code, clippy::upper_case_acronyms)]
272        pub enum PriceType {
273            Unknown,
274            #[default]
275            Price,
276            TWAP,
277            Volatility,
278        }
279
280        #[derive(Default, Copy, Clone, Serialize)]
281        #[repr(C)]
282        pub struct Price {
283            pub magic: u32,       // Pyth magic number.
284            pub ver: u32,         // Program version.
285            pub atype: u32,       // Account type.
286            pub size: u32,        // Price account size.
287            pub ptype: PriceType, // Price or calculation type.
288            pub expo: i32,        // Price exponent.
289            pub num: u32,         // Number of component prices.
290            pub unused: u32,
291            pub curr_slot: u64,        // Currently accumulating price slot.
292            pub valid_slot: u64,       // Valid slot-time of agg. price.
293            pub twap: i64,             // Time-weighted average price.
294            pub avol: u64,             // Annualized price volatility.
295            pub drv0: i64,             // Space for future derived values.
296            pub drv1: i64,             // Space for future derived values.
297            pub drv2: i64,             // Space for future derived values.
298            pub drv3: i64,             // Space for future derived values.
299            pub drv4: i64,             // Space for future derived values.
300            pub drv5: i64,             // Space for future derived values.
301            pub prod: AccKey,          // Product account key.
302            pub next: AccKey,          // Next Price account in linked list.
303            pub agg_pub: AccKey,       // Quoter who computed last aggregate price.
304            pub agg: PriceInfo,        // Aggregate price info.
305            pub comp: [PriceComp; 32], // Price components one per quoter.
306        }
307
308        #[cfg(target_endian = "little")]
309        unsafe impl Zeroable for Price {}
310
311        #[cfg(target_endian = "little")]
312        unsafe impl Pod for Price {}
313    }
314
315    pub fn get_account_bytes<T: bytemuck::Pod>(account: &mut T) -> BytesMut {
316        let mut bytes = BytesMut::new();
317        let data = bytemuck::bytes_of_mut(account);
318        bytes.extend_from_slice(data);
319        bytes
320    }
321
322    pub fn get_anchor_account_bytes<T: bytemuck::Pod + Discriminator>(account: &mut T) -> BytesMut {
323        let mut bytes = BytesMut::new();
324        bytes.extend_from_slice(T::DISCRIMINATOR);
325        let data = bytemuck::bytes_of_mut(account);
326        bytes.extend_from_slice(data);
327        bytes
328    }
329
330    #[macro_export]
331    macro_rules! create_account_info {
332        ($account:expr, $pubkey:expr, $owner:expr, $name: ident) => {
333            let acc = Account {
334                data: crate::utils::test_utils::get_account_bytes(&mut $account).to_vec(),
335                owner: $owner,
336                ..Default::default()
337            };
338            let $name: crate::ffi::AccountWithKey = (*$pubkey, acc).into();
339        };
340    }
341
342    #[macro_export]
343    macro_rules! create_anchor_account_info {
344        ($account:expr, $pubkey:expr, $type:ident, $name: ident) => {
345            let owner = constants::PROGRAM_ID;
346            let acc = Account {
347                data: crate::utils::test_utils::get_anchor_account_bytes(&mut $account).to_vec(),
348                owner,
349                ..Default::default()
350            };
351            let $name: crate::ffi::AccountWithKey = ($pubkey, acc).into();
352        };
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use solana_sdk::signer::Signer;
359
360    use super::*;
361
362    #[test]
363    fn test_keypair_from_json_numbers_array() {
364        let keypair_data = "[17,188,105,73,182,3,56,125,157,20,12,82,88,197,181,202,251,248,97,103,215,165,233,145,114,254,20,89,100,79,207,168,206,103,77,58,215,94,196,155,224,116,73,74,62,200,30,248,101,102,164,126,6,170,77,190,186,142,107,222,3,242,143,155]";
365
366        let keypair = read_keypair_str_multi_format(keypair_data).unwrap();
367        assert!(keypair.pubkey().to_string() == "EtiM5qwcrrawQP9FfRErBatNvDgEU656tk5aA8iTgqri");
368    }
369
370    #[test]
371    fn test_keypair_from_json_comma_separated_numbers() {
372        let keypair_data = "17,188,105,73,182,3,56,125,157,20,12,82,88,197,181,202,251,248,97,103,215,165,233,145,114,254,20,89,100,79,207,168,206,103,77,58,215,94,196,155,224,116,73,74,62,200,30,248,101,102,164,126,6,170,77,190,186,142,107,222,3,242,143,155";
373
374        let keypair = read_keypair_str_multi_format(keypair_data).unwrap();
375        assert!(keypair.pubkey().to_string() == "EtiM5qwcrrawQP9FfRErBatNvDgEU656tk5aA8iTgqri");
376    }
377
378    #[test]
379    fn test_keypair_from_base58_string() {
380        let keypair_data = "MZsY4Vme2Xa417rhh1MUGCru9oYNDxCjH1TZRWJPNSzRmZmodjczVaGuWKgzBsoKxx2ZLQZjUWTkLu44jE5DhSJ";
381
382        let keypair = read_keypair_str_multi_format(keypair_data).unwrap();
383        assert!(keypair.pubkey().to_string() == "EtiM5qwcrrawQP9FfRErBatNvDgEU656tk5aA8iTgqri");
384    }
385
386    #[test]
387    fn test_keypair_from_base64_string() {
388        let keypair_data = "EbxpSbYDOH2dFAxSWMW1yvv4YWfXpemRcv4UWWRPz6jOZ006117Em+B0SUo+yB74ZWakfgaqTb66jmveA/KPmw==";
389
390        let keypair = read_keypair_str_multi_format(keypair_data).unwrap();
391        assert!(keypair.pubkey().to_string() == "EtiM5qwcrrawQP9FfRErBatNvDgEU656tk5aA8iTgqri");
392    }
393
394    #[test]
395    fn test_https_to_ws() {
396        let https_url = "https://dlob.drift.trade";
397        assert!(http_to_ws(https_url).unwrap() == "wss://dlob.drift.trade/ws");
398        let http_url = "http://dlob.drift.trade";
399        assert!(http_to_ws(http_url).unwrap() == "ws://dlob.drift.trade/ws")
400    }
401}