Skip to main content

atuin_dotfiles/
store.rs

1use std::collections::BTreeMap;
2
3use atuin_client::record::sqlite_store::SqliteStore;
4// Sync aliases
5// This will be noticeable similar to the kv store, though I expect the two shall diverge
6// While we will support a range of shell config, I'd rather have a larger number of small records
7// + stores, rather than one mega config store.
8use atuin_common::record::{DecryptedData, Host, HostId};
9use atuin_common::utils::unquote;
10use eyre::{Result, bail, ensure, eyre};
11
12use atuin_client::record::encryption::PASETO_V4;
13use atuin_client::record::store::Store;
14
15use crate::shell::Alias;
16
17const CONFIG_SHELL_ALIAS_VERSION: &str = "v0";
18const CONFIG_SHELL_ALIAS_TAG: &str = "config-shell-alias";
19const CONFIG_SHELL_ALIAS_FIELD_MAX_LEN: usize = 20000; // 20kb max total len, way more than should be needed.
20
21mod alias;
22pub mod var;
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum AliasRecord {
26    Create(Alias),  // create a full record
27    Delete(String), // delete by name
28}
29
30impl AliasRecord {
31    pub fn serialize(&self) -> Result<DecryptedData> {
32        use rmp::encode;
33
34        let mut output = vec![];
35
36        match self {
37            AliasRecord::Create(alias) => {
38                encode::write_u8(&mut output, 0)?; // create
39                encode::write_array_len(&mut output, 2)?; // 2 fields
40
41                encode::write_str(&mut output, alias.name.as_str())?;
42                encode::write_str(&mut output, alias.value.as_str())?;
43            }
44            AliasRecord::Delete(name) => {
45                encode::write_u8(&mut output, 1)?; // delete
46                encode::write_array_len(&mut output, 1)?; // 1 field
47
48                encode::write_str(&mut output, name.as_str())?;
49            }
50        }
51
52        Ok(DecryptedData(output))
53    }
54
55    pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> {
56        use rmp::decode;
57
58        fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report {
59            eyre!("{err:?}")
60        }
61
62        match version {
63            CONFIG_SHELL_ALIAS_VERSION => {
64                let mut bytes = decode::Bytes::new(&data.0);
65
66                let record_type = decode::read_u8(&mut bytes).map_err(error_report)?;
67
68                match record_type {
69                    // create
70                    0 => {
71                        let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?;
72                        ensure!(
73                            nfields == 2,
74                            "too many entries in v0 shell alias create record"
75                        );
76
77                        let bytes = bytes.remaining_slice();
78
79                        let (key, bytes) =
80                            decode::read_str_from_slice(bytes).map_err(error_report)?;
81                        let (value, bytes) =
82                            decode::read_str_from_slice(bytes).map_err(error_report)?;
83
84                        if !bytes.is_empty() {
85                            bail!("trailing bytes in encoded shell alias record. malformed")
86                        }
87
88                        Ok(AliasRecord::Create(Alias {
89                            name: key.to_owned(),
90                            value: value.to_owned(),
91                        }))
92                    }
93
94                    // delete
95                    1 => {
96                        let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?;
97                        ensure!(
98                            nfields == 1,
99                            "too many entries in v0 shell alias delete record"
100                        );
101
102                        let bytes = bytes.remaining_slice();
103
104                        let (key, bytes) =
105                            decode::read_str_from_slice(bytes).map_err(error_report)?;
106
107                        if !bytes.is_empty() {
108                            bail!("trailing bytes in encoded shell alias record. malformed")
109                        }
110
111                        Ok(AliasRecord::Delete(key.to_owned()))
112                    }
113
114                    n => {
115                        bail!("unknown AliasRecord type {n}")
116                    }
117                }
118            }
119            _ => {
120                bail!("unknown version {version:?}")
121            }
122        }
123    }
124}
125
126#[derive(Debug, Clone)]
127pub struct AliasStore {
128    pub store: SqliteStore,
129    pub host_id: HostId,
130    pub encryption_key: [u8; 32],
131}
132
133impl AliasStore {
134    // will want to init the actual kv store when that is done
135    pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> AliasStore {
136        AliasStore {
137            store,
138            host_id,
139            encryption_key,
140        }
141    }
142
143    pub async fn posix(&self) -> Result<String> {
144        let aliases = self.aliases().await?;
145        Ok(Self::format_posix(&aliases))
146    }
147
148    pub async fn xonsh(&self) -> Result<String> {
149        let aliases = self.aliases().await?;
150        Ok(Self::format_xonsh(&aliases))
151    }
152
153    pub async fn powershell(&self) -> Result<String> {
154        let aliases = self.aliases().await?;
155        Ok(Self::format_powershell(&aliases))
156    }
157
158    fn format_posix(aliases: &[Alias]) -> String {
159        let mut config = String::new();
160
161        for alias in aliases {
162            // If it's quoted, remove the quotes. If it's not quoted, do nothing.
163            let value = unquote(alias.value.as_str()).unwrap_or(alias.value.clone());
164
165            // we're about to quote it ourselves anyway!
166            config.push_str(&format!("alias {}='{}'\n", alias.name, value));
167        }
168
169        config
170    }
171
172    fn format_xonsh(aliases: &[Alias]) -> String {
173        let mut config = String::new();
174
175        for alias in aliases {
176            config.push_str(&format!("aliases['{}'] ='{}'\n", alias.name, alias.value));
177        }
178
179        config
180    }
181
182    fn format_powershell(aliases: &[Alias]) -> String {
183        let mut config = String::new();
184
185        for alias in aliases {
186            config.push_str(&crate::shell::powershell::format_alias(alias));
187        }
188
189        config
190    }
191
192    pub async fn build(&self) -> Result<()> {
193        let dir = atuin_common::utils::dotfiles_cache_dir();
194        tokio::fs::create_dir_all(dir.clone()).await?;
195
196        let aliases = self.aliases().await?;
197
198        // Build for all supported shells
199        let posix = Self::format_posix(&aliases);
200        let xonsh = Self::format_xonsh(&aliases);
201        let powershell = Self::format_powershell(&aliases);
202
203        // All the same contents, maybe optimize in the future or perhaps there will be quirks
204        // per-shell
205        // I'd prefer separation atm
206        let zsh = dir.join("aliases.zsh");
207        let bash = dir.join("aliases.bash");
208        let fish = dir.join("aliases.fish");
209        let xsh = dir.join("aliases.xsh");
210        let ps1 = dir.join("aliases.ps1");
211
212        tokio::fs::write(zsh, &posix).await?;
213        tokio::fs::write(bash, &posix).await?;
214        tokio::fs::write(fish, &posix).await?;
215        tokio::fs::write(xsh, &xonsh).await?;
216        tokio::fs::write(ps1, &powershell).await?;
217
218        Ok(())
219    }
220
221    pub async fn set(&self, name: &str, value: &str) -> Result<()> {
222        if name.len() + value.len() > CONFIG_SHELL_ALIAS_FIELD_MAX_LEN {
223            return Err(eyre!(
224                "alias record too large: max len {} bytes",
225                CONFIG_SHELL_ALIAS_FIELD_MAX_LEN
226            ));
227        }
228
229        let record = AliasRecord::Create(Alias {
230            name: name.to_string(),
231            value: value.to_string(),
232        });
233
234        let bytes = record.serialize()?;
235
236        let idx = self
237            .store
238            .last(self.host_id, CONFIG_SHELL_ALIAS_TAG)
239            .await?
240            .map_or(0, |entry| entry.idx + 1);
241
242        let record = atuin_common::record::Record::builder()
243            .host(Host::new(self.host_id))
244            .version(CONFIG_SHELL_ALIAS_VERSION.to_string())
245            .tag(CONFIG_SHELL_ALIAS_TAG.to_string())
246            .idx(idx)
247            .data(bytes)
248            .build();
249
250        self.store
251            .push(&record.encrypt::<PASETO_V4>(&self.encryption_key))
252            .await?;
253
254        // set mutates shell config, so build again
255        self.build().await?;
256
257        Ok(())
258    }
259
260    pub async fn delete(&self, name: &str) -> Result<()> {
261        if name.len() > CONFIG_SHELL_ALIAS_FIELD_MAX_LEN {
262            return Err(eyre!(
263                "alias record too large: max len {} bytes",
264                CONFIG_SHELL_ALIAS_FIELD_MAX_LEN
265            ));
266        }
267
268        let record = AliasRecord::Delete(name.to_string());
269
270        let bytes = record.serialize()?;
271
272        let idx = self
273            .store
274            .last(self.host_id, CONFIG_SHELL_ALIAS_TAG)
275            .await?
276            .map_or(0, |entry| entry.idx + 1);
277
278        let record = atuin_common::record::Record::builder()
279            .host(Host::new(self.host_id))
280            .version(CONFIG_SHELL_ALIAS_VERSION.to_string())
281            .tag(CONFIG_SHELL_ALIAS_TAG.to_string())
282            .idx(idx)
283            .data(bytes)
284            .build();
285
286        self.store
287            .push(&record.encrypt::<PASETO_V4>(&self.encryption_key))
288            .await?;
289
290        // delete mutates shell config, so build again
291        self.build().await?;
292
293        Ok(())
294    }
295
296    pub async fn aliases(&self) -> Result<Vec<Alias>> {
297        let mut build = BTreeMap::new();
298
299        // this is sorted, oldest to newest
300        let tagged = self.store.all_tagged(CONFIG_SHELL_ALIAS_TAG).await?;
301
302        for record in tagged {
303            let version = record.version.clone();
304
305            let decrypted = match version.as_str() {
306                CONFIG_SHELL_ALIAS_VERSION => record.decrypt::<PASETO_V4>(&self.encryption_key)?,
307                version => bail!("unknown version {version:?}"),
308            };
309
310            let ar = AliasRecord::deserialize(&decrypted.data, version.as_str())?;
311
312            match ar {
313                AliasRecord::Create(a) => {
314                    build.insert(a.name.clone(), a);
315                }
316                AliasRecord::Delete(d) => {
317                    build.remove(&d);
318                }
319            }
320        }
321
322        Ok(build.into_values().collect())
323    }
324}
325
326#[cfg(test)]
327pub(crate) fn test_local_timeout() -> f64 {
328    std::env::var("ATUIN_TEST_LOCAL_TIMEOUT")
329        .ok()
330        .and_then(|x| x.parse().ok())
331        // this hardcoded value should be replaced by a simple way to get the
332        // default local_timeout of Settings if possible
333        .unwrap_or(2.0)
334}
335
336#[cfg(test)]
337mod tests {
338    use rand::rngs::OsRng;
339
340    use atuin_client::record::sqlite_store::SqliteStore;
341
342    use crate::shell::Alias;
343
344    use super::{AliasRecord, AliasStore, CONFIG_SHELL_ALIAS_VERSION, test_local_timeout};
345    use crypto_secretbox::{KeyInit, XSalsa20Poly1305};
346
347    #[test]
348    fn encode_decode() {
349        let record = Alias {
350            name: "k".to_owned(),
351            value: "kubectl".to_owned(),
352        };
353        let record = AliasRecord::Create(record);
354
355        let snapshot = [204, 0, 146, 161, 107, 167, 107, 117, 98, 101, 99, 116, 108];
356
357        let encoded = record.serialize().unwrap();
358        let decoded = AliasRecord::deserialize(&encoded, CONFIG_SHELL_ALIAS_VERSION).unwrap();
359
360        assert_eq!(encoded.0, &snapshot);
361        assert_eq!(decoded, record);
362    }
363
364    #[tokio::test]
365    async fn build_aliases() {
366        let store = SqliteStore::new(":memory:", test_local_timeout())
367            .await
368            .unwrap();
369        let key: [u8; 32] = XSalsa20Poly1305::generate_key(&mut OsRng).into();
370        let host_id = atuin_common::record::HostId(atuin_common::utils::uuid_v7());
371
372        let alias = AliasStore::new(store, host_id, key);
373
374        alias.set("k", "kubectl").await.unwrap();
375        alias.set("gp", "git push").await.unwrap();
376        alias
377            .set("kgap", "'kubectl get pods --all-namespaces'")
378            .await
379            .unwrap();
380
381        let mut aliases = alias.aliases().await.unwrap();
382
383        aliases.sort_by_key(|a| a.name.clone());
384
385        assert_eq!(aliases.len(), 3);
386
387        assert_eq!(
388            aliases[0],
389            Alias {
390                name: String::from("gp"),
391                value: String::from("git push")
392            }
393        );
394
395        assert_eq!(
396            aliases[1],
397            Alias {
398                name: String::from("k"),
399                value: String::from("kubectl")
400            }
401        );
402
403        assert_eq!(
404            aliases[2],
405            Alias {
406                name: String::from("kgap"),
407                value: String::from("'kubectl get pods --all-namespaces'")
408            }
409        );
410
411        let build = alias.posix().await.expect("failed to build aliases");
412
413        assert_eq!(
414            build,
415            "alias gp='git push'
416alias k='kubectl'
417alias kgap='kubectl get pods --all-namespaces'
418"
419        )
420    }
421}