1use std::{
10 future::{Future, IntoFuture},
11 pin::Pin,
12 sync::Arc,
13 task::{Context, Poll},
14};
15
16use bao_tree::{io::BaoContentItem, ChunkNum, ChunkRanges};
17use bytes::Bytes;
18use genawaiter::sync::{Co, Gen};
19use iroh::endpoint::Connection;
20use n0_future::{Stream, StreamExt};
21use nested_enum_utils::enum_conversions;
22use rand::Rng;
23use snafu::IntoError;
24use tokio::sync::mpsc;
25
26use super::{fsm, GetError, GetResult, Stats};
27use crate::{
28 get::error::{BadRequestSnafu, LocalFailureSnafu},
29 hashseq::HashSeq,
30 protocol::{ChunkRangesSeq, GetRequest},
31 util::ChunkRangesExt,
32 Hash, HashAndFormat,
33};
34
35pub struct GetBlobResult {
40 rx: n0_future::stream::Boxed<GetBlobItem>,
41}
42
43impl IntoFuture for GetBlobResult {
44 type Output = GetResult<Bytes>;
45 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
46
47 fn into_future(self) -> Self::IntoFuture {
48 Box::pin(self.bytes())
49 }
50}
51
52impl GetBlobResult {
53 pub async fn bytes(self) -> GetResult<Bytes> {
54 let (bytes, _) = self.bytes_and_stats().await?;
55 Ok(bytes)
56 }
57
58 pub async fn bytes_and_stats(mut self) -> GetResult<(Bytes, Stats)> {
59 let mut parts = Vec::new();
60 let stats = loop {
61 let Some(item) = self.next().await else {
62 return Err(LocalFailureSnafu.into_error(anyhow::anyhow!("unexpected end").into()));
63 };
64 match item {
65 GetBlobItem::Item(item) => {
66 if let BaoContentItem::Leaf(leaf) = item {
67 parts.push(leaf.data);
68 }
69 }
70 GetBlobItem::Done(stats) => {
71 break stats;
72 }
73 GetBlobItem::Error(cause) => {
74 return Err(cause);
75 }
76 }
77 };
78 let bytes = if parts.len() == 1 {
79 parts.pop().unwrap()
80 } else {
81 let mut bytes = Vec::new();
82 for part in parts {
83 bytes.extend_from_slice(&part);
84 }
85 bytes.into()
86 };
87 Ok((bytes, stats))
88 }
89}
90
91impl Stream for GetBlobResult {
92 type Item = GetBlobItem;
93
94 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
95 self.rx.poll_next(cx)
96 }
97}
98
99#[derive(Debug)]
101#[enum_conversions()]
102pub enum GetBlobItem {
103 Item(BaoContentItem),
105 Done(Stats),
107 Error(GetError),
109}
110
111pub fn get_blob(connection: Connection, hash: Hash) -> GetBlobResult {
112 let generator = Gen::new(|co| async move {
113 if let Err(cause) = get_blob_impl(&connection, &hash, &co).await {
114 co.yield_(GetBlobItem::Error(cause)).await;
115 }
116 });
117 GetBlobResult {
118 rx: Box::pin(generator),
119 }
120}
121
122async fn get_blob_impl(
123 connection: &Connection,
124 hash: &Hash,
125 co: &Co<GetBlobItem>,
126) -> GetResult<()> {
127 let request = GetRequest::blob(*hash);
128 let request = fsm::start(connection.clone(), request, Default::default());
129 let connected = request.next().await?;
130 let fsm::ConnectedNext::StartRoot(start) = connected.next().await? else {
131 unreachable!("expected start root");
132 };
133 let header = start.next();
134 let (mut curr, _size) = header.next().await?;
135 let end = loop {
136 match curr.next().await {
137 fsm::BlobContentNext::More((next, res)) => {
138 co.yield_(res?.into()).await;
139 curr = next;
140 }
141 fsm::BlobContentNext::Done(end) => {
142 break end;
143 }
144 }
145 };
146 let fsm::EndBlobNext::Closing(closing) = end.next() else {
147 unreachable!("expected closing");
148 };
149 let stats = closing.next().await?;
150 co.yield_(stats.into()).await;
151 Ok(())
152}
153
154pub async fn get_unverified_size(connection: &Connection, hash: &Hash) -> GetResult<(u64, Stats)> {
159 let request = GetRequest::new(
160 *hash,
161 ChunkRangesSeq::from_ranges(vec![ChunkRanges::last_chunk()]),
162 );
163 let request = fsm::start(connection.clone(), request, Default::default());
164 let connected = request.next().await?;
165 let fsm::ConnectedNext::StartRoot(start) = connected.next().await? else {
166 unreachable!("expected start root");
167 };
168 let at_blob_header = start.next();
169 let (curr, size) = at_blob_header.next().await?;
170 let stats = curr.finish().next().await?;
171 Ok((size, stats))
172}
173
174pub async fn get_verified_size(connection: &Connection, hash: &Hash) -> GetResult<(u64, Stats)> {
179 tracing::trace!("Getting verified size of {}", hash.to_hex());
180 let request = GetRequest::new(
181 *hash,
182 ChunkRangesSeq::from_ranges(vec![ChunkRanges::last_chunk()]),
183 );
184 let request = fsm::start(connection.clone(), request, Default::default());
185 let connected = request.next().await?;
186 let fsm::ConnectedNext::StartRoot(start) = connected.next().await? else {
187 unreachable!("expected start root");
188 };
189 let header = start.next();
190 let (mut curr, size) = header.next().await?;
191 let end = loop {
192 match curr.next().await {
193 fsm::BlobContentNext::More((next, res)) => {
194 let _ = res?;
195 curr = next;
196 }
197 fsm::BlobContentNext::Done(end) => {
198 break end;
199 }
200 }
201 };
202 let fsm::EndBlobNext::Closing(closing) = end.next() else {
203 unreachable!("expected closing");
204 };
205 let stats = closing.next().await?;
206 tracing::trace!(
207 "Got verified size of {}, {:.6}s",
208 hash.to_hex(),
209 stats.elapsed.as_secs_f64()
210 );
211 Ok((size, stats))
212}
213
214pub async fn get_hash_seq_and_sizes(
222 connection: &Connection,
223 hash: &Hash,
224 max_size: u64,
225 _progress: Option<mpsc::Sender<u64>>,
226) -> GetResult<(HashSeq, Arc<[u64]>)> {
227 let content = HashAndFormat::hash_seq(*hash);
228 tracing::debug!("Getting hash seq and children sizes of {}", content);
229 let request = GetRequest::new(
230 *hash,
231 ChunkRangesSeq::from_ranges_infinite([ChunkRanges::all(), ChunkRanges::last_chunk()]),
232 );
233 let at_start = fsm::start(connection.clone(), request, Default::default());
234 let at_connected = at_start.next().await?;
235 let fsm::ConnectedNext::StartRoot(start) = at_connected.next().await? else {
236 unreachable!("query includes root");
237 };
238 let at_start_root = start.next();
239 let (at_blob_content, size) = at_start_root.next().await?;
240 if size > max_size {
242 return Err(BadRequestSnafu.into_error(anyhow::anyhow!("size too large").into()));
243 }
244 let (mut curr, hash_seq) = at_blob_content.concatenate_into_vec().await?;
245 let hash_seq = HashSeq::try_from(Bytes::from(hash_seq))
246 .map_err(|e| BadRequestSnafu.into_error(e.into()))?;
247 let mut sizes = Vec::with_capacity(hash_seq.len());
248 let closing = loop {
249 match curr.next() {
250 fsm::EndBlobNext::MoreChildren(more) => {
251 let hash = match hash_seq.get(sizes.len()) {
252 Some(hash) => hash,
253 None => break more.finish(),
254 };
255 let at_header = more.next(hash);
256 let (at_content, size) = at_header.next().await?;
257 let next = at_content.drain().await?;
258 sizes.push(size);
259 curr = next;
260 }
261 fsm::EndBlobNext::Closing(closing) => break closing,
262 }
263 };
264 let _stats = closing.next().await?;
265 tracing::debug!(
266 "Got hash seq and children sizes of {}: {:?}",
267 content,
268 sizes
269 );
270 Ok((hash_seq, sizes.into()))
271}
272
273pub async fn get_chunk_probe(
285 connection: &Connection,
286 hash: &Hash,
287 chunk: ChunkNum,
288) -> GetResult<Stats> {
289 let ranges = ChunkRanges::from(chunk..chunk + 1);
290 let ranges = ChunkRangesSeq::from_ranges([ranges]);
291 let request = GetRequest::new(*hash, ranges);
292 let request = fsm::start(connection.clone(), request, Default::default());
293 let connected = request.next().await?;
294 let fsm::ConnectedNext::StartRoot(start) = connected.next().await? else {
295 unreachable!("query includes root");
296 };
297 let header = start.next();
298 let (mut curr, _size) = header.next().await?;
299 let end = loop {
300 match curr.next().await {
301 fsm::BlobContentNext::More((next, res)) => {
302 res?;
303 curr = next;
304 }
305 fsm::BlobContentNext::Done(end) => {
306 break end;
307 }
308 }
309 };
310 let fsm::EndBlobNext::Closing(closing) = end.next() else {
311 unreachable!("query contains only one blob");
312 };
313 let stats = closing.next().await?;
314 Ok(stats)
315}
316
317pub fn random_hash_seq_ranges(sizes: &[u64], mut rng: impl Rng) -> ChunkRangesSeq {
323 let total_chunks = sizes
324 .iter()
325 .map(|size| ChunkNum::full_chunks(*size).0)
326 .sum::<u64>();
327 let random_chunk = rng.gen_range(0..total_chunks);
328 let mut remaining = random_chunk;
329 let mut ranges = vec![];
330 ranges.push(ChunkRanges::empty());
331 for size in sizes.iter() {
332 let chunks = ChunkNum::full_chunks(*size).0;
333 if remaining < chunks {
334 ranges.push(ChunkRanges::from(
335 ChunkNum(remaining)..ChunkNum(remaining + 1),
336 ));
337 break;
338 } else {
339 remaining -= chunks;
340 ranges.push(ChunkRanges::empty());
341 }
342 }
343 ChunkRangesSeq::from_ranges(ranges)
344}