1use anchor_lang::prelude::*;
2
3#[derive(AnchorSerialize, AnchorDeserialize, Clone, Debug)]
4pub struct Message {
5 pub role: String,
6 pub content: String,
7}
8
9pub struct CoolRouterCPI<'info> {
10 pub request_pda: AccountInfo<'info>,
11 pub authority: AccountInfo<'info>,
12 pub caller_program: AccountInfo<'info>,
13 pub system_program: AccountInfo<'info>,
14 pub coolrouter_program: Pubkey,
15 pub callback_accounts: Vec<AccountInfo<'info>>,
16}
17
18impl<'info> CoolRouterCPI<'info> {
19 pub fn new(
20 request_pda: AccountInfo<'info>,
21 authority: AccountInfo<'info>,
22 caller_program: AccountInfo<'info>,
23 system_program: AccountInfo<'info>,
24 coolrouter_program: Pubkey,
25 ) -> Self {
26 Self {
27 request_pda,
28 authority,
29 caller_program,
30 system_program,
31 coolrouter_program,
32 callback_accounts: Vec::new(),
33 }
34 }
35
36 pub fn add_callback_account(mut self, account: AccountInfo<'info>) -> Self {
37 self.callback_accounts.push(account);
38 self
39 }
40
41 pub fn add_callback_accounts(mut self, accounts: Vec<AccountInfo<'info>>) -> Self {
42 self.callback_accounts.extend(accounts);
43 self
44 }
45
46 pub fn create_request(
47 self,
48 request_id: String,
49 provider: String,
50 model_id: String,
51 messages: Vec<Message>,
52 min_votes: u8,
53 approval_threshold: u8,
54 ) -> Result<()> {
55 let data = Self::serialize_create_request(
56 &request_id,
57 &provider,
58 &model_id,
59 &messages,
60 min_votes,
61 approval_threshold,
62 )?;
63
64 let cpi_accounts = vec![
65 self.request_pda.clone(),
66 self.authority.clone(),
67 self.caller_program.clone(),
68 self.system_program.clone(),
69 ];
70
71 let mut account_metas = cpi_accounts
72 .iter()
73 .map(|acc| AccountMeta {
74 pubkey: *acc.key,
75 is_signer: acc.is_signer,
76 is_writable: acc.is_writable,
77 })
78 .collect::<Vec<_>>();
79
80 for acc in &self.callback_accounts {
81 account_metas.push(AccountMeta {
82 pubkey: *acc.key,
83 is_signer: false,
84 is_writable: true,
85 });
86 }
87
88 let ix = anchor_lang::solana_program::instruction::Instruction {
89 program_id: self.coolrouter_program,
90 accounts: account_metas,
91 data,
92 };
93
94 let mut all_accounts = cpi_accounts;
95 all_accounts.extend(self.callback_accounts);
96
97 anchor_lang::solana_program::program::invoke(&ix, &all_accounts)?;
98
99 Ok(())
100 }
101
102 fn serialize_create_request(
103 request_id: &str,
104 provider: &str,
105 model_id: &str,
106 messages: &[Message],
107 min_votes: u8,
108 approval_threshold: u8,
109 ) -> Result<Vec<u8>> {
110 let mut data = Vec::new();
111
112 let discriminator = Self::calculate_discriminator("global:create_request");
113 data.extend_from_slice(&discriminator);
114
115 data.extend_from_slice(&request_id.to_string().try_to_vec()?);
116 data.extend_from_slice(&provider.to_string().try_to_vec()?);
117 data.extend_from_slice(&model_id.to_string().try_to_vec()?);
118 data.extend_from_slice(&messages.to_vec().try_to_vec()?);
119 data.extend_from_slice(&min_votes.try_to_vec()?);
120 data.extend_from_slice(&approval_threshold.try_to_vec()?);
121
122 Ok(data)
123 }
124
125 fn calculate_discriminator(namespace_and_name: &str) -> [u8; 8] {
126 use sha2::{Digest, Sha256};
127 let mut hasher = Sha256::new();
128 hasher.update(namespace_and_name.as_bytes());
129 let hash_result = hasher.finalize();
130 hash_result[..8].try_into().unwrap()
131 }
132}
133
134pub fn create_llm_request<'info>(
135 request_pda: AccountInfo<'info>,
136 authority: AccountInfo<'info>,
137 caller_program: AccountInfo<'info>,
138 system_program: AccountInfo<'info>,
139 coolrouter_program: Pubkey,
140 callback_accounts: Vec<AccountInfo<'info>>,
141 request_id: String,
142 provider: String,
143 model_id: String,
144 messages: Vec<Message>,
145 min_votes: u8,
146 approval_threshold: u8,
147) -> Result<()> {
148 CoolRouterCPI::new(
149 request_pda,
150 authority,
151 caller_program,
152 system_program,
153 coolrouter_program,
154 )
155 .add_callback_accounts(callback_accounts)
156 .create_request(request_id, provider, model_id, messages, min_votes, approval_threshold)
157}