1use crate::{
2 bundle::{BundleError, JitoError},
3 config::Config,
4 constant::ESTIMATED_LAMPORTS_FOR_PAYMENT_INSTRUCTION,
5 fee::fee::{FeeConfigUtil, TransactionFeeUtil},
6 signer::bundle_signer::BundleSigner,
7 token::token::TokenUtil,
8 transaction::{TransactionUtil, VersionedTransactionResolved},
9 usage_limit::UsageTracker,
10 validator::transaction_validator::TransactionValidator,
11 KoraError,
12};
13use solana_client::nonblocking::rpc_client::RpcClient;
14use solana_commitment_config::CommitmentConfig;
15use solana_sdk::{instruction::Instruction, pubkey::Pubkey};
16use std::{collections::HashMap, sync::Arc};
17
18pub struct BundleProcessor {
19 pub resolved_transactions: Vec<VersionedTransactionResolved>,
20 pub total_required_lamports: u64,
21 pub total_payment_lamports: u64,
22 pub total_solana_estimated_fee: u64,
23}
24
25pub enum BundleProcessingMode<'a> {
26 CheckUsage(Option<&'a str>),
27 SkipUsage,
28}
29
30impl BundleProcessor {
31 pub fn extract_transactions_to_process(
35 transactions: &[String],
36 sign_only_indices: Option<Vec<usize>>,
37 ) -> Result<(Vec<String>, HashMap<usize, usize>), KoraError> {
38 let indices = sign_only_indices.unwrap_or_else(|| (0..transactions.len()).collect());
39
40 let mut index_to_position: HashMap<usize, usize> = HashMap::with_capacity(indices.len());
42 let mut filtered: Vec<String> = Vec::with_capacity(indices.len());
43
44 for idx in indices {
45 if index_to_position.contains_key(&idx) {
46 continue; }
48 let tx = transactions.get(idx).ok_or_else(|| {
49 KoraError::ValidationError(format!(
50 "sign_only_indices index {} out of bounds (bundle has {} transactions)",
51 idx,
52 transactions.len()
53 ))
54 })?;
55 index_to_position.insert(idx, filtered.len());
56 filtered.push(tx.clone());
57 }
58
59 Ok((filtered, index_to_position))
60 }
61
62 pub fn merge_signed_transactions(
65 original_transactions: &[String],
66 signed_transactions: Vec<String>,
67 index_to_position: &std::collections::HashMap<usize, usize>,
68 ) -> Vec<String> {
69 (0..original_transactions.len())
70 .map(|idx| {
71 if let Some(&position) = index_to_position.get(&idx) {
72 signed_transactions[position].clone()
73 } else {
74 original_transactions[idx].clone()
75 }
76 })
77 .collect()
78 }
79
80 #[allow(clippy::too_many_arguments)]
81 pub async fn process_bundle<'a>(
82 encoded_txs: &[String],
83 fee_payer: Pubkey,
84 payment_destination: &Pubkey,
85 config: &Config,
86 rpc_client: &Arc<RpcClient>,
87 sig_verify: bool,
88 processing_mode: BundleProcessingMode<'a>,
89 ) -> Result<Self, KoraError> {
90 let validator = TransactionValidator::new(config, fee_payer)?;
91 let mut resolved_transactions = Vec::with_capacity(encoded_txs.len());
92 let mut total_required_lamports = 0u64;
93 let mut all_bundle_instructions: Vec<Instruction> = Vec::new();
94 let mut txs_missing_payment_count = 0u64;
95
96 for encoded in encoded_txs {
98 let transaction = TransactionUtil::decode_b64_transaction(encoded)?;
99
100 let mut resolved_tx = VersionedTransactionResolved::from_transaction(
101 &transaction,
102 config,
103 rpc_client,
104 sig_verify,
105 )
106 .await?;
107
108 if let BundleProcessingMode::CheckUsage(user_id) = processing_mode {
110 UsageTracker::check_transaction_usage_limit(
111 config,
112 &mut resolved_tx,
113 user_id,
114 &fee_payer,
115 rpc_client,
116 )
117 .await?;
118 }
119
120 validator.validate_transaction(config, &mut resolved_tx, rpc_client).await?;
121
122 let fee_calc = FeeConfigUtil::estimate_kora_fee(
123 &mut resolved_tx,
124 &fee_payer,
125 config.validation.is_payment_required(),
126 rpc_client,
127 config,
128 )
129 .await?;
130
131 total_required_lamports =
132 total_required_lamports.checked_add(fee_calc.total_fee_lamports).ok_or_else(
133 || KoraError::ValidationError("Bundle fee calculation overflow".to_string()),
134 )?;
135
136 if fee_calc.payment_instruction_fee > 0 {
138 txs_missing_payment_count += 1;
139 }
140
141 all_bundle_instructions.extend(resolved_tx.all_instructions.clone());
142 resolved_transactions.push(resolved_tx);
143 }
144
145 if txs_missing_payment_count > 1 {
149 let overcount =
150 (txs_missing_payment_count - 1) * ESTIMATED_LAMPORTS_FOR_PAYMENT_INSTRUCTION;
151
152 total_required_lamports =
153 total_required_lamports.checked_sub(overcount).ok_or_else(|| {
154 KoraError::ValidationError("Bundle fee calculation overflow".to_string())
155 })?;
156 }
157
158 let mut total_payment_lamports = 0u64;
160 let mut total_solana_estimated_fee = 0u64;
161 for resolved in resolved_transactions.iter_mut() {
162 if let Some(payment) = TokenUtil::find_payment_in_transaction(
163 config,
164 resolved,
165 rpc_client,
166 payment_destination,
167 Some(&all_bundle_instructions),
168 )
169 .await?
170 {
171 total_payment_lamports =
172 total_payment_lamports.checked_add(payment).ok_or_else(|| {
173 KoraError::ValidationError("Payment calculation overflow".to_string())
174 })?;
175 }
176
177 let fee = TransactionFeeUtil::get_estimate_fee_resolved(rpc_client, resolved).await?;
178 total_solana_estimated_fee =
179 total_solana_estimated_fee.checked_add(fee).ok_or_else(|| {
180 KoraError::ValidationError("Bundle Solana fee calculation overflow".to_string())
181 })?;
182
183 validator.validate_lamport_fee(total_solana_estimated_fee)?;
184 }
185
186 Ok(Self {
187 resolved_transactions,
188 total_required_lamports,
189 total_payment_lamports,
190 total_solana_estimated_fee,
191 })
192 }
193
194 fn validate_payment(&self) -> Result<(), KoraError> {
195 if self.total_payment_lamports < self.total_required_lamports {
196 return Err(BundleError::Jito(JitoError::InsufficientBundlePayment(
197 self.total_required_lamports,
198 self.total_payment_lamports,
199 ))
200 .into());
201 }
202 Ok(())
203 }
204
205 pub async fn sign_all(
206 mut self,
207 signer: &Arc<solana_keychain::Signer>,
208 fee_payer: &Pubkey,
209 rpc_client: &RpcClient,
210 ) -> Result<Vec<VersionedTransactionResolved>, KoraError> {
211 self.validate_payment()?;
212
213 let mut blockhash = None;
214
215 for resolved in self.resolved_transactions.iter_mut() {
216 if blockhash.is_none() && resolved.transaction.signatures.is_empty() {
218 blockhash = Some(
219 rpc_client
220 .get_latest_blockhash_with_commitment(CommitmentConfig::confirmed())
221 .await?
222 .0,
223 );
224 }
225
226 BundleSigner::sign_transaction_for_bundle(resolved, signer, fee_payer, &blockhash)
227 .await?;
228 }
229
230 Ok(self.resolved_transactions)
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn test_validate_payment_sufficient() {
240 let processor = BundleProcessor {
241 resolved_transactions: vec![],
242 total_required_lamports: 1000,
243 total_payment_lamports: 1500,
244 total_solana_estimated_fee: 1000,
245 };
246
247 assert!(processor.validate_payment().is_ok());
248 }
249
250 #[test]
251 fn test_validate_payment_exact() {
252 let processor = BundleProcessor {
253 resolved_transactions: vec![],
254 total_required_lamports: 1000,
255 total_payment_lamports: 1000,
256 total_solana_estimated_fee: 1000,
257 };
258
259 assert!(processor.validate_payment().is_ok());
260 }
261
262 #[test]
263 fn test_validate_payment_insufficient() {
264 let processor = BundleProcessor {
265 resolved_transactions: vec![],
266 total_required_lamports: 2000,
267 total_payment_lamports: 1000,
268 total_solana_estimated_fee: 1000,
269 };
270
271 let result = processor.validate_payment();
272 assert!(result.is_err());
273 let err = result.unwrap_err();
274 assert!(matches!(err, KoraError::JitoError(_)));
275 if let KoraError::JitoError(msg) = err {
276 assert!(msg.contains("insufficient"));
277 assert!(msg.contains("2000"));
278 assert!(msg.contains("1000"));
279 }
280 }
281
282 #[test]
283 fn test_validate_payment_zero_required() {
284 let processor = BundleProcessor {
285 resolved_transactions: vec![],
286 total_required_lamports: 0,
287 total_payment_lamports: 0,
288 total_solana_estimated_fee: 1000,
289 };
290
291 assert!(processor.validate_payment().is_ok());
292 }
293
294 #[test]
295 fn test_validate_payment_max_values() {
296 let processor = BundleProcessor {
297 resolved_transactions: vec![],
298 total_required_lamports: u64::MAX,
299 total_payment_lamports: u64::MAX,
300 total_solana_estimated_fee: 1000,
301 };
302
303 assert!(processor.validate_payment().is_ok());
304 }
305
306 #[test]
307 fn test_validate_payment_one_lamport_short() {
308 let processor = BundleProcessor {
309 resolved_transactions: vec![],
310 total_required_lamports: 1001,
311 total_payment_lamports: 1000,
312 total_solana_estimated_fee: 500,
313 };
314
315 let result = processor.validate_payment();
316 assert!(result.is_err());
317 let err = result.unwrap_err();
318 assert!(matches!(err, KoraError::JitoError(_)));
319 }
320
321 #[test]
322 fn test_bundle_processor_fields() {
323 let processor = BundleProcessor {
324 resolved_transactions: vec![],
325 total_required_lamports: 5000,
326 total_payment_lamports: 6000,
327 total_solana_estimated_fee: 2500,
328 };
329
330 assert_eq!(processor.total_required_lamports, 5000);
331 assert_eq!(processor.total_payment_lamports, 6000);
332 assert_eq!(processor.total_solana_estimated_fee, 2500);
333 assert!(processor.resolved_transactions.is_empty());
334 }
335
336 #[test]
337 fn test_extract_transactions_none_returns_all() {
338 let txs = vec!["tx0".to_string(), "tx1".to_string(), "tx2".to_string()];
339 let (result, index_to_position) =
340 BundleProcessor::extract_transactions_to_process(&txs, None).unwrap();
341 assert_eq!(result, txs);
342 assert_eq!(index_to_position.len(), 3);
343 assert_eq!(index_to_position.get(&0), Some(&0));
344 assert_eq!(index_to_position.get(&1), Some(&1));
345 assert_eq!(index_to_position.get(&2), Some(&2));
346 }
347
348 #[test]
349 fn test_extract_transactions_specific_indices() {
350 let txs = vec!["tx0".to_string(), "tx1".to_string(), "tx2".to_string()];
351 let (result, index_to_position) =
352 BundleProcessor::extract_transactions_to_process(&txs, Some(vec![0, 2])).unwrap();
353 assert_eq!(result, vec!["tx0".to_string(), "tx2".to_string()]);
354 assert_eq!(index_to_position.len(), 2);
355 assert_eq!(index_to_position.get(&0), Some(&0));
356 assert_eq!(index_to_position.get(&2), Some(&1));
357 }
358
359 #[test]
360 fn test_extract_transactions_out_of_bounds() {
361 let txs = vec!["tx0".to_string(), "tx1".to_string()];
362 let result = BundleProcessor::extract_transactions_to_process(&txs, Some(vec![0, 5]));
363 assert!(result.is_err());
364 let err = result.unwrap_err();
365 assert!(matches!(err, KoraError::ValidationError(_)));
366 }
367
368 #[test]
369 fn test_extract_transactions_empty_indices() {
370 let txs = vec!["tx0".to_string(), "tx1".to_string()];
371 let (result, index_to_position) =
372 BundleProcessor::extract_transactions_to_process(&txs, Some(vec![])).unwrap();
373 assert!(result.is_empty());
374 assert!(index_to_position.is_empty());
375 }
376
377 #[test]
378 fn test_extract_transactions_duplicate_indices_silently_skipped() {
379 let txs = vec!["tx0".to_string(), "tx1".to_string()];
380 let (result, index_to_position) =
381 BundleProcessor::extract_transactions_to_process(&txs, Some(vec![0, 0, 1])).unwrap();
382 assert_eq!(result, vec!["tx0".to_string(), "tx1".to_string()]);
384 assert_eq!(index_to_position.len(), 2);
385 assert_eq!(index_to_position.get(&0), Some(&0)); assert_eq!(index_to_position.get(&1), Some(&1)); }
388
389 #[test]
390 fn test_merge_signed_transactions_preserves_order() {
391 let original =
392 vec!["tx0".to_string(), "tx1".to_string(), "tx2".to_string(), "tx3".to_string()];
393 let signed = vec!["signed_tx0".to_string(), "signed_tx2".to_string()];
394 let index_to_position =
396 std::collections::HashMap::from([(0_usize, 0_usize), (2_usize, 1_usize)]);
397
398 let result =
399 BundleProcessor::merge_signed_transactions(&original, signed, &index_to_position);
400
401 assert_eq!(
402 result,
403 vec![
404 "signed_tx0".to_string(),
405 "tx1".to_string(),
406 "signed_tx2".to_string(),
407 "tx3".to_string(),
408 ]
409 );
410 }
411
412 #[test]
413 fn test_merge_signed_transactions_all_signed() {
414 let original = vec!["tx0".to_string(), "tx1".to_string()];
415 let signed = vec!["signed_tx0".to_string(), "signed_tx1".to_string()];
416 let index_to_position =
417 std::collections::HashMap::from([(0_usize, 0_usize), (1_usize, 1_usize)]);
418
419 let result =
420 BundleProcessor::merge_signed_transactions(&original, signed, &index_to_position);
421 assert_eq!(result, vec!["signed_tx0".to_string(), "signed_tx1".to_string()]);
422 }
423
424 #[test]
425 fn test_merge_signed_transactions_descending_indices() {
426 let original =
427 vec!["tx0".to_string(), "tx1".to_string(), "tx2".to_string(), "tx3".to_string()];
428 let signed = vec!["signed_tx2".to_string(), "signed_tx0".to_string()];
430 let index_to_position =
432 std::collections::HashMap::from([(2_usize, 0_usize), (0_usize, 1_usize)]);
433
434 let result =
435 BundleProcessor::merge_signed_transactions(&original, signed, &index_to_position);
436
437 assert_eq!(
438 result,
439 vec![
440 "signed_tx0".to_string(),
441 "tx1".to_string(),
442 "signed_tx2".to_string(),
443 "tx3".to_string(),
444 ]
445 );
446 }
447}