use crate::serde::{SerdeError, SerializableStruct, ShapeDeserializer};
use crate::{Schema, ShapeId};
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::endpoint::Endpoint;
pub trait ClientProtocolInner: Send + Sync + std::fmt::Debug {
type Request;
type Response;
fn protocol_id(&self) -> &ShapeId;
fn serialize_request(
&self,
input: &dyn SerializableStruct,
input_schema: &Schema,
endpoint: &str,
cfg: &ConfigBag,
) -> Result<Self::Request, SerdeError>;
fn deserialize_response<'a>(
&self,
response: &'a Self::Response,
output_schema: &Schema,
cfg: &ConfigBag,
) -> Result<Box<dyn ShapeDeserializer + 'a>, SerdeError>;
fn update_endpoint(
&self,
request: &mut Self::Request,
endpoint: &Endpoint,
cfg: &ConfigBag,
) -> Result<(), SerdeError>;
fn payload_codec(&self) -> Option<&dyn crate::codec::DynCodec> {
None
}
}
pub trait ClientProtocol<
Req = aws_smithy_runtime_api::http::Request,
Res = aws_smithy_runtime_api::http::Response,
>: Send + Sync + std::fmt::Debug
{
fn protocol_id(&self) -> &ShapeId;
fn serialize_request(
&self,
input: &dyn SerializableStruct,
input_schema: &Schema,
endpoint: &str,
cfg: &ConfigBag,
) -> Result<Req, SerdeError>;
fn deserialize_response<'a>(
&self,
response: &'a Res,
output_schema: &Schema,
cfg: &ConfigBag,
) -> Result<Box<dyn ShapeDeserializer + 'a>, SerdeError>;
fn update_endpoint(
&self,
request: &mut Req,
endpoint: &Endpoint,
cfg: &ConfigBag,
) -> Result<(), SerdeError>;
fn payload_codec(&self) -> Option<&dyn crate::codec::DynCodec>;
}
impl<P> ClientProtocol<P::Request, P::Response> for P
where
P: ClientProtocolInner,
{
fn protocol_id(&self) -> &ShapeId {
<Self as ClientProtocolInner>::protocol_id(self)
}
fn serialize_request(
&self,
input: &dyn SerializableStruct,
input_schema: &Schema,
endpoint: &str,
cfg: &ConfigBag,
) -> Result<P::Request, SerdeError> {
<Self as ClientProtocolInner>::serialize_request(self, input, input_schema, endpoint, cfg)
}
fn deserialize_response<'a>(
&self,
response: &'a P::Response,
output_schema: &Schema,
cfg: &ConfigBag,
) -> Result<Box<dyn ShapeDeserializer + 'a>, SerdeError> {
<Self as ClientProtocolInner>::deserialize_response(self, response, output_schema, cfg)
}
fn update_endpoint(
&self,
request: &mut P::Request,
endpoint: &Endpoint,
cfg: &ConfigBag,
) -> Result<(), SerdeError> {
<Self as ClientProtocolInner>::update_endpoint(self, request, endpoint, cfg)
}
fn payload_codec(&self) -> Option<&dyn crate::codec::DynCodec> {
<Self as ClientProtocolInner>::payload_codec(self)
}
}
pub fn apply_http_endpoint(
request: &mut aws_smithy_runtime_api::http::Request,
endpoint: &Endpoint,
cfg: &ConfigBag,
) -> Result<(), SerdeError> {
use std::borrow::Cow;
let endpoint_prefix = cfg.load::<aws_smithy_runtime_api::client::endpoint::EndpointPrefix>();
let endpoint_url = match endpoint_prefix {
None => Cow::Borrowed(endpoint.url()),
Some(prefix) => {
let parsed: http::Uri = endpoint
.url()
.parse()
.map_err(|e| SerdeError::custom(format!("invalid endpoint URI: {e}")))?;
let scheme = parsed.scheme_str().unwrap_or_default();
let prefix = prefix.as_str();
let authority = parsed.authority().map(|a| a.as_str()).unwrap_or_default();
let path_and_query = parsed
.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or_default();
Cow::Owned(format!("{scheme}://{prefix}{authority}{path_and_query}"))
}
};
request.uri_mut().set_endpoint(&endpoint_url).map_err(|e| {
SerdeError::custom(format!("failed to apply endpoint `{endpoint_url}`: {e}"))
})?;
for (header_name, header_values) in endpoint.headers() {
request.headers_mut().remove(header_name);
for value in header_values {
request
.headers_mut()
.append(header_name.to_owned(), value.to_owned());
}
}
Ok(())
}
#[derive(Debug)]
pub struct SharedClientProtocol<
Req = aws_smithy_runtime_api::http::Request,
Res = aws_smithy_runtime_api::http::Response,
> {
inner: std::sync::Arc<dyn ClientProtocol<Req, Res>>,
}
impl<Req, Res> Clone for SharedClientProtocol<Req, Res> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<Req, Res> SharedClientProtocol<Req, Res>
where
Req: 'static,
Res: 'static,
{
pub fn new<P>(protocol: P) -> Self
where
P: ClientProtocol<Req, Res> + 'static,
{
Self {
inner: std::sync::Arc::new(protocol),
}
}
}
impl<Req, Res> std::ops::Deref for SharedClientProtocol<Req, Res> {
type Target = dyn ClientProtocol<Req, Res>;
fn deref(&self) -> &Self::Target {
&*self.inner
}
}
impl aws_smithy_types::config_bag::Storable
for SharedClientProtocol<
aws_smithy_runtime_api::http::Request,
aws_smithy_runtime_api::http::Response,
>
{
type Storer = aws_smithy_types::config_bag::StoreReplace<Self>;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::serde::{SerdeError, SerializableStruct, ShapeDeserializer};
use crate::{Schema, ShapeId};
use aws_smithy_runtime_api::http::{Request, Response};
use aws_smithy_types::body::SdkBody;
use aws_smithy_types::config_bag::{ConfigBag, Layer};
use aws_smithy_types::endpoint::Endpoint;
#[derive(Debug)]
struct StubProtocol;
static STUB_ID: ShapeId = ShapeId::from_static("test#StubProtocol", "test", "StubProtocol");
impl ClientProtocolInner for StubProtocol {
type Request = Request;
type Response = Response;
fn protocol_id(&self) -> &ShapeId {
&STUB_ID
}
fn serialize_request(
&self,
_input: &dyn SerializableStruct,
_input_schema: &Schema,
_endpoint: &str,
_cfg: &ConfigBag,
) -> Result<Request, SerdeError> {
unimplemented!()
}
fn deserialize_response<'a>(
&self,
_response: &'a Response,
_output_schema: &Schema,
_cfg: &ConfigBag,
) -> Result<Box<dyn ShapeDeserializer + 'a>, SerdeError> {
unimplemented!()
}
fn update_endpoint(
&self,
request: &mut Request,
endpoint: &Endpoint,
cfg: &ConfigBag,
) -> Result<(), SerdeError> {
apply_http_endpoint(request, endpoint, cfg)
}
}
fn request_with_uri(uri: &str) -> Request {
let mut req = Request::new(SdkBody::empty());
req.set_uri(uri).unwrap();
req
}
#[test]
fn basic_endpoint() {
let proto = StubProtocol;
let mut req = request_with_uri("/original/path");
let endpoint = Endpoint::builder()
.url("https://service.us-east-1.amazonaws.com")
.build();
let cfg = ConfigBag::base();
ClientProtocolInner::update_endpoint(&proto, &mut req, &endpoint, &cfg).unwrap();
assert_eq!(
req.uri(),
"https://service.us-east-1.amazonaws.com/original/path"
);
}
#[test]
fn endpoint_with_prefix() {
let proto = StubProtocol;
let mut req = request_with_uri("/path");
let endpoint = Endpoint::builder()
.url("https://service.us-east-1.amazonaws.com")
.build();
let mut cfg = ConfigBag::base();
let mut layer = Layer::new("test");
layer.store_put(
aws_smithy_runtime_api::client::endpoint::EndpointPrefix::new("myprefix.").unwrap(),
);
cfg.push_shared_layer(layer.freeze());
ClientProtocolInner::update_endpoint(&proto, &mut req, &endpoint, &cfg).unwrap();
assert_eq!(
req.uri(),
"https://myprefix.service.us-east-1.amazonaws.com/path"
);
}
#[test]
fn endpoint_with_headers() {
let proto = StubProtocol;
let mut req = request_with_uri("/path");
let endpoint = Endpoint::builder()
.url("https://example.com")
.header("x-custom", "value1")
.header("x-custom", "value2")
.build();
let cfg = ConfigBag::base();
ClientProtocolInner::update_endpoint(&proto, &mut req, &endpoint, &cfg).unwrap();
assert_eq!(req.uri(), "https://example.com/path");
let values: Vec<&str> = req.headers().get_all("x-custom").collect();
assert_eq!(values, vec!["value1", "value2"]);
}
#[test]
fn endpoint_with_path() {
let proto = StubProtocol;
let mut req = request_with_uri("/operation");
let endpoint = Endpoint::builder().url("https://example.com/base").build();
let cfg = ConfigBag::base();
ClientProtocolInner::update_endpoint(&proto, &mut req, &endpoint, &cfg).unwrap();
assert_eq!(req.uri(), "https://example.com/base/operation");
}
}