1use bytes::{BufMut, BytesMut};
2use culprit::{Culprit, ResultExt};
3use graft_core::byte_unit::ByteUnit;
4use graft_proto::common::v1::GraftErr;
5use http::{
6 HeaderName, HeaderValue, Uri,
7 header::AUTHORIZATION,
8 uri::{Builder, PathAndQuery},
9};
10use std::{any::type_name, sync::Arc, time::Duration};
11use tracing::field;
12use url::Url;
13
14use ureq::{Agent, Proxy, config::AutoHeaderValue};
15
16use crate::{USER_AGENT, error::ClientErr};
17
18use prost::Message;
19
20const CONTENT_TYPE: HeaderName = HeaderName::from_static("content-type");
21const APPLICATION_PROTOBUF: HeaderValue = HeaderValue::from_static("application/x-protobuf");
22const MAX_READ_SIZE: ByteUnit = ByteUnit::from_mb(8);
23
24#[derive(Debug, Clone)]
25pub(crate) struct EndpointBuilder {
26 endpoint: Uri,
27}
28
29impl From<Url> for EndpointBuilder {
30 fn from(endpoint: Url) -> Self {
31 let endpoint: Uri = endpoint.as_str().parse().expect("url is valid uri");
32 assert!(
33 endpoint.path_and_query().is_none_or(|p| p.path() == "/"),
34 "endpoint can not include a path {endpoint}"
35 );
36 Self { endpoint }
37 }
38}
39
40impl EndpointBuilder {
41 pub(crate) fn build(&self, path: &'static str) -> Result<Uri, http::Error> {
42 assert!(path.starts_with("/"), "path must begin with /");
43 let path = PathAndQuery::from_static(path);
44 let uri = Builder::from(self.endpoint.clone())
45 .path_and_query(path)
46 .build()?;
47 Ok(uri)
48 }
49}
50
51#[derive(Debug, Clone)]
52pub struct NetClient {
53 api_token: Option<String>,
54 agent: Agent,
55}
56
57impl NetClient {
58 pub fn new(api_token: Option<String>) -> Self {
59 Self::new_with_proxy(api_token, Proxy::try_from_env())
60 }
61
62 pub fn new_with_proxy(api_token: Option<String>, proxy: Option<Proxy>) -> Self {
63 Self {
64 api_token,
65 agent: Agent::config_builder()
66 .user_agent(AutoHeaderValue::Provided(Arc::new(USER_AGENT.to_string())))
67 .proxy(proxy)
68 .http_status_as_error(false)
69 .max_idle_age(Duration::from_secs(300))
70 .timeout_connect(Some(Duration::from_secs(60)))
71 .timeout_recv_response(Some(Duration::from_secs(60)))
72 .timeout_global(Some(Duration::from_secs(300)))
73 .build()
74 .new_agent(),
75 }
76 }
77
78 pub(crate) fn send<Msg: Message, Resp: Message + Default>(
79 &self,
80 uri: Uri,
81 msg: Msg,
82 ) -> Result<Resp, Culprit<ClientErr>> {
83 let span = tracing::trace_span!(
84 "NetClient::send",
85 path = uri.path(),
86 status = field::Empty,
87 err = field::Empty
88 )
89 .entered();
90
91 let req = self
92 .agent
93 .post(uri)
94 .header(CONTENT_TYPE, APPLICATION_PROTOBUF);
95
96 let req = if let Some(token) = &self.api_token {
97 req.header(AUTHORIZATION, format!("Bearer {token}"))
98 } else {
99 req
100 };
101
102 let resp = match req.send(&msg.encode_to_vec()) {
103 Ok(resp) => resp,
104 Err(err) => {
105 span.record("err", err.to_string());
106 return Err(err.into());
107 }
108 };
109
110 let status = resp.status();
111 span.record("status", status.as_u16());
112
113 let content_type = resp.headers().get(CONTENT_TYPE);
114 if content_type != Some(&APPLICATION_PROTOBUF) {
115 return Err(
116 Culprit::new(ClientErr::ProtobufDecodeErr).with_note(format!(
117 "expected content type '{}' but received {:?}",
118 APPLICATION_PROTOBUF.to_str().unwrap(),
119 content_type
120 )),
121 );
122 }
123
124 let success = (200..300).contains(&status.as_u16());
125
126 let reader = resp
128 .into_body()
129 .into_with_config()
130 .limit(MAX_READ_SIZE.as_u64());
131 let mut writer = BytesMut::new().writer();
132 std::io::copy(&mut reader.reader(), &mut writer).or_into_ctx()?;
133 let body = writer.into_inner().freeze();
134 let body_size = ByteUnit::new(body.len() as u64);
135
136 if success {
137 Ok(Resp::decode(body).map_err(|err| {
138 let note = format!(
139 "failed to decode response body into {} from buffer of size {}",
140 type_name::<Resp>(),
141 body_size
142 );
143 Culprit::from_err(err).with_note(note)
144 })?)
145 } else {
146 let err = GraftErr::decode(body).map_err(|err| {
147 let note = format!(
148 "failed to decode response body into GraftErr from buffer of size {body_size}"
149 );
150 Culprit::from_err(err).with_note(note)
151 })?;
152
153 precept::expect_always_or_unreachable!(
156 !(500..600).contains(&status.as_u16()) || err.code() == graft_proto::GraftErrCode::ServiceUnavailable,
157 "client requests should not return 5xx errors",
158 {
159 "status": status.as_u16(),
160 "code": err.code().as_str_name(),
161 "message": err.message
162 }
163 );
164 Err(err.into())
165 }
166 }
167}