oracle_nosql_rust_sdk/
handle.rs

1//
2// Copyright (c) 2024, 2025 Oracle and/or its affiliates. All rights reserved.
3//
4// Licensed under the Universal Permissive License v 1.0 as shown at
5//  https://oss.oracle.com/licenses/upl/
6//
7use crate::auth_common::authentication_provider::AuthenticationProvider;
8use crate::auth_common::instance_principal_auth_provider::InstancePrincipalAuthProvider;
9use crate::auth_common::resource_principal_auth_provider::ResourcePrincipalAuthProvider;
10use crate::auth_common::signer;
11use crate::handle_builder::AuthConfig;
12use crate::handle_builder::AuthType;
13use reqwest::header::{HeaderMap, HeaderValue};
14
15use crate::error::NoSQLErrorCode::InternalRetry;
16use crate::error::{ia_err, user_agent};
17use crate::error::{NoSQLError, NoSQLErrorCode};
18use crate::handle_builder::AuthProvider;
19use crate::handle_builder::HandleBuilder;
20use crate::handle_builder::HandleMode;
21use crate::nson::MapWalker;
22use crate::reader::Reader;
23use crate::writer::Writer;
24
25use std::collections::HashMap;
26use std::result::Result;
27use std::sync::atomic::AtomicUsize;
28use std::sync::atomic::Ordering;
29use std::sync::Arc;
30use std::time::Duration;
31use tracing::{debug, trace};
32use url::Url;
33
34/// **The main database handle**.
35///
36/// This should be created once and used
37/// throughout the application lifetime, across all threads.
38///
39/// Note: there is no need to enclose this struct in an `Rc` or [`Arc`], as it uses an
40/// [`Arc`] internally, so calling `.clone()` on this struct will always return the
41/// same underlying handle.
42#[derive(Clone, Debug)]
43pub struct Handle {
44    // Use an inner Arc so cloning keeps the same contents
45    pub(crate) inner: Arc<HandleRef>,
46}
47
48#[derive(Debug)]
49pub(crate) struct HandleRef {
50    pub(crate) client: reqwest::Client,
51    pub(crate) endpoint: String,
52    pub(crate) serial_version: i16,
53    pub(crate) builder: HandleBuilder,
54    // session doesn't require a tokio Mutex because it's never held across awaits
55    session: std::sync::Mutex<String>,
56    request_id: AtomicUsize,
57    timeout: Duration,
58}
59
60impl Handle {
61    /// Create a new [`HandleBuilder`].
62    pub fn builder() -> HandleBuilder {
63        HandleBuilder::new()
64    }
65
66    // Create the new Handle based on builder configuration
67    pub(crate) async fn new(b: &HandleBuilder) -> Result<Handle, NoSQLError> {
68        if b.auth_type == AuthType::None {
69            if b.from_environment {
70                return ia_err!("cannot build handle: no auth type specified. set ORACLE_NOSQL_AUTH environment.");
71            }
72            return ia_err!("cannot build handle: no auth type specified");
73            //trace!("defaulting auth config to user-based OCI auth from ~/.oci/config");
74            //builder = builder.cloud_auth_from_file("~/.oci/config")?;
75        }
76
77        let mut builder = b.clone();
78        // default timeout to 30 seconds
79        // TODO: connection timeout vs request timeout
80        let timeout = {
81            if let Some(t) = builder.timeout {
82                t.clone()
83            } else {
84                Duration::new(30, 0)
85            }
86        };
87        let c = {
88            if let Some(c) = &builder.client {
89                c.clone()
90            } else {
91                let mut cb = reqwest::Client::builder()
92                    .timeout(timeout)
93                    .connect_timeout(timeout)
94                    //.pool_idle_timeout(timeout)
95                    .connection_verbose(true);
96                if let Some(cert) = &builder.add_cert {
97                    cb = cb.add_root_certificate(cert.clone());
98                }
99                if builder.accept_invalid_certs {
100                    cb = cb.danger_accept_invalid_certs(true);
101                }
102                cb.build()?
103            }
104        };
105        // create auth provider if not already created
106        match builder.auth_type {
107            AuthType::Instance => {
108                let ifp = InstancePrincipalAuthProvider::new_with_client(&c).await?;
109                if builder.region.is_none() {
110                    builder = builder.cloud_region(ifp.region_id())?;
111                }
112                let ap = AuthProvider::Instance {
113                    provider: Box::new(ifp),
114                };
115                builder.auth = Arc::new(tokio::sync::Mutex::new(AuthConfig { provider: ap }));
116            }
117            AuthType::Resource => {
118                let rfp = ResourcePrincipalAuthProvider::new()?;
119                if builder.region.is_none() {
120                    builder = builder.cloud_region(rfp.region_id())?;
121                }
122                let ap = AuthProvider::Resource {
123                    provider: Box::new(rfp),
124                };
125                builder.auth = Arc::new(tokio::sync::Mutex::new(AuthConfig { provider: ap }));
126            }
127            _ => {}
128        }
129        if builder.endpoint.is_empty() {
130            if builder.from_environment {
131                return ia_err!("can't determine NoSQL endpoint: set ORACLE_NOSQL_ENDPOINT or ORACLE_NOSQL_REGION");
132            } else {
133                return ia_err!("can't determine NoSQL endpoint: call HandleBuilder::endpoint() or HandleBuilder::cloud_region()");
134            }
135        }
136        // normalize endpoint to "http[s]://{endpoint}/V2/nosql/data"
137        let mut ep = String::from("http");
138        if builder.use_https {
139            ep.push('s');
140        }
141        ep.push_str("://");
142        ep.push_str(&builder.endpoint);
143        ep.push_str("/V2/nosql/data");
144        debug!(
145            "Creating new Handle: {:?}, {:?}, endpoint={}",
146            builder.mode, builder.auth, ep
147        );
148        Ok(Handle {
149            inner: Arc::new(HandleRef {
150                client: c,
151                endpoint: ep,
152                serial_version: 4,
153                builder: builder,
154                timeout: timeout.clone(),
155                session: std::sync::Mutex::new("".to_string()),
156                request_id: AtomicUsize::new(1),
157            }),
158        })
159    }
160
161    // geeez, all this to get a stupid usize from an http header....
162    fn get_usize_header(headers: &HeaderMap, field: &str) -> Result<usize, NoSQLError> {
163        let val = headers.get(field);
164        if val.is_none() {
165            return ia_err!("missing \"{}\" value in return headers", field);
166        }
167        let valstr = val.unwrap().to_str();
168        if let Err(_) = valstr {
169            return ia_err!(
170                "\"{}\" value in return headers is not a valid string",
171                field
172            );
173        }
174        match valstr.unwrap().parse::<usize>() {
175            Ok(v) => {
176                return Ok(v);
177            }
178            Err(_) => {
179                return ia_err!("\"{}\" value in return headers is not an integer", field);
180            }
181        }
182    }
183
184    async fn post_data(
185        &self,
186        data: &Vec<u8>,
187        send_options: &mut SendOptions,
188    ) -> Result<Vec<u8>, NoSQLError> {
189        let request_id = self.inner.request_id.fetch_add(1, Ordering::Relaxed);
190        let mut headers = HeaderMap::new();
191        headers.insert("x-nosql-request-id", HeaderValue::from(request_id));
192
193        // If there is an oci auth provider, use that to set up required headers
194        let mut oci_provider: Option<&Box<dyn AuthenticationProvider>> = None;
195
196        // We need to lock the auth config because it may be asynchronously refreshed elsewhere
197        let pguard = self.inner.builder.auth.lock().await;
198        match &pguard.provider {
199            AuthProvider::Instance { provider } => {
200                oci_provider = Some(provider);
201            }
202            AuthProvider::Resource { provider } => {
203                oci_provider = Some(provider);
204            }
205            AuthProvider::External { provider } => {
206                oci_provider = Some(provider);
207            }
208            AuthProvider::File { provider } => {
209                oci_provider = Some(provider);
210            }
211            AuthProvider::Onprem { provider } => {
212                if let Some(p) = provider {
213                    p.add_required_headers(&self.inner.client, &mut headers)
214                        .await?;
215                }
216            }
217            AuthProvider::None => {}
218        }
219
220        if let Some(sp) = oci_provider {
221            if !self.inner.builder.default_compartment_id.is_empty() {
222                headers.insert(
223                    "x-nosql-compartment-id",
224                    HeaderValue::from_str(&self.inner.builder.default_compartment_id)?,
225                );
226            } else {
227                headers.insert(
228                    "x-nosql-compartment-id",
229                    HeaderValue::from_str(sp.tenancy_id())?,
230                );
231            }
232            {
233                // If there's a session cookie value, set it into the headers.
234                // The lock is needed because another async operation might try to
235                // update the session value while we're trying to read it.
236                // This is in its own code block so the lock will be released directly afterwards.
237                let sguard = self.inner.session.lock().unwrap();
238                if sguard.len() > 0 {
239                    let s = format!("session={}", sguard.as_str());
240                    headers.insert("Cookie", HeaderValue::from_str(s.as_str())?);
241                }
242            }
243            trace!("Adding required headers");
244            headers = signer::get_required_headers(
245                reqwest::Method::POST,
246                "",
247                headers,
248                Url::parse(&self.inner.endpoint)?,
249                sp,
250                HashMap::new(),
251                true,
252            )?;
253        } else if self.inner.builder.mode == HandleMode::Onprem {
254            // headers added above if necessary
255        } else if self.inner.builder.mode == HandleMode::Cloudsim {
256            headers.insert("Authorization", HeaderValue::from_str("Bearer rust")?);
257        }
258        // this will unlock the auth mutex
259        core::mem::drop(pguard);
260
261        // let send_options.compartment_id override compartment header
262        if !send_options.compartment_id.is_empty() {
263            headers.insert(
264                "x-nosql-compartment-id",
265                HeaderValue::from_str(&send_options.compartment_id)?,
266            );
267        }
268
269        // let send_options.namespace override namespace header
270        if !send_options.namespace.is_empty() {
271            headers.insert(
272                "x-nosql-default-ns",
273                HeaderValue::from_str(&send_options.namespace)?,
274            );
275        }
276
277        // Set User-Agent
278        headers.insert("User-Agent", HeaderValue::from_str(user_agent())?);
279
280        let resp = self
281            .inner
282            .client
283            .post(&self.inner.endpoint)
284            // TODO: resolve this clone... Hmmm
285            .body(data.clone())
286            .timeout(send_options.timeout.clone())
287            .headers(headers)
288            .send()
289            .await?;
290        // check resp status for 200, err on others
291        if !resp.status().is_success() {
292            let status = resp.status().clone();
293            let content = resp.text().await?;
294            return ia_err!(
295                "got unexpected http status: {}, response text: {}",
296                status,
297                content
298            );
299        }
300
301        // read request id in return, validate
302        match Self::get_usize_header(resp.headers(), "x-nosql-request-id") {
303            Ok(rid) => {
304                if request_id != rid {
305                    // TODO: if rid is less, loop again to read next response
306                    // In theory, this should never happen with http 1.1...
307                    return ia_err!("expected request_id {}, found {}", request_id, rid);
308                }
309            }
310            Err(e) => {
311                return ia_err!("can't get request_id from response: {}", e.to_string());
312            }
313        }
314        //println!("Response status={} headers:", resp.status());
315        //for (key, value) in resp.headers().iter() {
316        //println!("  {:?}: {:?}", key, value);
317        //}
318        // get session cookie, if available
319        for i in resp.cookies() {
320            if i.name() == "session" {
321                let mut sguard = self.inner.session.lock().unwrap();
322                *sguard = i.value().to_string();
323                trace!("Setting session={}", i.value());
324            }
325        }
326        let result = resp.bytes().await?;
327        // TODO: some way to avoid this copy
328        Ok(result.to_vec())
329    }
330
331    // TODO: opCode
332    pub(crate) async fn send_and_receive(
333        &self,
334        w: Writer,
335        send_options: &mut SendOptions,
336    ) -> Result<Reader, NoSQLError> {
337        send_options.retries = 0;
338        loop {
339            match self.send_and_receive_once(&w, send_options).await {
340                Ok(r) => return Ok(r),
341                Err(e) => {
342                    if e.code == InternalRetry {
343                        send_options.retries += 1;
344                        //tokio::time::sleep(Duration::from_millis(30)).await;
345                        continue;
346                    }
347                    return Err(e);
348                }
349            }
350        }
351    }
352
353    pub(crate) async fn send_and_receive_once(
354        &self,
355        w: &Writer,
356        send_options: &mut SendOptions,
357    ) -> Result<Reader, NoSQLError> {
358        let bytes = self.post_data(&w.buf, send_options).await?;
359
360        //println!("returned data: len={}", bytes.len());
361        let mut r = Reader::new().from_bytes(&bytes);
362        let m = MapWalker::check_reader_for_error(&mut r);
363        if m.is_ok() {
364            return Ok(r);
365        }
366        let err = m.unwrap_err();
367        // this is very specific: If we get a SIU error, and it has a specific string,
368        // it's likely that the service should have retried internally but did not for
369        // some reason. In this case, delay a bit and retry with the same auth header.
370        // allow for up to 4 retries, in case the routing to the service is doing round-robin
371        // across instances (typically 3 in NoSQL cloud).
372        // TODO: check current nano versus timeout at start of request
373        if send_options.retries < 40 && err.code == NoSQLErrorCode::SecurityInfoUnavailable {
374            // Note space at end of this message
375            if err.message == "NotAuthenticated. " {
376                // TODO: check remaining time for request based on timeout
377                tokio::time::sleep(Duration::from_millis(30)).await;
378                trace!("waited 30ms, now retrying SIU error");
379                return Err(NoSQLError::new(InternalRetry, ""));
380            }
381        }
382        // For other auth errors, try refreshing the auth provider. It may have
383        // expired credentials.
384        if send_options.retries < 4
385            && (err.code == NoSQLErrorCode::SecurityInfoUnavailable
386                || err.code == NoSQLErrorCode::RetryAuthentication
387                || err.code == NoSQLErrorCode::InvalidAuthorization)
388        {
389            let refreshed = self
390                .inner
391                .builder
392                .refresh_auth(&self.inner.client)
393                .await
394                .map_err(|e| {
395                    NoSQLError::new(
396                        err.code,
397                        format!(
398                            "error trying to refresh authentication provider: {}",
399                            e.to_string()
400                        )
401                        .as_str(),
402                    )
403                })?;
404            if refreshed {
405                trace!("Refreshed auth provider: retrying");
406                return Err(NoSQLError::new(InternalRetry, ""));
407            }
408            trace!("attempt to refresh generated no error but did not refresh auth");
409        }
410        Err(err)
411    }
412
413    pub(crate) fn get_timeout(&self, t: &Option<Duration>) -> Duration {
414        // if t is given, use that. If not, use handle's timeout
415        if let Some(d) = t {
416            return d.clone();
417        }
418        self.inner.timeout.clone()
419    }
420}
421
422#[derive(Debug, Default)]
423pub(crate) struct SendOptions {
424    #[allow(dead_code)]
425    pub(crate) retryable: bool,
426    pub(crate) retries: u16,
427    pub(crate) timeout: Duration,
428    pub(crate) compartment_id: String,
429    pub(crate) namespace: String,
430}