hc_vault/
client.rs

1use crate::Auth;
2use crate::Config;
3use crate::Error;
4use crate::RenewError;
5use crate::RenewPolicy;
6
7use serde::Serialize;
8use url::Url;
9
10/// The Client struct represents a single Vault-Connection/Session that can be used for any
11/// further requests to vault
12pub struct Client<T>
13where
14    T: Auth,
15{
16    config: Config,
17    auth: T,
18    reauth_mutex: std::sync::Mutex<()>,
19}
20
21impl<T> Client<T>
22where
23    T: Auth,
24{
25    /// This function is used to obtain a new vault session with the given config and
26    /// auth settings
27    pub fn new(conf: Config, auth_opts: T) -> Result<Client<T>, Error> {
28        match auth_opts.auth(&conf.vault_url) {
29            Err(e) => return Err(e),
30            Ok(()) => {}
31        };
32
33        let client = Client::<T> {
34            config: conf,
35            auth: auth_opts,
36            reauth_mutex: std::sync::Mutex::new(()),
37        };
38
39        Ok(client)
40    }
41
42    /// This function will enter an infitive Loop and blocks the current thread.
43    /// It will do everything related to renewing the token/session. This will
44    /// idealy run inside it's own thread as to not block anything else important
45    pub fn renew_background(&self) -> Result<(), RenewError> {
46        let threshold = match self.config.renew_policy {
47            RenewPolicy::Renew(t) => t,
48            _ => return Err(RenewError::NotEnabled),
49        };
50
51        loop {
52            if !self.auth.is_renewable() {
53                return Err(RenewError::NotRenewable);
54            }
55
56            let total_duration = self.auth.get_total_duration();
57            let wait_percentage = 1.0 - threshold;
58
59            let wait_duration =
60                std::time::Duration::from_secs(((total_duration as f32) * wait_percentage) as u64);
61
62            std::thread::sleep(wait_duration);
63
64            match self.auth.renew(&self.config.vault_url) {
65                Ok(_) => {}
66                Err(e) => {
67                    return Err(RenewError::from(e));
68                }
69            };
70        }
71    }
72
73    /// A simple method to get the underlying vault session/client token
74    /// for the current vault session.
75    /// It is not recommended to use this function, but rather stick to other
76    /// more integrated parts, like the vault_request function
77    pub fn get_token(&self) -> String {
78        self.auth.get_token()
79    }
80
81    /// This function is used to check if the current
82    /// session is still valid and if not to renew
83    /// the session/obtain a new one and update
84    /// all data related to it
85    pub async fn check_session(&self) -> Result<(), Error> {
86        if !self.auth.is_expired() {
87            return Ok(());
88        }
89
90        // Take mutex to ensure only one thread can try to reauth at a time
91        let _data = self.reauth_mutex.lock().unwrap();
92        // If the mutex is acquired, check if the session still needs to be renewed or if another
93        // thread has already done this, in which case this one should just return as its all fine
94        // now
95        if !self.auth.is_expired() {
96            return Ok(());
97        }
98
99        let result = match self.config.renew_policy {
100            RenewPolicy::Reauth => self.auth.auth(&self.config.vault_url),
101            RenewPolicy::Nothing | RenewPolicy::Renew(_) => Err(Error::SessionExpired),
102        };
103
104        return result;
105    }
106
107    /// This function is a general way to directly make requests to vault using
108    /// the current session. This can be used to make custom requests or to make requests
109    /// to mounts that are not directly covered by this crate.
110    pub async fn vault_request<P: Serialize>(
111        &self,
112        method: reqwest::Method,
113        path: &str,
114        body: Option<&P>,
115    ) -> Result<reqwest::Response, Error> {
116        self.check_session().await?;
117
118        let mut url = match Url::parse(&self.config.vault_url) {
119            Err(e) => {
120                return Err(Error::from(e));
121            }
122            Ok(url) => url,
123        };
124        url = match url.join("v1/") {
125            Err(e) => {
126                return Err(Error::from(e));
127            }
128            Ok(u) => u,
129        };
130        url = match url.join(path) {
131            Err(e) => {
132                return Err(Error::from(e));
133            }
134            Ok(u) => u,
135        };
136
137        let token = self.auth.get_token();
138
139        let http_client = reqwest::Client::new();
140        let mut req = http_client
141            .request(method, url)
142            .header("X-Vault-Token", &token)
143            .header("X-Vault-Request", "true");
144
145        if body.is_some() {
146            req = req.json(body.unwrap());
147        }
148
149        let resp = match req.send().await {
150            Err(e) => return Err(Error::from(e)),
151            Ok(resp) => resp,
152        };
153
154        let status_code = resp.status().as_u16();
155
156        match status_code {
157            200 | 204 => Ok(resp),
158            _ => Err(Error::from(status_code)),
159        }
160    }
161}