gcloud_auth_patch/token_source/
service_account_token_source.rs1use std::collections::HashMap;
2use std::fmt::Debug;
3
4use async_trait::async_trait;
5use reqwest::Response;
6use serde::{Deserialize, Serialize};
7use time::OffsetDateTime;
8
9use crate::credentials;
10use crate::error::{Error, TokenErrorResponse};
11use crate::misc::UnwrapOrEmpty;
12use crate::token::{Token, TOKEN_URL};
13use crate::token_source::{default_http_client, InternalIdToken, InternalToken, TokenSource};
14
15#[derive(Clone, Serialize)]
16struct Claims<'a> {
17 iss: &'a str,
18 sub: Option<&'a str>,
19 scope: Option<&'a str>,
20 aud: &'a str,
21 exp: i64,
22 iat: i64,
23 #[serde(flatten)]
24 private_claims: &'a HashMap<String, serde_json::Value>,
25}
26
27impl Claims<'_> {
28 fn token(&self, pk: &jsonwebtoken::EncodingKey, pk_id: &str) -> Result<String, Error> {
29 let mut header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256);
30 header.kid = Some(pk_id.to_string());
31 let v = jsonwebtoken::encode(&header, self, pk)?;
32 Ok(v)
33 }
34}
35
36#[allow(dead_code)]
40pub struct ServiceAccountTokenSource {
41 email: String,
42 pk: jsonwebtoken::EncodingKey,
43 pk_id: String,
44 audience: String,
45 use_id_token: bool,
46 private_claims: HashMap<String, serde_json::Value>,
47}
48
49impl Debug for ServiceAccountTokenSource {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 f.debug_struct("ServiceAccountTokenSource")
53 .field("email", &self.email)
54 .field("pk_id", &self.pk_id)
55 .field("audience", &self.audience)
56 .finish()
57 }
58}
59
60impl ServiceAccountTokenSource {
61 pub(crate) fn new(cred: &credentials::CredentialsFile, audience: &str) -> Result<ServiceAccountTokenSource, Error> {
62 Ok(ServiceAccountTokenSource {
63 email: cred.client_email.unwrap_or_empty(),
64 pk: cred.try_to_private_key()?,
65 pk_id: cred.private_key_id.unwrap_or_empty(),
66 audience: match &cred.audience {
67 None => audience.to_string(),
68 Some(s) => s.to_string(),
69 },
70 use_id_token: false,
71 private_claims: HashMap::new(),
72 })
73 }
74}
75
76#[async_trait]
77impl TokenSource for ServiceAccountTokenSource {
78 async fn token(&self) -> Result<Token, Error> {
79 let iat = OffsetDateTime::now_utc();
80 let exp = iat + time::Duration::hours(1);
81
82 let token = Claims {
83 iss: self.email.as_ref(),
84 sub: Some(self.email.as_ref()),
85 scope: None,
86 aud: self.audience.as_ref(),
87 exp: exp.unix_timestamp(),
88 iat: iat.unix_timestamp(),
89 private_claims: &HashMap::new(),
90 }
91 .token(&self.pk, &self.pk_id)?;
92
93 return Ok(Token {
94 access_token: token,
95 token_type: "Bearer".to_string(),
96 expiry: Some(exp),
97 });
98 }
99}
100
101#[allow(dead_code)]
102#[derive(Clone, Deserialize)]
103struct OAuth2Token {
104 pub access_token: String,
105 pub token_type: String,
106 pub id_token: Option<String>,
107 pub expires_in: Option<i64>,
108}
109
110pub struct OAuth2ServiceAccountTokenSource {
112 pub email: String,
113 pub pk: jsonwebtoken::EncodingKey,
114 pub pk_id: String,
115 pub scopes: String,
116 pub token_url: String,
117 pub sub: Option<String>,
118
119 pub client: reqwest::Client,
120
121 use_id_token: bool,
122 private_claims: HashMap<String, serde_json::Value>,
123}
124
125impl Debug for OAuth2ServiceAccountTokenSource {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 f.debug_struct("OAuth2ServiceAccountTokenSource")
129 .field("email", &self.email)
130 .field("pk_id", &self.pk_id)
131 .field("scopes", &self.scopes)
132 .field("token_url", &self.token_url)
133 .field("sub", &self.sub)
134 .field("client", &self.client)
135 .field("use_id_token", &self.use_id_token)
136 .field("private_claims", &self.private_claims)
137 .finish()
138 }
139}
140
141impl OAuth2ServiceAccountTokenSource {
142 pub(crate) fn new(
143 cred: &credentials::CredentialsFile,
144 scopes: &str,
145 sub: Option<&str>,
146 ) -> Result<OAuth2ServiceAccountTokenSource, Error> {
147 Ok(OAuth2ServiceAccountTokenSource {
148 email: cred.client_email.unwrap_or_empty(),
149 pk: cred.try_to_private_key()?,
150 pk_id: cred.private_key_id.unwrap_or_empty(),
151 scopes: scopes.to_string(),
152 token_url: match &cred.token_uri {
153 None => TOKEN_URL.to_string(),
154 Some(s) => s.to_string(),
155 },
156 client: default_http_client(),
157 sub: sub.map(|s| s.to_string()),
158 use_id_token: false,
159 private_claims: HashMap::new(),
160 })
161 }
162
163 pub(crate) fn with_use_id_token(mut self) -> Self {
164 self.use_id_token = true;
165 self
166 }
167
168 pub(crate) fn with_private_claims(mut self, claims: HashMap<String, serde_json::Value>) -> Self {
169 self.private_claims = claims;
170 self
171 }
172
173 async fn check_response_status(response: Response) -> Result<Response, Error> {
175 let error = match response.error_for_status_ref() {
177 Ok(_) => return Ok(response),
178 Err(error) => error,
179 };
180
181 let status = response.status();
183
184 let body_text = response.text().await.unwrap_or_else(|_| "".to_string());
186
187 match serde_json::from_str::<TokenErrorResponse>(&body_text) {
188 Ok(err_response) => Err(Error::TokenErrorResponse {
189 status: status.as_u16(),
190 error: err_response.error,
191 error_description: err_response.error_description,
192 }),
193 Err(_) => {
194 println!("HTTP error response body (status {}): {}", status.as_u16(), body_text);
197 Err(Error::HttpError(error))
198 }
199 }
200 }
201}
202
203#[async_trait]
204impl TokenSource for OAuth2ServiceAccountTokenSource {
205 async fn token(&self) -> Result<Token, Error> {
206 let iat = OffsetDateTime::now_utc();
207 let exp = iat + time::Duration::hours(1);
208
209 let claims = Claims {
210 iss: self.email.as_ref(),
211 sub: self.sub.as_ref().map(|s| s.as_ref()),
212 scope: Some(self.scopes.as_ref()),
213 aud: self.token_url.as_ref(),
214 exp: exp.unix_timestamp(),
215 iat: iat.unix_timestamp(),
216 private_claims: &self.private_claims,
217 };
218 let request_token = claims.token(&self.pk, &self.pk_id)?;
219
220 let form = [
221 ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
222 ("assertion", request_token.as_str()),
223 ];
224
225 match self.use_id_token {
226 true => {
227 let audience = claims
228 .private_claims
229 .get("target_audience")
230 .ok_or(Error::NoTargetAudienceFound)?
231 .as_str()
232 .ok_or(Error::NoTargetAudienceFound)?;
233 let response = self.client.post(self.token_url.as_str()).form(&form).send().await?;
234 Ok(Self::check_response_status(response)
235 .await?
236 .json::<InternalIdToken>()
237 .await?
238 .to_token(audience)?)
239 }
240 false => {
241 let response = self.client.post(self.token_url.as_str()).form(&form).send().await?;
242 Ok(Self::check_response_status(response)
243 .await?
244 .json::<InternalToken>()
245 .await?
246 .to_token(iat))
247 }
248 }
249 }
250}