hugging_face_client/api/
download_parquet.rs

1use std::{
2  pin::Pin,
3  task::{Context, Poll},
4};
5
6use bytes::Bytes;
7use futures_core::Stream;
8use pin_project_lite::pin_project;
9use serde::Serialize;
10use snafu::ResultExt;
11
12use crate::errors::{ReqwestClientSnafu, Result};
13
14/// Request of [`crate::client::Client::download_parquet`]
15#[derive(Debug, Serialize)]
16pub struct DownloadParquetReq<'a> {
17  #[serde(rename = "repo_id")]
18  pub(crate) repo_name: &'a str,
19
20  pub(crate) subset: &'a str,
21
22  pub(crate) split: &'a str,
23
24  pub(crate) nth: usize,
25}
26
27impl<'a> DownloadParquetReq<'a> {
28  pub fn new(repo_name: &'a str, subset: &'a str, split: &'a str) -> DownloadParquetReq<'a> {
29    DownloadParquetReq {
30      repo_name,
31      subset,
32      split,
33      nth: 0,
34    }
35  }
36
37  pub fn nth(mut self, nth: usize) -> DownloadParquetReq<'a> {
38    self.nth = nth;
39    self
40  }
41}
42
43pin_project! {
44  pub struct DownloadParquetRes<T> {
45    #[pin]
46    stream: T,
47  }
48}
49
50impl<T> DownloadParquetRes<T> {
51  pub(crate) fn new(stream: T) -> DownloadParquetRes<T> {
52    DownloadParquetRes { stream }
53  }
54
55  pub fn get_stream(self) -> T {
56    self.stream
57  }
58}
59
60impl<T> Stream for DownloadParquetRes<T>
61where
62  T: Stream<Item = reqwest::Result<Bytes>>,
63{
64  type Item = Result<Bytes>;
65
66  fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
67    let project = self.project();
68    let stream: Pin<&mut _> = project.stream;
69    stream
70      .poll_next(cx)
71      .map(|a| a.map(|b| b.context(ReqwestClientSnafu)))
72  }
73}