kora_lib/admin/
token_util.rs

1use crate::{
2    error::KoraError,
3    state::{get_request_signer_with_signer_key, get_signer_pool},
4    token::token::TokenType,
5    transaction::TransactionUtil,
6};
7use solana_client::nonblocking::rpc_client::RpcClient;
8use solana_compute_budget_interface::ComputeBudgetInstruction;
9use solana_keychain::SolanaSigner;
10use solana_message::{Message, VersionedMessage};
11use solana_sdk::{instruction::Instruction, pubkey::Pubkey};
12
13use spl_associated_token_account_interface::{
14    address::get_associated_token_address, instruction::create_associated_token_account,
15};
16use std::{fmt::Display, str::FromStr, sync::Arc};
17
18#[cfg(not(test))]
19use {crate::cache::CacheUtil, crate::state::get_config};
20
21#[cfg(test)]
22use {
23    crate::config::SplTokenConfig, crate::tests::cache_mock::MockCacheUtil as CacheUtil,
24    crate::tests::config_mock::mock_state::get_config,
25};
26
27/*
28This funciton is tested via the makefile, as it's a CLI command and requires a validator running.
29*/
30
31const DEFAULT_CHUNK_SIZE: usize = 10;
32
33pub struct ATAToCreate {
34    pub mint: Pubkey,
35    pub ata: Pubkey,
36    pub token_program: Pubkey,
37}
38
39impl Display for ATAToCreate {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "Token {}: ATA {} (Token program: {})", self.mint, self.ata, self.token_program)
42    }
43}
44
45/// Initialize ATAs for all allowed payment tokens for the paymaster
46/// This function initializes ATAs for ALL signers in the pool
47///
48/// Order of priority is:
49/// 1. Payment address provided in config
50/// 2. All signers in pool
51pub async fn initialize_atas(
52    rpc_client: &RpcClient,
53    compute_unit_price: Option<u64>,
54    compute_unit_limit: Option<u32>,
55    chunk_size: Option<usize>,
56    fee_payer_key: Option<String>,
57) -> Result<(), KoraError> {
58    let config = get_config()?;
59
60    let fee_payer = get_request_signer_with_signer_key(fee_payer_key.as_deref())?;
61
62    let addresses_to_initialize_atas = if let Some(payment_address) = &config.kora.payment_address {
63        vec![Pubkey::from_str(payment_address)
64            .map_err(|e| KoraError::InternalServerError(format!("Invalid payment address: {e}")))?]
65    } else {
66        get_signer_pool()?
67            .get_signers_info()
68            .iter()
69            .filter_map(|info| info.public_key.parse().ok())
70            .collect::<Vec<Pubkey>>()
71    };
72
73    initialize_atas_with_chunk_size(
74        rpc_client,
75        &fee_payer,
76        &addresses_to_initialize_atas,
77        compute_unit_price,
78        compute_unit_limit,
79        chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE),
80    )
81    .await
82}
83
84/// Initialize ATAs for all allowed payment tokens for the provided addresses with configurable chunk size
85/// This function does not use cache and directly checks on-chain
86pub async fn initialize_atas_with_chunk_size(
87    rpc_client: &RpcClient,
88    fee_payer: &Arc<solana_keychain::Signer>,
89    addresses_to_initialize_atas: &Vec<Pubkey>,
90    compute_unit_price: Option<u64>,
91    compute_unit_limit: Option<u32>,
92    chunk_size: usize,
93) -> Result<(), KoraError> {
94    for address in addresses_to_initialize_atas {
95        println!("Initializing ATAs for address: {address}");
96
97        let atas_to_create = find_missing_atas(rpc_client, address).await?;
98
99        if atas_to_create.is_empty() {
100            println!("✓ All required ATAs already exist for address: {address}");
101            continue;
102        }
103
104        create_atas_for_signer(
105            rpc_client,
106            fee_payer,
107            address,
108            &atas_to_create,
109            compute_unit_price,
110            compute_unit_limit,
111            chunk_size,
112        )
113        .await?;
114    }
115
116    println!("✓ Successfully created all ATAs");
117
118    Ok(())
119}
120
121/// Helper function to create ATAs for a single signer
122async fn create_atas_for_signer(
123    rpc_client: &RpcClient,
124    fee_payer: &Arc<solana_keychain::Signer>,
125    address: &Pubkey,
126    atas_to_create: &[ATAToCreate],
127    compute_unit_price: Option<u64>,
128    compute_unit_limit: Option<u32>,
129    chunk_size: usize,
130) -> Result<usize, KoraError> {
131    let instructions = atas_to_create
132        .iter()
133        .map(|ata| {
134            create_associated_token_account(
135                &fee_payer.pubkey(),
136                address,
137                &ata.mint,
138                &ata.token_program,
139            )
140        })
141        .collect::<Vec<Instruction>>();
142
143    // Process instructions in chunks
144    let total_atas = instructions.len();
145    let chunks: Vec<_> = instructions.chunks(chunk_size).collect();
146    let num_chunks = chunks.len();
147
148    println!(
149        "Creating {total_atas} ATAs in {num_chunks} transaction(s) (chunk size: {chunk_size})..."
150    );
151
152    let mut created_atas_idx = 0;
153
154    for (chunk_idx, chunk) in chunks.iter().enumerate() {
155        let chunk_num = chunk_idx + 1;
156        println!("Processing chunk {chunk_num}/{num_chunks}");
157
158        // Build instructions for this chunk with compute budget
159        let mut chunk_instructions = Vec::new();
160
161        // Add compute budget instructions to each chunk
162        if let Some(compute_unit_price) = compute_unit_price {
163            chunk_instructions
164                .push(ComputeBudgetInstruction::set_compute_unit_price(compute_unit_price));
165        }
166        if let Some(compute_unit_limit) = compute_unit_limit {
167            chunk_instructions
168                .push(ComputeBudgetInstruction::set_compute_unit_limit(compute_unit_limit));
169        }
170
171        // Add the ATA creation instructions for this chunk
172        chunk_instructions.extend_from_slice(chunk);
173
174        let blockhash = rpc_client
175            .get_latest_blockhash()
176            .await
177            .map_err(|e| KoraError::RpcError(format!("Failed to get blockhash: {e}")))?;
178
179        let fee_payer_pubkey = fee_payer.pubkey();
180        let message = VersionedMessage::Legacy(Message::new_with_blockhash(
181            &chunk_instructions,
182            Some(&fee_payer_pubkey),
183            &blockhash,
184        ));
185
186        let mut tx = TransactionUtil::new_unsigned_versioned_transaction(message);
187        let message_bytes = tx.message.serialize();
188        let signature = fee_payer
189            .sign_message(&message_bytes)
190            .await
191            .map_err(|e| KoraError::SigningError(e.to_string()))?;
192
193        tx.signatures = vec![signature];
194
195        match rpc_client.send_and_confirm_transaction_with_spinner(&tx).await {
196            Ok(signature) => {
197                println!(
198                    "✓ Chunk {chunk_num}/{num_chunks} successful. Transaction signature: {signature}"
199                );
200
201                // Print the ATAs created in this chunk
202                let chunk_end = std::cmp::min(created_atas_idx + chunk.len(), atas_to_create.len());
203
204                (created_atas_idx..chunk_end).for_each(|i| {
205                    let ATAToCreate { mint, ata, token_program } = &atas_to_create[i];
206                    println!("  - Token {mint}: ATA {ata} (Token program: {token_program})");
207                });
208                created_atas_idx = chunk_end;
209            }
210            Err(e) => {
211                println!("✗ Chunk {chunk_num}/{num_chunks} failed: {e}");
212
213                if created_atas_idx > 0 {
214                    println!("\nSuccessfully created ATAs ({created_atas_idx}/{total_atas}):");
215                    println!(
216                        "{}",
217                        atas_to_create[0..created_atas_idx]
218                            .iter()
219                            .map(|ata| format!("  ✓ {ata}"))
220                            .collect::<Vec<String>>()
221                            .join("\n")
222                    );
223                    println!("\nRemaining ATAs to create: {}", total_atas - created_atas_idx);
224                } else {
225                    println!("No ATAs were successfully created.");
226                }
227
228                println!("This may be a temporary network issue. Please re-run the command to retry ATA creation.");
229                return Err(KoraError::RpcError(format!(
230                    "Failed to send ATA creation transaction for chunk {chunk_num}/{num_chunks}: {e}"
231                )));
232            }
233        }
234    }
235
236    // Show summary of all successfully created ATAs
237    println!("\n🎉 All ATA creation completed successfully!");
238    println!("Successfully created ATAs ({total_atas}/{total_atas}):");
239    println!(
240        "{}",
241        atas_to_create.iter().map(|ata| format!("  ✓ {ata}")).collect::<Vec<String>>().join("\n")
242    );
243
244    Ok(total_atas)
245}
246
247pub async fn find_missing_atas(
248    rpc_client: &RpcClient,
249    payment_address: &Pubkey,
250) -> Result<Vec<ATAToCreate>, KoraError> {
251    let config = get_config()?;
252
253    // Parse all allowed SPL paid token mints
254    let mut token_mints = Vec::new();
255    for token_str in &config.validation.allowed_spl_paid_tokens {
256        match Pubkey::from_str(token_str) {
257            Ok(mint) => token_mints.push(mint),
258            Err(_) => {
259                println!("⚠️  Skipping invalid token mint: {token_str}");
260                continue;
261            }
262        }
263    }
264
265    if token_mints.is_empty() {
266        println!("✓ No SPL payment tokens configured");
267        return Ok(Vec::new());
268    }
269
270    let mut atas_to_create = Vec::new();
271
272    // Check each token mint for existing ATA
273    for mint in &token_mints {
274        let ata = get_associated_token_address(payment_address, mint);
275
276        match CacheUtil::get_account(rpc_client, &ata, false).await {
277            Ok(_) => {
278                println!("✓ ATA already exists for token {mint}: {ata}");
279            }
280            Err(_) => {
281                // Fetch mint account to determine if it's SPL or Token2022
282                let mint_account =
283                    CacheUtil::get_account(rpc_client, mint, false).await.map_err(|e| {
284                        KoraError::RpcError(format!("Failed to fetch mint account for {mint}: {e}"))
285                    })?;
286
287                let token_program = TokenType::get_token_program_from_owner(&mint_account.owner)?;
288
289                println!("Creating ATA for token {mint}: {ata}");
290
291                atas_to_create.push(ATAToCreate {
292                    mint: *mint,
293                    ata,
294                    token_program: token_program.program_id(),
295                });
296            }
297        }
298    }
299
300    Ok(atas_to_create)
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use crate::tests::{
307        common::{
308            create_mock_rpc_client_account_not_found, create_mock_spl_mint_account,
309            create_mock_token_account, setup_or_get_test_signer, RpcMockBuilder,
310        },
311        config_mock::{ConfigMockBuilder, ValidationConfigBuilder},
312    };
313    use std::{
314        collections::VecDeque,
315        sync::{Arc, Mutex},
316    };
317
318    #[tokio::test]
319    async fn test_find_missing_atas_no_spl_tokens() {
320        let _m = ConfigMockBuilder::new()
321            .with_validation(
322                ValidationConfigBuilder::new()
323                    .with_allowed_spl_paid_tokens(SplTokenConfig::Allowlist(vec![]))
324                    .build(),
325            )
326            .build_and_setup();
327
328        let rpc_client = create_mock_rpc_client_account_not_found();
329        let payment_address = Pubkey::new_unique();
330
331        let result = find_missing_atas(&rpc_client, &payment_address).await.unwrap();
332
333        assert!(result.is_empty(), "Should return empty vec when no SPL tokens configured");
334    }
335
336    #[tokio::test]
337    async fn test_find_missing_atas_with_spl_tokens() {
338        let allowed_spl_tokens = [Pubkey::new_unique(), Pubkey::new_unique()];
339
340        let _m = ConfigMockBuilder::new()
341            .with_validation(
342                ValidationConfigBuilder::new()
343                    .with_allowed_spl_paid_tokens(SplTokenConfig::Allowlist(
344                        allowed_spl_tokens.iter().map(|p| p.to_string()).collect(),
345                    ))
346                    .build(),
347            )
348            .build_and_setup();
349
350        let cache_ctx = CacheUtil::get_account_context();
351        cache_ctx.checkpoint(); // Clear any previous expectations
352
353        let payment_address = Pubkey::new_unique();
354        let rpc_client = create_mock_rpc_client_account_not_found();
355
356        // First call: Found in cache (Ok)
357        // Second call: ATA account not found (Err)
358        // Third call: mint account found (Ok)
359        let responses = Arc::new(Mutex::new(VecDeque::from([
360            Ok(create_mock_token_account(&Pubkey::new_unique(), &Pubkey::new_unique())),
361            Err(KoraError::RpcError("ATA not found".to_string())),
362            Ok(create_mock_spl_mint_account(6)),
363        ])));
364
365        let responses_clone = responses.clone();
366        cache_ctx
367            .expect()
368            .times(3)
369            .returning(move |_, _, _| responses_clone.lock().unwrap().pop_front().unwrap());
370
371        let result = find_missing_atas(&rpc_client, &payment_address).await;
372
373        assert!(result.is_ok(), "Should handle SPL tokens with proper mocking");
374        let atas = result.unwrap();
375        assert_eq!(atas.len(), 1, "Should return 1 missing ATAs");
376    }
377
378    #[tokio::test]
379    async fn test_create_atas_for_signer_calls_rpc_correctly() {
380        let _m = ConfigMockBuilder::new().build_and_setup();
381
382        let _ = setup_or_get_test_signer();
383
384        let address = Pubkey::new_unique();
385        let mint1 = Pubkey::new_unique();
386        let mint2 = Pubkey::new_unique();
387
388        let atas_to_create = vec![
389            ATAToCreate {
390                mint: mint1,
391                ata: spl_associated_token_account_interface::address::get_associated_token_address(
392                    &address, &mint1,
393                ),
394                token_program: spl_token_interface::id(),
395            },
396            ATAToCreate {
397                mint: mint2,
398                ata: spl_associated_token_account_interface::address::get_associated_token_address(
399                    &address, &mint2,
400                ),
401                token_program: spl_token_interface::id(),
402            },
403        ];
404
405        let rpc_client = RpcMockBuilder::new().with_blockhash().with_send_transaction().build();
406
407        let result = create_atas_for_signer(
408            &rpc_client,
409            &get_request_signer_with_signer_key(None).unwrap(),
410            &address,
411            &atas_to_create,
412            Some(1000),
413            Some(100_000),
414            2,
415        )
416        .await;
417
418        // Should fail with signature validation error since mock signature doesn't match real transaction
419        match result {
420            Ok(_) => {
421                panic!("Expected signature validation error, but got success");
422            }
423            Err(e) => {
424                let error_msg = format!("{e:?}");
425                // Check if it's a signature validation error (the mocked signature doesn't match the real transaction signature)
426                assert!(
427                    error_msg.contains("signature")
428                        || error_msg.contains("Signature")
429                        || error_msg.contains("invalid")
430                        || error_msg.contains("mismatch"),
431                    "Expected signature validation error, got: {error_msg}"
432                );
433            }
434        }
435    }
436
437    #[tokio::test]
438    async fn test_initialize_atas_when_all_tokens_are_allowed() {
439        let _m = ConfigMockBuilder::new()
440            .with_allowed_spl_paid_tokens(SplTokenConfig::All)
441            .build_and_setup();
442
443        let _ = setup_or_get_test_signer();
444
445        let rpc_client = RpcMockBuilder::new().build();
446
447        let result = initialize_atas(&rpc_client, None, None, None, None).await;
448
449        assert!(result.is_ok(), "Expected atas init to succeed");
450    }
451}