1use std::{
2 io,
3 path::PathBuf,
4 sync::{Arc, Mutex},
5};
6
7use anyhow::{anyhow, Context, Result};
8use bytes::Bytes;
9use futures_buffered::BufferedStreamExt;
10use futures_lite::StreamExt;
11use futures_util::{sink::Buffer, FutureExt, SinkExt, Stream};
12use quic_rpc::{client::UpdateSink, Connector, RpcClient};
13use tokio::io::AsyncRead;
14use tokio_util::io::ReaderStream;
15use tracing::{debug, warn};
16
17use super::WrapOption;
18use crate::{
19 format::collection::Collection,
20 net_protocol::BatchId,
21 provider::BatchAddPathProgress,
22 rpc::proto::{
23 blobs::{
24 BatchAddPathRequest, BatchAddStreamRequest, BatchAddStreamResponse,
25 BatchAddStreamUpdate, BatchCreateTempTagRequest, BatchUpdate,
26 },
27 tags::{self, SyncMode},
28 RpcService,
29 },
30 store::ImportMode,
31 util::{SetTagOption, TagDrop},
32 BlobFormat, HashAndFormat, Tag, TempTag,
33};
34
35#[derive(derive_more::Debug)]
37struct BatchInner<C>
38where
39 C: Connector<RpcService>,
40{
41 batch: BatchId,
43 rpc: RpcClient<RpcService, C>,
45 #[debug(skip)]
47 updates: Mutex<Buffer<UpdateSink<C, BatchUpdate>, BatchUpdate>>,
48}
49
50#[derive(derive_more::Debug)]
57pub struct Batch<C>(Arc<BatchInner<C>>)
58where
59 C: Connector<RpcService>;
60
61impl<C> TagDrop for BatchInner<C>
62where
63 C: Connector<RpcService>,
64{
65 fn on_drop(&self, content: &HashAndFormat) {
66 let mut updates = self.updates.lock().unwrap();
67 updates.feed(BatchUpdate::Drop(*content)).now_or_never();
79 updates.flush().now_or_never();
80 }
81}
82
83#[derive(Debug, Clone, Copy, Default)]
85pub struct AddFileOpts {
86 pub import_mode: ImportMode,
88 pub format: BlobFormat,
90}
91
92#[derive(Debug, Clone)]
94pub struct AddDirOpts {
95 pub import_mode: ImportMode,
97 pub wrap: WrapOption,
99 pub io_parallelism: usize,
101}
102
103impl Default for AddDirOpts {
104 fn default() -> Self {
105 Self {
106 import_mode: ImportMode::TryReference,
107 wrap: WrapOption::NoWrap,
108 io_parallelism: 4,
109 }
110 }
111}
112
113#[derive(Debug, Clone)]
115pub struct AddReaderOpts {
116 pub format: BlobFormat,
118 pub chunk_size: usize,
120}
121
122impl Default for AddReaderOpts {
123 fn default() -> Self {
124 Self {
125 format: BlobFormat::Raw,
126 chunk_size: 1024 * 64,
127 }
128 }
129}
130
131impl<C> Batch<C>
132where
133 C: Connector<RpcService>,
134{
135 pub(super) fn new(
136 batch: BatchId,
137 rpc: RpcClient<RpcService, C>,
138 updates: UpdateSink<C, BatchUpdate>,
139 buffer_size: usize,
140 ) -> Self {
141 let updates = updates.buffer(buffer_size);
142 Self(Arc::new(BatchInner {
143 batch,
144 rpc,
145 updates: updates.into(),
146 }))
147 }
148
149 pub async fn add_bytes(&self, bytes: impl Into<Bytes>) -> Result<TempTag> {
151 self.add_bytes_with_opts(bytes, Default::default()).await
152 }
153
154 pub async fn add_file(&self, path: PathBuf) -> Result<(TempTag, u64)> {
158 self.add_file_with_opts(path, AddFileOpts::default()).await
159 }
160
161 pub async fn add_dir(&self, root: PathBuf) -> Result<TempTag> {
163 self.add_dir_with_opts(root, Default::default()).await
164 }
165
166 pub async fn add_reader(
172 &self,
173 reader: impl AsyncRead + Unpin + Send + 'static,
174 ) -> anyhow::Result<TempTag> {
175 self.add_reader_with_opts(reader, Default::default()).await
176 }
177
178 pub async fn add_stream(
180 &self,
181 input: impl Stream<Item = io::Result<Bytes>> + Send + Unpin + 'static,
182 ) -> Result<TempTag> {
183 self.add_stream_with_opts(input, Default::default()).await
184 }
185
186 pub async fn temp_tag(&self, content: HashAndFormat) -> Result<TempTag> {
193 self.0
195 .rpc
196 .rpc(BatchCreateTempTagRequest {
197 batch: self.0.batch,
198 content,
199 })
200 .await??;
201 Ok(self.local_temp_tag(content, None))
203 }
204
205 pub async fn add_reader_with_opts(
209 &self,
210 reader: impl AsyncRead + Unpin + Send + 'static,
211 opts: AddReaderOpts,
212 ) -> anyhow::Result<TempTag> {
213 let AddReaderOpts { format, chunk_size } = opts;
214 let input = ReaderStream::with_capacity(reader, chunk_size);
215 self.add_stream_with_opts(input, format).await
216 }
217
218 pub async fn add_bytes_with_opts(
220 &self,
221 bytes: impl Into<Bytes>,
222 format: BlobFormat,
223 ) -> Result<TempTag> {
224 let input = futures_lite::stream::once(Ok(bytes.into()));
225 self.add_stream_with_opts(input, format).await
226 }
227
228 pub async fn add_file_with_opts(
241 &self,
242 path: PathBuf,
243 opts: AddFileOpts,
244 ) -> Result<(TempTag, u64)> {
245 let AddFileOpts {
246 import_mode,
247 format,
248 } = opts;
249 anyhow::ensure!(
250 path.is_absolute(),
251 "Path must be absolute, but got: {:?}",
252 path
253 );
254 anyhow::ensure!(path.is_file(), "Path does not refer to a file: {:?}", path);
255 let mut stream = self
256 .0
257 .rpc
258 .server_streaming(BatchAddPathRequest {
259 path,
260 import_mode,
261 format,
262 batch: self.0.batch,
263 })
264 .await?;
265 let mut res_hash = None;
266 let mut res_size = None;
267 while let Some(item) = stream.next().await {
268 match item?.0 {
269 BatchAddPathProgress::Abort(cause) => {
270 Err(cause)?;
271 }
272 BatchAddPathProgress::Done { hash } => {
273 res_hash = Some(hash);
274 }
275 BatchAddPathProgress::Found { size } => {
276 res_size = Some(size);
277 }
278 _ => {}
279 }
280 }
281 let hash = res_hash.context("Missing hash")?;
282 let size = res_size.context("Missing size")?;
283 Ok((
284 self.local_temp_tag(HashAndFormat { hash, format }, Some(size)),
285 size,
286 ))
287 }
288
289 pub async fn add_dir_with_opts(&self, root: PathBuf, opts: AddDirOpts) -> Result<TempTag> {
296 let AddDirOpts {
297 import_mode,
298 wrap,
299 io_parallelism,
300 } = opts;
301 anyhow::ensure!(root.is_absolute(), "Path must be absolute");
302
303 let data_sources = crate::util::fs::scan_path(root, wrap)?;
308 let opts = AddFileOpts {
309 import_mode,
310 format: BlobFormat::Raw,
311 };
312 let result: Vec<_> = futures_lite::stream::iter(data_sources)
313 .map(|source| {
314 async move {
316 let name = source.name().to_string();
317 let (tag, size) = self
318 .add_file_with_opts(source.path().to_owned(), opts)
319 .await?;
320 let hash = *tag.hash();
321 anyhow::Ok((name, hash, size, tag))
322 }
323 })
324 .buffered_ordered(io_parallelism)
325 .try_collect()
326 .await?;
327
328 let (collection, child_tags): (Collection, Vec<_>) = result
330 .into_iter()
331 .map(|(name, hash, _, tag)| ((name, hash), tag))
332 .unzip();
333
334 let tag = self.add_collection(collection).await?;
335 drop(child_tags);
336 Ok(tag)
337 }
338
339 pub async fn add_stream_with_opts(
344 &self,
345 mut input: impl Stream<Item = io::Result<Bytes>> + Send + Unpin + 'static,
346 format: BlobFormat,
347 ) -> Result<TempTag> {
348 let (mut sink, mut stream) = self
349 .0
350 .rpc
351 .bidi(BatchAddStreamRequest {
352 batch: self.0.batch,
353 format,
354 })
355 .await?;
356 let mut size = 0u64;
357 while let Some(item) = input.next().await {
358 match item {
359 Ok(chunk) => {
360 size += chunk.len() as u64;
361 sink.send(BatchAddStreamUpdate::Chunk(chunk))
362 .await
363 .map_err(|err| anyhow!("Failed to send input stream to remote: {err:?}"))?;
364 }
365 Err(err) => {
366 warn!("Abort send, reason: failed to read from source stream: {err:?}");
367 sink.send(BatchAddStreamUpdate::Abort)
368 .await
369 .map_err(|err| anyhow!("Failed to send input stream to remote: {err:?}"))?;
370 break;
371 }
372 }
373 }
374 drop(sink);
376 let mut res = None;
377 while let Some(item) = stream.next().await {
378 match item? {
379 BatchAddStreamResponse::Abort(cause) => {
380 Err(cause)?;
381 }
382 BatchAddStreamResponse::Result { hash } => {
383 res = Some(hash);
384 }
385 _ => {}
386 }
387 }
388 let hash = res.context("Missing answer")?;
389 Ok(self.local_temp_tag(HashAndFormat { hash, format }, Some(size)))
390 }
391
392 pub async fn add_collection(&self, collection: Collection) -> Result<TempTag> {
402 self.add_blob_seq(collection.to_blobs()).await
403 }
404
405 pub async fn add_blob_seq(&self, iter: impl Iterator<Item = Bytes>) -> Result<TempTag> {
411 let mut blobs = iter.peekable();
412 let mut tags = vec![];
414 loop {
415 let blob = blobs.next().context("Failed to get next blob")?;
416 if blobs.peek().is_none() {
417 return self.add_bytes_with_opts(blob, BlobFormat::HashSeq).await;
418 } else {
419 tags.push(self.add_bytes(blob).await?);
420 }
421 }
422 }
423
424 pub async fn persist(&self, tt: TempTag) -> Result<Tag> {
426 let tag = self
427 .0
428 .rpc
429 .rpc(tags::CreateRequest {
430 value: tt.hash_and_format(),
431 batch: Some(self.0.batch),
432 sync: SyncMode::Full,
433 })
434 .await??;
435 Ok(tag)
436 }
437
438 pub async fn persist_to(&self, tt: TempTag, tag: Tag) -> Result<()> {
440 self.0
441 .rpc
442 .rpc(tags::SetRequest {
443 name: tag,
444 value: tt.hash_and_format(),
445 batch: Some(self.0.batch),
446 sync: SyncMode::Full,
447 })
448 .await??;
449 Ok(())
450 }
451
452 pub async fn persist_with_opts(&self, tt: TempTag, opts: SetTagOption) -> Result<Tag> {
455 match opts {
456 SetTagOption::Auto => self.persist(tt).await,
457 SetTagOption::Named(tag) => {
458 self.persist_to(tt, tag.clone()).await?;
459 Ok(tag)
460 }
461 }
462 }
463
464 fn local_temp_tag(&self, inner: HashAndFormat, _size: Option<u64>) -> TempTag {
468 let on_drop: Arc<dyn TagDrop> = self.0.clone();
469 let on_drop = Some(Arc::downgrade(&on_drop));
470 TempTag::new(inner, on_drop)
471 }
472}