1use std::collections::HashMap;
2use std::io::{Read, Seek, SeekFrom};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use bytes::Bytes;
7use color_eyre::eyre::{ensure, eyre, Result};
8use gcp_auth::AuthenticationManager;
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, RANGE};
10use reqwest::{Client, ClientBuilder};
11use serde_json::Value;
12use tokio::io::{AsyncRead, ReadBuf};
13use tokio::runtime::Runtime;
14
15use crate::errors::GCSReaderError;
16use crate::uri::GCSObjectURI;
17
18macro_rules! bearer {
19 ($token:expr) => {
20 format!("Bearer {}", $token.as_str())
21 };
22}
23
24pub enum Auth {
25 Auto,
26 Token(String),
27}
28
29impl Default for Auth {
30 fn default() -> Self {
31 Self::Auto
32 }
33}
34
35impl Auth {
36 async fn gcp_auth_token() -> Result<String> {
37 let authentication_manager = AuthenticationManager::new().await?;
38 let scopes = &["https://www.googleapis.com/auth/devstorage.read_only"];
39 let token = authentication_manager.get_token(scopes).await?;
40 Ok(token.as_str().to_string())
41 }
42
43 async fn token(&self) -> Result<String> {
44 match self {
45 Self::Auto => Self::gcp_auth_token().await,
46 Self::Token(token) => Ok(token.to_string()),
47 }
48 }
49}
50
51#[derive(Debug)]
52pub struct GCSReader {
53 client: Client,
54 uri: GCSObjectURI,
55 pos: u64,
56 len: u64,
57}
58
59impl GCSReader {
60 pub fn open(uri: GCSObjectURI, auth: Auth) -> Result<Self> {
61 let token = Runtime::new()?.block_on(auth.token())?;
62
63 let md_res = reqwest::blocking::Client::new()
64 .get(uri.endpoint())
65 .header(AUTHORIZATION, HeaderValue::from_str(&bearer!(token))?)
66 .send()?;
67
68 ensure!(md_res.status().is_success(), GCSReaderError::from_response(md_res)?);
69 let len = md_res
70 .json::<HashMap<String, Value>>()?
71 .get("size")
72 .ok_or(GCSReaderError::GetSizeError(uri.uri()))?
73 .as_str()
74 .unwrap()
75 .parse::<u64>()
76 .unwrap();
77
78 let mut header = HeaderMap::new();
79 header.insert(AUTHORIZATION, HeaderValue::from_str(&bearer!(token))?);
80 let client = ClientBuilder::new().default_headers(header).build()?;
81
82 Ok(Self {
83 client,
84 uri,
85 pos: 0,
86 len,
87 })
88 }
89
90 pub fn from_uri(uri: &str, auth: Auth) -> Result<Self> {
91 let uri = GCSObjectURI::new(uri)?;
92 Self::open(uri, auth)
93 }
94
95 pub async fn read_range(&self, start: u64, end: u64) -> Result<Bytes> {
96 let range = format!("bytes={}-{}", start, end - 1);
97 let mut header = HeaderMap::new();
98 header.insert(RANGE, HeaderValue::from_str(&range)?);
99
100 let mut params = HashMap::new();
101 params.insert("alt", "media");
102
103 let res = self
104 .client
105 .get(self.uri.endpoint())
106 .headers(header)
107 .query(¶ms)
108 .send()
109 .await?;
110
111 ensure!(
112 res.status().is_success(),
113 GCSReaderError::from_async_response(res).await?
114 );
115 res.bytes().await.map_err(|e| eyre!(e))
116 }
117}
118
119impl Read for GCSReader {
120 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
121 let start = self.pos;
122 let end = std::cmp::min(self.pos + (buf.len() as u64), self.len);
123 if start == end {
124 return Ok(0);
125 }
126 let bytes = Runtime::new()?
127 .block_on(self.read_range(start, end))
128 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
129 let len = bytes.len() as u64;
130 buf.clone_from_slice(&bytes);
131 self.pos += len;
132 Ok(len as usize)
133 }
134}
135
136impl AsyncRead for GCSReader {
137 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
138 todo!()
139 }
140}
141
142impl Seek for GCSReader {
143 fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
144 let new_pos = match pos {
145 SeekFrom::Start(pos) => pos as i64,
146 SeekFrom::End(pos) => self.len as i64 + pos,
147 SeekFrom::Current(pos) => self.pos as i64 + pos,
148 };
149 if new_pos < 0 && new_pos >= self.len as i64 {
150 return Err(std::io::Error::new(
151 std::io::ErrorKind::InvalidInput,
152 format!("Invalid seek position: {}", new_pos),
153 ));
154 }
155 self.pos = new_pos as u64;
156 Ok(self.pos)
157 }
158}