pub mod subscribe_config;
pub mod usage;
use std::{
collections::{BTreeMap, BTreeSet, HashSet},
fmt,
};
use base64::{
DecodeError as Base64DecodeError, Engine as _, engine::general_purpose::STANDARD as BASE64,
};
use serde::{Deserialize, Serialize};
use solana_address::Address;
pub mod ws_compression;
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "method", content = "params", rename_all = "camelCase")]
pub enum BacktestRequest {
CreateBacktestSession(CreateBacktestSessionRequest),
Continue(ContinueParams),
ContinueTo(ContinueToParams),
ContinueSessionV1(ContinueSessionRequestV1),
ContinueToSessionV1(ContinueToSessionRequestV1),
CloseBacktestSession,
CloseSessionV1(CloseSessionRequestV1),
AttachBacktestSession {
session_id: String,
last_sequence: Option<u64>,
},
ResumeAttachedSession,
AttachParallelControlSessionV2 {
control_session_id: String,
#[serde(default)]
last_sequences: BTreeMap<String, u64>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum CreateBacktestSessionRequest {
V1(CreateBacktestSessionRequestV1),
V0(CreateSessionParams),
}
impl CreateBacktestSessionRequest {
pub fn into_request_options(self) -> CreateBacktestSessionRequestOptions {
match self {
Self::V0(request) => CreateBacktestSessionRequestOptions {
request,
parallel: false,
},
Self::V1(CreateBacktestSessionRequestV1 { request, parallel }) => {
CreateBacktestSessionRequestOptions { request, parallel }
}
}
}
pub fn into_request_and_parallel(self) -> (CreateSessionParams, bool) {
let options = self.into_request_options();
(options.request, options.parallel)
}
}
impl From<CreateSessionParams> for CreateBacktestSessionRequest {
fn from(value: CreateSessionParams) -> Self {
Self::V0(value)
}
}
impl From<CreateBacktestSessionRequestV1> for CreateBacktestSessionRequest {
fn from(value: CreateBacktestSessionRequestV1) -> Self {
Self::V1(value)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CreateBacktestSessionRequestV1 {
#[serde(flatten)]
pub request: CreateSessionParams,
pub parallel: bool,
}
#[derive(Debug, Clone)]
pub struct CreateBacktestSessionRequestOptions {
pub request: CreateSessionParams,
pub parallel: bool,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ContinueSessionRequestV1 {
pub session_id: String,
pub request: ContinueParams,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ContinueToSessionRequestV1 {
pub session_id: String,
pub request: ContinueToParams,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CloseSessionRequestV1 {
pub session_id: String,
}
#[serde_with::serde_as]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind", content = "value", rename_all = "camelCase")]
pub enum DiscoveryFilter {
ProgramExecuted(#[serde_as(as = "serde_with::DisplayFromStr")] Address),
}
pub struct TxMatchContext<'a> {
pub invoked_programs: &'a HashSet<Address>,
}
impl DiscoveryFilter {
pub fn matches(&self, ctx: &TxMatchContext<'_>) -> bool {
match self {
Self::ProgramExecuted(target) => ctx.invoked_programs.contains(target),
}
}
}
#[serde_with::serde_as]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CreateSessionParams {
pub start_slot: u64,
pub end_slot: u64,
#[serde_as(as = "BTreeSet<serde_with::DisplayFromStr>")]
#[serde(default)]
pub signer_filter: BTreeSet<Address>,
#[serde(default)]
pub send_summary: bool,
#[serde(default)]
pub capacity_wait_timeout_secs: Option<u16>,
#[serde(default)]
pub disconnect_timeout_secs: Option<u16>,
#[serde(default)]
pub extra_compute_units: Option<u32>,
#[serde(default)]
pub agents: Vec<AgentParams>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub discoveries: Vec<DiscoveryFilter>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum FailFastDivergenceKind {
#[default]
AnyNonBenign,
Tracked,
}
impl FailFastDivergenceKind {
pub fn as_str(self) -> &'static str {
match self {
Self::AnyNonBenign => "any-non-benign",
Self::Tracked => "tracked",
}
}
pub fn from_str_opt(value: &str) -> Option<Self> {
match value {
"any-non-benign" => Some(Self::AnyNonBenign),
"tracked" => Some(Self::Tracked),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum AgentType {
Arb,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ArbRouteParams {
pub base_mint: String,
pub temp_mint: String,
#[serde(default)]
pub buy_dexes: Vec<String>,
#[serde(default)]
pub sell_dexes: Vec<String>,
pub min_input: u64,
pub max_input: u64,
#[serde(default)]
pub min_profit: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AgentParams {
pub agent_type: AgentType,
pub wallet: Option<String>,
pub keypair: Option<String>,
pub seed_sol_lamports: Option<u64>,
#[serde(default)]
pub seed_token_accounts: BTreeMap<String, u64>,
#[serde(default)]
pub arb_routes: Vec<ArbRouteParams>,
}
#[serde_with::serde_as]
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct AccountModifications(
#[serde_as(as = "BTreeMap<serde_with::DisplayFromStr, _>")]
#[serde(default)]
pub BTreeMap<Address, AccountData>,
);
#[serde_with::serde_as]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ContinueParams {
#[serde(default = "ContinueParams::default_advance_count")]
pub advance_count: u64,
#[serde(default)]
pub transactions: Vec<String>,
#[serde(default)]
pub modify_account_states: AccountModifications,
}
impl Default for ContinueParams {
fn default() -> Self {
Self {
advance_count: Self::default_advance_count(),
transactions: Vec::new(),
modify_account_states: AccountModifications(BTreeMap::new()),
}
}
}
impl ContinueParams {
pub fn default_advance_count() -> u64 {
1
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PausedEvent {
pub slot: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub batch_index: Option<u32>,
}
#[serde_with::serde_as]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DiscoveryBatchEvent {
pub slot: u64,
pub batch_index: u32,
pub matched: Vec<DiscoveryFilter>,
pub transactions: Vec<EncodedBinary>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ContinueToParams {
pub slot: u64,
#[serde(default)]
pub batch_index: Option<u32>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum BinaryEncoding {
Base64,
}
impl BinaryEncoding {
pub fn encode(self, bytes: &[u8]) -> String {
match self {
Self::Base64 => BASE64.encode(bytes),
}
}
pub fn decode(self, data: &str) -> Result<Vec<u8>, Base64DecodeError> {
match self {
Self::Base64 => BASE64.decode(data),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct EncodedBinary {
pub data: String,
pub encoding: BinaryEncoding,
}
impl EncodedBinary {
pub fn new(data: String, encoding: BinaryEncoding) -> Self {
Self { data, encoding }
}
pub fn from_bytes(bytes: &[u8], encoding: BinaryEncoding) -> Self {
Self {
data: encoding.encode(bytes),
encoding,
}
}
pub fn decode(&self) -> Result<Vec<u8>, Base64DecodeError> {
self.encoding.decode(&self.data)
}
}
#[serde_with::serde_as]
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AccountData {
pub data: EncodedBinary,
pub executable: bool,
pub lamports: u64,
#[serde_as(as = "serde_with::DisplayFromStr")]
pub owner: Address,
pub space: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method", content = "params", rename_all = "camelCase")]
pub enum BacktestResponse {
SessionCreated {
session_id: String,
rpc_endpoint: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
task_id: Option<String>,
},
SessionAttached {
session_id: String,
rpc_endpoint: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
task_id: Option<String>,
},
SessionsCreated {
session_ids: Vec<String>,
},
SessionsCreatedV2 {
control_session_id: String,
session_ids: Vec<String>,
#[serde(default)]
task_ids: Vec<Option<String>>,
#[serde(default)]
start_slots: Vec<u64>,
#[serde(default)]
end_slots: Vec<u64>,
},
ParallelSessionAttachedV2 {
control_session_id: String,
session_ids: Vec<String>,
#[serde(default)]
task_ids: Vec<Option<String>>,
},
ReadyForContinue,
SlotNotification(u64),
Paused(PausedEvent),
DiscoveryBatch(DiscoveryBatchEvent),
Error(BacktestError),
Success,
Completed {
#[serde(skip_serializing_if = "Option::is_none")]
summary: Option<SessionSummary>,
#[serde(default, skip_serializing_if = "Option::is_none")]
agent_stats: Option<Vec<AgentStatsReport>>,
},
Status {
status: BacktestStatus,
},
SessionEventV1 {
session_id: String,
event: SessionEventV1,
},
SessionEventV2 {
session_id: String,
seq_id: u64,
event: SessionEventKind,
},
}
impl BacktestResponse {
pub fn is_completed(&self) -> bool {
matches!(self, BacktestResponse::Completed { .. })
}
pub fn is_terminal(&self) -> bool {
match self {
BacktestResponse::Completed { .. } => true,
BacktestResponse::Error(e) => matches!(
e,
BacktestError::NoMoreBlocks
| BacktestError::AdvanceSlotFailed { .. }
| BacktestError::FinalizeSlotFailed { .. }
| BacktestError::Internal { .. }
),
_ => false,
}
}
}
impl From<BacktestStatus> for BacktestResponse {
fn from(status: BacktestStatus) -> Self {
Self::Status { status }
}
}
impl From<String> for BacktestResponse {
fn from(message: String) -> Self {
BacktestError::Internal { error: message }.into()
}
}
impl From<&str> for BacktestResponse {
fn from(message: &str) -> Self {
BacktestError::Internal {
error: message.to_string(),
}
.into()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method", content = "params", rename_all = "camelCase")]
pub enum SessionEventV1 {
ReadyForContinue,
SlotNotification(u64),
Paused(PausedEvent),
DiscoveryBatch(DiscoveryBatchEvent),
Error(BacktestError),
Success,
Completed {
#[serde(skip_serializing_if = "Option::is_none")]
summary: Option<SessionSummary>,
#[serde(default, skip_serializing_if = "Option::is_none")]
agent_stats: Option<Vec<AgentStatsReport>>,
},
Status {
status: BacktestStatus,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method", content = "params", rename_all = "camelCase")]
pub enum SessionEventKind {
ReadyForContinue,
SlotNotification(u64),
Paused(PausedEvent),
DiscoveryBatch(DiscoveryBatchEvent),
Error(BacktestError),
Success,
Completed {
#[serde(skip_serializing_if = "Option::is_none")]
summary: Option<SessionSummary>,
},
Status {
status: BacktestStatus,
},
}
impl SessionEventKind {
pub fn is_terminal(&self) -> bool {
match self {
Self::Completed { .. } => true,
Self::Error(e) => matches!(
e,
BacktestError::NoMoreBlocks
| BacktestError::AdvanceSlotFailed { .. }
| BacktestError::FinalizeSlotFailed { .. }
| BacktestError::Internal { .. }
),
_ => false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SequencedResponse {
pub seq_id: u64,
#[serde(flatten)]
pub response: BacktestResponse,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum BacktestStatus {
StartingRuntime,
DecodedTransactions,
AppliedAccountModifications,
ReadyToExecuteUserTransactions,
ExecutedUserTransactions,
ExecutingBlockTransactions,
ExecutedBlockTransactions,
ProgramAccountsLoaded,
}
impl std::fmt::Display for BacktestStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
Self::StartingRuntime => "starting runtime",
Self::DecodedTransactions => "decoded transactions",
Self::AppliedAccountModifications => "applied account modifications",
Self::ReadyToExecuteUserTransactions => "ready to execute user transactions",
Self::ExecutedUserTransactions => "executed user transactions",
Self::ExecutingBlockTransactions => "executing block transactions",
Self::ExecutedBlockTransactions => "executed block transactions",
Self::ProgramAccountsLoaded => "program accounts loaded",
};
f.write_str(s)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AgentStatsReport {
pub name: String,
pub slots_processed: u64,
pub opportunities_found: u64,
pub opportunities_skipped: u64,
pub no_routes: u64,
pub txs_produced: u64,
pub expected_gain_by_mint: BTreeMap<String, i64>,
#[serde(default)]
pub txs_submitted: u64,
#[serde(default)]
pub txs_failed: u64,
#[serde(default)]
pub txs_simulation_rejected: u64,
#[serde(default)]
pub txs_simulation_failed: u64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionSummary {
pub correct_simulation: usize,
pub incorrect_simulation: usize,
pub execution_errors: usize,
pub balance_diff: usize,
pub log_diff: usize,
}
impl SessionSummary {
pub fn has_deviations(&self) -> bool {
self.incorrect_simulation > 0 || self.execution_errors > 0 || self.balance_diff > 0
}
pub fn total_transactions(&self) -> usize {
self.correct_simulation
+ self.incorrect_simulation
+ self.execution_errors
+ self.balance_diff
+ self.log_diff
}
}
impl std::fmt::Display for SessionSummary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let total = self.total_transactions();
write!(
f,
"Session summary: {total} transactions\n\
\x20 - {} correct simulation\n\
\x20 - {} incorrect simulation\n\
\x20 - {} execution errors\n\
\x20 - {} balance diffs\n\
\x20 - {} log diffs",
self.correct_simulation,
self.incorrect_simulation,
self.execution_errors,
self.balance_diff,
self.log_diff,
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum BacktestError {
InvalidTransactionEncoding {
index: usize,
error: String,
},
InvalidTransactionFormat {
index: usize,
error: String,
},
InvalidAccountEncoding {
address: String,
encoding: BinaryEncoding,
error: String,
},
InvalidAccountOwner {
address: String,
error: String,
},
InvalidAccountPubkey {
address: String,
error: String,
},
NoMoreBlocks,
AdvanceSlotFailed {
slot: u64,
error: String,
},
FinalizeSlotFailed {
slot: u64,
error: String,
},
InvalidRequest {
error: String,
},
Internal {
error: String,
},
InvalidBlockhashFormat {
slot: u64,
error: String,
},
InitializingSysvarsFailed {
slot: u64,
error: String,
},
ClerkError {
error: String,
},
SimulationError {
error: String,
},
SessionNotFound {
session_id: String,
},
SessionOwnerMismatch,
SessionOwnershipBusy {
reason: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AvailableRange {
pub bundle_start_slot: u64,
pub bundle_start_slot_utc: Option<String>,
pub max_bundle_end_slot: Option<u64>,
pub max_bundle_end_slot_utc: Option<String>,
pub max_bundle_size: Option<u64>,
}
pub fn split_range(
ranges: &[AvailableRange],
requested_start: u64,
requested_end: u64,
) -> Result<Vec<(u64, u64)>, String> {
if requested_end < requested_start {
return Err(format!(
"invalid range: start_slot {requested_start} > end_slot {requested_end}"
));
}
let mut ends_by_start: BTreeMap<u64, BTreeSet<u64>> = BTreeMap::new();
for r in ranges {
if let Some(end) = r.max_bundle_end_slot
&& end > r.bundle_start_slot
{
ends_by_start
.entry(r.bundle_start_slot)
.or_default()
.insert(end);
}
}
let Some((&anchor_start, _)) = ends_by_start.range(..=requested_start).rfind(|(_, ends)| {
ends.iter()
.next_back()
.is_some_and(|&end| end >= requested_start)
}) else {
return Err(format!(
"start_slot {requested_start} is not covered by any available bundle range"
));
};
let mut best_from: BTreeMap<u64, Vec<(u64, u64)>> = BTreeMap::new();
for (&start, ends) in ends_by_start.range(anchor_start..=requested_end).rev() {
let mut best: Option<Vec<(u64, u64)>> = None;
for &end in ends {
let candidate = if end >= requested_end {
Some(vec![(start, requested_end)])
} else {
best_from.get(&(end + 1)).map(|rest| {
std::iter::once((start, end))
.chain(rest.iter().copied())
.collect()
})
};
if let Some(candidate) = candidate
&& best.as_ref().is_none_or(|b| candidate.len() > b.len())
{
best = Some(candidate);
}
}
if let Some(best) = best {
best_from.insert(start, best);
}
}
best_from.remove(&anchor_start).ok_or_else(|| {
let mut covered_to = anchor_start.saturating_sub(1);
for (&start, ends) in ends_by_start.range(anchor_start..=requested_end) {
if start > covered_to.saturating_add(1) {
break;
}
if let Some(&end) = ends.iter().next_back() {
covered_to = covered_to.max(end);
}
}
if covered_to < requested_end {
format!("gap in coverage at slot {}", covered_to + 1)
} else {
format!(
"no gap-free split of [{requested_start}, {requested_end}] aligns with the available bundle ranges"
)
}
})
}
impl From<BacktestError> for BacktestResponse {
fn from(error: BacktestError) -> Self {
Self::Error(error)
}
}
impl std::error::Error for BacktestError {}
impl fmt::Display for BacktestError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BacktestError::InvalidTransactionEncoding { index, error } => {
write!(f, "invalid transaction encoding at index {index}: {error}")
}
BacktestError::InvalidTransactionFormat { index, error } => {
write!(f, "invalid transaction format at index {index}: {error}")
}
BacktestError::InvalidAccountEncoding {
address,
encoding,
error,
} => write!(
f,
"invalid encoding for account {address} ({encoding:?}): {error}"
),
BacktestError::InvalidAccountOwner { address, error } => {
write!(f, "invalid owner for account {address}: {error}")
}
BacktestError::InvalidAccountPubkey { address, error } => {
write!(f, "invalid account pubkey {address}: {error}")
}
BacktestError::NoMoreBlocks => write!(f, "no more blocks available"),
BacktestError::AdvanceSlotFailed { slot, error } => {
write!(f, "failed to advance to slot {slot}: {error}")
}
BacktestError::FinalizeSlotFailed { slot, error } => {
write!(f, "failed to finalize slot {slot}: {error}")
}
BacktestError::InvalidRequest { error } => write!(f, "invalid request: {error}"),
BacktestError::Internal { error } => write!(f, "internal error: {error}"),
BacktestError::InvalidBlockhashFormat { slot, error } => {
write!(f, "invalid blockhash at slot {slot}: {error}")
}
BacktestError::InitializingSysvarsFailed { slot, error } => {
write!(f, "failed to initialize sysvars at slot {slot}: {error}")
}
BacktestError::ClerkError { error } => write!(f, "clerk error: {error}"),
BacktestError::SimulationError { error } => {
write!(f, "simulation error: {error}")
}
BacktestError::SessionNotFound { session_id } => {
write!(f, "session not found: {session_id}")
}
BacktestError::SessionOwnerMismatch => {
write!(f, "session owner mismatch")
}
BacktestError::SessionOwnershipBusy { reason } => {
write!(f, "session ownership busy: {reason}")
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fail_fast_divergence_kind_str_round_trips() {
for kind in [
FailFastDivergenceKind::AnyNonBenign,
FailFastDivergenceKind::Tracked,
] {
assert_eq!(
FailFastDivergenceKind::from_str_opt(kind.as_str()),
Some(kind)
);
}
assert_eq!(FailFastDivergenceKind::from_str_opt("nonsense"), None);
assert_eq!(
FailFastDivergenceKind::default(),
FailFastDivergenceKind::AnyNonBenign
);
}
fn range(start: u64, end: u64) -> AvailableRange {
AvailableRange {
bundle_start_slot: start,
bundle_start_slot_utc: None,
max_bundle_end_slot: Some(end),
max_bundle_end_slot_utc: None,
max_bundle_size: None,
}
}
#[rstest::rstest]
#[case::single(vec![range(100, 300)], 100, 300, Some(vec![(100, 300)]))]
#[case::multi(
vec![range(100, 200), range(201, 300), range(301, 400)],
100, 300, Some(vec![(100, 200), (201, 300)])
)]
#[case::nested(
vec![range(100, 500), range(110, 150), range(150, 190), range(501, 900)],
100, 900, Some(vec![(100, 500), (501, 900)])
)]
#[case::prefers_finer_grid(
vec![range(1_000, 1_999), range(1_500, 3_400), range(2_000, 2_999), range(3_000, 3_999)],
1_000, 3_999, Some(vec![(1_000, 1_999), (2_000, 2_999), (3_000, 3_999)])
)]
#[case::shared_start_prefers_finer(
vec![range(100, 150), range(100, 120), range(121, 140), range(141, 160)],
100, 160, Some(vec![(100, 120), (121, 140), (141, 160)])
)]
#[case::falls_back_to_coarse(
vec![range(100, 160), range(100, 120), range(121, 140)],
100, 160, Some(vec![(100, 160)])
)]
#[case::clamps_final_bundle(vec![range(100, 199), range(200, 999)], 100, 450, Some(vec![(100, 199), (200, 450)]))]
#[case::anchors_mid_bundle(vec![range(150, 350)], 200, 300, Some(vec![(150, 300)]))]
#[case::anchors_then_continues(
vec![range(150, 350), range(351, 600)],
200, 600, Some(vec![(150, 350), (351, 600)])
)]
#[case::start_inside_bundle_anchors(vec![range(200, 400)], 300, 400, Some(vec![(200, 400)]))]
#[case::start_before_first_bundle(vec![range(200, 400)], 100, 400, None)]
#[case::end_not_covered(vec![range(100, 200)], 100, 300, None)]
#[case::gap_in_coverage(vec![range(100, 200), range(210, 300)], 100, 300, None)]
#[case::inverted_range(vec![range(100, 300)], 300, 100, None)]
fn split_range_cases(
#[case] ranges: Vec<AvailableRange>,
#[case] start: u64,
#[case] end: u64,
#[case] expected: Option<Vec<(u64, u64)>>,
) {
match expected {
Some(expected) => assert_eq!(split_range(&ranges, start, end).unwrap(), expected),
None => assert!(split_range(&ranges, start, end).is_err()),
}
}
fn ends_by_start(ranges: &[AvailableRange]) -> BTreeMap<u64, BTreeSet<u64>> {
let mut ends: BTreeMap<u64, BTreeSet<u64>> = BTreeMap::new();
for r in ranges {
if let Some(end) = r.max_bundle_end_slot
&& end > r.bundle_start_slot
{
ends.entry(r.bundle_start_slot).or_default().insert(end);
}
}
ends
}
fn reference_max_split(
ends: &BTreeMap<u64, BTreeSet<u64>>,
cursor: u64,
end: u64,
) -> Option<Vec<(u64, u64)>> {
ends.get(&cursor)?
.iter()
.filter_map(|&bundle_end| {
if bundle_end >= end {
Some(vec![(cursor, end)])
} else {
reference_max_split(ends, bundle_end + 1, end).map(|mut rest| {
rest.insert(0, (cursor, bundle_end));
rest
})
}
})
.max_by_key(Vec::len)
}
fn is_valid_split(
split: &[(u64, u64)],
ends: &BTreeMap<u64, BTreeSet<u64>>,
start: u64,
end: u64,
) -> bool {
split.first().is_some_and(|&(s, _)| s == start)
&& split.last().is_some_and(|&(_, e)| e == end)
&& split.windows(2).all(|w| w[1].0 == w[0].1 + 1)
&& split.iter().all(|&(s, e)| {
e >= s
&& ends
.get(&s)
.and_then(|bundle_ends| bundle_ends.iter().next_back())
.is_some_and(|&max_end| e <= max_end)
})
}
#[test]
fn split_range_matches_reference() {
let mut seed: u64 = 0x9E3779B97F4A7C15;
let mut next = || {
seed = seed
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
seed >> 33
};
for _ in 0..50_000 {
let ranges: Vec<AvailableRange> = (0..next() % 6)
.map(|_| {
let start = next() % 12;
range(start, start + next() % 6) })
.collect();
let start = next() % 12;
let end = start + next() % 6;
let got = split_range(&ranges, start, end);
let ends = ends_by_start(&ranges);
let anchor = ends
.range(..=start)
.rfind(|(_, e)| e.iter().next_back().is_some_and(|&x| x >= start))
.map(|(&s, _)| s);
let reference = anchor.and_then(|a| reference_max_split(&ends, a, end));
let layout: Vec<_> = ranges
.iter()
.map(|r| (r.bundle_start_slot, r.max_bundle_end_slot))
.collect();
match (&got, &reference) {
(Ok(split), Some(best)) => {
assert!(
is_valid_split(split, &ends, anchor.unwrap(), end),
"invalid split {split:?} for {layout:?} [{start},{end}]"
);
assert_eq!(
split.len(),
best.len(),
"suboptimal split {split:?} vs {best:?} for {layout:?} [{start},{end}]"
);
}
(Err(_), None) => {}
_ => panic!(
"disagreement: split_range={got:?}, reference={reference:?} for {layout:?} [{start},{end}]"
),
}
}
}
}