Skip to main content

krusty_kms_client/
staking.rs

1//! STRK staking delegation pool operations.
2
3use crate::abi;
4use crate::tx::Tx;
5use crate::wallet::utils::{self, core_felt_to_rs, rs_felt_to_core};
6use crate::wallet::WalletExecutor;
7use krusty_kms_common::address::Address;
8use krusty_kms_common::amount::Amount;
9use krusty_kms_common::chain::ChainId;
10use krusty_kms_common::token::Token;
11use krusty_kms_common::{KmsError, Result};
12use starknet_rust::core::types::{BlockId, BlockTag, Call, FunctionCall};
13use starknet_rust::providers::jsonrpc::{HttpTransport, JsonRpcClient};
14use starknet_rust::providers::Provider;
15use std::sync::Arc;
16
17/// Mainnet staking contract address.
18const MAINNET_STAKING_CONTRACT: &str =
19    "0x0594c1582459ea03f77deaf9eb7e3917d6994a03c13405ba42867f83d85f085d";
20/// Sepolia staking contract address.
21const SEPOLIA_STAKING_CONTRACT: &str =
22    "0x03745ab04a431fc02871a139be6b93d9260b0ff3e779ad9c8b377183b23109f1";
23
24/// Get the staking contract address for a given chain.
25pub fn staking_contract_address(chain_id: ChainId) -> Address {
26    match chain_id {
27        ChainId::Mainnet => Address::from_hex(MAINNET_STAKING_CONTRACT).unwrap(),
28        ChainId::Sepolia => Address::from_hex(SEPOLIA_STAKING_CONTRACT).unwrap(),
29    }
30}
31
32/// A handle for interacting with a staking delegation pool.
33pub struct Staking {
34    provider: Arc<JsonRpcClient<HttpTransport>>,
35    pool_address: Address,
36    token: Token,
37}
38
39/// A pool member's staking position.
40#[derive(Debug, Clone)]
41pub struct PoolPosition {
42    pub reward_address: Address,
43    pub amount: Amount,
44    pub unclaimed_rewards: Amount,
45    pub commission: u16,
46}
47
48impl Staking {
49    /// Create a staking handle from a known pool address and token.
50    pub fn new(
51        provider: Arc<JsonRpcClient<HttpTransport>>,
52        pool_address: Address,
53        token: Token,
54    ) -> Self {
55        Self {
56            provider,
57            pool_address,
58            token,
59        }
60    }
61
62    /// Build calls to enter a delegation pool: approve + enter_delegation_pool.
63    pub fn populate_enter(&self, amount: &Amount, reward_address: &Address) -> Vec<Call> {
64        let pool_rs = core_felt_to_rs(self.pool_address.as_felt());
65        let (low, high) = amount.to_u256();
66
67        let approve = Call {
68            to: core_felt_to_rs(self.token.address.as_felt()),
69            selector: *abi::erc20::APPROVE,
70            calldata: vec![pool_rs, core_felt_to_rs(low), core_felt_to_rs(high)],
71        };
72
73        let enter = Call {
74            to: pool_rs,
75            selector: *abi::pool::ENTER_DELEGATION_POOL,
76            calldata: vec![
77                core_felt_to_rs(reward_address.as_felt()),
78                core_felt_to_rs(low),
79                core_felt_to_rs(high),
80            ],
81        };
82
83        vec![approve, enter]
84    }
85
86    /// Build calls to add more stake: approve + add_to_delegation_pool.
87    pub fn populate_add(&self, amount: &Amount) -> Vec<Call> {
88        let pool_rs = core_felt_to_rs(self.pool_address.as_felt());
89        let (low, high) = amount.to_u256();
90
91        let approve = Call {
92            to: core_felt_to_rs(self.token.address.as_felt()),
93            selector: *abi::erc20::APPROVE,
94            calldata: vec![pool_rs, core_felt_to_rs(low), core_felt_to_rs(high)],
95        };
96
97        let add = Call {
98            to: pool_rs,
99            selector: *abi::pool::ADD_TO_DELEGATION_POOL,
100            calldata: vec![core_felt_to_rs(low), core_felt_to_rs(high)],
101        };
102
103        vec![approve, add]
104    }
105
106    /// Build a claim_rewards call.
107    pub fn populate_claim_rewards(&self, reward_address: &Address) -> Call {
108        Call {
109            to: core_felt_to_rs(self.pool_address.as_felt()),
110            selector: *abi::pool::CLAIM_REWARDS,
111            calldata: vec![core_felt_to_rs(reward_address.as_felt())],
112        }
113    }
114
115    /// Build an exit_delegation_pool_intent call.
116    pub fn populate_exit_intent(&self, amount: &Amount) -> Call {
117        let (low, high) = amount.to_u256();
118        Call {
119            to: core_felt_to_rs(self.pool_address.as_felt()),
120            selector: *abi::pool::EXIT_INTENT,
121            calldata: vec![core_felt_to_rs(low), core_felt_to_rs(high)],
122        }
123    }
124
125    /// Build an exit_delegation_pool_action call.
126    pub fn populate_exit(&self) -> Call {
127        Call {
128            to: core_felt_to_rs(self.pool_address.as_felt()),
129            selector: *abi::pool::EXIT_ACTION,
130            calldata: vec![],
131        }
132    }
133
134    /// Check if an address is a pool member.
135    pub async fn is_member(&self, address: &Address) -> Result<bool> {
136        match self.get_position(address).await {
137            Ok(_) => Ok(true),
138            Err(_) => Ok(false),
139        }
140    }
141
142    /// Get a member's staking position.
143    pub async fn get_position(&self, address: &Address) -> Result<PoolPosition> {
144        let pool_rs = core_felt_to_rs(self.pool_address.as_felt());
145        let addr_rs = core_felt_to_rs(address.as_felt());
146
147        let result = self
148            .provider
149            .call(
150                FunctionCall {
151                    contract_address: pool_rs,
152                    entry_point_selector: *abi::pool::POOL_MEMBER_INFO,
153                    calldata: vec![addr_rs],
154                },
155                BlockId::Tag(BlockTag::Latest),
156            )
157            .await
158            .map_err(|e| KmsError::StakingError(e.to_string()))?;
159
160        // Parse pool_member_info response
161        // Expected: reward_address, amount (u256), unclaimed_rewards (u256), commission (u16), ...
162        if result.len() < 6 {
163            return Err(KmsError::StakingError(
164                "Unexpected pool_member_info response length".into(),
165            ));
166        }
167
168        let reward_address = Address::from(rs_felt_to_core(result[0]));
169        let amount_raw = utils::felt_to_u128(&result[1]);
170        let unclaimed_raw = utils::felt_to_u128(&result[3]);
171        let commission = utils::felt_to_u16(&result[5]);
172
173        Ok(PoolPosition {
174            reward_address,
175            amount: Amount::from_raw(amount_raw, self.token.decimals),
176            unclaimed_rewards: Amount::from_raw(unclaimed_raw, self.token.decimals),
177            commission,
178        })
179    }
180
181    /// Get the pool commission rate (basis points).
182    pub async fn get_commission(&self) -> Result<u16> {
183        let pool_rs = core_felt_to_rs(self.pool_address.as_felt());
184        let result = self
185            .provider
186            .call(
187                FunctionCall {
188                    contract_address: pool_rs,
189                    entry_point_selector: *abi::pool::CONTRACT_PARAMETERS,
190                    calldata: vec![],
191                },
192                BlockId::Tag(BlockTag::Latest),
193            )
194            .await
195            .map_err(|e| KmsError::StakingError(e.to_string()))?;
196
197        if result.is_empty() {
198            return Err(KmsError::StakingError(
199                "Empty contract_parameters response".into(),
200            ));
201        }
202
203        // Commission is typically the first field
204        Ok(utils::felt_to_u16(&result[0]))
205    }
206
207    /// Stake: if already a member, adds to pool; otherwise enters as new member.
208    pub async fn stake(
209        &self,
210        wallet: &dyn WalletExecutor,
211        amount: &Amount,
212        reward_address: &Address,
213    ) -> Result<Tx> {
214        let is_existing = self.is_member(wallet.address()).await?;
215        let calls = if is_existing {
216            self.populate_add(amount)
217        } else {
218            self.populate_enter(amount, reward_address)
219        };
220        wallet.execute(calls).await
221    }
222
223    /// The pool address.
224    pub fn pool_address(&self) -> &Address {
225        &self.pool_address
226    }
227
228    /// The staking token.
229    pub fn token(&self) -> &Token {
230        &self.token
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn test_staking_contract_addresses() {
240        let mainnet = staking_contract_address(ChainId::Mainnet);
241        let sepolia = staking_contract_address(ChainId::Sepolia);
242        assert_ne!(mainnet.as_felt(), sepolia.as_felt());
243    }
244
245    #[test]
246    fn test_populate_enter() {
247        let provider = Arc::new(JsonRpcClient::new(
248            starknet_rust::providers::jsonrpc::HttpTransport::new(
249                url::Url::parse("http://localhost:5050").unwrap(),
250            ),
251        ));
252        let token = krusty_kms_common::token::presets::strk(ChainId::Mainnet);
253        let pool = Address::from_hex("0xDEAD").unwrap();
254        let staking = Staking::new(provider, pool, token);
255
256        let amount = Amount::from_raw(1_000_000_000_000_000_000, 18);
257        let reward = Address::from_hex("0xBEEF").unwrap();
258        let calls = staking.populate_enter(&amount, &reward);
259
260        // approve + enter = 2 calls
261        assert_eq!(calls.len(), 2);
262    }
263
264    #[test]
265    fn test_populate_add() {
266        let provider = Arc::new(JsonRpcClient::new(
267            starknet_rust::providers::jsonrpc::HttpTransport::new(
268                url::Url::parse("http://localhost:5050").unwrap(),
269            ),
270        ));
271        let token = krusty_kms_common::token::presets::strk(ChainId::Mainnet);
272        let pool = Address::from_hex("0xDEAD").unwrap();
273        let staking = Staking::new(provider, pool, token);
274
275        let amount = Amount::from_raw(500, 18);
276        let calls = staking.populate_add(&amount);
277
278        assert_eq!(calls.len(), 2);
279    }
280
281    #[test]
282    fn test_populate_exit() {
283        let provider = Arc::new(JsonRpcClient::new(
284            starknet_rust::providers::jsonrpc::HttpTransport::new(
285                url::Url::parse("http://localhost:5050").unwrap(),
286            ),
287        ));
288        let token = krusty_kms_common::token::presets::strk(ChainId::Mainnet);
289        let pool = Address::from_hex("0xDEAD").unwrap();
290        let staking = Staking::new(provider, pool, token);
291
292        let call = staking.populate_exit();
293        assert!(call.calldata.is_empty());
294    }
295}