kora_lib/bundle/
helper.rs

1use crate::{
2    bundle::{BundleError, JitoError},
3    config::Config,
4    constant::ESTIMATED_LAMPORTS_FOR_PAYMENT_INSTRUCTION,
5    fee::fee::{FeeConfigUtil, TransactionFeeUtil},
6    signer::bundle_signer::BundleSigner,
7    token::token::TokenUtil,
8    transaction::{TransactionUtil, VersionedTransactionResolved},
9    usage_limit::UsageTracker,
10    validator::transaction_validator::TransactionValidator,
11    KoraError,
12};
13use solana_client::nonblocking::rpc_client::RpcClient;
14use solana_commitment_config::CommitmentConfig;
15use solana_sdk::{instruction::Instruction, pubkey::Pubkey};
16use std::{collections::HashMap, sync::Arc};
17
18pub struct BundleProcessor {
19    pub resolved_transactions: Vec<VersionedTransactionResolved>,
20    pub total_required_lamports: u64,
21    pub total_payment_lamports: u64,
22    pub total_solana_estimated_fee: u64,
23}
24
25pub enum BundleProcessingMode<'a> {
26    CheckUsage(Option<&'a str>),
27    SkipUsage,
28}
29
30impl BundleProcessor {
31    /// Extract transactions at specified indices for processing.
32    /// Returns (filtered_transactions, index_to_position_map).
33    /// If `sign_only_indices` is None, returns all transactions with all indices.
34    pub fn extract_transactions_to_process(
35        transactions: &[String],
36        sign_only_indices: Option<Vec<usize>>,
37    ) -> Result<(Vec<String>, HashMap<usize, usize>), KoraError> {
38        let indices = sign_only_indices.unwrap_or_else(|| (0..transactions.len()).collect());
39
40        // Build map and filtered list (duplicates silently ignored)
41        let mut index_to_position: HashMap<usize, usize> = HashMap::with_capacity(indices.len());
42        let mut filtered: Vec<String> = Vec::with_capacity(indices.len());
43
44        for idx in indices {
45            if index_to_position.contains_key(&idx) {
46                continue; // skip duplicate
47            }
48            let tx = transactions.get(idx).ok_or_else(|| {
49                KoraError::ValidationError(format!(
50                    "sign_only_indices index {} out of bounds (bundle has {} transactions)",
51                    idx,
52                    transactions.len()
53                ))
54            })?;
55            index_to_position.insert(idx, filtered.len());
56            filtered.push(tx.clone());
57        }
58
59        Ok((filtered, index_to_position))
60    }
61
62    /// Merge signed transactions back into the original list, preserving order.
63    /// `index_to_position` maps original transaction index -> position in signed_transactions vec.
64    pub fn merge_signed_transactions(
65        original_transactions: &[String],
66        signed_transactions: Vec<String>,
67        index_to_position: &std::collections::HashMap<usize, usize>,
68    ) -> Vec<String> {
69        (0..original_transactions.len())
70            .map(|idx| {
71                if let Some(&position) = index_to_position.get(&idx) {
72                    signed_transactions[position].clone()
73                } else {
74                    original_transactions[idx].clone()
75                }
76            })
77            .collect()
78    }
79
80    #[allow(clippy::too_many_arguments)]
81    pub async fn process_bundle<'a>(
82        encoded_txs: &[String],
83        fee_payer: Pubkey,
84        payment_destination: &Pubkey,
85        config: &Config,
86        rpc_client: &Arc<RpcClient>,
87        sig_verify: bool,
88        processing_mode: BundleProcessingMode<'a>,
89    ) -> Result<Self, KoraError> {
90        let validator = TransactionValidator::new(config, fee_payer)?;
91        let mut resolved_transactions = Vec::with_capacity(encoded_txs.len());
92        let mut total_required_lamports = 0u64;
93        let mut all_bundle_instructions: Vec<Instruction> = Vec::new();
94        let mut txs_missing_payment_count = 0u64;
95
96        // Phase 1: Decode, resolve, validate, calc fees, collect instructions
97        for encoded in encoded_txs {
98            let transaction = TransactionUtil::decode_b64_transaction(encoded)?;
99
100            let mut resolved_tx = VersionedTransactionResolved::from_transaction(
101                &transaction,
102                config,
103                rpc_client,
104                sig_verify,
105            )
106            .await?;
107
108            // Check usage limit for each transaction in the bundle (skip for estimates)
109            if let BundleProcessingMode::CheckUsage(user_id) = processing_mode {
110                UsageTracker::check_transaction_usage_limit(
111                    config,
112                    &mut resolved_tx,
113                    user_id,
114                    &fee_payer,
115                    rpc_client,
116                )
117                .await?;
118            }
119
120            validator.validate_transaction(config, &mut resolved_tx, rpc_client).await?;
121
122            let fee_calc = FeeConfigUtil::estimate_kora_fee(
123                &mut resolved_tx,
124                &fee_payer,
125                config.validation.is_payment_required(),
126                rpc_client,
127                config,
128            )
129            .await?;
130
131            total_required_lamports =
132                total_required_lamports.checked_add(fee_calc.total_fee_lamports).ok_or_else(
133                    || KoraError::ValidationError("Bundle fee calculation overflow".to_string()),
134                )?;
135
136            // Track how many transactions are missing payment instructions
137            if fee_calc.payment_instruction_fee > 0 {
138                txs_missing_payment_count += 1;
139            }
140
141            all_bundle_instructions.extend(resolved_tx.all_instructions.clone());
142            resolved_transactions.push(resolved_tx);
143        }
144
145        // For bundles, only ONE payment instruction is needed across all transactions.
146        // If multiple transactions are missing payments, we've overcounted by
147        // (txs_missing_payment_count - 1) * ESTIMATED_LAMPORTS_FOR_PAYMENT_INSTRUCTION
148        if txs_missing_payment_count > 1 {
149            let overcount =
150                (txs_missing_payment_count - 1) * ESTIMATED_LAMPORTS_FOR_PAYMENT_INSTRUCTION;
151
152            total_required_lamports =
153                total_required_lamports.checked_sub(overcount).ok_or_else(|| {
154                    KoraError::ValidationError("Bundle fee calculation overflow".to_string())
155                })?;
156        }
157
158        // Phase 2: Calculate payments with cross-tx ATA visibility
159        let mut total_payment_lamports = 0u64;
160        let mut total_solana_estimated_fee = 0u64;
161        for resolved in resolved_transactions.iter_mut() {
162            if let Some(payment) = TokenUtil::find_payment_in_transaction(
163                config,
164                resolved,
165                rpc_client,
166                payment_destination,
167                Some(&all_bundle_instructions),
168            )
169            .await?
170            {
171                total_payment_lamports =
172                    total_payment_lamports.checked_add(payment).ok_or_else(|| {
173                        KoraError::ValidationError("Payment calculation overflow".to_string())
174                    })?;
175            }
176
177            let fee = TransactionFeeUtil::get_estimate_fee_resolved(rpc_client, resolved).await?;
178            total_solana_estimated_fee =
179                total_solana_estimated_fee.checked_add(fee).ok_or_else(|| {
180                    KoraError::ValidationError("Bundle Solana fee calculation overflow".to_string())
181                })?;
182
183            validator.validate_lamport_fee(total_solana_estimated_fee)?;
184        }
185
186        Ok(Self {
187            resolved_transactions,
188            total_required_lamports,
189            total_payment_lamports,
190            total_solana_estimated_fee,
191        })
192    }
193
194    fn validate_payment(&self) -> Result<(), KoraError> {
195        if self.total_payment_lamports < self.total_required_lamports {
196            return Err(BundleError::Jito(JitoError::InsufficientBundlePayment(
197                self.total_required_lamports,
198                self.total_payment_lamports,
199            ))
200            .into());
201        }
202        Ok(())
203    }
204
205    pub async fn sign_all(
206        mut self,
207        signer: &Arc<solana_keychain::Signer>,
208        fee_payer: &Pubkey,
209        rpc_client: &RpcClient,
210    ) -> Result<Vec<VersionedTransactionResolved>, KoraError> {
211        self.validate_payment()?;
212
213        let mut blockhash = None;
214
215        for resolved in self.resolved_transactions.iter_mut() {
216            // Get latest blockhash if signatures are empty and blockhash is not set
217            if blockhash.is_none() && resolved.transaction.signatures.is_empty() {
218                blockhash = Some(
219                    rpc_client
220                        .get_latest_blockhash_with_commitment(CommitmentConfig::confirmed())
221                        .await?
222                        .0,
223                );
224            }
225
226            BundleSigner::sign_transaction_for_bundle(resolved, signer, fee_payer, &blockhash)
227                .await?;
228        }
229
230        Ok(self.resolved_transactions)
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn test_validate_payment_sufficient() {
240        let processor = BundleProcessor {
241            resolved_transactions: vec![],
242            total_required_lamports: 1000,
243            total_payment_lamports: 1500,
244            total_solana_estimated_fee: 1000,
245        };
246
247        assert!(processor.validate_payment().is_ok());
248    }
249
250    #[test]
251    fn test_validate_payment_exact() {
252        let processor = BundleProcessor {
253            resolved_transactions: vec![],
254            total_required_lamports: 1000,
255            total_payment_lamports: 1000,
256            total_solana_estimated_fee: 1000,
257        };
258
259        assert!(processor.validate_payment().is_ok());
260    }
261
262    #[test]
263    fn test_validate_payment_insufficient() {
264        let processor = BundleProcessor {
265            resolved_transactions: vec![],
266            total_required_lamports: 2000,
267            total_payment_lamports: 1000,
268            total_solana_estimated_fee: 1000,
269        };
270
271        let result = processor.validate_payment();
272        assert!(result.is_err());
273        let err = result.unwrap_err();
274        assert!(matches!(err, KoraError::JitoError(_)));
275        if let KoraError::JitoError(msg) = err {
276            assert!(msg.contains("insufficient"));
277            assert!(msg.contains("2000"));
278            assert!(msg.contains("1000"));
279        }
280    }
281
282    #[test]
283    fn test_validate_payment_zero_required() {
284        let processor = BundleProcessor {
285            resolved_transactions: vec![],
286            total_required_lamports: 0,
287            total_payment_lamports: 0,
288            total_solana_estimated_fee: 1000,
289        };
290
291        assert!(processor.validate_payment().is_ok());
292    }
293
294    #[test]
295    fn test_validate_payment_max_values() {
296        let processor = BundleProcessor {
297            resolved_transactions: vec![],
298            total_required_lamports: u64::MAX,
299            total_payment_lamports: u64::MAX,
300            total_solana_estimated_fee: 1000,
301        };
302
303        assert!(processor.validate_payment().is_ok());
304    }
305
306    #[test]
307    fn test_validate_payment_one_lamport_short() {
308        let processor = BundleProcessor {
309            resolved_transactions: vec![],
310            total_required_lamports: 1001,
311            total_payment_lamports: 1000,
312            total_solana_estimated_fee: 500,
313        };
314
315        let result = processor.validate_payment();
316        assert!(result.is_err());
317        let err = result.unwrap_err();
318        assert!(matches!(err, KoraError::JitoError(_)));
319    }
320
321    #[test]
322    fn test_bundle_processor_fields() {
323        let processor = BundleProcessor {
324            resolved_transactions: vec![],
325            total_required_lamports: 5000,
326            total_payment_lamports: 6000,
327            total_solana_estimated_fee: 2500,
328        };
329
330        assert_eq!(processor.total_required_lamports, 5000);
331        assert_eq!(processor.total_payment_lamports, 6000);
332        assert_eq!(processor.total_solana_estimated_fee, 2500);
333        assert!(processor.resolved_transactions.is_empty());
334    }
335
336    #[test]
337    fn test_extract_transactions_none_returns_all() {
338        let txs = vec!["tx0".to_string(), "tx1".to_string(), "tx2".to_string()];
339        let (result, index_to_position) =
340            BundleProcessor::extract_transactions_to_process(&txs, None).unwrap();
341        assert_eq!(result, txs);
342        assert_eq!(index_to_position.len(), 3);
343        assert_eq!(index_to_position.get(&0), Some(&0));
344        assert_eq!(index_to_position.get(&1), Some(&1));
345        assert_eq!(index_to_position.get(&2), Some(&2));
346    }
347
348    #[test]
349    fn test_extract_transactions_specific_indices() {
350        let txs = vec!["tx0".to_string(), "tx1".to_string(), "tx2".to_string()];
351        let (result, index_to_position) =
352            BundleProcessor::extract_transactions_to_process(&txs, Some(vec![0, 2])).unwrap();
353        assert_eq!(result, vec!["tx0".to_string(), "tx2".to_string()]);
354        assert_eq!(index_to_position.len(), 2);
355        assert_eq!(index_to_position.get(&0), Some(&0));
356        assert_eq!(index_to_position.get(&2), Some(&1));
357    }
358
359    #[test]
360    fn test_extract_transactions_out_of_bounds() {
361        let txs = vec!["tx0".to_string(), "tx1".to_string()];
362        let result = BundleProcessor::extract_transactions_to_process(&txs, Some(vec![0, 5]));
363        assert!(result.is_err());
364        let err = result.unwrap_err();
365        assert!(matches!(err, KoraError::ValidationError(_)));
366    }
367
368    #[test]
369    fn test_extract_transactions_empty_indices() {
370        let txs = vec!["tx0".to_string(), "tx1".to_string()];
371        let (result, index_to_position) =
372            BundleProcessor::extract_transactions_to_process(&txs, Some(vec![])).unwrap();
373        assert!(result.is_empty());
374        assert!(index_to_position.is_empty());
375    }
376
377    #[test]
378    fn test_extract_transactions_duplicate_indices_silently_skipped() {
379        let txs = vec!["tx0".to_string(), "tx1".to_string()];
380        let (result, index_to_position) =
381            BundleProcessor::extract_transactions_to_process(&txs, Some(vec![0, 0, 1])).unwrap();
382        // Duplicates are silently skipped, only unique indices processed
383        assert_eq!(result, vec!["tx0".to_string(), "tx1".to_string()]);
384        assert_eq!(index_to_position.len(), 2);
385        assert_eq!(index_to_position.get(&0), Some(&0)); // tx0 at position 0 in filtered
386        assert_eq!(index_to_position.get(&1), Some(&1)); // tx1 at position 1 in filtered
387    }
388
389    #[test]
390    fn test_merge_signed_transactions_preserves_order() {
391        let original =
392            vec!["tx0".to_string(), "tx1".to_string(), "tx2".to_string(), "tx3".to_string()];
393        let signed = vec!["signed_tx0".to_string(), "signed_tx2".to_string()];
394        // index 0 -> position 0, index 2 -> position 1
395        let index_to_position =
396            std::collections::HashMap::from([(0_usize, 0_usize), (2_usize, 1_usize)]);
397
398        let result =
399            BundleProcessor::merge_signed_transactions(&original, signed, &index_to_position);
400
401        assert_eq!(
402            result,
403            vec![
404                "signed_tx0".to_string(),
405                "tx1".to_string(),
406                "signed_tx2".to_string(),
407                "tx3".to_string(),
408            ]
409        );
410    }
411
412    #[test]
413    fn test_merge_signed_transactions_all_signed() {
414        let original = vec!["tx0".to_string(), "tx1".to_string()];
415        let signed = vec!["signed_tx0".to_string(), "signed_tx1".to_string()];
416        let index_to_position =
417            std::collections::HashMap::from([(0_usize, 0_usize), (1_usize, 1_usize)]);
418
419        let result =
420            BundleProcessor::merge_signed_transactions(&original, signed, &index_to_position);
421        assert_eq!(result, vec!["signed_tx0".to_string(), "signed_tx1".to_string()]);
422    }
423
424    #[test]
425    fn test_merge_signed_transactions_descending_indices() {
426        let original =
427            vec!["tx0".to_string(), "tx1".to_string(), "tx2".to_string(), "tx3".to_string()];
428        // indices [2, 0] means: signed[0] = tx2, signed[1] = tx0
429        let signed = vec!["signed_tx2".to_string(), "signed_tx0".to_string()];
430        // index 2 -> position 0, index 0 -> position 1
431        let index_to_position =
432            std::collections::HashMap::from([(2_usize, 0_usize), (0_usize, 1_usize)]);
433
434        let result =
435            BundleProcessor::merge_signed_transactions(&original, signed, &index_to_position);
436
437        assert_eq!(
438            result,
439            vec![
440                "signed_tx0".to_string(),
441                "tx1".to_string(),
442                "signed_tx2".to_string(),
443                "tx3".to_string(),
444            ]
445        );
446    }
447}