Documentation
use borsh::{BorshDeserialize, BorshSerialize};
use solana_program::{account_info::AccountInfo, pubkey::Pubkey};
use std::str::FromStr;
use ttv_curve::math::{exchange_rate, logit, proportion_pt};
use ttv_fixed_point::Number;

/// Financial parameters for the market
#[derive(BorshSerialize, BorshDeserialize, Clone, Default)]
pub struct MarketFinancials {
    /// Expiration timestamp, which is copied from the vault associated with the PT
    pub expiration_ts: u64,

    /// Balance of PT in the market
    /// This amount is tracked separately to prevent bugs from token transfers directly to the market
    pub pt_balance: u64,

    /// Balance of SY in the market
    /// This amount is tracked separately to prevent bugs from token transfers directly to the market
    pub sy_balance: u64,

    /// Initial log of fee rate, which decreases over time
    pub ln_fee_rate_root: f64,

    /// Last seen log of implied rate (APY) for PT
    /// Used to maintain continuity of the APY between trades over time
    pub last_ln_implied_rate: f64,

    /// Initial rate scalar, which increases over time
    pub rate_scalar_root: f64,
}

impl MarketFinancials {
    pub const SIZE_OF: usize = 8 + 8 + 8 + 16 + 16 + 16;

    fn asset_balance(&self, sy_exchange_rate: Number) -> Number {
        Number::from_natural_u64(self.sy_balance) * sy_exchange_rate
    }

    fn current_rate_anchor(&self, sy_exchange_rate: Number, now: u64) -> f64 {
        let sec_remaining = self.sec_remaining(now);
        let asset = self.asset_balance(sy_exchange_rate).floor_u64();
        let current_rate_scalar = self.current_rate_scalar(now);
        ttv_curve::math::find_rate_anchor(
            self.pt_balance,
            asset,
            current_rate_scalar,
            self.last_ln_implied_rate.into(),
            sec_remaining,
        )
    }

    fn sec_remaining(&self, now: u64) -> u64 {
        if now > self.expiration_ts {
            0
        } else {
            self.expiration_ts - now
        }
    }

    fn current_rate_scalar(&self, now: u64) -> f64 {
        let sec_remaining = self.sec_remaining(now);
        ttv_curve::math::rate_scalar::<f64>(self.rate_scalar_root, sec_remaining)
    }
}

pub struct GetExchangeRateResult {
    pub pt_asset_exchange_rate: f64,
}

pub fn get_exchange_rate(
    market_account_info: &AccountInfo,
    vault_account_info: &AccountInfo,
    current_unix_timestamp: u64,
) -> GetExchangeRateResult {
    let program_id = Pubkey::from_str("ExponentnaRg3CQbW6dqQNZKXp7gtZ9DGMp1cwC4HAS7").unwrap();
    assert_eq!(
        market_account_info.owner, &program_id,
        "Market account not owned by program"
    );
    assert_eq!(
        vault_account_info.owner, &program_id,
        "Vault account not owned by program"
    );

    let market_data = market_account_info.try_borrow_data().unwrap();
    let vault_data = vault_account_info.try_borrow_data().unwrap();

    assert_eq!(
        &market_data[0..8],
        &[212, 4, 132, 126, 169, 121, 121, 20],
        "Invalid market account discriminator"
    );
    assert_eq!(
        &vault_data[0..8],
        &[211, 8, 232, 43, 2, 152, 117, 119],
        "Invalid vault account discriminator"
    );

    let final_sy_exchange_rate = Number::from_bytes_le(&vault_data[401..401 + 32]);
    let financials =
        MarketFinancials::deserialize(&mut &market_data[364..364 + MarketFinancials::SIZE_OF])
            .unwrap();

    let asset_balance = financials.asset_balance(final_sy_exchange_rate);
    let asset_balance_u64 = asset_balance.floor_u64();

    let p = proportion_pt::<f64>(financials.pt_balance, asset_balance_u64);
    let l_p = logit(p);

    let current_rate_scalar = financials.current_rate_scalar(current_unix_timestamp);
    let rate_anchor =
        financials.current_rate_anchor(final_sy_exchange_rate, current_unix_timestamp);
    let pt_asset_rate = exchange_rate(l_p, current_rate_scalar, rate_anchor);

    GetExchangeRateResult {
        pt_asset_exchange_rate: 1.00 / pt_asset_rate,
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    #[cfg(not(target_os = "solana"))]
    use solana_client::rpc_client::RpcClient;

    #[test]
    #[cfg(not(target_os = "solana"))]
    fn test_exchange_rates_with_real_data() {
        // Initialize RPC client
        let rpc_url = "https://api.mainnet-beta.solana.com".to_string();
        let client = RpcClient::new(rpc_url);

        // Example vault address - replace with a real vault address for testing
        let vault_address =
            Pubkey::from_str("xzJdaENdahnFpPNZCTb36oRUtar3ZnRHMVB4ewDdRBe").unwrap();

        // Derive market address
        let program_id = Pubkey::from_str("ExponentnaRg3CQbW6dqQNZKXp7gtZ9DGMp1cwC4HAS7").unwrap();
        let (market_address, _) =
            Pubkey::find_program_address(&[b"market", vault_address.as_ref()], &program_id);

        // Get current Solana timestamp
        let slot = client.get_slot().expect("Failed to get slot");
        let current_time = client
            .get_block_time(slot)
            .expect("Failed to get block time") as u64;

        // Fetch accounts
        let mut vault_account = client
            .get_account(&vault_address)
            .expect("Failed to get vault account");
        let mut market_account = client
            .get_account(&market_address)
            .expect("Failed to get market account");

        // Create AccountInfo structs
        let mut vault_lamports = vault_account.lamports;
        let mut market_lamports = market_account.lamports;

        let vault_account_info = AccountInfo::new(
            &vault_address,
            false,
            false,
            &mut vault_lamports,
            &mut vault_account.data,
            &program_id,
            false,
            0,
        );

        let market_account_info = AccountInfo::new(
            &market_address,
            false,
            false,
            &mut market_lamports,
            &mut market_account.data,
            &program_id,
            false,
            0,
        );

        // Get exchange rates
        let result = get_exchange_rate(&market_account_info, &vault_account_info, current_time);

        println!("PT -> Asset rate: {}", result.pt_asset_exchange_rate);
    }
}