1use std::{
10 future::{Future, IntoFuture},
11 pin::Pin,
12 sync::Arc,
13 task::{Context, Poll},
14};
15
16use bao_tree::{io::BaoContentItem, ChunkNum, ChunkRanges};
17use bytes::Bytes;
18use genawaiter::sync::{Co, Gen};
19use iroh::endpoint::Connection;
20use n0_future::{Stream, StreamExt};
21use nested_enum_utils::enum_conversions;
22use rand::Rng;
23use snafu::IntoError;
24use tokio::sync::mpsc;
25
26use super::{fsm, GetError, GetResult, Stats};
27use crate::{
28 get::error::{BadRequestSnafu, LocalFailureSnafu},
29 hashseq::HashSeq,
30 protocol::{ChunkRangesExt, ChunkRangesSeq, GetRequest},
31 Hash, HashAndFormat,
32};
33
34pub struct GetBlobResult {
39 rx: n0_future::stream::Boxed<GetBlobItem>,
40}
41
42impl IntoFuture for GetBlobResult {
43 type Output = GetResult<Bytes>;
44 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
45
46 fn into_future(self) -> Self::IntoFuture {
47 Box::pin(self.bytes())
48 }
49}
50
51impl GetBlobResult {
52 pub async fn bytes(self) -> GetResult<Bytes> {
53 let (bytes, _) = self.bytes_and_stats().await?;
54 Ok(bytes)
55 }
56
57 pub async fn bytes_and_stats(mut self) -> GetResult<(Bytes, Stats)> {
58 let mut parts = Vec::new();
59 let stats = loop {
60 let Some(item) = self.next().await else {
61 return Err(LocalFailureSnafu.into_error(anyhow::anyhow!("unexpected end").into()));
62 };
63 match item {
64 GetBlobItem::Item(item) => {
65 if let BaoContentItem::Leaf(leaf) = item {
66 parts.push(leaf.data);
67 }
68 }
69 GetBlobItem::Done(stats) => {
70 break stats;
71 }
72 GetBlobItem::Error(cause) => {
73 return Err(cause);
74 }
75 }
76 };
77 let bytes = if parts.len() == 1 {
78 parts.pop().unwrap()
79 } else {
80 let mut bytes = Vec::new();
81 for part in parts {
82 bytes.extend_from_slice(&part);
83 }
84 bytes.into()
85 };
86 Ok((bytes, stats))
87 }
88}
89
90impl Stream for GetBlobResult {
91 type Item = GetBlobItem;
92
93 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
94 self.rx.poll_next(cx)
95 }
96}
97
98#[derive(Debug)]
100#[enum_conversions()]
101pub enum GetBlobItem {
102 Item(BaoContentItem),
104 Done(Stats),
106 Error(GetError),
108}
109
110pub fn get_blob(connection: Connection, hash: Hash) -> GetBlobResult {
111 let generator = Gen::new(|co| async move {
112 if let Err(cause) = get_blob_impl(&connection, &hash, &co).await {
113 co.yield_(GetBlobItem::Error(cause)).await;
114 }
115 });
116 GetBlobResult {
117 rx: Box::pin(generator),
118 }
119}
120
121async fn get_blob_impl(
122 connection: &Connection,
123 hash: &Hash,
124 co: &Co<GetBlobItem>,
125) -> GetResult<()> {
126 let request = GetRequest::blob(*hash);
127 let request = fsm::start(connection.clone(), request, Default::default());
128 let connected = request.next().await?;
129 let fsm::ConnectedNext::StartRoot(start) = connected.next().await? else {
130 unreachable!("expected start root");
131 };
132 let header = start.next();
133 let (mut curr, _size) = header.next().await?;
134 let end = loop {
135 match curr.next().await {
136 fsm::BlobContentNext::More((next, res)) => {
137 co.yield_(res?.into()).await;
138 curr = next;
139 }
140 fsm::BlobContentNext::Done(end) => {
141 break end;
142 }
143 }
144 };
145 let fsm::EndBlobNext::Closing(closing) = end.next() else {
146 unreachable!("expected closing");
147 };
148 let stats = closing.next().await?;
149 co.yield_(stats.into()).await;
150 Ok(())
151}
152
153pub async fn get_unverified_size(connection: &Connection, hash: &Hash) -> GetResult<(u64, Stats)> {
158 let request = GetRequest::new(
159 *hash,
160 ChunkRangesSeq::from_ranges(vec![ChunkRanges::last_chunk()]),
161 );
162 let request = fsm::start(connection.clone(), request, Default::default());
163 let connected = request.next().await?;
164 let fsm::ConnectedNext::StartRoot(start) = connected.next().await? else {
165 unreachable!("expected start root");
166 };
167 let at_blob_header = start.next();
168 let (curr, size) = at_blob_header.next().await?;
169 let stats = curr.finish().next().await?;
170 Ok((size, stats))
171}
172
173pub async fn get_verified_size(connection: &Connection, hash: &Hash) -> GetResult<(u64, Stats)> {
178 tracing::trace!("Getting verified size of {}", hash.to_hex());
179 let request = GetRequest::new(
180 *hash,
181 ChunkRangesSeq::from_ranges(vec![ChunkRanges::last_chunk()]),
182 );
183 let request = fsm::start(connection.clone(), request, Default::default());
184 let connected = request.next().await?;
185 let fsm::ConnectedNext::StartRoot(start) = connected.next().await? else {
186 unreachable!("expected start root");
187 };
188 let header = start.next();
189 let (mut curr, size) = header.next().await?;
190 let end = loop {
191 match curr.next().await {
192 fsm::BlobContentNext::More((next, res)) => {
193 let _ = res?;
194 curr = next;
195 }
196 fsm::BlobContentNext::Done(end) => {
197 break end;
198 }
199 }
200 };
201 let fsm::EndBlobNext::Closing(closing) = end.next() else {
202 unreachable!("expected closing");
203 };
204 let stats = closing.next().await?;
205 tracing::trace!(
206 "Got verified size of {}, {:.6}s",
207 hash.to_hex(),
208 stats.elapsed.as_secs_f64()
209 );
210 Ok((size, stats))
211}
212
213pub async fn get_hash_seq_and_sizes(
221 connection: &Connection,
222 hash: &Hash,
223 max_size: u64,
224 _progress: Option<mpsc::Sender<u64>>,
225) -> GetResult<(HashSeq, Arc<[u64]>)> {
226 let content = HashAndFormat::hash_seq(*hash);
227 tracing::debug!("Getting hash seq and children sizes of {}", content);
228 let request = GetRequest::new(
229 *hash,
230 ChunkRangesSeq::from_ranges_infinite([ChunkRanges::all(), ChunkRanges::last_chunk()]),
231 );
232 let at_start = fsm::start(connection.clone(), request, Default::default());
233 let at_connected = at_start.next().await?;
234 let fsm::ConnectedNext::StartRoot(start) = at_connected.next().await? else {
235 unreachable!("query includes root");
236 };
237 let at_start_root = start.next();
238 let (at_blob_content, size) = at_start_root.next().await?;
239 if size > max_size {
241 return Err(BadRequestSnafu.into_error(anyhow::anyhow!("size too large").into()));
242 }
243 let (mut curr, hash_seq) = at_blob_content.concatenate_into_vec().await?;
244 let hash_seq = HashSeq::try_from(Bytes::from(hash_seq))
245 .map_err(|e| BadRequestSnafu.into_error(e.into()))?;
246 let mut sizes = Vec::with_capacity(hash_seq.len());
247 let closing = loop {
248 match curr.next() {
249 fsm::EndBlobNext::MoreChildren(more) => {
250 let hash = match hash_seq.get(sizes.len()) {
251 Some(hash) => hash,
252 None => break more.finish(),
253 };
254 let at_header = more.next(hash);
255 let (at_content, size) = at_header.next().await?;
256 let next = at_content.drain().await?;
257 sizes.push(size);
258 curr = next;
259 }
260 fsm::EndBlobNext::Closing(closing) => break closing,
261 }
262 };
263 let _stats = closing.next().await?;
264 tracing::debug!(
265 "Got hash seq and children sizes of {}: {:?}",
266 content,
267 sizes
268 );
269 Ok((hash_seq, sizes.into()))
270}
271
272pub async fn get_chunk_probe(
284 connection: &Connection,
285 hash: &Hash,
286 chunk: ChunkNum,
287) -> GetResult<Stats> {
288 let ranges = ChunkRanges::from(chunk..chunk + 1);
289 let ranges = ChunkRangesSeq::from_ranges([ranges]);
290 let request = GetRequest::new(*hash, ranges);
291 let request = fsm::start(connection.clone(), request, Default::default());
292 let connected = request.next().await?;
293 let fsm::ConnectedNext::StartRoot(start) = connected.next().await? else {
294 unreachable!("query includes root");
295 };
296 let header = start.next();
297 let (mut curr, _size) = header.next().await?;
298 let end = loop {
299 match curr.next().await {
300 fsm::BlobContentNext::More((next, res)) => {
301 res?;
302 curr = next;
303 }
304 fsm::BlobContentNext::Done(end) => {
305 break end;
306 }
307 }
308 };
309 let fsm::EndBlobNext::Closing(closing) = end.next() else {
310 unreachable!("query contains only one blob");
311 };
312 let stats = closing.next().await?;
313 Ok(stats)
314}
315
316pub fn random_hash_seq_ranges(sizes: &[u64], mut rng: impl Rng) -> ChunkRangesSeq {
322 let total_chunks = sizes
323 .iter()
324 .map(|size| ChunkNum::full_chunks(*size).0)
325 .sum::<u64>();
326 let random_chunk = rng.gen_range(0..total_chunks);
327 let mut remaining = random_chunk;
328 let mut ranges = vec![];
329 ranges.push(ChunkRanges::empty());
330 for size in sizes.iter() {
331 let chunks = ChunkNum::full_chunks(*size).0;
332 if remaining < chunks {
333 ranges.push(ChunkRanges::from(
334 ChunkNum(remaining)..ChunkNum(remaining + 1),
335 ));
336 break;
337 } else {
338 remaining -= chunks;
339 ranges.push(ChunkRanges::empty());
340 }
341 }
342 ChunkRangesSeq::from_ranges(ranges)
343}