1use std::collections::HashMap;
2
3use anchor_lang::prelude::{Clock, Pubkey, Result};
4use anchor_spl::token::{Mint, TokenAccount};
5use anyhow::{ensure, Context};
6use mercurial_vault::state::Vault;
7use spl_token_swap::curve::calculator::TradeDirection;
8
9use crate::{
10 context::{get_curve_type, get_first_key, get_second_key, get_trade_fee_bps_bytes},
11 curve::curve_type::CurveType,
12 depeg::update_base_virtual_price,
13 error::PoolError,
14 math::{get_swap_curve, SwapResult},
15 state::Pool,
16};
17
18pub struct VaultInfo {
19 pub lp_amount: u64,
21 pub lp_supply: u64,
23 pub vault: Vault,
25}
26
27#[derive(Clone)]
28pub struct QuoteData {
29 pub pool: Pool,
31 pub vault_a: Vault,
33 pub vault_b: Vault,
35 pub pool_vault_a_lp_token: TokenAccount,
37 pub pool_vault_b_lp_token: TokenAccount,
39 pub vault_a_lp_mint: Mint,
41 pub vault_b_lp_mint: Mint,
43 pub vault_a_token: TokenAccount,
45 pub vault_b_token: TokenAccount,
47 pub clock: Clock,
49 pub stake_data: HashMap<Pubkey, Vec<u8>>,
51}
52
53#[derive(Debug, Clone)]
54pub struct QuoteResult {
55 pub out_amount: u64,
57 pub fee: u64,
59}
60
61pub fn compute_quote(
62 in_token_mint: Pubkey,
63 in_amount: u64,
64 quote_data: QuoteData,
65) -> anyhow::Result<QuoteResult> {
66 let QuoteData {
67 mut pool,
68 vault_a,
69 vault_b,
70 pool_vault_a_lp_token,
71 pool_vault_b_lp_token,
72 vault_a_lp_mint,
73 vault_b_lp_mint,
74 vault_a_token,
75 vault_b_token,
76 clock,
77 stake_data,
78 } = quote_data;
79
80 update_base_virtual_price(&mut pool, &clock, stake_data)?;
81
82 let current_time: u64 = clock.unix_timestamp.try_into()?;
83
84 ensure!(
85 in_token_mint == pool.token_a_mint || in_token_mint == pool.token_b_mint,
86 "In token mint not matches with pool token mints"
87 );
88
89 let token_a_amount = vault_a
90 .get_amount_by_share(
91 current_time,
92 pool_vault_a_lp_token.amount,
93 vault_a_lp_mint.supply,
94 )
95 .context("Fail to get token a amount")?;
96
97 let token_b_amount = vault_b
98 .get_amount_by_share(
99 current_time,
100 pool_vault_b_lp_token.amount,
101 vault_b_lp_mint.supply,
102 )
103 .context("Fail to get token b amount")?;
104
105 let trade_direction = if in_token_mint == pool.token_a_mint {
106 TradeDirection::AtoB
107 } else {
108 TradeDirection::BtoA
109 };
110
111 let (
112 mut in_vault,
113 out_vault,
114 in_vault_lp,
115 in_vault_lp_mint,
116 out_vault_lp_mint,
117 out_vault_token_account,
118 in_token_total_amount,
119 out_token_total_amount,
120 ) = match trade_direction {
121 TradeDirection::AtoB => (
122 vault_a,
123 vault_b,
124 pool_vault_a_lp_token,
125 vault_a_lp_mint,
126 vault_b_lp_mint,
127 vault_b_token,
128 token_a_amount,
129 token_b_amount,
130 ),
131 TradeDirection::BtoA => (
132 vault_b,
133 vault_a,
134 pool_vault_b_lp_token,
135 vault_b_lp_mint,
136 vault_a_lp_mint,
137 vault_a_token,
138 token_b_amount,
139 token_a_amount,
140 ),
141 };
142
143 let trade_fee = pool
144 .fees
145 .trading_fee(in_amount.into())
146 .context("Fail to calculate trading fee")?;
147
148 let owner_fee = pool
149 .fees
150 .owner_trading_fee(in_amount.into())
151 .context("Fail to calculate owner trading fee")?;
152
153 let in_amount_after_owner_fee = in_amount
154 .checked_sub(owner_fee.try_into()?)
155 .context("Fail to calculate in_amount_after_owner_fee")?;
156
157 let before_in_token_total_amount = in_token_total_amount;
158
159 let in_lp = in_vault
160 .get_unmint_amount(
161 current_time,
162 in_amount_after_owner_fee,
163 in_vault_lp_mint.supply,
164 )
165 .context("Fail to get in_vault_lp")?;
166
167 in_vault.total_amount = in_vault
168 .total_amount
169 .checked_add(in_amount_after_owner_fee)
170 .context("Fail to add in_vault.total_amount")?;
171
172 let after_in_token_total_amount = in_vault
173 .get_amount_by_share(
174 current_time,
175 in_lp
176 .checked_add(in_vault_lp.amount)
177 .context("Fail to get new in_vault_lp")?,
178 in_vault_lp_mint
179 .supply
180 .checked_add(in_lp)
181 .context("Fail to get new in_vault_lp_mint")?,
182 )
183 .context("Fail to get after_in_token_total_amount")?;
184
185 let actual_in_amount = after_in_token_total_amount
186 .checked_sub(before_in_token_total_amount)
187 .context("Fail to get actual_in_amount")?;
188
189 let actual_in_amount_after_fee = actual_in_amount
190 .checked_sub(trade_fee.try_into()?)
191 .context("Fail to calculate in_amount_after_fee")?;
192
193 let swap_curve = get_swap_curve(pool.curve_type);
194
195 let SwapResult {
196 destination_amount_swapped,
197 ..
198 } = swap_curve
199 .swap(
200 actual_in_amount_after_fee,
201 in_token_total_amount,
202 out_token_total_amount,
203 trade_direction,
204 )
205 .context("Fail to get swap result")?;
206
207 let out_vault_lp = out_vault
208 .get_unmint_amount(
209 current_time,
210 destination_amount_swapped.try_into()?,
211 out_vault_lp_mint.supply,
212 )
213 .context("Fail to get out_vault_lp")?;
214
215 let out_amount = out_vault
216 .get_amount_by_share(current_time, out_vault_lp, out_vault_lp_mint.supply)
217 .context("Fail to get out_amount")?;
218
219 ensure!(
220 out_amount < out_vault_token_account.amount,
221 "Out amount > vault reserve"
222 );
223
224 let total_fee = trade_fee
225 .checked_add(owner_fee)
226 .context("Fail to calculate total fee")?;
227
228 Ok(QuoteResult {
229 fee: total_fee.try_into()?,
230 out_amount,
231 })
232}
233
234pub fn compute_pool_tokens(
236 current_time: u64,
237 vault_a: VaultInfo,
238 vault_b: VaultInfo,
239) -> Result<(u64, u64)> {
240 let token_a_amount = vault_a
241 .vault
242 .get_amount_by_share(current_time, vault_a.lp_amount, vault_a.lp_supply)
243 .ok_or(PoolError::MathOverflow)?;
244 let token_b_amount = vault_b
245 .vault
246 .get_amount_by_share(current_time, vault_b.lp_amount, vault_b.lp_supply)
247 .ok_or(PoolError::MathOverflow)?;
248 Ok((token_a_amount, token_b_amount))
249}
250
251pub fn derive_admin_token_fee(token_mint: Pubkey, pool: Pubkey) -> (Pubkey, u8) {
252 Pubkey::find_program_address(
253 &["fee".as_ref(), token_mint.as_ref(), pool.as_ref()],
254 &crate::ID,
255 )
256}
257
258pub fn derive_vault_lp(vault: Pubkey, pool: Pubkey) -> (Pubkey, u8) {
259 Pubkey::find_program_address(&[vault.as_ref(), pool.as_ref()], &crate::ID)
260}
261
262pub fn derive_lp_mint(pool: Pubkey) -> (Pubkey, u8) {
263 Pubkey::find_program_address(&["lp_mint".as_ref(), pool.as_ref()], &crate::ID)
264}
265
266#[deprecated(note = "use derive_permissionless_pool_with_fee_tier")]
267pub fn derive_permissionless_pool(
268 curve_type: CurveType,
269 token_a_mint: Pubkey,
270 token_b_mint: Pubkey,
271) -> (Pubkey, u8) {
272 Pubkey::find_program_address(
273 &[
274 &get_curve_type(curve_type).to_le_bytes(),
275 get_first_key(token_a_mint, token_b_mint).as_ref(),
276 get_second_key(token_a_mint, token_b_mint).as_ref(),
277 ],
278 &crate::ID,
279 )
280}
281
282pub fn derive_permissionless_pool_with_fee_tier(
283 curve_type: CurveType,
284 token_a_mint: Pubkey,
285 token_b_mint: Pubkey,
286 trade_fee_bps: u64,
287) -> (Pubkey, u8) {
288 Pubkey::find_program_address(
289 &[
290 &get_curve_type(curve_type).to_le_bytes(),
291 get_first_key(token_a_mint, token_b_mint).as_ref(),
292 get_second_key(token_a_mint, token_b_mint).as_ref(),
293 get_trade_fee_bps_bytes(curve_type, trade_fee_bps)
294 .unwrap()
295 .as_ref(),
296 ],
297 &crate::ID,
298 )
299}