graft_client/
net.rs

1use bytes::{BufMut, BytesMut};
2use culprit::{Culprit, ResultExt};
3use graft_core::byte_unit::ByteUnit;
4use graft_proto::common::v1::GraftErr;
5use http::{
6    HeaderName, HeaderValue, Uri,
7    header::AUTHORIZATION,
8    uri::{Builder, PathAndQuery},
9};
10use std::{any::type_name, sync::Arc, time::Duration};
11use tracing::field;
12use url::Url;
13
14use ureq::{Agent, Proxy, config::AutoHeaderValue};
15
16use crate::{USER_AGENT, error::ClientErr};
17
18use prost::Message;
19
20const CONTENT_TYPE: HeaderName = HeaderName::from_static("content-type");
21const APPLICATION_PROTOBUF: HeaderValue = HeaderValue::from_static("application/x-protobuf");
22const MAX_READ_SIZE: ByteUnit = ByteUnit::from_mb(8);
23
24#[derive(Debug, Clone)]
25pub(crate) struct EndpointBuilder {
26    endpoint: Uri,
27}
28
29impl From<Url> for EndpointBuilder {
30    fn from(endpoint: Url) -> Self {
31        let endpoint: Uri = endpoint.as_str().parse().expect("url is valid uri");
32        assert!(
33            endpoint.path_and_query().is_none_or(|p| p.path() == "/"),
34            "endpoint can not include a path {endpoint}"
35        );
36        Self { endpoint }
37    }
38}
39
40impl EndpointBuilder {
41    pub(crate) fn build(&self, path: &'static str) -> Result<Uri, http::Error> {
42        assert!(path.starts_with("/"), "path must begin with /");
43        let path = PathAndQuery::from_static(path);
44        let uri = Builder::from(self.endpoint.clone())
45            .path_and_query(path)
46            .build()?;
47        Ok(uri)
48    }
49}
50
51#[derive(Debug, Clone)]
52pub struct NetClient {
53    api_token: Option<String>,
54    agent: Agent,
55}
56
57impl NetClient {
58    pub fn new(api_token: Option<String>) -> Self {
59        Self::new_with_proxy(api_token, Proxy::try_from_env())
60    }
61
62    pub fn new_with_proxy(api_token: Option<String>, proxy: Option<Proxy>) -> Self {
63        Self {
64            api_token,
65            agent: Agent::config_builder()
66                .user_agent(AutoHeaderValue::Provided(Arc::new(USER_AGENT.to_string())))
67                .proxy(proxy)
68                .http_status_as_error(false)
69                .max_idle_age(Duration::from_secs(300))
70                .timeout_connect(Some(Duration::from_secs(60)))
71                .timeout_recv_response(Some(Duration::from_secs(60)))
72                .timeout_global(Some(Duration::from_secs(300)))
73                .build()
74                .new_agent(),
75        }
76    }
77
78    pub(crate) fn send<Msg: Message, Resp: Message + Default>(
79        &self,
80        uri: Uri,
81        msg: Msg,
82    ) -> Result<Resp, Culprit<ClientErr>> {
83        let span = tracing::trace_span!(
84            "NetClient::send",
85            path = uri.path(),
86            status = field::Empty,
87            err = field::Empty
88        )
89        .entered();
90
91        let req = self
92            .agent
93            .post(uri)
94            .header(CONTENT_TYPE, APPLICATION_PROTOBUF);
95
96        let req = if let Some(token) = &self.api_token {
97            req.header(AUTHORIZATION, format!("Bearer {token}"))
98        } else {
99            req
100        };
101
102        let resp = match req.send(&msg.encode_to_vec()) {
103            Ok(resp) => resp,
104            Err(err) => {
105                span.record("err", err.to_string());
106                return Err(err.into());
107            }
108        };
109
110        let status = resp.status();
111        span.record("status", status.as_u16());
112
113        let content_type = resp.headers().get(CONTENT_TYPE);
114        if content_type != Some(&APPLICATION_PROTOBUF) {
115            return Err(
116                Culprit::new(ClientErr::ProtobufDecodeErr).with_note(format!(
117                    "expected content type '{}' but received {:?}",
118                    APPLICATION_PROTOBUF.to_str().unwrap(),
119                    content_type
120                )),
121            );
122        }
123
124        let success = (200..300).contains(&status.as_u16());
125
126        // read the response into a Bytes object
127        let reader = resp
128            .into_body()
129            .into_with_config()
130            .limit(MAX_READ_SIZE.as_u64());
131        let mut writer = BytesMut::new().writer();
132        std::io::copy(&mut reader.reader(), &mut writer).or_into_ctx()?;
133        let body = writer.into_inner().freeze();
134        let body_size = ByteUnit::new(body.len() as u64);
135
136        if success {
137            Ok(Resp::decode(body).map_err(|err| {
138                let note = format!(
139                    "failed to decode response body into {} from buffer of size {}",
140                    type_name::<Resp>(),
141                    body_size
142                );
143                Culprit::from_err(err).with_note(note)
144            })?)
145        } else {
146            let err = GraftErr::decode(body).map_err(|err| {
147                let note = format!(
148                    "failed to decode response body into GraftErr from buffer of size {body_size}"
149                );
150                Culprit::from_err(err).with_note(note)
151            })?;
152
153            // 5xx errors are not expected from client requests unless the graft
154            // error signals that the service is temporarily unavailable
155            precept::expect_always_or_unreachable!(
156                !(500..600).contains(&status.as_u16()) || err.code() == graft_proto::GraftErrCode::ServiceUnavailable,
157                "client requests should not return 5xx errors",
158                {
159                    "status": status.as_u16(),
160                    "code": err.code().as_str_name(),
161                    "message": err.message
162                }
163            );
164            Err(err.into())
165        }
166    }
167}