1use serde::de::DeserializeOwned;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicI64, Ordering};
4use tokio::sync::Semaphore;
5
6use crate::config::ClientConfig;
7use crate::credential::{ChainProvider, Credential, CredentialProvider};
8use crate::error::{Result, StsError};
9use crate::exec::{calculate_smoothed_offset, extract_server_time, handle_response};
10use crate::request::build_signed_request;
11
12#[derive(Debug, Clone)]
26pub struct AssumeRoleRequest {
27 pub role_arn: String,
29 pub role_session_name: String,
31 pub policy: Option<String>,
33 pub duration_seconds: Option<u64>,
35 pub external_id: Option<String>,
37}
38
39impl AssumeRoleRequest {
40 pub fn builder() -> AssumeRoleRequestBuilder {
42 AssumeRoleRequestBuilder::default()
43 }
44
45 pub(crate) fn to_params(&self) -> Vec<(&str, String)> {
46 let mut params = vec![
47 ("RoleArn", self.role_arn.clone()),
48 ("RoleSessionName", self.role_session_name.clone()),
49 ];
50 if let Some(ref policy) = self.policy {
51 params.push(("Policy", policy.clone()));
52 }
53 if let Some(duration) = self.duration_seconds {
54 params.push(("DurationSeconds", duration.to_string()));
55 }
56 if let Some(ref external_id) = self.external_id {
57 params.push(("ExternalId", external_id.clone()));
58 }
59 params
60 }
61}
62
63#[derive(Default)]
77pub struct AssumeRoleRequestBuilder {
78 role_arn: Option<String>,
79 role_session_name: Option<String>,
80 policy: Option<String>,
81 duration_seconds: Option<u64>,
82 external_id: Option<String>,
83}
84
85impl AssumeRoleRequestBuilder {
86 pub fn role_arn(mut self, arn: impl Into<String>) -> Self {
88 self.role_arn = Some(arn.into());
89 self
90 }
91
92 pub fn role_session_name(mut self, name: impl Into<String>) -> Self {
94 self.role_session_name = Some(name.into());
95 self
96 }
97
98 pub fn policy(mut self, policy: impl Into<String>) -> Self {
100 self.policy = Some(policy.into());
101 self
102 }
103
104 pub fn duration_seconds(mut self, seconds: u64) -> Self {
106 self.duration_seconds = Some(seconds);
107 self
108 }
109
110 pub fn external_id(mut self, id: impl Into<String>) -> Self {
112 self.external_id = Some(id.into());
113 self
114 }
115
116 pub fn build(self) -> AssumeRoleRequest {
123 self.try_build()
124 .expect("AssumeRoleRequest requires role_arn and role_session_name")
125 }
126
127 pub fn try_build(self) -> Result<AssumeRoleRequest> {
135 let role_arn = self.role_arn.ok_or_else(|| {
136 StsError::Validation("role_arn is required for AssumeRoleRequest".into())
137 })?;
138 let role_session_name = self.role_session_name.ok_or_else(|| {
139 StsError::Validation("role_session_name is required for AssumeRoleRequest".into())
140 })?;
141 Ok(AssumeRoleRequest {
142 role_arn,
143 role_session_name,
144 policy: self.policy,
145 duration_seconds: self.duration_seconds,
146 external_id: self.external_id,
147 })
148 }
149}
150
151#[derive(Debug, Clone)]
153pub struct AssumeRoleWithSamlRequest {
154 pub saml_provider_arn: String,
156 pub role_arn: String,
158 pub saml_assertion: String,
160 pub policy: Option<String>,
162 pub duration_seconds: Option<u64>,
164}
165
166impl AssumeRoleWithSamlRequest {
167 pub fn builder() -> AssumeRoleWithSamlRequestBuilder {
169 AssumeRoleWithSamlRequestBuilder::default()
170 }
171
172 pub(crate) fn to_params(&self) -> Vec<(&str, String)> {
173 let mut params = vec![
174 ("SAMLProviderArn", self.saml_provider_arn.clone()),
175 ("RoleArn", self.role_arn.clone()),
176 ("SAMLAssertion", self.saml_assertion.clone()),
177 ];
178 if let Some(ref policy) = self.policy {
179 params.push(("Policy", policy.clone()));
180 }
181 if let Some(duration) = self.duration_seconds {
182 params.push(("DurationSeconds", duration.to_string()));
183 }
184 params
185 }
186}
187
188#[derive(Default)]
190pub struct AssumeRoleWithSamlRequestBuilder {
191 saml_provider_arn: Option<String>,
192 role_arn: Option<String>,
193 saml_assertion: Option<String>,
194 policy: Option<String>,
195 duration_seconds: Option<u64>,
196}
197
198impl AssumeRoleWithSamlRequestBuilder {
199 pub fn saml_provider_arn(mut self, arn: impl Into<String>) -> Self {
201 self.saml_provider_arn = Some(arn.into());
202 self
203 }
204
205 pub fn role_arn(mut self, arn: impl Into<String>) -> Self {
207 self.role_arn = Some(arn.into());
208 self
209 }
210
211 pub fn saml_assertion(mut self, assertion: impl Into<String>) -> Self {
213 self.saml_assertion = Some(assertion.into());
214 self
215 }
216
217 pub fn policy(mut self, policy: impl Into<String>) -> Self {
219 self.policy = Some(policy.into());
220 self
221 }
222
223 pub fn duration_seconds(mut self, seconds: u64) -> Self {
225 self.duration_seconds = Some(seconds);
226 self
227 }
228
229 pub fn build(self) -> AssumeRoleWithSamlRequest {
236 self.try_build().expect(
237 "AssumeRoleWithSamlRequest requires saml_provider_arn, role_arn, and saml_assertion",
238 )
239 }
240
241 pub fn try_build(self) -> Result<AssumeRoleWithSamlRequest> {
247 let saml_provider_arn = self.saml_provider_arn.ok_or_else(|| {
248 StsError::Validation(
249 "saml_provider_arn is required for AssumeRoleWithSamlRequest".into(),
250 )
251 })?;
252 let role_arn = self.role_arn.ok_or_else(|| {
253 StsError::Validation("role_arn is required for AssumeRoleWithSamlRequest".into())
254 })?;
255 let saml_assertion = self.saml_assertion.ok_or_else(|| {
256 StsError::Validation("saml_assertion is required for AssumeRoleWithSamlRequest".into())
257 })?;
258 Ok(AssumeRoleWithSamlRequest {
259 saml_provider_arn,
260 role_arn,
261 saml_assertion,
262 policy: self.policy,
263 duration_seconds: self.duration_seconds,
264 })
265 }
266}
267
268#[derive(Debug, Clone)]
270pub struct AssumeRoleWithOidcRequest {
271 pub oidc_provider_arn: String,
273 pub role_arn: String,
275 pub oidc_token: String,
277 pub policy: Option<String>,
279 pub duration_seconds: Option<u64>,
281 pub role_session_name: Option<String>,
283}
284
285impl AssumeRoleWithOidcRequest {
286 pub fn builder() -> AssumeRoleWithOidcRequestBuilder {
288 AssumeRoleWithOidcRequestBuilder::default()
289 }
290
291 pub(crate) fn to_params(&self) -> Vec<(&str, String)> {
292 let mut params = vec![
293 ("OIDCProviderArn", self.oidc_provider_arn.clone()),
294 ("RoleArn", self.role_arn.clone()),
295 ("OIDCToken", self.oidc_token.clone()),
296 ];
297 if let Some(ref policy) = self.policy {
298 params.push(("Policy", policy.clone()));
299 }
300 if let Some(duration) = self.duration_seconds {
301 params.push(("DurationSeconds", duration.to_string()));
302 }
303 if let Some(ref session) = self.role_session_name {
304 params.push(("RoleSessionName", session.clone()));
305 }
306 params
307 }
308}
309
310#[derive(Default)]
312pub struct AssumeRoleWithOidcRequestBuilder {
313 oidc_provider_arn: Option<String>,
314 role_arn: Option<String>,
315 oidc_token: Option<String>,
316 policy: Option<String>,
317 duration_seconds: Option<u64>,
318 role_session_name: Option<String>,
319}
320
321impl AssumeRoleWithOidcRequestBuilder {
322 pub fn oidc_provider_arn(mut self, arn: impl Into<String>) -> Self {
324 self.oidc_provider_arn = Some(arn.into());
325 self
326 }
327
328 pub fn role_arn(mut self, arn: impl Into<String>) -> Self {
330 self.role_arn = Some(arn.into());
331 self
332 }
333
334 pub fn oidc_token(mut self, token: impl Into<String>) -> Self {
336 self.oidc_token = Some(token.into());
337 self
338 }
339
340 pub fn policy(mut self, policy: impl Into<String>) -> Self {
342 self.policy = Some(policy.into());
343 self
344 }
345
346 pub fn duration_seconds(mut self, seconds: u64) -> Self {
348 self.duration_seconds = Some(seconds);
349 self
350 }
351
352 pub fn role_session_name(mut self, name: impl Into<String>) -> Self {
354 self.role_session_name = Some(name.into());
355 self
356 }
357
358 pub fn build(self) -> AssumeRoleWithOidcRequest {
365 self.try_build().expect(
366 "AssumeRoleWithOidcRequest requires oidc_provider_arn, role_arn, and oidc_token",
367 )
368 }
369
370 pub fn try_build(self) -> Result<AssumeRoleWithOidcRequest> {
376 let oidc_provider_arn = self.oidc_provider_arn.ok_or_else(|| {
377 StsError::Validation(
378 "oidc_provider_arn is required for AssumeRoleWithOidcRequest".into(),
379 )
380 })?;
381 let role_arn = self.role_arn.ok_or_else(|| {
382 StsError::Validation("role_arn is required for AssumeRoleWithOidcRequest".into())
383 })?;
384 let oidc_token = self.oidc_token.ok_or_else(|| {
385 StsError::Validation("oidc_token is required for AssumeRoleWithOidcRequest".into())
386 })?;
387 Ok(AssumeRoleWithOidcRequest {
388 oidc_provider_arn,
389 role_arn,
390 oidc_token,
391 policy: self.policy,
392 duration_seconds: self.duration_seconds,
393 role_session_name: self.role_session_name,
394 })
395 }
396}
397
398pub struct Client {
400 http: reqwest::Client,
401 config: ClientConfig,
402 credential: Credential,
403 time_offset: Arc<AtomicI64>,
406 semaphore: Arc<Semaphore>,
408}
409
410impl Client {
411 pub fn new(credential: Credential) -> Result<Self> {
413 Self::with_config(credential, ClientConfig::default())
414 }
415
416 pub fn with_config(credential: Credential, config: ClientConfig) -> Result<Self> {
418 let mut builder = reqwest::Client::builder()
419 .timeout(config.timeout)
420 .connect_timeout(config.connect_timeout)
421 .pool_idle_timeout(config.pool_idle_timeout)
422 .pool_max_idle_per_host(config.pool_max_idle_per_host);
423
424 if let Some(keepalive) = config.tcp_keepalive {
426 builder = builder.tcp_keepalive(keepalive);
427 }
428
429 let http = builder.build().map_err(StsError::HttpClient)?;
430 let semaphore = Arc::new(Semaphore::new(config.max_concurrent_requests));
431 Ok(Self {
432 http,
433 config,
434 credential,
435 time_offset: Arc::new(AtomicI64::new(0)),
436 semaphore,
437 })
438 }
439
440 pub fn from_env() -> Result<Self> {
442 let credential = ChainProvider::default_chain().resolve()?;
443 Self::new(credential)
444 }
445
446 pub async fn assume_role(
448 &self,
449 request: AssumeRoleRequest,
450 ) -> Result<crate::response::AssumeRoleResponse> {
451 let owned = request.to_params();
452 let params: Vec<(&str, &str)> = owned.iter().map(|(k, v)| (*k, v.as_str())).collect();
453 self.execute("AssumeRole", ¶ms).await
454 }
455
456 pub async fn assume_role_with_saml(
458 &self,
459 request: AssumeRoleWithSamlRequest,
460 ) -> Result<crate::response::AssumeRoleWithSamlResponse> {
461 let owned = request.to_params();
462 let params: Vec<(&str, &str)> = owned.iter().map(|(k, v)| (*k, v.as_str())).collect();
463 self.execute("AssumeRoleWithSAML", ¶ms).await
464 }
465
466 pub async fn assume_role_with_oidc(
468 &self,
469 request: AssumeRoleWithOidcRequest,
470 ) -> Result<crate::response::AssumeRoleWithOidcResponse> {
471 let owned = request.to_params();
472 let params: Vec<(&str, &str)> = owned.iter().map(|(k, v)| (*k, v.as_str())).collect();
473 self.execute("AssumeRoleWithOIDC", ¶ms).await
474 }
475
476 pub async fn get_caller_identity(&self) -> Result<crate::response::GetCallerIdentityResponse> {
478 self.execute("GetCallerIdentity", &[]).await
479 }
480
481 pub fn time_offset(&self) -> i64 {
486 self.time_offset.load(Ordering::Relaxed)
487 }
488
489 fn update_time_offset(&self, server_time: i64) {
494 let local_time = chrono::Utc::now().timestamp();
495 let new_offset = server_time - local_time;
496 let current_offset = self.time_offset.load(Ordering::Relaxed);
497 let smoothed = calculate_smoothed_offset(current_offset, new_offset);
498 self.time_offset.store(smoothed, Ordering::Relaxed);
499 }
500
501 async fn execute<T: DeserializeOwned>(
502 &self,
503 action: &str,
504 params: &[(&str, &str)],
505 ) -> Result<T> {
506 let _permit = self
508 .semaphore
509 .acquire()
510 .await
511 .map_err(|e| crate::error::StsError::Config(format!("Semaphore closed: {}", e)))?;
512
513 let time_offset = self.time_offset.load(Ordering::Relaxed);
514 let body =
515 build_signed_request(action, params, &self.credential, &self.config, time_offset)?;
516
517 let response = self
518 .http
519 .post(&self.config.endpoint)
520 .header("Content-Type", "application/x-www-form-urlencoded")
521 .body(body)
522 .send()
523 .await?;
524
525 if let Some(server_time) = extract_server_time(response.headers()) {
527 self.update_time_offset(server_time);
528 }
529
530 let status = response.status();
531 let text = response.text().await?;
532
533 handle_response(status, text)
534 }
535}