Skip to main content

atuin_client/record/
sync.rs

1// do a sync :O
2use std::{cmp::Ordering, fmt::Write};
3
4use eyre::Result;
5use thiserror::Error;
6
7use super::{encryption::PASETO_V4, store::Store};
8use crate::{api_client::Client, settings::Settings};
9
10use atuin_common::record::{Diff, HostId, RecordId, RecordIdx, RecordStatus};
11use indicatif::{ProgressBar, ProgressState, ProgressStyle};
12
13#[derive(Error, Debug)]
14pub enum SyncError {
15    #[error("the local store is ahead of the remote, but for another host. has remote lost data?")]
16    LocalAheadOtherHost,
17
18    #[error("an issue with the local database occurred: {msg:?}")]
19    LocalStoreError { msg: String },
20
21    #[error("something has gone wrong with the sync logic: {msg:?}")]
22    SyncLogicError { msg: String },
23
24    #[error("operational error: {msg:?}")]
25    OperationalError { msg: String },
26
27    #[error("a request to the sync server failed: {msg:?}")]
28    RemoteRequestError { msg: String },
29
30    #[error(
31        "the encryption key on this machine does not match the data on the server. \
32         this usually means a new machine was set up without copying the existing key. \
33         to fix: run `atuin key` on a machine that already syncs correctly, then run \
34         `atuin store rekey <key>` on this machine with the value from the other machine"
35    )]
36    WrongKey,
37}
38
39#[derive(Debug, Eq, PartialEq)]
40pub enum Operation {
41    // Either upload or download until the states matches the below
42    Upload {
43        local: RecordIdx,
44        remote: Option<RecordIdx>,
45        host: HostId,
46        tag: String,
47    },
48    Download {
49        local: Option<RecordIdx>,
50        remote: RecordIdx,
51        host: HostId,
52        tag: String,
53    },
54    Noop {
55        host: HostId,
56        tag: String,
57    },
58}
59
60pub async fn build_client(settings: &Settings) -> Result<Client<'_>, SyncError> {
61    Client::new(
62        &settings.sync_address,
63        settings
64            .sync_auth_token()
65            .await
66            .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?,
67        settings.network_connect_timeout,
68        settings.network_timeout,
69    )
70    .map_err(|e| SyncError::OperationalError { msg: e.to_string() })
71}
72
73pub async fn diff(
74    client: &Client<'_>,
75    store: &impl Store,
76) -> Result<(Vec<Diff>, RecordStatus), SyncError> {
77    let local_index = store
78        .status()
79        .await
80        .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?;
81
82    let remote_index = client
83        .record_status()
84        .await
85        .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?;
86
87    let diff = local_index.diff(&remote_index);
88
89    Ok((diff, remote_index))
90}
91
92// Take a diff, along with a local store, and resolve it into a set of operations.
93// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download.
94// In theory this could be done as a part of the diffing stage, but it's easier to reason
95// about and test this way
96pub async fn operations(
97    diffs: Vec<Diff>,
98    _store: &impl Store,
99) -> Result<Vec<Operation>, SyncError> {
100    let mut operations = Vec::with_capacity(diffs.len());
101
102    for diff in diffs {
103        let op = match (diff.local, diff.remote) {
104            // We both have it! Could be either. Compare.
105            (Some(local), Some(remote)) => match local.cmp(&remote) {
106                Ordering::Equal => Operation::Noop {
107                    host: diff.host,
108                    tag: diff.tag,
109                },
110                Ordering::Greater => Operation::Upload {
111                    local,
112                    remote: Some(remote),
113                    host: diff.host,
114                    tag: diff.tag,
115                },
116                Ordering::Less => Operation::Download {
117                    local: Some(local),
118                    remote,
119                    host: diff.host,
120                    tag: diff.tag,
121                },
122            },
123
124            // Remote has it, we don't. Gotta be download
125            (None, Some(remote)) => Operation::Download {
126                local: None,
127                remote,
128                host: diff.host,
129                tag: diff.tag,
130            },
131
132            // We have it, remote doesn't. Gotta be upload.
133            (Some(local), None) => Operation::Upload {
134                local,
135                remote: None,
136                host: diff.host,
137                tag: diff.tag,
138            },
139
140            // something is pretty fucked.
141            (None, None) => {
142                return Err(SyncError::SyncLogicError {
143                    msg: String::from(
144                        "diff has nothing for local or remote - (host, tag) does not exist",
145                    ),
146                });
147            }
148        };
149
150        operations.push(op);
151    }
152
153    // sort them - purely so we have a stable testing order, and can rely on
154    // same input = same output
155    // We can sort by ID so long as we continue to use UUIDv7 or something
156    // with the same properties
157
158    operations.sort_by_key(|op| match op {
159        Operation::Noop { host, tag } => (0, *host, tag.clone()),
160
161        Operation::Upload { host, tag, .. } => (1, *host, tag.clone()),
162
163        Operation::Download { host, tag, .. } => (2, *host, tag.clone()),
164    });
165
166    Ok(operations)
167}
168
169async fn sync_upload(
170    store: &impl Store,
171    client: &Client<'_>,
172    host: HostId,
173    tag: String,
174    local: RecordIdx,
175    remote: Option<RecordIdx>,
176    page_size: u64,
177) -> Result<i64, SyncError> {
178    let remote = remote.unwrap_or(0);
179    let expected = local - remote;
180    let mut progress = 0;
181
182    let pb = ProgressBar::new(expected);
183    pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})")
184        .unwrap()
185        .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap())
186        .progress_chars("#>-"));
187
188    println!(
189        "Uploading {} records to {}/{}",
190        expected,
191        host.0.as_simple(),
192        tag
193    );
194
195    loop {
196        let page = store
197            .next(host, tag.as_str(), remote + progress, page_size)
198            .await
199            .map_err(|e| {
200                error!("failed to read upload page: {e:?}");
201
202                SyncError::LocalStoreError { msg: e.to_string() }
203            })?;
204
205        if page.is_empty() {
206            break;
207        }
208
209        client.post_records(&page).await.map_err(|e| {
210            error!("failed to post records: {e:?}");
211
212            SyncError::RemoteRequestError { msg: e.to_string() }
213        })?;
214
215        progress += page.len() as u64;
216        pb.set_position(progress);
217
218        if progress >= expected {
219            break;
220        }
221    }
222
223    pb.finish_with_message("Uploaded records");
224
225    Ok(progress as i64)
226}
227
228async fn sync_download(
229    store: &impl Store,
230    client: &Client<'_>,
231    host: HostId,
232    tag: String,
233    local: Option<RecordIdx>,
234    remote: RecordIdx,
235    page_size: u64,
236) -> Result<Vec<RecordId>, SyncError> {
237    let local = local.unwrap_or(0);
238    let expected = remote - local;
239    let mut progress = 0;
240    let mut ret = Vec::new();
241
242    println!(
243        "Downloading {} records from {}/{}",
244        expected,
245        host.0.as_simple(),
246        tag
247    );
248
249    let pb = ProgressBar::new(expected);
250    pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})")
251        .unwrap()
252        .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap())
253        .progress_chars("#>-"));
254
255    loop {
256        let page = client
257            .next_records(host, tag.clone(), local + progress, page_size)
258            .await
259            .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?;
260
261        if page.is_empty() {
262            break;
263        }
264
265        store
266            .push_batch(page.iter())
267            .await
268            .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?;
269
270        ret.extend(page.iter().map(|f| f.id));
271
272        progress += page.len() as u64;
273        pb.set_position(progress);
274
275        if progress >= expected {
276            break;
277        }
278    }
279
280    pb.finish_with_message("Downloaded records");
281
282    Ok(ret)
283}
284
285pub async fn sync_remote(
286    client: &Client<'_>,
287    operations: Vec<Operation>,
288    local_store: &impl Store,
289    page_size: u64,
290) -> Result<(i64, Vec<RecordId>), SyncError> {
291    let mut uploaded = 0;
292    let mut downloaded = Vec::new();
293
294    // this can totally run in parallel, but lets get it working first
295    for i in operations {
296        match i {
297            Operation::Upload {
298                host,
299                tag,
300                local,
301                remote,
302            } => {
303                uploaded +=
304                    sync_upload(local_store, client, host, tag, local, remote, page_size).await?
305            }
306
307            Operation::Download {
308                host,
309                tag,
310                local,
311                remote,
312            } => {
313                let mut d =
314                    sync_download(local_store, client, host, tag, local, remote, page_size).await?;
315                downloaded.append(&mut d)
316            }
317
318            Operation::Noop { .. } => continue,
319        }
320    }
321
322    Ok((uploaded, downloaded))
323}
324
325pub async fn check_encryption_key(
326    client: &Client<'_>,
327    remote_index: &RecordStatus,
328    encryption_key: &[u8; 32],
329) -> Result<(), SyncError> {
330    let sample = remote_index
331        .hosts
332        .iter()
333        .flat_map(|(host, tags)| tags.keys().map(move |tag| (*host, tag.clone())))
334        .next();
335
336    let Some((host, tag)) = sample else {
337        return Ok(());
338    };
339
340    let records = client
341        .next_records(host, tag, 0, 1)
342        .await
343        .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?;
344
345    let Some(record) = records.into_iter().next() else {
346        return Ok(());
347    };
348
349    record
350        .decrypt::<PASETO_V4>(encryption_key)
351        .map_err(|_| SyncError::WrongKey)?;
352
353    Ok(())
354}
355
356pub async fn sync(
357    settings: &Settings,
358    store: &impl Store,
359    encryption_key: &[u8; 32],
360) -> Result<(i64, Vec<RecordId>), SyncError> {
361    let client = build_client(settings).await?;
362    let (diff, remote_index) = diff(&client, store).await?;
363
364    // Bail before mutating either side if the local key can't read the remote.
365    check_encryption_key(&client, &remote_index, encryption_key).await?;
366
367    let operations = operations(diff, store).await?;
368    let (uploaded, downloaded) = sync_remote(&client, operations, store, 100).await?;
369
370    Ok((uploaded, downloaded))
371}
372
373#[cfg(test)]
374mod tests {
375    use atuin_common::record::{Diff, EncryptedData, HostId, Record};
376    use pretty_assertions::assert_eq;
377
378    use crate::{
379        record::{
380            encryption::PASETO_V4,
381            sqlite_store::SqliteStore,
382            store::Store,
383            sync::{self, Operation},
384        },
385        settings::test_local_timeout,
386    };
387
388    fn test_record() -> Record<EncryptedData> {
389        Record::builder()
390            .host(atuin_common::record::Host::new(HostId(
391                atuin_common::utils::uuid_v7(),
392            )))
393            .version("v1".into())
394            .tag(atuin_common::utils::uuid_v7().simple().to_string())
395            .data(EncryptedData {
396                data: String::new(),
397                content_encryption_key: String::new(),
398            })
399            .idx(0)
400            .build()
401    }
402
403    // Take a list of local records, and a list of remote records.
404    // Return the local database, and a diff of local/remote, ready to build
405    // ops
406    async fn build_test_diff(
407        local_records: Vec<Record<EncryptedData>>,
408        remote_records: Vec<Record<EncryptedData>>,
409    ) -> (SqliteStore, Vec<Diff>) {
410        let local_store = SqliteStore::new(":memory:", test_local_timeout())
411            .await
412            .expect("failed to open in memory sqlite");
413        let remote_store = SqliteStore::new(":memory:", test_local_timeout())
414            .await
415            .expect("failed to open in memory sqlite"); // "remote"
416
417        for i in local_records {
418            local_store.push(&i).await.unwrap();
419        }
420
421        for i in remote_records {
422            remote_store.push(&i).await.unwrap();
423        }
424
425        let local_index = local_store.status().await.unwrap();
426        let remote_index = remote_store.status().await.unwrap();
427
428        let diff = local_index.diff(&remote_index);
429
430        (local_store, diff)
431    }
432
433    #[tokio::test]
434    async fn test_basic_diff() {
435        // a diff where local is ahead of remote. nothing else.
436
437        let record = test_record();
438        let (store, diff) = build_test_diff(vec![record.clone()], vec![]).await;
439
440        assert_eq!(diff.len(), 1);
441
442        let operations = sync::operations(diff, &store).await.unwrap();
443
444        assert_eq!(operations.len(), 1);
445
446        assert_eq!(
447            operations[0],
448            Operation::Upload {
449                host: record.host.id,
450                tag: record.tag,
451                local: record.idx,
452                remote: None,
453            }
454        );
455    }
456
457    #[tokio::test]
458    async fn build_two_way_diff() {
459        // a diff where local is ahead of remote for one, and remote for
460        // another. One upload, one download
461
462        let shared_record = test_record();
463        let remote_ahead = test_record();
464
465        let local_ahead = shared_record
466            .append(vec![1, 2, 3])
467            .encrypt::<PASETO_V4>(&[0; 32]);
468
469        assert_eq!(local_ahead.idx, 1);
470
471        let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store
472        let remote = vec![shared_record.clone(), remote_ahead.clone()]; // remote knows about the already-synced, and one new record in a new store
473
474        let (store, diff) = build_test_diff(local, remote).await;
475        let operations = sync::operations(diff, &store).await.unwrap();
476
477        assert_eq!(operations.len(), 2);
478
479        assert_eq!(
480            operations,
481            vec![
482                // Or in otherwords, local is ahead by one
483                Operation::Upload {
484                    host: local_ahead.host.id,
485                    tag: local_ahead.tag,
486                    local: 1,
487                    remote: Some(0),
488                },
489                // Or in other words, remote knows of a record in an entirely new store (tag)
490                Operation::Download {
491                    host: remote_ahead.host.id,
492                    tag: remote_ahead.tag,
493                    local: None,
494                    remote: 0,
495                },
496            ]
497        );
498    }
499
500    #[tokio::test]
501    async fn build_complex_diff() {
502        // One shared, ahead but known only by remote
503        // One known only by local
504        // One known only by remote
505
506        let shared_record = test_record();
507        let local_only = test_record();
508
509        let local_only_20 = test_record();
510        let local_only_21 = local_only_20
511            .append(vec![1, 2, 3])
512            .encrypt::<PASETO_V4>(&[0; 32]);
513        let local_only_22 = local_only_21
514            .append(vec![1, 2, 3])
515            .encrypt::<PASETO_V4>(&[0; 32]);
516        let local_only_23 = local_only_22
517            .append(vec![1, 2, 3])
518            .encrypt::<PASETO_V4>(&[0; 32]);
519
520        let remote_only = test_record();
521
522        let remote_only_20 = test_record();
523        let remote_only_21 = remote_only_20
524            .append(vec![2, 3, 2])
525            .encrypt::<PASETO_V4>(&[0; 32]);
526        let remote_only_22 = remote_only_21
527            .append(vec![2, 3, 2])
528            .encrypt::<PASETO_V4>(&[0; 32]);
529        let remote_only_23 = remote_only_22
530            .append(vec![2, 3, 2])
531            .encrypt::<PASETO_V4>(&[0; 32]);
532        let remote_only_24 = remote_only_23
533            .append(vec![2, 3, 2])
534            .encrypt::<PASETO_V4>(&[0; 32]);
535
536        let second_shared = test_record();
537        let second_shared_remote_ahead = second_shared
538            .append(vec![1, 2, 3])
539            .encrypt::<PASETO_V4>(&[0; 32]);
540        let second_shared_remote_ahead2 = second_shared_remote_ahead
541            .append(vec![1, 2, 3])
542            .encrypt::<PASETO_V4>(&[0; 32]);
543
544        let third_shared = test_record();
545        let third_shared_local_ahead = third_shared
546            .append(vec![1, 2, 3])
547            .encrypt::<PASETO_V4>(&[0; 32]);
548        let third_shared_local_ahead2 = third_shared_local_ahead
549            .append(vec![1, 2, 3])
550            .encrypt::<PASETO_V4>(&[0; 32]);
551
552        let fourth_shared = test_record();
553        let fourth_shared_remote_ahead = fourth_shared
554            .append(vec![1, 2, 3])
555            .encrypt::<PASETO_V4>(&[0; 32]);
556        let fourth_shared_remote_ahead2 = fourth_shared_remote_ahead
557            .append(vec![1, 2, 3])
558            .encrypt::<PASETO_V4>(&[0; 32]);
559
560        let local = vec![
561            shared_record.clone(),
562            second_shared.clone(),
563            third_shared.clone(),
564            fourth_shared.clone(),
565            fourth_shared_remote_ahead.clone(),
566            // single store, only local has it
567            local_only.clone(),
568            // bigger store, also only known by local
569            local_only_20.clone(),
570            local_only_21.clone(),
571            local_only_22.clone(),
572            local_only_23.clone(),
573            // another shared store, but local is ahead on this one
574            third_shared_local_ahead.clone(),
575            third_shared_local_ahead2.clone(),
576        ];
577
578        let remote = vec![
579            remote_only.clone(),
580            remote_only_20.clone(),
581            remote_only_21.clone(),
582            remote_only_22.clone(),
583            remote_only_23.clone(),
584            remote_only_24.clone(),
585            shared_record.clone(),
586            second_shared.clone(),
587            third_shared.clone(),
588            second_shared_remote_ahead.clone(),
589            second_shared_remote_ahead2.clone(),
590            fourth_shared.clone(),
591            fourth_shared_remote_ahead.clone(),
592            fourth_shared_remote_ahead2.clone(),
593        ]; // remote knows about the already-synced, and one new record in a new store
594
595        let (store, diff) = build_test_diff(local, remote).await;
596        let operations = sync::operations(diff, &store).await.unwrap();
597
598        assert_eq!(operations.len(), 7);
599
600        let mut result_ops = vec![
601            // We started with a shared record, but the remote knows of two newer records in the
602            // same store
603            Operation::Download {
604                local: Some(0),
605                remote: 2,
606                host: second_shared_remote_ahead.host.id,
607                tag: second_shared_remote_ahead.tag,
608            },
609            // We have a shared record, local knows of the first two but not the last
610            Operation::Download {
611                local: Some(1),
612                remote: 2,
613                host: fourth_shared_remote_ahead2.host.id,
614                tag: fourth_shared_remote_ahead2.tag,
615            },
616            // Remote knows of a store with a single record that local does not have
617            Operation::Download {
618                local: None,
619                remote: 0,
620                host: remote_only.host.id,
621                tag: remote_only.tag,
622            },
623            // Remote knows of a store with a bunch of records that local does not have
624            Operation::Download {
625                local: None,
626                remote: 4,
627                host: remote_only_20.host.id,
628                tag: remote_only_20.tag,
629            },
630            // Local knows of a record in a store that remote does not have
631            Operation::Upload {
632                local: 0,
633                remote: None,
634                host: local_only.host.id,
635                tag: local_only.tag,
636            },
637            // Local knows of 4 records in a store that remote does not have
638            Operation::Upload {
639                local: 3,
640                remote: None,
641                host: local_only_20.host.id,
642                tag: local_only_20.tag,
643            },
644            // Local knows of 2 more records in a shared store that remote only has one of
645            Operation::Upload {
646                local: 2,
647                remote: Some(0),
648                host: third_shared.host.id,
649                tag: third_shared.tag,
650            },
651        ];
652
653        result_ops.sort_by_key(|op| match op {
654            Operation::Noop { host, tag } => (0, *host, tag.clone()),
655
656            Operation::Upload { host, tag, .. } => (1, *host, tag.clone()),
657
658            Operation::Download { host, tag, .. } => (2, *host, tag.clone()),
659        });
660
661        assert_eq!(result_ops, operations);
662    }
663}