1use crate::api::api;
9use crate::error::BittensorError;
10use crate::extrinsics::ExtrinsicResponse;
11use crate::utils::{normalize_weights, NormalizedWeight};
12use subxt::OnlineClient;
13use subxt::PolkadotConfig;
14
15#[derive(Debug, Clone)]
17pub struct WeightsParams {
18 pub netuid: u16,
20 pub uids: Vec<u16>,
22 pub weights: Vec<u16>,
24 pub version_key: u64,
26}
27
28impl WeightsParams {
29 pub fn new(netuid: u16, uids: Vec<u16>, weights: Vec<u16>) -> Result<Self, &'static str> {
41 if uids.len() != weights.len() {
42 return Err("UIDs and weights must have the same length");
43 }
44 Ok(Self {
45 netuid,
46 uids,
47 weights,
48 version_key: 0,
49 })
50 }
51
52 pub fn with_version_key(mut self, version_key: u64) -> Self {
54 self.version_key = version_key;
55 self
56 }
57
58 pub fn to_normalized(&self) -> Vec<NormalizedWeight> {
60 let weight_pairs: Vec<(u16, u16)> = self
61 .uids
62 .iter()
63 .zip(self.weights.iter())
64 .map(|(u, w)| (*u, *w))
65 .collect();
66 normalize_weights(&weight_pairs)
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct CommitRevealParams {
73 pub netuid: u16,
75 pub commit_hash: [u8; 32],
77 pub uids: Vec<u16>,
79 pub weights: Vec<u16>,
81 pub salt: Vec<u16>,
83 pub version_key: u64,
85}
86
87impl CommitRevealParams {
88 pub fn new_with_hash(netuid: u16, commit_hash: [u8; 32]) -> Self {
90 Self {
91 netuid,
92 commit_hash,
93 uids: Vec::new(),
94 weights: Vec::new(),
95 salt: Vec::new(),
96 version_key: 0,
97 }
98 }
99
100 pub fn new_with_weights(
104 netuid: u16,
105 uids: Vec<u16>,
106 weights: Vec<u16>,
107 salt: Vec<u16>,
108 version_key: u64,
109 ) -> Self {
110 let commit_hash = compute_commit_hash(&uids, &weights, &salt, version_key);
111
112 Self {
113 netuid,
114 commit_hash,
115 uids,
116 weights,
117 salt,
118 version_key,
119 }
120 }
121}
122
123fn compute_commit_hash(uids: &[u16], weights: &[u16], salt: &[u16], version_key: u64) -> [u8; 32] {
125 use sp_core::keccak_256;
126
127 let mut data = Vec::new();
128
129 for uid in uids {
130 data.extend_from_slice(&uid.to_le_bytes());
131 }
132
133 for weight in weights {
134 data.extend_from_slice(&weight.to_le_bytes());
135 }
136
137 for s in salt {
138 data.extend_from_slice(&s.to_le_bytes());
139 }
140
141 data.extend_from_slice(&version_key.to_le_bytes());
142
143 keccak_256(&data)
144}
145
146pub async fn set_weights<S>(
170 client: &OnlineClient<PolkadotConfig>,
171 signer: &S,
172 params: WeightsParams,
173) -> Result<ExtrinsicResponse<()>, BittensorError>
174where
175 S: subxt::tx::Signer<PolkadotConfig>,
176{
177 let normalized = params.to_normalized();
178 let (dests, values): (Vec<u16>, Vec<u16>) =
179 normalized.into_iter().map(|w| (w.uid, w.weight)).unzip();
180
181 let call =
182 api::tx()
183 .subtensor_module()
184 .set_weights(params.netuid, dests, values, params.version_key);
185
186 let tx_hash = client
187 .tx()
188 .sign_and_submit_default(&call, signer)
189 .await
190 .map_err(|e| BittensorError::TxSubmissionError {
191 message: format!("Failed to submit set_weights: {}", e),
192 })?;
193
194 Ok(ExtrinsicResponse::success()
195 .with_message("Weights set successfully")
196 .with_extrinsic_hash(&format!("{:?}", tx_hash))
197 .with_data(()))
198}
199
200pub async fn commit_weights<S>(
205 client: &OnlineClient<PolkadotConfig>,
206 signer: &S,
207 params: CommitRevealParams,
208) -> Result<ExtrinsicResponse<[u8; 32]>, BittensorError>
209where
210 S: subxt::tx::Signer<PolkadotConfig>,
211{
212 let commit_hash_h256 = subxt::utils::H256::from_slice(¶ms.commit_hash);
213
214 let call = api::tx()
215 .subtensor_module()
216 .commit_weights(params.netuid, commit_hash_h256);
217
218 let tx_hash = client
219 .tx()
220 .sign_and_submit_default(&call, signer)
221 .await
222 .map_err(|e| BittensorError::TxSubmissionError {
223 message: format!("Failed to submit commit_weights: {}", e),
224 })?;
225
226 Ok(ExtrinsicResponse::success()
227 .with_message("Weights committed successfully")
228 .with_extrinsic_hash(&format!("{:?}", tx_hash))
229 .with_data(params.commit_hash))
230}
231
232pub async fn reveal_weights<S>(
234 client: &OnlineClient<PolkadotConfig>,
235 signer: &S,
236 params: CommitRevealParams,
237) -> Result<ExtrinsicResponse<()>, BittensorError>
238where
239 S: subxt::tx::Signer<PolkadotConfig>,
240{
241 if params.uids.is_empty() || params.weights.is_empty() || params.salt.is_empty() {
242 return Err(BittensorError::ConfigError {
243 field: "params".to_string(),
244 message: "UIDs, weights, and salt are required for reveal".to_string(),
245 });
246 }
247
248 let call = api::tx().subtensor_module().reveal_weights(
249 params.netuid,
250 params.uids,
251 params.weights,
252 params.salt,
253 params.version_key,
254 );
255
256 let tx_hash = client
257 .tx()
258 .sign_and_submit_default(&call, signer)
259 .await
260 .map_err(|e| BittensorError::TxSubmissionError {
261 message: format!("Failed to submit reveal_weights: {}", e),
262 })?;
263
264 Ok(ExtrinsicResponse::success()
265 .with_message("Weights revealed successfully")
266 .with_extrinsic_hash(&format!("{:?}", tx_hash))
267 .with_data(()))
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn test_weights_params() {
276 let params = WeightsParams::new(1, vec![0, 1, 2], vec![100, 200, 300]).unwrap();
277 assert_eq!(params.netuid, 1);
278 assert_eq!(params.uids.len(), 3);
279 assert_eq!(params.weights.len(), 3);
280 assert_eq!(params.version_key, 0);
281 }
282
283 #[test]
284 fn test_weights_params_builder() {
285 let params = WeightsParams::new(1, vec![0], vec![100])
286 .unwrap()
287 .with_version_key(42);
288 assert_eq!(params.version_key, 42);
289 }
290
291 #[test]
292 fn test_normalize_weights() {
293 let params = WeightsParams::new(1, vec![0, 1], vec![100, 100]).unwrap();
294 let normalized = params.to_normalized();
295
296 assert_eq!(normalized.len(), 2);
297 let diff = (normalized[0].weight as i32 - normalized[1].weight as i32).abs();
299 assert!(diff < 2);
300 }
301
302 #[test]
303 fn test_commit_reveal_params() {
304 let params = CommitRevealParams::new_with_weights(
305 1,
306 vec![0, 1, 2],
307 vec![100, 200, 300],
308 vec![1, 2, 3],
309 0,
310 );
311
312 assert_eq!(params.netuid, 1);
313 assert_eq!(params.uids.len(), 3);
314 assert!(params.commit_hash.iter().any(|&b| b != 0));
316 }
317
318 #[test]
319 fn test_commit_hash_deterministic() {
320 let hash1 = compute_commit_hash(&[0, 1], &[100, 200], &[1, 2], 0);
321 let hash2 = compute_commit_hash(&[0, 1], &[100, 200], &[1, 2], 0);
322 assert_eq!(hash1, hash2);
323
324 let hash3 = compute_commit_hash(&[0, 1], &[100, 201], &[1, 2], 0);
325 assert_ne!(hash1, hash3);
326 }
327
328 #[test]
329 fn test_commit_reveal_params_with_hash() {
330 let hash = [1u8; 32];
331 let params = CommitRevealParams::new_with_hash(1, hash);
332 assert_eq!(params.commit_hash, hash);
333 assert!(params.uids.is_empty());
334 }
335}