mod query;
mod request;
mod routing;
mod streaming;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
#[cfg(not(feature = "tls-rustls"))]
use http_body_util::Full;
#[cfg(not(feature = "tls-rustls"))]
use hyper::body::Bytes;
#[cfg(not(feature = "tls-rustls"))]
use hyper_util::client::legacy::connect::HttpConnector;
#[cfg(not(feature = "tls-rustls"))]
use hyper_util::client::legacy::Client;
#[cfg(not(feature = "tls-rustls"))]
use hyper_util::rt::TokioExecutor;
use crate::error::{ClientError, ClientResult};
use crate::streaming::EventStream;
use crate::transport::Transport;
#[cfg(not(feature = "tls-rustls"))]
type HttpClient = Client<HttpConnector, Full<Bytes>>;
#[cfg(feature = "tls-rustls")]
type HttpClient = crate::tls::HttpsClient;
#[derive(Clone, Debug)]
pub struct RestTransport {
inner: Arc<Inner>,
}
#[derive(Debug)]
struct Inner {
client: HttpClient,
base_url: String,
request_timeout: Duration,
stream_connect_timeout: Duration,
}
impl RestTransport {
pub fn new(base_url: impl Into<String>) -> ClientResult<Self> {
Self::with_timeout(base_url, Duration::from_secs(30))
}
pub fn with_timeout(
base_url: impl Into<String>,
request_timeout: Duration,
) -> ClientResult<Self> {
Self::with_timeouts(base_url, request_timeout, request_timeout)
}
pub fn with_timeouts(
base_url: impl Into<String>,
request_timeout: Duration,
stream_connect_timeout: Duration,
) -> ClientResult<Self> {
Self::with_all_timeouts(
base_url,
request_timeout,
stream_connect_timeout,
Duration::from_secs(10),
)
}
pub fn with_all_timeouts(
base_url: impl Into<String>,
request_timeout: Duration,
stream_connect_timeout: Duration,
connection_timeout: Duration,
) -> ClientResult<Self> {
let base_url = base_url.into();
if base_url.is_empty()
|| (!base_url.starts_with("http://") && !base_url.starts_with("https://"))
{
return Err(ClientError::InvalidEndpoint(format!(
"invalid base URL: {base_url}"
)));
}
#[cfg(not(feature = "tls-rustls"))]
let client = {
let mut connector = HttpConnector::new();
connector.set_connect_timeout(Some(connection_timeout));
connector.set_nodelay(true);
Client::builder(TokioExecutor::new())
.pool_idle_timeout(Duration::from_secs(90))
.build(connector)
};
#[cfg(feature = "tls-rustls")]
let client = crate::tls::build_https_client_with_connect_timeout(
crate::tls::default_tls_config(),
connection_timeout,
);
Ok(Self {
inner: Arc::new(Inner {
client,
base_url: base_url.trim_end_matches('/').to_owned(),
request_timeout,
stream_connect_timeout,
}),
})
}
#[must_use]
pub fn base_url(&self) -> &str {
&self.inner.base_url
}
}
impl Transport for RestTransport {
fn send_request<'a>(
&'a self,
method: &'a str,
params: serde_json::Value,
extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<serde_json::Value>> + Send + 'a>> {
Box::pin(self.execute_request(method, params, extra_headers))
}
fn send_streaming_request<'a>(
&'a self,
method: &'a str,
params: serde_json::Value,
extra_headers: &'a HashMap<String, String>,
) -> Pin<Box<dyn Future<Output = ClientResult<EventStream>> + Send + 'a>> {
Box::pin(self.execute_streaming_request(method, params, extra_headers))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rest_transport_rejects_invalid_url() {
assert!(RestTransport::new("not-a-url").is_err());
}
#[test]
fn rest_transport_stores_base_url() {
let t = RestTransport::new("http://localhost:9090").unwrap();
assert_eq!(t.base_url(), "http://localhost:9090");
}
#[tokio::test]
async fn send_request_via_trait_delegation() {
use http_body_util::Full;
use hyper::body::Bytes;
let response_body = r#"{"status":"ok","data":42}"#;
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
let (stream, _) = listener.accept().await.unwrap();
let io = hyper_util::rt::TokioIo::new(stream);
let body = response_body.to_owned();
tokio::spawn(async move {
let service = hyper::service::service_fn(move |_req| {
let body = body.clone();
async move {
Ok::<_, hyper::Error>(
hyper::Response::builder()
.status(200)
.header("content-type", "application/json")
.body(Full::new(Bytes::from(body)))
.unwrap(),
)
}
});
let _ = hyper_util::server::conn::auto::Builder::new(
hyper_util::rt::TokioExecutor::new(),
)
.serve_connection(io, service)
.await;
});
}
});
let url = format!("http://127.0.0.1:{}", addr.port());
let transport = RestTransport::new(&url).unwrap();
let dyn_transport: &dyn crate::transport::Transport = &transport;
let result = dyn_transport
.send_request("SendMessage", serde_json::json!({}), &HashMap::new())
.await;
assert!(result.is_ok(), "send_request via trait should succeed");
}
#[tokio::test]
async fn send_streaming_request_via_trait_delegation() {
use http_body_util::Full;
use hyper::body::Bytes;
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
let (stream, _) = listener.accept().await.unwrap();
let io = hyper_util::rt::TokioIo::new(stream);
tokio::spawn(async move {
let service = hyper::service::service_fn(|_req| async {
let sse_body = "data: {\"hello\":\"world\"}\n\n";
Ok::<_, hyper::Error>(
hyper::Response::builder()
.status(200)
.header("content-type", "text/event-stream")
.body(Full::new(Bytes::from(sse_body)))
.unwrap(),
)
});
let _ = hyper_util::server::conn::auto::Builder::new(
hyper_util::rt::TokioExecutor::new(),
)
.serve_connection(io, service)
.await;
});
}
});
let url = format!("http://127.0.0.1:{}", addr.port());
let transport = RestTransport::new(&url).unwrap();
let dyn_transport: &dyn crate::transport::Transport = &transport;
let result = dyn_transport
.send_streaming_request(
"SendStreamingMessage",
serde_json::json!({}),
&HashMap::new(),
)
.await;
assert!(
result.is_ok(),
"send_streaming_request via trait should succeed"
);
}
}