use async_trait::async_trait;
use serde_json::{json, Value};
use super::provider::{ChainFamily, ChainProvider, TxStatus};
use crate::core::types::ExchangeError;
const ETH_TOKEN_ADDRESS: &str =
"0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7";
const BALANCE_OF_SELECTOR: &str =
"0x2e4263afad30923c891518314c3c95dbe830a16874e8abc5777a9a20b54c76e";
#[async_trait]
pub trait StarkNetChain: ChainProvider {
async fn invoke(
&self,
contract_address: &str,
selector: &str,
calldata: &[String],
) -> Result<String, ExchangeError>;
async fn call_contract(
&self,
contract_address: &str,
selector: &str,
calldata: &[String],
) -> Result<Vec<String>, ExchangeError>;
async fn get_starknet_nonce(&self, address: &str) -> Result<String, ExchangeError>;
async fn get_receipt(&self, tx_hash: &str) -> Result<TxStatus, ExchangeError>;
}
pub struct StarkNetProvider {
rpc_url: String,
client: reqwest::Client,
chain_id: String,
request_id: std::sync::atomic::AtomicU64,
}
impl StarkNetProvider {
pub fn new(rpc_url: impl Into<String>, chain_id: impl Into<String>) -> Self {
Self {
rpc_url: rpc_url.into(),
client: reqwest::Client::new(),
chain_id: chain_id.into(),
request_id: std::sync::atomic::AtomicU64::new(1),
}
}
pub fn mainnet() -> Self {
Self::new("https://alpha-mainnet.starknet.io", "SN_MAIN")
}
pub fn sepolia() -> Self {
Self::new("https://alpha-sepolia.starknet.io", "SN_SEPOLIA")
}
fn next_id(&self) -> u64 {
self.request_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
}
async fn rpc_call(
&self,
method: &str,
params: Value,
) -> Result<Value, ExchangeError> {
let id = self.next_id();
let body = json!({
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params,
});
let response = self
.client
.post(&self.rpc_url)
.json(&body)
.send()
.await
.map_err(|e| ExchangeError::Network(format!("{}: {}", method, e)))?;
let raw: Value = response
.json()
.await
.map_err(|e| ExchangeError::Parse(format!("{}: failed to parse response: {}", method, e)))?;
if let Some(err_obj) = raw.get("error") {
let code = err_obj.get("code").and_then(Value::as_i64).unwrap_or(-1);
let msg = err_obj
.get("message")
.and_then(Value::as_str)
.unwrap_or("unknown RPC error");
return Err(ExchangeError::InvalidRequest(format!(
"{}: RPC error {}: {}",
method, code, msg
)));
}
raw.get("result")
.cloned()
.ok_or_else(|| ExchangeError::Parse(format!("{}: missing 'result' in response", method)))
}
fn felt_to_u64(felt: &str) -> Result<u64, ExchangeError> {
let stripped = felt.strip_prefix("0x").or_else(|| felt.strip_prefix("0X")).unwrap_or(felt);
u64::from_str_radix(stripped, 16).map_err(|e| {
ExchangeError::Parse(format!("failed to parse felt '{}' as u64: {}", felt, e))
})
}
fn felt_to_u128(felt: &str) -> Result<u128, ExchangeError> {
let stripped = felt.strip_prefix("0x").or_else(|| felt.strip_prefix("0X")).unwrap_or(felt);
u128::from_str_radix(stripped, 16).map_err(|e| {
ExchangeError::Parse(format!("failed to parse felt '{}' as u128: {}", felt, e))
})
}
}
#[async_trait]
impl ChainProvider for StarkNetProvider {
fn chain_family(&self) -> ChainFamily {
ChainFamily::StarkNet
}
async fn broadcast_tx(&self, tx_bytes: &[u8]) -> Result<String, ExchangeError> {
let tx_json: Value = serde_json::from_slice(tx_bytes).map_err(|e| {
ExchangeError::InvalidRequest(format!(
"broadcast_tx: tx_bytes must be a JSON-encoded invoke transaction: {}",
e
))
})?;
let result = self
.rpc_call("starknet_addInvokeTransaction", json!([tx_json]))
.await?;
result
.get("transaction_hash")
.and_then(Value::as_str)
.map(str::to_string)
.ok_or_else(|| {
ExchangeError::Parse(
"starknet_addInvokeTransaction: missing 'transaction_hash' in result"
.to_string(),
)
})
}
async fn get_height(&self) -> Result<u64, ExchangeError> {
let result = self.rpc_call("starknet_getBlockNumber", json!([])).await?;
result
.as_u64()
.ok_or_else(|| ExchangeError::Parse("starknet_getBlockNumber: non-integer result".to_string()))
}
async fn get_nonce(&self, address: &str) -> Result<u64, ExchangeError> {
let nonce_felt = self.get_starknet_nonce(address).await?;
Self::felt_to_u64(&nonce_felt)
}
async fn get_native_balance(&self, address: &str) -> Result<String, ExchangeError> {
let ret = self
.call_contract(ETH_TOKEN_ADDRESS, BALANCE_OF_SELECTOR, &[address.to_string()])
.await?;
if ret.len() < 2 {
return Err(ExchangeError::Parse(format!(
"get_native_balance: expected 2 felt return values (Uint256), got {}",
ret.len()
)));
}
let low: u128 = Self::felt_to_u128(&ret[0])?;
let high: u128 = Self::felt_to_u128(&ret[1])?;
let balance = if high == 0 {
low.to_string()
} else {
let two_128_str = "340282366920938463463374607431768211456";
let approx = (high as f64) * 340282366920938463463374607431768211456_f64 + low as f64;
let _ = two_128_str; format!("{:.0}", approx)
};
Ok(balance)
}
async fn get_tx_status(&self, tx_hash: &str) -> Result<TxStatus, ExchangeError> {
self.get_receipt(tx_hash).await
}
}
#[async_trait]
impl StarkNetChain for StarkNetProvider {
async fn invoke(
&self,
contract_address: &str,
selector: &str,
calldata: &[String],
) -> Result<String, ExchangeError> {
let tx = json!({
"type": "INVOKE",
"sender_address": contract_address,
"calldata": calldata,
"entry_point_selector": selector,
});
let result = self
.rpc_call("starknet_addInvokeTransaction", json!([tx]))
.await?;
result
.get("transaction_hash")
.and_then(Value::as_str)
.map(str::to_string)
.ok_or_else(|| {
ExchangeError::Parse(
"starknet_addInvokeTransaction (invoke): missing 'transaction_hash'"
.to_string(),
)
})
}
async fn call_contract(
&self,
contract_address: &str,
selector: &str,
calldata: &[String],
) -> Result<Vec<String>, ExchangeError> {
let request = json!({
"contract_address": contract_address,
"entry_point_selector": selector,
"calldata": calldata,
});
let result = self
.rpc_call("starknet_call", json!([request, "latest"]))
.await?;
let arr = result.as_array().ok_or_else(|| {
ExchangeError::Parse("starknet_call: expected array result".to_string())
})?;
arr.iter()
.map(|v| {
v.as_str()
.map(str::to_string)
.ok_or_else(|| ExchangeError::Parse("starknet_call: non-string felt in result".to_string()))
})
.collect()
}
async fn get_starknet_nonce(&self, address: &str) -> Result<String, ExchangeError> {
let result = self
.rpc_call("starknet_getNonce", json!(["latest", address]))
.await?;
result
.as_str()
.map(str::to_string)
.ok_or_else(|| {
ExchangeError::Parse("starknet_getNonce: expected string (felt) result".to_string())
})
}
async fn get_receipt(&self, tx_hash: &str) -> Result<TxStatus, ExchangeError> {
let result = self
.rpc_call("starknet_getTransactionReceipt", json!([tx_hash]))
.await;
match result {
Err(ExchangeError::InvalidRequest(msg)) if msg.contains("29") || msg.contains("not found") => {
return Ok(TxStatus::NotFound);
}
Err(e) => return Err(e),
Ok(receipt) => {
let finality = receipt
.get("finality_status")
.and_then(Value::as_str)
.unwrap_or("");
let execution = receipt
.get("execution_status")
.and_then(Value::as_str)
.unwrap_or("");
let legacy_status = receipt
.get("status")
.and_then(Value::as_str)
.unwrap_or("");
match (finality, execution, legacy_status) {
(_, "REVERTED", _) => {
let reason = receipt
.get("revert_reason")
.and_then(Value::as_str)
.unwrap_or("transaction reverted")
.to_string();
Ok(TxStatus::Failed { reason })
}
(_, _, "REJECTED") => Ok(TxStatus::Failed {
reason: "transaction rejected".to_string(),
}),
("ACCEPTED_ON_L1", _, _) | ("ACCEPTED_ON_L2", _, _) => {
let block = receipt
.get("block_number")
.and_then(Value::as_u64)
.unwrap_or(0);
Ok(TxStatus::Confirmed { block })
}
(_, _, "ACCEPTED_ON_L1") | (_, _, "ACCEPTED_ON_L2") => {
let block = receipt
.get("block_number")
.and_then(Value::as_u64)
.unwrap_or(0);
Ok(TxStatus::Confirmed { block })
}
_ => Ok(TxStatus::Pending),
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chain_family() {
let provider = StarkNetProvider::mainnet();
assert_eq!(provider.chain_family(), ChainFamily::StarkNet);
}
#[test]
fn test_chain_family_sepolia() {
let provider = StarkNetProvider::sepolia();
assert_eq!(provider.chain_family(), ChainFamily::StarkNet);
assert_eq!(provider.chain_id, "SN_SEPOLIA");
}
#[test]
fn test_felt_to_u64_hex_prefix() {
assert_eq!(StarkNetProvider::felt_to_u64("0x3").unwrap(), 3u64);
assert_eq!(StarkNetProvider::felt_to_u64("0xff").unwrap(), 255u64);
assert_eq!(StarkNetProvider::felt_to_u64("0x0").unwrap(), 0u64);
}
#[test]
fn test_felt_to_u64_no_prefix() {
assert_eq!(StarkNetProvider::felt_to_u64("3").unwrap(), 3u64);
assert_eq!(StarkNetProvider::felt_to_u64("ff").unwrap(), 255u64);
}
#[test]
fn test_felt_to_u64_invalid() {
assert!(StarkNetProvider::felt_to_u64("0xzzzz").is_err());
assert!(StarkNetProvider::felt_to_u64("not_a_felt").is_err());
}
#[test]
fn test_felt_to_u128() {
assert_eq!(StarkNetProvider::felt_to_u128("0x1").unwrap(), 1u128);
assert_eq!(
StarkNetProvider::felt_to_u128("0xffffffffffffffffffffffffffffffff").unwrap(),
u128::MAX
);
}
#[test]
fn test_mainnet_rpc_url() {
let p = StarkNetProvider::mainnet();
assert_eq!(p.rpc_url, "https://alpha-mainnet.starknet.io");
assert_eq!(p.chain_id, "SN_MAIN");
}
#[test]
fn test_sepolia_rpc_url() {
let p = StarkNetProvider::sepolia();
assert_eq!(p.rpc_url, "https://alpha-sepolia.starknet.io");
}
#[test]
fn test_chain_family_name() {
let p = StarkNetProvider::mainnet();
assert_eq!(p.chain_family().name(), "starknet");
}
#[test]
fn test_request_id_increments() {
let p = StarkNetProvider::mainnet();
let id1 = p.next_id();
let id2 = p.next_id();
assert_eq!(id2, id1 + 1);
}
}