1use reqwest::dns::Resolve;
2use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
3use serde::de::DeserializeOwned;
4use serde::{Deserialize, Serialize};
6use std::io::Read;
7use std::sync::{Arc, RwLock};
8use std::time::Duration;
9use uuid::Uuid;
10
11use crate::{Error, Fact};
12
13static APP_USER_AGENT: &str = "Oso Cloud (rust)";
14
15const MAX_BODY_SIZE: usize = 10 * 1024 * 1024; #[derive(Clone, Debug, Serialize, Deserialize)]
18pub struct ApiError {
19 message: Option<String>,
20}
21
22#[derive(Clone)]
23pub(crate) struct Client {
24 client: reqwest::Client,
25 pub(crate) url: Arc<String>,
26 last_offset: Arc<RwLock<Option<String>>>,
27}
28
29#[derive(Clone)]
31pub struct ConnectOptions<R: Resolve + 'static> {
32 pub dns_resolver: Option<Arc<R>>,
35
36 pub ca_path: Option<String>,
40}
41
42pub(crate) struct ClientBuilder {
43 client_builder: reqwest::ClientBuilder,
44 url: Arc<String>,
45}
46
47impl ClientBuilder {
48 pub(crate) fn new(url: &str, api_key: &str) -> Result<Self, Error> {
49 let mut headers = HeaderMap::new();
50 let mut auth_value = HeaderValue::from_str(&format!("Bearer {api_key}"))
51 .map_err(|e| Error::Input(format!("invalid auth token: {e}")))?;
52 auth_value.set_sensitive(true);
53 headers.insert(AUTHORIZATION, auth_value);
54 headers.insert("X-Oso-Client-Id", HeaderValue::from_static("rust"));
55 headers.insert("Accept", HeaderValue::from_static("application/json"));
56 headers.insert(
57 "X-Oso-Instance-Id",
58 HeaderValue::from_str(&Uuid::new_v4().to_string()).unwrap(),
59 );
60 let client_builder = reqwest::Client::builder()
61 .user_agent(APP_USER_AGENT)
62 .default_headers(headers)
63 .http2_keep_alive_while_idle(true)
64 .http2_keep_alive_interval(Duration::from_secs(30))
65 .http2_keep_alive_timeout(Duration::from_secs(1));
66
67 Ok(Self {
68 client_builder,
69 url: Arc::new(url.to_string()),
70 })
71 }
72
73 pub fn dns_resolver<R: Resolve + 'static>(mut self, resolver: Arc<R>) -> ClientBuilder {
75 self.client_builder = self.client_builder.dns_resolver(resolver);
76 self
77 }
78
79 pub fn ca_path(mut self, ca_path: &str) -> Result<ClientBuilder, Error> {
81 let mut buf = Vec::new();
82 std::fs::File::open(ca_path)
83 .map_err(|e| Error::Input(format!("Failed to read CA file at path {}: {}", ca_path, e)))?
84 .read_to_end(&mut buf)
85 .map_err(|e| Error::Input(format!("Failed to read CA file at path {}: {}", ca_path, e)))?;
86 let cert = reqwest::Certificate::from_pem(&buf)?;
87 self.client_builder = self.client_builder.add_root_certificate(cert);
88 Ok(self)
89 }
90
91 pub fn build(self) -> Result<Client, Error> {
92 let client = self.client_builder.build()?;
93 Ok(Client {
94 client,
95 url: self.url.clone(),
96 last_offset: Default::default(),
97 })
98 }
99}
100
101impl Client {
102 async fn handle_error<T>(response: reqwest::Response) -> Result<T, Error>
103 where
104 T: DeserializeOwned,
105 {
106 if !response.status().is_success() {
107 let status = response.status();
108 let request_id = response
109 .headers()
110 .get("X-Request-ID")
111 .and_then(|h| h.to_str().ok())
112 .map(|s| s.to_string());
113 let message = match response.json::<ApiError>().await {
114 Ok(err) => err.message.unwrap_or_else(|| status.to_string()),
115 Err(err) => {
116 tracing::warn!("failed to parse error response: {:#?}", err);
117 status.to_string()
118 }
119 };
120 return Err(Error::Server { message, request_id });
121 }
122
123 Ok(response.json().await?)
124 }
125
126 fn set_last_offset(&self, response: &reqwest::Response) {
127 let offset = response.headers().get("OsoOffset").and_then(|h| h.to_str().ok());
128 if let Some(offset) = offset {
129 *self.last_offset.write().unwrap() = Some(offset.to_string());
130 }
131 }
132
133 #[tracing::instrument(skip(self), level = "trace", err)]
134 pub async fn get<Params, Response>(&self, path: &str, params: Params) -> Result<Response, Error>
135 where
136 Params: std::fmt::Debug + Serialize,
137 Response: DeserializeOwned,
138 {
139 let url = format!("{}/api/{path}", self.url, path = path);
140 let mut request = self.client.get(url).query(¶ms);
141
142 if let Some(offset) = self.last_offset.read().unwrap().as_ref() {
143 request = request.header("OsoOffset", offset);
144 }
145 request = request.header(
146 "X-Request-ID",
147 HeaderValue::from_str(&Uuid::new_v4().to_string()).unwrap(),
148 );
149 let response = request.send().await?;
150 Self::handle_error(response).await
151 }
152
153 #[tracing::instrument(skip(self), level = "trace", err)]
154 pub async fn post<Body, Response>(&self, path: &str, body: Body, is_mutation: bool) -> Result<Response, Error>
155 where
156 Body: std::fmt::Debug + Serialize,
157 Response: DeserializeOwned,
158 {
159 let url = format!("{}/api/{path}", self.url);
160
161 let body_vec = serde_json::to_vec(&body).unwrap();
162 if body_vec.len() > MAX_BODY_SIZE {
163 return Err(Error::Input("Request payload too large".to_owned()));
164 }
165
166 let mut request = self.client.post(url).json(&body);
167 if let Some(offset) = self.last_offset.read().unwrap().as_ref() {
168 request = request.header("OsoOffset", offset);
169 }
170 request = request.header(
171 "X-Request-ID",
172 HeaderValue::from_str(&Uuid::new_v4().to_string()).unwrap(),
173 );
174 let response = request.send().await?;
175 if is_mutation {
176 self.set_last_offset(&response);
177 }
178 Self::handle_error(response).await
179 }
180
181 pub async fn bulk(&self, delete: &[Fact<'_>], tell: &[Fact<'_>]) -> Result<(), Error> {
182 #[derive(Debug, Serialize)]
183 struct BulkRequest<'a> {
184 delete: &'a [Fact<'a>],
185 tell: &'a [Fact<'a>],
186 }
187
188 let _: crate::ApiResult = self.post("bulk", BulkRequest { delete, tell }, true).await?;
189 Ok(())
190 }
191}