gcs_reader/
reader.rs

1use std::collections::HashMap;
2use std::io::{Read, Seek, SeekFrom};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use bytes::Bytes;
7use color_eyre::eyre::{ensure, eyre, Result};
8use gcp_auth::AuthenticationManager;
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, RANGE};
10use reqwest::{Client, ClientBuilder};
11use serde_json::Value;
12use tokio::io::{AsyncRead, ReadBuf};
13use tokio::runtime::Runtime;
14
15use crate::errors::GCSReaderError;
16use crate::uri::GCSObjectURI;
17
18macro_rules! bearer {
19    ($token:expr) => {
20        format!("Bearer {}", $token.as_str())
21    };
22}
23
24pub enum Auth {
25    Auto,
26    Token(String),
27}
28
29impl Default for Auth {
30    fn default() -> Self {
31        Self::Auto
32    }
33}
34
35impl Auth {
36    async fn gcp_auth_token() -> Result<String> {
37        let authentication_manager = AuthenticationManager::new().await?;
38        let scopes = &["https://www.googleapis.com/auth/devstorage.read_only"];
39        let token = authentication_manager.get_token(scopes).await?;
40        Ok(token.as_str().to_string())
41    }
42
43    async fn token(&self) -> Result<String> {
44        match self {
45            Self::Auto => Self::gcp_auth_token().await,
46            Self::Token(token) => Ok(token.to_string()),
47        }
48    }
49}
50
51#[derive(Debug)]
52pub struct GCSReader {
53    client: Client,
54    uri: GCSObjectURI,
55    pos: u64,
56    len: u64,
57}
58
59impl GCSReader {
60    pub fn open(uri: GCSObjectURI, auth: Auth) -> Result<Self> {
61        let token = Runtime::new()?.block_on(auth.token())?;
62
63        let md_res = reqwest::blocking::Client::new()
64            .get(uri.endpoint())
65            .header(AUTHORIZATION, HeaderValue::from_str(&bearer!(token))?)
66            .send()?;
67
68        ensure!(md_res.status().is_success(), GCSReaderError::from_response(md_res)?);
69        let len = md_res
70            .json::<HashMap<String, Value>>()?
71            .get("size")
72            .ok_or(GCSReaderError::GetSizeError(uri.uri()))?
73            .as_str()
74            .unwrap()
75            .parse::<u64>()
76            .unwrap();
77
78        let mut header = HeaderMap::new();
79        header.insert(AUTHORIZATION, HeaderValue::from_str(&bearer!(token))?);
80        let client = ClientBuilder::new().default_headers(header).build()?;
81
82        Ok(Self {
83            client,
84            uri,
85            pos: 0,
86            len,
87        })
88    }
89
90    pub fn from_uri(uri: &str, auth: Auth) -> Result<Self> {
91        let uri = GCSObjectURI::new(uri)?;
92        Self::open(uri, auth)
93    }
94
95    pub async fn read_range(&self, start: u64, end: u64) -> Result<Bytes> {
96        let range = format!("bytes={}-{}", start, end - 1);
97        let mut header = HeaderMap::new();
98        header.insert(RANGE, HeaderValue::from_str(&range)?);
99
100        let mut params = HashMap::new();
101        params.insert("alt", "media");
102
103        let res = self
104            .client
105            .get(self.uri.endpoint())
106            .headers(header)
107            .query(&params)
108            .send()
109            .await?;
110
111        ensure!(
112            res.status().is_success(),
113            GCSReaderError::from_async_response(res).await?
114        );
115        res.bytes().await.map_err(|e| eyre!(e))
116    }
117}
118
119impl Read for GCSReader {
120    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
121        let start = self.pos;
122        let end = std::cmp::min(self.pos + (buf.len() as u64), self.len);
123        if start == end {
124            return Ok(0);
125        }
126        let bytes = Runtime::new()?
127            .block_on(self.read_range(start, end))
128            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
129        let len = bytes.len() as u64;
130        buf.clone_from_slice(&bytes);
131        self.pos += len;
132        Ok(len as usize)
133    }
134}
135
136impl AsyncRead for GCSReader {
137    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
138        todo!()
139    }
140}
141
142impl Seek for GCSReader {
143    fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
144        let new_pos = match pos {
145            SeekFrom::Start(pos) => pos as i64,
146            SeekFrom::End(pos) => self.len as i64 + pos,
147            SeekFrom::Current(pos) => self.pos as i64 + pos,
148        };
149        if new_pos < 0 && new_pos >= self.len as i64 {
150            return Err(std::io::Error::new(
151                std::io::ErrorKind::InvalidInput,
152                format!("Invalid seek position: {}", new_pos),
153            ));
154        }
155        self.pos = new_pos as u64;
156        Ok(self.pos)
157    }
158}