iroh_blobs/
rpc.rs

1//! Provides a rpc protocol as well as a client for the protocol
2
3use std::{
4    io,
5    ops::Deref,
6    sync::{Arc, Mutex},
7};
8
9use anyhow::anyhow;
10use client::{
11    blobs::{self, BlobInfo, BlobStatus, DownloadMode, IncompleteBlobInfo, MemClient, WrapOption},
12    tags::TagInfo,
13    MemConnector,
14};
15use futures_buffered::BufferedStreamExt;
16use futures_lite::StreamExt;
17use futures_util::{FutureExt, Stream};
18use genawaiter::sync::{Co, Gen};
19use iroh::{Endpoint, NodeAddr};
20use iroh_io::AsyncSliceReader;
21use proto::{
22    blobs::{
23        AddPathRequest, AddPathResponse, AddStreamRequest, AddStreamResponse, AddStreamUpdate,
24        BatchAddPathRequest, BatchAddPathResponse, BatchAddStreamRequest, BatchAddStreamResponse,
25        BatchAddStreamUpdate, BatchCreateRequest, BatchCreateResponse, BatchCreateTempTagRequest,
26        BatchUpdate, BlobStatusRequest, BlobStatusResponse, ConsistencyCheckRequest,
27        CreateCollectionRequest, CreateCollectionResponse, DeleteRequest, DownloadResponse,
28        ExportRequest, ExportResponse, ListIncompleteRequest, ListRequest, ReadAtRequest,
29        ReadAtResponse, ValidateRequest,
30    },
31    tags::{
32        CreateRequest as TagsCreateRequest, DeleteRequest as TagDeleteRequest,
33        ListRequest as TagListRequest, SetRequest as TagsSetRequest, SyncMode,
34    },
35    Request, RpcError, RpcResult, RpcService,
36};
37use quic_rpc::{
38    server::{ChannelTypes, RpcChannel, RpcServerError},
39    RpcClient, RpcServer,
40};
41use tokio_util::task::AbortOnDropHandle;
42use tracing::{debug, warn};
43
44use crate::{
45    downloader::{DownloadRequest, Downloader},
46    export::ExportProgress,
47    format::collection::Collection,
48    get::{
49        db::{DownloadProgress, GetState},
50        Stats,
51    },
52    net_protocol::{BlobDownloadRequest, Blobs, BlobsInner},
53    provider::{AddProgress, BatchAddPathProgress},
54    store::{ConsistencyCheckProgress, ImportProgress, MapEntry, ValidateProgress},
55    util::{
56        local_pool::LocalPoolHandle,
57        progress::{AsyncChannelProgressSender, ProgressSender},
58        SetTagOption,
59    },
60    BlobFormat, HashAndFormat, Tag,
61};
62pub mod client;
63pub mod proto;
64
65/// Chunk size for getting blobs over RPC
66const RPC_BLOB_GET_CHUNK_SIZE: usize = 1024 * 64;
67/// Channel cap for getting blobs over RPC
68const RPC_BLOB_GET_CHANNEL_CAP: usize = 2;
69
70impl<D: crate::store::Store> Blobs<D> {
71    /// Get a client for the blobs protocol
72    pub fn client(&self) -> &blobs::MemClient {
73        &self
74            .rpc_handler
75            .get_or_init(|| RpcHandler::new(&self.inner))
76            .client
77    }
78
79    /// Handle an RPC request
80    pub async fn handle_rpc_request<C>(
81        self,
82        msg: Request,
83        chan: RpcChannel<RpcService, C>,
84    ) -> std::result::Result<(), RpcServerError<C>>
85    where
86        C: ChannelTypes<RpcService>,
87    {
88        Handler(self.inner.clone())
89            .handle_rpc_request(msg, chan)
90            .await
91    }
92}
93
94/// This is just an internal helper so I don't have to
95/// define all the rpc methods on `self: Arc<BlobsInner<S>>`
96#[derive(Clone)]
97struct Handler<S>(Arc<BlobsInner<S>>);
98
99impl<S> Deref for Handler<S> {
100    type Target = BlobsInner<S>;
101
102    fn deref(&self) -> &Self::Target {
103        &self.0
104    }
105}
106
107impl<D: crate::store::Store> Handler<D> {
108    fn store(&self) -> &D {
109        &self.0.store
110    }
111
112    fn rt(&self) -> &LocalPoolHandle {
113        self.0.rt()
114    }
115
116    fn endpoint(&self) -> &Endpoint {
117        &self.0.endpoint
118    }
119
120    fn downloader(&self) -> &Downloader {
121        &self.0.downloader
122    }
123
124    #[cfg(feature = "rpc")]
125    pub(crate) async fn batches(
126        &self,
127    ) -> tokio::sync::MutexGuard<'_, crate::net_protocol::BlobBatches> {
128        self.0.batches.lock().await
129    }
130
131    /// Handle an RPC request
132    pub async fn handle_rpc_request<C>(
133        self,
134        msg: Request,
135        chan: RpcChannel<RpcService, C>,
136    ) -> std::result::Result<(), RpcServerError<C>>
137    where
138        C: ChannelTypes<RpcService>,
139    {
140        match msg {
141            Request::Blobs(msg) => self.handle_blobs_request(msg, chan).await,
142            Request::Tags(msg) => self.handle_tags_request(msg, chan).await,
143        }
144    }
145
146    /// Handle a tags request
147    pub async fn handle_tags_request<C>(
148        self,
149        msg: proto::tags::Request,
150        chan: RpcChannel<proto::RpcService, C>,
151    ) -> std::result::Result<(), RpcServerError<C>>
152    where
153        C: ChannelTypes<proto::RpcService>,
154    {
155        use proto::tags::Request::*;
156        match msg {
157            Create(msg) => chan.rpc(msg, self, Self::tags_create).await,
158            Set(msg) => chan.rpc(msg, self, Self::tags_set).await,
159            DeleteTag(msg) => chan.rpc(msg, self, Self::blob_delete_tag).await,
160            ListTags(msg) => chan.server_streaming(msg, self, Self::blob_list_tags).await,
161        }
162    }
163
164    /// Handle a blobs request
165    pub async fn handle_blobs_request<C>(
166        self,
167        msg: proto::blobs::Request,
168        chan: RpcChannel<proto::RpcService, C>,
169    ) -> std::result::Result<(), RpcServerError<C>>
170    where
171        C: ChannelTypes<proto::RpcService>,
172    {
173        use proto::blobs::Request::*;
174        match msg {
175            List(msg) => chan.server_streaming(msg, self, Self::blob_list).await,
176            ListIncomplete(msg) => {
177                chan.server_streaming(msg, self, Self::blob_list_incomplete)
178                    .await
179            }
180            CreateCollection(msg) => chan.rpc(msg, self, Self::create_collection).await,
181            Delete(msg) => chan.rpc(msg, self, Self::blob_delete_blob).await,
182            AddPath(msg) => {
183                chan.server_streaming(msg, self, Self::blob_add_from_path)
184                    .await
185            }
186            Download(msg) => chan.server_streaming(msg, self, Self::blob_download).await,
187            Export(msg) => chan.server_streaming(msg, self, Self::blob_export).await,
188            Validate(msg) => chan.server_streaming(msg, self, Self::blob_validate).await,
189            Fsck(msg) => {
190                chan.server_streaming(msg, self, Self::blob_consistency_check)
191                    .await
192            }
193            ReadAt(msg) => chan.server_streaming(msg, self, Self::blob_read_at).await,
194            AddStream(msg) => chan.bidi_streaming(msg, self, Self::blob_add_stream).await,
195            AddStreamUpdate(_msg) => Err(RpcServerError::UnexpectedUpdateMessage),
196            BlobStatus(msg) => chan.rpc(msg, self, Self::blob_status).await,
197            BatchCreate(msg) => chan.bidi_streaming(msg, self, Self::batch_create).await,
198            BatchUpdate(_) => Err(RpcServerError::UnexpectedStartMessage),
199            BatchAddStream(msg) => chan.bidi_streaming(msg, self, Self::batch_add_stream).await,
200            BatchAddStreamUpdate(_) => Err(RpcServerError::UnexpectedStartMessage),
201            BatchAddPath(msg) => {
202                chan.server_streaming(msg, self, Self::batch_add_from_path)
203                    .await
204            }
205            BatchCreateTempTag(msg) => chan.rpc(msg, self, Self::batch_create_temp_tag).await,
206        }
207    }
208
209    async fn blob_status(self, msg: BlobStatusRequest) -> RpcResult<BlobStatusResponse> {
210        let blobs = self;
211        let entry = blobs
212            .store()
213            .get(&msg.hash)
214            .await
215            .map_err(|e| RpcError::new(&e))?;
216        Ok(BlobStatusResponse(match entry {
217            Some(entry) => {
218                if entry.is_complete() {
219                    BlobStatus::Complete {
220                        size: entry.size().value(),
221                    }
222                } else {
223                    BlobStatus::Partial { size: entry.size() }
224                }
225            }
226            None => BlobStatus::NotFound,
227        }))
228    }
229
230    async fn blob_list_impl(self, co: &Co<RpcResult<BlobInfo>>) -> io::Result<()> {
231        use bao_tree::io::fsm::Outboard;
232
233        let blobs = self;
234        let db = blobs.store();
235        for blob in db.blobs().await? {
236            let blob = blob?;
237            let Some(entry) = db.get(&blob).await? else {
238                continue;
239            };
240            let hash = entry.hash();
241            let size = entry.outboard().await?.tree().size();
242            let path = "".to_owned();
243            co.yield_(Ok(BlobInfo { hash, size, path })).await;
244        }
245        Ok(())
246    }
247
248    async fn blob_list_incomplete_impl(
249        self,
250        co: &Co<RpcResult<IncompleteBlobInfo>>,
251    ) -> io::Result<()> {
252        let blobs = self;
253        let db = blobs.store();
254        for hash in db.partial_blobs().await? {
255            let hash = hash?;
256            let Ok(Some(entry)) = db.get_mut(&hash).await else {
257                continue;
258            };
259            if entry.is_complete() {
260                continue;
261            }
262            let size = 0;
263            let expected_size = entry.size().value();
264            co.yield_(Ok(IncompleteBlobInfo {
265                hash,
266                size,
267                expected_size,
268            }))
269            .await;
270        }
271        Ok(())
272    }
273
274    fn blob_list(
275        self,
276        _msg: ListRequest,
277    ) -> impl Stream<Item = RpcResult<BlobInfo>> + Send + 'static {
278        Gen::new(|co| async move {
279            if let Err(e) = self.blob_list_impl(&co).await {
280                co.yield_(Err(RpcError::new(&e))).await;
281            }
282        })
283    }
284
285    fn blob_list_incomplete(
286        self,
287        _msg: ListIncompleteRequest,
288    ) -> impl Stream<Item = RpcResult<IncompleteBlobInfo>> + Send + 'static {
289        Gen::new(move |co| async move {
290            if let Err(e) = self.blob_list_incomplete_impl(&co).await {
291                co.yield_(Err(RpcError::new(&e))).await;
292            }
293        })
294    }
295
296    async fn blob_delete_tag(self, msg: TagDeleteRequest) -> RpcResult<()> {
297        self.store()
298            .set_tag(msg.name, None)
299            .await
300            .map_err(|e| RpcError::new(&e))?;
301        Ok(())
302    }
303
304    async fn blob_delete_blob(self, msg: DeleteRequest) -> RpcResult<()> {
305        self.store()
306            .delete(vec![msg.hash])
307            .await
308            .map_err(|e| RpcError::new(&e))?;
309        Ok(())
310    }
311
312    fn blob_list_tags(self, msg: TagListRequest) -> impl Stream<Item = TagInfo> + Send + 'static {
313        tracing::info!("blob_list_tags");
314        let blobs = self;
315        Gen::new(|co| async move {
316            let tags = blobs.store().tags().await.unwrap();
317            #[allow(clippy::manual_flatten)]
318            for item in tags {
319                if let Ok((name, HashAndFormat { hash, format })) = item {
320                    if (format.is_raw() && msg.raw) || (format.is_hash_seq() && msg.hash_seq) {
321                        co.yield_(TagInfo { name, hash, format }).await;
322                    }
323                }
324            }
325        })
326    }
327
328    /// Invoke validate on the database and stream out the result
329    fn blob_validate(
330        self,
331        msg: ValidateRequest,
332    ) -> impl Stream<Item = ValidateProgress> + Send + 'static {
333        let (tx, rx) = async_channel::bounded(1);
334        let tx2 = tx.clone();
335        let blobs = self;
336        tokio::task::spawn(async move {
337            if let Err(e) = blobs
338                .store()
339                .validate(msg.repair, AsyncChannelProgressSender::new(tx).boxed())
340                .await
341            {
342                tx2.send(ValidateProgress::Abort(RpcError::new(&e)))
343                    .await
344                    .ok();
345            }
346        });
347        rx
348    }
349
350    /// Invoke validate on the database and stream out the result
351    fn blob_consistency_check(
352        self,
353        msg: ConsistencyCheckRequest,
354    ) -> impl Stream<Item = ConsistencyCheckProgress> + Send + 'static {
355        let (tx, rx) = async_channel::bounded(1);
356        let tx2 = tx.clone();
357        let blobs = self;
358        tokio::task::spawn(async move {
359            if let Err(e) = blobs
360                .store()
361                .consistency_check(msg.repair, AsyncChannelProgressSender::new(tx).boxed())
362                .await
363            {
364                tx2.send(ConsistencyCheckProgress::Abort(RpcError::new(&e)))
365                    .await
366                    .ok();
367            }
368        });
369        rx
370    }
371
372    fn blob_add_from_path(self, msg: AddPathRequest) -> impl Stream<Item = AddPathResponse> {
373        // provide a little buffer so that we don't slow down the sender
374        let (tx, rx) = async_channel::bounded(32);
375        let tx2 = tx.clone();
376        let rt = self.rt().clone();
377        rt.spawn_detached(|| async move {
378            if let Err(e) = self.blob_add_from_path0(msg, tx).await {
379                tx2.send(AddProgress::Abort(RpcError::new(&*e))).await.ok();
380            }
381        });
382        rx.map(AddPathResponse)
383    }
384
385    async fn tags_set(self, msg: TagsSetRequest) -> RpcResult<()> {
386        let blobs = self;
387        blobs
388            .store()
389            .set_tag(msg.name, msg.value)
390            .await
391            .map_err(|e| RpcError::new(&e))?;
392        if let SyncMode::Full = msg.sync {
393            blobs.store().sync().await.map_err(|e| RpcError::new(&e))?;
394        }
395        if let Some(batch) = msg.batch {
396            if let Some(content) = msg.value.as_ref() {
397                blobs
398                    .batches()
399                    .await
400                    .remove_one(batch, content)
401                    .map_err(|e| RpcError::new(&*e))?;
402            }
403        }
404        Ok(())
405    }
406
407    async fn tags_create(self, msg: TagsCreateRequest) -> RpcResult<Tag> {
408        let blobs = self;
409        let tag = blobs
410            .store()
411            .create_tag(msg.value)
412            .await
413            .map_err(|e| RpcError::new(&e))?;
414        if let SyncMode::Full = msg.sync {
415            blobs.store().sync().await.map_err(|e| RpcError::new(&e))?;
416        }
417        if let Some(batch) = msg.batch {
418            blobs
419                .batches()
420                .await
421                .remove_one(batch, &msg.value)
422                .map_err(|e| RpcError::new(&*e))?;
423        }
424        Ok(tag)
425    }
426
427    fn blob_download(self, msg: BlobDownloadRequest) -> impl Stream<Item = DownloadResponse> {
428        let (sender, receiver) = async_channel::bounded(1024);
429        let endpoint = self.endpoint().clone();
430        let progress = AsyncChannelProgressSender::new(sender);
431
432        let blobs_protocol = self.clone();
433
434        self.rt().spawn_detached(move || async move {
435            if let Err(err) = blobs_protocol
436                .download(endpoint, msg, progress.clone())
437                .await
438            {
439                progress
440                    .send(DownloadProgress::Abort(RpcError::new(&*err)))
441                    .await
442                    .ok();
443            }
444        });
445
446        receiver.map(DownloadResponse)
447    }
448
449    fn blob_export(self, msg: ExportRequest) -> impl Stream<Item = ExportResponse> {
450        let (tx, rx) = async_channel::bounded(1024);
451        let progress = AsyncChannelProgressSender::new(tx);
452        let rt = self.rt().clone();
453        rt.spawn_detached(move || async move {
454            let res = crate::export::export(
455                self.store(),
456                msg.hash,
457                msg.path,
458                msg.format,
459                msg.mode,
460                progress.clone(),
461            )
462            .await;
463            match res {
464                Ok(()) => progress.send(ExportProgress::AllDone).await.ok(),
465                Err(err) => progress
466                    .send(ExportProgress::Abort(RpcError::new(&*err)))
467                    .await
468                    .ok(),
469            };
470        });
471        rx.map(ExportResponse)
472    }
473
474    async fn blob_add_from_path0(
475        self,
476        msg: AddPathRequest,
477        progress: async_channel::Sender<AddProgress>,
478    ) -> anyhow::Result<()> {
479        use std::collections::BTreeMap;
480
481        use crate::store::ImportMode;
482
483        let blobs = self.clone();
484        let progress = AsyncChannelProgressSender::new(progress);
485        let names = Arc::new(Mutex::new(BTreeMap::new()));
486        // convert import progress to provide progress
487        let import_progress = progress.clone().with_filter_map(move |x| match x {
488            ImportProgress::Found { id, name } => {
489                names.lock().unwrap().insert(id, name);
490                None
491            }
492            ImportProgress::Size { id, size } => {
493                let name = names.lock().unwrap().remove(&id)?;
494                Some(AddProgress::Found { id, name, size })
495            }
496            ImportProgress::OutboardProgress { id, offset } => {
497                Some(AddProgress::Progress { id, offset })
498            }
499            ImportProgress::OutboardDone { hash, id } => Some(AddProgress::Done { hash, id }),
500            _ => None,
501        });
502        let AddPathRequest {
503            wrap,
504            path: root,
505            in_place,
506            tag,
507        } = msg;
508        // Check that the path is absolute and exists.
509        anyhow::ensure!(root.is_absolute(), "path must be absolute");
510        anyhow::ensure!(
511            root.exists(),
512            "trying to add missing path: {}",
513            root.display()
514        );
515
516        let import_mode = match in_place {
517            true => ImportMode::TryReference,
518            false => ImportMode::Copy,
519        };
520
521        let create_collection = match wrap {
522            WrapOption::Wrap { .. } => true,
523            WrapOption::NoWrap => root.is_dir(),
524        };
525
526        let temp_tag = if create_collection {
527            // import all files below root recursively
528            let data_sources = crate::util::fs::scan_path(root, wrap)?;
529            let blobs = self;
530
531            const IO_PARALLELISM: usize = 4;
532            let result: Vec<_> = futures_lite::stream::iter(data_sources)
533                .map(|source| {
534                    let import_progress = import_progress.clone();
535                    let blobs = blobs.clone();
536                    async move {
537                        let name = source.name().to_string();
538                        let (tag, size) = blobs
539                            .store()
540                            .import_file(
541                                source.path().to_owned(),
542                                import_mode,
543                                BlobFormat::Raw,
544                                import_progress,
545                            )
546                            .await?;
547                        let hash = *tag.hash();
548                        io::Result::Ok((name, hash, size, tag))
549                    }
550                })
551                .buffered_ordered(IO_PARALLELISM)
552                .try_collect()
553                .await?;
554
555            // create a collection
556            let (collection, _child_tags): (Collection, Vec<_>) = result
557                .into_iter()
558                .map(|(name, hash, _, tag)| ((name, hash), tag))
559                .unzip();
560
561            collection.store(blobs.store()).await?
562        } else {
563            // import a single file
564            let (tag, _size) = blobs
565                .store()
566                .import_file(root, import_mode, BlobFormat::Raw, import_progress)
567                .await?;
568            tag
569        };
570
571        let hash_and_format = temp_tag.inner();
572        let HashAndFormat { hash, format } = *hash_and_format;
573        let tag = match tag {
574            SetTagOption::Named(tag) => {
575                blobs
576                    .store()
577                    .set_tag(tag.clone(), Some(*hash_and_format))
578                    .await?;
579                tag
580            }
581            SetTagOption::Auto => blobs.store().create_tag(*hash_and_format).await?,
582        };
583        progress
584            .send(AddProgress::AllDone {
585                hash,
586                format,
587                tag: tag.clone(),
588            })
589            .await?;
590        Ok(())
591    }
592
593    async fn batch_create_temp_tag(self, msg: BatchCreateTempTagRequest) -> RpcResult<()> {
594        let blobs = self;
595        let tag = blobs.store().temp_tag(msg.content);
596        blobs.batches().await.store(msg.batch, tag);
597        Ok(())
598    }
599
600    fn batch_add_stream(
601        self,
602        msg: BatchAddStreamRequest,
603        stream: impl Stream<Item = BatchAddStreamUpdate> + Send + Unpin + 'static,
604    ) -> impl Stream<Item = BatchAddStreamResponse> {
605        let (tx, rx) = async_channel::bounded(32);
606        let this = self.clone();
607
608        self.rt().spawn_detached(|| async move {
609            if let Err(err) = this.batch_add_stream0(msg, stream, tx.clone()).await {
610                tx.send(BatchAddStreamResponse::Abort(RpcError::new(&*err)))
611                    .await
612                    .ok();
613            }
614        });
615        rx
616    }
617
618    fn batch_add_from_path(
619        self,
620        msg: BatchAddPathRequest,
621    ) -> impl Stream<Item = BatchAddPathResponse> {
622        // provide a little buffer so that we don't slow down the sender
623        let (tx, rx) = async_channel::bounded(32);
624        let tx2 = tx.clone();
625        let this = self.clone();
626        self.rt().spawn_detached(|| async move {
627            if let Err(e) = this.batch_add_from_path0(msg, tx).await {
628                tx2.send(BatchAddPathProgress::Abort(RpcError::new(&*e)))
629                    .await
630                    .ok();
631            }
632        });
633        rx.map(BatchAddPathResponse)
634    }
635
636    async fn batch_add_stream0(
637        self,
638        msg: BatchAddStreamRequest,
639        stream: impl Stream<Item = BatchAddStreamUpdate> + Send + Unpin + 'static,
640        progress: async_channel::Sender<BatchAddStreamResponse>,
641    ) -> anyhow::Result<()> {
642        let blobs = self;
643        let progress = AsyncChannelProgressSender::new(progress);
644
645        let stream = stream.map(|item| match item {
646            BatchAddStreamUpdate::Chunk(chunk) => Ok(chunk),
647            BatchAddStreamUpdate::Abort => {
648                Err(io::Error::new(io::ErrorKind::Interrupted, "Remote abort"))
649            }
650        });
651
652        let import_progress = progress.clone().with_filter_map(move |x| match x {
653            ImportProgress::OutboardProgress { offset, .. } => {
654                Some(BatchAddStreamResponse::OutboardProgress { offset })
655            }
656            _ => None,
657        });
658        let (temp_tag, _len) = blobs
659            .store()
660            .import_stream(stream, msg.format, import_progress)
661            .await?;
662        let hash = temp_tag.inner().hash;
663        blobs.batches().await.store(msg.batch, temp_tag);
664        progress
665            .send(BatchAddStreamResponse::Result { hash })
666            .await?;
667        Ok(())
668    }
669
670    async fn batch_add_from_path0(
671        self,
672        msg: BatchAddPathRequest,
673        progress: async_channel::Sender<BatchAddPathProgress>,
674    ) -> anyhow::Result<()> {
675        let progress = AsyncChannelProgressSender::new(progress);
676        // convert import progress to provide progress
677        let import_progress = progress.clone().with_filter_map(move |x| match x {
678            ImportProgress::Size { size, .. } => Some(BatchAddPathProgress::Found { size }),
679            ImportProgress::OutboardProgress { offset, .. } => {
680                Some(BatchAddPathProgress::Progress { offset })
681            }
682            ImportProgress::OutboardDone { hash, .. } => Some(BatchAddPathProgress::Done { hash }),
683            _ => None,
684        });
685        let BatchAddPathRequest {
686            path: root,
687            import_mode,
688            format,
689            batch,
690        } = msg;
691        // Check that the path is absolute and exists.
692        anyhow::ensure!(root.is_absolute(), "path must be absolute");
693        anyhow::ensure!(
694            root.exists(),
695            "trying to add missing path: {}",
696            root.display()
697        );
698        let blobs = self;
699        let (tag, _) = blobs
700            .store()
701            .import_file(root, import_mode, format, import_progress)
702            .await?;
703        let hash = *tag.hash();
704        blobs.batches().await.store(batch, tag);
705
706        progress.send(BatchAddPathProgress::Done { hash }).await?;
707        Ok(())
708    }
709
710    fn blob_add_stream(
711        self,
712        msg: AddStreamRequest,
713        stream: impl Stream<Item = AddStreamUpdate> + Send + Unpin + 'static,
714    ) -> impl Stream<Item = AddStreamResponse> {
715        let (tx, rx) = async_channel::bounded(32);
716        let this = self.clone();
717
718        self.rt().spawn_detached(|| async move {
719            if let Err(err) = this.blob_add_stream0(msg, stream, tx.clone()).await {
720                tx.send(AddProgress::Abort(RpcError::new(&*err))).await.ok();
721            }
722        });
723
724        rx.map(AddStreamResponse)
725    }
726
727    async fn blob_add_stream0(
728        self,
729        msg: AddStreamRequest,
730        stream: impl Stream<Item = AddStreamUpdate> + Send + Unpin + 'static,
731        progress: async_channel::Sender<AddProgress>,
732    ) -> anyhow::Result<()> {
733        let progress = AsyncChannelProgressSender::new(progress);
734
735        let stream = stream.map(|item| match item {
736            AddStreamUpdate::Chunk(chunk) => Ok(chunk),
737            AddStreamUpdate::Abort => {
738                Err(io::Error::new(io::ErrorKind::Interrupted, "Remote abort"))
739            }
740        });
741
742        let name_cache = Arc::new(Mutex::new(None));
743        let import_progress = progress.clone().with_filter_map(move |x| match x {
744            ImportProgress::Found { id: _, name } => {
745                let _ = name_cache.lock().unwrap().insert(name);
746                None
747            }
748            ImportProgress::Size { id, size } => {
749                let name = name_cache.lock().unwrap().take()?;
750                Some(AddProgress::Found { id, name, size })
751            }
752            ImportProgress::OutboardProgress { id, offset } => {
753                Some(AddProgress::Progress { id, offset })
754            }
755            ImportProgress::OutboardDone { hash, id } => Some(AddProgress::Done { hash, id }),
756            _ => None,
757        });
758        let blobs = self;
759        let (temp_tag, _len) = blobs
760            .store()
761            .import_stream(stream, BlobFormat::Raw, import_progress)
762            .await?;
763        let hash_and_format = *temp_tag.inner();
764        let HashAndFormat { hash, format } = hash_and_format;
765        let tag = match msg.tag {
766            SetTagOption::Named(tag) => {
767                blobs
768                    .store()
769                    .set_tag(tag.clone(), Some(hash_and_format))
770                    .await?;
771                tag
772            }
773            SetTagOption::Auto => blobs.store().create_tag(hash_and_format).await?,
774        };
775        progress
776            .send(AddProgress::AllDone { hash, tag, format })
777            .await?;
778        Ok(())
779    }
780
781    fn blob_read_at(
782        self,
783        req: ReadAtRequest,
784    ) -> impl Stream<Item = RpcResult<ReadAtResponse>> + Send + 'static {
785        let (tx, rx) = async_channel::bounded(RPC_BLOB_GET_CHANNEL_CAP);
786        let db = self.store().clone();
787        self.rt().spawn_detached(move || async move {
788            if let Err(err) = read_loop(req, db, tx.clone(), RPC_BLOB_GET_CHUNK_SIZE).await {
789                tx.send(RpcResult::Err(RpcError::new(&*err))).await.ok();
790            }
791        });
792
793        async fn read_loop<D: crate::store::Store>(
794            req: ReadAtRequest,
795            db: D,
796            tx: async_channel::Sender<RpcResult<ReadAtResponse>>,
797            max_chunk_size: usize,
798        ) -> anyhow::Result<()> {
799            let entry = db.get(&req.hash).await?;
800            let entry = entry.ok_or_else(|| anyhow!("Blob not found"))?;
801            let size = entry.size();
802
803            anyhow::ensure!(
804                req.offset <= size.value(),
805                "requested offset is out of range: {} > {:?}",
806                req.offset,
807                size
808            );
809
810            let len: usize = req
811                .len
812                .as_result_len(size.value() - req.offset)
813                .try_into()?;
814
815            anyhow::ensure!(
816                req.offset + len as u64 <= size.value(),
817                "requested range is out of bounds: offset: {}, len: {} > {:?}",
818                req.offset,
819                len,
820                size
821            );
822
823            tx.send(Ok(ReadAtResponse::Entry {
824                size,
825                is_complete: entry.is_complete(),
826            }))
827            .await?;
828            let mut reader = entry.data_reader().await?;
829
830            let (num_chunks, chunk_size) = if len <= max_chunk_size {
831                (1, len)
832            } else {
833                let num_chunks = len / max_chunk_size + (len % max_chunk_size != 0) as usize;
834                (num_chunks, max_chunk_size)
835            };
836
837            let mut read = 0u64;
838            for i in 0..num_chunks {
839                let chunk_size = if i == num_chunks - 1 {
840                    // last chunk might be smaller
841                    len - read as usize
842                } else {
843                    chunk_size
844                };
845                let chunk = reader.read_at(req.offset + read, chunk_size).await?;
846                let chunk_len = chunk.len();
847                if !chunk.is_empty() {
848                    tx.send(Ok(ReadAtResponse::Data { chunk })).await?;
849                }
850                if chunk_len < chunk_size {
851                    break;
852                } else {
853                    read += chunk_len as u64;
854                }
855            }
856            Ok(())
857        }
858
859        rx
860    }
861
862    fn batch_create(
863        self,
864        _: BatchCreateRequest,
865        mut updates: impl Stream<Item = BatchUpdate> + Send + Unpin + 'static,
866    ) -> impl Stream<Item = BatchCreateResponse> {
867        let blobs = self;
868        async move {
869            let batch = blobs.batches().await.create();
870            tokio::spawn(async move {
871                while let Some(item) = updates.next().await {
872                    match item {
873                        BatchUpdate::Drop(content) => {
874                            // this can not fail, since we keep the batch alive.
875                            // therefore it is safe to ignore the result.
876                            let _ = blobs.batches().await.remove_one(batch, &content);
877                        }
878                        BatchUpdate::Ping => {}
879                    }
880                }
881                blobs.batches().await.remove(batch);
882            });
883            BatchCreateResponse::Id(batch)
884        }
885        .into_stream()
886    }
887
888    async fn create_collection(
889        self,
890        req: CreateCollectionRequest,
891    ) -> RpcResult<CreateCollectionResponse> {
892        let CreateCollectionRequest {
893            collection,
894            tag,
895            tags_to_delete,
896        } = req;
897
898        let blobs = self;
899
900        let temp_tag = collection
901            .store(blobs.store())
902            .await
903            .map_err(|e| RpcError::new(&*e))?;
904        let hash_and_format = temp_tag.inner();
905        let HashAndFormat { hash, .. } = *hash_and_format;
906        let tag = match tag {
907            SetTagOption::Named(tag) => {
908                blobs
909                    .store()
910                    .set_tag(tag.clone(), Some(*hash_and_format))
911                    .await
912                    .map_err(|e| RpcError::new(&e))?;
913                tag
914            }
915            SetTagOption::Auto => blobs
916                .store()
917                .create_tag(*hash_and_format)
918                .await
919                .map_err(|e| RpcError::new(&e))?,
920        };
921
922        for tag in tags_to_delete {
923            blobs
924                .store()
925                .set_tag(tag, None)
926                .await
927                .map_err(|e| RpcError::new(&e))?;
928        }
929
930        Ok(CreateCollectionResponse { hash, tag })
931    }
932
933    pub(crate) async fn download(
934        &self,
935        endpoint: Endpoint,
936        req: BlobDownloadRequest,
937        progress: AsyncChannelProgressSender<DownloadProgress>,
938    ) -> anyhow::Result<()> {
939        let BlobDownloadRequest {
940            hash,
941            format,
942            nodes,
943            tag,
944            mode,
945        } = req;
946        let hash_and_format = HashAndFormat { hash, format };
947        let temp_tag = self.store().temp_tag(hash_and_format);
948        let stats = match mode {
949            DownloadMode::Queued => {
950                self.download_queued(endpoint, hash_and_format, nodes, progress.clone())
951                    .await?
952            }
953            DownloadMode::Direct => {
954                self.download_direct_from_nodes(endpoint, hash_and_format, nodes, progress.clone())
955                    .await?
956            }
957        };
958
959        progress.send(DownloadProgress::AllDone(stats)).await.ok();
960        match tag {
961            SetTagOption::Named(tag) => {
962                self.store().set_tag(tag, Some(hash_and_format)).await?;
963            }
964            SetTagOption::Auto => {
965                self.store().create_tag(hash_and_format).await?;
966            }
967        }
968        drop(temp_tag);
969
970        Ok(())
971    }
972
973    async fn download_queued(
974        &self,
975        endpoint: Endpoint,
976        hash_and_format: HashAndFormat,
977        nodes: Vec<NodeAddr>,
978        progress: AsyncChannelProgressSender<DownloadProgress>,
979    ) -> anyhow::Result<Stats> {
980        /// Name used for logging when new node addresses are added from gossip.
981        const BLOB_DOWNLOAD_SOURCE_NAME: &str = "blob_download";
982
983        let mut node_ids = Vec::with_capacity(nodes.len());
984        let mut any_added = false;
985        for node in nodes {
986            node_ids.push(node.node_id);
987            if !node.is_empty() {
988                endpoint.add_node_addr_with_source(node, BLOB_DOWNLOAD_SOURCE_NAME)?;
989                any_added = true;
990            }
991        }
992        let can_download = !node_ids.is_empty() && (any_added || endpoint.discovery().is_some());
993        anyhow::ensure!(can_download, "no way to reach a node for download");
994        let req = DownloadRequest::new(hash_and_format, node_ids).progress_sender(progress);
995        let handle = self.downloader().queue(req).await;
996        let stats = handle.await?;
997        Ok(stats)
998    }
999
1000    #[tracing::instrument("download_direct", skip_all, fields(hash=%hash_and_format.hash.fmt_short()))]
1001    async fn download_direct_from_nodes(
1002        &self,
1003        endpoint: Endpoint,
1004        hash_and_format: HashAndFormat,
1005        nodes: Vec<NodeAddr>,
1006        progress: AsyncChannelProgressSender<DownloadProgress>,
1007    ) -> anyhow::Result<Stats> {
1008        let mut last_err = None;
1009        let mut remaining_nodes = nodes.len();
1010        let mut nodes_iter = nodes.into_iter();
1011        'outer: loop {
1012            match crate::get::db::get_to_db_in_steps(
1013                self.store().clone(),
1014                hash_and_format,
1015                progress.clone(),
1016            )
1017            .await?
1018            {
1019                GetState::Complete(stats) => return Ok(stats),
1020                GetState::NeedsConn(needs_conn) => {
1021                    let (conn, node_id) = 'inner: loop {
1022                        match nodes_iter.next() {
1023                            None => break 'outer,
1024                            Some(node) => {
1025                                remaining_nodes -= 1;
1026                                let node_id = node.node_id;
1027                                if node_id == endpoint.node_id() {
1028                                    debug!(
1029                                        ?remaining_nodes,
1030                                        "skip node {} (it is the node id of ourselves)",
1031                                        node_id.fmt_short()
1032                                    );
1033                                    continue 'inner;
1034                                }
1035                                match endpoint.connect(node, crate::protocol::ALPN).await {
1036                                    Ok(conn) => break 'inner (conn, node_id),
1037                                    Err(err) => {
1038                                        debug!(
1039                                            ?remaining_nodes,
1040                                            "failed to connect to {}: {err}",
1041                                            node_id.fmt_short()
1042                                        );
1043                                        continue 'inner;
1044                                    }
1045                                }
1046                            }
1047                        }
1048                    };
1049                    match needs_conn.proceed(conn).await {
1050                        Ok(stats) => return Ok(stats),
1051                        Err(err) => {
1052                            warn!(
1053                                ?remaining_nodes,
1054                                "failed to download from {}: {err}",
1055                                node_id.fmt_short()
1056                            );
1057                            last_err = Some(err);
1058                        }
1059                    }
1060                }
1061            }
1062        }
1063        match last_err {
1064            Some(err) => Err(err.into()),
1065            None => Err(anyhow!("No nodes to download from provided")),
1066        }
1067    }
1068}
1069
1070/// An in memory rpc handler for the blobs rpc protocol
1071///
1072/// This struct contains both a task that handles rpc requests and a client
1073/// that can be used to send rpc requests.
1074///
1075/// Dropping it will stop the handler task, so you need to put it somewhere
1076/// where it will be kept alive. This struct will capture a copy of
1077/// [`crate::net_protocol::Blobs`] and keep it alive.
1078#[derive(Debug)]
1079pub(crate) struct RpcHandler {
1080    /// Client to hand out
1081    client: MemClient,
1082    /// Handler task
1083    _handler: AbortOnDropHandle<()>,
1084}
1085
1086impl Deref for RpcHandler {
1087    type Target = MemClient;
1088
1089    fn deref(&self) -> &Self::Target {
1090        &self.client
1091    }
1092}
1093
1094impl RpcHandler {
1095    fn new<D: crate::store::Store>(blobs: &Arc<BlobsInner<D>>) -> Self {
1096        let blobs = blobs.clone();
1097        let (listener, connector) = quic_rpc::transport::flume::channel(1);
1098        let listener = RpcServer::new(listener);
1099        let client = RpcClient::new(connector);
1100        let client = MemClient::new(client);
1101        let _handler = listener.spawn_accept_loop(move |req, chan| {
1102            Handler(blobs.clone()).handle_rpc_request(req, chan)
1103        });
1104        Self { client, _handler }
1105    }
1106}