1use crate::{
2 config::Config,
3 error::KoraError,
4 state::{get_request_signer_with_signer_key, get_signer_pool},
5 token::token::TokenType,
6 transaction::TransactionUtil,
7};
8use solana_client::nonblocking::rpc_client::RpcClient;
9use solana_compute_budget_interface::ComputeBudgetInstruction;
10use solana_keychain::SolanaSigner;
11use solana_message::{Message, VersionedMessage};
12use solana_sdk::{instruction::Instruction, pubkey::Pubkey};
13
14use spl_associated_token_account_interface::{
15 address::get_associated_token_address, instruction::create_associated_token_account,
16};
17use std::{fmt::Display, str::FromStr, sync::Arc};
18
19#[cfg(not(test))]
20use {crate::cache::CacheUtil, crate::state::get_config};
21
22#[cfg(test)]
23use {
24 crate::config::SplTokenConfig, crate::tests::cache_mock::MockCacheUtil as CacheUtil,
25 crate::tests::config_mock::mock_state::get_config,
26};
27
28const DEFAULT_CHUNK_SIZE: usize = 10;
33
34pub struct ATAToCreate {
35 pub mint: Pubkey,
36 pub ata: Pubkey,
37 pub token_program: Pubkey,
38}
39
40impl Display for ATAToCreate {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 write!(f, "Token {}: ATA {} (Token program: {})", self.mint, self.ata, self.token_program)
43 }
44}
45
46pub async fn initialize_atas(
53 rpc_client: &RpcClient,
54 compute_unit_price: Option<u64>,
55 compute_unit_limit: Option<u32>,
56 chunk_size: Option<usize>,
57 fee_payer_key: Option<String>,
58) -> Result<(), KoraError> {
59 let config = get_config()?;
60
61 let fee_payer = get_request_signer_with_signer_key(fee_payer_key.as_deref())?;
62
63 let addresses_to_initialize_atas = if let Some(payment_address) = &config.kora.payment_address {
64 vec![Pubkey::from_str(payment_address)
65 .map_err(|e| KoraError::InternalServerError(format!("Invalid payment address: {e}")))?]
66 } else {
67 get_signer_pool()?
68 .get_signers_info()
69 .iter()
70 .filter_map(|info| info.public_key.parse().ok())
71 .collect::<Vec<Pubkey>>()
72 };
73
74 initialize_atas_with_chunk_size(
75 rpc_client,
76 &fee_payer,
77 &addresses_to_initialize_atas,
78 compute_unit_price,
79 compute_unit_limit,
80 chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE),
81 )
82 .await
83}
84
85pub async fn initialize_atas_with_chunk_size(
88 rpc_client: &RpcClient,
89 fee_payer: &Arc<solana_keychain::Signer>,
90 addresses_to_initialize_atas: &Vec<Pubkey>,
91 compute_unit_price: Option<u64>,
92 compute_unit_limit: Option<u32>,
93 chunk_size: usize,
94) -> Result<(), KoraError> {
95 let config = get_config()?;
96
97 for address in addresses_to_initialize_atas {
98 println!("Initializing ATAs for address: {address}");
99
100 #[allow(clippy::needless_borrow)]
101 let atas_to_create = find_missing_atas(&config, rpc_client, address).await?;
102
103 if atas_to_create.is_empty() {
104 println!("✓ All required ATAs already exist for address: {address}");
105 continue;
106 }
107
108 create_atas_for_signer(
109 rpc_client,
110 fee_payer,
111 address,
112 &atas_to_create,
113 compute_unit_price,
114 compute_unit_limit,
115 chunk_size,
116 )
117 .await?;
118 }
119
120 println!("✓ Successfully created all ATAs");
121
122 Ok(())
123}
124
125async fn create_atas_for_signer(
127 rpc_client: &RpcClient,
128 fee_payer: &Arc<solana_keychain::Signer>,
129 address: &Pubkey,
130 atas_to_create: &[ATAToCreate],
131 compute_unit_price: Option<u64>,
132 compute_unit_limit: Option<u32>,
133 chunk_size: usize,
134) -> Result<usize, KoraError> {
135 let instructions = atas_to_create
136 .iter()
137 .map(|ata| {
138 create_associated_token_account(
139 &fee_payer.pubkey(),
140 address,
141 &ata.mint,
142 &ata.token_program,
143 )
144 })
145 .collect::<Vec<Instruction>>();
146
147 let total_atas = instructions.len();
149 let chunks: Vec<_> = instructions.chunks(chunk_size).collect();
150 let num_chunks = chunks.len();
151
152 println!(
153 "Creating {total_atas} ATAs in {num_chunks} transaction(s) (chunk size: {chunk_size})..."
154 );
155
156 let mut created_atas_idx = 0;
157
158 for (chunk_idx, chunk) in chunks.iter().enumerate() {
159 let chunk_num = chunk_idx + 1;
160 println!("Processing chunk {chunk_num}/{num_chunks}");
161
162 let mut chunk_instructions = Vec::new();
164
165 if let Some(compute_unit_price) = compute_unit_price {
167 chunk_instructions
168 .push(ComputeBudgetInstruction::set_compute_unit_price(compute_unit_price));
169 }
170 if let Some(compute_unit_limit) = compute_unit_limit {
171 chunk_instructions
172 .push(ComputeBudgetInstruction::set_compute_unit_limit(compute_unit_limit));
173 }
174
175 chunk_instructions.extend_from_slice(chunk);
177
178 let blockhash = rpc_client
179 .get_latest_blockhash()
180 .await
181 .map_err(|e| KoraError::RpcError(format!("Failed to get blockhash: {e}")))?;
182
183 let fee_payer_pubkey = fee_payer.pubkey();
184 let message = VersionedMessage::Legacy(Message::new_with_blockhash(
185 &chunk_instructions,
186 Some(&fee_payer_pubkey),
187 &blockhash,
188 ));
189
190 let mut tx = TransactionUtil::new_unsigned_versioned_transaction(message);
191 let message_bytes = tx.message.serialize();
192 let signature = fee_payer
193 .sign_message(&message_bytes)
194 .await
195 .map_err(|e| KoraError::SigningError(e.to_string()))?;
196
197 tx.signatures = vec![signature];
198
199 match rpc_client.send_and_confirm_transaction_with_spinner(&tx).await {
200 Ok(signature) => {
201 println!(
202 "✓ Chunk {chunk_num}/{num_chunks} successful. Transaction signature: {signature}"
203 );
204
205 let chunk_end = std::cmp::min(created_atas_idx + chunk.len(), atas_to_create.len());
207
208 (created_atas_idx..chunk_end).for_each(|i| {
209 let ATAToCreate { mint, ata, token_program } = &atas_to_create[i];
210 println!(" - Token {mint}: ATA {ata} (Token program: {token_program})");
211 });
212 created_atas_idx = chunk_end;
213 }
214 Err(e) => {
215 println!("✗ Chunk {chunk_num}/{num_chunks} failed: {e}");
216
217 if created_atas_idx > 0 {
218 println!("\nSuccessfully created ATAs ({created_atas_idx}/{total_atas}):");
219 println!(
220 "{}",
221 atas_to_create[0..created_atas_idx]
222 .iter()
223 .map(|ata| format!(" ✓ {ata}"))
224 .collect::<Vec<String>>()
225 .join("\n")
226 );
227 println!("\nRemaining ATAs to create: {}", total_atas - created_atas_idx);
228 } else {
229 println!("No ATAs were successfully created.");
230 }
231
232 println!("This may be a temporary network issue. Please re-run the command to retry ATA creation.");
233 return Err(KoraError::RpcError(format!(
234 "Failed to send ATA creation transaction for chunk {chunk_num}/{num_chunks}: {e}"
235 )));
236 }
237 }
238 }
239
240 println!("\n🎉 All ATA creation completed successfully!");
242 println!("Successfully created ATAs ({total_atas}/{total_atas}):");
243 println!(
244 "{}",
245 atas_to_create.iter().map(|ata| format!(" ✓ {ata}")).collect::<Vec<String>>().join("\n")
246 );
247
248 Ok(total_atas)
249}
250
251pub async fn find_missing_atas(
252 config: &Config,
253 rpc_client: &RpcClient,
254 payment_address: &Pubkey,
255) -> Result<Vec<ATAToCreate>, KoraError> {
256 let mut token_mints = Vec::new();
258 for token_str in &config.validation.allowed_spl_paid_tokens {
259 match Pubkey::from_str(token_str) {
260 Ok(mint) => token_mints.push(mint),
261 Err(_) => {
262 println!("⚠️ Skipping invalid token mint: {token_str}");
263 continue;
264 }
265 }
266 }
267
268 if token_mints.is_empty() {
269 println!("✓ No SPL payment tokens configured");
270 return Ok(Vec::new());
271 }
272
273 let mut atas_to_create = Vec::new();
274
275 for mint in &token_mints {
277 let ata = get_associated_token_address(payment_address, mint);
278
279 match CacheUtil::get_account(config, rpc_client, &ata, false).await {
280 Ok(_) => {
281 println!("✓ ATA already exists for token {mint}: {ata}");
282 }
283 Err(_) => {
284 let mint_account =
286 CacheUtil::get_account(config, rpc_client, mint, false).await.map_err(|e| {
287 KoraError::RpcError(format!("Failed to fetch mint account for {mint}: {e}"))
288 })?;
289
290 let token_program = TokenType::get_token_program_from_owner(&mint_account.owner)?;
291
292 println!("Creating ATA for token {mint}: {ata}");
293
294 atas_to_create.push(ATAToCreate {
295 mint: *mint,
296 ata,
297 token_program: token_program.program_id(),
298 });
299 }
300 }
301 }
302
303 Ok(atas_to_create)
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use crate::tests::{
310 common::{
311 create_mock_rpc_client_account_not_found, create_mock_spl_mint_account,
312 create_mock_token_account, setup_or_get_test_signer, RpcMockBuilder,
313 },
314 config_mock::{ConfigMockBuilder, ValidationConfigBuilder},
315 };
316 use std::{
317 collections::VecDeque,
318 sync::{Arc, Mutex},
319 };
320
321 #[tokio::test]
322 async fn test_find_missing_atas_no_spl_tokens() {
323 let _m = ConfigMockBuilder::new()
324 .with_validation(
325 ValidationConfigBuilder::new()
326 .with_allowed_spl_paid_tokens(SplTokenConfig::Allowlist(vec![]))
327 .build(),
328 )
329 .build_and_setup();
330
331 let rpc_client = create_mock_rpc_client_account_not_found();
332 let payment_address = Pubkey::new_unique();
333
334 let config = get_config().unwrap();
335 let result = find_missing_atas(&config, &rpc_client, &payment_address).await.unwrap();
336
337 assert!(result.is_empty(), "Should return empty vec when no SPL tokens configured");
338 }
339
340 #[tokio::test]
341 async fn test_find_missing_atas_with_spl_tokens() {
342 let allowed_spl_tokens = [Pubkey::new_unique(), Pubkey::new_unique()];
343
344 let _m = ConfigMockBuilder::new()
345 .with_validation(
346 ValidationConfigBuilder::new()
347 .with_allowed_spl_paid_tokens(SplTokenConfig::Allowlist(
348 allowed_spl_tokens.iter().map(|p| p.to_string()).collect(),
349 ))
350 .build(),
351 )
352 .build_and_setup();
353
354 let cache_ctx = CacheUtil::get_account_context();
355 cache_ctx.checkpoint(); let payment_address = Pubkey::new_unique();
358 let rpc_client = create_mock_rpc_client_account_not_found();
359
360 let responses = Arc::new(Mutex::new(VecDeque::from([
364 Ok(create_mock_token_account(&Pubkey::new_unique(), &Pubkey::new_unique())),
365 Err(KoraError::RpcError("ATA not found".to_string())),
366 Ok(create_mock_spl_mint_account(6)),
367 ])));
368
369 let responses_clone = responses.clone();
370 cache_ctx
371 .expect()
372 .times(3)
373 .returning(move |_, _, _, _| responses_clone.lock().unwrap().pop_front().unwrap());
374
375 let config = get_config().unwrap();
376 let result = find_missing_atas(&config, &rpc_client, &payment_address).await;
377
378 assert!(result.is_ok(), "Should handle SPL tokens with proper mocking");
379 let atas = result.unwrap();
380 assert_eq!(atas.len(), 1, "Should return 1 missing ATAs");
381 }
382
383 #[tokio::test]
384 async fn test_create_atas_for_signer_calls_rpc_correctly() {
385 let _m = ConfigMockBuilder::new().build_and_setup();
386
387 let _ = setup_or_get_test_signer();
388
389 let address = Pubkey::new_unique();
390 let mint1 = Pubkey::new_unique();
391 let mint2 = Pubkey::new_unique();
392
393 let atas_to_create = vec![
394 ATAToCreate {
395 mint: mint1,
396 ata: spl_associated_token_account_interface::address::get_associated_token_address(
397 &address, &mint1,
398 ),
399 token_program: spl_token_interface::id(),
400 },
401 ATAToCreate {
402 mint: mint2,
403 ata: spl_associated_token_account_interface::address::get_associated_token_address(
404 &address, &mint2,
405 ),
406 token_program: spl_token_interface::id(),
407 },
408 ];
409
410 let rpc_client = RpcMockBuilder::new().with_blockhash().with_send_transaction().build();
411
412 let result = create_atas_for_signer(
413 &rpc_client,
414 &get_request_signer_with_signer_key(None).unwrap(),
415 &address,
416 &atas_to_create,
417 Some(1000),
418 Some(100_000),
419 2,
420 )
421 .await;
422
423 match result {
425 Ok(_) => {
426 panic!("Expected signature validation error, but got success");
427 }
428 Err(e) => {
429 let error_msg = format!("{e:?}");
430 assert!(
432 error_msg.contains("signature")
433 || error_msg.contains("Signature")
434 || error_msg.contains("invalid")
435 || error_msg.contains("mismatch"),
436 "Expected signature validation error, got: {error_msg}"
437 );
438 }
439 }
440 }
441
442 #[tokio::test]
443 async fn test_initialize_atas_when_all_tokens_are_allowed() {
444 let _m = ConfigMockBuilder::new()
445 .with_allowed_spl_paid_tokens(SplTokenConfig::All)
446 .build_and_setup();
447
448 let _ = setup_or_get_test_signer();
449
450 let rpc_client = RpcMockBuilder::new().build();
451
452 let result = initialize_atas(&rpc_client, None, None, None, None).await;
453
454 assert!(result.is_ok(), "Expected atas init to succeed");
455 }
456}