use crate::ApiError;
use crate::Result;
use futures_core::future::BoxFuture;
use http_body_util::{BodyExt, Full};
use hyper::{http, Method, Version};
use minicbor::{CborLen, Decode, Encode};
use ockam::identity::SecureClient;
use ockam_core::api::Request;
use ockam_core::{cbor_encode_preallocate, Decodable, Encodable, Encoded, Message, TryClone};
use ockam_node::Context;
use std::str::FromStr;
use std::task::Poll;
use tonic::body::BoxBody;
use tonic::codegen::Service;
pub struct SecureClientService {
secure_client: SecureClient,
ctx: Option<Context>,
service_address: String,
}
impl Clone for SecureClientService {
fn clone(&self) -> Self {
let ctx_clone = if let Some(ctx) = &self.ctx {
ctx.try_clone().ok()
} else {
None
};
SecureClientService {
secure_client: self.secure_client.clone(),
ctx: ctx_clone,
service_address: self.service_address.clone(),
}
}
}
impl SecureClientService {
pub fn new(
secure_client: SecureClient,
ctx: &Context,
service_address: &str,
) -> SecureClientService {
SecureClientService {
secure_client,
ctx: ctx.try_clone().ok(),
service_address: service_address.to_string(),
}
}
}
impl Service<http::Request<BoxBody>> for SecureClientService {
type Response = http::Response<BoxBody>;
type Error = ApiError;
type Future = BoxFuture<'static, Result<Self::Response>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, request: http::Request<BoxBody>) -> Self::Future {
let mut service = self.clone();
Box::pin(async move { service.send_request(request).await })
}
}
impl SecureClientService {
async fn send_request(
&mut self,
request: http::Request<BoxBody>,
) -> Result<http::Response<BoxBody>> {
if let Some(ctx) = &self.ctx {
let ockam_request_body = Self::make_ockam_request_body(request).await?;
trace!(
"Sending a request to {} => 0#{}",
self.secure_client.secure_route(),
self.service_address
);
let r = self
.secure_client
.tell(
ctx,
&self.service_address,
Request::post("/").body(ockam_request_body),
)
.await?;
if let Some(e) = r.error()? {
trace!("Sending a request - received an error {e}");
}
};
http::Response::builder()
.body(BoxBody::default())
.map_err(ApiError::message)
}
async fn make_ockam_request_body(request: http::Request<BoxBody>) -> Result<OckamGrpcRequest> {
let mut bytes: Vec<u8> = Vec::new();
let (head, mut body) = request.into_parts();
while let Some(frame) = body.frame().await {
if let Ok(f) = frame {
if let Some(chunk) = f.data_ref() {
bytes.extend_from_slice(chunk);
}
}
}
Ok(OckamGrpcRequest::from(http::Request::from_parts(
head, bytes,
)))
}
}
#[derive(Debug, Clone, Encode, Decode, CborLen, Message, PartialEq, Eq)]
#[cbor(map)]
pub struct OckamGrpcRequest {
#[n(1)]
method: String,
#[n(2)]
uri: String,
#[n(3)]
version: HttpVersion,
#[n(4)]
headers: Vec<(String, String)>,
#[cbor(with = "minicbor::bytes")]
#[n(5)]
body: Vec<u8>,
}
#[derive(Debug, Clone, PartialEq, Eq, Encode, Decode, CborLen)]
#[cbor(index_only)]
enum HttpVersion {
#[n(0)]
Http09,
#[n(1)]
Http10,
#[n(2)]
Http11,
#[n(3)]
Http2,
#[n(4)]
Http3,
}
impl From<http::Request<Vec<u8>>> for OckamGrpcRequest {
fn from(req: http::Request<Vec<u8>>) -> Self {
Self {
method: req.method().to_string(),
uri: req.uri().to_string(),
version: match req.version() {
Version::HTTP_09 => HttpVersion::Http09,
Version::HTTP_10 => HttpVersion::Http10,
Version::HTTP_11 => HttpVersion::Http11,
Version::HTTP_2 => HttpVersion::Http2,
Version::HTTP_3 => HttpVersion::Http3,
_ => HttpVersion::Http3,
},
headers: req
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap().to_string()))
.collect(),
body: req.into_body(),
}
}
}
impl OckamGrpcRequest {
pub fn make_http_request(self) -> Result<http::Request<BoxBody>> {
let mut req = http::Request::builder();
req = req.method(Method::from_str(&self.method).map_err(ApiError::message)?);
let version = match self.version {
HttpVersion::Http09 => Version::HTTP_09,
HttpVersion::Http10 => Version::HTTP_10,
HttpVersion::Http11 => Version::HTTP_11,
HttpVersion::Http2 => Version::HTTP_2,
HttpVersion::Http3 => Version::HTTP_3,
};
req = req.version(version);
req = req.uri(self.uri);
for (k, v) in self.headers.iter() {
req = req.header(k.to_string(), v.to_string());
}
let body = Full::new(bytes::Bytes::from(self.body))
.map_err(|never| match never {})
.boxed_unsync();
req.body(body).map_err(ApiError::message)
}
}
impl Encodable for OckamGrpcRequest {
fn encode(self) -> ockam_core::Result<Encoded> {
cbor_encode_preallocate(self)
}
}
impl Decodable for OckamGrpcRequest {
fn decode(e: &[u8]) -> ockam_core::Result<Self> {
Ok(minicbor::decode(e)?)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ockam_core::Decodable;
#[test]
fn test_make_ockam_request_body() {
let request = make_http_request();
let ockam_request_body = OckamGrpcRequest::from(request);
assert_eq!(
<OckamGrpcRequest as Decodable>::decode(
minicbor::to_vec(ockam_request_body.clone())
.unwrap()
.as_slice()
)
.unwrap(),
ockam_request_body
);
}
fn make_http_request() -> http::Request<Vec<u8>> {
http::Request::builder()
.method(Method::GET)
.uri("http://localhost:8080/")
.body("hello".as_bytes().to_vec())
.unwrap()
}
}