use std::{
collections::HashMap,
str::FromStr,
sync::{Arc, RwLock},
};
use async_trait::async_trait;
use ed25519_dalek::SigningKey as Keypair;
use rand::thread_rng;
#[cfg(feature = "hd-wallet")]
use crate::{constants::*, keyring::traits::DerivedKeypair};
use crate::{
error::{Result, RialoError},
keyring::{
provider_base::BaseKeyringProvider,
traits::{Keyring, KeyringProvider},
},
rpc::types::Pubkey,
};
pub struct InMemoryKeyringProvider {
keyrings: Arc<RwLock<HashMap<String, (Keyring, String)>>>,
}
impl Default for InMemoryKeyringProvider {
fn default() -> Self {
Self::new()
}
}
impl InMemoryKeyringProvider {
pub fn new() -> Self {
Self {
keyrings: Arc::new(RwLock::new(HashMap::new())),
}
}
}
#[async_trait]
impl KeyringProvider for InMemoryKeyringProvider {
async fn create(&self, name: &str, password: &str) -> Result<Keyring> {
if self.exists(name).await? {
return Err(RialoError::Keyring(format!(
"Keyring already exists: {name}"
)));
}
let keypair = Keypair::generate(&mut thread_rng());
let keyring = Keyring::new(name.to_string(), keypair, None, None);
let mut keyrings = self.keyrings.write().unwrap();
keyrings.insert(name.to_string(), (keyring.clone(), password.to_string()));
Ok(keyring)
}
#[cfg(feature = "mnemonic")]
async fn create_with_mnemonic(&self, name: &str, password: &str) -> Result<Keyring> {
if self.exists(name).await? {
return Err(RialoError::Keyring(format!(
"Keyring already exists: {name}"
)));
}
let keypair = Keypair::generate(&mut thread_rng());
let mnemonic = "generated mnemonic".to_string(); let keyring = Keyring::new(name.to_string(), keypair, Some(mnemonic), None);
let mut keyrings = self.keyrings.write().unwrap();
keyrings.insert(name.to_string(), (keyring.clone(), password.to_string()));
Ok(keyring)
}
#[cfg(feature = "mnemonic")]
async fn recover_from_mnemonic(
&self,
name: &str,
mnemonic: &str,
password: &str,
) -> Result<Keyring> {
if self.exists(name).await? {
return Err(RialoError::Keyring(format!(
"Keyring already exists: {name}"
)));
}
let keypair = Keypair::generate(&mut thread_rng());
let keyring = Keyring::new(name.to_string(), keypair, Some(mnemonic.to_string()), None);
let mut keyrings = self.keyrings.write().unwrap();
keyrings.insert(name.to_string(), (keyring.clone(), password.to_string()));
Ok(keyring)
}
async fn load(&self, name: &str, password: &str) -> Result<Keyring> {
let keyrings = self.keyrings.read().unwrap();
match keyrings.get(name) {
Some((keyring, stored_password)) if stored_password == password => Ok(keyring.clone()),
Some(_) => Err(RialoError::Password("Invalid password".to_string())),
None => Err(RialoError::Keyring(format!("Keyring not found: {name}"))),
}
}
async fn list(&self) -> Result<Vec<String>> {
let keyrings = self.keyrings.read().unwrap();
Ok(keyrings.keys().cloned().collect())
}
async fn exists(&self, name: &str) -> Result<bool> {
let keyrings = self.keyrings.read().unwrap();
Ok(keyrings.contains_key(name))
}
#[cfg(feature = "hd-wallet")]
async fn derive_keyring(
&self,
source_keyring_name: &str,
new_keyring_name: &str,
keypair_index: u32,
password: &str,
) -> Result<Keyring> {
if !self.exists(source_keyring_name).await? {
return Err(RialoError::Keyring(format!(
"Source keyring '{source_keyring_name}' not found"
)));
}
if self.exists(new_keyring_name).await? {
return Err(RialoError::Keyring(format!(
"Keyring '{new_keyring_name}' already exists"
)));
}
let source_keyring = self.load(source_keyring_name, password).await?;
let mnemonic = source_keyring.mnemonic().ok_or_else(|| {
RialoError::Keyring("Source keyring does not have a mnemonic phrase".to_string())
})?;
let path = format!("{}{}'/{}'", BASE_DERIVATION_PATH, keypair_index, 0);
let keypair = Keypair::generate(&mut thread_rng());
let keyring = Keyring::new(
new_keyring_name.to_string(),
keypair,
Some(mnemonic.to_string()),
Some(path),
);
let mut keyrings = self.keyrings.write().unwrap();
keyrings.insert(
new_keyring_name.to_string(),
(keyring.clone(), password.to_string()),
);
Ok(keyring)
}
async fn list_public_keys(&self) -> Result<Vec<(String, Pubkey)>> {
let keyrings = self.keyrings.read().unwrap();
let mut results = Vec::new();
for (name, (keyring, _)) in keyrings.iter() {
results.push((name.clone(), Pubkey::from_str(&keyring.pubkey_string())?));
}
Ok(results)
}
async fn list_keypairs(&self, keyring_name: &str) -> Result<Vec<(u32, Pubkey)>> {
let keyrings = self.keyrings.read().unwrap();
if let Some((keyring, _)) = keyrings.get(keyring_name) {
let mut keypairs = Vec::new();
for i in 0..keyring.keypairs.len() as u32 {
if let Some(keypair) = keyring.get_keypair(i) {
keypairs.push((i, Pubkey::from_str(&keypair.pubkey_string())?));
}
}
Ok(keypairs)
} else {
Err(RialoError::Keyring(format!(
"Keyring not found: {keyring_name}"
)))
}
}
async fn get_keypair_balance(&self, keyring_name: &str, _keypair_index: u32) -> Result<u64> {
if self.exists(keyring_name).await? {
Ok(0) } else {
Err(RialoError::Keyring(format!(
"Keyring not found: {keyring_name}"
)))
}
}
async fn get_public_key(&self, name: &str) -> Result<Pubkey> {
let keyrings = self.keyrings.read().unwrap();
if let Some((keyring, _)) = keyrings.get(name) {
Ok(Pubkey::from_str(&keyring.pubkey_string())?)
} else {
Err(RialoError::Keyring(format!("Keyring not found: {name}")))
}
}
async fn get_keypairs_info(&self, name: &str) -> Result<Vec<(u32, Pubkey)>> {
self.list_keypairs(name).await
}
async fn get_keypair_public_key(&self, name: &str, keypair_index: u32) -> Result<Pubkey> {
let keyrings = self.keyrings.read().unwrap();
if let Some((keyring, _)) = keyrings.get(name) {
if let Some(keypair) = keyring.get_keypair(keypair_index) {
Ok(Pubkey::from_str(&keypair.pubkey_string())?)
} else {
Err(RialoError::Keyring(format!(
"Keypair {keypair_index} not found in keyring {name}"
)))
}
} else {
Err(RialoError::Keyring(format!("Keyring not found: {name}")))
}
}
async fn next_keypair_index(&self, name: &str) -> Result<u32> {
let keyrings = self.keyrings.read().unwrap();
if let Some((keyring, _)) = keyrings.get(name) {
Ok(keyring.keypairs.len() as u32)
} else {
Err(RialoError::Keyring(format!("Keyring not found: {name}")))
}
}
#[cfg(feature = "hd-wallet")]
async fn derive_keypair(
&self,
keyring_name: &str,
keypair_index: u32,
password: &str,
) -> Result<(u32, Pubkey)> {
if !self.exists(keyring_name).await? {
return Err(RialoError::Keyring(format!(
"Keyring not found: {keyring_name}"
)));
}
let mut keyrings = self.keyrings.write().unwrap();
let (keyring, stored_password) = keyrings
.get_mut(keyring_name)
.ok_or_else(|| RialoError::Keyring(format!("Keyring not found: {keyring_name}")))?;
if stored_password != password {
return Err(RialoError::Password("Invalid password".to_string()));
}
let keypair = Keypair::generate(&mut thread_rng());
let derived = DerivedKeypair::new(keypair_index, keypair, None);
keyring.add_keypair(derived.clone());
Ok((keypair_index, Pubkey::from_str(&derived.pubkey_string())?))
}
}
#[async_trait]
impl BaseKeyringProvider for InMemoryKeyringProvider {}
#[deprecated(since = "0.2.0", note = "Use InMemoryKeyringProvider instead")]
pub type InMemoryWalletProvider = InMemoryKeyringProvider;