Skip to main content

safe_rs/
wallet.rs

1//! Generic Wallet type for Safe and EOA accounts
2//!
3//! This module provides a `Wallet<A>` type that wraps any account implementing
4//! the `Account` trait, enabling generic code that works with both Safe and EOA.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use safe_rs::{WalletBuilder, WalletConfig, Account};
10//!
11//! // Connect to a Safe wallet using the fluent builder API
12//! let safe_wallet = WalletBuilder::new(provider, signer)
13//!     .connect(safe_address)
14//!     .await?;
15//!
16//! // Connect to an EOA wallet
17//! let eoa_wallet = WalletBuilder::new(provider, signer)
18//!     .connect_eoa()
19//!     .await?;
20//!
21//! // Deploy a new Safe and connect to it
22//! let builder = WalletBuilder::new(provider, signer);
23//! let address = builder.deploy(rpc_url, config.clone()).await?;
24//! let wallet = builder.connect(address).await?;
25//!
26//! // Generic function that works with any account type
27//! async fn do_something<A: Account>(wallet: &Wallet<A>) -> Result<()> {
28//!     wallet.batch()
29//!         .add_typed(token, IERC20::transferCall { to: recipient, amount })
30//!         .simulate().await?
31//!         .execute().await?;
32//!     Ok(())
33//! }
34//! ```
35
36use alloy::network::{AnyNetwork, EthereumWallet};
37use alloy::primitives::{Address, Bytes, U256};
38use alloy::providers::{Provider, ProviderBuilder};
39use alloy::signers::local::PrivateKeySigner;
40use url::Url;
41
42use crate::account::Account;
43use crate::chain::{ChainAddresses, ChainConfig};
44use crate::create2::{compute_create2_address, encode_setup_call};
45use crate::eoa::Eoa;
46use crate::error::{Error, Result};
47use crate::safe::{is_safe, ExecutionResult, Safe};
48use crate::types::Operation;
49use crate::ISafeProxyFactory;
50
51/// Configuration for Safe address computation and deployment
52#[derive(Debug, Clone)]
53pub struct WalletConfig {
54    /// Salt nonce for CREATE2 address computation (default: 0)
55    pub salt_nonce: U256,
56    /// Additional owners beyond the signer (default: empty)
57    pub additional_owners: Vec<Address>,
58    /// Threshold for the Safe (default: 1)
59    pub threshold: u64,
60    /// Fallback handler address (default: v1.4.1 fallback handler)
61    pub fallback_handler: Option<Address>,
62}
63
64impl Default for WalletConfig {
65    fn default() -> Self {
66        Self {
67            salt_nonce: U256::ZERO,
68            additional_owners: Vec::new(),
69            threshold: 1,
70            fallback_handler: None,
71        }
72    }
73}
74
75impl WalletConfig {
76    /// Creates a new WalletConfig with default values
77    pub fn new() -> Self {
78        Self::default()
79    }
80
81    /// Sets the salt nonce for CREATE2 address computation
82    pub fn with_salt_nonce(mut self, salt_nonce: U256) -> Self {
83        self.salt_nonce = salt_nonce;
84        self
85    }
86
87    /// Sets additional owners beyond the signer
88    pub fn with_additional_owners(mut self, owners: Vec<Address>) -> Self {
89        self.additional_owners = owners;
90        self
91    }
92
93    /// Sets the threshold for the Safe
94    pub fn with_threshold(mut self, threshold: u64) -> Self {
95        self.threshold = threshold;
96        self
97    }
98
99    /// Sets a custom fallback handler
100    pub fn with_fallback_handler(mut self, handler: Address) -> Self {
101        self.fallback_handler = Some(handler);
102        self
103    }
104
105    /// Builds the owners array (signer + additional owners)
106    fn build_owners(&self, signer_address: Address) -> Vec<Address> {
107        let mut owners = vec![signer_address];
108        for owner in &self.additional_owners {
109            if !owners.contains(owner) {
110                owners.push(*owner);
111            }
112        }
113        owners
114    }
115
116    /// Gets the fallback handler, using the v1.4.1 default if not specified
117    fn get_fallback_handler(&self) -> Address {
118        self.fallback_handler
119            .unwrap_or_else(|| ChainAddresses::v1_4_1().fallback_handler)
120    }
121}
122
123// =============================================================================
124// WalletBuilder
125// =============================================================================
126
127/// Builder for creating wallets with a fluent API.
128///
129/// This builder holds the provider and signer, allowing you to:
130/// - Connect to an existing Safe at a known address
131/// - Connect to a Safe at a computed CREATE2 address
132/// - Connect as an EOA (no Safe)
133/// - Deploy a new Safe and then connect to it
134///
135/// # Example
136///
137/// ```rust,ignore
138/// // Connect to existing Safe
139/// let wallet = WalletBuilder::new(provider, signer)
140///     .connect(address)
141///     .await?;
142///
143/// // Deploy then connect (builder not consumed by deploy)
144/// let builder = WalletBuilder::new(provider, signer);
145/// let address = builder.deploy(rpc_url, config.clone()).await?;
146/// let wallet = builder.connect(address).await?;
147/// ```
148pub struct WalletBuilder<P> {
149    provider: P,
150    signer: PrivateKeySigner,
151}
152
153impl<P> WalletBuilder<P> {
154    /// Creates a new WalletBuilder with the given provider and signer.
155    pub fn new(provider: P, signer: PrivateKeySigner) -> Self {
156        Self { provider, signer }
157    }
158
159    /// Returns a reference to the signer.
160    pub fn signer(&self) -> &PrivateKeySigner {
161        &self.signer
162    }
163
164    /// Returns the signer's address.
165    pub fn signer_address(&self) -> Address {
166        self.signer.address()
167    }
168
169    /// Returns a reference to the provider.
170    pub fn provider(&self) -> &P {
171        &self.provider
172    }
173}
174
175impl<P> WalletBuilder<P>
176where
177    P: Provider<AnyNetwork> + Clone + 'static,
178{
179    /// Connects to an existing Safe at the given address.
180    ///
181    /// # Arguments
182    /// * `address` - The Safe contract address
183    ///
184    /// # Example
185    ///
186    /// ```rust,ignore
187    /// let wallet = WalletBuilder::new(provider, signer)
188    ///     .connect(safe_address)
189    ///     .await?;
190    /// ```
191    pub async fn connect(self, address: Address) -> Result<Wallet<Safe<P>>> {
192        let safe = Safe::connect(self.provider, self.signer, address).await?;
193        Ok(Wallet::from_account(safe))
194    }
195
196    /// Connects to a Safe at the computed CREATE2 address for the given config.
197    ///
198    /// This computes the deterministic Safe address based on the signer and config,
199    /// then connects to it. Returns an error if no Safe is deployed at that address.
200    ///
201    /// # Arguments
202    /// * `config` - Configuration for Safe address computation
203    ///
204    /// # Example
205    ///
206    /// ```rust,ignore
207    /// let config = WalletConfig::new().with_salt_nonce(U256::from(42));
208    /// let wallet = WalletBuilder::new(provider, signer)
209    ///     .connect_with_config(config)
210    ///     .await?;
211    /// ```
212    pub async fn connect_with_config(self, config: WalletConfig) -> Result<Wallet<Safe<P>>> {
213        let safe_address = self.compute_address(&config).await?;
214
215        // Check if Safe is deployed
216        if !is_safe(&self.provider, safe_address).await? {
217            return Err(Error::InvalidConfig(format!(
218                "No Safe deployed at computed address {}",
219                safe_address
220            )));
221        }
222
223        let safe = Safe::connect(self.provider, self.signer, safe_address).await?;
224        Ok(Wallet::from_account(safe))
225    }
226
227    /// Connects as an EOA (no Safe).
228    ///
229    /// # Arguments
230    /// * `rpc_url` - The RPC URL for sending signed transactions
231    ///
232    /// # Example
233    ///
234    /// ```rust,ignore
235    /// let wallet = WalletBuilder::new(provider, signer)
236    ///     .connect_eoa(rpc_url)
237    ///     .await?;
238    /// ```
239    pub async fn connect_eoa(self, rpc_url: Url) -> Result<Wallet<Eoa<P>>> {
240        let eoa = Eoa::connect(self.provider, self.signer, rpc_url).await?;
241        Ok(Wallet::from_account(eoa))
242    }
243
244    /// Computes the Safe address that would be used for the given config.
245    ///
246    /// This is useful for checking what Safe address would be computed without
247    /// actually connecting or deploying.
248    ///
249    /// # Arguments
250    /// * `config` - Configuration for Safe address computation
251    ///
252    /// # Example
253    ///
254    /// ```rust,ignore
255    /// let builder = WalletBuilder::new(provider, signer);
256    /// let config = WalletConfig::new().with_salt_nonce(U256::from(42));
257    /// let address = builder.compute_address(&config).await?;
258    /// ```
259    pub async fn compute_address(&self, config: &WalletConfig) -> Result<Address> {
260        let addresses = ChainAddresses::v1_4_1();
261        let signer_address = self.signer.address();
262
263        // Build owners array
264        let owners = config.build_owners(signer_address);
265
266        // Get fallback handler
267        let fallback_handler = config.get_fallback_handler();
268
269        // Encode initializer
270        let initializer = encode_setup_call(&owners, config.threshold, fallback_handler);
271
272        // Get proxy creation code
273        let factory = ISafeProxyFactory::new(addresses.proxy_factory, &self.provider);
274        let creation_code = factory
275            .proxyCreationCode()
276            .call()
277            .await
278            .map_err(|e| Error::Fetch {
279                what: "proxy creation code",
280                reason: e.to_string(),
281            })?;
282
283        // Compute deterministic address
284        let safe_address = compute_create2_address(
285            addresses.proxy_factory,
286            addresses.safe_singleton,
287            &initializer,
288            config.salt_nonce,
289            &creation_code,
290        );
291
292        Ok(safe_address)
293    }
294
295    /// Deploys a Safe with the given configuration. Idempotent.
296    ///
297    /// If a Safe already exists at the computed address, returns that address
298    /// without deploying. Otherwise, deploys a new Safe.
299    ///
300    /// Uses `&self` so the builder can be reused for `connect()` afterward.
301    ///
302    /// # Arguments
303    /// * `rpc_url` - The RPC URL for sending the deployment transaction
304    /// * `config` - Configuration for Safe deployment
305    ///
306    /// # Example
307    ///
308    /// ```rust,ignore
309    /// let builder = WalletBuilder::new(provider, signer);
310    /// let config = WalletConfig::new().with_salt_nonce(U256::from(42));
311    /// let address = builder.deploy(rpc_url, config.clone()).await?;
312    /// let wallet = builder.connect(address).await?;
313    /// ```
314    pub async fn deploy(&self, rpc_url: Url, config: WalletConfig) -> Result<Address> {
315        let addresses = ChainAddresses::v1_4_1();
316        let signer_address = self.signer.address();
317
318        // Build owners array
319        let owners = config.build_owners(signer_address);
320
321        // Validate threshold
322        if config.threshold == 0 || config.threshold as usize > owners.len() {
323            return Err(Error::InvalidConfig(format!(
324                "Invalid threshold: {} (must be 1-{})",
325                config.threshold,
326                owners.len()
327            )));
328        }
329
330        // Get fallback handler
331        let fallback_handler = config.get_fallback_handler();
332
333        // Encode initializer
334        let initializer = encode_setup_call(&owners, config.threshold, fallback_handler);
335
336        // Get proxy creation code
337        let factory = ISafeProxyFactory::new(addresses.proxy_factory, &self.provider);
338        let creation_code = factory
339            .proxyCreationCode()
340            .call()
341            .await
342            .map_err(|e| Error::Fetch {
343                what: "proxy creation code",
344                reason: e.to_string(),
345            })?;
346
347        // Compute deterministic address
348        let safe_address = compute_create2_address(
349            addresses.proxy_factory,
350            addresses.safe_singleton,
351            &initializer,
352            config.salt_nonce,
353            &creation_code,
354        );
355
356        // Check if Safe is already deployed
357        if is_safe(&self.provider, safe_address).await? {
358            return Ok(safe_address);
359        }
360
361        // Deploy the Safe
362        let wallet_provider = ProviderBuilder::new()
363            .network::<AnyNetwork>()
364            .wallet(EthereumWallet::from(self.signer.clone()))
365            .connect_http(rpc_url);
366
367        let factory_with_wallet = ISafeProxyFactory::new(addresses.proxy_factory, &wallet_provider);
368
369        let pending_tx = factory_with_wallet
370            .createProxyWithNonce(addresses.safe_singleton, initializer, config.salt_nonce)
371            .send()
372            .await
373            .map_err(|e| Error::ExecutionFailed {
374                reason: format!("Failed to send deployment transaction: {}", e),
375            })?;
376
377        let _receipt = pending_tx.get_receipt().await.map_err(|e| Error::ExecutionFailed {
378            reason: format!("Failed to get deployment receipt: {}", e),
379        })?;
380
381        // Verify deployment
382        if !is_safe(&self.provider, safe_address).await? {
383            return Err(Error::ExecutionFailed {
384                reason: format!("Deployment failed: no Safe at expected address {}", safe_address),
385            });
386        }
387
388        Ok(safe_address)
389    }
390}
391
392/// A wallet that wraps any account type implementing the `Account` trait.
393///
394/// This provides a unified interface for both Safe and EOA wallets with
395/// compile-time polymorphism.
396///
397/// # Type Parameters
398///
399/// * `A` - The account type (e.g., `Safe<P>` or `Eoa<P>`)
400///
401/// # Example
402///
403/// ```rust,ignore
404/// // Connect to a Safe using the fluent builder API
405/// let wallet = WalletBuilder::new(provider, signer)
406///     .connect(safe_address)
407///     .await?;
408///
409/// // Use the unified batch API
410/// wallet.batch()
411///     .add_typed(token, call)
412///     .execute().await?;
413/// ```
414pub struct Wallet<A: Account> {
415    account: A,
416}
417
418impl<A: Account> Wallet<A> {
419    /// Creates a new wallet wrapping the given account.
420    pub fn from_account(account: A) -> Self {
421        Self { account }
422    }
423
424    /// Returns the wallet's address.
425    ///
426    /// For Safe wallets, returns the Safe contract address.
427    /// For EOA wallets, returns the signer address.
428    pub fn address(&self) -> Address {
429        self.account.address()
430    }
431
432    /// Returns the underlying signer address.
433    ///
434    /// For Safe wallets, this is the owner/signer address.
435    /// For EOA wallets, this is the same as `address()`.
436    pub fn signer_address(&self) -> Address {
437        self.account.signer_address()
438    }
439
440    /// Returns a reference to the provider.
441    pub fn provider(&self) -> &A::Provider {
442        self.account.provider()
443    }
444
445    /// Returns the chain configuration.
446    pub fn config(&self) -> &ChainConfig {
447        self.account.config()
448    }
449
450    /// Gets the current nonce for the account.
451    ///
452    /// For Safe wallets, this is the Safe's internal nonce.
453    /// For EOA wallets, this is the account's transaction count.
454    pub async fn nonce(&self) -> Result<U256> {
455        self.account.nonce().await
456    }
457
458    /// Creates a new builder for batching transactions.
459    ///
460    /// Returns `A::Builder<'_>` which implements `CallBuilder`.
461    ///
462    /// # Example
463    ///
464    /// ```rust,ignore
465    /// wallet.batch()
466    ///     .add_typed(token, IERC20::transferCall { to: recipient, amount })
467    ///     .simulate().await?
468    ///     .execute().await?;
469    /// ```
470    pub fn batch(&self) -> A::Builder<'_> {
471        self.account.batch()
472    }
473
474    /// Executes a single transaction.
475    ///
476    /// This is a convenience method for executing a single call without
477    /// the batch builder. For multiple calls, use `batch()` instead.
478    ///
479    /// # Errors
480    /// Returns `Error::UnsupportedEoaOperation` if `operation` is `DelegateCall`
481    /// and the wallet is an EOA.
482    pub async fn execute_single(
483        &self,
484        to: Address,
485        value: U256,
486        data: Bytes,
487        operation: Operation,
488    ) -> Result<ExecutionResult> {
489        self.account.execute_single(to, value, data, operation).await
490    }
491
492    /// Returns a reference to the underlying account.
493    pub fn inner(&self) -> &A {
494        &self.account
495    }
496
497    /// Consumes the wallet and returns the underlying account.
498    pub fn into_inner(self) -> A {
499        self.account
500    }
501}
502
503// =============================================================================
504// Safe-specific implementation
505// =============================================================================
506
507impl<P> Wallet<Safe<P>>
508where
509    P: Provider<AnyNetwork> + Clone + 'static,
510{
511    /// Returns true (this is a Safe wallet).
512    pub fn is_safe(&self) -> bool {
513        true
514    }
515
516    /// Returns false (this is not an EOA wallet).
517    pub fn is_eoa(&self) -> bool {
518        false
519    }
520
521    /// Returns a reference to the underlying Safe.
522    pub fn safe(&self) -> &Safe<P> {
523        self.inner()
524    }
525}
526
527// =============================================================================
528// EOA-specific implementation
529// =============================================================================
530
531impl<P> Wallet<Eoa<P>>
532where
533    P: Provider<AnyNetwork> + Clone + 'static,
534{
535    /// Returns false (this is not a Safe wallet).
536    pub fn is_safe(&self) -> bool {
537        false
538    }
539
540    /// Returns true (this is an EOA wallet).
541    pub fn is_eoa(&self) -> bool {
542        true
543    }
544
545    /// Returns a reference to the underlying Eoa.
546    pub fn eoa(&self) -> &Eoa<P> {
547        self.inner()
548    }
549}
550
551#[cfg(test)]
552mod tests {
553    use super::*;
554
555    #[test]
556    fn test_wallet_config_default() {
557        let config = WalletConfig::default();
558        assert_eq!(config.salt_nonce, U256::ZERO);
559        assert!(config.additional_owners.is_empty());
560        assert_eq!(config.threshold, 1);
561        assert!(config.fallback_handler.is_none());
562    }
563
564    #[test]
565    fn test_wallet_config_builder() {
566        use alloy::primitives::address;
567
568        let owner2 = address!("2222222222222222222222222222222222222222");
569        let handler = address!("fd0732Dc9E303f09fCEf3a7388Ad10A83459Ec99");
570
571        let config = WalletConfig::new()
572            .with_salt_nonce(U256::from(42))
573            .with_additional_owners(vec![owner2])
574            .with_threshold(2)
575            .with_fallback_handler(handler);
576
577        assert_eq!(config.salt_nonce, U256::from(42));
578        assert_eq!(config.additional_owners, vec![owner2]);
579        assert_eq!(config.threshold, 2);
580        assert_eq!(config.fallback_handler, Some(handler));
581    }
582
583    #[test]
584    fn test_wallet_config_build_owners() {
585        use alloy::primitives::address;
586
587        let signer = address!("1111111111111111111111111111111111111111");
588        let owner2 = address!("2222222222222222222222222222222222222222");
589        let owner3 = address!("3333333333333333333333333333333333333333");
590
591        let config = WalletConfig::new().with_additional_owners(vec![owner2, owner3]);
592        let owners = config.build_owners(signer);
593
594        assert_eq!(owners.len(), 3);
595        assert_eq!(owners[0], signer);
596        assert_eq!(owners[1], owner2);
597        assert_eq!(owners[2], owner3);
598    }
599
600    #[test]
601    fn test_wallet_config_build_owners_no_duplicates() {
602        use alloy::primitives::address;
603
604        let signer = address!("1111111111111111111111111111111111111111");
605        // Include signer in additional owners (should not duplicate)
606        let config = WalletConfig::new().with_additional_owners(vec![signer]);
607        let owners = config.build_owners(signer);
608
609        assert_eq!(owners.len(), 1);
610        assert_eq!(owners[0], signer);
611    }
612
613    #[test]
614    fn test_wallet_config_get_fallback_handler_default() {
615        let config = WalletConfig::default();
616        let handler = config.get_fallback_handler();
617        assert_eq!(handler, ChainAddresses::v1_4_1().fallback_handler);
618    }
619
620    #[test]
621    fn test_wallet_config_get_fallback_handler_custom() {
622        use alloy::primitives::address;
623
624        let custom_handler = address!("dead000000000000000000000000000000000000");
625        let config = WalletConfig::new().with_fallback_handler(custom_handler);
626        let handler = config.get_fallback_handler();
627        assert_eq!(handler, custom_handler);
628    }
629}