hugging_face_client/api/
download_parquet.rs1use std::{
2 pin::Pin,
3 task::{Context, Poll},
4};
5
6use bytes::Bytes;
7use futures_core::Stream;
8use pin_project_lite::pin_project;
9use serde::Serialize;
10use snafu::ResultExt;
11
12use crate::errors::{ReqwestClientSnafu, Result};
13
14#[derive(Debug, Serialize)]
16pub struct DownloadParquetReq<'a> {
17 #[serde(rename = "repo_id")]
18 pub(crate) repo_name: &'a str,
19
20 pub(crate) subset: &'a str,
21
22 pub(crate) split: &'a str,
23
24 pub(crate) nth: usize,
25}
26
27impl<'a> DownloadParquetReq<'a> {
28 pub fn new(repo_name: &'a str, subset: &'a str, split: &'a str) -> DownloadParquetReq<'a> {
29 DownloadParquetReq {
30 repo_name,
31 subset,
32 split,
33 nth: 0,
34 }
35 }
36
37 pub fn nth(mut self, nth: usize) -> DownloadParquetReq<'a> {
38 self.nth = nth;
39 self
40 }
41}
42
43pin_project! {
44 pub struct DownloadParquetRes<T> {
45 #[pin]
46 stream: T,
47 }
48}
49
50impl<T> DownloadParquetRes<T> {
51 pub(crate) fn new(stream: T) -> DownloadParquetRes<T> {
52 DownloadParquetRes { stream }
53 }
54
55 pub fn get_stream(self) -> T {
56 self.stream
57 }
58}
59
60impl<T> Stream for DownloadParquetRes<T>
61where
62 T: Stream<Item = reqwest::Result<Bytes>>,
63{
64 type Item = Result<Bytes>;
65
66 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
67 let project = self.project();
68 let stream: Pin<&mut _> = project.stream;
69 stream
70 .poll_next(cx)
71 .map(|a| a.map(|b| b.context(ReqwestClientSnafu)))
72 }
73}