use crate::{ffi::http_client_v0 as ffi, ErrorCode};
use bytes::Buf;
use http::{Method, Request};
use http_body::combinators::UnsyncBoxBody;
use http_body::Body;
use http_body::Empty;
use std::error::Error as StdError;
use std::future::Future;
use std::pin::Pin;
use std::string::FromUtf8Error;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
mod pollable;
use http::header::HeaderName;
use http::header::HeaderValue;
use pollable::PollableResponse;
mod error;
pub use error::Error;
#[doc(hidden)]
pub use ffi::API as FFI_API;
pub struct ConnectionBuilder {
host: String,
auth: bool,
user_agent: Option<String>,
}
impl ConnectionBuilder {
pub fn for_host(host: impl ToString) -> Self {
Self {
host: host.to_string(),
auth: false,
user_agent: None,
}
}
pub fn with_authentication(mut self) -> Self {
self.auth = true;
self
}
pub fn with_user_agent(mut self, user_agent: impl ToString) -> Self {
self.user_agent = Some(user_agent.to_string());
self
}
async fn build_internal(self) -> Result<ffi::ChannelHandle, Error> {
let Self {
host,
auth,
user_agent,
} = self;
let build_handle = ffi::channel_create(
&host,
user_agent.as_ref().map_or("", |s| s.as_str()),
u32::from(auth),
)?;
let handle = PollableResponse::<_, _> {
handle: build_handle,
poll_fn: |h: ffi::ChannelBuildHandle| match ffi::channel_create_poll(h) {
Ok(handle) => Poll::Ready(Ok(handle)),
Err(e) if e == ErrorCode::Unavailable => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
},
drop_fn: ffi::channel_create_drop,
}
.await?;
Ok(handle)
}
pub async fn build_grpc(self) -> Result<GrpcChannel, Error> {
let handle = self.build_internal().await?;
Ok(GrpcChannel::new(handle))
}
pub async fn build_http(self) -> Result<HttpClient, Error> {
let handle = self.build_internal().await?;
Ok(HttpClient::new(handle))
}
}
struct Channel(ffi::ChannelHandle);
impl Drop for Channel {
fn drop(&mut self) {
ffi::channel_drop(self.0);
}
}
#[derive(Clone)]
pub struct GrpcChannel {
native: Arc<Channel>,
}
impl GrpcChannel {
fn new(native: ffi::ChannelHandle) -> Self {
Self {
native: Arc::new(Channel(native)),
}
}
}
#[derive(Clone)]
pub struct HttpClient {
native: Arc<Channel>,
}
pub struct HttpResponse {
pub code: http::StatusCode,
pub body: Vec<u8>,
}
impl HttpResponse {
pub fn body_string(self) -> Result<String, FromUtf8Error> {
String::from_utf8(self.body)
}
}
impl HttpClient {
fn new(native: ffi::ChannelHandle) -> Self {
Self {
native: Arc::new(Channel(native)),
}
}
async fn internal_call<D, E, B>(
&mut self,
request: http::Request<B>,
) -> Result<http::Response<UnsyncBoxBody<D, E>>, Error>
where
D: Buf + IntoIterator<Item = u8> + Send + From<Vec<u8>> + Unpin + 'static,
E: StdError + Send + Sync + 'static,
B: http_body::Body<Data = D, Error = E> + Send + Unpin + 'static,
http::Error: From<E>,
{
let channel = self.native.0;
let handle = initiate_call(channel, request).await?;
let (handle, status, version) = poll_call(handle).await?;
build_response(handle, status, version)
}
pub async fn get<E>(
&mut self,
path: impl TryInto<http::uri::PathAndQuery, Error = E>,
) -> Result<HttpResponse, Error>
where
E: StdError + Send + Sync + 'static,
http::Error: From<E>,
{
let path: http::uri::PathAndQuery = path.try_into().map_err(Error::from_specific)?;
let req = Request::builder()
.method(Method::GET)
.uri(path)
.body(Empty::<bytes::Bytes>::new().map_err(|e| match e {}))
.map_err(Error::Error)?;
let res = self.internal_call(req).await?;
let (head, body) = res.into_parts();
let response_bytes = to_bytes(body).await.map_err(Error::from_specific)?;
Ok(HttpResponse {
code: head.status,
body: response_bytes,
})
}
}
impl<D, E, B> tower::Service<http::Request<B>> for GrpcChannel
where
D: Buf + IntoIterator<Item = u8> + Send + From<Vec<u8>> + Unpin + 'static,
E: Into<Box<dyn StdError + Send + Sync>> + 'static,
B: http_body::Body<Data = D, Error = E> + Send + Unpin + 'static,
{
type Response = http::Response<UnsyncBoxBody<D, E>>;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, request: http::Request<B>) -> Self::Future {
let channel = self.native.0;
let f = async move {
let handle = initiate_call(channel, request).await?;
let (handle, status, version) = poll_call(handle).await?;
build_response(handle, status, version)
};
Box::pin(f)
}
}
async fn poll_call(
handle: ffi::AsyncCallHandle,
) -> Result<(ffi::HttpResponseHandle, http::StatusCode, http::Version), Error> {
fn poll_call(
handle: ffi::AsyncCallHandle,
) -> Poll<Result<(ffi::HttpResponseHandle, u16, ffi::Version), ErrorCode>> {
let ready = ffi::channel_call_poll(handle);
match ready {
Ok(_) => {
let mut status = 0;
let mut version = ffi::Version::V1_0;
let response_handle =
ffi::channel_call_retrieve_response(handle, &mut status, &mut version)?;
Poll::Ready(Ok((response_handle, status, version)))
}
Err(e) if e == ErrorCode::Unavailable => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}
let (handle, status, version) = PollableResponse {
handle,
poll_fn: poll_call,
drop_fn: ffi::channel_call_drop,
}
.await?;
let status = http::StatusCode::from_u16(status).map_err(|e| Error::Other(e.into()))?;
let version = match version {
ffi::Version::V0_9 => http::Version::HTTP_09,
ffi::Version::V1_0 => http::Version::HTTP_10,
ffi::Version::V1_1 => http::Version::HTTP_11,
ffi::Version::V2_0 => http::Version::HTTP_2,
ffi::Version::V3_0 => http::Version::HTTP_3,
version => return Err(Error::InvalidHttpVersion(version as u32)),
};
Ok((handle, status, version))
}
async fn initiate_call<D, E, B>(
channel: ffi::ChannelHandle,
request: http::Request<B>,
) -> Result<ffi::AsyncCallHandle, Error>
where
D: Buf + IntoIterator<Item = u8>,
E: Into<Box<dyn StdError + Send + Sync>> + 'static,
B: http_body::Body<Data = D, Error = E> + Send + Unpin,
{
let (head, body) = request.into_parts();
let body = to_bytes(body).await.map_err(|e| Error::Other(e.into()))?;
let ffi_version = match head.version {
http::Version::HTTP_09 => ffi::Version::V0_9,
http::Version::HTTP_10 => ffi::Version::V1_0,
http::Version::HTTP_11 => ffi::Version::V1_1,
http::Version::HTTP_2 => ffi::Version::V2_0,
http::Version::HTTP_3 => ffi::Version::V3_0,
unhandled => {
unreachable!("unhandled HTTP version: {unhandled:?}",)
}
};
let ffi_method = match head.method {
http::Method::GET => ffi::Method::Get,
http::Method::POST => ffi::Method::Post,
http::Method::PUT => ffi::Method::Put,
http::Method::DELETE => ffi::Method::Delete,
http::Method::PATCH => ffi::Method::Patch,
http::Method::HEAD => ffi::Method::Head,
http::Method::OPTIONS => ffi::Method::Options,
http::Method::TRACE => ffi::Method::Trace,
http::Method::CONNECT => ffi::Method::Connect,
unhandled => {
unreachable!("unhandled HTTP method: {unhandled:?}",)
}
};
let mut headers = vec![];
for (key, value) in head.headers.iter() {
#[cfg(not(target_pointer_width = "32"))]
compile_error!("the below code is only valid with 32-bit pointers and usize");
headers.push(ffi::HttpHeaderKeyValue {
key_str_ptr: key.as_str().as_ptr() as u32,
key_str_len: key.as_str().len() as u32,
value_ptr: value.as_ref().as_ptr() as u32,
value_len: value.as_ref().len() as u32,
});
}
let handle = ffi::channel_call(
channel,
&format!("{}", head.uri),
ffi_version,
ffi_method,
&headers,
&body,
)
.map_err(|e| Error::FFIError(e.into()))?;
Ok(handle)
}
fn build_response<D, E>(
handle: ffi::HttpResponseHandle,
status: http::StatusCode,
version: http::Version,
) -> Result<http::Response<UnsyncBoxBody<D, E>>, Error>
where
D: Buf + Unpin + Send + From<Vec<u8>> + 'static,
E: Into<Box<dyn StdError + Send + Sync>> + 'static,
{
let raw_body = ffi::http_response_body(handle);
let body = http_body::Full::from(raw_body)
.map_err(|e| match e {})
.boxed_unsync();
let header_count = ffi::http_response_header_count(handle);
let mut headers = http::HeaderMap::<HeaderValue>::default();
for index in 0..header_count {
let key = ffi::http_response_header_key(handle, index);
let value = ffi::http_response_header_value(handle, index);
headers.insert(
HeaderName::from_lowercase(key.as_bytes()).map_err(Error::from_specific)?,
value.try_into().map_err(Error::from_specific)?,
);
}
ffi::http_response_drop(handle);
let mut response = http::Response::new(body);
*response.status_mut() = status;
*response.version_mut() = version;
*response.headers_mut() = headers;
Ok(response)
}
async fn to_bytes<
D: Buf + IntoIterator<Item = u8>,
E: Into<Box<dyn StdError + Send + Sync>> + 'static,
B: http_body::Body<Data = D, Error = E> + Unpin,
>(
mut body: B,
) -> Result<Vec<u8>, E> {
let mut accum = vec![];
while let Some(data) = body.data().await {
accum.extend(data?);
if accum.len() > (1usize << 31) {
panic!("too long message");
}
}
Ok(accum)
}