balance_alerts_layer/
lib.rs1use alloy::network::Ethereum;
16use alloy::primitives::{Address, U256};
17use alloy::providers::{PendingTransactionBuilder, Provider, ProviderLayer, RootProvider};
18use alloy::transports::TransportResult;
19
20#[derive(Debug, Clone, Default)]
22pub struct BalanceAlertConfig {
23 pub watch_address: Address,
25 pub warn_threshold: Option<U256>,
27 pub error_threshold: Option<U256>,
29}
30
31#[derive(Debug, Clone, Default)]
32pub struct BalanceAlertLayer {
33 config: BalanceAlertConfig,
34}
35
36impl BalanceAlertLayer {
53 pub fn new(config: BalanceAlertConfig) -> Self {
54 Self { config }
55 }
56}
57
58impl<P> ProviderLayer<P> for BalanceAlertLayer
59where
60 P: Provider,
61{
62 type Provider = BalanceAlertProvider<P>;
63
64 fn layer(&self, inner: P) -> Self::Provider {
65 BalanceAlertProvider::new(inner, self.config.clone())
66 }
67}
68
69#[derive(Clone, Debug)]
70pub struct BalanceAlertProvider<P> {
71 inner: P,
72 config: BalanceAlertConfig,
73}
74
75impl<P> BalanceAlertProvider<P>
76where
77 P: Provider,
78{
79 #[allow(clippy::missing_const_for_fn)]
80 fn new(inner: P, config: BalanceAlertConfig) -> Self {
81 Self { inner, config }
82 }
83}
84
85#[async_trait::async_trait]
86impl<P> Provider for BalanceAlertProvider<P>
87where
88 P: Provider,
89{
90 #[inline(always)]
91 fn root(&self) -> &RootProvider {
92 self.inner.root()
93 }
94
95 async fn send_raw_transaction(
102 &self,
103 encoded_tx: &[u8],
104 ) -> TransportResult<PendingTransactionBuilder<Ethereum>> {
105 let res = self.inner.send_raw_transaction(encoded_tx).await;
106 let balance = self.inner.get_balance(self.config.watch_address).await?;
107
108 if balance < self.config.error_threshold.unwrap_or(U256::ZERO) {
109 tracing::error!(
110 "balance of {} < error threshold: {}",
111 self.config.watch_address,
112 balance
113 );
114 } else if balance < self.config.warn_threshold.unwrap_or(U256::ZERO) {
115 tracing::warn!(
116 "balance of {} < warning threshold: {}",
117 self.config.watch_address,
118 balance
119 );
120 } else {
121 tracing::trace!("balance of {} is: {}", self.config.watch_address, balance);
122 }
123 res
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use alloy::{
131 network::{EthereumWallet, TransactionBuilder},
132 node_bindings::Anvil,
133 primitives::utils::parse_ether,
134 providers::ProviderBuilder,
135 rpc::{client::RpcClient, types::TransactionRequest},
136 signers::local::LocalSigner,
137 };
138
139 async fn burn_eth(provider: impl Provider, amount: U256) -> anyhow::Result<()> {
140 let tx = TransactionRequest::default().with_to(Address::ZERO).with_value(amount);
141 provider.send_transaction(tx).await?.watch().await?;
142 Ok(())
143 }
144
145 #[tokio::test]
146 #[tracing_test::traced_test]
147 async fn test_balance_alert_layer() -> anyhow::Result<()> {
148 let anvil = Anvil::default().args(["--balance", "10"]).spawn();
150 let wallet = EthereumWallet::from(LocalSigner::from(anvil.keys()[0].clone()));
151 let client = RpcClient::builder().http(anvil.endpoint_url());
152
153 let balance_alerts_layer = BalanceAlertLayer::new(BalanceAlertConfig {
154 watch_address: wallet.default_signer().address(),
155 warn_threshold: Some(parse_ether("9").unwrap()),
156 error_threshold: Some(parse_ether("5").unwrap()),
157 });
158
159 let provider =
160 ProviderBuilder::new().layer(balance_alerts_layer).wallet(wallet).on_client(client);
161
162 burn_eth(&provider, parse_ether("0.5").unwrap()).await?;
163 assert!(!logs_contain("< warning threshold")); burn_eth(&provider, parse_ether("0.6").unwrap()).await?;
166 assert!(logs_contain("< warning threshold"));
167
168 burn_eth(&provider, parse_ether("6").unwrap()).await?;
169 assert!(logs_contain("< error threshold"));
170
171 Ok(())
172 }
173
174 #[tokio::test]
175 #[tracing_test::traced_test]
176 async fn test_balance_alert_layer_no_config() -> anyhow::Result<()> {
177 let anvil = Anvil::default().args(["--balance", "10"]).spawn();
179 let wallet = EthereumWallet::from(LocalSigner::from(anvil.keys()[0].clone()));
180 let client = RpcClient::builder().http(anvil.endpoint_url());
181
182 let balance_alerts_layer = BalanceAlertLayer::new(BalanceAlertConfig {
183 watch_address: wallet.default_signer().address(),
184 warn_threshold: None,
185 error_threshold: None,
186 });
187
188 let provider =
189 ProviderBuilder::new().layer(balance_alerts_layer).wallet(wallet).on_client(client);
190
191 burn_eth(&provider, parse_ether("0.5").unwrap()).await?;
193 assert!(!logs_contain("< warning threshold"));
194
195 burn_eth(&provider, parse_ether("0.6").unwrap()).await?;
196 assert!(!logs_contain("< warning threshold"));
197
198 burn_eth(&provider, parse_ether("6").unwrap()).await?;
199 assert!(!logs_contain("< error threshold"));
200
201 Ok(())
202 }
203}