reqsign_google/
sign_request.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use http::header;
19use jsonwebtoken::{Algorithm, EncodingKey, Header as JwtHeader};
20use log::debug;
21use percent_encoding::{percent_decode_str, utf8_percent_encode};
22use rand::thread_rng;
23use rsa::pkcs1v15::SigningKey;
24use rsa::pkcs8::DecodePrivateKey;
25use rsa::signature::RandomizedSigner;
26use serde::{Deserialize, Serialize};
27use std::borrow::Cow;
28use std::time::Duration;
29
30use reqsign_core::{
31    Context, Result, SignRequest, SigningCredential, SigningMethod, SigningRequest,
32    hash::hex_sha256, time::*,
33};
34
35use crate::constants::{DEFAULT_SCOPE, GOOG_QUERY_ENCODE_SET, GOOG_URI_ENCODE_SET, GOOGLE_SCOPE};
36use crate::credential::{Credential, ServiceAccount, Token};
37
38/// Claims is used to build JWT for Google Cloud.
39#[derive(Debug, Serialize)]
40struct Claims {
41    iss: String,
42    scope: String,
43    aud: String,
44    exp: u64,
45    iat: u64,
46}
47
48impl Claims {
49    fn new(client_email: &str, scope: &str) -> Self {
50        let current = Timestamp::now().as_second() as u64;
51
52        Claims {
53            iss: client_email.to_string(),
54            scope: scope.to_string(),
55            aud: "https://oauth2.googleapis.com/token".to_string(),
56            exp: current + 3600,
57            iat: current,
58        }
59    }
60}
61
62/// OAuth2 token response.
63#[derive(Deserialize)]
64struct TokenResponse {
65    access_token: String,
66    #[serde(default)]
67    expires_in: Option<u64>,
68}
69
70/// RequestSigner for Google service requests.
71#[derive(Debug)]
72pub struct RequestSigner {
73    service: String,
74    region: String,
75    scope: Option<String>,
76}
77
78impl Default for RequestSigner {
79    fn default() -> Self {
80        Self {
81            service: String::new(),
82            region: "auto".to_string(),
83            scope: None,
84        }
85    }
86}
87
88impl RequestSigner {
89    /// Create a new builder with the specified service.
90    pub fn new(service: impl Into<String>) -> Self {
91        Self {
92            service: service.into(),
93            region: "auto".to_string(),
94            scope: None,
95        }
96    }
97
98    /// Set the OAuth2 scope.
99    pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
100        self.scope = Some(scope.into());
101        self
102    }
103
104    /// Set the region for the builder.
105    pub fn with_region(mut self, region: impl Into<String>) -> Self {
106        self.region = region.into();
107        self
108    }
109
110    /// Exchange a service account for an access token.
111    ///
112    /// This method is used internally when a token is needed but only a service account
113    /// is available. It creates a JWT and exchanges it for an OAuth2 access token.
114    async fn exchange_token(&self, ctx: &Context, sa: &ServiceAccount) -> Result<Token> {
115        let scope = self
116            .scope
117            .clone()
118            .or_else(|| ctx.env_var(GOOGLE_SCOPE))
119            .unwrap_or_else(|| DEFAULT_SCOPE.to_string());
120
121        debug!("exchanging service account for token with scope: {scope}");
122
123        // Create JWT
124        let jwt = jsonwebtoken::encode(
125            &JwtHeader::new(Algorithm::RS256),
126            &Claims::new(&sa.client_email, &scope),
127            &EncodingKey::from_rsa_pem(sa.private_key.as_bytes()).map_err(|e| {
128                reqsign_core::Error::unexpected("failed to parse RSA private key").with_source(e)
129            })?,
130        )
131        .map_err(|e| reqsign_core::Error::unexpected("failed to encode JWT").with_source(e))?;
132
133        // Exchange JWT for access token
134        let body =
135            format!("grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer&assertion={jwt}");
136        let req = http::Request::builder()
137            .method(http::Method::POST)
138            .uri("https://oauth2.googleapis.com/token")
139            .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
140            .body(body.into_bytes().into())
141            .map_err(|e| {
142                reqsign_core::Error::unexpected("failed to build HTTP request").with_source(e)
143            })?;
144
145        let resp = ctx.http_send(req).await?;
146
147        if resp.status() != http::StatusCode::OK {
148            let body = String::from_utf8_lossy(resp.body());
149            return Err(reqsign_core::Error::unexpected(format!(
150                "exchange token failed: {body}"
151            )));
152        }
153
154        let token_resp: TokenResponse = serde_json::from_slice(resp.body()).map_err(|e| {
155            reqsign_core::Error::unexpected("failed to parse token response").with_source(e)
156        })?;
157
158        let expires_at = token_resp
159            .expires_in
160            .map(|expires_in| Timestamp::now() + Duration::from_secs(expires_in));
161
162        Ok(Token {
163            access_token: token_resp.access_token,
164            expires_at,
165        })
166    }
167
168    fn build_token_auth(
169        &self,
170        parts: &mut http::request::Parts,
171        token: &Token,
172    ) -> Result<SigningRequest> {
173        let mut req = SigningRequest::build(parts)?;
174
175        req.headers.insert(header::AUTHORIZATION, {
176            let mut value: http::HeaderValue = format!("Bearer {}", &token.access_token)
177                .parse()
178                .map_err(|e| {
179                    reqsign_core::Error::unexpected("failed to parse header value").with_source(e)
180                })?;
181            value.set_sensitive(true);
182            value
183        });
184
185        Ok(req)
186    }
187
188    fn build_signed_query(
189        &self,
190        _ctx: &Context,
191        parts: &mut http::request::Parts,
192        service_account: &ServiceAccount,
193        expires_in: Duration,
194    ) -> Result<SigningRequest> {
195        let mut req = SigningRequest::build(parts)?;
196        let now = Timestamp::now();
197
198        // Canonicalize headers
199        canonicalize_header(&mut req)?;
200
201        // Canonicalize query
202        canonicalize_query(
203            &mut req,
204            SigningMethod::Query(expires_in),
205            service_account,
206            now,
207            &self.service,
208            &self.region,
209        )?;
210
211        // Build canonical request string
212        let creq = canonical_request_string(&mut req)?;
213        let encoded_req = hex_sha256(creq.as_bytes());
214
215        // Build scope
216        let scope = format!(
217            "{}/{}/{}/goog4_request",
218            now.format_date(),
219            self.region,
220            self.service
221        );
222        debug!("calculated scope: {scope}");
223
224        // Build string to sign
225        let string_to_sign = {
226            let mut f = String::new();
227            f.push_str("GOOG4-RSA-SHA256");
228            f.push('\n');
229            f.push_str(&now.format_iso8601());
230            f.push('\n');
231            f.push_str(&scope);
232            f.push('\n');
233            f.push_str(&encoded_req);
234            f
235        };
236        debug!("calculated string to sign: {string_to_sign}");
237
238        // Sign the string
239        let mut rng = thread_rng();
240        let private_key = rsa::RsaPrivateKey::from_pkcs8_pem(&service_account.private_key)
241            .map_err(|e| {
242                reqsign_core::Error::unexpected("failed to parse private key").with_source(e)
243            })?;
244        let signing_key = SigningKey::<sha2::Sha256>::new(private_key);
245        let signature = signing_key.sign_with_rng(&mut rng, string_to_sign.as_bytes());
246
247        req.query
248            .push(("X-Goog-Signature".to_string(), signature.to_string()));
249
250        Ok(req)
251    }
252}
253
254#[async_trait::async_trait]
255impl SignRequest for RequestSigner {
256    type Credential = Credential;
257
258    async fn sign_request(
259        &self,
260        ctx: &Context,
261        req: &mut http::request::Parts,
262        credential: Option<&Self::Credential>,
263        expires_in: Option<Duration>,
264    ) -> Result<()> {
265        let cred = credential
266            .ok_or_else(|| reqsign_core::Error::credential_invalid("missing credential"))?;
267
268        let signing_req = match expires_in {
269            // Query signing - must use ServiceAccount
270            Some(expires) => {
271                let sa = cred.service_account.as_ref().ok_or_else(|| {
272                    reqsign_core::Error::credential_invalid(
273                        "service account required for query signing",
274                    )
275                })?;
276                self.build_signed_query(ctx, req, sa, expires)?
277            }
278            // Header authentication - prefer valid token, otherwise exchange from SA
279            None => {
280                // Check if we have a valid token
281                if let Some(token) = &cred.token {
282                    if token.is_valid() {
283                        self.build_token_auth(req, token)?
284                    } else if let Some(sa) = &cred.service_account {
285                        // Token expired, but we have SA, exchange for new token
286                        debug!("token expired, exchanging service account for new token");
287                        let new_token = self.exchange_token(ctx, sa).await?;
288                        self.build_token_auth(req, &new_token)?
289                    } else {
290                        return Err(reqsign_core::Error::credential_invalid(
291                            "token expired and no service account available",
292                        ));
293                    }
294                } else if let Some(sa) = &cred.service_account {
295                    // No token but have SA, exchange for token
296                    debug!("no token available, exchanging service account for token");
297                    let token = self.exchange_token(ctx, sa).await?;
298                    self.build_token_auth(req, &token)?
299                } else {
300                    return Err(reqsign_core::Error::credential_invalid(
301                        "no valid credential available",
302                    ));
303                }
304            }
305        };
306
307        signing_req.apply(req).map_err(|e| {
308            reqsign_core::Error::unexpected("failed to apply signing request").with_source(e)
309        })
310    }
311}
312
313fn canonical_request_string(req: &mut SigningRequest) -> Result<String> {
314    // 256 is specially chosen to avoid reallocation for most requests.
315    let mut f = String::with_capacity(256);
316
317    // Insert method
318    f.push_str(req.method.as_str());
319    f.push('\n');
320
321    // Insert encoded path
322    let path = percent_decode_str(&req.path)
323        .decode_utf8()
324        .map_err(|e| reqsign_core::Error::unexpected("failed to decode path").with_source(e))?;
325    f.push_str(&Cow::from(utf8_percent_encode(&path, &GOOG_URI_ENCODE_SET)));
326    f.push('\n');
327
328    // Insert query
329    f.push_str(&SigningRequest::query_to_string(
330        req.query.clone(),
331        "=",
332        "&",
333    ));
334    f.push('\n');
335
336    // Insert signed headers
337    let signed_headers = req.header_name_to_vec_sorted();
338    for header in signed_headers.iter() {
339        let value = &req.headers[*header];
340        f.push_str(header);
341        f.push(':');
342        f.push_str(value.to_str().expect("header value must be valid"));
343        f.push('\n');
344    }
345    f.push('\n');
346    f.push_str(&signed_headers.join(";"));
347    f.push('\n');
348    f.push_str("UNSIGNED-PAYLOAD");
349
350    debug!("canonical request string: {f}");
351    Ok(f)
352}
353
354fn canonicalize_header(req: &mut SigningRequest) -> Result<()> {
355    for (_, value) in req.headers.iter_mut() {
356        SigningRequest::header_value_normalize(value)
357    }
358
359    // Insert HOST header if not present.
360    if req.headers.get(header::HOST).is_none() {
361        req.headers.insert(
362            header::HOST,
363            req.authority.as_str().parse().map_err(|e| {
364                reqsign_core::Error::unexpected("failed to parse host header").with_source(e)
365            })?,
366        );
367    }
368
369    Ok(())
370}
371
372fn canonicalize_query(
373    req: &mut SigningRequest,
374    method: SigningMethod,
375    cred: &ServiceAccount,
376    now: Timestamp,
377    service: &str,
378    region: &str,
379) -> Result<()> {
380    if let SigningMethod::Query(expire) = method {
381        req.query
382            .push(("X-Goog-Algorithm".into(), "GOOG4-RSA-SHA256".into()));
383        req.query.push((
384            "X-Goog-Credential".into(),
385            format!(
386                "{}/{}/{}/{}/goog4_request",
387                &cred.client_email,
388                now.format_date(),
389                region,
390                service
391            ),
392        ));
393        req.query.push(("X-Goog-Date".into(), now.format_iso8601()));
394        req.query
395            .push(("X-Goog-Expires".into(), expire.as_secs().to_string()));
396        req.query.push((
397            "X-Goog-SignedHeaders".into(),
398            req.header_name_to_vec_sorted().join(";"),
399        ));
400    }
401
402    // Return if query is empty.
403    if req.query.is_empty() {
404        return Ok(());
405    }
406
407    // Sort by param name
408    req.query.sort();
409
410    req.query = req
411        .query
412        .iter()
413        .map(|(k, v)| {
414            (
415                utf8_percent_encode(k, &GOOG_QUERY_ENCODE_SET).to_string(),
416                utf8_percent_encode(v, &GOOG_QUERY_ENCODE_SET).to_string(),
417            )
418        })
419        .collect();
420
421    Ok(())
422}