Skip to main content

claude_api/
vertex.rs

1//! Google Cloud Vertex AI support: a [`RequestSigner`] that attaches an
2//! `OAuth2` bearer token so requests are authenticated against the Vertex AI
3//! Anthropic endpoint.
4//!
5//! The URL shape for Vertex AI is:
6//! ```text
7//! https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/anthropic/models/{model}:rawPredict
8//! ```
9//!
10//! Auth is a standard `Authorization: Bearer {token}` header where the token
11//! is a Google `OAuth2` access token.
12//!
13//! # Credential sources
14//!
15//! Two credential sources are supported:
16//!
17//! - **Static token** (`VertexCredentials::from_token`): supply a token
18//!   string directly. Useful for tests and short-lived scripts where you
19//!   already have a token (e.g., from `gcloud auth print-access-token`).
20//!
21//! - **Application Default Credentials** (`VertexCredentials::from_adc`):
22//!   uses [`gcp_auth`] to obtain a token via Application Default Credentials
23//!   (service-account key file, GCE metadata server, `gcloud` CLI, or the
24//!   `GOOGLE_APPLICATION_CREDENTIALS` environment variable). This path
25//!   requires an active Tokio runtime because token refresh is async; it
26//!   calls `Handle::current().block_on(...)` inside `sign()`.
27//!
28//! [`VertexCredentials::from_env`] checks `VERTEX_ACCESS_TOKEN` first (static
29//! token), then falls back to `GOOGLE_APPLICATION_CREDENTIALS` (ADC).
30//!
31//! # Set up the client
32//!
33//! ```no_run
34//! use std::sync::Arc;
35//! use claude_api::{Client, vertex::{VertexCredentials, VertexSigner}};
36//! # fn run() -> Result<(), claude_api::Error> {
37//! let creds = VertexCredentials::from_env()
38//!     .expect("VERTEX_ACCESS_TOKEN or GOOGLE_APPLICATION_CREDENTIALS must be set");
39//! let region = std::env::var("VERTEX_REGION").unwrap_or_else(|_| "us-east5".into());
40//! let project = std::env::var("VERTEX_PROJECT").expect("VERTEX_PROJECT must be set");
41//! let client = Client::builder()
42//!     .signer(Arc::new(VertexSigner::new(creds)))
43//!     .base_url(format!(
44//!         "https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/anthropic"
45//!     ))
46//!     .build()?;
47//! # Ok(())
48//! # }
49//! ```
50//!
51//! Gated on the `vertex` feature.
52
53#![cfg(feature = "vertex")]
54#![cfg_attr(docsrs, doc(cfg(feature = "vertex")))]
55
56use std::sync::Arc;
57
58use gcp_auth::TokenProvider;
59
60use crate::auth::{RequestSigner, SignerResult};
61
62/// The `OAuth2` scope required for Vertex AI API access.
63const VERTEX_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform";
64
65/// The set of scopes passed to the token provider.
66const VERTEX_SCOPES: &[&str] = &[VERTEX_SCOPE];
67
68/// Credential source for Vertex AI authentication.
69///
70/// Carries either a static bearer token (no async refresh) or an ADC-backed
71/// [`TokenProvider`] that fetches and caches tokens via `gcp_auth`.
72///
73/// Construct with [`VertexCredentials::from_token`],
74/// [`VertexCredentials::from_adc`], or [`VertexCredentials::from_env`].
75#[derive(Clone)]
76pub struct VertexCredentials {
77    inner: CredentialInner,
78}
79
80#[derive(Clone)]
81enum CredentialInner {
82    /// A static bearer token -- no async refresh.
83    Static(String),
84    /// An ADC-backed token provider from `gcp_auth`.
85    Adc(Arc<dyn TokenProvider>),
86}
87
88impl VertexCredentials {
89    /// Use a pre-obtained `OAuth2` bearer token directly.
90    ///
91    /// The token is used verbatim; no refresh is performed. Suitable for
92    /// short-lived scripts or tests where you already have a token (e.g.,
93    /// `gcloud auth print-access-token`).
94    #[must_use]
95    pub fn from_token(token: impl Into<String>) -> Self {
96        Self {
97            inner: CredentialInner::Static(token.into()),
98        }
99    }
100
101    /// Use Application Default Credentials via [`gcp_auth`].
102    ///
103    /// Tries, in order:
104    /// 1. `GOOGLE_APPLICATION_CREDENTIALS` env var (service-account key file)
105    /// 2. `~/.config/gcloud/application_default_credentials.json`
106    /// 3. GCE instance-metadata server
107    /// 4. `gcloud auth print-access-token`
108    ///
109    /// Tokens are cached and refreshed automatically by `gcp_auth`. This
110    /// constructor is async because provider discovery may involve network
111    /// I/O (metadata server probe).
112    ///
113    /// # Errors
114    ///
115    /// Returns an error if no credential source is found or if the initial
116    /// provider discovery fails.
117    pub async fn from_adc() -> Result<Self, gcp_auth::Error> {
118        let provider = gcp_auth::provider().await?;
119        Ok(Self {
120            inner: CredentialInner::Adc(provider),
121        })
122    }
123
124    /// Read credentials from environment variables.
125    ///
126    /// Checks, in order:
127    /// 1. `VERTEX_ACCESS_TOKEN` -- if set, constructs a
128    ///    [`from_token`](Self::from_token) credential.
129    /// 2. `GOOGLE_APPLICATION_CREDENTIALS` -- if set, returns an
130    ///    [`from_adc`](Self::from_adc) credential.
131    ///
132    /// Returns `None` when neither variable is set.
133    ///
134    /// # Panics
135    ///
136    /// This method calls `from_adc()` synchronously by blocking on the
137    /// current Tokio runtime when `GOOGLE_APPLICATION_CREDENTIALS` is set.
138    /// It will panic if called outside a Tokio runtime context.
139    #[must_use]
140    pub fn from_env() -> Option<Self> {
141        // Prefer a static token if supplied.
142        if let Ok(token) = std::env::var("VERTEX_ACCESS_TOKEN") {
143            return Some(Self::from_token(token));
144        }
145
146        // Fall back to ADC when GOOGLE_APPLICATION_CREDENTIALS is set.
147        if std::env::var_os("GOOGLE_APPLICATION_CREDENTIALS").is_some() {
148            let handle = tokio::runtime::Handle::current();
149            return match handle.block_on(gcp_auth::provider()) {
150                Ok(provider) => Some(Self {
151                    inner: CredentialInner::Adc(provider),
152                }),
153                Err(_) => None,
154            };
155        }
156
157        None
158    }
159
160    /// Resolve the current bearer token.
161    ///
162    /// For static credentials this is infallible. For ADC credentials,
163    /// requires a Tokio runtime handle (will block the current thread).
164    fn resolve_token(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
165        match &self.inner {
166            CredentialInner::Static(t) => Ok(t.clone()),
167            CredentialInner::Adc(provider) => {
168                let handle = tokio::runtime::Handle::current();
169                let token = handle
170                    .block_on(provider.token(VERTEX_SCOPES))
171                    .map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })?;
172                Ok(token.as_str().to_owned())
173            }
174        }
175    }
176}
177
178impl std::fmt::Debug for VertexCredentials {
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180        match &self.inner {
181            CredentialInner::Static(t) => f
182                .debug_struct("VertexCredentials")
183                .field("kind", &"static-token")
184                .field("token", &format!("<redacted, {} chars>", t.len()))
185                .finish(),
186            CredentialInner::Adc(_) => f
187                .debug_struct("VertexCredentials")
188                .field("kind", &"adc")
189                .finish(),
190        }
191    }
192}
193
194/// Vertex AI bearer-token signer.
195///
196/// Attaches `Authorization: Bearer {token}` to every outbound request and
197/// removes the `x-api-key` header (Vertex AI does not use it).
198///
199/// Install on a [`Client`](crate::Client) via
200/// [`ClientBuilder::signer`](crate::ClientBuilder::signer).
201#[derive(Debug, Clone)]
202pub struct VertexSigner {
203    credentials: VertexCredentials,
204}
205
206impl VertexSigner {
207    /// Build a signer from `credentials`.
208    #[must_use]
209    pub fn new(credentials: VertexCredentials) -> Self {
210        Self { credentials }
211    }
212}
213
214impl RequestSigner for VertexSigner {
215    fn sign(&self, request: &mut reqwest::Request) -> SignerResult {
216        // Remove the Anthropic API-key header -- Vertex does not use it.
217        request.headers_mut().remove("x-api-key");
218
219        let token = self.credentials.resolve_token()?;
220        let bearer = format!("Bearer {token}");
221        request
222            .headers_mut()
223            .insert("authorization", bearer.parse()?);
224
225        Ok(())
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    fn make_request_with_api_key() -> reqwest::Request {
234        let client = reqwest::Client::new();
235        client
236            .post("https://us-east5-aiplatform.googleapis.com/v1/projects/my-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-6:rawPredict")
237            .header("x-api-key", "sk-ant-test-key")
238            .body(r#"{"messages":[{"role":"user","content":"hi"}]}"#)
239            .build()
240            .unwrap()
241    }
242
243    fn make_request_without_api_key() -> reqwest::Request {
244        let client = reqwest::Client::new();
245        client
246            .post("https://us-east5-aiplatform.googleapis.com/v1/projects/my-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-6:rawPredict")
247            .body(r#"{"messages":[{"role":"user","content":"hi"}]}"#)
248            .build()
249            .unwrap()
250    }
251
252    fn static_signer(token: &str) -> VertexSigner {
253        VertexSigner::new(VertexCredentials::from_token(token))
254    }
255
256    #[test]
257    fn sign_adds_authorization_bearer_header() {
258        let signer = static_signer("ya29.test-token");
259        let mut req = make_request_without_api_key();
260        signer.sign(&mut req).expect("sign succeeds");
261
262        let auth = req
263            .headers()
264            .get("authorization")
265            .expect("authorization header set by signer");
266        let auth_str = auth.to_str().expect("authorization is ASCII");
267        assert_eq!(
268            auth_str, "Bearer ya29.test-token",
269            "expected bearer prefix: {auth_str}"
270        );
271    }
272
273    #[test]
274    fn sign_removes_x_api_key_header() {
275        let signer = static_signer("ya29.test-token");
276        let mut req = make_request_with_api_key();
277
278        // Verify the header is present before signing.
279        assert!(
280            req.headers().get("x-api-key").is_some(),
281            "x-api-key must be present before sign()"
282        );
283
284        signer.sign(&mut req).expect("sign succeeds");
285
286        assert!(
287            req.headers().get("x-api-key").is_none(),
288            "x-api-key must be removed after sign()"
289        );
290    }
291
292    #[test]
293    fn sign_sets_correct_bearer_format() {
294        let token = "ya29.c.long-token-value-here";
295        let signer = static_signer(token);
296        let mut req = make_request_without_api_key();
297        signer.sign(&mut req).expect("sign succeeds");
298
299        let auth = req
300            .headers()
301            .get("authorization")
302            .unwrap()
303            .to_str()
304            .unwrap();
305        assert!(auth.starts_with("Bearer "), "must start with 'Bearer '");
306        assert!(auth.contains(token), "must contain the token");
307    }
308
309    #[test]
310    fn credentials_redact_token_in_debug() {
311        let creds = VertexCredentials::from_token("ya29.very-secret-token");
312        let dbg = format!("{creds:?}");
313        assert!(!dbg.contains("very-secret-token"), "{dbg}");
314        assert!(dbg.contains("redacted"), "{dbg}");
315    }
316
317    #[test]
318    fn credentials_debug_shows_adc_kind_without_token() {
319        // Build an ADC variant manually by wrapping a fake provider.
320        struct FakeProvider;
321
322        #[async_trait::async_trait]
323        impl TokenProvider for FakeProvider {
324            async fn token(
325                &self,
326                _scopes: &[&str],
327            ) -> Result<Arc<gcp_auth::Token>, gcp_auth::Error> {
328                unimplemented!()
329            }
330
331            async fn project_id(&self) -> Result<Arc<str>, gcp_auth::Error> {
332                unimplemented!()
333            }
334        }
335
336        let creds = VertexCredentials {
337            inner: CredentialInner::Adc(Arc::new(FakeProvider)),
338        };
339        let dbg = format!("{creds:?}");
340        assert!(dbg.contains("adc"), "{dbg}");
341    }
342
343    #[test]
344    fn from_env_returns_none_when_no_vars_set() {
345        // Guard: clear both env vars for the duration of this test.
346        // Because std::env is process-global, we can only check that
347        // from_env() returns None when neither var is present. We cannot
348        // reliably clear env vars that may be set by the outer environment,
349        // so we skip the assertion if either is already set.
350        let has_token = std::env::var("VERTEX_ACCESS_TOKEN").is_ok();
351        let has_adc = std::env::var_os("GOOGLE_APPLICATION_CREDENTIALS").is_some();
352        if has_token || has_adc {
353            // Environment is pre-configured; skip.
354            return;
355        }
356        let result = {
357            // Neither var is set in the environment.
358            // from_env() must return None when called outside a Tokio
359            // runtime (the GOOGLE_APPLICATION_CREDENTIALS branch is not
360            // reached, so no panic).
361            VertexCredentials::from_env()
362        };
363        assert!(
364            result.is_none(),
365            "expected None when env vars are absent: {result:?}"
366        );
367    }
368
369    #[test]
370    fn from_env_returns_static_creds_when_vertex_access_token_env_is_set() {
371        // We cannot mutate env vars in tests (unsafe_code = "forbid").
372        // This test verifies the from_env logic for the static-token path by
373        // constructing credentials directly via from_token, which exercises
374        // the same CredentialInner::Static variant that from_env() produces
375        // when VERTEX_ACCESS_TOKEN is set.
376        let creds = VertexCredentials::from_token("ya29.env-test-token");
377        assert!(
378            matches!(creds.inner, CredentialInner::Static(_)),
379            "from_token must yield a static credential"
380        );
381    }
382}