balance_alerts_layer/
lib.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use alloy::network::Ethereum;
16use alloy::primitives::{Address, U256};
17use alloy::providers::{PendingTransactionBuilder, Provider, ProviderLayer, RootProvider};
18use alloy::transports::TransportResult;
19
20/// Configuration for the BalanceAlertLayer
21#[derive(Debug, Clone, Default)]
22pub struct BalanceAlertConfig {
23    /// Address to periodically check the balance of
24    pub watch_address: Address,
25    /// Threshold at which to log a warning
26    pub warn_threshold: Option<U256>,
27    /// Threshold at which to log an error
28    pub error_threshold: Option<U256>,
29}
30
31#[derive(Debug, Clone, Default)]
32pub struct BalanceAlertLayer {
33    config: BalanceAlertConfig,
34}
35
36/// A ProviderLayer that can be added to an alloy Provider
37/// to log warnings and errors when the balance of a given address
38/// falls below certain thresholds.
39///
40/// This checks the balance after every transaction sent via send_transaction
41/// and errors, warns or trace logs accordingly
42///
43/// # Examples
44/// ```ignore
45/// let provider = ProviderBuilder::new()
46///     .layer(BalanceAlertLayer::new(BalanceAlertConfig {
47///         watch_address: wallet.default_signer().address(),
48///         warn_threshold: parse_ether("0.1")?,
49///         error_threshold: parse_ether("0.1")?,
50///     }));
51/// ```
52impl BalanceAlertLayer {
53    pub fn new(config: BalanceAlertConfig) -> Self {
54        Self { config }
55    }
56}
57
58impl<P> ProviderLayer<P> for BalanceAlertLayer
59where
60    P: Provider,
61{
62    type Provider = BalanceAlertProvider<P>;
63
64    fn layer(&self, inner: P) -> Self::Provider {
65        BalanceAlertProvider::new(inner, self.config.clone())
66    }
67}
68
69#[derive(Clone, Debug)]
70pub struct BalanceAlertProvider<P> {
71    inner: P,
72    config: BalanceAlertConfig,
73}
74
75impl<P> BalanceAlertProvider<P>
76where
77    P: Provider,
78{
79    #[allow(clippy::missing_const_for_fn)]
80    fn new(inner: P, config: BalanceAlertConfig) -> Self {
81        Self { inner, config }
82    }
83}
84
85#[async_trait::async_trait]
86impl<P> Provider for BalanceAlertProvider<P>
87where
88    P: Provider,
89{
90    #[inline(always)]
91    fn root(&self) -> &RootProvider {
92        self.inner.root()
93    }
94
95    /// Broadcasts a raw transaction RLP bytes to the network.
96    ///
97    /// This override checks the watched address after sending the transaction and
98    /// logs a warning or error if the balance falls below the configured thresholds.
99    ///
100    /// See [`send_transaction`](Self::send_transaction) for more details.
101    async fn send_raw_transaction(
102        &self,
103        encoded_tx: &[u8],
104    ) -> TransportResult<PendingTransactionBuilder<Ethereum>> {
105        let res = self.inner.send_raw_transaction(encoded_tx).await;
106        let balance = self.inner.get_balance(self.config.watch_address).await?;
107
108        if balance < self.config.error_threshold.unwrap_or(U256::ZERO) {
109            tracing::error!(
110                "balance of {} < error threshold: {}",
111                self.config.watch_address,
112                balance
113            );
114        } else if balance < self.config.warn_threshold.unwrap_or(U256::ZERO) {
115            tracing::warn!(
116                "balance of {} < warning threshold: {}",
117                self.config.watch_address,
118                balance
119            );
120        } else {
121            tracing::trace!("balance of {} is: {}", self.config.watch_address, balance);
122        }
123        res
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use alloy::{
131        network::{EthereumWallet, TransactionBuilder},
132        node_bindings::Anvil,
133        primitives::utils::parse_ether,
134        providers::ProviderBuilder,
135        rpc::{client::RpcClient, types::TransactionRequest},
136        signers::local::LocalSigner,
137    };
138
139    async fn burn_eth(provider: impl Provider, amount: U256) -> anyhow::Result<()> {
140        let tx = TransactionRequest::default().with_to(Address::ZERO).with_value(amount);
141        provider.send_transaction(tx).await?.watch().await?;
142        Ok(())
143    }
144
145    #[tokio::test]
146    #[tracing_test::traced_test]
147    async fn test_balance_alert_layer() -> anyhow::Result<()> {
148        // Initial wallet balance is 10 eth, set up to warn if < 9 and error if < 5
149        let anvil = Anvil::default().args(["--balance", "10"]).spawn();
150        let wallet = EthereumWallet::from(LocalSigner::from(anvil.keys()[0].clone()));
151        let client = RpcClient::builder().http(anvil.endpoint_url());
152
153        let balance_alerts_layer = BalanceAlertLayer::new(BalanceAlertConfig {
154            watch_address: wallet.default_signer().address(),
155            warn_threshold: Some(parse_ether("9").unwrap()),
156            error_threshold: Some(parse_ether("5").unwrap()),
157        });
158
159        let provider =
160            ProviderBuilder::new().layer(balance_alerts_layer).wallet(wallet).on_client(client);
161
162        burn_eth(&provider, parse_ether("0.5").unwrap()).await?;
163        assert!(!logs_contain("< warning threshold")); // no log yet
164
165        burn_eth(&provider, parse_ether("0.6").unwrap()).await?;
166        assert!(logs_contain("< warning threshold"));
167
168        burn_eth(&provider, parse_ether("6").unwrap()).await?;
169        assert!(logs_contain("< error threshold"));
170
171        Ok(())
172    }
173
174    #[tokio::test]
175    #[tracing_test::traced_test]
176    async fn test_balance_alert_layer_no_config() -> anyhow::Result<()> {
177        // Initial wallet balance is 10 eth, set up to warn if < 9 and error if < 5
178        let anvil = Anvil::default().args(["--balance", "10"]).spawn();
179        let wallet = EthereumWallet::from(LocalSigner::from(anvil.keys()[0].clone()));
180        let client = RpcClient::builder().http(anvil.endpoint_url());
181
182        let balance_alerts_layer = BalanceAlertLayer::new(BalanceAlertConfig {
183            watch_address: wallet.default_signer().address(),
184            warn_threshold: None,
185            error_threshold: None,
186        });
187
188        let provider =
189            ProviderBuilder::new().layer(balance_alerts_layer).wallet(wallet).on_client(client);
190
191        // no warning or error logs should be emitted
192        burn_eth(&provider, parse_ether("0.5").unwrap()).await?;
193        assert!(!logs_contain("< warning threshold"));
194
195        burn_eth(&provider, parse_ether("0.6").unwrap()).await?;
196        assert!(!logs_contain("< warning threshold"));
197
198        burn_eth(&provider, parse_ether("6").unwrap()).await?;
199        assert!(!logs_contain("< error threshold"));
200
201        Ok(())
202    }
203}