1use std::collections::HashSet;
2
3use atrium_api::types::{
4 string::{Did, RecordKey, Tid},
5 Collection, LimitedU32,
6};
7use futures::TryStreamExt;
8use ipld_core::cid::Cid;
9use serde::{de::DeserializeOwned, Deserialize, Serialize};
10
11use crate::{
12 blockstore::{AsyncBlockStoreRead, AsyncBlockStoreWrite, DAG_CBOR, SHA2_256},
13 mst,
14};
15
16mod schema {
17 use super::*;
18
19 #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
25 pub struct Commit {
26 pub did: Did,
28 pub version: i64,
30 pub data: Cid,
32 pub rev: Tid,
34 pub prev: Option<Cid>,
36 }
37
38 #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
45 pub struct SignedCommit {
46 pub did: Did,
48 pub version: i64,
50 pub data: Cid,
52 pub rev: Tid,
54 pub prev: Option<Cid>,
56 #[serde(with = "serde_bytes")]
58 pub sig: Vec<u8>,
59 }
60}
61
62async fn read_record<T: DeserializeOwned>(
63 mut db: impl AsyncBlockStoreRead,
64 cid: Cid,
65) -> Result<T, Error> {
66 assert_eq!(cid.codec(), crate::blockstore::DAG_CBOR);
67
68 let data = db.read_block(cid).await?;
69 let parsed: T = serde_ipld_dagcbor::from_reader(&data[..])?;
70 Ok(parsed)
71}
72
73pub struct CommitBuilder<'r, S: AsyncBlockStoreWrite> {
75 repo: &'r mut Repository<S>,
76 inner: schema::Commit,
77}
78
79impl<'r, S: AsyncBlockStoreWrite> CommitBuilder<'r, S> {
80 fn new(repo: &'r mut Repository<S>, did: Did, root: Cid) -> Self {
81 CommitBuilder {
82 inner: schema::Commit {
83 did,
84 version: 3,
85 data: root,
86 rev: Tid::now(LimitedU32::MIN),
87 prev: None,
88 },
89 repo,
90 }
91 }
92
93 pub fn prev(&mut self, prev: Cid) -> &mut Self {
95 self.inner.prev = Some(prev);
96 self
97 }
98
99 pub fn rev(&mut self, time: Tid) -> &mut Self {
102 self.inner.rev = time;
103 self
104 }
105
106 pub fn bytes(&self) -> Vec<u8> {
108 serde_ipld_dagcbor::to_vec(&self.inner).unwrap() }
110
111 pub async fn finalize(self, sig: Vec<u8>) -> Result<Cid, Error> {
116 let s = schema::SignedCommit {
117 did: self.inner.did.clone(),
118 version: self.inner.version,
119 data: self.inner.data,
120 rev: self.inner.rev.clone(),
121 prev: self.inner.prev,
122 sig,
123 };
124 let b = serde_ipld_dagcbor::to_vec(&s).unwrap();
125 let c = self.repo.db.write_block(DAG_CBOR, SHA2_256, &b).await?;
126
127 self.repo.root = c;
128 self.repo.latest_commit = s.clone();
129 Ok(c)
130 }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq)]
135pub struct RepoBuilder<S: AsyncBlockStoreRead + AsyncBlockStoreWrite> {
136 db: S,
137 commit: schema::Commit,
138}
139
140impl<S: AsyncBlockStoreRead + AsyncBlockStoreWrite> RepoBuilder<S> {
141 pub fn bytes(&self) -> Vec<u8> {
143 serde_ipld_dagcbor::to_vec(&self.commit).unwrap() }
145
146 pub async fn finalize(mut self, sig: Vec<u8>) -> Result<Repository<S>, Error> {
148 let s = schema::SignedCommit {
150 did: self.commit.did.clone(),
151 version: self.commit.version,
152 data: self.commit.data,
153 rev: self.commit.rev.clone(),
154 prev: self.commit.prev,
155 sig,
156 };
157 let b = serde_ipld_dagcbor::to_vec(&s).unwrap();
158 let c = self.db.write_block(DAG_CBOR, SHA2_256, &b).await?;
159
160 Ok(Repository { db: self.db, root: c, latest_commit: s })
161 }
162}
163
164#[derive(Debug, Clone, PartialEq, Eq)]
166pub struct Commit {
167 inner: schema::SignedCommit,
168}
169
170impl Commit {
171 pub fn data(&self) -> Cid {
173 self.inner.data
174 }
175
176 pub fn rev(&self) -> Tid {
178 self.inner.rev.clone()
179 }
180
181 pub fn bytes(&self) -> Vec<u8> {
183 serde_ipld_dagcbor::to_vec(&schema::Commit {
184 did: self.inner.did.clone(),
185 version: self.inner.version,
186 data: self.inner.data,
187 rev: self.inner.rev.clone(),
188 prev: self.inner.prev,
189 })
190 .unwrap() }
192
193 pub fn sig(&self) -> &[u8] {
195 self.inner.sig.as_slice()
196 }
197}
198
199#[derive(Debug)]
224pub struct Repository<S> {
225 db: S,
226 root: Cid,
227 latest_commit: schema::SignedCommit,
228}
229
230impl<R: AsyncBlockStoreRead> Repository<R> {
231 pub async fn open(mut db: R, root: Cid) -> Result<Self, Error> {
235 let commit_block = db.read_block(root).await?;
236 let latest_commit: schema::SignedCommit =
237 serde_ipld_dagcbor::from_reader(&commit_block[..])?;
238
239 Ok(Self { db, root, latest_commit })
240 }
241
242 pub fn root(&self) -> Cid {
244 self.root
245 }
246
247 pub fn commit(&self) -> Commit {
249 Commit { inner: self.latest_commit.clone() }
250 }
251
252 pub fn tree(&mut self) -> mst::Tree<&mut R> {
258 mst::Tree::open(&mut self.db, self.latest_commit.data)
259 }
260
261 pub async fn get<C: Collection>(
263 &mut self,
264 rkey: RecordKey,
265 ) -> Result<Option<C::Record>, Error> {
266 let path = C::repo_path(&rkey);
267 let mut mst = mst::Tree::open(&mut self.db, self.latest_commit.data);
268
269 if let Some(cid) = mst.get(&path).await? {
270 Ok(Some(read_record::<C::Record>(&mut self.db, cid).await?))
271 } else {
272 Ok(None)
273 }
274 }
275
276 pub async fn get_raw<T: DeserializeOwned>(&mut self, key: &str) -> Result<Option<T>, Error> {
278 let mut mst = mst::Tree::open(&mut self.db, self.latest_commit.data);
279
280 if let Some(cid) = mst.get(key).await? {
281 Ok(Some(read_record::<T>(&mut self.db, cid).await?))
282 } else {
283 Ok(None)
284 }
285 }
286
287 pub async fn get_raw_cid<T: DeserializeOwned>(&mut self, cid: Cid) -> Result<Option<T>, Error> {
300 let mut mst = mst::Tree::open(&mut self.db, self.latest_commit.data);
301
302 let mut ocid = None;
303
304 let mut it = Box::pin(mst.entries());
305 while let Some((_rkey, rcid)) = it.try_next().await? {
306 if rcid == cid {
307 ocid = Some(rcid);
308 break;
309 }
310 }
311
312 drop(it);
314
315 if let Some(ocid) = ocid {
316 Ok(Some(read_record::<T>(&mut self.db, ocid).await?))
317 } else {
318 Ok(None)
319 }
320 }
321
322 pub async fn export(&mut self) -> Result<impl Iterator<Item = Cid>, Error> {
324 let mut mst = mst::Tree::open(&mut self.db, self.latest_commit.data);
325
326 let mut r = vec![self.root];
327 r.extend(mst.export().try_collect::<Vec<_>>().await?);
328 Ok(r.into_iter())
329 }
330
331 pub async fn export_into(&mut self, mut bs: impl AsyncBlockStoreWrite) -> Result<(), Error> {
333 let cids = self.export().await?.collect::<HashSet<_>>();
334
335 for cid in cids {
336 bs.write_block(cid.codec(), SHA2_256, self.db.read_block(cid).await?.as_slice())
337 .await?;
338 }
339
340 Ok(())
341 }
342
343 pub async fn extract<C: Collection>(
351 &mut self,
352 rkey: RecordKey,
353 ) -> Result<impl Iterator<Item = Cid>, Error> {
354 let path = C::repo_path(&rkey);
355 self.extract_raw(&path).await
356 }
357
358 pub async fn extract_into<C: Collection>(
360 &mut self,
361 rkey: RecordKey,
362 bs: impl AsyncBlockStoreWrite,
363 ) -> Result<(), Error> {
364 let path = C::repo_path(&rkey);
365 self.extract_raw_into(&path, bs).await
366 }
367
368 pub async fn extract_raw(&mut self, key: &str) -> Result<impl Iterator<Item = Cid>, Error> {
370 let mut mst = mst::Tree::open(&mut self.db, self.latest_commit.data);
371
372 let mut r = vec![self.root];
373 r.extend(mst.extract_path(key).await?);
374 Ok(r.into_iter())
375 }
376
377 pub async fn extract_raw_into(
379 &mut self,
380 key: &str,
381 mut bs: impl AsyncBlockStoreWrite,
382 ) -> Result<(), Error> {
383 let cids = self.extract_raw(key).await?.collect::<HashSet<_>>();
384
385 for cid in cids {
386 bs.write_block(cid.codec(), SHA2_256, self.db.read_block(cid).await?.as_slice())
387 .await?;
388 }
389
390 Ok(())
391 }
392}
393
394impl<S: AsyncBlockStoreRead + AsyncBlockStoreWrite> Repository<S> {
395 pub async fn create(mut db: S, did: Did) -> Result<RepoBuilder<S>, Error> {
397 let tree = mst::Tree::create(&mut db).await?;
398 let root = tree.root();
399
400 Ok(RepoBuilder {
401 db,
402 commit: schema::Commit {
403 did,
404 version: 3,
405 data: root,
406 rev: Tid::now(LimitedU32::MIN),
407 prev: None,
408 },
409 })
410 }
411
412 pub async fn add<C: Collection>(
414 &mut self,
415 rkey: RecordKey,
416 record: C::Record,
417 ) -> Result<(CommitBuilder<'_, S>, Cid), Error> {
418 let path = C::repo_path(&rkey);
419 self.add_raw(&path, record).await
420 }
421
422 pub async fn add_raw<'a, T: Serialize>(
424 &'a mut self,
425 key: &str,
426 data: T,
427 ) -> Result<(CommitBuilder<'a, S>, Cid), Error> {
428 let data = serde_ipld_dagcbor::to_vec(&data).unwrap();
429 let cid = self.db.write_block(DAG_CBOR, SHA2_256, &data).await?;
430
431 let mut mst = mst::Tree::open(&mut self.db, self.latest_commit.data);
432 mst.add(key, cid).await?;
433 let root = mst.root();
434
435 Ok((CommitBuilder::new(self, self.latest_commit.did.clone(), root), cid))
436 }
437
438 pub async fn update<C: Collection>(
440 &mut self,
441 rkey: RecordKey,
442 record: C::Record,
443 ) -> Result<(CommitBuilder<'_, S>, Cid), Error> {
444 let path = C::repo_path(&rkey);
445 self.update_raw(&path, record).await
446 }
447
448 pub async fn update_raw<'a, T: Serialize>(
450 &'a mut self,
451 key: &str,
452 data: T,
453 ) -> Result<(CommitBuilder<'a, S>, Cid), Error> {
454 let data = serde_ipld_dagcbor::to_vec(&data).unwrap();
455 let cid = self.db.write_block(DAG_CBOR, SHA2_256, &data).await?;
456
457 let mut mst = mst::Tree::open(&mut self.db, self.latest_commit.data);
458 mst.update(key, cid).await?;
459 let root = mst.root();
460
461 Ok((CommitBuilder::new(self, self.latest_commit.did.clone(), root), cid))
462 }
463
464 pub async fn delete<C: Collection>(
466 &mut self,
467 rkey: RecordKey,
468 ) -> Result<CommitBuilder<'_, S>, Error> {
469 let path = C::repo_path(&rkey);
470 self.delete_raw(&path).await
471 }
472
473 pub async fn delete_raw<'a>(&'a mut self, key: &str) -> Result<CommitBuilder<'a, S>, Error> {
475 let mut mst = mst::Tree::open(&mut self.db, self.latest_commit.data);
476 mst.delete(key).await?;
477 let root = mst.root();
478
479 Ok(CommitBuilder::new(self, self.latest_commit.did.clone(), root))
480 }
481}
482
483#[derive(Debug, thiserror::Error)]
485pub enum Error {
486 #[error("Invalid key: {0}")]
487 InvalidKey(#[from] std::str::Utf8Error),
488 #[error("Invalid RecordKey: {0}")]
489 InvalidRecordKey(&'static str),
490 #[error("Blockstore error: {0}")]
491 BlockStore(#[from] crate::blockstore::Error),
492 #[error("MST error: {0}")]
493 Mst(#[from] mst::Error),
494 #[error("serde_ipld_dagcbor decoding error: {0}")]
495 Parse(#[from] serde_ipld_dagcbor::DecodeError<std::io::Error>),
496}
497
498#[cfg(test)]
499mod test {
500 use std::str::FromStr;
501
502 use crate::blockstore::MemoryBlockStore;
503 use atrium_api::{app::bsky, types::string::Datetime};
504 use atrium_crypto::{
505 did::parse_did_key,
506 keypair::{Did as _, P256Keypair},
507 verify::Verifier,
508 };
509
510 use super::*;
511
512 async fn create_repo<S: AsyncBlockStoreRead + AsyncBlockStoreWrite>(
513 bs: S,
514 did: Did,
515 keypair: &P256Keypair,
516 ) -> Repository<S> {
517 let builder = Repository::create(bs, did).await.unwrap();
518
519 let sig = keypair.sign(&builder.bytes()).unwrap();
521
522 builder.finalize(sig).await.unwrap()
524 }
525
526 #[tokio::test]
527 async fn test_create_repo() {
528 let mut bs = MemoryBlockStore::new();
529
530 let keypair = P256Keypair::create(&mut rand::thread_rng());
532 let dkey = keypair.did();
533
534 let mut repo =
536 create_repo(&mut bs, Did::new("did:web:pds.abc.com".to_string()).unwrap(), &keypair)
537 .await;
538
539 let commit = repo.commit();
540
541 let (alg, pub_key) = parse_did_key(&dkey).unwrap();
543 Verifier::default().verify(alg, &pub_key, &commit.bytes(), commit.sig()).unwrap();
544
545 assert_eq!(
547 commit.data(),
548 Cid::from_str("bafyreie5737gdxlw5i64vzichcalba3z2v5n6icifvx5xytvske7mr3hpm").unwrap()
549 );
550
551 let (cb, _) = repo
553 .add::<bsky::feed::Post>(
554 RecordKey::new(Tid::now(LimitedU32::MIN).to_string()).unwrap(),
555 bsky::feed::post::RecordData {
556 created_at: Datetime::now(),
557 embed: None,
558 entities: None,
559 facets: None,
560 labels: None,
561 langs: None,
562 reply: None,
563 tags: None,
564 text: "Hello world".to_string(),
565 }
566 .into(),
567 )
568 .await
569 .unwrap();
570
571 let sig = keypair.sign(&cb.bytes()).unwrap();
572 let _cid = cb.finalize(sig).await.unwrap();
573
574 let commit = repo.commit();
576
577 Verifier::default().verify(alg, &pub_key, &commit.bytes(), commit.sig()).unwrap();
578 }
579
580 #[tokio::test]
581 async fn test_extract() {
582 let mut bs = MemoryBlockStore::new();
583
584 let keypair = P256Keypair::create(&mut rand::thread_rng());
586
587 let mut repo =
589 create_repo(&mut bs, Did::new("did:web:pds.abc.com".to_string()).unwrap(), &keypair)
590 .await;
591
592 let rkey = RecordKey::new("2222222222222".to_string()).unwrap();
593 let (cb, _) = repo
594 .add::<bsky::feed::Post>(
595 rkey.clone(),
596 bsky::feed::post::RecordData {
597 created_at: Datetime::from_str("2025-02-01T00:00:00.000Z").unwrap(),
598 embed: None,
599 entities: None,
600 facets: None,
601 labels: None,
602 langs: None,
603 reply: None,
604 tags: None,
605 text: "Hello world".to_string(),
606 }
607 .into(),
608 )
609 .await
610 .unwrap();
611
612 let sig = keypair.sign(&cb.bytes()).unwrap();
613 let cid = cb.finalize(sig).await.unwrap();
614
615 let commit = repo.commit();
616
617 let mut bs2 = MemoryBlockStore::new();
618 repo.extract_into::<bsky::feed::Post>(rkey.clone(), &mut bs2).await.unwrap();
619
620 assert!(bs2.read_block(cid).await.is_ok()); assert!(bs2.read_block(commit.data()).await.is_ok()); let mut repo2 = Repository::open(&mut bs2, repo.root()).await.unwrap();
625 assert!(repo2.get::<bsky::feed::Post>(rkey.clone()).await.is_ok());
626
627 let cb = repo.delete::<bsky::feed::Post>(rkey.clone()).await.unwrap();
628 let sig = keypair.sign(&cb.bytes()).unwrap();
629 let cid = cb.finalize(sig).await.unwrap();
630
631 let cids =
634 repo.extract::<bsky::feed::Post>(rkey.clone()).await.unwrap().collect::<HashSet<_>>();
635
636 assert!(cids.contains(&cid)); assert!(cids.contains(
638 &Cid::from_str("bafyreie5737gdxlw5i64vzichcalba3z2v5n6icifvx5xytvske7mr3hpm").unwrap()
640 ))
641 }
642
643 #[tokio::test]
644 async fn test_extract_complex() {
645 let mut bs = MemoryBlockStore::new();
646
647 let keypair = P256Keypair::create(&mut rand::thread_rng());
649
650 let mut repo =
652 create_repo(&mut bs, Did::new("did:web:pds.abc.com".to_string()).unwrap(), &keypair)
653 .await;
654
655 let mut records = Vec::new();
656
657 for i in 0..10 {
658 let rkey = loop {
660 let rkey = RecordKey::new(Tid::now(LimitedU32::MIN).to_string()).unwrap();
661 if !records.contains(&rkey) {
662 break rkey;
663 }
664 };
665
666 let (cb, _) = repo
667 .add::<bsky::feed::Post>(
668 rkey.clone(),
669 bsky::feed::post::RecordData {
670 created_at: Datetime::from_str("2025-02-01T00:00:00.000Z").unwrap(),
671 embed: None,
672 entities: None,
673 facets: None,
674 labels: None,
675 langs: None,
676 reply: None,
677 tags: None,
678 text: format!("Hello world, post {i}"),
679 }
680 .into(),
681 )
682 .await
683 .unwrap();
684
685 let sig = keypair.sign(&cb.bytes()).unwrap();
686 cb.finalize(sig).await.unwrap();
687
688 records.push(rkey.clone());
689 }
690
691 for record in records {
692 let mut bs2 = MemoryBlockStore::new();
693 repo.extract_into::<bsky::feed::Post>(record.clone(), &mut bs2).await.unwrap();
694
695 assert!(bs2.contains(repo.root()));
696 assert!(bs2.contains(repo.commit().data()));
697
698 let mut repo2 = Repository::open(&mut bs2, repo.root()).await.unwrap();
699 assert!(repo2.get::<bsky::feed::Post>(record.clone()).await.is_ok());
700 }
701 }
702}