Skip to main content

kora_lib/admin/
token_util.rs

1use crate::{
2    config::Config,
3    error::KoraError,
4    state::{get_request_signer_with_signer_key, get_signer_pool},
5    token::token::TokenType,
6    transaction::TransactionUtil,
7};
8use solana_client::nonblocking::rpc_client::RpcClient;
9use solana_compute_budget_interface::ComputeBudgetInstruction;
10use solana_keychain::SolanaSigner;
11use solana_message::{Message, VersionedMessage};
12use solana_sdk::{instruction::Instruction, pubkey::Pubkey};
13
14use spl_associated_token_account_interface::{
15    address::get_associated_token_address, instruction::create_associated_token_account,
16};
17use std::{fmt::Display, str::FromStr, sync::Arc};
18
19#[cfg(not(test))]
20use {crate::cache::CacheUtil, crate::state::get_config};
21
22#[cfg(test)]
23use {
24    crate::config::SplTokenConfig, crate::tests::cache_mock::MockCacheUtil as CacheUtil,
25    crate::tests::config_mock::mock_state::get_config,
26};
27
28/*
29This funciton is tested via the makefile, as it's a CLI command and requires a validator running.
30*/
31
32const DEFAULT_CHUNK_SIZE: usize = 10;
33
34pub struct ATAToCreate {
35    pub mint: Pubkey,
36    pub ata: Pubkey,
37    pub token_program: Pubkey,
38}
39
40impl Display for ATAToCreate {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        write!(f, "Token {}: ATA {} (Token program: {})", self.mint, self.ata, self.token_program)
43    }
44}
45
46/// Initialize ATAs for all allowed payment tokens for the paymaster
47/// This function initializes ATAs for ALL signers in the pool
48///
49/// Order of priority is:
50/// 1. Payment address provided in config
51/// 2. All signers in pool
52pub async fn initialize_atas(
53    rpc_client: &RpcClient,
54    compute_unit_price: Option<u64>,
55    compute_unit_limit: Option<u32>,
56    chunk_size: Option<usize>,
57    fee_payer_key: Option<String>,
58) -> Result<(), KoraError> {
59    let config = get_config()?;
60
61    let fee_payer = get_request_signer_with_signer_key(fee_payer_key.as_deref())?;
62
63    let addresses_to_initialize_atas = if let Some(payment_address) = &config.kora.payment_address {
64        vec![Pubkey::from_str(payment_address)
65            .map_err(|e| KoraError::InternalServerError(format!("Invalid payment address: {e}")))?]
66    } else {
67        get_signer_pool()?
68            .get_signers_info()
69            .iter()
70            .filter_map(|info| info.public_key.parse().ok())
71            .collect::<Vec<Pubkey>>()
72    };
73
74    initialize_atas_with_chunk_size(
75        rpc_client,
76        &fee_payer,
77        &addresses_to_initialize_atas,
78        compute_unit_price,
79        compute_unit_limit,
80        chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE),
81    )
82    .await
83}
84
85/// Initialize ATAs for all allowed payment tokens for the provided addresses with configurable chunk size
86/// This function does not use cache and directly checks on-chain
87pub async fn initialize_atas_with_chunk_size(
88    rpc_client: &RpcClient,
89    fee_payer: &Arc<solana_keychain::Signer>,
90    addresses_to_initialize_atas: &Vec<Pubkey>,
91    compute_unit_price: Option<u64>,
92    compute_unit_limit: Option<u32>,
93    chunk_size: usize,
94) -> Result<(), KoraError> {
95    let config = get_config()?;
96
97    for address in addresses_to_initialize_atas {
98        println!("Initializing ATAs for address: {address}");
99
100        #[allow(clippy::needless_borrow)]
101        let atas_to_create = find_missing_atas(&config, rpc_client, address).await?;
102
103        if atas_to_create.is_empty() {
104            println!("✓ All required ATAs already exist for address: {address}");
105            continue;
106        }
107
108        create_atas_for_signer(
109            rpc_client,
110            fee_payer,
111            address,
112            &atas_to_create,
113            compute_unit_price,
114            compute_unit_limit,
115            chunk_size,
116        )
117        .await?;
118    }
119
120    println!("✓ Successfully created all ATAs");
121
122    Ok(())
123}
124
125/// Helper function to create ATAs for a single signer
126async fn create_atas_for_signer(
127    rpc_client: &RpcClient,
128    fee_payer: &Arc<solana_keychain::Signer>,
129    address: &Pubkey,
130    atas_to_create: &[ATAToCreate],
131    compute_unit_price: Option<u64>,
132    compute_unit_limit: Option<u32>,
133    chunk_size: usize,
134) -> Result<usize, KoraError> {
135    let instructions = atas_to_create
136        .iter()
137        .map(|ata| {
138            create_associated_token_account(
139                &fee_payer.pubkey(),
140                address,
141                &ata.mint,
142                &ata.token_program,
143            )
144        })
145        .collect::<Vec<Instruction>>();
146
147    // Process instructions in chunks
148    let total_atas = instructions.len();
149    let chunks: Vec<_> = instructions.chunks(chunk_size).collect();
150    let num_chunks = chunks.len();
151
152    println!(
153        "Creating {total_atas} ATAs in {num_chunks} transaction(s) (chunk size: {chunk_size})..."
154    );
155
156    let mut created_atas_idx = 0;
157
158    for (chunk_idx, chunk) in chunks.iter().enumerate() {
159        let chunk_num = chunk_idx + 1;
160        println!("Processing chunk {chunk_num}/{num_chunks}");
161
162        // Build instructions for this chunk with compute budget
163        let mut chunk_instructions = Vec::new();
164
165        // Add compute budget instructions to each chunk
166        if let Some(compute_unit_price) = compute_unit_price {
167            chunk_instructions
168                .push(ComputeBudgetInstruction::set_compute_unit_price(compute_unit_price));
169        }
170        if let Some(compute_unit_limit) = compute_unit_limit {
171            chunk_instructions
172                .push(ComputeBudgetInstruction::set_compute_unit_limit(compute_unit_limit));
173        }
174
175        // Add the ATA creation instructions for this chunk
176        chunk_instructions.extend_from_slice(chunk);
177
178        let blockhash = rpc_client
179            .get_latest_blockhash()
180            .await
181            .map_err(|e| KoraError::RpcError(format!("Failed to get blockhash: {e}")))?;
182
183        let fee_payer_pubkey = fee_payer.pubkey();
184        let message = VersionedMessage::Legacy(Message::new_with_blockhash(
185            &chunk_instructions,
186            Some(&fee_payer_pubkey),
187            &blockhash,
188        ));
189
190        let mut tx = TransactionUtil::new_unsigned_versioned_transaction(message);
191        let message_bytes = tx.message.serialize();
192        let signature = fee_payer
193            .sign_message(&message_bytes)
194            .await
195            .map_err(|e| KoraError::SigningError(e.to_string()))?;
196
197        tx.signatures = vec![signature];
198
199        match rpc_client.send_and_confirm_transaction_with_spinner(&tx).await {
200            Ok(signature) => {
201                println!(
202                    "✓ Chunk {chunk_num}/{num_chunks} successful. Transaction signature: {signature}"
203                );
204
205                // Print the ATAs created in this chunk
206                let chunk_end = std::cmp::min(created_atas_idx + chunk.len(), atas_to_create.len());
207
208                (created_atas_idx..chunk_end).for_each(|i| {
209                    let ATAToCreate { mint, ata, token_program } = &atas_to_create[i];
210                    println!("  - Token {mint}: ATA {ata} (Token program: {token_program})");
211                });
212                created_atas_idx = chunk_end;
213            }
214            Err(e) => {
215                println!("✗ Chunk {chunk_num}/{num_chunks} failed: {e}");
216
217                if created_atas_idx > 0 {
218                    println!("\nSuccessfully created ATAs ({created_atas_idx}/{total_atas}):");
219                    println!(
220                        "{}",
221                        atas_to_create[0..created_atas_idx]
222                            .iter()
223                            .map(|ata| format!("  ✓ {ata}"))
224                            .collect::<Vec<String>>()
225                            .join("\n")
226                    );
227                    println!("\nRemaining ATAs to create: {}", total_atas - created_atas_idx);
228                } else {
229                    println!("No ATAs were successfully created.");
230                }
231
232                println!("This may be a temporary network issue. Please re-run the command to retry ATA creation.");
233                return Err(KoraError::RpcError(format!(
234                    "Failed to send ATA creation transaction for chunk {chunk_num}/{num_chunks}: {e}"
235                )));
236            }
237        }
238    }
239
240    // Show summary of all successfully created ATAs
241    println!("\n🎉 All ATA creation completed successfully!");
242    println!("Successfully created ATAs ({total_atas}/{total_atas}):");
243    println!(
244        "{}",
245        atas_to_create.iter().map(|ata| format!("  ✓ {ata}")).collect::<Vec<String>>().join("\n")
246    );
247
248    Ok(total_atas)
249}
250
251pub async fn find_missing_atas(
252    config: &Config,
253    rpc_client: &RpcClient,
254    payment_address: &Pubkey,
255) -> Result<Vec<ATAToCreate>, KoraError> {
256    // Parse all allowed SPL paid token mints
257    let mut token_mints = Vec::new();
258    for token_str in &config.validation.allowed_spl_paid_tokens {
259        match Pubkey::from_str(token_str) {
260            Ok(mint) => token_mints.push(mint),
261            Err(_) => {
262                println!("⚠️  Skipping invalid token mint: {token_str}");
263                continue;
264            }
265        }
266    }
267
268    if token_mints.is_empty() {
269        println!("✓ No SPL payment tokens configured");
270        return Ok(Vec::new());
271    }
272
273    let mut atas_to_create = Vec::new();
274
275    // Check each token mint for existing ATA
276    for mint in &token_mints {
277        let ata = get_associated_token_address(payment_address, mint);
278
279        match CacheUtil::get_account(config, rpc_client, &ata, false).await {
280            Ok(_) => {
281                println!("✓ ATA already exists for token {mint}: {ata}");
282            }
283            Err(_) => {
284                // Fetch mint account to determine if it's SPL or Token2022
285                let mint_account =
286                    CacheUtil::get_account(config, rpc_client, mint, false).await.map_err(|e| {
287                        KoraError::RpcError(format!("Failed to fetch mint account for {mint}: {e}"))
288                    })?;
289
290                let token_program = TokenType::get_token_program_from_owner(&mint_account.owner)?;
291
292                println!("Creating ATA for token {mint}: {ata}");
293
294                atas_to_create.push(ATAToCreate {
295                    mint: *mint,
296                    ata,
297                    token_program: token_program.program_id(),
298                });
299            }
300        }
301    }
302
303    Ok(atas_to_create)
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use crate::tests::{
310        common::{
311            create_mock_rpc_client_account_not_found, create_mock_spl_mint_account,
312            create_mock_token_account, setup_or_get_test_signer, RpcMockBuilder,
313        },
314        config_mock::{ConfigMockBuilder, ValidationConfigBuilder},
315    };
316    use std::{
317        collections::VecDeque,
318        sync::{Arc, Mutex},
319    };
320
321    #[tokio::test]
322    async fn test_find_missing_atas_no_spl_tokens() {
323        let _m = ConfigMockBuilder::new()
324            .with_validation(
325                ValidationConfigBuilder::new()
326                    .with_allowed_spl_paid_tokens(SplTokenConfig::Allowlist(vec![]))
327                    .build(),
328            )
329            .build_and_setup();
330
331        let rpc_client = create_mock_rpc_client_account_not_found();
332        let payment_address = Pubkey::new_unique();
333
334        let config = get_config().unwrap();
335        let result = find_missing_atas(&config, &rpc_client, &payment_address).await.unwrap();
336
337        assert!(result.is_empty(), "Should return empty vec when no SPL tokens configured");
338    }
339
340    #[tokio::test]
341    async fn test_find_missing_atas_with_spl_tokens() {
342        let allowed_spl_tokens = [Pubkey::new_unique(), Pubkey::new_unique()];
343
344        let _m = ConfigMockBuilder::new()
345            .with_validation(
346                ValidationConfigBuilder::new()
347                    .with_allowed_spl_paid_tokens(SplTokenConfig::Allowlist(
348                        allowed_spl_tokens.iter().map(|p| p.to_string()).collect(),
349                    ))
350                    .build(),
351            )
352            .build_and_setup();
353
354        let cache_ctx = CacheUtil::get_account_context();
355        cache_ctx.checkpoint(); // Clear any previous expectations
356
357        let payment_address = Pubkey::new_unique();
358        let rpc_client = create_mock_rpc_client_account_not_found();
359
360        // First call: Found in cache (Ok)
361        // Second call: ATA account not found (Err)
362        // Third call: mint account found (Ok)
363        let responses = Arc::new(Mutex::new(VecDeque::from([
364            Ok(create_mock_token_account(&Pubkey::new_unique(), &Pubkey::new_unique())),
365            Err(KoraError::RpcError("ATA not found".to_string())),
366            Ok(create_mock_spl_mint_account(6)),
367        ])));
368
369        let responses_clone = responses.clone();
370        cache_ctx
371            .expect()
372            .times(3)
373            .returning(move |_, _, _, _| responses_clone.lock().unwrap().pop_front().unwrap());
374
375        let config = get_config().unwrap();
376        let result = find_missing_atas(&config, &rpc_client, &payment_address).await;
377
378        assert!(result.is_ok(), "Should handle SPL tokens with proper mocking");
379        let atas = result.unwrap();
380        assert_eq!(atas.len(), 1, "Should return 1 missing ATAs");
381    }
382
383    #[tokio::test]
384    async fn test_create_atas_for_signer_calls_rpc_correctly() {
385        let _m = ConfigMockBuilder::new().build_and_setup();
386
387        let _ = setup_or_get_test_signer();
388
389        let address = Pubkey::new_unique();
390        let mint1 = Pubkey::new_unique();
391        let mint2 = Pubkey::new_unique();
392
393        let atas_to_create = vec![
394            ATAToCreate {
395                mint: mint1,
396                ata: spl_associated_token_account_interface::address::get_associated_token_address(
397                    &address, &mint1,
398                ),
399                token_program: spl_token_interface::id(),
400            },
401            ATAToCreate {
402                mint: mint2,
403                ata: spl_associated_token_account_interface::address::get_associated_token_address(
404                    &address, &mint2,
405                ),
406                token_program: spl_token_interface::id(),
407            },
408        ];
409
410        let rpc_client = RpcMockBuilder::new().with_blockhash().with_send_transaction().build();
411
412        let result = create_atas_for_signer(
413            &rpc_client,
414            &get_request_signer_with_signer_key(None).unwrap(),
415            &address,
416            &atas_to_create,
417            Some(1000),
418            Some(100_000),
419            2,
420        )
421        .await;
422
423        // Should fail with signature validation error since mock signature doesn't match real transaction
424        match result {
425            Ok(_) => {
426                panic!("Expected signature validation error, but got success");
427            }
428            Err(e) => {
429                let error_msg = format!("{e:?}");
430                // Check if it's a signature validation error (the mocked signature doesn't match the real transaction signature)
431                assert!(
432                    error_msg.contains("signature")
433                        || error_msg.contains("Signature")
434                        || error_msg.contains("invalid")
435                        || error_msg.contains("mismatch"),
436                    "Expected signature validation error, got: {error_msg}"
437                );
438            }
439        }
440    }
441
442    #[tokio::test]
443    async fn test_initialize_atas_when_all_tokens_are_allowed() {
444        let _m = ConfigMockBuilder::new()
445            .with_allowed_spl_paid_tokens(SplTokenConfig::All)
446            .build_and_setup();
447
448        let _ = setup_or_get_test_signer();
449
450        let rpc_client = RpcMockBuilder::new().build();
451
452        let result = initialize_atas(&rpc_client, None, None, None, None).await;
453
454        assert!(result.is_ok(), "Expected atas init to succeed");
455    }
456}