iroh_blobs/
rpc.rs

1//! Provides a rpc protocol as well as a client for the protocol
2
3use std::{
4    io,
5    ops::Deref,
6    sync::{Arc, Mutex},
7};
8
9use anyhow::anyhow;
10use client::{
11    blobs::{self, BlobInfo, BlobStatus, DownloadMode, IncompleteBlobInfo, MemClient, WrapOption},
12    tags::TagInfo,
13    MemConnector,
14};
15use futures_buffered::BufferedStreamExt;
16use futures_lite::StreamExt;
17use futures_util::{FutureExt, Stream};
18use genawaiter::sync::{Co, Gen};
19use iroh::{Endpoint, NodeAddr};
20use iroh_io::AsyncSliceReader;
21use proto::{
22    blobs::{
23        AddPathRequest, AddPathResponse, AddStreamRequest, AddStreamResponse, AddStreamUpdate,
24        BatchAddPathRequest, BatchAddPathResponse, BatchAddStreamRequest, BatchAddStreamResponse,
25        BatchAddStreamUpdate, BatchCreateRequest, BatchCreateResponse, BatchCreateTempTagRequest,
26        BatchUpdate, BlobStatusRequest, BlobStatusResponse, ConsistencyCheckRequest,
27        CreateCollectionRequest, CreateCollectionResponse, DeleteRequest, DownloadResponse,
28        ExportRequest, ExportResponse, ListIncompleteRequest, ListRequest, ReadAtRequest,
29        ReadAtResponse, ValidateRequest,
30    },
31    tags::{
32        CreateRequest as TagsCreateRequest, DeleteRequest as TagDeleteRequest,
33        ListRequest as TagListRequest, RenameRequest, SetRequest as TagsSetRequest, SyncMode,
34    },
35    Request, RpcError, RpcResult, RpcService,
36};
37use quic_rpc::{
38    server::{ChannelTypes, RpcChannel, RpcServerError},
39    RpcClient, RpcServer,
40};
41use tokio_util::task::AbortOnDropHandle;
42use tracing::{debug, warn};
43
44use crate::{
45    downloader::{DownloadRequest, Downloader},
46    export::ExportProgress,
47    format::collection::Collection,
48    get::{
49        db::{DownloadProgress, GetState},
50        Stats,
51    },
52    net_protocol::{BlobDownloadRequest, Blobs, BlobsInner},
53    provider::{AddProgress, BatchAddPathProgress},
54    store::{ConsistencyCheckProgress, ImportProgress, MapEntry, ValidateProgress},
55    util::{
56        local_pool::LocalPoolHandle,
57        progress::{AsyncChannelProgressSender, ProgressSender},
58        SetTagOption,
59    },
60    BlobFormat, HashAndFormat, Tag,
61};
62pub mod client;
63pub mod proto;
64
65/// Chunk size for getting blobs over RPC
66const RPC_BLOB_GET_CHUNK_SIZE: usize = 1024 * 64;
67/// Channel cap for getting blobs over RPC
68const RPC_BLOB_GET_CHANNEL_CAP: usize = 2;
69
70impl<D: crate::store::Store> Blobs<D> {
71    /// Get a client for the blobs protocol
72    pub fn client(&self) -> &blobs::MemClient {
73        &self
74            .rpc_handler
75            .get_or_init(|| RpcHandler::new(&self.inner))
76            .client
77    }
78
79    /// Handle an RPC request
80    pub async fn handle_rpc_request<C>(
81        self,
82        msg: Request,
83        chan: RpcChannel<RpcService, C>,
84    ) -> std::result::Result<(), RpcServerError<C>>
85    where
86        C: ChannelTypes<RpcService>,
87    {
88        Handler(self.inner.clone())
89            .handle_rpc_request(msg, chan)
90            .await
91    }
92}
93
94/// This is just an internal helper so I don't have to
95/// define all the rpc methods on `self: Arc<BlobsInner<S>>`
96#[derive(Clone)]
97struct Handler<S>(Arc<BlobsInner<S>>);
98
99impl<S> Deref for Handler<S> {
100    type Target = BlobsInner<S>;
101
102    fn deref(&self) -> &Self::Target {
103        &self.0
104    }
105}
106
107impl<D: crate::store::Store> Handler<D> {
108    fn store(&self) -> &D {
109        &self.0.store
110    }
111
112    fn rt(&self) -> &LocalPoolHandle {
113        self.0.rt()
114    }
115
116    fn endpoint(&self) -> &Endpoint {
117        &self.0.endpoint
118    }
119
120    fn downloader(&self) -> &Downloader {
121        &self.0.downloader
122    }
123
124    #[cfg(feature = "rpc")]
125    pub(crate) async fn batches(
126        &self,
127    ) -> tokio::sync::MutexGuard<'_, crate::net_protocol::BlobBatches> {
128        self.0.batches.lock().await
129    }
130
131    /// Handle an RPC request
132    pub async fn handle_rpc_request<C>(
133        self,
134        msg: Request,
135        chan: RpcChannel<RpcService, C>,
136    ) -> std::result::Result<(), RpcServerError<C>>
137    where
138        C: ChannelTypes<RpcService>,
139    {
140        match msg {
141            Request::Blobs(msg) => self.handle_blobs_request(msg, chan).await,
142            Request::Tags(msg) => self.handle_tags_request(msg, chan).await,
143        }
144    }
145
146    /// Handle a tags request
147    pub async fn handle_tags_request<C>(
148        self,
149        msg: proto::tags::Request,
150        chan: RpcChannel<proto::RpcService, C>,
151    ) -> std::result::Result<(), RpcServerError<C>>
152    where
153        C: ChannelTypes<proto::RpcService>,
154    {
155        use proto::tags::Request::*;
156        match msg {
157            Create(msg) => chan.rpc(msg, self, Self::tags_create).await,
158            Set(msg) => chan.rpc(msg, self, Self::tags_set).await,
159            DeleteTag(msg) => chan.rpc(msg, self, Self::blob_delete_tag).await,
160            ListTags(msg) => chan.server_streaming(msg, self, Self::blob_list_tags).await,
161            Rename(msg) => chan.rpc(msg, self, Self::tags_rename).await,
162        }
163    }
164
165    /// Handle a blobs request
166    pub async fn handle_blobs_request<C>(
167        self,
168        msg: proto::blobs::Request,
169        chan: RpcChannel<proto::RpcService, C>,
170    ) -> std::result::Result<(), RpcServerError<C>>
171    where
172        C: ChannelTypes<proto::RpcService>,
173    {
174        use proto::blobs::Request::*;
175        match msg {
176            List(msg) => chan.server_streaming(msg, self, Self::blob_list).await,
177            ListIncomplete(msg) => {
178                chan.server_streaming(msg, self, Self::blob_list_incomplete)
179                    .await
180            }
181            CreateCollection(msg) => chan.rpc(msg, self, Self::create_collection).await,
182            Delete(msg) => chan.rpc(msg, self, Self::blob_delete_blob).await,
183            AddPath(msg) => {
184                chan.server_streaming(msg, self, Self::blob_add_from_path)
185                    .await
186            }
187            Download(msg) => chan.server_streaming(msg, self, Self::blob_download).await,
188            Export(msg) => chan.server_streaming(msg, self, Self::blob_export).await,
189            Validate(msg) => chan.server_streaming(msg, self, Self::blob_validate).await,
190            Fsck(msg) => {
191                chan.server_streaming(msg, self, Self::blob_consistency_check)
192                    .await
193            }
194            ReadAt(msg) => chan.server_streaming(msg, self, Self::blob_read_at).await,
195            AddStream(msg) => chan.bidi_streaming(msg, self, Self::blob_add_stream).await,
196            AddStreamUpdate(_msg) => Err(RpcServerError::UnexpectedUpdateMessage),
197            BlobStatus(msg) => chan.rpc(msg, self, Self::blob_status).await,
198            BatchCreate(msg) => chan.bidi_streaming(msg, self, Self::batch_create).await,
199            BatchUpdate(_) => Err(RpcServerError::UnexpectedStartMessage),
200            BatchAddStream(msg) => chan.bidi_streaming(msg, self, Self::batch_add_stream).await,
201            BatchAddStreamUpdate(_) => Err(RpcServerError::UnexpectedStartMessage),
202            BatchAddPath(msg) => {
203                chan.server_streaming(msg, self, Self::batch_add_from_path)
204                    .await
205            }
206            BatchCreateTempTag(msg) => chan.rpc(msg, self, Self::batch_create_temp_tag).await,
207        }
208    }
209
210    async fn blob_status(self, msg: BlobStatusRequest) -> RpcResult<BlobStatusResponse> {
211        let blobs = self;
212        let entry = blobs
213            .store()
214            .get(&msg.hash)
215            .await
216            .map_err(|e| RpcError::new(&e))?;
217        Ok(BlobStatusResponse(match entry {
218            Some(entry) => {
219                if entry.is_complete() {
220                    BlobStatus::Complete {
221                        size: entry.size().value(),
222                    }
223                } else {
224                    BlobStatus::Partial { size: entry.size() }
225                }
226            }
227            None => BlobStatus::NotFound,
228        }))
229    }
230
231    async fn blob_list_impl(self, co: &Co<RpcResult<BlobInfo>>) -> io::Result<()> {
232        use bao_tree::io::fsm::Outboard;
233
234        let blobs = self;
235        let db = blobs.store();
236        for blob in db.blobs().await? {
237            let blob = blob?;
238            let Some(entry) = db.get(&blob).await? else {
239                continue;
240            };
241            let hash = entry.hash();
242            let size = entry.outboard().await?.tree().size();
243            let path = "".to_owned();
244            co.yield_(Ok(BlobInfo { hash, size, path })).await;
245        }
246        Ok(())
247    }
248
249    async fn blob_list_incomplete_impl(
250        self,
251        co: &Co<RpcResult<IncompleteBlobInfo>>,
252    ) -> io::Result<()> {
253        let blobs = self;
254        let db = blobs.store();
255        for hash in db.partial_blobs().await? {
256            let hash = hash?;
257            let Ok(Some(entry)) = db.get_mut(&hash).await else {
258                continue;
259            };
260            if entry.is_complete() {
261                continue;
262            }
263            let size = 0;
264            let expected_size = entry.size().value();
265            co.yield_(Ok(IncompleteBlobInfo {
266                hash,
267                size,
268                expected_size,
269            }))
270            .await;
271        }
272        Ok(())
273    }
274
275    fn blob_list(
276        self,
277        _msg: ListRequest,
278    ) -> impl Stream<Item = RpcResult<BlobInfo>> + Send + 'static {
279        Gen::new(|co| async move {
280            if let Err(e) = self.blob_list_impl(&co).await {
281                co.yield_(Err(RpcError::new(&e))).await;
282            }
283        })
284    }
285
286    fn blob_list_incomplete(
287        self,
288        _msg: ListIncompleteRequest,
289    ) -> impl Stream<Item = RpcResult<IncompleteBlobInfo>> + Send + 'static {
290        Gen::new(move |co| async move {
291            if let Err(e) = self.blob_list_incomplete_impl(&co).await {
292                co.yield_(Err(RpcError::new(&e))).await;
293            }
294        })
295    }
296
297    async fn blob_delete_tag(self, msg: TagDeleteRequest) -> RpcResult<()> {
298        self.store()
299            .delete_tags(msg.from, msg.to)
300            .await
301            .map_err(|e| RpcError::new(&e))?;
302        Ok(())
303    }
304
305    async fn blob_delete_blob(self, msg: DeleteRequest) -> RpcResult<()> {
306        self.store()
307            .delete(vec![msg.hash])
308            .await
309            .map_err(|e| RpcError::new(&e))?;
310        Ok(())
311    }
312
313    fn blob_list_tags(self, msg: TagListRequest) -> impl Stream<Item = TagInfo> + Send + 'static {
314        tracing::info!("blob_list_tags");
315        let blobs = self;
316        Gen::new(|co| async move {
317            let tags = blobs.store().tags(msg.from, msg.to).await.unwrap();
318            #[allow(clippy::manual_flatten)]
319            for item in tags {
320                if let Ok((name, HashAndFormat { hash, format })) = item {
321                    if (format.is_raw() && msg.raw) || (format.is_hash_seq() && msg.hash_seq) {
322                        co.yield_(TagInfo { name, hash, format }).await;
323                    }
324                }
325            }
326        })
327    }
328
329    /// Invoke validate on the database and stream out the result
330    fn blob_validate(
331        self,
332        msg: ValidateRequest,
333    ) -> impl Stream<Item = ValidateProgress> + Send + 'static {
334        let (tx, rx) = async_channel::bounded(1);
335        let tx2 = tx.clone();
336        let blobs = self;
337        tokio::task::spawn(async move {
338            if let Err(e) = blobs
339                .store()
340                .validate(msg.repair, AsyncChannelProgressSender::new(tx).boxed())
341                .await
342            {
343                tx2.send(ValidateProgress::Abort(RpcError::new(&e)))
344                    .await
345                    .ok();
346            }
347        });
348        rx
349    }
350
351    /// Invoke validate on the database and stream out the result
352    fn blob_consistency_check(
353        self,
354        msg: ConsistencyCheckRequest,
355    ) -> impl Stream<Item = ConsistencyCheckProgress> + Send + 'static {
356        let (tx, rx) = async_channel::bounded(1);
357        let tx2 = tx.clone();
358        let blobs = self;
359        tokio::task::spawn(async move {
360            if let Err(e) = blobs
361                .store()
362                .consistency_check(msg.repair, AsyncChannelProgressSender::new(tx).boxed())
363                .await
364            {
365                tx2.send(ConsistencyCheckProgress::Abort(RpcError::new(&e)))
366                    .await
367                    .ok();
368            }
369        });
370        rx
371    }
372
373    fn blob_add_from_path(self, msg: AddPathRequest) -> impl Stream<Item = AddPathResponse> {
374        // provide a little buffer so that we don't slow down the sender
375        let (tx, rx) = async_channel::bounded(32);
376        let tx2 = tx.clone();
377        let rt = self.rt().clone();
378        rt.spawn_detached(|| async move {
379            if let Err(e) = self.blob_add_from_path0(msg, tx).await {
380                tx2.send(AddProgress::Abort(RpcError::new(&*e))).await.ok();
381            }
382        });
383        rx.map(AddPathResponse)
384    }
385
386    async fn tags_rename(self, msg: RenameRequest) -> RpcResult<()> {
387        let blobs = self;
388        blobs
389            .store()
390            .rename_tag(msg.from, msg.to)
391            .await
392            .map_err(|e| RpcError::new(&e))?;
393        Ok(())
394    }
395
396    async fn tags_set(self, msg: TagsSetRequest) -> RpcResult<()> {
397        let blobs = self;
398        blobs
399            .store()
400            .set_tag(msg.name, msg.value)
401            .await
402            .map_err(|e| RpcError::new(&e))?;
403        if let SyncMode::Full = msg.sync {
404            blobs.store().sync().await.map_err(|e| RpcError::new(&e))?;
405        }
406        if let Some(batch) = msg.batch {
407            blobs
408                .batches()
409                .await
410                .remove_one(batch, &msg.value)
411                .map_err(|e| RpcError::new(&*e))?;
412        }
413        Ok(())
414    }
415
416    async fn tags_create(self, msg: TagsCreateRequest) -> RpcResult<Tag> {
417        let blobs = self;
418        let tag = blobs
419            .store()
420            .create_tag(msg.value)
421            .await
422            .map_err(|e| RpcError::new(&e))?;
423        if let SyncMode::Full = msg.sync {
424            blobs.store().sync().await.map_err(|e| RpcError::new(&e))?;
425        }
426        if let Some(batch) = msg.batch {
427            blobs
428                .batches()
429                .await
430                .remove_one(batch, &msg.value)
431                .map_err(|e| RpcError::new(&*e))?;
432        }
433        Ok(tag)
434    }
435
436    fn blob_download(self, msg: BlobDownloadRequest) -> impl Stream<Item = DownloadResponse> {
437        let (sender, receiver) = async_channel::bounded(1024);
438        let endpoint = self.endpoint().clone();
439        let progress = AsyncChannelProgressSender::new(sender);
440
441        let blobs_protocol = self.clone();
442
443        self.rt().spawn_detached(move || async move {
444            if let Err(err) = blobs_protocol
445                .download(endpoint, msg, progress.clone())
446                .await
447            {
448                progress
449                    .send(DownloadProgress::Abort(RpcError::new(&*err)))
450                    .await
451                    .ok();
452            }
453        });
454
455        receiver.map(DownloadResponse)
456    }
457
458    fn blob_export(self, msg: ExportRequest) -> impl Stream<Item = ExportResponse> {
459        let (tx, rx) = async_channel::bounded(1024);
460        let progress = AsyncChannelProgressSender::new(tx);
461        let rt = self.rt().clone();
462        rt.spawn_detached(move || async move {
463            let res = crate::export::export(
464                self.store(),
465                msg.hash,
466                msg.path,
467                msg.format,
468                msg.mode,
469                progress.clone(),
470            )
471            .await;
472            match res {
473                Ok(()) => progress.send(ExportProgress::AllDone).await.ok(),
474                Err(err) => progress
475                    .send(ExportProgress::Abort(RpcError::new(&*err)))
476                    .await
477                    .ok(),
478            };
479        });
480        rx.map(ExportResponse)
481    }
482
483    async fn blob_add_from_path0(
484        self,
485        msg: AddPathRequest,
486        progress: async_channel::Sender<AddProgress>,
487    ) -> anyhow::Result<()> {
488        use std::collections::BTreeMap;
489
490        use crate::store::ImportMode;
491
492        let blobs = self.clone();
493        let progress = AsyncChannelProgressSender::new(progress);
494        let names = Arc::new(Mutex::new(BTreeMap::new()));
495        // convert import progress to provide progress
496        let import_progress = progress.clone().with_filter_map(move |x| match x {
497            ImportProgress::Found { id, name } => {
498                names.lock().unwrap().insert(id, name);
499                None
500            }
501            ImportProgress::Size { id, size } => {
502                let name = names.lock().unwrap().remove(&id)?;
503                Some(AddProgress::Found { id, name, size })
504            }
505            ImportProgress::OutboardProgress { id, offset } => {
506                Some(AddProgress::Progress { id, offset })
507            }
508            ImportProgress::OutboardDone { hash, id } => Some(AddProgress::Done { hash, id }),
509            _ => None,
510        });
511        let AddPathRequest {
512            wrap,
513            path: root,
514            in_place,
515            tag,
516        } = msg;
517        // Check that the path is absolute and exists.
518        anyhow::ensure!(root.is_absolute(), "path must be absolute");
519        anyhow::ensure!(
520            root.exists(),
521            "trying to add missing path: {}",
522            root.display()
523        );
524
525        let import_mode = match in_place {
526            true => ImportMode::TryReference,
527            false => ImportMode::Copy,
528        };
529
530        let create_collection = match wrap {
531            WrapOption::Wrap { .. } => true,
532            WrapOption::NoWrap => root.is_dir(),
533        };
534
535        let temp_tag = if create_collection {
536            // import all files below root recursively
537            let data_sources = crate::util::fs::scan_path(root, wrap)?;
538            let blobs = self;
539
540            const IO_PARALLELISM: usize = 4;
541            let result: Vec<_> = futures_lite::stream::iter(data_sources)
542                .map(|source| {
543                    let import_progress = import_progress.clone();
544                    let blobs = blobs.clone();
545                    async move {
546                        let name = source.name().to_string();
547                        let (tag, size) = blobs
548                            .store()
549                            .import_file(
550                                source.path().to_owned(),
551                                import_mode,
552                                BlobFormat::Raw,
553                                import_progress,
554                            )
555                            .await?;
556                        let hash = *tag.hash();
557                        io::Result::Ok((name, hash, size, tag))
558                    }
559                })
560                .buffered_ordered(IO_PARALLELISM)
561                .try_collect()
562                .await?;
563
564            // create a collection
565            let (collection, _child_tags): (Collection, Vec<_>) = result
566                .into_iter()
567                .map(|(name, hash, _, tag)| ((name, hash), tag))
568                .unzip();
569
570            collection.store(blobs.store()).await?
571        } else {
572            // import a single file
573            let (tag, _size) = blobs
574                .store()
575                .import_file(root, import_mode, BlobFormat::Raw, import_progress)
576                .await?;
577            tag
578        };
579
580        let hash_and_format = temp_tag.inner();
581        let HashAndFormat { hash, format } = *hash_and_format;
582        let tag = match tag {
583            SetTagOption::Named(tag) => {
584                blobs.store().set_tag(tag.clone(), *hash_and_format).await?;
585                tag
586            }
587            SetTagOption::Auto => blobs.store().create_tag(*hash_and_format).await?,
588        };
589        progress
590            .send(AddProgress::AllDone {
591                hash,
592                format,
593                tag: tag.clone(),
594            })
595            .await?;
596        Ok(())
597    }
598
599    async fn batch_create_temp_tag(self, msg: BatchCreateTempTagRequest) -> RpcResult<()> {
600        let blobs = self;
601        let tag = blobs.store().temp_tag(msg.content);
602        blobs.batches().await.store(msg.batch, tag);
603        Ok(())
604    }
605
606    fn batch_add_stream(
607        self,
608        msg: BatchAddStreamRequest,
609        stream: impl Stream<Item = BatchAddStreamUpdate> + Send + Unpin + 'static,
610    ) -> impl Stream<Item = BatchAddStreamResponse> {
611        let (tx, rx) = async_channel::bounded(32);
612        let this = self.clone();
613
614        self.rt().spawn_detached(|| async move {
615            if let Err(err) = this.batch_add_stream0(msg, stream, tx.clone()).await {
616                tx.send(BatchAddStreamResponse::Abort(RpcError::new(&*err)))
617                    .await
618                    .ok();
619            }
620        });
621        rx
622    }
623
624    fn batch_add_from_path(
625        self,
626        msg: BatchAddPathRequest,
627    ) -> impl Stream<Item = BatchAddPathResponse> {
628        // provide a little buffer so that we don't slow down the sender
629        let (tx, rx) = async_channel::bounded(32);
630        let tx2 = tx.clone();
631        let this = self.clone();
632        self.rt().spawn_detached(|| async move {
633            if let Err(e) = this.batch_add_from_path0(msg, tx).await {
634                tx2.send(BatchAddPathProgress::Abort(RpcError::new(&*e)))
635                    .await
636                    .ok();
637            }
638        });
639        rx.map(BatchAddPathResponse)
640    }
641
642    async fn batch_add_stream0(
643        self,
644        msg: BatchAddStreamRequest,
645        stream: impl Stream<Item = BatchAddStreamUpdate> + Send + Unpin + 'static,
646        progress: async_channel::Sender<BatchAddStreamResponse>,
647    ) -> anyhow::Result<()> {
648        let blobs = self;
649        let progress = AsyncChannelProgressSender::new(progress);
650
651        let stream = stream.map(|item| match item {
652            BatchAddStreamUpdate::Chunk(chunk) => Ok(chunk),
653            BatchAddStreamUpdate::Abort => {
654                Err(io::Error::new(io::ErrorKind::Interrupted, "Remote abort"))
655            }
656        });
657
658        let import_progress = progress.clone().with_filter_map(move |x| match x {
659            ImportProgress::OutboardProgress { offset, .. } => {
660                Some(BatchAddStreamResponse::OutboardProgress { offset })
661            }
662            _ => None,
663        });
664        let (temp_tag, _len) = blobs
665            .store()
666            .import_stream(stream, msg.format, import_progress)
667            .await?;
668        let hash = temp_tag.inner().hash;
669        blobs.batches().await.store(msg.batch, temp_tag);
670        progress
671            .send(BatchAddStreamResponse::Result { hash })
672            .await?;
673        Ok(())
674    }
675
676    async fn batch_add_from_path0(
677        self,
678        msg: BatchAddPathRequest,
679        progress: async_channel::Sender<BatchAddPathProgress>,
680    ) -> anyhow::Result<()> {
681        let progress = AsyncChannelProgressSender::new(progress);
682        // convert import progress to provide progress
683        let import_progress = progress.clone().with_filter_map(move |x| match x {
684            ImportProgress::Size { size, .. } => Some(BatchAddPathProgress::Found { size }),
685            ImportProgress::OutboardProgress { offset, .. } => {
686                Some(BatchAddPathProgress::Progress { offset })
687            }
688            ImportProgress::OutboardDone { hash, .. } => Some(BatchAddPathProgress::Done { hash }),
689            _ => None,
690        });
691        let BatchAddPathRequest {
692            path: root,
693            import_mode,
694            format,
695            batch,
696        } = msg;
697        // Check that the path is absolute and exists.
698        anyhow::ensure!(root.is_absolute(), "path must be absolute");
699        anyhow::ensure!(
700            root.exists(),
701            "trying to add missing path: {}",
702            root.display()
703        );
704        let blobs = self;
705        let (tag, _) = blobs
706            .store()
707            .import_file(root, import_mode, format, import_progress)
708            .await?;
709        let hash = *tag.hash();
710        blobs.batches().await.store(batch, tag);
711
712        progress.send(BatchAddPathProgress::Done { hash }).await?;
713        Ok(())
714    }
715
716    fn blob_add_stream(
717        self,
718        msg: AddStreamRequest,
719        stream: impl Stream<Item = AddStreamUpdate> + Send + Unpin + 'static,
720    ) -> impl Stream<Item = AddStreamResponse> {
721        let (tx, rx) = async_channel::bounded(32);
722        let this = self.clone();
723
724        self.rt().spawn_detached(|| async move {
725            if let Err(err) = this.blob_add_stream0(msg, stream, tx.clone()).await {
726                tx.send(AddProgress::Abort(RpcError::new(&*err))).await.ok();
727            }
728        });
729
730        rx.map(AddStreamResponse)
731    }
732
733    async fn blob_add_stream0(
734        self,
735        msg: AddStreamRequest,
736        stream: impl Stream<Item = AddStreamUpdate> + Send + Unpin + 'static,
737        progress: async_channel::Sender<AddProgress>,
738    ) -> anyhow::Result<()> {
739        let progress = AsyncChannelProgressSender::new(progress);
740
741        let stream = stream.map(|item| match item {
742            AddStreamUpdate::Chunk(chunk) => Ok(chunk),
743            AddStreamUpdate::Abort => {
744                Err(io::Error::new(io::ErrorKind::Interrupted, "Remote abort"))
745            }
746        });
747
748        let name_cache = Arc::new(Mutex::new(None));
749        let import_progress = progress.clone().with_filter_map(move |x| match x {
750            ImportProgress::Found { id: _, name } => {
751                let _ = name_cache.lock().unwrap().insert(name);
752                None
753            }
754            ImportProgress::Size { id, size } => {
755                let name = name_cache.lock().unwrap().take()?;
756                Some(AddProgress::Found { id, name, size })
757            }
758            ImportProgress::OutboardProgress { id, offset } => {
759                Some(AddProgress::Progress { id, offset })
760            }
761            ImportProgress::OutboardDone { hash, id } => Some(AddProgress::Done { hash, id }),
762            _ => None,
763        });
764        let blobs = self;
765        let (temp_tag, _len) = blobs
766            .store()
767            .import_stream(stream, BlobFormat::Raw, import_progress)
768            .await?;
769        let hash_and_format = *temp_tag.inner();
770        let HashAndFormat { hash, format } = hash_and_format;
771        let tag = match msg.tag {
772            SetTagOption::Named(tag) => {
773                blobs.store().set_tag(tag.clone(), hash_and_format).await?;
774                tag
775            }
776            SetTagOption::Auto => blobs.store().create_tag(hash_and_format).await?,
777        };
778        progress
779            .send(AddProgress::AllDone { hash, tag, format })
780            .await?;
781        Ok(())
782    }
783
784    fn blob_read_at(
785        self,
786        req: ReadAtRequest,
787    ) -> impl Stream<Item = RpcResult<ReadAtResponse>> + Send + 'static {
788        let (tx, rx) = async_channel::bounded(RPC_BLOB_GET_CHANNEL_CAP);
789        let db = self.store().clone();
790        self.rt().spawn_detached(move || async move {
791            if let Err(err) = read_loop(req, db, tx.clone(), RPC_BLOB_GET_CHUNK_SIZE).await {
792                tx.send(RpcResult::Err(RpcError::new(&*err))).await.ok();
793            }
794        });
795
796        async fn read_loop<D: crate::store::Store>(
797            req: ReadAtRequest,
798            db: D,
799            tx: async_channel::Sender<RpcResult<ReadAtResponse>>,
800            max_chunk_size: usize,
801        ) -> anyhow::Result<()> {
802            let entry = db.get(&req.hash).await?;
803            let entry = entry.ok_or_else(|| anyhow!("Blob not found"))?;
804            let size = entry.size();
805
806            anyhow::ensure!(
807                req.offset <= size.value(),
808                "requested offset is out of range: {} > {:?}",
809                req.offset,
810                size
811            );
812
813            let len: usize = req
814                .len
815                .as_result_len(size.value() - req.offset)
816                .try_into()?;
817
818            anyhow::ensure!(
819                req.offset + len as u64 <= size.value(),
820                "requested range is out of bounds: offset: {}, len: {} > {:?}",
821                req.offset,
822                len,
823                size
824            );
825
826            tx.send(Ok(ReadAtResponse::Entry {
827                size,
828                is_complete: entry.is_complete(),
829            }))
830            .await?;
831            let mut reader = entry.data_reader().await?;
832
833            let (num_chunks, chunk_size) = if len <= max_chunk_size {
834                (1, len)
835            } else {
836                let num_chunks = len / max_chunk_size + (len % max_chunk_size != 0) as usize;
837                (num_chunks, max_chunk_size)
838            };
839
840            let mut read = 0u64;
841            for i in 0..num_chunks {
842                let chunk_size = if i == num_chunks - 1 {
843                    // last chunk might be smaller
844                    len - read as usize
845                } else {
846                    chunk_size
847                };
848                let chunk = reader.read_at(req.offset + read, chunk_size).await?;
849                let chunk_len = chunk.len();
850                if !chunk.is_empty() {
851                    tx.send(Ok(ReadAtResponse::Data { chunk })).await?;
852                }
853                if chunk_len < chunk_size {
854                    break;
855                } else {
856                    read += chunk_len as u64;
857                }
858            }
859            Ok(())
860        }
861
862        rx
863    }
864
865    fn batch_create(
866        self,
867        _: BatchCreateRequest,
868        mut updates: impl Stream<Item = BatchUpdate> + Send + Unpin + 'static,
869    ) -> impl Stream<Item = BatchCreateResponse> {
870        let blobs = self;
871        async move {
872            let batch = blobs.batches().await.create();
873            tokio::spawn(async move {
874                while let Some(item) = updates.next().await {
875                    match item {
876                        BatchUpdate::Drop(content) => {
877                            // this can not fail, since we keep the batch alive.
878                            // therefore it is safe to ignore the result.
879                            let _ = blobs.batches().await.remove_one(batch, &content);
880                        }
881                        BatchUpdate::Ping => {}
882                    }
883                }
884                blobs.batches().await.remove(batch);
885            });
886            BatchCreateResponse::Id(batch)
887        }
888        .into_stream()
889    }
890
891    async fn create_collection(
892        self,
893        req: CreateCollectionRequest,
894    ) -> RpcResult<CreateCollectionResponse> {
895        let CreateCollectionRequest {
896            collection,
897            tag,
898            tags_to_delete,
899        } = req;
900
901        let blobs = self;
902
903        let temp_tag = collection
904            .store(blobs.store())
905            .await
906            .map_err(|e| RpcError::new(&*e))?;
907        let hash_and_format = temp_tag.inner();
908        let HashAndFormat { hash, .. } = *hash_and_format;
909        let tag = match tag {
910            SetTagOption::Named(tag) => {
911                blobs
912                    .store()
913                    .set_tag(tag.clone(), *hash_and_format)
914                    .await
915                    .map_err(|e| RpcError::new(&e))?;
916                tag
917            }
918            SetTagOption::Auto => blobs
919                .store()
920                .create_tag(*hash_and_format)
921                .await
922                .map_err(|e| RpcError::new(&e))?,
923        };
924
925        for tag in tags_to_delete {
926            blobs
927                .store()
928                .delete_tags(Some(tag.clone()), Some(tag.successor()))
929                .await
930                .map_err(|e| RpcError::new(&e))?;
931        }
932
933        Ok(CreateCollectionResponse { hash, tag })
934    }
935
936    pub(crate) async fn download(
937        &self,
938        endpoint: Endpoint,
939        req: BlobDownloadRequest,
940        progress: AsyncChannelProgressSender<DownloadProgress>,
941    ) -> anyhow::Result<()> {
942        let BlobDownloadRequest {
943            hash,
944            format,
945            nodes,
946            tag,
947            mode,
948        } = req;
949        let hash_and_format = HashAndFormat { hash, format };
950        let temp_tag = self.store().temp_tag(hash_and_format);
951        let stats = match mode {
952            DownloadMode::Queued => {
953                self.download_queued(endpoint, hash_and_format, nodes, progress.clone())
954                    .await?
955            }
956            DownloadMode::Direct => {
957                self.download_direct_from_nodes(endpoint, hash_and_format, nodes, progress.clone())
958                    .await?
959            }
960        };
961
962        progress.send(DownloadProgress::AllDone(stats)).await.ok();
963        match tag {
964            SetTagOption::Named(tag) => {
965                self.store().set_tag(tag, hash_and_format).await?;
966            }
967            SetTagOption::Auto => {
968                self.store().create_tag(hash_and_format).await?;
969            }
970        }
971        drop(temp_tag);
972
973        Ok(())
974    }
975
976    async fn download_queued(
977        &self,
978        endpoint: Endpoint,
979        hash_and_format: HashAndFormat,
980        nodes: Vec<NodeAddr>,
981        progress: AsyncChannelProgressSender<DownloadProgress>,
982    ) -> anyhow::Result<Stats> {
983        /// Name used for logging when new node addresses are added from gossip.
984        const BLOB_DOWNLOAD_SOURCE_NAME: &str = "blob_download";
985
986        let mut node_ids = Vec::with_capacity(nodes.len());
987        let mut any_added = false;
988        for node in nodes {
989            node_ids.push(node.node_id);
990            if !node.is_empty() {
991                endpoint.add_node_addr_with_source(node, BLOB_DOWNLOAD_SOURCE_NAME)?;
992                any_added = true;
993            }
994        }
995        let can_download = !node_ids.is_empty() && (any_added || endpoint.discovery().is_some());
996        anyhow::ensure!(can_download, "no way to reach a node for download");
997        let req = DownloadRequest::new(hash_and_format, node_ids).progress_sender(progress);
998        let handle = self.downloader().queue(req).await;
999        let stats = handle.await?;
1000        Ok(stats)
1001    }
1002
1003    #[tracing::instrument("download_direct", skip_all, fields(hash=%hash_and_format.hash.fmt_short()))]
1004    async fn download_direct_from_nodes(
1005        &self,
1006        endpoint: Endpoint,
1007        hash_and_format: HashAndFormat,
1008        nodes: Vec<NodeAddr>,
1009        progress: AsyncChannelProgressSender<DownloadProgress>,
1010    ) -> anyhow::Result<Stats> {
1011        let mut last_err = None;
1012        let mut remaining_nodes = nodes.len();
1013        let mut nodes_iter = nodes.into_iter();
1014        'outer: loop {
1015            match crate::get::db::get_to_db_in_steps(
1016                self.store().clone(),
1017                hash_and_format,
1018                progress.clone(),
1019            )
1020            .await?
1021            {
1022                GetState::Complete(stats) => return Ok(stats),
1023                GetState::NeedsConn(needs_conn) => {
1024                    let (conn, node_id) = 'inner: loop {
1025                        match nodes_iter.next() {
1026                            None => break 'outer,
1027                            Some(node) => {
1028                                remaining_nodes -= 1;
1029                                let node_id = node.node_id;
1030                                if node_id == endpoint.node_id() {
1031                                    debug!(
1032                                        ?remaining_nodes,
1033                                        "skip node {} (it is the node id of ourselves)",
1034                                        node_id.fmt_short()
1035                                    );
1036                                    continue 'inner;
1037                                }
1038                                match endpoint.connect(node, crate::protocol::ALPN).await {
1039                                    Ok(conn) => break 'inner (conn, node_id),
1040                                    Err(err) => {
1041                                        debug!(
1042                                            ?remaining_nodes,
1043                                            "failed to connect to {}: {err}",
1044                                            node_id.fmt_short()
1045                                        );
1046                                        continue 'inner;
1047                                    }
1048                                }
1049                            }
1050                        }
1051                    };
1052                    match needs_conn.proceed(conn).await {
1053                        Ok(stats) => return Ok(stats),
1054                        Err(err) => {
1055                            warn!(
1056                                ?remaining_nodes,
1057                                "failed to download from {}: {err}",
1058                                node_id.fmt_short()
1059                            );
1060                            last_err = Some(err);
1061                        }
1062                    }
1063                }
1064            }
1065        }
1066        match last_err {
1067            Some(err) => Err(err.into()),
1068            None => Err(anyhow!("No nodes to download from provided")),
1069        }
1070    }
1071}
1072
1073/// An in memory rpc handler for the blobs rpc protocol
1074///
1075/// This struct contains both a task that handles rpc requests and a client
1076/// that can be used to send rpc requests.
1077///
1078/// Dropping it will stop the handler task, so you need to put it somewhere
1079/// where it will be kept alive. This struct will capture a copy of
1080/// [`crate::net_protocol::Blobs`] and keep it alive.
1081#[derive(Debug)]
1082pub(crate) struct RpcHandler {
1083    /// Client to hand out
1084    client: MemClient,
1085    /// Handler task
1086    _handler: AbortOnDropHandle<()>,
1087}
1088
1089impl Deref for RpcHandler {
1090    type Target = MemClient;
1091
1092    fn deref(&self) -> &Self::Target {
1093        &self.client
1094    }
1095}
1096
1097impl RpcHandler {
1098    fn new<D: crate::store::Store>(blobs: &Arc<BlobsInner<D>>) -> Self {
1099        let blobs = blobs.clone();
1100        let (listener, connector) = quic_rpc::transport::flume::channel(1);
1101        let listener = RpcServer::new(listener);
1102        let client = RpcClient::new(connector);
1103        let client = MemClient::new(client);
1104        let _handler = listener.spawn_accept_loop(move |req, chan| {
1105            Handler(blobs.clone()).handle_rpc_request(req, chan)
1106        });
1107        Self { client, _handler }
1108    }
1109}