1use futures::Stream;
4use sui_graphql_macros::Response;
5use sui_sdk_types::Address;
6use sui_sdk_types::StructTag;
7
8use super::Client;
9use crate::error::Error;
10use crate::pagination::Page;
11use crate::pagination::PageInfo;
12use crate::pagination::paginate;
13use crate::scalars::BigInt;
14
15#[derive(Debug, Clone)]
17pub struct Balance {
18 pub coin_type: StructTag,
20 pub total_balance: u64,
22}
23
24impl Client {
25 pub async fn get_balance(
56 &self,
57 owner: Address,
58 coin_type: &StructTag,
59 ) -> Result<Option<Balance>, Error> {
60 #[derive(Response)]
61 struct Response {
62 #[field(path = "address?.balance?.coinType?.repr?")]
63 coin_type: Option<StructTag>,
64 #[field(path = "address?.balance?.totalBalance?")]
65 total_balance: Option<BigInt>,
66 }
67
68 const QUERY: &str = r#"
69 query($owner: SuiAddress!, $coinType: String!) {
70 address(address: $owner) {
71 balance(coinType: $coinType) {
72 coinType {
73 repr
74 }
75 totalBalance
76 }
77 }
78 }
79 "#;
80
81 let variables = serde_json::json!({
82 "owner": owner,
83 "coinType": coin_type.to_string(),
84 });
85
86 let response = self.query::<Response>(QUERY, variables).await?;
87
88 let Some(data) = response.into_data() else {
89 return Ok(None);
90 };
91
92 match (data.coin_type, data.total_balance) {
93 (Some(coin_type), Some(total_balance)) => Ok(Some(Balance {
94 coin_type,
95 total_balance: total_balance.0,
96 })),
97 _ => Ok(None),
98 }
99 }
100
101 pub fn list_balances(&self, owner: Address) -> impl Stream<Item = Result<Balance, Error>> + '_ {
125 let client = self.clone();
126 paginate(move |cursor| {
127 let client = client.clone();
128 async move { client.fetch_balances_page(owner, cursor.as_deref()).await }
129 })
130 }
131
132 async fn fetch_balances_page(
134 &self,
135 owner: Address,
136 cursor: Option<&str>,
137 ) -> Result<Page<Balance>, Error> {
138 #[derive(Response)]
139 struct Response {
140 #[field(path = "address?.balances?.pageInfo?")]
141 page_info: Option<PageInfo>,
142 #[field(path = "address?.balances?.nodes?[].coinType?.repr?")]
143 coin_types: Option<Vec<Option<StructTag>>>,
144 #[field(path = "address?.balances?.nodes?[].totalBalance?")]
145 total_balances: Option<Vec<Option<BigInt>>>,
146 }
147
148 const QUERY: &str = r#"
149 query($owner: SuiAddress!, $after: String) {
150 address(address: $owner) {
151 balances(after: $after) {
152 pageInfo {
153 hasNextPage
154 endCursor
155 }
156 nodes {
157 coinType {
158 repr
159 }
160 totalBalance
161 }
162 }
163 }
164 }
165 "#;
166
167 let variables = serde_json::json!({
168 "owner": owner,
169 "after": cursor,
170 });
171
172 let response = self.query::<Response>(QUERY, variables).await?;
173
174 let data = response.into_data();
175 let page_info = data
176 .as_ref()
177 .and_then(|d| d.page_info.clone())
178 .unwrap_or_default();
179
180 let (coin_types, total_balances) = data
181 .map(|d| {
182 (
183 d.coin_types.unwrap_or_default(),
184 d.total_balances.unwrap_or_default(),
185 )
186 })
187 .unwrap_or_default();
188
189 let balances: Vec<Balance> = coin_types
191 .into_iter()
192 .zip(total_balances)
193 .filter_map(|(ct, tb)| match (ct, tb) {
194 (Some(coin_type), Some(total_balance)) => Some((coin_type, total_balance)),
195 _ => None,
196 })
197 .map(|(coin_type, total_balance)| Balance {
198 coin_type,
199 total_balance: total_balance.0,
200 })
201 .collect();
202
203 Ok(Page {
204 items: balances,
205 has_next_page: page_info.has_next_page,
206 end_cursor: page_info.end_cursor,
207 ..Default::default()
208 })
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use futures::StreamExt;
216 use std::sync::Arc;
217 use std::sync::atomic::AtomicUsize;
218 use std::sync::atomic::Ordering;
219 use wiremock::Mock;
220 use wiremock::MockServer;
221 use wiremock::ResponseTemplate;
222 use wiremock::matchers::method;
223 use wiremock::matchers::path;
224
225 #[tokio::test]
226 async fn test_get_balance_found() {
227 let mock_server = MockServer::start().await;
228
229 Mock::given(method("POST"))
230 .and(path("/"))
231 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
232 "data": {
233 "address": {
234 "balance": {
235 "coinType": {
236 "repr": "0x2::sui::SUI"
237 },
238 "totalBalance": "1000000000"
239 }
240 }
241 }
242 })))
243 .mount(&mock_server)
244 .await;
245
246 let client = Client::new(&mock_server.uri()).unwrap();
247 let owner: Address = "0x1".parse().unwrap();
248
249 let result = client.get_balance(owner, &StructTag::sui()).await;
250 assert!(result.is_ok());
251
252 let balance = result.unwrap();
253 assert!(balance.is_some());
254
255 let balance = balance.unwrap();
256 assert_eq!(balance.coin_type, StructTag::sui());
257 assert_eq!(balance.total_balance, 1000000000);
258 }
259
260 #[tokio::test]
261 async fn test_get_balance_not_found() {
262 let mock_server = MockServer::start().await;
263
264 Mock::given(method("POST"))
265 .and(path("/"))
266 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
267 "data": {
268 "address": {
269 "balance": null
270 }
271 }
272 })))
273 .mount(&mock_server)
274 .await;
275
276 let client = Client::new(&mock_server.uri()).unwrap();
277 let owner: Address = "0x1".parse().unwrap();
278
279 let result = client.get_balance(owner, &StructTag::sui()).await;
280 assert!(result.is_ok());
281 assert!(result.unwrap().is_none());
282 }
283
284 #[tokio::test]
285 async fn test_get_balance_invalid_number() {
286 let mock_server = MockServer::start().await;
287
288 Mock::given(method("POST"))
289 .and(path("/"))
290 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
291 "data": {
292 "address": {
293 "balance": {
294 "coinType": {
295 "repr": "0x2::sui::SUI"
296 },
297 "totalBalance": "not_a_number"
298 }
299 }
300 }
301 })))
302 .mount(&mock_server)
303 .await;
304
305 let client = Client::new(&mock_server.uri()).unwrap();
306 let owner: Address = "0x1".parse().unwrap();
307
308 let result = client.get_balance(owner, &StructTag::sui()).await;
309 assert!(matches!(result, Err(Error::Request(e)) if e.is_decode()));
310 }
311
312 #[tokio::test]
313 async fn test_list_balances_empty() {
314 let mock_server = MockServer::start().await;
315
316 Mock::given(method("POST"))
317 .and(path("/"))
318 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
319 "data": {
320 "address": {
321 "balances": {
322 "pageInfo": {
323 "hasNextPage": false,
324 "endCursor": null
325 },
326 "nodes": []
327 }
328 }
329 }
330 })))
331 .mount(&mock_server)
332 .await;
333
334 let client = Client::new(&mock_server.uri()).unwrap();
335 let owner: Address = "0x1".parse().unwrap();
336
337 let stream = client.list_balances(owner);
338 let balances: Vec<_> = stream.collect().await;
339
340 assert!(balances.is_empty());
341 }
342
343 #[tokio::test]
344 async fn test_list_balances_multiple() {
345 let mock_server = MockServer::start().await;
346
347 Mock::given(method("POST"))
348 .and(path("/"))
349 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
350 "data": {
351 "address": {
352 "balances": {
353 "pageInfo": {
354 "hasNextPage": false,
355 "endCursor": null
356 },
357 "nodes": [
358 {
359 "coinType": { "repr": "0x2::sui::SUI" },
360 "totalBalance": "1000000000"
361 },
362 {
363 "coinType": { "repr": "0xabc::token::USDC" },
364 "totalBalance": "500000"
365 }
366 ]
367 }
368 }
369 }
370 })))
371 .mount(&mock_server)
372 .await;
373
374 let client = Client::new(&mock_server.uri()).unwrap();
375 let owner: Address = "0x1".parse().unwrap();
376
377 let stream = client.list_balances(owner);
378 let balances: Vec<_> = stream.collect().await;
379
380 assert_eq!(balances.len(), 2);
381 assert!(balances[0].is_ok());
382 assert!(balances[1].is_ok());
383
384 let bal1 = balances[0].as_ref().unwrap();
385 assert_eq!(bal1.coin_type, StructTag::sui());
386 assert_eq!(bal1.total_balance, 1000000000);
387
388 let bal2 = balances[1].as_ref().unwrap();
389 let usdc: StructTag = "0xabc::token::USDC".parse().unwrap();
390 assert_eq!(bal2.coin_type, usdc);
391 assert_eq!(bal2.total_balance, 500000);
392 }
393
394 #[tokio::test]
395 async fn test_list_balances_with_pagination() {
396 let mock_server = MockServer::start().await;
397 let call_count = Arc::new(AtomicUsize::new(0));
398 let call_count_clone = call_count.clone();
399
400 Mock::given(method("POST"))
401 .and(path("/"))
402 .respond_with(move |_req: &wiremock::Request| {
403 let count = call_count_clone.fetch_add(1, Ordering::SeqCst);
404 match count {
405 0 => ResponseTemplate::new(200).set_body_json(serde_json::json!({
406 "data": {
407 "address": {
408 "balances": {
409 "pageInfo": {
410 "hasNextPage": true,
411 "endCursor": "cursor1"
412 },
413 "nodes": [
414 {
415 "coinType": { "repr": "0x2::sui::SUI" },
416 "totalBalance": "1000000000"
417 }
418 ]
419 }
420 }
421 }
422 })),
423 1 => ResponseTemplate::new(200).set_body_json(serde_json::json!({
424 "data": {
425 "address": {
426 "balances": {
427 "pageInfo": {
428 "hasNextPage": false,
429 "endCursor": null
430 },
431 "nodes": [
432 {
433 "coinType": { "repr": "0xabc::token::USDC" },
434 "totalBalance": "500000"
435 }
436 ]
437 }
438 }
439 }
440 })),
441 _ => ResponseTemplate::new(200).set_body_json(serde_json::json!({
442 "data": {
443 "address": {
444 "balances": {
445 "pageInfo": { "hasNextPage": false, "endCursor": null },
446 "nodes": []
447 }
448 }
449 }
450 })),
451 }
452 })
453 .mount(&mock_server)
454 .await;
455
456 let client = Client::new(&mock_server.uri()).unwrap();
457 let owner: Address = "0x1".parse().unwrap();
458
459 let stream = client.list_balances(owner);
460 let balances: Vec<_> = stream.collect().await;
461
462 assert_eq!(balances.len(), 2);
463 assert_eq!(call_count.load(Ordering::SeqCst), 2);
464
465 let bal1 = balances[0].as_ref().unwrap();
466 assert_eq!(bal1.coin_type, StructTag::sui());
467 assert_eq!(bal1.total_balance, 1000000000);
468
469 let bal2 = balances[1].as_ref().unwrap();
470 let usdc: StructTag = "0xabc::token::USDC".parse().unwrap();
471 assert_eq!(bal2.coin_type, usdc);
472 assert_eq!(bal2.total_balance, 500000);
473 }
474}