gmsol_utils/
token_config.rs

1use std::collections::BTreeSet;
2
3use anchor_lang::prelude::*;
4
5use crate::{
6    chunk_by::chunk_by,
7    fixed_str::{bytes_to_fixed_str, FixedStrError},
8    market::HasMarketMeta,
9    oracle::PriceProviderKind,
10    pubkey::DEFAULT_PUBKEY,
11    swap::HasSwapParams,
12};
13
14/// Default heartbeat duration for price updates.
15pub const DEFAULT_HEARTBEAT_DURATION: u32 = 30;
16
17/// Default precision for price.
18pub const DEFAULT_PRECISION: u8 = 4;
19
20/// Default timestamp adjustment.
21pub const DEFAULT_TIMESTAMP_ADJUSTMENT: u32 = 0;
22
23/// Default maximum deviation ratio.
24pub const DEFAULT_MAX_DEVIATION_RATIO: u32 = 0;
25
26const MAX_FEEDS: usize = 4;
27const MAX_FLAGS: usize = 8;
28const MAX_NAME_LEN: usize = 32;
29
30/// Token config error.
31#[derive(Debug, thiserror::Error)]
32#[non_exhaustive]
33pub enum TokenConfigError {
34    /// Not found.
35    #[error("not found")]
36    NotFound,
37    /// Invalid provider index.
38    #[error("invalid provider index")]
39    InvalidProviderIndex,
40    /// Fixed str error.
41    #[error(transparent)]
42    FixedStr(#[from] FixedStrError),
43    /// Exceed max length limit.
44    #[error("exceed max length limit")]
45    ExceedMaxLengthLimit,
46    /// Exceed max ratio.
47    #[error("exceed max ratio")]
48    ExceedMaxRatio,
49    /// Max deviation factor too small.
50    #[error("max deviation factor too small")]
51    MaxDeviationFactorTooSmall,
52}
53
54pub(crate) type TokenConfigResult<T> = std::result::Result<T, TokenConfigError>;
55
56/// Token Flags.
57#[derive(num_enum::IntoPrimitive)]
58#[repr(u8)]
59#[non_exhaustive]
60pub enum TokenConfigFlag {
61    /// Is initialized.
62    Initialized,
63    /// Enabled.
64    Enabled,
65    /// Is a synthetic asset.
66    Synthetic,
67    /// Indicates whether price adjustment is allowed.
68    AllowPriceAdjustment,
69    // CHECK: Cannot have more than `MAX_FLAGS` flags.
70}
71
72crate::flags!(TokenConfigFlag, MAX_FLAGS, u8);
73
74#[zero_copy]
75#[derive(PartialEq, Eq)]
76#[cfg_attr(feature = "debug", derive(derive_more::Debug))]
77#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
78pub struct TokenConfig {
79    /// Name.
80    pub name: [u8; MAX_NAME_LEN],
81    /// Flags.
82    pub flags: TokenConfigFlagContainer,
83    /// Token decimals.
84    pub token_decimals: u8,
85    /// Precision.
86    pub precision: u8,
87    /// Expected provider.
88    pub expected_provider: u8,
89    /// Price Feeds.
90    pub feeds: [FeedConfig; MAX_FEEDS],
91    /// Heartbeat duration.
92    pub heartbeat_duration: u32,
93    #[cfg_attr(feature = "debug", debug(skip))]
94    reserved: [u8; 32],
95}
96
97#[cfg(feature = "display")]
98impl std::fmt::Display for TokenConfig {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        writeln!(f, "Name: {}", self.name().unwrap_or("*unknown*"))?;
101        writeln!(f, "Enabled: {}", self.is_enabled())?;
102        writeln!(f, "Synthetic: {}", self.is_synthetic())?;
103        writeln!(f, "Decimals: {}", self.token_decimals)?;
104        writeln!(f, "Precision: {}", self.precision)?;
105        writeln!(f, "Heartbeat: {}", self.heartbeat_duration)?;
106        writeln!(
107            f,
108            "Expected Provider: {}",
109            self.expected_provider()
110                .map(|kind| kind.to_string())
111                .unwrap_or("*unknown*".to_string())
112        )?;
113        Ok(())
114    }
115}
116
117impl TokenConfig {
118    /// Get the corresponding price feed config.
119    pub fn get_feed_config(&self, kind: &PriceProviderKind) -> TokenConfigResult<&FeedConfig> {
120        let index = *kind as usize;
121        let config = self.feeds.get(index).ok_or(TokenConfigError::NotFound)?;
122        if config.feed == DEFAULT_PUBKEY {
123            Err(TokenConfigError::NotFound)
124        } else {
125            Ok(config)
126        }
127    }
128
129    /// Get the mutable reference of feed config by kind.
130    pub fn get_feed_config_mut(
131        &mut self,
132        kind: &PriceProviderKind,
133    ) -> TokenConfigResult<&mut FeedConfig> {
134        let index = *kind as usize;
135        let config = self
136            .feeds
137            .get_mut(index)
138            .ok_or(TokenConfigError::NotFound)?;
139        if config.feed == DEFAULT_PUBKEY {
140            Err(TokenConfigError::NotFound)
141        } else {
142            Ok(config)
143        }
144    }
145
146    /// Set feed config.
147    pub fn set_feed_config(
148        &mut self,
149        kind: &PriceProviderKind,
150        new_config: FeedConfig,
151    ) -> TokenConfigResult<()> {
152        let index = *kind as usize;
153        let config = self
154            .feeds
155            .get_mut(index)
156            .ok_or(TokenConfigError::InvalidProviderIndex)?;
157        *config = new_config;
158        Ok(())
159    }
160
161    /// Get the corresponding price feed address.
162    pub fn get_feed(&self, kind: &PriceProviderKind) -> TokenConfigResult<Pubkey> {
163        Ok(self.get_feed_config(kind)?.feed)
164    }
165
166    /// Set expected provider.
167    pub fn set_expected_provider(&mut self, provider: PriceProviderKind) {
168        self.expected_provider = provider as u8;
169    }
170
171    /// Get expected price provider kind.
172    pub fn expected_provider(&self) -> TokenConfigResult<PriceProviderKind> {
173        let kind = PriceProviderKind::try_from(self.expected_provider)
174            .map_err(|_| TokenConfigError::InvalidProviderIndex)?;
175        Ok(kind)
176    }
177
178    /// Get price feed address for the expected provider.
179    pub fn get_expected_feed(&self) -> TokenConfigResult<Pubkey> {
180        self.get_feed(&self.expected_provider()?)
181    }
182
183    /// Set enabled.
184    pub fn set_enabled(&mut self, enable: bool) {
185        self.set_flag(TokenConfigFlag::Enabled, enable)
186    }
187
188    /// Set synthetic.
189    pub fn set_synthetic(&mut self, is_synthetic: bool) {
190        self.set_flag(TokenConfigFlag::Synthetic, is_synthetic)
191    }
192
193    /// Is enabled.
194    pub fn is_enabled(&self) -> bool {
195        self.flag(TokenConfigFlag::Enabled)
196    }
197
198    /// Is synthetic.
199    pub fn is_synthetic(&self) -> bool {
200        self.flag(TokenConfigFlag::Synthetic)
201    }
202
203    /// Returns whether the config is a valid pool token config.
204    pub fn is_valid_pool_token_config(&self) -> bool {
205        !self.is_synthetic()
206    }
207
208    /// Returns `true` if price adjustment is allowed.
209    pub fn is_price_adjustment_allowed(&self) -> bool {
210        self.flag(TokenConfigFlag::AllowPriceAdjustment)
211    }
212
213    /// Set flag
214    pub fn set_flag(&mut self, flag: TokenConfigFlag, value: bool) {
215        self.flags.set_flag(flag, value);
216    }
217
218    /// Get flag.
219    pub fn flag(&self, flag: TokenConfigFlag) -> bool {
220        self.flags.get_flag(flag)
221    }
222
223    /// Token decimals.
224    pub fn token_decimals(&self) -> u8 {
225        self.token_decimals
226    }
227
228    /// Price Precision.
229    pub fn precision(&self) -> u8 {
230        self.precision
231    }
232
233    /// Get timestamp adjustment.
234    pub fn timestamp_adjustment(
235        &self,
236        price_provider: &PriceProviderKind,
237    ) -> TokenConfigResult<u32> {
238        Ok(self.get_feed_config(price_provider)?.timestamp_adjustment())
239    }
240
241    /// Get max deviation factor.
242    pub fn max_deviation_factor(
243        &self,
244        price_provider: &PriceProviderKind,
245    ) -> TokenConfigResult<Option<u128>> {
246        Ok(self.get_feed_config(price_provider)?.max_deviation_factor())
247    }
248
249    /// Heartbeat duration.
250    pub fn heartbeat_duration(&self) -> u32 {
251        self.heartbeat_duration
252    }
253
254    /// Get token name.
255    pub fn name(&self) -> TokenConfigResult<&str> {
256        Ok(bytes_to_fixed_str(&self.name)?)
257    }
258}
259
260impl crate::InitSpace for TokenConfig {
261    const INIT_SPACE: usize = std::mem::size_of::<Self>();
262}
263
264/// Price Feed Config.
265#[zero_copy]
266#[derive(PartialEq, Eq)]
267#[cfg_attr(feature = "debug", derive(derive_more::Debug))]
268#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
269pub struct FeedConfig {
270    #[cfg_attr(
271        feature = "serde",
272        serde(with = "serde_with::As::<serde_with::DisplayFromStr>")
273    )]
274    feed: Pubkey,
275    timestamp_adjustment: u32,
276    /// The maximum allowed deviation ratio from the mid-price.
277    /// A value of `0` means no restriction is applied.
278    max_deviation_ratio: u32,
279    #[cfg_attr(feature = "debug", debug(skip))]
280    #[cfg_attr(feature = "serde", serde(with = "serde_bytes"))]
281    reserved: [u8; 24],
282}
283
284#[cfg(feature = "display")]
285impl std::fmt::Display for FeedConfig {
286    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287        write!(
288            f,
289            "feed = {}, timestamp_adjustment = {}",
290            self.feed, self.timestamp_adjustment
291        )
292    }
293}
294
295impl FeedConfig {
296    /// Multiplier used to convert a `u32` ratio into a `u128` factor.
297    /// The resulting precision is `FACTOR_DECIMALS` (typically 20) minus `log(RATIO_MULTIPLIER)`.
298    pub const RATIO_MULTIPLIER: u128 = 10u128.pow(12);
299
300    /// Create a new feed config.
301    pub fn new(feed: Pubkey) -> Self {
302        Self {
303            feed,
304            timestamp_adjustment: DEFAULT_TIMESTAMP_ADJUSTMENT,
305            max_deviation_ratio: DEFAULT_MAX_DEVIATION_RATIO,
306            reserved: Default::default(),
307        }
308    }
309
310    /// Set feed id.
311    pub fn with_feed(mut self, feed_id: Pubkey) -> Self {
312        self.feed = feed_id;
313        self
314    }
315
316    /// Set timestamp adjustment,
317    pub fn with_timestamp_adjustment(mut self, timestamp_adjustment: u32) -> Self {
318        self.timestamp_adjustment = timestamp_adjustment;
319        self
320    }
321
322    /// Set max deviation factor
323    pub fn with_max_deviation_factor(
324        mut self,
325        max_deviation_factor: Option<u128>,
326    ) -> TokenConfigResult<Self> {
327        let ratio = match max_deviation_factor {
328            Some(factor) => {
329                let ratio = (factor / Self::RATIO_MULTIPLIER)
330                    .try_into()
331                    .map_err(|_| TokenConfigError::ExceedMaxRatio)?;
332                if ratio == 0 {
333                    return Err(TokenConfigError::MaxDeviationFactorTooSmall);
334                }
335                ratio
336            }
337            None => 0,
338        };
339        self.max_deviation_ratio = ratio;
340        Ok(self)
341    }
342
343    /// Get feed.
344    pub fn feed(&self) -> &Pubkey {
345        &self.feed
346    }
347
348    /// Get timestamp adjustment.
349    pub fn timestamp_adjustment(&self) -> u32 {
350        self.timestamp_adjustment
351    }
352
353    /// Get max deviation factor.
354    pub fn max_deviation_factor(&self) -> Option<u128> {
355        let ratio = self.max_deviation_ratio;
356        if self.max_deviation_ratio == 0 {
357            None
358        } else {
359            Some(u128::from(ratio) * Self::RATIO_MULTIPLIER)
360        }
361    }
362}
363
364#[derive(AnchorSerialize, AnchorDeserialize, Clone)]
365#[cfg_attr(feature = "debug", derive(Debug))]
366pub struct UpdateTokenConfigParams {
367    /// Heartbeat duration.
368    pub heartbeat_duration: u32,
369    /// Price precision.
370    pub precision: u8,
371    /// Feeds.
372    pub feeds: Vec<Pubkey>,
373    /// Timestamp adjustments.
374    pub timestamp_adjustments: Vec<u32>,
375    /// Expected price provider.
376    pub expected_provider: Option<u8>,
377}
378
379impl Default for UpdateTokenConfigParams {
380    fn default() -> Self {
381        Self {
382            heartbeat_duration: DEFAULT_HEARTBEAT_DURATION,
383            precision: DEFAULT_PRECISION,
384            feeds: vec![DEFAULT_PUBKEY; MAX_FEEDS],
385            timestamp_adjustments: vec![DEFAULT_TIMESTAMP_ADJUSTMENT; MAX_FEEDS],
386            expected_provider: None,
387        }
388    }
389}
390
391impl<'a> From<&'a TokenConfig> for UpdateTokenConfigParams {
392    fn from(config: &'a TokenConfig) -> Self {
393        let (feeds, timestamp_adjustments) = config
394            .feeds
395            .iter()
396            .map(|config| (config.feed, config.timestamp_adjustment))
397            .unzip();
398
399        Self {
400            heartbeat_duration: config.heartbeat_duration(),
401            precision: config.precision(),
402            feeds,
403            timestamp_adjustments,
404            expected_provider: Some(config.expected_provider),
405        }
406    }
407}
408
409impl UpdateTokenConfigParams {
410    /// Update the feed address for the given price provider.
411    /// Return error when the feed was not set before.
412    pub fn update_price_feed(
413        mut self,
414        kind: &PriceProviderKind,
415        new_feed: Pubkey,
416        new_timestamp_adjustment: Option<u32>,
417    ) -> TokenConfigResult<Self> {
418        let index = *kind as usize;
419        let feed = self
420            .feeds
421            .get_mut(index)
422            .ok_or(TokenConfigError::NotFound)?;
423        let timestamp_adjustment = self
424            .timestamp_adjustments
425            .get_mut(index)
426            .ok_or(TokenConfigError::NotFound)?;
427        *feed = new_feed;
428        if let Some(new_timestamp_adjustment) = new_timestamp_adjustment {
429            *timestamp_adjustment = new_timestamp_adjustment;
430        }
431        Ok(self)
432    }
433
434    /// Set heartbeat duration.
435    pub fn with_heartbeat_duration(mut self, duration: u32) -> Self {
436        self.heartbeat_duration = duration;
437        self
438    }
439
440    /// Set precision.
441    pub fn with_precision(mut self, precision: u8) -> Self {
442        self.precision = precision;
443        self
444    }
445
446    /// Set expected provider.
447    pub fn with_expected_provider(mut self, provider: PriceProviderKind) -> Self {
448        self.expected_provider = Some(provider as u8);
449        self
450    }
451}
452
453/// Read Token Map.
454pub trait TokenMapAccess {
455    /// Get the config of the given token.
456    fn get(&self, token: &Pubkey) -> Option<&TokenConfig>;
457
458    /// Get token configs for the given market.
459    ///
460    /// Returns the token configs for `index_token`, `long_token` and `short_token`.
461    fn token_configs_for_market(&self, market: &impl HasMarketMeta) -> Option<[&TokenConfig; 3]> {
462        let meta = market.market_meta();
463        let index_token = self.get(&meta.index_token_mint)?;
464        let long_token = self.get(&meta.long_token_mint)?;
465        let short_token = self.get(&meta.short_token_mint)?;
466        Some([index_token, long_token, short_token])
467    }
468
469    /// Sort tokens by provider. This sort is stable.
470    fn sort_tokens_by_provider(&self, tokens: &mut [Pubkey]) -> Result<()> {
471        // Check the existence of token configs.
472        for token in tokens.iter() {
473            require!(self.get(token).is_some(), ErrorCode::RequireViolated);
474        }
475        tokens.sort_by_cached_key(|token| self.get(token).unwrap().expected_provider);
476        Ok(())
477    }
478}
479
480/// Tokens with feed.
481#[derive(AnchorSerialize, AnchorDeserialize, Clone)]
482#[cfg_attr(feature = "debug", derive(Debug))]
483#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
484pub struct TokensWithFeed {
485    /// Tokens that require prices,
486    /// which must be of the same length with `feeds`.
487    pub tokens: Vec<Pubkey>,
488    /// Token feeds for the tokens,
489    /// which must be of the same length with `tokens`.
490    pub feeds: Vec<Pubkey>,
491    /// Providers set,
492    /// which must be of the same length with `nums`.
493    pub providers: Vec<u8>,
494    /// The numbers of tokens of each provider.
495    pub nums: Vec<u16>,
496}
497
498/// A record of token config.
499#[derive(AnchorSerialize, AnchorDeserialize, Clone)]
500#[cfg_attr(feature = "debug", derive(Debug))]
501#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
502pub struct TokenRecord {
503    token: Pubkey,
504    feed: Pubkey,
505    provider: u8,
506}
507
508impl TokenRecord {
509    /// Create a new [`TokenRecord`]
510    pub fn new(token: Pubkey, feed: Pubkey, provider: PriceProviderKind) -> Self {
511        Self {
512            token,
513            feed,
514            provider: provider as u8,
515        }
516    }
517
518    /// Create a new [`TokenRecord`] from token config,
519    /// using the expected provider and feed.
520    pub fn from_config(token: Pubkey, config: &TokenConfig) -> TokenConfigResult<Self> {
521        Ok(Self::new(
522            token,
523            config.get_expected_feed()?,
524            config.expected_provider()?,
525        ))
526    }
527
528    /// Get token address.
529    pub fn token(&self) -> &Pubkey {
530        &self.token
531    }
532
533    /// Get feed id.
534    pub fn feed(&self) -> &Pubkey {
535        &self.feed
536    }
537
538    /// Get provider kind.
539    pub fn provider_kind(
540        &self,
541    ) -> std::result::Result<PriceProviderKind, num_enum::TryFromPrimitiveError<PriceProviderKind>>
542    {
543        PriceProviderKind::try_from(self.provider)
544    }
545}
546
547impl TokensWithFeed {
548    /// Create from token records.
549    /// # Panic
550    /// Panics if the number of tokens of the same provider exceeds `u16`.
551    pub fn try_from_records(mut records: Vec<TokenRecord>) -> TokenConfigResult<Self> {
552        records.sort_by_cached_key(|r| r.provider);
553        let mut chunks = chunk_by(&records, |a, b| a.provider == b.provider);
554        let capacity = chunks.size_hint().0;
555        let mut providers = Vec::with_capacity(capacity);
556        let mut nums = Vec::with_capacity(capacity);
557        chunks.try_for_each(|chunk| {
558            providers.push(chunk[0].provider);
559            nums.push(
560                u16::try_from(chunk.len()).map_err(|_| TokenConfigError::ExceedMaxLengthLimit)?,
561            );
562            TokenConfigResult::Ok(())
563        })?;
564        Ok(Self {
565            tokens: records.iter().map(|r| r.token).collect(),
566            feeds: records.iter().map(|r| r.feed).collect(),
567            providers,
568            nums,
569        })
570    }
571}
572
573/// Collect token records for the give tokens.
574pub fn token_records<A: TokenMapAccess>(
575    token_map: &A,
576    tokens: &BTreeSet<Pubkey>,
577) -> TokenConfigResult<Vec<TokenRecord>> {
578    tokens
579        .iter()
580        .map(|token| {
581            let config = token_map.get(token).ok_or(TokenConfigError::NotFound)?;
582            TokenRecord::from_config(*token, config)
583        })
584        .collect::<TokenConfigResult<Vec<_>>>()
585}
586
587/// Tokens Collector.
588pub struct TokensCollector {
589    tokens: Vec<Pubkey>,
590}
591
592impl TokensCollector {
593    /// Create a new [`TokensCollector`].
594    pub fn new(action: Option<&impl HasSwapParams>, extra_capacity: usize) -> Self {
595        let mut tokens;
596        match action {
597            Some(action) => {
598                let swap = action.swap();
599                tokens = Vec::with_capacity(swap.num_tokens() + extra_capacity);
600                // The tokens in the swap params must be sorted.
601                tokens.extend_from_slice(swap.tokens());
602            }
603            None => tokens = Vec::with_capacity(extra_capacity),
604        }
605
606        Self { tokens }
607    }
608
609    /// Insert a new token.
610    pub fn insert_token(&mut self, token: &Pubkey) -> bool {
611        match self.tokens.binary_search(token) {
612            Ok(_) => false,
613            Err(idx) => {
614                self.tokens.insert(idx, *token);
615                true
616            }
617        }
618    }
619
620    /// Convert to a vec.
621    pub fn into_vec(mut self, token_map: &impl TokenMapAccess) -> TokenConfigResult<Vec<Pubkey>> {
622        token_map
623            .sort_tokens_by_provider(&mut self.tokens)
624            .map_err(|_| TokenConfigError::NotFound)?;
625        Ok(self.tokens)
626    }
627
628    /// Get unique tokens.
629    pub fn unique_tokens(&self) -> BTreeSet<Pubkey> {
630        self.tokens.iter().copied().collect()
631    }
632
633    /// Convert to [`TokensWithFeed`].
634    pub fn to_feeds(&self, token_map: &impl TokenMapAccess) -> TokenConfigResult<TokensWithFeed> {
635        let records = self
636            .tokens
637            .iter()
638            .map(|token| {
639                let config = token_map.get(token).ok_or(TokenConfigError::NotFound)?;
640                TokenRecord::from_config(*token, config)
641            })
642            .collect::<TokenConfigResult<Vec<_>>>()?;
643        TokensWithFeed::try_from_records(records)
644    }
645}
646
647/// Max number of treasury token flags.
648#[cfg(feature = "treasury")]
649pub const MAX_TREASURY_TOKEN_FLAGS: usize = 8;
650
651/// Token Flags.
652#[cfg(feature = "treasury")]
653#[cfg_attr(feature = "enum-iter", derive(strum::EnumIter))]
654#[cfg_attr(feature = "debug", derive(Debug))]
655#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
656#[cfg_attr(feature = "clap", clap(rename_all = "snake_case"))]
657#[derive(
658    num_enum::IntoPrimitive, Clone, Copy, strum::EnumString, strum::Display, PartialEq, Eq,
659)]
660#[strum(serialize_all = "snake_case")]
661#[repr(u8)]
662pub enum TokenFlag {
663    /// Allow deposit.
664    AllowDeposit,
665    /// Allow withdrawal.
666    AllowWithdrawal,
667    // CHECK: cannot have more than `MAX_TREASURY_TOKEN_FLAGS` flags.
668}