coolrouter_cpi/
lib.rs

1use anchor_lang::prelude::*;
2
3#[derive(AnchorSerialize, AnchorDeserialize, Clone, Debug)]
4pub struct Message {
5    pub role: String,
6    pub content: String,
7}
8
9pub struct CoolRouterCPI<'info> {
10    pub request_pda: AccountInfo<'info>,
11    pub authority: AccountInfo<'info>,
12    pub caller_program: AccountInfo<'info>,
13    pub system_program: AccountInfo<'info>,
14    pub coolrouter_program: Pubkey,
15    pub callback_accounts: Vec<AccountInfo<'info>>,
16}
17
18impl<'info> CoolRouterCPI<'info> {
19    pub fn new(
20        request_pda: AccountInfo<'info>,
21        authority: AccountInfo<'info>,
22        caller_program: AccountInfo<'info>,
23        system_program: AccountInfo<'info>,
24        coolrouter_program: Pubkey,
25    ) -> Self {
26        Self {
27            request_pda,
28            authority,
29            caller_program,
30            system_program,
31            coolrouter_program,
32            callback_accounts: Vec::new(),
33        }
34    }
35
36    pub fn add_callback_account(mut self, account: AccountInfo<'info>) -> Self {
37        self.callback_accounts.push(account);
38        self
39    }
40
41    pub fn add_callback_accounts(mut self, accounts: Vec<AccountInfo<'info>>) -> Self {
42        self.callback_accounts.extend(accounts);
43        self
44    }
45
46    pub fn create_request(
47        self,
48        request_id: String,
49        provider: String,
50        model_id: String,
51        messages: Vec<Message>,
52        min_votes: u8,
53        approval_threshold: u8,
54    ) -> Result<()> {
55        let data = Self::serialize_create_request(
56            &request_id,
57            &provider,
58            &model_id,
59            &messages,
60            min_votes,
61            approval_threshold,
62        )?;
63
64        let cpi_accounts = vec![
65            self.request_pda.clone(),
66            self.authority.clone(),
67            self.caller_program.clone(),
68            self.system_program.clone(),
69        ];
70
71        let mut account_metas = cpi_accounts
72            .iter()
73            .map(|acc| AccountMeta {
74                pubkey: *acc.key,
75                is_signer: acc.is_signer,
76                is_writable: acc.is_writable,
77            })
78            .collect::<Vec<_>>();
79
80        for acc in &self.callback_accounts {
81            account_metas.push(AccountMeta {
82                pubkey: *acc.key,
83                is_signer: false,
84                is_writable: true,
85            });
86        }
87
88        let ix = anchor_lang::solana_program::instruction::Instruction {
89            program_id: self.coolrouter_program,
90            accounts: account_metas,
91            data,
92        };
93
94        let mut all_accounts = cpi_accounts;
95        all_accounts.extend(self.callback_accounts);
96
97        anchor_lang::solana_program::program::invoke(&ix, &all_accounts)?;
98
99        Ok(())
100    }
101
102    fn serialize_create_request(
103        request_id: &str,
104        provider: &str,
105        model_id: &str,
106        messages: &[Message],
107        min_votes: u8,
108        approval_threshold: u8,
109    ) -> Result<Vec<u8>> {
110        let mut data = Vec::new();
111
112        let discriminator = Self::calculate_discriminator("global:create_request");
113        data.extend_from_slice(&discriminator);
114
115        data.extend_from_slice(&request_id.to_string().try_to_vec()?);
116        data.extend_from_slice(&provider.to_string().try_to_vec()?);
117        data.extend_from_slice(&model_id.to_string().try_to_vec()?);
118        data.extend_from_slice(&messages.to_vec().try_to_vec()?);
119        data.extend_from_slice(&min_votes.try_to_vec()?);
120        data.extend_from_slice(&approval_threshold.try_to_vec()?);
121
122        Ok(data)
123    }
124
125    fn calculate_discriminator(namespace_and_name: &str) -> [u8; 8] {
126        use sha2::{Digest, Sha256};
127        let mut hasher = Sha256::new();
128        hasher.update(namespace_and_name.as_bytes());
129        let hash_result = hasher.finalize();
130        hash_result[..8].try_into().unwrap()
131    }
132}
133
134pub fn create_llm_request<'info>(
135    request_pda: AccountInfo<'info>,
136    authority: AccountInfo<'info>,
137    caller_program: AccountInfo<'info>,
138    system_program: AccountInfo<'info>,
139    coolrouter_program: Pubkey,
140    callback_accounts: Vec<AccountInfo<'info>>,
141    request_id: String,
142    provider: String,
143    model_id: String,
144    messages: Vec<Message>,
145    min_votes: u8,
146    approval_threshold: u8,
147) -> Result<()> {
148    CoolRouterCPI::new(
149        request_pda,
150        authority,
151        caller_program,
152        system_program,
153        coolrouter_program,
154    )
155    .add_callback_accounts(callback_accounts)
156    .create_request(request_id, provider, model_id, messages, min_votes, approval_threshold)
157}