1use std::{cmp::Ordering, fmt::Write};
3
4use eyre::Result;
5use thiserror::Error;
6
7use super::{encryption::PASETO_V4, store::Store};
8use crate::{api_client::Client, settings::Settings};
9
10use atuin_common::record::{Diff, HostId, RecordId, RecordIdx, RecordStatus};
11use indicatif::{ProgressBar, ProgressState, ProgressStyle};
12
13#[derive(Error, Debug)]
14pub enum SyncError {
15 #[error("the local store is ahead of the remote, but for another host. has remote lost data?")]
16 LocalAheadOtherHost,
17
18 #[error("an issue with the local database occurred: {msg:?}")]
19 LocalStoreError { msg: String },
20
21 #[error("something has gone wrong with the sync logic: {msg:?}")]
22 SyncLogicError { msg: String },
23
24 #[error("operational error: {msg:?}")]
25 OperationalError { msg: String },
26
27 #[error("a request to the sync server failed: {msg:?}")]
28 RemoteRequestError { msg: String },
29
30 #[error(
31 "the encryption key on this machine does not match the data on the server. \
32 this usually means a new machine was set up without copying the existing key. \
33 to fix: run `atuin key` on a machine that already syncs correctly, then run \
34 `atuin store rekey <key>` on this machine with the value from the other machine"
35 )]
36 WrongKey,
37}
38
39#[derive(Debug, Eq, PartialEq)]
40pub enum Operation {
41 Upload {
43 local: RecordIdx,
44 remote: Option<RecordIdx>,
45 host: HostId,
46 tag: String,
47 },
48 Download {
49 local: Option<RecordIdx>,
50 remote: RecordIdx,
51 host: HostId,
52 tag: String,
53 },
54 Noop {
55 host: HostId,
56 tag: String,
57 },
58}
59
60pub async fn build_client(settings: &Settings) -> Result<Client<'_>, SyncError> {
61 Client::new(
62 &settings.sync_address,
63 settings
64 .sync_auth_token()
65 .await
66 .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?,
67 settings.network_connect_timeout,
68 settings.network_timeout,
69 )
70 .map_err(|e| SyncError::OperationalError { msg: e.to_string() })
71}
72
73pub async fn diff(
74 client: &Client<'_>,
75 store: &impl Store,
76) -> Result<(Vec<Diff>, RecordStatus), SyncError> {
77 let local_index = store
78 .status()
79 .await
80 .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?;
81
82 let remote_index = client
83 .record_status()
84 .await
85 .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?;
86
87 let diff = local_index.diff(&remote_index);
88
89 Ok((diff, remote_index))
90}
91
92pub async fn operations(
97 diffs: Vec<Diff>,
98 _store: &impl Store,
99) -> Result<Vec<Operation>, SyncError> {
100 let mut operations = Vec::with_capacity(diffs.len());
101
102 for diff in diffs {
103 let op = match (diff.local, diff.remote) {
104 (Some(local), Some(remote)) => match local.cmp(&remote) {
106 Ordering::Equal => Operation::Noop {
107 host: diff.host,
108 tag: diff.tag,
109 },
110 Ordering::Greater => Operation::Upload {
111 local,
112 remote: Some(remote),
113 host: diff.host,
114 tag: diff.tag,
115 },
116 Ordering::Less => Operation::Download {
117 local: Some(local),
118 remote,
119 host: diff.host,
120 tag: diff.tag,
121 },
122 },
123
124 (None, Some(remote)) => Operation::Download {
126 local: None,
127 remote,
128 host: diff.host,
129 tag: diff.tag,
130 },
131
132 (Some(local), None) => Operation::Upload {
134 local,
135 remote: None,
136 host: diff.host,
137 tag: diff.tag,
138 },
139
140 (None, None) => {
142 return Err(SyncError::SyncLogicError {
143 msg: String::from(
144 "diff has nothing for local or remote - (host, tag) does not exist",
145 ),
146 });
147 }
148 };
149
150 operations.push(op);
151 }
152
153 operations.sort_by_key(|op| match op {
159 Operation::Noop { host, tag } => (0, *host, tag.clone()),
160
161 Operation::Upload { host, tag, .. } => (1, *host, tag.clone()),
162
163 Operation::Download { host, tag, .. } => (2, *host, tag.clone()),
164 });
165
166 Ok(operations)
167}
168
169async fn sync_upload(
170 store: &impl Store,
171 client: &Client<'_>,
172 host: HostId,
173 tag: String,
174 local: RecordIdx,
175 remote: Option<RecordIdx>,
176 page_size: u64,
177) -> Result<i64, SyncError> {
178 let remote = remote.unwrap_or(0);
179 let expected = local - remote;
180 let mut progress = 0;
181
182 let pb = ProgressBar::new(expected);
183 pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})")
184 .unwrap()
185 .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap())
186 .progress_chars("#>-"));
187
188 println!(
189 "Uploading {} records to {}/{}",
190 expected,
191 host.0.as_simple(),
192 tag
193 );
194
195 loop {
196 let page = store
197 .next(host, tag.as_str(), remote + progress, page_size)
198 .await
199 .map_err(|e| {
200 error!("failed to read upload page: {e:?}");
201
202 SyncError::LocalStoreError { msg: e.to_string() }
203 })?;
204
205 if page.is_empty() {
206 break;
207 }
208
209 client.post_records(&page).await.map_err(|e| {
210 error!("failed to post records: {e:?}");
211
212 SyncError::RemoteRequestError { msg: e.to_string() }
213 })?;
214
215 progress += page.len() as u64;
216 pb.set_position(progress);
217
218 if progress >= expected {
219 break;
220 }
221 }
222
223 pb.finish_with_message("Uploaded records");
224
225 Ok(progress as i64)
226}
227
228async fn sync_download(
229 store: &impl Store,
230 client: &Client<'_>,
231 host: HostId,
232 tag: String,
233 local: Option<RecordIdx>,
234 remote: RecordIdx,
235 page_size: u64,
236) -> Result<Vec<RecordId>, SyncError> {
237 let local = local.unwrap_or(0);
238 let expected = remote - local;
239 let mut progress = 0;
240 let mut ret = Vec::new();
241
242 println!(
243 "Downloading {} records from {}/{}",
244 expected,
245 host.0.as_simple(),
246 tag
247 );
248
249 let pb = ProgressBar::new(expected);
250 pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})")
251 .unwrap()
252 .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap())
253 .progress_chars("#>-"));
254
255 loop {
256 let page = client
257 .next_records(host, tag.clone(), local + progress, page_size)
258 .await
259 .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?;
260
261 if page.is_empty() {
262 break;
263 }
264
265 store
266 .push_batch(page.iter())
267 .await
268 .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?;
269
270 ret.extend(page.iter().map(|f| f.id));
271
272 progress += page.len() as u64;
273 pb.set_position(progress);
274
275 if progress >= expected {
276 break;
277 }
278 }
279
280 pb.finish_with_message("Downloaded records");
281
282 Ok(ret)
283}
284
285pub async fn sync_remote(
286 client: &Client<'_>,
287 operations: Vec<Operation>,
288 local_store: &impl Store,
289 page_size: u64,
290) -> Result<(i64, Vec<RecordId>), SyncError> {
291 let mut uploaded = 0;
292 let mut downloaded = Vec::new();
293
294 for i in operations {
296 match i {
297 Operation::Upload {
298 host,
299 tag,
300 local,
301 remote,
302 } => {
303 uploaded +=
304 sync_upload(local_store, client, host, tag, local, remote, page_size).await?
305 }
306
307 Operation::Download {
308 host,
309 tag,
310 local,
311 remote,
312 } => {
313 let mut d =
314 sync_download(local_store, client, host, tag, local, remote, page_size).await?;
315 downloaded.append(&mut d)
316 }
317
318 Operation::Noop { .. } => continue,
319 }
320 }
321
322 Ok((uploaded, downloaded))
323}
324
325pub async fn check_encryption_key(
326 client: &Client<'_>,
327 remote_index: &RecordStatus,
328 encryption_key: &[u8; 32],
329) -> Result<(), SyncError> {
330 let sample = remote_index
331 .hosts
332 .iter()
333 .flat_map(|(host, tags)| tags.keys().map(move |tag| (*host, tag.clone())))
334 .next();
335
336 let Some((host, tag)) = sample else {
337 return Ok(());
338 };
339
340 let records = client
341 .next_records(host, tag, 0, 1)
342 .await
343 .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?;
344
345 let Some(record) = records.into_iter().next() else {
346 return Ok(());
347 };
348
349 record
350 .decrypt::<PASETO_V4>(encryption_key)
351 .map_err(|_| SyncError::WrongKey)?;
352
353 Ok(())
354}
355
356pub async fn sync(
357 settings: &Settings,
358 store: &impl Store,
359 encryption_key: &[u8; 32],
360) -> Result<(i64, Vec<RecordId>), SyncError> {
361 let client = build_client(settings).await?;
362 let (diff, remote_index) = diff(&client, store).await?;
363
364 check_encryption_key(&client, &remote_index, encryption_key).await?;
366
367 let operations = operations(diff, store).await?;
368 let (uploaded, downloaded) = sync_remote(&client, operations, store, 100).await?;
369
370 Ok((uploaded, downloaded))
371}
372
373#[cfg(test)]
374mod tests {
375 use atuin_common::record::{Diff, EncryptedData, HostId, Record};
376 use pretty_assertions::assert_eq;
377
378 use crate::{
379 record::{
380 encryption::PASETO_V4,
381 sqlite_store::SqliteStore,
382 store::Store,
383 sync::{self, Operation},
384 },
385 settings::test_local_timeout,
386 };
387
388 fn test_record() -> Record<EncryptedData> {
389 Record::builder()
390 .host(atuin_common::record::Host::new(HostId(
391 atuin_common::utils::uuid_v7(),
392 )))
393 .version("v1".into())
394 .tag(atuin_common::utils::uuid_v7().simple().to_string())
395 .data(EncryptedData {
396 data: String::new(),
397 content_encryption_key: String::new(),
398 })
399 .idx(0)
400 .build()
401 }
402
403 async fn build_test_diff(
407 local_records: Vec<Record<EncryptedData>>,
408 remote_records: Vec<Record<EncryptedData>>,
409 ) -> (SqliteStore, Vec<Diff>) {
410 let local_store = SqliteStore::new(":memory:", test_local_timeout())
411 .await
412 .expect("failed to open in memory sqlite");
413 let remote_store = SqliteStore::new(":memory:", test_local_timeout())
414 .await
415 .expect("failed to open in memory sqlite"); for i in local_records {
418 local_store.push(&i).await.unwrap();
419 }
420
421 for i in remote_records {
422 remote_store.push(&i).await.unwrap();
423 }
424
425 let local_index = local_store.status().await.unwrap();
426 let remote_index = remote_store.status().await.unwrap();
427
428 let diff = local_index.diff(&remote_index);
429
430 (local_store, diff)
431 }
432
433 #[tokio::test]
434 async fn test_basic_diff() {
435 let record = test_record();
438 let (store, diff) = build_test_diff(vec![record.clone()], vec![]).await;
439
440 assert_eq!(diff.len(), 1);
441
442 let operations = sync::operations(diff, &store).await.unwrap();
443
444 assert_eq!(operations.len(), 1);
445
446 assert_eq!(
447 operations[0],
448 Operation::Upload {
449 host: record.host.id,
450 tag: record.tag,
451 local: record.idx,
452 remote: None,
453 }
454 );
455 }
456
457 #[tokio::test]
458 async fn build_two_way_diff() {
459 let shared_record = test_record();
463 let remote_ahead = test_record();
464
465 let local_ahead = shared_record
466 .append(vec![1, 2, 3])
467 .encrypt::<PASETO_V4>(&[0; 32]);
468
469 assert_eq!(local_ahead.idx, 1);
470
471 let local = vec![shared_record.clone(), local_ahead.clone()]; let remote = vec![shared_record.clone(), remote_ahead.clone()]; let (store, diff) = build_test_diff(local, remote).await;
475 let operations = sync::operations(diff, &store).await.unwrap();
476
477 assert_eq!(operations.len(), 2);
478
479 assert_eq!(
480 operations,
481 vec![
482 Operation::Upload {
484 host: local_ahead.host.id,
485 tag: local_ahead.tag,
486 local: 1,
487 remote: Some(0),
488 },
489 Operation::Download {
491 host: remote_ahead.host.id,
492 tag: remote_ahead.tag,
493 local: None,
494 remote: 0,
495 },
496 ]
497 );
498 }
499
500 #[tokio::test]
501 async fn build_complex_diff() {
502 let shared_record = test_record();
507 let local_only = test_record();
508
509 let local_only_20 = test_record();
510 let local_only_21 = local_only_20
511 .append(vec![1, 2, 3])
512 .encrypt::<PASETO_V4>(&[0; 32]);
513 let local_only_22 = local_only_21
514 .append(vec![1, 2, 3])
515 .encrypt::<PASETO_V4>(&[0; 32]);
516 let local_only_23 = local_only_22
517 .append(vec![1, 2, 3])
518 .encrypt::<PASETO_V4>(&[0; 32]);
519
520 let remote_only = test_record();
521
522 let remote_only_20 = test_record();
523 let remote_only_21 = remote_only_20
524 .append(vec![2, 3, 2])
525 .encrypt::<PASETO_V4>(&[0; 32]);
526 let remote_only_22 = remote_only_21
527 .append(vec![2, 3, 2])
528 .encrypt::<PASETO_V4>(&[0; 32]);
529 let remote_only_23 = remote_only_22
530 .append(vec![2, 3, 2])
531 .encrypt::<PASETO_V4>(&[0; 32]);
532 let remote_only_24 = remote_only_23
533 .append(vec![2, 3, 2])
534 .encrypt::<PASETO_V4>(&[0; 32]);
535
536 let second_shared = test_record();
537 let second_shared_remote_ahead = second_shared
538 .append(vec![1, 2, 3])
539 .encrypt::<PASETO_V4>(&[0; 32]);
540 let second_shared_remote_ahead2 = second_shared_remote_ahead
541 .append(vec![1, 2, 3])
542 .encrypt::<PASETO_V4>(&[0; 32]);
543
544 let third_shared = test_record();
545 let third_shared_local_ahead = third_shared
546 .append(vec![1, 2, 3])
547 .encrypt::<PASETO_V4>(&[0; 32]);
548 let third_shared_local_ahead2 = third_shared_local_ahead
549 .append(vec![1, 2, 3])
550 .encrypt::<PASETO_V4>(&[0; 32]);
551
552 let fourth_shared = test_record();
553 let fourth_shared_remote_ahead = fourth_shared
554 .append(vec![1, 2, 3])
555 .encrypt::<PASETO_V4>(&[0; 32]);
556 let fourth_shared_remote_ahead2 = fourth_shared_remote_ahead
557 .append(vec![1, 2, 3])
558 .encrypt::<PASETO_V4>(&[0; 32]);
559
560 let local = vec![
561 shared_record.clone(),
562 second_shared.clone(),
563 third_shared.clone(),
564 fourth_shared.clone(),
565 fourth_shared_remote_ahead.clone(),
566 local_only.clone(),
568 local_only_20.clone(),
570 local_only_21.clone(),
571 local_only_22.clone(),
572 local_only_23.clone(),
573 third_shared_local_ahead.clone(),
575 third_shared_local_ahead2.clone(),
576 ];
577
578 let remote = vec![
579 remote_only.clone(),
580 remote_only_20.clone(),
581 remote_only_21.clone(),
582 remote_only_22.clone(),
583 remote_only_23.clone(),
584 remote_only_24.clone(),
585 shared_record.clone(),
586 second_shared.clone(),
587 third_shared.clone(),
588 second_shared_remote_ahead.clone(),
589 second_shared_remote_ahead2.clone(),
590 fourth_shared.clone(),
591 fourth_shared_remote_ahead.clone(),
592 fourth_shared_remote_ahead2.clone(),
593 ]; let (store, diff) = build_test_diff(local, remote).await;
596 let operations = sync::operations(diff, &store).await.unwrap();
597
598 assert_eq!(operations.len(), 7);
599
600 let mut result_ops = vec![
601 Operation::Download {
604 local: Some(0),
605 remote: 2,
606 host: second_shared_remote_ahead.host.id,
607 tag: second_shared_remote_ahead.tag,
608 },
609 Operation::Download {
611 local: Some(1),
612 remote: 2,
613 host: fourth_shared_remote_ahead2.host.id,
614 tag: fourth_shared_remote_ahead2.tag,
615 },
616 Operation::Download {
618 local: None,
619 remote: 0,
620 host: remote_only.host.id,
621 tag: remote_only.tag,
622 },
623 Operation::Download {
625 local: None,
626 remote: 4,
627 host: remote_only_20.host.id,
628 tag: remote_only_20.tag,
629 },
630 Operation::Upload {
632 local: 0,
633 remote: None,
634 host: local_only.host.id,
635 tag: local_only.tag,
636 },
637 Operation::Upload {
639 local: 3,
640 remote: None,
641 host: local_only_20.host.id,
642 tag: local_only_20.tag,
643 },
644 Operation::Upload {
646 local: 2,
647 remote: Some(0),
648 host: third_shared.host.id,
649 tag: third_shared.tag,
650 },
651 ];
652
653 result_ops.sort_by_key(|op| match op {
654 Operation::Noop { host, tag } => (0, *host, tag.clone()),
655
656 Operation::Upload { host, tag, .. } => (1, *host, tag.clone()),
657
658 Operation::Download { host, tag, .. } => (2, *host, tag.clone()),
659 });
660
661 assert_eq!(result_ops, operations);
662 }
663}