1use crate::Auth;
2use crate::Config;
3use crate::Error;
4use crate::RenewError;
5use crate::RenewPolicy;
6
7use serde::Serialize;
8use url::Url;
9
10pub struct Client<T>
13where
14 T: Auth,
15{
16 config: Config,
17 auth: T,
18 reauth_mutex: std::sync::Mutex<()>,
19}
20
21impl<T> Client<T>
22where
23 T: Auth,
24{
25 pub fn new(conf: Config, auth_opts: T) -> Result<Client<T>, Error> {
28 match auth_opts.auth(&conf.vault_url) {
29 Err(e) => return Err(e),
30 Ok(()) => {}
31 };
32
33 let client = Client::<T> {
34 config: conf,
35 auth: auth_opts,
36 reauth_mutex: std::sync::Mutex::new(()),
37 };
38
39 Ok(client)
40 }
41
42 pub fn renew_background(&self) -> Result<(), RenewError> {
46 let threshold = match self.config.renew_policy {
47 RenewPolicy::Renew(t) => t,
48 _ => return Err(RenewError::NotEnabled),
49 };
50
51 loop {
52 if !self.auth.is_renewable() {
53 return Err(RenewError::NotRenewable);
54 }
55
56 let total_duration = self.auth.get_total_duration();
57 let wait_percentage = 1.0 - threshold;
58
59 let wait_duration =
60 std::time::Duration::from_secs(((total_duration as f32) * wait_percentage) as u64);
61
62 std::thread::sleep(wait_duration);
63
64 match self.auth.renew(&self.config.vault_url) {
65 Ok(_) => {}
66 Err(e) => {
67 return Err(RenewError::from(e));
68 }
69 };
70 }
71 }
72
73 pub fn get_token(&self) -> String {
78 self.auth.get_token()
79 }
80
81 pub async fn check_session(&self) -> Result<(), Error> {
86 if !self.auth.is_expired() {
87 return Ok(());
88 }
89
90 let _data = self.reauth_mutex.lock().unwrap();
92 if !self.auth.is_expired() {
96 return Ok(());
97 }
98
99 let result = match self.config.renew_policy {
100 RenewPolicy::Reauth => self.auth.auth(&self.config.vault_url),
101 RenewPolicy::Nothing | RenewPolicy::Renew(_) => Err(Error::SessionExpired),
102 };
103
104 return result;
105 }
106
107 pub async fn vault_request<P: Serialize>(
111 &self,
112 method: reqwest::Method,
113 path: &str,
114 body: Option<&P>,
115 ) -> Result<reqwest::Response, Error> {
116 self.check_session().await?;
117
118 let mut url = match Url::parse(&self.config.vault_url) {
119 Err(e) => {
120 return Err(Error::from(e));
121 }
122 Ok(url) => url,
123 };
124 url = match url.join("v1/") {
125 Err(e) => {
126 return Err(Error::from(e));
127 }
128 Ok(u) => u,
129 };
130 url = match url.join(path) {
131 Err(e) => {
132 return Err(Error::from(e));
133 }
134 Ok(u) => u,
135 };
136
137 let token = self.auth.get_token();
138
139 let http_client = reqwest::Client::new();
140 let mut req = http_client
141 .request(method, url)
142 .header("X-Vault-Token", &token)
143 .header("X-Vault-Request", "true");
144
145 if body.is_some() {
146 req = req.json(body.unwrap());
147 }
148
149 let resp = match req.send().await {
150 Err(e) => return Err(Error::from(e)),
151 Ok(resp) => resp,
152 };
153
154 let status_code = resp.status().as_u16();
155
156 match status_code {
157 200 | 204 => Ok(resp),
158 _ => Err(Error::from(status_code)),
159 }
160 }
161}