reqsign_google/
sign_request.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use http::header;
19use jsonwebtoken::{Algorithm, EncodingKey, Header as JwtHeader};
20use log::debug;
21use percent_encoding::{percent_decode_str, utf8_percent_encode};
22use rand::thread_rng;
23use rsa::pkcs1v15::SigningKey;
24use rsa::pkcs8::DecodePrivateKey;
25use rsa::signature::RandomizedSigner;
26use serde::{Deserialize, Serialize};
27use std::borrow::Cow;
28use std::time::Duration;
29
30use reqsign_core::{
31    Context, Result, SignRequest, SigningCredential, SigningMethod, SigningRequest,
32    hash::hex_sha256, time::*,
33};
34
35use crate::constants::{DEFAULT_SCOPE, GOOG_QUERY_ENCODE_SET, GOOG_URI_ENCODE_SET, GOOGLE_SCOPE};
36use crate::credential::{Credential, ServiceAccount, Token};
37
38/// Claims is used to build JWT for Google Cloud.
39#[derive(Debug, Serialize)]
40struct Claims {
41    iss: String,
42    scope: String,
43    aud: String,
44    exp: u64,
45    iat: u64,
46}
47
48impl Claims {
49    fn new(client_email: &str, scope: &str) -> Self {
50        let current = Timestamp::now().as_second() as u64;
51
52        Claims {
53            iss: client_email.to_string(),
54            scope: scope.to_string(),
55            aud: "https://oauth2.googleapis.com/token".to_string(),
56            exp: current + 3600,
57            iat: current,
58        }
59    }
60}
61
62/// OAuth2 token response.
63#[derive(Deserialize)]
64struct TokenResponse {
65    access_token: String,
66    #[serde(default)]
67    expires_in: Option<u64>,
68}
69
70/// RequestSigner for Google service requests.
71#[derive(Debug)]
72pub struct RequestSigner {
73    service: String,
74    region: String,
75    scope: Option<String>,
76}
77
78impl Default for RequestSigner {
79    fn default() -> Self {
80        Self {
81            service: String::new(),
82            region: "auto".to_string(),
83            scope: None,
84        }
85    }
86}
87
88impl RequestSigner {
89    /// Create a new builder with the specified service.
90    pub fn new(service: impl Into<String>) -> Self {
91        Self {
92            service: service.into(),
93            region: "auto".to_string(),
94            scope: None,
95        }
96    }
97
98    /// Set the OAuth2 scope.
99    pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
100        self.scope = Some(scope.into());
101        self
102    }
103
104    /// Set the region for the builder.
105    pub fn with_region(mut self, region: impl Into<String>) -> Self {
106        self.region = region.into();
107        self
108    }
109
110    /// Exchange a service account for an access token.
111    ///
112    /// This method is used internally when a token is needed but only a service account
113    /// is available. It creates a JWT and exchanges it for an OAuth2 access token.
114    async fn exchange_token(&self, ctx: &Context, sa: &ServiceAccount) -> Result<Token> {
115        let scope = self
116            .scope
117            .clone()
118            .or_else(|| ctx.env_var(GOOGLE_SCOPE))
119            .unwrap_or_else(|| DEFAULT_SCOPE.to_string());
120
121        debug!("exchanging service account for token with scope: {scope}");
122
123        // Create JWT
124        let jwt = jsonwebtoken::encode(
125            &JwtHeader::new(Algorithm::RS256),
126            &Claims::new(&sa.client_email, &scope),
127            &EncodingKey::from_rsa_pem(sa.private_key.as_bytes()).map_err(|e| {
128                reqsign_core::Error::unexpected("failed to parse RSA private key").with_source(e)
129            })?,
130        )
131        .map_err(|e| reqsign_core::Error::unexpected("failed to encode JWT").with_source(e))?;
132
133        // Exchange JWT for access token
134        let body =
135            format!("grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer&assertion={jwt}");
136        let req = http::Request::builder()
137            .method(http::Method::POST)
138            .uri("https://oauth2.googleapis.com/token")
139            .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
140            .body(body.into_bytes().into())
141            .map_err(|e| {
142                reqsign_core::Error::unexpected("failed to build HTTP request").with_source(e)
143            })?;
144
145        let resp = ctx.http_send(req).await?;
146
147        if resp.status() != http::StatusCode::OK {
148            let body = String::from_utf8_lossy(resp.body());
149            return Err(reqsign_core::Error::unexpected(format!(
150                "exchange token failed: {body}"
151            )));
152        }
153
154        let token_resp: TokenResponse = serde_json::from_slice(resp.body()).map_err(|e| {
155            reqsign_core::Error::unexpected("failed to parse token response").with_source(e)
156        })?;
157
158        let expires_at = token_resp
159            .expires_in
160            .map(|expires_in| Timestamp::now() + Duration::from_secs(expires_in));
161
162        Ok(Token {
163            access_token: token_resp.access_token,
164            expires_at,
165        })
166    }
167
168    fn build_token_auth(
169        &self,
170        parts: &mut http::request::Parts,
171        token: &Token,
172    ) -> Result<SigningRequest> {
173        let mut req = SigningRequest::build(parts)?;
174
175        req.headers.insert(header::AUTHORIZATION, {
176            let mut value: http::HeaderValue = format!("Bearer {}", &token.access_token)
177                .parse()
178                .map_err(|e| {
179                    reqsign_core::Error::unexpected("failed to parse header value").with_source(e)
180                })?;
181            value.set_sensitive(true);
182            value
183        });
184
185        Ok(req)
186    }
187
188    fn build_signed_query(
189        &self,
190        _ctx: &Context,
191        parts: &mut http::request::Parts,
192        service_account: &ServiceAccount,
193        expires_in: Duration,
194    ) -> Result<SigningRequest> {
195        let mut req = SigningRequest::build(parts)?;
196        let now = Timestamp::now();
197
198        // Canonicalize headers
199        canonicalize_header(&mut req)?;
200
201        // Canonicalize query
202        canonicalize_query(
203            &mut req,
204            SigningMethod::Query(expires_in),
205            service_account,
206            now,
207            &self.service,
208            &self.region,
209        )?;
210
211        // Build canonical request string
212        let creq = canonical_request_string(&mut req)?;
213        let encoded_req = hex_sha256(creq.as_bytes());
214
215        // Build scope
216        let scope = format!(
217            "{}/{}/{}/goog4_request",
218            now.format_date(),
219            self.region,
220            self.service
221        );
222        debug!("calculated scope: {scope}");
223
224        // Build string to sign
225        let string_to_sign = {
226            let mut f = String::new();
227            f.push_str("GOOG4-RSA-SHA256");
228            f.push('\n');
229            f.push_str(&now.format_iso8601());
230            f.push('\n');
231            f.push_str(&scope);
232            f.push('\n');
233            f.push_str(&encoded_req);
234            f
235        };
236        debug!("calculated string to sign: {string_to_sign}");
237
238        // Sign the string
239        let mut rng = thread_rng();
240        let private_key = rsa::RsaPrivateKey::from_pkcs8_pem(&service_account.private_key)
241            .map_err(|e| {
242                reqsign_core::Error::unexpected("failed to parse private key").with_source(e)
243            })?;
244        let signing_key = SigningKey::<sha2::Sha256>::new(private_key);
245        let signature = signing_key.sign_with_rng(&mut rng, string_to_sign.as_bytes());
246
247        req.query
248            .push(("X-Goog-Signature".to_string(), signature.to_string()));
249
250        Ok(req)
251    }
252}
253
254#[async_trait::async_trait]
255impl SignRequest for RequestSigner {
256    type Credential = Credential;
257
258    async fn sign_request(
259        &self,
260        ctx: &Context,
261        req: &mut http::request::Parts,
262        credential: Option<&Self::Credential>,
263        expires_in: Option<Duration>,
264    ) -> Result<()> {
265        let Some(cred) = credential else {
266            return Ok(());
267        };
268
269        let signing_req = match expires_in {
270            // Query signing - must use ServiceAccount
271            Some(expires) => {
272                let sa = cred.service_account.as_ref().ok_or_else(|| {
273                    reqsign_core::Error::credential_invalid(
274                        "service account required for query signing",
275                    )
276                })?;
277                self.build_signed_query(ctx, req, sa, expires)?
278            }
279            // Header authentication - prefer valid token, otherwise exchange from SA
280            None => {
281                // Check if we have a valid token
282                if let Some(token) = &cred.token {
283                    if token.is_valid() {
284                        self.build_token_auth(req, token)?
285                    } else if let Some(sa) = &cred.service_account {
286                        // Token expired, but we have SA, exchange for new token
287                        debug!("token expired, exchanging service account for new token");
288                        let new_token = self.exchange_token(ctx, sa).await?;
289                        self.build_token_auth(req, &new_token)?
290                    } else {
291                        return Err(reqsign_core::Error::credential_invalid(
292                            "token expired and no service account available",
293                        ));
294                    }
295                } else if let Some(sa) = &cred.service_account {
296                    // No token but have SA, exchange for token
297                    debug!("no token available, exchanging service account for token");
298                    let token = self.exchange_token(ctx, sa).await?;
299                    self.build_token_auth(req, &token)?
300                } else {
301                    return Err(reqsign_core::Error::credential_invalid(
302                        "no valid credential available",
303                    ));
304                }
305            }
306        };
307
308        signing_req.apply(req).map_err(|e| {
309            reqsign_core::Error::unexpected("failed to apply signing request").with_source(e)
310        })
311    }
312}
313
314fn canonical_request_string(req: &mut SigningRequest) -> Result<String> {
315    // 256 is specially chosen to avoid reallocation for most requests.
316    let mut f = String::with_capacity(256);
317
318    // Insert method
319    f.push_str(req.method.as_str());
320    f.push('\n');
321
322    // Insert encoded path
323    let path = percent_decode_str(&req.path)
324        .decode_utf8()
325        .map_err(|e| reqsign_core::Error::unexpected("failed to decode path").with_source(e))?;
326    f.push_str(&Cow::from(utf8_percent_encode(&path, &GOOG_URI_ENCODE_SET)));
327    f.push('\n');
328
329    // Insert query
330    f.push_str(&SigningRequest::query_to_string(
331        req.query.clone(),
332        "=",
333        "&",
334    ));
335    f.push('\n');
336
337    // Insert signed headers
338    let signed_headers = req.header_name_to_vec_sorted();
339    for header in signed_headers.iter() {
340        let value = &req.headers[*header];
341        f.push_str(header);
342        f.push(':');
343        f.push_str(value.to_str().expect("header value must be valid"));
344        f.push('\n');
345    }
346    f.push('\n');
347    f.push_str(&signed_headers.join(";"));
348    f.push('\n');
349    f.push_str("UNSIGNED-PAYLOAD");
350
351    debug!("canonical request string: {f}");
352    Ok(f)
353}
354
355fn canonicalize_header(req: &mut SigningRequest) -> Result<()> {
356    for (_, value) in req.headers.iter_mut() {
357        SigningRequest::header_value_normalize(value)
358    }
359
360    // Insert HOST header if not present.
361    if req.headers.get(header::HOST).is_none() {
362        req.headers.insert(
363            header::HOST,
364            req.authority.as_str().parse().map_err(|e| {
365                reqsign_core::Error::unexpected("failed to parse host header").with_source(e)
366            })?,
367        );
368    }
369
370    Ok(())
371}
372
373fn canonicalize_query(
374    req: &mut SigningRequest,
375    method: SigningMethod,
376    cred: &ServiceAccount,
377    now: Timestamp,
378    service: &str,
379    region: &str,
380) -> Result<()> {
381    if let SigningMethod::Query(expire) = method {
382        req.query
383            .push(("X-Goog-Algorithm".into(), "GOOG4-RSA-SHA256".into()));
384        req.query.push((
385            "X-Goog-Credential".into(),
386            format!(
387                "{}/{}/{}/{}/goog4_request",
388                &cred.client_email,
389                now.format_date(),
390                region,
391                service
392            ),
393        ));
394        req.query.push(("X-Goog-Date".into(), now.format_iso8601()));
395        req.query
396            .push(("X-Goog-Expires".into(), expire.as_secs().to_string()));
397        req.query.push((
398            "X-Goog-SignedHeaders".into(),
399            req.header_name_to_vec_sorted().join(";"),
400        ));
401    }
402
403    // Return if query is empty.
404    if req.query.is_empty() {
405        return Ok(());
406    }
407
408    // Sort by param name
409    req.query.sort();
410
411    req.query = req
412        .query
413        .iter()
414        .map(|(k, v)| {
415            (
416                utf8_percent_encode(k, &GOOG_QUERY_ENCODE_SET).to_string(),
417                utf8_percent_encode(v, &GOOG_QUERY_ENCODE_SET).to_string(),
418            )
419        })
420        .collect();
421
422    Ok(())
423}