Skip to main content

relay_actions/
storage.rs

1use anyhow::{Context, Result};
2use base64::{Engine as _, prelude::BASE64_STANDARD};
3use serde::{Deserialize, Serialize};
4use std::{collections::HashMap, path::PathBuf};
5
6use crate::sinks::{progress::ProgressSink, result::ResultSink};
7use relay_lib::{
8    crypto::{SigningKey, StaticSecret, hex},
9    prelude::{Address, InboxId, KeyRecord},
10};
11
12#[derive(Clone, Serialize, Deserialize)]
13pub struct Identity {
14    pub pub_record: KeyRecord,
15    pub inboxes: Vec<InboxId>,
16    #[serde(with = "hex::serde")]
17    pub signing_key: [u8; 32],
18    #[serde(with = "hex::serde")]
19    pub static_secret: [u8; 32],
20}
21
22impl Identity {
23    pub fn signing_key(&self) -> SigningKey {
24        SigningKey::from_bytes(&self.signing_key)
25    }
26
27    pub fn static_secret(&self) -> StaticSecret {
28        StaticSecret::from(self.static_secret)
29    }
30
31    pub fn print(&self, level: u32, result: &dyn ResultSink) {
32        result.section(level, &self.pub_record.id);
33        result.bullet_label(
34            level + 1,
35            "Created",
36            &self.pub_record.created_at.to_string(),
37        );
38
39        if let Some(expires_at) = self.pub_record.expires_at {
40            result.bullet_label(level + 1, "Expires", &expires_at.to_string());
41        } else {
42            result.bullet(level + 1, "Never Expires");
43        }
44
45        if !self.inboxes.is_empty() {
46            let mut inboxes = self.inboxes.clone();
47            inboxes.sort_by(|a, b| a.canonical().cmp(b.canonical()));
48
49            result.section(level + 1, "Inboxes");
50            for inbox in inboxes.iter() {
51                result.bullet(level + 2, &inbox.to_string());
52            }
53        }
54    }
55}
56
57#[derive(Clone, Serialize, Deserialize)]
58pub struct StorageRoot {
59    pub identities: HashMap<Address, Identity>,
60}
61
62impl StorageRoot {
63    pub fn get_identity(&self, address: &Address) -> Option<&Identity> {
64        let mut address = address.clone();
65        address.inbox = None;
66        self.identities.get(&address)
67    }
68
69    pub fn export(&self, progress: &mut dyn ProgressSink, addresses: &[Address]) -> Result<String> {
70        progress.step("Validating addresses", "Validated addresses");
71        for address in addresses.iter() {
72            let mut addr = address.clone();
73            addr.inbox = None;
74            if !self.identities.contains_key(&addr) {
75                progress.abort(&format!(
76                    "No identity found for address {}",
77                    address.canonical()
78                ));
79            }
80            if address.inbox().is_some() {
81                progress.abort("Please specify addresses without inbox IDs");
82            }
83        }
84
85        progress.step("Exporting identities", "Exported identities");
86        let identities = self
87            .identities
88            .iter()
89            .filter(|(addr, _)| addresses.contains(addr))
90            .map(|(_, identity)| identity.clone())
91            .collect::<Vec<_>>();
92        let json = serde_json::to_vec(&identities)
93            .unwrap_or_else(|_| progress.abort("Failed to serialize identities"));
94        Ok(BASE64_STANDARD.encode(json))
95    }
96
97    pub fn import(
98        &mut self,
99        progress: &mut dyn ProgressSink,
100        data: &str,
101        replace: bool,
102    ) -> Result<Vec<Address>> {
103        progress.step("Decoding identity data", "Decoded identity data");
104        let json = BASE64_STANDARD
105            .decode(data)
106            .unwrap_or_else(|_| progress.abort("Failed to decode base64 identity data"));
107        let identities: Vec<Identity> = serde_json::from_slice(&json)
108            .unwrap_or_else(|_| progress.abort("Failed to deserialize identity data"));
109        let identities = identities
110            .into_iter()
111            .map(|identity| {
112                let mut address = Address::parse(&identity.pub_record.id).unwrap_or_else(|_| {
113                    progress.abort(&format!(
114                        "Failed to parse address from identity with ID {}",
115                        identity.pub_record.id
116                    ));
117                });
118                address.inbox = None;
119                (address, identity)
120            })
121            .collect::<HashMap<_, _>>();
122
123        progress.step(
124            "Checking for existing identities",
125            "Checked for existing identities",
126        );
127        for (addr, _) in identities.iter() {
128            if self.identities.contains_key(addr) {
129                if replace {
130                    progress.warn(&format!("Replacing existing identity for address {}", addr));
131                } else {
132                    progress.abort(&format!(
133                        "Identity for address {} already exists. Use --replace to overwrite.",
134                        addr
135                    ));
136                }
137            }
138        }
139
140        progress.step("Importing identities", "Imported identities");
141        let addresses: Vec<Address> = identities.keys().cloned().collect();
142        self.identities.extend(identities);
143
144        Ok(addresses)
145    }
146}
147
148#[derive(Clone)]
149pub struct Storage {
150    pub root: StorageRoot,
151}
152
153impl Default for Storage {
154    fn default() -> Self {
155        Self {
156            root: StorageRoot {
157                identities: HashMap::new(),
158            },
159        }
160    }
161}
162
163impl Storage {
164    pub fn path() -> Result<PathBuf> {
165        let mut path = dirs::data_dir().context("Could not find data directory")?;
166        path.push("relay_actions");
167        std::fs::create_dir_all(&path).context("Could not create data directory")?;
168        path.push("storage.ron");
169        Ok(path)
170    }
171
172    pub fn init() -> Result<Self> {
173        if !Self::path()?.exists() {
174            Self::default().save()?;
175        }
176
177        let ron = std::fs::read_to_string(Self::path()?).context("Could not read storage file")?;
178        let root: StorageRoot =
179            ron::de::from_str(&ron).context("Could not deserialize storage file")?;
180        Ok(Self { root })
181    }
182
183    pub fn save(&self) -> Result<()> {
184        let ron = ron::ser::to_string_pretty(&self.root, ron::ser::PrettyConfig::default())
185            .context("Could not serialize storage")?;
186        std::fs::write(Self::path()?, ron).context("Could not write storage file")?;
187        Ok(())
188    }
189
190    pub fn delete(&self) -> Result<()> {
191        std::fs::remove_file(Self::path()?).context("Could not delete storage file")?;
192
193        let path = Self::path()?;
194        let parent = path.parent().context("Could not get parent directory")?;
195        std::fs::remove_dir_all(parent).context("Could not delete storage directory")?;
196
197        Ok(())
198    }
199
200    pub fn reset(&mut self) {
201        *self = Self::default();
202    }
203}