Skip to main content

sui_graphql/client/
coins.rs

1//! Coin-related convenience methods.
2
3use futures::Stream;
4use sui_graphql_macros::Response;
5use sui_sdk_types::Address;
6use sui_sdk_types::StructTag;
7
8use super::Client;
9use crate::error::Error;
10use crate::pagination::Page;
11use crate::pagination::PageInfo;
12use crate::pagination::paginate;
13use crate::scalars::BigInt;
14
15/// Balance information for a coin type.
16#[derive(Debug, Clone)]
17pub struct Balance {
18    /// The coin type (e.g., `0x2::sui::SUI`).
19    pub coin_type: StructTag,
20    /// The total balance in base units.
21    pub total_balance: u64,
22}
23
24impl Client {
25    /// Get the balance for a specific coin type owned by an address.
26    ///
27    /// # Example
28    ///
29    /// ```no_run
30    /// use sui_graphql::Client;
31    /// use sui_sdk_types::Address;
32    /// use sui_sdk_types::StructTag;
33    ///
34    /// #[tokio::main]
35    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
36    ///     let client = Client::new(Client::MAINNET)?;
37    ///     let owner: Address = "0x123...".parse()?;
38    ///
39    ///     // Get SUI balance using the helper
40    ///     let sui_balance = client.get_balance(owner, &StructTag::sui()).await?;
41    ///     if let Some(bal) = sui_balance {
42    ///         println!("SUI balance: {}", bal.total_balance);
43    ///     }
44    ///
45    ///     // Get balance for other coin types by parsing the type string
46    ///     let usdc_type: StructTag = "0xdba...::usdc::USDC".parse()?;
47    ///     let usdc_balance = client.get_balance(owner, &usdc_type).await?;
48    ///     if let Some(bal) = usdc_balance {
49    ///         println!("USDC balance: {}", bal.total_balance);
50    ///     }
51    ///
52    ///     Ok(())
53    /// }
54    /// ```
55    pub async fn get_balance(
56        &self,
57        owner: Address,
58        coin_type: &StructTag,
59    ) -> Result<Option<Balance>, Error> {
60        #[derive(Response)]
61        struct Response {
62            #[field(path = "address?.balance?.coinType?.repr?")]
63            coin_type: Option<StructTag>,
64            #[field(path = "address?.balance?.totalBalance?")]
65            total_balance: Option<BigInt>,
66        }
67
68        const QUERY: &str = r#"
69            query($owner: SuiAddress!, $coinType: String!) {
70                address(address: $owner) {
71                    balance(coinType: $coinType) {
72                        coinType {
73                            repr
74                        }
75                        totalBalance
76                    }
77                }
78            }
79        "#;
80
81        let variables = serde_json::json!({
82            "owner": owner,
83            "coinType": coin_type.to_string(),
84        });
85
86        let response = self.query::<Response>(QUERY, variables).await?;
87
88        let Some(data) = response.into_data() else {
89            return Ok(None);
90        };
91
92        match (data.coin_type, data.total_balance) {
93            (Some(coin_type), Some(total_balance)) => Ok(Some(Balance {
94                coin_type,
95                total_balance: total_balance.0,
96            })),
97            _ => Ok(None),
98        }
99    }
100
101    /// Stream all coin balances owned by an address.
102    ///
103    /// # Example
104    ///
105    /// ```no_run
106    /// use futures::StreamExt;
107    /// use std::pin::pin;
108    /// use sui_graphql::Client;
109    /// use sui_sdk_types::Address;
110    ///
111    /// #[tokio::main]
112    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
113    ///     let client = Client::new(Client::MAINNET)?;
114    ///     let owner: Address = "0x123...".parse()?;
115    ///
116    ///     let mut stream = pin!(client.list_balances(owner));
117    ///     while let Some(result) = stream.next().await {
118    ///         let balance = result?;
119    ///         println!("{}: {}", balance.coin_type, balance.total_balance);
120    ///     }
121    ///     Ok(())
122    /// }
123    /// ```
124    pub fn list_balances(&self, owner: Address) -> impl Stream<Item = Result<Balance, Error>> + '_ {
125        let client = self.clone();
126        paginate(move |cursor| {
127            let client = client.clone();
128            async move { client.fetch_balances_page(owner, cursor.as_deref()).await }
129        })
130    }
131
132    /// Fetch a single page of balances.
133    async fn fetch_balances_page(
134        &self,
135        owner: Address,
136        cursor: Option<&str>,
137    ) -> Result<Page<Balance>, Error> {
138        #[derive(Response)]
139        struct Response {
140            #[field(path = "address?.balances?.pageInfo?")]
141            page_info: Option<PageInfo>,
142            #[field(path = "address?.balances?.nodes?[].coinType?.repr?")]
143            coin_types: Option<Vec<Option<StructTag>>>,
144            #[field(path = "address?.balances?.nodes?[].totalBalance?")]
145            total_balances: Option<Vec<Option<BigInt>>>,
146        }
147
148        const QUERY: &str = r#"
149            query($owner: SuiAddress!, $after: String) {
150                address(address: $owner) {
151                    balances(after: $after) {
152                        pageInfo {
153                            hasNextPage
154                            endCursor
155                        }
156                        nodes {
157                            coinType {
158                                repr
159                            }
160                            totalBalance
161                        }
162                    }
163                }
164            }
165        "#;
166
167        let variables = serde_json::json!({
168            "owner": owner,
169            "after": cursor,
170        });
171
172        let response = self.query::<Response>(QUERY, variables).await?;
173
174        let data = response.into_data();
175        let page_info = data
176            .as_ref()
177            .and_then(|d| d.page_info.clone())
178            .unwrap_or_default();
179
180        let (coin_types, total_balances) = data
181            .map(|d| {
182                (
183                    d.coin_types.unwrap_or_default(),
184                    d.total_balances.unwrap_or_default(),
185                )
186            })
187            .unwrap_or_default();
188
189        // Zip coin_types and total_balances together
190        let balances: Vec<Balance> = coin_types
191            .into_iter()
192            .zip(total_balances)
193            .filter_map(|(ct, tb)| match (ct, tb) {
194                (Some(coin_type), Some(total_balance)) => Some((coin_type, total_balance)),
195                _ => None,
196            })
197            .map(|(coin_type, total_balance)| Balance {
198                coin_type,
199                total_balance: total_balance.0,
200            })
201            .collect();
202
203        Ok(Page {
204            items: balances,
205            has_next_page: page_info.has_next_page,
206            end_cursor: page_info.end_cursor,
207            ..Default::default()
208        })
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use futures::StreamExt;
216    use std::sync::Arc;
217    use std::sync::atomic::AtomicUsize;
218    use std::sync::atomic::Ordering;
219    use wiremock::Mock;
220    use wiremock::MockServer;
221    use wiremock::ResponseTemplate;
222    use wiremock::matchers::method;
223    use wiremock::matchers::path;
224
225    #[tokio::test]
226    async fn test_get_balance_found() {
227        let mock_server = MockServer::start().await;
228
229        Mock::given(method("POST"))
230            .and(path("/"))
231            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
232                "data": {
233                    "address": {
234                        "balance": {
235                            "coinType": {
236                                "repr": "0x2::sui::SUI"
237                            },
238                            "totalBalance": "1000000000"
239                        }
240                    }
241                }
242            })))
243            .mount(&mock_server)
244            .await;
245
246        let client = Client::new(&mock_server.uri()).unwrap();
247        let owner: Address = "0x1".parse().unwrap();
248
249        let result = client.get_balance(owner, &StructTag::sui()).await;
250        assert!(result.is_ok());
251
252        let balance = result.unwrap();
253        assert!(balance.is_some());
254
255        let balance = balance.unwrap();
256        assert_eq!(balance.coin_type, StructTag::sui());
257        assert_eq!(balance.total_balance, 1000000000);
258    }
259
260    #[tokio::test]
261    async fn test_get_balance_not_found() {
262        let mock_server = MockServer::start().await;
263
264        Mock::given(method("POST"))
265            .and(path("/"))
266            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
267                "data": {
268                    "address": {
269                        "balance": null
270                    }
271                }
272            })))
273            .mount(&mock_server)
274            .await;
275
276        let client = Client::new(&mock_server.uri()).unwrap();
277        let owner: Address = "0x1".parse().unwrap();
278
279        let result = client.get_balance(owner, &StructTag::sui()).await;
280        assert!(result.is_ok());
281        assert!(result.unwrap().is_none());
282    }
283
284    #[tokio::test]
285    async fn test_get_balance_invalid_number() {
286        let mock_server = MockServer::start().await;
287
288        Mock::given(method("POST"))
289            .and(path("/"))
290            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
291                "data": {
292                    "address": {
293                        "balance": {
294                            "coinType": {
295                                "repr": "0x2::sui::SUI"
296                            },
297                            "totalBalance": "not_a_number"
298                        }
299                    }
300                }
301            })))
302            .mount(&mock_server)
303            .await;
304
305        let client = Client::new(&mock_server.uri()).unwrap();
306        let owner: Address = "0x1".parse().unwrap();
307
308        let result = client.get_balance(owner, &StructTag::sui()).await;
309        assert!(matches!(result, Err(Error::Request(e)) if e.is_decode()));
310    }
311
312    #[tokio::test]
313    async fn test_list_balances_empty() {
314        let mock_server = MockServer::start().await;
315
316        Mock::given(method("POST"))
317            .and(path("/"))
318            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
319                "data": {
320                    "address": {
321                        "balances": {
322                            "pageInfo": {
323                                "hasNextPage": false,
324                                "endCursor": null
325                            },
326                            "nodes": []
327                        }
328                    }
329                }
330            })))
331            .mount(&mock_server)
332            .await;
333
334        let client = Client::new(&mock_server.uri()).unwrap();
335        let owner: Address = "0x1".parse().unwrap();
336
337        let stream = client.list_balances(owner);
338        let balances: Vec<_> = stream.collect().await;
339
340        assert!(balances.is_empty());
341    }
342
343    #[tokio::test]
344    async fn test_list_balances_multiple() {
345        let mock_server = MockServer::start().await;
346
347        Mock::given(method("POST"))
348            .and(path("/"))
349            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
350                "data": {
351                    "address": {
352                        "balances": {
353                            "pageInfo": {
354                                "hasNextPage": false,
355                                "endCursor": null
356                            },
357                            "nodes": [
358                                {
359                                    "coinType": { "repr": "0x2::sui::SUI" },
360                                    "totalBalance": "1000000000"
361                                },
362                                {
363                                    "coinType": { "repr": "0xabc::token::USDC" },
364                                    "totalBalance": "500000"
365                                }
366                            ]
367                        }
368                    }
369                }
370            })))
371            .mount(&mock_server)
372            .await;
373
374        let client = Client::new(&mock_server.uri()).unwrap();
375        let owner: Address = "0x1".parse().unwrap();
376
377        let stream = client.list_balances(owner);
378        let balances: Vec<_> = stream.collect().await;
379
380        assert_eq!(balances.len(), 2);
381        assert!(balances[0].is_ok());
382        assert!(balances[1].is_ok());
383
384        let bal1 = balances[0].as_ref().unwrap();
385        assert_eq!(bal1.coin_type, StructTag::sui());
386        assert_eq!(bal1.total_balance, 1000000000);
387
388        let bal2 = balances[1].as_ref().unwrap();
389        let usdc: StructTag = "0xabc::token::USDC".parse().unwrap();
390        assert_eq!(bal2.coin_type, usdc);
391        assert_eq!(bal2.total_balance, 500000);
392    }
393
394    #[tokio::test]
395    async fn test_list_balances_with_pagination() {
396        let mock_server = MockServer::start().await;
397        let call_count = Arc::new(AtomicUsize::new(0));
398        let call_count_clone = call_count.clone();
399
400        Mock::given(method("POST"))
401            .and(path("/"))
402            .respond_with(move |_req: &wiremock::Request| {
403                let count = call_count_clone.fetch_add(1, Ordering::SeqCst);
404                match count {
405                    0 => ResponseTemplate::new(200).set_body_json(serde_json::json!({
406                        "data": {
407                            "address": {
408                                "balances": {
409                                    "pageInfo": {
410                                        "hasNextPage": true,
411                                        "endCursor": "cursor1"
412                                    },
413                                    "nodes": [
414                                        {
415                                            "coinType": { "repr": "0x2::sui::SUI" },
416                                            "totalBalance": "1000000000"
417                                        }
418                                    ]
419                                }
420                            }
421                        }
422                    })),
423                    1 => ResponseTemplate::new(200).set_body_json(serde_json::json!({
424                        "data": {
425                            "address": {
426                                "balances": {
427                                    "pageInfo": {
428                                        "hasNextPage": false,
429                                        "endCursor": null
430                                    },
431                                    "nodes": [
432                                        {
433                                            "coinType": { "repr": "0xabc::token::USDC" },
434                                            "totalBalance": "500000"
435                                        }
436                                    ]
437                                }
438                            }
439                        }
440                    })),
441                    _ => ResponseTemplate::new(200).set_body_json(serde_json::json!({
442                        "data": {
443                            "address": {
444                                "balances": {
445                                    "pageInfo": { "hasNextPage": false, "endCursor": null },
446                                    "nodes": []
447                                }
448                            }
449                        }
450                    })),
451                }
452            })
453            .mount(&mock_server)
454            .await;
455
456        let client = Client::new(&mock_server.uri()).unwrap();
457        let owner: Address = "0x1".parse().unwrap();
458
459        let stream = client.list_balances(owner);
460        let balances: Vec<_> = stream.collect().await;
461
462        assert_eq!(balances.len(), 2);
463        assert_eq!(call_count.load(Ordering::SeqCst), 2);
464
465        let bal1 = balances[0].as_ref().unwrap();
466        assert_eq!(bal1.coin_type, StructTag::sui());
467        assert_eq!(bal1.total_balance, 1000000000);
468
469        let bal2 = balances[1].as_ref().unwrap();
470        let usdc: StructTag = "0xabc::token::USDC".parse().unwrap();
471        assert_eq!(bal2.coin_type, usdc);
472        assert_eq!(bal2.total_balance, 500000);
473    }
474}