1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
use crate::OkaeriSdkError; use hyper::client::HttpConnector; use hyper::{Body, Client, Method, Request}; use hyper_timeout::TimeoutConnector; use hyper_tls::HttpsConnector; use serde::de::DeserializeOwned; use std::collections::HashMap; use std::env; use std::time::Duration; use url::Url; type Result<T> = std::result::Result<T, OkaeriSdkError>; pub(crate) struct OkaeriClient { base_url: Url, hyper: Client<TimeoutConnector<HttpsConnector<HttpConnector>>>, headers: HashMap<String, String>, } impl OkaeriClient { pub fn new( base_url: Url, timeout: Duration, mut headers: HashMap<String, String>, ) -> Result<Self> { let https = HttpsConnector::new(); let mut connector = TimeoutConnector::new(https); connector.set_connect_timeout(Some(timeout)); connector.set_read_timeout(Some(timeout)); connector.set_write_timeout(Some(timeout)); let hyper = Client::builder().build::<_, hyper::Body>(connector); headers.insert( String::from("User-Agent"), String::from("okaeri-sdk/1 (rust)"), ); Ok(OkaeriClient { base_url, hyper, headers, }) } pub(crate) fn read_base_url(provided: Option<&str>, def: &str, env_name: &str) -> Result<Url> { let base_url = match env::var(env_name) { Ok(value) => value, Err(_) => String::from(provided.unwrap_or(def)), }; let base_url = Url::parse(base_url.as_str()).map_err(|source| OkaeriSdkError::InvalidUrl { url: base_url, source, })?; Ok(base_url) } pub(crate) fn read_timeout( provided: Option<Duration>, def: Duration, env_name: &str, ) -> Result<Duration> { let timeout = match env::var(env_name) { Ok(from) => { let value = from .parse::<u64>() .map_err(|_| OkaeriSdkError::InvalidInt { from })?; Duration::from_millis(value) } Err(_) => provided.unwrap_or(def), }; Ok(timeout) } pub(crate) async fn post<T>(self, path: impl AsRef<str>, body: impl Into<String>) -> Result<T> where T: DeserializeOwned, { self.request(path, body, Method::POST).await } pub(crate) async fn get<T>(self, path: impl AsRef<str>) -> Result<T> where T: DeserializeOwned, { self.request(path, "", Method::GET).await } async fn request<T>( self, path: impl AsRef<str>, body: impl Into<String>, method: impl Into<Method>, ) -> Result<T> where T: DeserializeOwned, { let path = path.as_ref(); let body = body.into(); let method = method.into(); let url = format!("{}{}", self.base_url, path); let mut req = Request::builder().method(method).uri(url); for (key, value) in self.headers { req = req.header(key.as_str(), value.as_str()); } let req = req .body(Body::from(body)) .map_err(|err| OkaeriSdkError::ResponseError { group: String::from("REQUEST_ERROR"), message: format!("failed to create request: {}", err), })?; let res = self .hyper .request(req) .await .map_err(|err| OkaeriSdkError::ResponseError { group: String::from("REQUEST_ERROR"), message: format!("failed to dispatch request: {}", err), })?; if !res.status().is_success() { let error = OkaeriSdkError::ResponseError { group: String::from("REQUEST_ERROR"), message: format!("received invalid status code {}", res.status()), }; return Err(error); } let bytes = hyper::body::to_bytes(res) .await .map_err(|err| OkaeriSdkError::ResponseError { group: String::from("REQUEST_ERROR"), message: format!("failed to process request: {}", err), })?; let body_str = String::from_utf8(bytes.to_vec()).map_err(|err| OkaeriSdkError::ResponseError { group: String::from("REQUEST_ERROR"), message: format!("failed to convert body to string: {}", err), })?; let info: T = serde_json::from_str(&body_str).map_err(|_| OkaeriSdkError::ResponseParseError { body: body_str.clone(), })?; Ok(info) } }