Skip to main content

kora_lib/bundle/
helper.rs

1use crate::{
2    bundle::{BundleError, JitoError},
3    config::Config,
4    constant::ESTIMATED_LAMPORTS_FOR_PAYMENT_INSTRUCTION,
5    fee::fee::{FeeConfigUtil, TransactionFeeUtil},
6    lighthouse::LighthouseUtil,
7    plugin::{PluginExecutionContext, TransactionPluginRunner},
8    signer::bundle_signer::BundleSigner,
9    token::token::TokenUtil,
10    transaction::{TransactionUtil, VersionedTransactionResolved},
11    usage_limit::UsageTracker,
12    validator::transaction_validator::TransactionValidator,
13    KoraError,
14};
15use solana_client::nonblocking::rpc_client::RpcClient;
16use solana_commitment_config::CommitmentConfig;
17use solana_sdk::{instruction::Instruction, pubkey::Pubkey};
18use std::{collections::HashMap, sync::Arc};
19
20pub struct BundleProcessor {
21    pub resolved_transactions: Vec<VersionedTransactionResolved>,
22    pub total_required_lamports: u64,
23    pub total_payment_lamports: u64,
24    pub total_solana_estimated_fee: u64,
25}
26
27pub enum BundleProcessingMode<'a> {
28    CheckUsage(Option<&'a str>),
29    SkipUsage,
30}
31
32impl BundleProcessor {
33    /// Extract transactions at specified indices for processing.
34    /// Returns (filtered_transactions, index_to_position_map).
35    /// If `sign_only_indices` is None, returns all transactions with all indices.
36    pub fn extract_transactions_to_process(
37        transactions: &[String],
38        sign_only_indices: Option<Vec<usize>>,
39    ) -> Result<(Vec<String>, HashMap<usize, usize>), KoraError> {
40        let indices = sign_only_indices.unwrap_or_else(|| (0..transactions.len()).collect());
41
42        // Build map and filtered list (duplicates silently ignored)
43        let mut index_to_position: HashMap<usize, usize> = HashMap::with_capacity(indices.len());
44        let mut filtered: Vec<String> = Vec::with_capacity(indices.len());
45
46        for idx in indices {
47            if index_to_position.contains_key(&idx) {
48                continue; // skip duplicate
49            }
50            let tx = transactions.get(idx).ok_or_else(|| {
51                KoraError::ValidationError(format!(
52                    "sign_only_indices index {} out of bounds (bundle has {} transactions)",
53                    idx,
54                    transactions.len()
55                ))
56            })?;
57            index_to_position.insert(idx, filtered.len());
58            filtered.push(tx.clone());
59        }
60
61        Ok((filtered, index_to_position))
62    }
63
64    /// Merge signed transactions back into the original list, preserving order.
65    /// `index_to_position` maps original transaction index -> position in signed_transactions vec.
66    pub fn merge_signed_transactions(
67        original_transactions: &[String],
68        signed_transactions: Vec<String>,
69        index_to_position: &std::collections::HashMap<usize, usize>,
70    ) -> Vec<String> {
71        (0..original_transactions.len())
72            .map(|idx| {
73                if let Some(&position) = index_to_position.get(&idx) {
74                    signed_transactions[position].clone()
75                } else {
76                    original_transactions[idx].clone()
77                }
78            })
79            .collect()
80    }
81
82    #[allow(clippy::too_many_arguments)]
83    pub async fn process_bundle<'a>(
84        encoded_txs: &[String],
85        fee_payer: Pubkey,
86        payment_destination: &Pubkey,
87        config: &Config,
88        rpc_client: &Arc<RpcClient>,
89        sig_verify: bool,
90        plugin_context: Option<PluginExecutionContext>,
91        processing_mode: BundleProcessingMode<'a>,
92    ) -> Result<Self, KoraError> {
93        let validator = TransactionValidator::new(config, fee_payer)?;
94        let plugin_runner = TransactionPluginRunner::from_config(config);
95        let mut resolved_transactions = Vec::with_capacity(encoded_txs.len());
96        let mut total_required_lamports = 0u64;
97        let mut all_bundle_instructions: Vec<Instruction> = Vec::new();
98        let mut txs_missing_payment_count = 0u64;
99
100        // Phase 1: Decode, resolve, validate, calc fees, collect instructions
101        for encoded in encoded_txs {
102            let transaction = TransactionUtil::decode_b64_transaction(encoded)?;
103
104            let mut resolved_tx = VersionedTransactionResolved::from_transaction(
105                &transaction,
106                config,
107                rpc_client,
108                sig_verify,
109            )
110            .await?;
111
112            // Check usage limit for each transaction in the bundle (skip for estimates)
113            if let BundleProcessingMode::CheckUsage(user_id) = processing_mode {
114                UsageTracker::check_transaction_usage_limit(
115                    config,
116                    &mut resolved_tx,
117                    user_id,
118                    &fee_payer,
119                    rpc_client,
120                )
121                .await?;
122            }
123
124            validator.validate_transaction(config, &mut resolved_tx, rpc_client).await?;
125            if let Some(context) = plugin_context {
126                plugin_runner
127                    .run(&mut resolved_tx, config, rpc_client, &fee_payer, context)
128                    .await?;
129            }
130
131            let fee_calc = FeeConfigUtil::estimate_kora_fee(
132                &mut resolved_tx,
133                &fee_payer,
134                config.validation.is_payment_required(),
135                rpc_client,
136                config,
137            )
138            .await?;
139
140            total_required_lamports =
141                total_required_lamports.checked_add(fee_calc.total_fee_lamports).ok_or_else(
142                    || KoraError::ValidationError("Bundle fee calculation overflow".to_string()),
143                )?;
144
145            // Track how many transactions are missing payment instructions
146            if fee_calc.payment_instruction_fee > 0 {
147                txs_missing_payment_count += 1;
148            }
149
150            all_bundle_instructions.extend(resolved_tx.all_instructions.clone());
151            resolved_transactions.push(resolved_tx);
152        }
153
154        // For bundles, only ONE payment instruction is needed across all transactions.
155        // If multiple transactions are missing payments, we've overcounted by
156        // (txs_missing_payment_count - 1) * ESTIMATED_LAMPORTS_FOR_PAYMENT_INSTRUCTION
157        if txs_missing_payment_count > 1 {
158            let overcount =
159                (txs_missing_payment_count - 1) * ESTIMATED_LAMPORTS_FOR_PAYMENT_INSTRUCTION;
160
161            total_required_lamports =
162                total_required_lamports.checked_sub(overcount).ok_or_else(|| {
163                    KoraError::ValidationError("Bundle fee calculation overflow".to_string())
164                })?;
165        }
166
167        // Phase 2: Calculate payments with cross-tx ATA visibility
168        let mut total_payment_lamports = 0u64;
169        let mut total_solana_estimated_fee = 0u64;
170        for resolved in resolved_transactions.iter_mut() {
171            if let Some(payment) = TokenUtil::find_payment_in_transaction(
172                config,
173                resolved,
174                rpc_client,
175                payment_destination,
176                Some(&all_bundle_instructions),
177            )
178            .await?
179            {
180                total_payment_lamports =
181                    total_payment_lamports.checked_add(payment).ok_or_else(|| {
182                        KoraError::ValidationError("Payment calculation overflow".to_string())
183                    })?;
184            }
185
186            let fee = TransactionFeeUtil::get_estimate_fee_resolved(rpc_client, resolved).await?;
187            total_solana_estimated_fee =
188                total_solana_estimated_fee.checked_add(fee).ok_or_else(|| {
189                    KoraError::ValidationError("Bundle Solana fee calculation overflow".to_string())
190                })?;
191
192            validator.validate_lamport_fee(total_solana_estimated_fee)?;
193        }
194
195        Ok(Self {
196            resolved_transactions,
197            total_required_lamports,
198            total_payment_lamports,
199            total_solana_estimated_fee,
200        })
201    }
202
203    fn validate_payment(&self) -> Result<(), KoraError> {
204        if self.total_payment_lamports < self.total_required_lamports {
205            return Err(BundleError::Jito(JitoError::InsufficientBundlePayment(
206                self.total_required_lamports,
207                self.total_payment_lamports,
208            ))
209            .into());
210        }
211        Ok(())
212    }
213
214    pub async fn sign_all(
215        mut self,
216        signer: &Arc<solana_keychain::Signer>,
217        fee_payer: &Pubkey,
218        rpc_client: &RpcClient,
219        config: &Config,
220        will_send: bool,
221    ) -> Result<Vec<VersionedTransactionResolved>, KoraError> {
222        self.validate_payment()?;
223
224        let mut blockhash = None;
225        let tx_count = self.resolved_transactions.len();
226
227        for (i, resolved) in self.resolved_transactions.iter_mut().enumerate() {
228            // Get latest blockhash if signatures are empty and blockhash is not set
229            if blockhash.is_none() && resolved.transaction.signatures.is_empty() {
230                blockhash = Some(
231                    rpc_client
232                        .get_latest_blockhash_with_commitment(CommitmentConfig::confirmed())
233                        .await?
234                        .0,
235                );
236            }
237
238            // Add lighthouse assertion only to last transaction in bundle
239            if i == tx_count - 1 {
240                LighthouseUtil::add_fee_payer_assertion(
241                    &mut resolved.transaction,
242                    rpc_client,
243                    fee_payer,
244                    self.total_solana_estimated_fee,
245                    &config.kora.lighthouse,
246                    will_send,
247                )
248                .await?;
249            }
250
251            BundleSigner::sign_transaction_for_bundle(resolved, signer, fee_payer, &blockhash)
252                .await?;
253        }
254
255        Ok(self.resolved_transactions)
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_validate_payment_sufficient() {
265        let processor = BundleProcessor {
266            resolved_transactions: vec![],
267            total_required_lamports: 1000,
268            total_payment_lamports: 1500,
269            total_solana_estimated_fee: 1000,
270        };
271
272        assert!(processor.validate_payment().is_ok());
273    }
274
275    #[test]
276    fn test_validate_payment_exact() {
277        let processor = BundleProcessor {
278            resolved_transactions: vec![],
279            total_required_lamports: 1000,
280            total_payment_lamports: 1000,
281            total_solana_estimated_fee: 1000,
282        };
283
284        assert!(processor.validate_payment().is_ok());
285    }
286
287    #[test]
288    fn test_validate_payment_insufficient() {
289        let processor = BundleProcessor {
290            resolved_transactions: vec![],
291            total_required_lamports: 2000,
292            total_payment_lamports: 1000,
293            total_solana_estimated_fee: 1000,
294        };
295
296        let result = processor.validate_payment();
297        assert!(result.is_err());
298        let err = result.unwrap_err();
299        assert!(matches!(err, KoraError::JitoError(_)));
300        if let KoraError::JitoError(msg) = err {
301            assert!(msg.contains("insufficient"));
302            assert!(msg.contains("2000"));
303            assert!(msg.contains("1000"));
304        }
305    }
306
307    #[test]
308    fn test_validate_payment_zero_required() {
309        let processor = BundleProcessor {
310            resolved_transactions: vec![],
311            total_required_lamports: 0,
312            total_payment_lamports: 0,
313            total_solana_estimated_fee: 1000,
314        };
315
316        assert!(processor.validate_payment().is_ok());
317    }
318
319    #[test]
320    fn test_validate_payment_max_values() {
321        let processor = BundleProcessor {
322            resolved_transactions: vec![],
323            total_required_lamports: u64::MAX,
324            total_payment_lamports: u64::MAX,
325            total_solana_estimated_fee: 1000,
326        };
327
328        assert!(processor.validate_payment().is_ok());
329    }
330
331    #[test]
332    fn test_validate_payment_one_lamport_short() {
333        let processor = BundleProcessor {
334            resolved_transactions: vec![],
335            total_required_lamports: 1001,
336            total_payment_lamports: 1000,
337            total_solana_estimated_fee: 500,
338        };
339
340        let result = processor.validate_payment();
341        assert!(result.is_err());
342        let err = result.unwrap_err();
343        assert!(matches!(err, KoraError::JitoError(_)));
344    }
345
346    #[test]
347    fn test_bundle_processor_fields() {
348        let processor = BundleProcessor {
349            resolved_transactions: vec![],
350            total_required_lamports: 5000,
351            total_payment_lamports: 6000,
352            total_solana_estimated_fee: 2500,
353        };
354
355        assert_eq!(processor.total_required_lamports, 5000);
356        assert_eq!(processor.total_payment_lamports, 6000);
357        assert_eq!(processor.total_solana_estimated_fee, 2500);
358        assert!(processor.resolved_transactions.is_empty());
359    }
360
361    #[test]
362    fn test_extract_transactions_none_returns_all() {
363        let txs = vec!["tx0".to_string(), "tx1".to_string(), "tx2".to_string()];
364        let (result, index_to_position) =
365            BundleProcessor::extract_transactions_to_process(&txs, None).unwrap();
366        assert_eq!(result, txs);
367        assert_eq!(index_to_position.len(), 3);
368        assert_eq!(index_to_position.get(&0), Some(&0));
369        assert_eq!(index_to_position.get(&1), Some(&1));
370        assert_eq!(index_to_position.get(&2), Some(&2));
371    }
372
373    #[test]
374    fn test_extract_transactions_specific_indices() {
375        let txs = vec!["tx0".to_string(), "tx1".to_string(), "tx2".to_string()];
376        let (result, index_to_position) =
377            BundleProcessor::extract_transactions_to_process(&txs, Some(vec![0, 2])).unwrap();
378        assert_eq!(result, vec!["tx0".to_string(), "tx2".to_string()]);
379        assert_eq!(index_to_position.len(), 2);
380        assert_eq!(index_to_position.get(&0), Some(&0));
381        assert_eq!(index_to_position.get(&2), Some(&1));
382    }
383
384    #[test]
385    fn test_extract_transactions_out_of_bounds() {
386        let txs = vec!["tx0".to_string(), "tx1".to_string()];
387        let result = BundleProcessor::extract_transactions_to_process(&txs, Some(vec![0, 5]));
388        assert!(result.is_err());
389        let err = result.unwrap_err();
390        assert!(matches!(err, KoraError::ValidationError(_)));
391    }
392
393    #[test]
394    fn test_extract_transactions_empty_indices() {
395        let txs = vec!["tx0".to_string(), "tx1".to_string()];
396        let (result, index_to_position) =
397            BundleProcessor::extract_transactions_to_process(&txs, Some(vec![])).unwrap();
398        assert!(result.is_empty());
399        assert!(index_to_position.is_empty());
400    }
401
402    #[test]
403    fn test_extract_transactions_duplicate_indices_silently_skipped() {
404        let txs = vec!["tx0".to_string(), "tx1".to_string()];
405        let (result, index_to_position) =
406            BundleProcessor::extract_transactions_to_process(&txs, Some(vec![0, 0, 1])).unwrap();
407        // Duplicates are silently skipped, only unique indices processed
408        assert_eq!(result, vec!["tx0".to_string(), "tx1".to_string()]);
409        assert_eq!(index_to_position.len(), 2);
410        assert_eq!(index_to_position.get(&0), Some(&0)); // tx0 at position 0 in filtered
411        assert_eq!(index_to_position.get(&1), Some(&1)); // tx1 at position 1 in filtered
412    }
413
414    #[test]
415    fn test_merge_signed_transactions_preserves_order() {
416        let original =
417            vec!["tx0".to_string(), "tx1".to_string(), "tx2".to_string(), "tx3".to_string()];
418        let signed = vec!["signed_tx0".to_string(), "signed_tx2".to_string()];
419        // index 0 -> position 0, index 2 -> position 1
420        let index_to_position =
421            std::collections::HashMap::from([(0_usize, 0_usize), (2_usize, 1_usize)]);
422
423        let result =
424            BundleProcessor::merge_signed_transactions(&original, signed, &index_to_position);
425
426        assert_eq!(
427            result,
428            vec![
429                "signed_tx0".to_string(),
430                "tx1".to_string(),
431                "signed_tx2".to_string(),
432                "tx3".to_string(),
433            ]
434        );
435    }
436
437    #[test]
438    fn test_merge_signed_transactions_all_signed() {
439        let original = vec!["tx0".to_string(), "tx1".to_string()];
440        let signed = vec!["signed_tx0".to_string(), "signed_tx1".to_string()];
441        let index_to_position =
442            std::collections::HashMap::from([(0_usize, 0_usize), (1_usize, 1_usize)]);
443
444        let result =
445            BundleProcessor::merge_signed_transactions(&original, signed, &index_to_position);
446        assert_eq!(result, vec!["signed_tx0".to_string(), "signed_tx1".to_string()]);
447    }
448
449    #[test]
450    fn test_merge_signed_transactions_descending_indices() {
451        let original =
452            vec!["tx0".to_string(), "tx1".to_string(), "tx2".to_string(), "tx3".to_string()];
453        // indices [2, 0] means: signed[0] = tx2, signed[1] = tx0
454        let signed = vec!["signed_tx2".to_string(), "signed_tx0".to_string()];
455        // index 2 -> position 0, index 0 -> position 1
456        let index_to_position =
457            std::collections::HashMap::from([(2_usize, 0_usize), (0_usize, 1_usize)]);
458
459        let result =
460            BundleProcessor::merge_signed_transactions(&original, signed, &index_to_position);
461
462        assert_eq!(
463            result,
464            vec![
465                "signed_tx0".to_string(),
466                "tx1".to_string(),
467                "signed_tx2".to_string(),
468                "tx3".to_string(),
469            ]
470        );
471    }
472}