use futures_util::StreamExt;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use serde::de::DeserializeOwned;
use serde::Serialize;
use tokio::io::AsyncWriteExt;
use crate::error::{Error, Result};
use crate::token::Token;
#[derive(Clone)]
pub struct Client {
inner: reqwest::Client,
base_url: String,
}
impl Client {
pub fn new(base_url: impl Into<String>, token: Token) -> Result<Self> {
let mut headers = HeaderMap::new();
let value = format!("Bearer {}", token.as_str());
let header = HeaderValue::from_str(&value)
.map_err(|_| Error::BadRequest("token contained invalid bytes".into()))?;
headers.insert(AUTHORIZATION, header);
let inner = reqwest::Client::builder()
.default_headers(headers)
.user_agent(concat!(
"wavekat-platform-client/",
env!("CARGO_PKG_VERSION")
))
.build()?;
Ok(Self {
inner,
base_url: base_url.into().trim_end_matches('/').to_string(),
})
}
pub fn base_url(&self) -> &str {
&self.base_url
}
fn url(&self, path: &str) -> String {
format!("{}{}", self.base_url, path)
}
pub async fn get_json<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
let url = self.url(path);
let resp = self.inner.get(&url).send().await?;
decode(url, resp).await
}
pub async fn get_json_query<T: DeserializeOwned, Q: Serialize + ?Sized>(
&self,
path: &str,
query: &Q,
) -> Result<T> {
let url = self.url(path);
let resp = self.inner.get(&url).query(query).send().await?;
decode(url, resp).await
}
pub async fn post_json<T: DeserializeOwned, B: Serialize + ?Sized>(
&self,
path: &str,
body: &B,
) -> Result<T> {
let url = self.url(path);
let resp = self.inner.post(&url).json(body).send().await?;
decode(url, resp).await
}
pub async fn post_empty(&self, path: &str) -> Result<()> {
let url = self.url(path);
let resp = self.inner.post(&url).send().await?;
ensure_success(url, resp).await
}
pub async fn post_empty_returning_json<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
let url = self.url(path);
let resp = self.inner.post(&url).send().await?;
decode(url, resp).await
}
pub async fn delete(&self, path: &str) -> Result<()> {
let url = self.url(path);
let resp = self.inner.delete(&url).send().await?;
ensure_success(url, resp).await
}
pub async fn put_proxy_bytes(&self, path: &str, body: Vec<u8>) -> Result<()> {
let url = self.url(path);
let resp = self
.inner
.put(&url)
.header(reqwest::header::CONTENT_TYPE, "application/octet-stream")
.body(body)
.send()
.await?;
ensure_success(url, resp).await
}
pub async fn put_presigned_bytes(presigned_url: &str, body: Vec<u8>) -> Result<()> {
let resp = reqwest::Client::new()
.put(presigned_url)
.body(body)
.send()
.await?;
ensure_success(presigned_url.to_string(), resp).await
}
pub async fn get_stream_to<W: AsyncWriteExt + Unpin>(
&self,
path: &str,
sink: &mut W,
) -> Result<u64> {
let url = self.url(path);
let resp = self.inner.get(&url).send().await?;
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
return Err(http_error(status.as_u16(), url, body));
}
let mut stream = resp.bytes_stream();
let mut written: u64 = 0;
while let Some(chunk) = stream.next().await {
let bytes = chunk?;
sink.write_all(&bytes).await?;
written += bytes.len() as u64;
}
sink.flush().await?;
Ok(written)
}
}
async fn decode<T: DeserializeOwned>(url: String, resp: reqwest::Response) -> Result<T> {
let status = resp.status();
let text = resp.text().await?;
if !status.is_success() {
return Err(http_error(status.as_u16(), url, text));
}
serde_json::from_str(&text).map_err(|source| Error::Decode { url, source })
}
async fn ensure_success(url: String, resp: reqwest::Response) -> Result<()> {
let status = resp.status();
if status.is_success() {
return Ok(());
}
let body = resp.text().await.unwrap_or_default();
Err(http_error(status.as_u16(), url, body))
}
fn http_error(status: u16, url: String, body: String) -> Error {
let body = truncate(&body, 500).to_string();
if status == 401 {
Error::Unauthorized { url, body }
} else {
Error::Http { status, url, body }
}
}
fn truncate(s: &str, n: usize) -> &str {
if s.len() > n {
let mut end = n;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
&s[..end]
} else {
s
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn http_error_format_matches_cli_shape() {
let e = Error::Http {
status: 500,
url: "https://platform.wavekat.com/api/me".into(),
body: "boom".into(),
};
let s = e.to_string();
assert!(s.contains("500"), "{s}");
assert!(s.contains("https://platform.wavekat.com/api/me"), "{s}");
assert!(s.contains("boom"), "{s}");
}
#[test]
fn http_error_splits_401_into_unauthorized() {
let e = http_error(
401,
"https://platform.wavekat.com/api/me".into(),
"{\"error\":\"unauthenticated\"}".into(),
);
assert!(
matches!(e, Error::Unauthorized { .. }),
"expected Unauthorized, got {e:?}"
);
let s = e.to_string();
assert!(s.contains("401"), "{s}");
assert!(s.contains("https://platform.wavekat.com/api/me"), "{s}");
}
#[test]
fn http_error_keeps_non_401_in_http_variant() {
let e = http_error(
500,
"https://platform.wavekat.com/api/me".into(),
"boom".into(),
);
assert!(
matches!(e, Error::Http { status: 500, .. }),
"expected Http {{ status: 500 }}, got {e:?}"
);
}
#[test]
fn truncate_respects_char_boundaries() {
let s = "a".repeat(498) + "é"; let t = truncate(&s, 499);
assert!(s.starts_with(t));
}
}