use {
    crate::{
        error::ErrorCode,
        state::ExpressRelayMetadata,
        token::transfer_token_if_needed,
        FeeToken,
        Swap,
        SwapArgs,
        FEE_SPLIT_PRECISION,
    },
    anchor_lang::{
        accounts::interface_account::InterfaceAccount,
        prelude::*,
    },
    anchor_spl::token_interface::TokenAccount,
};
pub struct PostFeeSwapArgs {
    pub input_after_fees:  u64,
    pub output_after_fees: u64,
}
impl<'info> Swap<'info> {
    pub fn transfer_swap_fees(&self, args: &SwapArgs) -> Result<PostFeeSwapArgs> {
        let (post_fee_swap_args, transfer_swap_fees) = match args.fee_token {
            FeeToken::Input => {
                let SwapFeesWithRemainingAmount {
                    fees,
                    remaining_amount,
                } = self
                    .express_relay_metadata
                    .compute_swap_fees(args.referral_fee_bps, args.amount_input)?;
                (
                    PostFeeSwapArgs {
                        input_after_fees:  remaining_amount,
                        output_after_fees: args.amount_output,
                    },
                    TransferSwapFeeArgs {
                        fees,
                        from: &self.searcher_input_ta,
                        authority: &self.searcher,
                    },
                )
            }
            FeeToken::Output => {
                let SwapFeesWithRemainingAmount {
                    fees,
                    remaining_amount,
                } = self
                    .express_relay_metadata
                    .compute_swap_fees(args.referral_fee_bps, args.amount_output)?;
                (
                    PostFeeSwapArgs {
                        input_after_fees:  args.amount_input,
                        output_after_fees: remaining_amount,
                    },
                    TransferSwapFeeArgs {
                        fees,
                        from: &self.trader_output_ata,
                        authority: &self.trader,
                    },
                )
            }
        };
        self.transfer_swap_fees_cpi(&transfer_swap_fees)?;
        Ok(post_fee_swap_args)
    }
    fn transfer_swap_fees_cpi<'a>(&self, args: &TransferSwapFeeArgs<'info, 'a>) -> Result<()> {
        self.transfer_swap_fee_cpi(args.fees.router_fee, &self.router_fee_receiver_ta, args)?;
        self.transfer_swap_fee_cpi(args.fees.relayer_fee, &self.relayer_fee_receiver_ata, args)?;
        self.transfer_swap_fee_cpi(
            args.fees.express_relay_fee,
            &self.express_relay_fee_receiver_ata,
            args,
        )?;
        Ok(())
    }
    fn transfer_swap_fee_cpi<'a>(
        &self,
        fee: u64,
        receiver_ta: &InterfaceAccount<'info, TokenAccount>,
        args: &TransferSwapFeeArgs<'info, 'a>,
    ) -> Result<()> {
        transfer_token_if_needed(
            args.from,
            receiver_ta,
            &self.token_program_fee,
            args.authority,
            &self.mint_fee,
            fee,
        )?;
        Ok(())
    }
}
pub struct TransferSwapFeeArgs<'info, 'a> {
    pub fees:      SwapFees,
    pub from:      &'a InterfaceAccount<'info, TokenAccount>,
    pub authority: &'a Signer<'info>,
}
pub struct SwapFeesWithRemainingAmount {
    pub fees:             SwapFees,
    pub remaining_amount: u64,
}
pub struct SwapFees {
    pub router_fee:        u64,
    pub relayer_fee:       u64,
    pub express_relay_fee: u64,
}
impl ExpressRelayMetadata {
    pub fn compute_swap_fees(
        &self,
        referral_fee_bps: u16,
        amount: u64,
    ) -> Result<SwapFeesWithRemainingAmount> {
        if u64::from(referral_fee_bps) > FEE_SPLIT_PRECISION {
            return Err(ErrorCode::InvalidReferralFee.into());
        }
        let router_fee = amount
            .checked_mul(referral_fee_bps.into())
            .ok_or(ProgramError::ArithmeticOverflow)?
            / FEE_SPLIT_PRECISION;
        let platform_fee = amount
            .checked_mul(self.swap_platform_fee_bps)
            .ok_or(ProgramError::ArithmeticOverflow)?
            / FEE_SPLIT_PRECISION;
        let relayer_fee = platform_fee
            .checked_mul(self.split_relayer)
            .ok_or(ProgramError::ArithmeticOverflow)?
            / FEE_SPLIT_PRECISION;
        let remaining_amount = amount
            .checked_sub(router_fee)
            .ok_or(ProgramError::ArithmeticOverflow)?
            .checked_sub(platform_fee)
            .ok_or(ProgramError::ArithmeticOverflow)?;
        let express_relay_fee = platform_fee
            .checked_sub(relayer_fee)
            .ok_or(ProgramError::ArithmeticOverflow)?;
        Ok(SwapFeesWithRemainingAmount {
            fees: SwapFees {
                router_fee,
                relayer_fee,
                express_relay_fee,
            },
            remaining_amount,
        })
    }
}