gcloud_auth/token_source/
service_account_token_source.rs

1use std::collections::HashMap;
2use std::fmt::Debug;
3
4use async_trait::async_trait;
5use reqwest::Response;
6use serde::{Deserialize, Serialize};
7use time::OffsetDateTime;
8
9use crate::credentials;
10use crate::error::{Error, TokenErrorResponse};
11use crate::misc::UnwrapOrEmpty;
12use crate::token::{Token, TOKEN_URL};
13use crate::token_source::{default_http_client, InternalIdToken, InternalToken, TokenSource};
14
15#[derive(Clone, Serialize)]
16struct Claims<'a> {
17    iss: &'a str,
18    sub: Option<&'a str>,
19    scope: Option<&'a str>,
20    aud: &'a str,
21    exp: i64,
22    iat: i64,
23    #[serde(flatten)]
24    private_claims: &'a HashMap<String, serde_json::Value>,
25}
26
27impl Claims<'_> {
28    fn token(&self, pk: &jsonwebtoken::EncodingKey, pk_id: &str) -> Result<String, Error> {
29        let mut header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256);
30        header.kid = Some(pk_id.to_string());
31        let v = jsonwebtoken::encode(&header, self, pk)?;
32        Ok(v)
33    }
34}
35
36// Does not use any OAuth2 flow but instead creates a JWT and sends that as the access token.
37// The audience is typically a URL that specifies the scope of the credentials.
38// see golang.org/x/oauth2/gen/jwt.go
39#[allow(dead_code)]
40pub struct ServiceAccountTokenSource {
41    email: String,
42    pk: jsonwebtoken::EncodingKey,
43    pk_id: String,
44    audience: String,
45    use_id_token: bool,
46    private_claims: HashMap<String, serde_json::Value>,
47}
48
49impl Debug for ServiceAccountTokenSource {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        // jwt::EncodingKey does not implement Debug
52        f.debug_struct("ServiceAccountTokenSource")
53            .field("email", &self.email)
54            .field("pk_id", &self.pk_id)
55            .field("audience", &self.audience)
56            .finish()
57    }
58}
59
60impl ServiceAccountTokenSource {
61    pub(crate) fn new(cred: &credentials::CredentialsFile, audience: &str) -> Result<ServiceAccountTokenSource, Error> {
62        Ok(ServiceAccountTokenSource {
63            email: cred.client_email.unwrap_or_empty(),
64            pk: cred.try_to_private_key()?,
65            pk_id: cred.private_key_id.unwrap_or_empty(),
66            audience: match &cred.audience {
67                None => audience.to_string(),
68                Some(s) => s.to_string(),
69            },
70            use_id_token: false,
71            private_claims: HashMap::new(),
72        })
73    }
74}
75
76#[async_trait]
77impl TokenSource for ServiceAccountTokenSource {
78    async fn token(&self) -> Result<Token, Error> {
79        let iat = OffsetDateTime::now_utc();
80        let exp = iat + time::Duration::hours(1);
81
82        let token = Claims {
83            iss: self.email.as_ref(),
84            sub: Some(self.email.as_ref()),
85            scope: None,
86            aud: self.audience.as_ref(),
87            exp: exp.unix_timestamp(),
88            iat: iat.unix_timestamp(),
89            private_claims: &HashMap::new(),
90        }
91        .token(&self.pk, &self.pk_id)?;
92
93        return Ok(Token {
94            access_token: token,
95            token_type: "Bearer".to_string(),
96            expiry: Some(exp),
97        });
98    }
99}
100
101#[allow(dead_code)]
102#[derive(Clone, Deserialize)]
103struct OAuth2Token {
104    pub access_token: String,
105    pub token_type: String,
106    pub id_token: Option<String>,
107    pub expires_in: Option<i64>,
108}
109
110//jwt implements the OAuth 2.0 JSON Web Token flow
111pub struct OAuth2ServiceAccountTokenSource {
112    pub email: String,
113    pub pk: jsonwebtoken::EncodingKey,
114    pub pk_id: String,
115    pub scopes: String,
116    pub token_url: String,
117    pub sub: Option<String>,
118
119    pub client: reqwest::Client,
120
121    use_id_token: bool,
122    private_claims: HashMap<String, serde_json::Value>,
123}
124
125impl Debug for OAuth2ServiceAccountTokenSource {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        // jwt::EncodingKey does not implement Debug
128        f.debug_struct("OAuth2ServiceAccountTokenSource")
129            .field("email", &self.email)
130            .field("pk_id", &self.pk_id)
131            .field("scopes", &self.scopes)
132            .field("token_url", &self.token_url)
133            .field("sub", &self.sub)
134            .field("client", &self.client)
135            .field("use_id_token", &self.use_id_token)
136            .field("private_claims", &self.private_claims)
137            .finish()
138    }
139}
140
141impl OAuth2ServiceAccountTokenSource {
142    pub(crate) fn new(
143        cred: &credentials::CredentialsFile,
144        scopes: &str,
145        sub: Option<&str>,
146    ) -> Result<OAuth2ServiceAccountTokenSource, Error> {
147        Ok(OAuth2ServiceAccountTokenSource {
148            email: cred.client_email.unwrap_or_empty(),
149            pk: cred.try_to_private_key()?,
150            pk_id: cred.private_key_id.unwrap_or_empty(),
151            scopes: scopes.to_string(),
152            token_url: match &cred.token_uri {
153                None => TOKEN_URL.to_string(),
154                Some(s) => s.to_string(),
155            },
156            client: default_http_client(),
157            sub: sub.map(|s| s.to_string()),
158            use_id_token: false,
159            private_claims: HashMap::new(),
160        })
161    }
162
163    pub(crate) fn with_use_id_token(mut self) -> Self {
164        self.use_id_token = true;
165        self
166    }
167
168    pub(crate) fn with_private_claims(mut self, claims: HashMap<String, serde_json::Value>) -> Self {
169        self.private_claims = claims;
170        self
171    }
172
173    /// Checks whether an HTTP response is successful and returns it, or returns an error.
174    async fn check_response_status(response: Response) -> Result<Response, Error> {
175        // Check the status code, returning the response if it is not an error.
176        let error = match response.error_for_status_ref() {
177            Ok(_) => return Ok(response),
178            Err(error) => error,
179        };
180
181        // try to extract a response error, falling back to the status error if it can not be parsed.
182        let status = response.status();
183        Err(response
184            .json::<TokenErrorResponse>()
185            .await
186            .map(|response| Error::TokenErrorResponse {
187                status: status.as_u16(),
188                error: response.error,
189                error_description: response.error_description,
190            })
191            .unwrap_or(Error::HttpError(error)))
192    }
193}
194
195#[async_trait]
196impl TokenSource for OAuth2ServiceAccountTokenSource {
197    async fn token(&self) -> Result<Token, Error> {
198        let iat = OffsetDateTime::now_utc();
199        let exp = iat + time::Duration::hours(1);
200
201        let claims = Claims {
202            iss: self.email.as_ref(),
203            sub: self.sub.as_ref().map(|s| s.as_ref()),
204            scope: Some(self.scopes.as_ref()),
205            aud: self.token_url.as_ref(),
206            exp: exp.unix_timestamp(),
207            iat: iat.unix_timestamp(),
208            private_claims: &self.private_claims,
209        };
210        let request_token = claims.token(&self.pk, &self.pk_id)?;
211
212        let form = [
213            ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
214            ("assertion", request_token.as_str()),
215        ];
216
217        match self.use_id_token {
218            true => {
219                let audience = claims
220                    .private_claims
221                    .get("target_audience")
222                    .ok_or(Error::NoTargetAudienceFound)?
223                    .as_str()
224                    .ok_or(Error::NoTargetAudienceFound)?;
225                let response = self.client.post(self.token_url.as_str()).form(&form).send().await?;
226                Ok(Self::check_response_status(response)
227                    .await?
228                    .json::<InternalIdToken>()
229                    .await?
230                    .to_token(audience)?)
231            }
232            false => {
233                let response = self.client.post(self.token_url.as_str()).form(&form).send().await?;
234                Ok(Self::check_response_status(response)
235                    .await?
236                    .json::<InternalToken>()
237                    .await?
238                    .to_token(iat))
239            }
240        }
241    }
242}