oso_cloud/
api.rs

1use reqwest::dns::Resolve;
2use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
3use serde::de::DeserializeOwned;
4/// Internal API functionality not intended for public use.
5use serde::{Deserialize, Serialize};
6use std::io::Read;
7use std::sync::{Arc, RwLock};
8use std::time::Duration;
9use uuid::Uuid;
10
11use crate::{Error, Fact};
12
13static APP_USER_AGENT: &str = "Oso Cloud (rust)";
14
15const MAX_BODY_SIZE: usize = 10 * 1024 * 1024; // 10 MB
16
17#[derive(Clone, Debug, Serialize, Deserialize)]
18pub struct ApiError {
19    message: Option<String>,
20}
21
22#[derive(Clone)]
23pub(crate) struct Client {
24    client: reqwest::Client,
25    pub(crate) url: Arc<String>,
26    last_offset: Arc<RwLock<Option<String>>>,
27}
28
29/// Options for connecting to the Oso Service
30#[derive(Clone)]
31pub struct ConnectOptions<R: Resolve + 'static> {
32    /// A custom DNS resolver. Can be useful for adding custom DNS servers.
33    /// Must implement the `reqwest::dns::Resolve` trait.
34    pub dns_resolver: Option<Arc<R>>,
35
36    /// Path to another root CA certificate to trust when doing certificate
37    /// validation. Useful for trusting certificates signed with an internal
38    /// certificate authority.
39    pub ca_path: Option<String>,
40}
41
42pub(crate) struct ClientBuilder {
43    client_builder: reqwest::ClientBuilder,
44    url: Arc<String>,
45}
46
47impl ClientBuilder {
48    pub(crate) fn new(url: &str, api_key: &str) -> Result<Self, Error> {
49        let mut headers = HeaderMap::new();
50        let mut auth_value = HeaderValue::from_str(&format!("Bearer {api_key}"))
51            .map_err(|e| Error::Input(format!("invalid auth token: {e}")))?;
52        auth_value.set_sensitive(true);
53        headers.insert(AUTHORIZATION, auth_value);
54        headers.insert("X-Oso-Client-Id", HeaderValue::from_static("rust"));
55        headers.insert("Accept", HeaderValue::from_static("application/json"));
56        headers.insert(
57            "X-Oso-Instance-Id",
58            HeaderValue::from_str(&Uuid::new_v4().to_string()).unwrap(),
59        );
60        let client_builder = reqwest::Client::builder()
61            .user_agent(APP_USER_AGENT)
62            .default_headers(headers)
63            .http2_keep_alive_while_idle(true)
64            .http2_keep_alive_interval(Duration::from_secs(30))
65            .http2_keep_alive_timeout(Duration::from_secs(1));
66
67        Ok(Self {
68            client_builder,
69            url: Arc::new(url.to_string()),
70        })
71    }
72
73    /// Override the DNS resolver implementation.
74    pub fn dns_resolver<R: Resolve + 'static>(mut self, resolver: Arc<R>) -> ClientBuilder {
75        self.client_builder = self.client_builder.dns_resolver(resolver);
76        self
77    }
78
79    /// Add another CA certificate to trust
80    pub fn ca_path(mut self, ca_path: &str) -> Result<ClientBuilder, Error> {
81        let mut buf = Vec::new();
82        std::fs::File::open(ca_path)
83            .map_err(|e| Error::Input(format!("Failed to read CA file at path {}: {}", ca_path, e)))?
84            .read_to_end(&mut buf)
85            .map_err(|e| Error::Input(format!("Failed to read CA file at path {}: {}", ca_path, e)))?;
86        let cert = reqwest::Certificate::from_pem(&buf)?;
87        self.client_builder = self.client_builder.add_root_certificate(cert);
88        Ok(self)
89    }
90
91    pub fn build(self) -> Result<Client, Error> {
92        let client = self.client_builder.build()?;
93        Ok(Client {
94            client,
95            url: self.url.clone(),
96            last_offset: Default::default(),
97        })
98    }
99}
100
101impl Client {
102    async fn handle_error<T>(response: reqwest::Response) -> Result<T, Error>
103    where
104        T: DeserializeOwned,
105    {
106        if !response.status().is_success() {
107            let status = response.status();
108            let request_id = response
109                .headers()
110                .get("X-Request-ID")
111                .and_then(|h| h.to_str().ok())
112                .map(|s| s.to_string());
113            let message = match response.json::<ApiError>().await {
114                Ok(err) => err.message.unwrap_or_else(|| status.to_string()),
115                Err(err) => {
116                    tracing::warn!("failed to parse error response: {:#?}", err);
117                    status.to_string()
118                }
119            };
120            return Err(Error::Server { message, request_id });
121        }
122
123        Ok(response.json().await?)
124    }
125
126    fn set_last_offset(&self, response: &reqwest::Response) {
127        let offset = response.headers().get("OsoOffset").and_then(|h| h.to_str().ok());
128        if let Some(offset) = offset {
129            *self.last_offset.write().unwrap() = Some(offset.to_string());
130        }
131    }
132
133    #[tracing::instrument(skip(self), level = "trace", err)]
134    pub async fn get<Params, Response>(&self, path: &str, params: Params) -> Result<Response, Error>
135    where
136        Params: std::fmt::Debug + Serialize,
137        Response: DeserializeOwned,
138    {
139        let url = format!("{}/api/{path}", self.url, path = path);
140        let mut request = self.client.get(url).query(&params);
141
142        if let Some(offset) = self.last_offset.read().unwrap().as_ref() {
143            request = request.header("OsoOffset", offset);
144        }
145        request = request.header(
146            "X-Request-ID",
147            HeaderValue::from_str(&Uuid::new_v4().to_string()).unwrap(),
148        );
149        let response = request.send().await?;
150        Self::handle_error(response).await
151    }
152
153    #[tracing::instrument(skip(self), level = "trace", err)]
154    pub async fn post<Body, Response>(&self, path: &str, body: Body, is_mutation: bool) -> Result<Response, Error>
155    where
156        Body: std::fmt::Debug + Serialize,
157        Response: DeserializeOwned,
158    {
159        let url = format!("{}/api/{path}", self.url);
160
161        let body_vec = serde_json::to_vec(&body).unwrap();
162        if body_vec.len() > MAX_BODY_SIZE {
163            return Err(Error::Input("Request payload too large".to_owned()));
164        }
165
166        let mut request = self.client.post(url).json(&body);
167        if let Some(offset) = self.last_offset.read().unwrap().as_ref() {
168            request = request.header("OsoOffset", offset);
169        }
170        request = request.header(
171            "X-Request-ID",
172            HeaderValue::from_str(&Uuid::new_v4().to_string()).unwrap(),
173        );
174        let response = request.send().await?;
175        if is_mutation {
176            self.set_last_offset(&response);
177        }
178        Self::handle_error(response).await
179    }
180
181    pub async fn bulk(&self, delete: &[Fact<'_>], tell: &[Fact<'_>]) -> Result<(), Error> {
182        #[derive(Debug, Serialize)]
183        struct BulkRequest<'a> {
184            delete: &'a [Fact<'a>],
185            tell: &'a [Fact<'a>],
186        }
187
188        let _: crate::ApiResult = self.post("bulk", BulkRequest { delete, tell }, true).await?;
189        Ok(())
190    }
191}