use std::collections::{BTreeMap, BTreeSet};
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
use bon::Builder;
use simulator_api::{
AccountData, AccountModifications, AgentParams, BinaryEncoding, ContinueParams,
CreateBacktestSessionRequest, CreateBacktestSessionRequestV1, CreateSessionParams,
DiscoveryFilter,
};
use solana_address::Address;
use solana_client::rpc_client::SerializableTransaction;
use thiserror::Error;
use crate::BacktestClientError;
#[derive(Debug, Error)]
pub enum SerializeEncodeError {
#[error("bincode serialization error: {0}")]
Bincode(#[from] bincode::error::EncodeError),
}
fn serialize_to_base64(value: &impl serde::Serialize) -> Result<String, SerializeEncodeError> {
let bytes = bincode::serde::encode_to_vec(
value,
bincode::config::standard()
.with_fixed_int_encoding()
.with_little_endian(),
)?;
Ok(BASE64.encode(&bytes))
}
#[derive(Debug, Clone, Builder)]
pub struct CreateSession {
pub start_slot: u64,
pub end_slot: Option<u64>,
pub slot_count: Option<u64>,
#[builder(default)]
pub signer_filter: BTreeSet<Address>,
#[builder(default)]
pub send_summary: bool,
#[builder(default)]
pub parallel: bool,
pub capacity_wait_timeout_secs: Option<u16>,
pub disconnect_timeout_secs: Option<u16>,
pub extra_compute_units: Option<u32>,
#[builder(default)]
pub agents: Vec<AgentParams>,
#[builder(default)]
pub discoveries: Vec<DiscoveryFilter>,
}
impl CreateSession {
pub fn add_signer_filter(mut self, address: Address) -> Self {
self.signer_filter.insert(address);
self
}
pub fn into_params(self) -> Result<CreateSessionParams, BacktestClientError> {
let end_slot = match (self.end_slot, self.slot_count) {
(Some(_), Some(_)) => {
return Err(BacktestClientError::InvalidParams {
message: "CreateSession: set only one of end_slot or slot_count".to_string(),
});
}
(Some(end_slot), None) => end_slot,
(None, Some(slot_count)) => {
self.start_slot.checked_add(slot_count).ok_or_else(|| {
BacktestClientError::InvalidParams {
message: "CreateSession: start_slot + slot_count overflow".to_string(),
}
})?
}
(None, None) => {
return Err(BacktestClientError::InvalidParams {
message: "CreateSession: must set end_slot or slot_count".to_string(),
});
}
};
if end_slot < self.start_slot {
return Err(BacktestClientError::InvalidParams {
message: format!(
"CreateSession: end_slot ({end_slot}) must be >= start_slot ({})",
self.start_slot
),
});
}
Ok(CreateSessionParams {
start_slot: self.start_slot,
end_slot,
signer_filter: self.signer_filter,
send_summary: self.send_summary,
capacity_wait_timeout_secs: self.capacity_wait_timeout_secs,
disconnect_timeout_secs: self.disconnect_timeout_secs,
extra_compute_units: self.extra_compute_units,
agents: self.agents,
discoveries: self.discoveries,
})
}
pub fn into_request(self) -> Result<CreateBacktestSessionRequest, BacktestClientError> {
let parallel = self.parallel;
let request = self.into_params()?;
if parallel {
Ok(CreateBacktestSessionRequestV1 { request, parallel }.into())
} else {
Ok(request.into())
}
}
}
#[derive(Debug, Builder)]
pub struct Continue {
#[builder(default = ContinueParams::default_advance_count())]
pub advance_count: u64,
#[builder(default)]
pub transactions: Vec<String>,
#[builder(default)]
pub modify_accounts: BTreeMap<Address, AccountData>,
}
impl Continue {
pub fn push_transaction_base64(mut self, data: impl Into<String>) -> Self {
self.transactions.push(data.into());
self
}
pub fn push_transaction_bytes(mut self, bytes: &[u8]) -> Self {
self.transactions.push(BinaryEncoding::Base64.encode(bytes));
self
}
pub fn push_transaction(
mut self,
transaction: &impl SerializableTransaction,
) -> Result<Self, SerializeEncodeError> {
self.transactions.push(serialize_to_base64(&transaction)?);
Ok(self)
}
pub fn modify_account(mut self, address: Address, account: AccountData) -> Self {
self.modify_accounts.insert(address, account);
self
}
pub fn into_params(self) -> ContinueParams {
ContinueParams {
advance_count: self.advance_count,
transactions: self.transactions,
modify_account_states: AccountModifications(self.modify_accounts),
}
}
}