gcloud_auth_patch/token_source/
service_account_token_source.rs

1use std::collections::HashMap;
2use std::fmt::Debug;
3
4use async_trait::async_trait;
5use reqwest::Response;
6use serde::{Deserialize, Serialize};
7use time::OffsetDateTime;
8
9use crate::credentials;
10use crate::error::{Error, TokenErrorResponse};
11use crate::misc::UnwrapOrEmpty;
12use crate::token::{Token, TOKEN_URL};
13use crate::token_source::{default_http_client, InternalIdToken, InternalToken, TokenSource};
14
15#[derive(Clone, Serialize)]
16struct Claims<'a> {
17    iss: &'a str,
18    sub: Option<&'a str>,
19    scope: Option<&'a str>,
20    aud: &'a str,
21    exp: i64,
22    iat: i64,
23    #[serde(flatten)]
24    private_claims: &'a HashMap<String, serde_json::Value>,
25}
26
27impl Claims<'_> {
28    fn token(&self, pk: &jsonwebtoken::EncodingKey, pk_id: &str) -> Result<String, Error> {
29        let mut header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256);
30        header.kid = Some(pk_id.to_string());
31        let v = jsonwebtoken::encode(&header, self, pk)?;
32        Ok(v)
33    }
34}
35
36// Does not use any OAuth2 flow but instead creates a JWT and sends that as the access token.
37// The audience is typically a URL that specifies the scope of the credentials.
38// see golang.org/x/oauth2/gen/jwt.go
39#[allow(dead_code)]
40pub struct ServiceAccountTokenSource {
41    email: String,
42    pk: jsonwebtoken::EncodingKey,
43    pk_id: String,
44    audience: String,
45    use_id_token: bool,
46    private_claims: HashMap<String, serde_json::Value>,
47}
48
49impl Debug for ServiceAccountTokenSource {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        // jwt::EncodingKey does not implement Debug
52        f.debug_struct("ServiceAccountTokenSource")
53            .field("email", &self.email)
54            .field("pk_id", &self.pk_id)
55            .field("audience", &self.audience)
56            .finish()
57    }
58}
59
60impl ServiceAccountTokenSource {
61    pub(crate) fn new(cred: &credentials::CredentialsFile, audience: &str) -> Result<ServiceAccountTokenSource, Error> {
62        Ok(ServiceAccountTokenSource {
63            email: cred.client_email.unwrap_or_empty(),
64            pk: cred.try_to_private_key()?,
65            pk_id: cred.private_key_id.unwrap_or_empty(),
66            audience: match &cred.audience {
67                None => audience.to_string(),
68                Some(s) => s.to_string(),
69            },
70            use_id_token: false,
71            private_claims: HashMap::new(),
72        })
73    }
74}
75
76#[async_trait]
77impl TokenSource for ServiceAccountTokenSource {
78    async fn token(&self) -> Result<Token, Error> {
79        let iat = OffsetDateTime::now_utc();
80        let exp = iat + time::Duration::hours(1);
81
82        let token = Claims {
83            iss: self.email.as_ref(),
84            sub: Some(self.email.as_ref()),
85            scope: None,
86            aud: self.audience.as_ref(),
87            exp: exp.unix_timestamp(),
88            iat: iat.unix_timestamp(),
89            private_claims: &HashMap::new(),
90        }
91        .token(&self.pk, &self.pk_id)?;
92
93        return Ok(Token {
94            access_token: token,
95            token_type: "Bearer".to_string(),
96            expiry: Some(exp),
97        });
98    }
99}
100
101#[allow(dead_code)]
102#[derive(Clone, Deserialize)]
103struct OAuth2Token {
104    pub access_token: String,
105    pub token_type: String,
106    pub id_token: Option<String>,
107    pub expires_in: Option<i64>,
108}
109
110//jwt implements the OAuth 2.0 JSON Web Token flow
111pub struct OAuth2ServiceAccountTokenSource {
112    pub email: String,
113    pub pk: jsonwebtoken::EncodingKey,
114    pub pk_id: String,
115    pub scopes: String,
116    pub token_url: String,
117    pub sub: Option<String>,
118
119    pub client: reqwest::Client,
120
121    use_id_token: bool,
122    private_claims: HashMap<String, serde_json::Value>,
123}
124
125impl Debug for OAuth2ServiceAccountTokenSource {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        // jwt::EncodingKey does not implement Debug
128        f.debug_struct("OAuth2ServiceAccountTokenSource")
129            .field("email", &self.email)
130            .field("pk_id", &self.pk_id)
131            .field("scopes", &self.scopes)
132            .field("token_url", &self.token_url)
133            .field("sub", &self.sub)
134            .field("client", &self.client)
135            .field("use_id_token", &self.use_id_token)
136            .field("private_claims", &self.private_claims)
137            .finish()
138    }
139}
140
141impl OAuth2ServiceAccountTokenSource {
142    pub(crate) fn new(
143        cred: &credentials::CredentialsFile,
144        scopes: &str,
145        sub: Option<&str>,
146    ) -> Result<OAuth2ServiceAccountTokenSource, Error> {
147        Ok(OAuth2ServiceAccountTokenSource {
148            email: cred.client_email.unwrap_or_empty(),
149            pk: cred.try_to_private_key()?,
150            pk_id: cred.private_key_id.unwrap_or_empty(),
151            scopes: scopes.to_string(),
152            token_url: match &cred.token_uri {
153                None => TOKEN_URL.to_string(),
154                Some(s) => s.to_string(),
155            },
156            client: default_http_client(),
157            sub: sub.map(|s| s.to_string()),
158            use_id_token: false,
159            private_claims: HashMap::new(),
160        })
161    }
162
163    pub(crate) fn with_use_id_token(mut self) -> Self {
164        self.use_id_token = true;
165        self
166    }
167
168    pub(crate) fn with_private_claims(mut self, claims: HashMap<String, serde_json::Value>) -> Self {
169        self.private_claims = claims;
170        self
171    }
172
173    /// Checks whether an HTTP response is successful and returns it, or returns an error.
174    async fn check_response_status(response: Response) -> Result<Response, Error> {
175        // Check the status code, returning the response if it is not an error.
176        let error = match response.error_for_status_ref() {
177            Ok(_) => return Ok(response),
178            Err(error) => error,
179        };
180
181        // try to extract a response error, falling back to the status error if it can not be parsed.
182        let status = response.status();
183
184        // Read the body as text first so we can both attempt JSON parse and log the raw body if parsing fails.
185        let body_text = response.text().await.unwrap_or_else(|_| "".to_string());
186
187        match serde_json::from_str::<TokenErrorResponse>(&body_text) {
188            Ok(err_response) => Err(Error::TokenErrorResponse {
189                status: status.as_u16(),
190                error: err_response.error,
191                error_description: err_response.error_description,
192            }),
193            Err(_) => {
194                // Parsing the body as TokenErrorResponse failed; log the raw body to stdout for debugging
195                // and return the original HTTP error.
196                println!("HTTP error response body (status {}): {}", status.as_u16(), body_text);
197                Err(Error::HttpError(error))
198            }
199        }
200    }
201}
202
203#[async_trait]
204impl TokenSource for OAuth2ServiceAccountTokenSource {
205    async fn token(&self) -> Result<Token, Error> {
206        let iat = OffsetDateTime::now_utc();
207        let exp = iat + time::Duration::hours(1);
208
209        let claims = Claims {
210            iss: self.email.as_ref(),
211            sub: self.sub.as_ref().map(|s| s.as_ref()),
212            scope: Some(self.scopes.as_ref()),
213            aud: self.token_url.as_ref(),
214            exp: exp.unix_timestamp(),
215            iat: iat.unix_timestamp(),
216            private_claims: &self.private_claims,
217        };
218        let request_token = claims.token(&self.pk, &self.pk_id)?;
219
220        let form = [
221            ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
222            ("assertion", request_token.as_str()),
223        ];
224
225        match self.use_id_token {
226            true => {
227                let audience = claims
228                    .private_claims
229                    .get("target_audience")
230                    .ok_or(Error::NoTargetAudienceFound)?
231                    .as_str()
232                    .ok_or(Error::NoTargetAudienceFound)?;
233                let response = self.client.post(self.token_url.as_str()).form(&form).send().await?;
234                Ok(Self::check_response_status(response)
235                    .await?
236                    .json::<InternalIdToken>()
237                    .await?
238                    .to_token(audience)?)
239            }
240            false => {
241                let response = self.client.post(self.token_url.as_str()).form(&form).send().await?;
242                Ok(Self::check_response_status(response)
243                    .await?
244                    .json::<InternalToken>()
245                    .await?
246                    .to_token(iat))
247            }
248        }
249    }
250}