#![cfg(feature = "reqwest")]
pub use reqwest;
use std::sync::Arc;
use futures_util::StreamExt;
use hyper_rustls::ConfigBuilderExt;
use reqwest::{
header::{HeaderMap, AUTHORIZATION, CONTENT_TYPE},
Body, Client, Method, Request, StatusCode, Url,
};
use crate::{
agent::{
agent_error::HttpErrorPayload,
http_transport::{IC0_DOMAIN, IC0_SUB_DOMAIN},
AgentFuture, ReplicaV2Transport,
},
export::Principal,
AgentError, RequestId,
};
pub trait PasswordManager: Send + Sync {
fn cached(&self, url: &str) -> Result<Option<(String, String)>, String>;
fn required(&self, url: &str) -> Result<(String, String), String>;
}
impl dyn PasswordManager {
fn get(&self, cached: bool, url: &str) -> Result<Option<(String, String)>, AgentError> {
if cached {
self.cached(url)
} else {
self.required(url).map(Some)
}
.map_err(AgentError::AuthenticationError)
}
}
impl_debug_empty!(dyn PasswordManager);
#[derive(Debug)]
pub struct ReqwestHttpReplicaV2Transport {
url: Url,
client: Client,
password_manager: Option<Arc<dyn PasswordManager>>,
max_response_body_size: Option<usize>,
}
impl ReqwestHttpReplicaV2Transport {
pub fn create<U: Into<String>>(url: U) -> Result<Self, AgentError> {
let mut tls_config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_webpki_roots()
.with_no_client_auth();
tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Self::create_with_client(
url,
Client::builder()
.use_preconfigured_tls(tls_config)
.build()
.expect("Could not create HTTP client."),
)
}
pub fn create_with_client<U: Into<String>>(url: U, client: Client) -> Result<Self, AgentError> {
let url = url.into();
Ok(Self {
url: Url::parse(&url)
.and_then(|mut url| {
if let Some(domain) = url.domain() {
if domain.ends_with(IC0_SUB_DOMAIN) {
url.set_host(Some(IC0_DOMAIN))?;
}
}
url.join("api/v2/")
})
.map_err(|_| AgentError::InvalidReplicaUrl(url.clone()))?,
client,
password_manager: None,
max_response_body_size: None,
})
}
pub fn with_password_manager<P: 'static + PasswordManager>(self, password_manager: P) -> Self {
self.with_arc_password_manager(Arc::new(password_manager))
}
pub fn with_arc_password_manager(self, password_manager: Arc<dyn PasswordManager>) -> Self {
ReqwestHttpReplicaV2Transport {
password_manager: Some(password_manager),
..self
}
}
pub fn with_max_response_body_size(self, max_response_body_size: usize) -> Self {
ReqwestHttpReplicaV2Transport {
max_response_body_size: Some(max_response_body_size),
..self
}
}
pub fn password_manager(&self) -> Option<&dyn PasswordManager> {
self.password_manager.as_deref()
}
fn maybe_add_authorization(
&self,
http_request: &mut Request,
cached: bool,
) -> Result<(), AgentError> {
if let Some(pm) = &self.password_manager {
if let Some((u, p)) = pm.get(cached, http_request.url().as_str())? {
let auth = base64::encode(&format!("{}:{}", u, p));
http_request
.headers_mut()
.insert(AUTHORIZATION, format!("Basic {}", auth).parse().unwrap());
}
}
Ok(())
}
async fn request(
&self,
http_request: Request,
) -> Result<(StatusCode, HeaderMap, Vec<u8>), AgentError> {
let response = self
.client
.execute(http_request)
.await
.map_err(|x| AgentError::TransportError(Box::new(x)))?;
let http_status = response.status();
let response_headers = response.headers().clone();
if matches!(self
.max_response_body_size
.zip(response.content_length()), Some((size_limit, content_length)) if content_length as usize > size_limit)
{
return Err(AgentError::ResponseSizeExceededLimit());
}
let mut body: Vec<u8> = response
.content_length()
.map_or_else(Vec::new, |n| Vec::with_capacity(n as usize));
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|x| AgentError::TransportError(Box::new(x)))?;
if matches!(self
.max_response_body_size, Some(size_limit) if body.len() + chunk.len() > size_limit)
{
return Err(AgentError::ResponseSizeExceededLimit());
}
body.extend_from_slice(chunk.as_ref());
}
Ok((http_status, response_headers, body))
}
async fn execute(
&self,
method: Method,
endpoint: &str,
body: Option<Vec<u8>>,
) -> Result<Vec<u8>, AgentError> {
let url = self.url.join(endpoint)?;
let mut http_request = Request::new(method, url);
http_request
.headers_mut()
.insert(CONTENT_TYPE, "application/cbor".parse().unwrap());
self.maybe_add_authorization(&mut http_request, true)?;
*http_request.body_mut() = body.map(Body::from);
let mut status;
let mut headers;
let mut body;
loop {
let request_result = self.request(http_request.try_clone().unwrap()).await?;
status = request_result.0;
headers = request_result.1;
body = request_result.2;
if status == StatusCode::UNAUTHORIZED {
if self.url.scheme() == "https" || self.url.host_str() == Some("localhost") {
self.maybe_add_authorization(&mut http_request, false)?;
} else {
return Err(AgentError::CannotUseAuthenticationOnNonSecureUrl());
}
} else {
break;
}
}
if status.is_client_error() || status.is_server_error() {
Err(AgentError::HttpError(HttpErrorPayload {
status: status.into(),
content_type: headers
.get(CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map(|x| x.to_string()),
content: body,
}))
} else {
Ok(body)
}
}
}
impl ReplicaV2Transport for ReqwestHttpReplicaV2Transport {
fn call(
&self,
effective_canister_id: Principal,
envelope: Vec<u8>,
_request_id: RequestId,
) -> AgentFuture<()> {
Box::pin(async move {
let endpoint = format!("canister/{}/call", effective_canister_id.to_text());
self.execute(Method::POST, &endpoint, Some(envelope))
.await?;
Ok(())
})
}
fn read_state(
&self,
effective_canister_id: Principal,
envelope: Vec<u8>,
) -> AgentFuture<Vec<u8>> {
Box::pin(async move {
let endpoint = format!("canister/{effective_canister_id}/read_state");
self.execute(Method::POST, &endpoint, Some(envelope)).await
})
}
fn query(&self, effective_canister_id: Principal, envelope: Vec<u8>) -> AgentFuture<Vec<u8>> {
Box::pin(async move {
let endpoint = format!("canister/{effective_canister_id}/query");
self.execute(Method::POST, &endpoint, Some(envelope)).await
})
}
fn status(&self) -> AgentFuture<Vec<u8>> {
Box::pin(async move { self.execute(Method::GET, "status", None).await })
}
}
#[cfg(test)]
mod test {
use super::ReqwestHttpReplicaV2Transport;
#[test]
fn redirect() {
fn test(base: &str, result: &str) {
let t = ReqwestHttpReplicaV2Transport::create(base).unwrap();
assert_eq!(t.url.as_str(), result, "{}", base);
}
test("https://ic0.app", "https://ic0.app/api/v2/");
test("https://IC0.app", "https://ic0.app/api/v2/");
test("https://foo.ic0.app", "https://ic0.app/api/v2/");
test("https://foo.IC0.app", "https://ic0.app/api/v2/");
test("https://foo.Ic0.app", "https://ic0.app/api/v2/");
test("https://foo.iC0.app", "https://ic0.app/api/v2/");
test("https://foo.bar.ic0.app", "https://ic0.app/api/v2/");
test("https://ic0.app/foo/", "https://ic0.app/foo/api/v2/");
test("https://foo.ic0.app/foo/", "https://ic0.app/foo/api/v2/");
test("https://ic1.app", "https://ic1.app/api/v2/");
test("https://foo.ic1.app", "https://foo.ic1.app/api/v2/");
test("https://ic0.app.ic1.app", "https://ic0.app.ic1.app/api/v2/");
test("https://fooic0.app", "https://fooic0.app/api/v2/");
test("https://fooic0.app.ic0.app", "https://ic0.app/api/v2/");
}
}