1use crate::{
2 error::KoraError,
3 state::{get_request_signer_with_signer_key, get_signer_pool},
4 token::token::TokenType,
5 transaction::TransactionUtil,
6};
7use solana_client::nonblocking::rpc_client::RpcClient;
8use solana_compute_budget_interface::ComputeBudgetInstruction;
9use solana_keychain::SolanaSigner;
10use solana_message::{Message, VersionedMessage};
11use solana_sdk::{instruction::Instruction, pubkey::Pubkey};
12
13use spl_associated_token_account_interface::{
14 address::get_associated_token_address, instruction::create_associated_token_account,
15};
16use std::{fmt::Display, str::FromStr, sync::Arc};
17
18#[cfg(not(test))]
19use {crate::cache::CacheUtil, crate::state::get_config};
20
21#[cfg(test)]
22use {
23 crate::config::SplTokenConfig, crate::tests::cache_mock::MockCacheUtil as CacheUtil,
24 crate::tests::config_mock::mock_state::get_config,
25};
26
27const DEFAULT_CHUNK_SIZE: usize = 10;
32
33pub struct ATAToCreate {
34 pub mint: Pubkey,
35 pub ata: Pubkey,
36 pub token_program: Pubkey,
37}
38
39impl Display for ATAToCreate {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "Token {}: ATA {} (Token program: {})", self.mint, self.ata, self.token_program)
42 }
43}
44
45pub async fn initialize_atas(
52 rpc_client: &RpcClient,
53 compute_unit_price: Option<u64>,
54 compute_unit_limit: Option<u32>,
55 chunk_size: Option<usize>,
56 fee_payer_key: Option<String>,
57) -> Result<(), KoraError> {
58 let config = get_config()?;
59
60 let fee_payer = get_request_signer_with_signer_key(fee_payer_key.as_deref())?;
61
62 let addresses_to_initialize_atas = if let Some(payment_address) = &config.kora.payment_address {
63 vec![Pubkey::from_str(payment_address)
64 .map_err(|e| KoraError::InternalServerError(format!("Invalid payment address: {e}")))?]
65 } else {
66 get_signer_pool()?
67 .get_signers_info()
68 .iter()
69 .filter_map(|info| info.public_key.parse().ok())
70 .collect::<Vec<Pubkey>>()
71 };
72
73 initialize_atas_with_chunk_size(
74 rpc_client,
75 &fee_payer,
76 &addresses_to_initialize_atas,
77 compute_unit_price,
78 compute_unit_limit,
79 chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE),
80 )
81 .await
82}
83
84pub async fn initialize_atas_with_chunk_size(
87 rpc_client: &RpcClient,
88 fee_payer: &Arc<solana_keychain::Signer>,
89 addresses_to_initialize_atas: &Vec<Pubkey>,
90 compute_unit_price: Option<u64>,
91 compute_unit_limit: Option<u32>,
92 chunk_size: usize,
93) -> Result<(), KoraError> {
94 for address in addresses_to_initialize_atas {
95 println!("Initializing ATAs for address: {address}");
96
97 let atas_to_create = find_missing_atas(rpc_client, address).await?;
98
99 if atas_to_create.is_empty() {
100 println!("✓ All required ATAs already exist for address: {address}");
101 continue;
102 }
103
104 create_atas_for_signer(
105 rpc_client,
106 fee_payer,
107 address,
108 &atas_to_create,
109 compute_unit_price,
110 compute_unit_limit,
111 chunk_size,
112 )
113 .await?;
114 }
115
116 println!("✓ Successfully created all ATAs");
117
118 Ok(())
119}
120
121async fn create_atas_for_signer(
123 rpc_client: &RpcClient,
124 fee_payer: &Arc<solana_keychain::Signer>,
125 address: &Pubkey,
126 atas_to_create: &[ATAToCreate],
127 compute_unit_price: Option<u64>,
128 compute_unit_limit: Option<u32>,
129 chunk_size: usize,
130) -> Result<usize, KoraError> {
131 let instructions = atas_to_create
132 .iter()
133 .map(|ata| {
134 create_associated_token_account(
135 &fee_payer.pubkey(),
136 address,
137 &ata.mint,
138 &ata.token_program,
139 )
140 })
141 .collect::<Vec<Instruction>>();
142
143 let total_atas = instructions.len();
145 let chunks: Vec<_> = instructions.chunks(chunk_size).collect();
146 let num_chunks = chunks.len();
147
148 println!(
149 "Creating {total_atas} ATAs in {num_chunks} transaction(s) (chunk size: {chunk_size})..."
150 );
151
152 let mut created_atas_idx = 0;
153
154 for (chunk_idx, chunk) in chunks.iter().enumerate() {
155 let chunk_num = chunk_idx + 1;
156 println!("Processing chunk {chunk_num}/{num_chunks}");
157
158 let mut chunk_instructions = Vec::new();
160
161 if let Some(compute_unit_price) = compute_unit_price {
163 chunk_instructions
164 .push(ComputeBudgetInstruction::set_compute_unit_price(compute_unit_price));
165 }
166 if let Some(compute_unit_limit) = compute_unit_limit {
167 chunk_instructions
168 .push(ComputeBudgetInstruction::set_compute_unit_limit(compute_unit_limit));
169 }
170
171 chunk_instructions.extend_from_slice(chunk);
173
174 let blockhash = rpc_client
175 .get_latest_blockhash()
176 .await
177 .map_err(|e| KoraError::RpcError(format!("Failed to get blockhash: {e}")))?;
178
179 let fee_payer_pubkey = fee_payer.pubkey();
180 let message = VersionedMessage::Legacy(Message::new_with_blockhash(
181 &chunk_instructions,
182 Some(&fee_payer_pubkey),
183 &blockhash,
184 ));
185
186 let mut tx = TransactionUtil::new_unsigned_versioned_transaction(message);
187 let message_bytes = tx.message.serialize();
188 let signature = fee_payer
189 .sign_message(&message_bytes)
190 .await
191 .map_err(|e| KoraError::SigningError(e.to_string()))?;
192
193 tx.signatures = vec![signature];
194
195 match rpc_client.send_and_confirm_transaction_with_spinner(&tx).await {
196 Ok(signature) => {
197 println!(
198 "✓ Chunk {chunk_num}/{num_chunks} successful. Transaction signature: {signature}"
199 );
200
201 let chunk_end = std::cmp::min(created_atas_idx + chunk.len(), atas_to_create.len());
203
204 (created_atas_idx..chunk_end).for_each(|i| {
205 let ATAToCreate { mint, ata, token_program } = &atas_to_create[i];
206 println!(" - Token {mint}: ATA {ata} (Token program: {token_program})");
207 });
208 created_atas_idx = chunk_end;
209 }
210 Err(e) => {
211 println!("✗ Chunk {chunk_num}/{num_chunks} failed: {e}");
212
213 if created_atas_idx > 0 {
214 println!("\nSuccessfully created ATAs ({created_atas_idx}/{total_atas}):");
215 println!(
216 "{}",
217 atas_to_create[0..created_atas_idx]
218 .iter()
219 .map(|ata| format!(" ✓ {ata}"))
220 .collect::<Vec<String>>()
221 .join("\n")
222 );
223 println!("\nRemaining ATAs to create: {}", total_atas - created_atas_idx);
224 } else {
225 println!("No ATAs were successfully created.");
226 }
227
228 println!("This may be a temporary network issue. Please re-run the command to retry ATA creation.");
229 return Err(KoraError::RpcError(format!(
230 "Failed to send ATA creation transaction for chunk {chunk_num}/{num_chunks}: {e}"
231 )));
232 }
233 }
234 }
235
236 println!("\n🎉 All ATA creation completed successfully!");
238 println!("Successfully created ATAs ({total_atas}/{total_atas}):");
239 println!(
240 "{}",
241 atas_to_create.iter().map(|ata| format!(" ✓ {ata}")).collect::<Vec<String>>().join("\n")
242 );
243
244 Ok(total_atas)
245}
246
247pub async fn find_missing_atas(
248 rpc_client: &RpcClient,
249 payment_address: &Pubkey,
250) -> Result<Vec<ATAToCreate>, KoraError> {
251 let config = get_config()?;
252
253 let mut token_mints = Vec::new();
255 for token_str in &config.validation.allowed_spl_paid_tokens {
256 match Pubkey::from_str(token_str) {
257 Ok(mint) => token_mints.push(mint),
258 Err(_) => {
259 println!("⚠️ Skipping invalid token mint: {token_str}");
260 continue;
261 }
262 }
263 }
264
265 if token_mints.is_empty() {
266 println!("✓ No SPL payment tokens configured");
267 return Ok(Vec::new());
268 }
269
270 let mut atas_to_create = Vec::new();
271
272 for mint in &token_mints {
274 let ata = get_associated_token_address(payment_address, mint);
275
276 match CacheUtil::get_account(rpc_client, &ata, false).await {
277 Ok(_) => {
278 println!("✓ ATA already exists for token {mint}: {ata}");
279 }
280 Err(_) => {
281 let mint_account =
283 CacheUtil::get_account(rpc_client, mint, false).await.map_err(|e| {
284 KoraError::RpcError(format!("Failed to fetch mint account for {mint}: {e}"))
285 })?;
286
287 let token_program = TokenType::get_token_program_from_owner(&mint_account.owner)?;
288
289 println!("Creating ATA for token {mint}: {ata}");
290
291 atas_to_create.push(ATAToCreate {
292 mint: *mint,
293 ata,
294 token_program: token_program.program_id(),
295 });
296 }
297 }
298 }
299
300 Ok(atas_to_create)
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306 use crate::tests::{
307 common::{
308 create_mock_rpc_client_account_not_found, create_mock_spl_mint_account,
309 create_mock_token_account, setup_or_get_test_signer, RpcMockBuilder,
310 },
311 config_mock::{ConfigMockBuilder, ValidationConfigBuilder},
312 };
313 use std::{
314 collections::VecDeque,
315 sync::{Arc, Mutex},
316 };
317
318 #[tokio::test]
319 async fn test_find_missing_atas_no_spl_tokens() {
320 let _m = ConfigMockBuilder::new()
321 .with_validation(
322 ValidationConfigBuilder::new()
323 .with_allowed_spl_paid_tokens(SplTokenConfig::Allowlist(vec![]))
324 .build(),
325 )
326 .build_and_setup();
327
328 let rpc_client = create_mock_rpc_client_account_not_found();
329 let payment_address = Pubkey::new_unique();
330
331 let result = find_missing_atas(&rpc_client, &payment_address).await.unwrap();
332
333 assert!(result.is_empty(), "Should return empty vec when no SPL tokens configured");
334 }
335
336 #[tokio::test]
337 async fn test_find_missing_atas_with_spl_tokens() {
338 let allowed_spl_tokens = [Pubkey::new_unique(), Pubkey::new_unique()];
339
340 let _m = ConfigMockBuilder::new()
341 .with_validation(
342 ValidationConfigBuilder::new()
343 .with_allowed_spl_paid_tokens(SplTokenConfig::Allowlist(
344 allowed_spl_tokens.iter().map(|p| p.to_string()).collect(),
345 ))
346 .build(),
347 )
348 .build_and_setup();
349
350 let cache_ctx = CacheUtil::get_account_context();
351 cache_ctx.checkpoint(); let payment_address = Pubkey::new_unique();
354 let rpc_client = create_mock_rpc_client_account_not_found();
355
356 let responses = Arc::new(Mutex::new(VecDeque::from([
360 Ok(create_mock_token_account(&Pubkey::new_unique(), &Pubkey::new_unique())),
361 Err(KoraError::RpcError("ATA not found".to_string())),
362 Ok(create_mock_spl_mint_account(6)),
363 ])));
364
365 let responses_clone = responses.clone();
366 cache_ctx
367 .expect()
368 .times(3)
369 .returning(move |_, _, _| responses_clone.lock().unwrap().pop_front().unwrap());
370
371 let result = find_missing_atas(&rpc_client, &payment_address).await;
372
373 assert!(result.is_ok(), "Should handle SPL tokens with proper mocking");
374 let atas = result.unwrap();
375 assert_eq!(atas.len(), 1, "Should return 1 missing ATAs");
376 }
377
378 #[tokio::test]
379 async fn test_create_atas_for_signer_calls_rpc_correctly() {
380 let _m = ConfigMockBuilder::new().build_and_setup();
381
382 let _ = setup_or_get_test_signer();
383
384 let address = Pubkey::new_unique();
385 let mint1 = Pubkey::new_unique();
386 let mint2 = Pubkey::new_unique();
387
388 let atas_to_create = vec![
389 ATAToCreate {
390 mint: mint1,
391 ata: spl_associated_token_account_interface::address::get_associated_token_address(
392 &address, &mint1,
393 ),
394 token_program: spl_token_interface::id(),
395 },
396 ATAToCreate {
397 mint: mint2,
398 ata: spl_associated_token_account_interface::address::get_associated_token_address(
399 &address, &mint2,
400 ),
401 token_program: spl_token_interface::id(),
402 },
403 ];
404
405 let rpc_client = RpcMockBuilder::new().with_blockhash().with_send_transaction().build();
406
407 let result = create_atas_for_signer(
408 &rpc_client,
409 &get_request_signer_with_signer_key(None).unwrap(),
410 &address,
411 &atas_to_create,
412 Some(1000),
413 Some(100_000),
414 2,
415 )
416 .await;
417
418 match result {
420 Ok(_) => {
421 panic!("Expected signature validation error, but got success");
422 }
423 Err(e) => {
424 let error_msg = format!("{e:?}");
425 assert!(
427 error_msg.contains("signature")
428 || error_msg.contains("Signature")
429 || error_msg.contains("invalid")
430 || error_msg.contains("mismatch"),
431 "Expected signature validation error, got: {error_msg}"
432 );
433 }
434 }
435 }
436
437 #[tokio::test]
438 async fn test_initialize_atas_when_all_tokens_are_allowed() {
439 let _m = ConfigMockBuilder::new()
440 .with_allowed_spl_paid_tokens(SplTokenConfig::All)
441 .build_and_setup();
442
443 let _ = setup_or_get_test_signer();
444
445 let rpc_client = RpcMockBuilder::new().build();
446
447 let result = initialize_atas(&rpc_client, None, None, None, None).await;
448
449 assert!(result.is_ok(), "Expected atas init to succeed");
450 }
451}