Skip to main content

chio_external_guards/external/
vertex_safety.rs

1//! Google Vertex AI safety-classifier adapter (phase 13.2).
2//!
3//! Vertex AI exposes multiple safety surfaces. For phase 13.2 we use the
4//! generative-language `generateContent` safety-classification response:
5//! when a request is submitted, the response carries `safetyRatings[]`
6//! per category, each with a `probability` enum
7//! (`NEGLIGIBLE|LOW|MEDIUM|HIGH`). We deny when any rating meets or
8//! exceeds the configured probability threshold *or* when Vertex reports
9//! a top-level `promptFeedback.blockReason`.
10//!
11//! See: <https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/safety-filters>
12//!
13//! Authentication uses a bearer token (typically an OAuth access token
14//! minted from a service account). Like the Bedrock adapter, we accept
15//! the token as a [`Zeroizing<String>`] so tokens don't linger in memory.
16
17use std::time::Duration;
18
19use async_trait::async_trait;
20use chio_core_types::GuardEvidence;
21use chio_kernel::Verdict;
22use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
23use reqwest::Client;
24use serde::{Deserialize, Serialize};
25use sha2::{Digest, Sha256};
26use zeroize::Zeroizing;
27
28use super::bedrock::{classify_reqwest_error, classify_status_error};
29use super::{ExternalGuard, ExternalGuardError, GuardCallContext};
30
31/// Guard name reported by [`VertexSafetyGuard::name`].
32pub const GUARD_NAME: &str = "vertex-safety";
33
34/// Default request timeout.
35pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
36
37/// Vertex safety probability levels.
38#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Deserialize, Default)]
39#[serde(rename_all = "UPPERCASE")]
40pub enum VertexProbability {
41    Negligible,
42    Low,
43    #[default]
44    Medium,
45    High,
46    /// Anything not recognized (e.g. `PROBABILITY_UNSPECIFIED`) is
47    /// treated as [`VertexProbability::Low`] for threshold purposes.
48    #[serde(other)]
49    Unknown,
50}
51
52impl VertexProbability {
53    fn rank(self) -> u8 {
54        match self {
55            Self::Negligible => 0,
56            Self::Unknown | Self::Low => 1,
57            Self::Medium => 2,
58            Self::High => 3,
59        }
60    }
61}
62
63/// Configuration for [`VertexSafetyGuard`].
64#[derive(Clone)]
65pub struct VertexSafetyConfig {
66    /// Bearer token (OAuth access token).
67    pub api_key: Zeroizing<String>,
68    /// GCP project ID.
69    pub project: String,
70    /// Region, e.g. `us-central1`.
71    pub location: String,
72    /// Publisher model, e.g. `gemini-1.5-pro`.
73    pub model: String,
74    /// Endpoint override. When `None` we use
75    /// `https://{location}-aiplatform.googleapis.com`.
76    pub endpoint: Option<String>,
77    /// Per-request HTTP timeout.
78    pub timeout: Duration,
79    /// Threshold at or above which the guard denies.
80    pub probability_threshold: VertexProbability,
81}
82
83impl std::fmt::Debug for VertexSafetyConfig {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        f.debug_struct("VertexSafetyConfig")
86            .field("api_key", &"***redacted***")
87            .field("project", &self.project)
88            .field("location", &self.location)
89            .field("model", &self.model)
90            .field("endpoint", &self.endpoint)
91            .field("timeout", &self.timeout)
92            .field("probability_threshold", &self.probability_threshold)
93            .finish()
94    }
95}
96
97impl VertexSafetyConfig {
98    /// Construct a config with defaults.
99    pub fn new(
100        api_key: impl Into<String>,
101        project: impl Into<String>,
102        location: impl Into<String>,
103        model: impl Into<String>,
104    ) -> Self {
105        Self {
106            api_key: Zeroizing::new(api_key.into()),
107            project: project.into(),
108            location: location.into(),
109            model: model.into(),
110            endpoint: None,
111            timeout: DEFAULT_TIMEOUT,
112            probability_threshold: VertexProbability::Medium,
113        }
114    }
115
116    /// Override the endpoint (primarily for tests).
117    pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
118        self.endpoint = Some(endpoint.into());
119        self
120    }
121
122    /// Override the probability threshold.
123    pub fn with_threshold(mut self, threshold: VertexProbability) -> Self {
124        self.probability_threshold = threshold;
125        self
126    }
127
128    fn resolved_endpoint(&self) -> String {
129        match self.endpoint.as_deref() {
130            Some(ep) => ep.trim_end_matches('/').to_string(),
131            None => format!("https://{}-aiplatform.googleapis.com", self.location),
132        }
133    }
134
135    fn generate_url(&self) -> String {
136        format!(
137            "{}/v1/projects/{}/locations/{}/publishers/google/models/{}:generateContent",
138            self.resolved_endpoint(),
139            self.project,
140            self.location,
141            self.model
142        )
143    }
144}
145
146#[derive(Debug, Serialize)]
147struct GenerateRequest<'a> {
148    contents: Vec<GenerateContent<'a>>,
149}
150
151#[derive(Debug, Serialize)]
152struct GenerateContent<'a> {
153    role: &'a str,
154    parts: Vec<GeneratePart<'a>>,
155}
156
157#[derive(Debug, Serialize)]
158struct GeneratePart<'a> {
159    text: &'a str,
160}
161
162#[derive(Debug, Clone, Deserialize)]
163struct GenerateResponse {
164    #[serde(default)]
165    candidates: Vec<Candidate>,
166    #[serde(default, rename = "promptFeedback")]
167    prompt_feedback: Option<PromptFeedback>,
168}
169
170#[derive(Debug, Clone, Deserialize)]
171struct Candidate {
172    #[serde(default, rename = "safetyRatings")]
173    safety_ratings: Vec<SafetyRating>,
174    #[serde(default, rename = "finishReason")]
175    #[allow(dead_code)]
176    finish_reason: Option<String>,
177}
178
179#[derive(Debug, Clone, Deserialize, Serialize)]
180struct SafetyRating {
181    #[serde(default)]
182    category: String,
183    #[serde(default)]
184    probability: VertexProbabilityDefault,
185}
186
187#[derive(Debug, Clone, Deserialize, Serialize, Default)]
188#[serde(transparent)]
189struct VertexProbabilityDefault(VertexProbability);
190
191impl serde::Serialize for VertexProbability {
192    fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
193        let s = match self {
194            VertexProbability::Negligible => "NEGLIGIBLE",
195            VertexProbability::Low => "LOW",
196            VertexProbability::Medium => "MEDIUM",
197            VertexProbability::High => "HIGH",
198            VertexProbability::Unknown => "UNKNOWN",
199        };
200        ser.serialize_str(s)
201    }
202}
203
204#[derive(Debug, Clone, Deserialize)]
205struct PromptFeedback {
206    #[serde(default, rename = "blockReason")]
207    block_reason: Option<String>,
208    #[serde(default, rename = "safetyRatings")]
209    safety_ratings: Vec<SafetyRating>,
210}
211
212/// Guard wrapping Vertex AI safety ratings.
213pub struct VertexSafetyGuard {
214    cfg: VertexSafetyConfig,
215    http: Client,
216}
217
218impl VertexSafetyGuard {
219    /// Build a guard with an internally-owned [`reqwest::Client`].
220    pub fn new(cfg: VertexSafetyConfig) -> Result<Self, ExternalGuardError> {
221        let http = Client::builder()
222            .timeout(cfg.timeout)
223            .build()
224            .map_err(|e| ExternalGuardError::Permanent(format!("reqwest build: {e}")))?;
225        Ok(Self { cfg, http })
226    }
227
228    /// Build a guard with a caller-supplied client (for tests).
229    pub fn with_client(cfg: VertexSafetyConfig, http: Client) -> Self {
230        Self { cfg, http }
231    }
232
233    /// Build a [`GuardEvidence`] record for a prior decision.
234    pub fn evidence_from_decision(
235        &self,
236        verdict: Verdict,
237        details: Option<&VertexDecisionDetails>,
238    ) -> GuardEvidence {
239        GuardEvidence {
240            guard_name: self.name().to_string(),
241            verdict: matches!(verdict, Verdict::Allow),
242            details: details.and_then(|d| d.as_details_string()),
243        }
244    }
245}
246
247/// Structured details for receipt evidence.
248#[derive(Debug, Clone, Serialize)]
249pub struct VertexDecisionDetails {
250    /// The model's `promptFeedback.blockReason` if any.
251    pub block_reason: Option<String>,
252    /// Threshold used for the decision.
253    pub threshold: String,
254    /// Safety ratings observed.
255    pub safety_ratings: Vec<VertexRatingBreakdown>,
256}
257
258#[derive(Debug, Clone, Serialize)]
259pub struct VertexRatingBreakdown {
260    pub category: String,
261    pub probability: String,
262}
263
264impl VertexDecisionDetails {
265    fn as_details_string(&self) -> Option<String> {
266        serde_json::to_string(self).ok()
267    }
268}
269
270#[async_trait]
271impl ExternalGuard for VertexSafetyGuard {
272    fn name(&self) -> &str {
273        GUARD_NAME
274    }
275
276    fn cache_key(&self, ctx: &GuardCallContext) -> Option<String> {
277        let mut hasher = Sha256::new();
278        hasher.update(self.cfg.project.as_bytes());
279        hasher.update(b":");
280        hasher.update(self.cfg.model.as_bytes());
281        hasher.update(b":");
282        hasher.update(ctx.tool_name.as_bytes());
283        hasher.update(b":");
284        hasher.update(ctx.arguments_json.as_bytes());
285        let digest = hasher.finalize();
286        let mut hex = String::with_capacity(digest.len() * 2);
287        for b in digest {
288            hex.push_str(&format!("{b:02x}"));
289        }
290        Some(format!("vertex:{hex}"))
291    }
292
293    async fn eval(&self, ctx: &GuardCallContext) -> Result<Verdict, ExternalGuardError> {
294        let url = self.cfg.generate_url();
295        super::endpoint_security::validate_external_guard_url("vertex-safety endpoint", &url)?;
296
297        let body = GenerateRequest {
298            contents: vec![GenerateContent {
299                role: "user",
300                parts: vec![GeneratePart {
301                    text: &ctx.arguments_json,
302                }],
303            }],
304        };
305
306        let mut headers = HeaderMap::new();
307        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
308        let auth = format!("Bearer {}", self.cfg.api_key.as_str());
309        headers.insert(
310            AUTHORIZATION,
311            HeaderValue::from_str(&auth)
312                .map_err(|e| ExternalGuardError::Permanent(format!("invalid api key: {e}")))?,
313        );
314
315        let resp = self
316            .http
317            .post(&url)
318            .headers(headers)
319            .json(&body)
320            .send()
321            .await
322            .map_err(classify_reqwest_error)?;
323
324        let status = resp.status();
325        let text = resp
326            .text()
327            .await
328            .map_err(|e| ExternalGuardError::Transient(format!("read body: {e}")))?;
329
330        if !status.is_success() {
331            return Err(classify_status_error("vertex-safety", status, &text));
332        }
333
334        let parsed: GenerateResponse = serde_json::from_str(&text)
335            .map_err(|e| ExternalGuardError::Transient(format!("parse vertex response: {e}")))?;
336
337        if let Some(pf) = parsed.prompt_feedback.as_ref() {
338            if pf.block_reason.is_some() {
339                tracing::info!(
340                    guard = GUARD_NAME,
341                    block_reason = ?pf.block_reason,
342                    "vertex safety promptFeedback blocked"
343                );
344                return Ok(Verdict::Deny);
345            }
346        }
347
348        let threshold_rank = self.cfg.probability_threshold.rank();
349        let mut max_rank = 0_u8;
350        let candidate_ratings = parsed
351            .candidates
352            .iter()
353            .flat_map(|c| c.safety_ratings.iter());
354        let pf_ratings = parsed
355            .prompt_feedback
356            .as_ref()
357            .map(|p| p.safety_ratings.as_slice())
358            .unwrap_or(&[])
359            .iter();
360        for rating in candidate_ratings.chain(pf_ratings) {
361            let rank = rating.probability.0.rank();
362            if rank > max_rank {
363                max_rank = rank;
364            }
365        }
366
367        tracing::info!(
368            guard = GUARD_NAME,
369            max_rank,
370            threshold_rank,
371            "vertex safety response"
372        );
373
374        Ok(if max_rank >= threshold_rank {
375            Verdict::Deny
376        } else {
377            Verdict::Allow
378        })
379    }
380}