gcloud_auth/token_source/
service_account_token_source.rs1use std::collections::HashMap;
2use std::fmt::Debug;
3
4use async_trait::async_trait;
5use reqwest::Response;
6use serde::{Deserialize, Serialize};
7use time::OffsetDateTime;
8
9use crate::credentials;
10use crate::error::{Error, TokenErrorResponse};
11use crate::misc::UnwrapOrEmpty;
12use crate::token::{Token, TOKEN_URL};
13use crate::token_source::{default_http_client, InternalIdToken, InternalToken, TokenSource};
14
15#[derive(Clone, Serialize)]
16struct Claims<'a> {
17 iss: &'a str,
18 sub: Option<&'a str>,
19 scope: Option<&'a str>,
20 aud: &'a str,
21 exp: i64,
22 iat: i64,
23 #[serde(flatten)]
24 private_claims: &'a HashMap<String, serde_json::Value>,
25}
26
27impl Claims<'_> {
28 fn token(&self, pk: &jsonwebtoken::EncodingKey, pk_id: &str) -> Result<String, Error> {
29 let mut header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256);
30 header.kid = Some(pk_id.to_string());
31 let v = jsonwebtoken::encode(&header, self, pk)?;
32 Ok(v)
33 }
34}
35
36#[allow(dead_code)]
40pub struct ServiceAccountTokenSource {
41 email: String,
42 pk: jsonwebtoken::EncodingKey,
43 pk_id: String,
44 audience: String,
45 use_id_token: bool,
46 private_claims: HashMap<String, serde_json::Value>,
47}
48
49impl Debug for ServiceAccountTokenSource {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 f.debug_struct("ServiceAccountTokenSource")
53 .field("email", &self.email)
54 .field("pk_id", &self.pk_id)
55 .field("audience", &self.audience)
56 .finish()
57 }
58}
59
60impl ServiceAccountTokenSource {
61 pub(crate) fn new(cred: &credentials::CredentialsFile, audience: &str) -> Result<ServiceAccountTokenSource, Error> {
62 Ok(ServiceAccountTokenSource {
63 email: cred.client_email.unwrap_or_empty(),
64 pk: cred.try_to_private_key()?,
65 pk_id: cred.private_key_id.unwrap_or_empty(),
66 audience: match &cred.audience {
67 None => audience.to_string(),
68 Some(s) => s.to_string(),
69 },
70 use_id_token: false,
71 private_claims: HashMap::new(),
72 })
73 }
74}
75
76#[async_trait]
77impl TokenSource for ServiceAccountTokenSource {
78 async fn token(&self) -> Result<Token, Error> {
79 let iat = OffsetDateTime::now_utc();
80 let exp = iat + time::Duration::hours(1);
81
82 let token = Claims {
83 iss: self.email.as_ref(),
84 sub: Some(self.email.as_ref()),
85 scope: None,
86 aud: self.audience.as_ref(),
87 exp: exp.unix_timestamp(),
88 iat: iat.unix_timestamp(),
89 private_claims: &HashMap::new(),
90 }
91 .token(&self.pk, &self.pk_id)?;
92
93 return Ok(Token {
94 access_token: token,
95 token_type: "Bearer".to_string(),
96 expiry: Some(exp),
97 });
98 }
99}
100
101#[allow(dead_code)]
102#[derive(Clone, Deserialize)]
103struct OAuth2Token {
104 pub access_token: String,
105 pub token_type: String,
106 pub id_token: Option<String>,
107 pub expires_in: Option<i64>,
108}
109
110pub struct OAuth2ServiceAccountTokenSource {
112 pub email: String,
113 pub pk: jsonwebtoken::EncodingKey,
114 pub pk_id: String,
115 pub scopes: String,
116 pub token_url: String,
117 pub sub: Option<String>,
118
119 pub client: reqwest::Client,
120
121 use_id_token: bool,
122 private_claims: HashMap<String, serde_json::Value>,
123}
124
125impl Debug for OAuth2ServiceAccountTokenSource {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 f.debug_struct("OAuth2ServiceAccountTokenSource")
129 .field("email", &self.email)
130 .field("pk_id", &self.pk_id)
131 .field("scopes", &self.scopes)
132 .field("token_url", &self.token_url)
133 .field("sub", &self.sub)
134 .field("client", &self.client)
135 .field("use_id_token", &self.use_id_token)
136 .field("private_claims", &self.private_claims)
137 .finish()
138 }
139}
140
141impl OAuth2ServiceAccountTokenSource {
142 pub(crate) fn new(
143 cred: &credentials::CredentialsFile,
144 scopes: &str,
145 sub: Option<&str>,
146 ) -> Result<OAuth2ServiceAccountTokenSource, Error> {
147 Ok(OAuth2ServiceAccountTokenSource {
148 email: cred.client_email.unwrap_or_empty(),
149 pk: cred.try_to_private_key()?,
150 pk_id: cred.private_key_id.unwrap_or_empty(),
151 scopes: scopes.to_string(),
152 token_url: match &cred.token_uri {
153 None => TOKEN_URL.to_string(),
154 Some(s) => s.to_string(),
155 },
156 client: default_http_client(),
157 sub: sub.map(|s| s.to_string()),
158 use_id_token: false,
159 private_claims: HashMap::new(),
160 })
161 }
162
163 pub(crate) fn with_use_id_token(mut self) -> Self {
164 self.use_id_token = true;
165 self
166 }
167
168 pub(crate) fn with_private_claims(mut self, claims: HashMap<String, serde_json::Value>) -> Self {
169 self.private_claims = claims;
170 self
171 }
172
173 async fn check_response_status(response: Response) -> Result<Response, Error> {
175 let error = match response.error_for_status_ref() {
177 Ok(_) => return Ok(response),
178 Err(error) => error,
179 };
180
181 let status = response.status();
183 Err(response
184 .json::<TokenErrorResponse>()
185 .await
186 .map(|response| Error::TokenErrorResponse {
187 status: status.as_u16(),
188 error: response.error,
189 error_description: response.error_description,
190 })
191 .unwrap_or(Error::HttpError(error)))
192 }
193}
194
195#[async_trait]
196impl TokenSource for OAuth2ServiceAccountTokenSource {
197 async fn token(&self) -> Result<Token, Error> {
198 let iat = OffsetDateTime::now_utc();
199 let exp = iat + time::Duration::hours(1);
200
201 let claims = Claims {
202 iss: self.email.as_ref(),
203 sub: self.sub.as_ref().map(|s| s.as_ref()),
204 scope: Some(self.scopes.as_ref()),
205 aud: self.token_url.as_ref(),
206 exp: exp.unix_timestamp(),
207 iat: iat.unix_timestamp(),
208 private_claims: &self.private_claims,
209 };
210 let request_token = claims.token(&self.pk, &self.pk_id)?;
211
212 let form = [
213 ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
214 ("assertion", request_token.as_str()),
215 ];
216
217 match self.use_id_token {
218 true => {
219 let audience = claims
220 .private_claims
221 .get("target_audience")
222 .ok_or(Error::NoTargetAudienceFound)?
223 .as_str()
224 .ok_or(Error::NoTargetAudienceFound)?;
225 let response = self.client.post(self.token_url.as_str()).form(&form).send().await?;
226 Ok(Self::check_response_status(response)
227 .await?
228 .json::<InternalIdToken>()
229 .await?
230 .to_token(audience)?)
231 }
232 false => {
233 let response = self.client.post(self.token_url.as_str()).form(&form).send().await?;
234 Ok(Self::check_response_status(response)
235 .await?
236 .json::<InternalToken>()
237 .await?
238 .to_token(iat))
239 }
240 }
241 }
242}