chio_external_guards/external/
vertex_safety.rs1use std::time::Duration;
18
19use async_trait::async_trait;
20use chio_core_types::GuardEvidence;
21use chio_kernel::Verdict;
22use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
23use reqwest::Client;
24use serde::{Deserialize, Serialize};
25use sha2::{Digest, Sha256};
26use zeroize::Zeroizing;
27
28use super::bedrock::{classify_reqwest_error, classify_status_error};
29use super::{ExternalGuard, ExternalGuardError, GuardCallContext};
30
31pub const GUARD_NAME: &str = "vertex-safety";
33
34pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Deserialize, Default)]
39#[serde(rename_all = "UPPERCASE")]
40pub enum VertexProbability {
41 Negligible,
42 Low,
43 #[default]
44 Medium,
45 High,
46 #[serde(other)]
49 Unknown,
50}
51
52impl VertexProbability {
53 fn rank(self) -> u8 {
54 match self {
55 Self::Negligible => 0,
56 Self::Unknown | Self::Low => 1,
57 Self::Medium => 2,
58 Self::High => 3,
59 }
60 }
61}
62
63#[derive(Clone)]
65pub struct VertexSafetyConfig {
66 pub api_key: Zeroizing<String>,
68 pub project: String,
70 pub location: String,
72 pub model: String,
74 pub endpoint: Option<String>,
77 pub timeout: Duration,
79 pub probability_threshold: VertexProbability,
81}
82
83impl std::fmt::Debug for VertexSafetyConfig {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 f.debug_struct("VertexSafetyConfig")
86 .field("api_key", &"***redacted***")
87 .field("project", &self.project)
88 .field("location", &self.location)
89 .field("model", &self.model)
90 .field("endpoint", &self.endpoint)
91 .field("timeout", &self.timeout)
92 .field("probability_threshold", &self.probability_threshold)
93 .finish()
94 }
95}
96
97impl VertexSafetyConfig {
98 pub fn new(
100 api_key: impl Into<String>,
101 project: impl Into<String>,
102 location: impl Into<String>,
103 model: impl Into<String>,
104 ) -> Self {
105 Self {
106 api_key: Zeroizing::new(api_key.into()),
107 project: project.into(),
108 location: location.into(),
109 model: model.into(),
110 endpoint: None,
111 timeout: DEFAULT_TIMEOUT,
112 probability_threshold: VertexProbability::Medium,
113 }
114 }
115
116 pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
118 self.endpoint = Some(endpoint.into());
119 self
120 }
121
122 pub fn with_threshold(mut self, threshold: VertexProbability) -> Self {
124 self.probability_threshold = threshold;
125 self
126 }
127
128 fn resolved_endpoint(&self) -> String {
129 match self.endpoint.as_deref() {
130 Some(ep) => ep.trim_end_matches('/').to_string(),
131 None => format!("https://{}-aiplatform.googleapis.com", self.location),
132 }
133 }
134
135 fn generate_url(&self) -> String {
136 format!(
137 "{}/v1/projects/{}/locations/{}/publishers/google/models/{}:generateContent",
138 self.resolved_endpoint(),
139 self.project,
140 self.location,
141 self.model
142 )
143 }
144}
145
146#[derive(Debug, Serialize)]
147struct GenerateRequest<'a> {
148 contents: Vec<GenerateContent<'a>>,
149}
150
151#[derive(Debug, Serialize)]
152struct GenerateContent<'a> {
153 role: &'a str,
154 parts: Vec<GeneratePart<'a>>,
155}
156
157#[derive(Debug, Serialize)]
158struct GeneratePart<'a> {
159 text: &'a str,
160}
161
162#[derive(Debug, Clone, Deserialize)]
163struct GenerateResponse {
164 #[serde(default)]
165 candidates: Vec<Candidate>,
166 #[serde(default, rename = "promptFeedback")]
167 prompt_feedback: Option<PromptFeedback>,
168}
169
170#[derive(Debug, Clone, Deserialize)]
171struct Candidate {
172 #[serde(default, rename = "safetyRatings")]
173 safety_ratings: Vec<SafetyRating>,
174 #[serde(default, rename = "finishReason")]
175 #[allow(dead_code)]
176 finish_reason: Option<String>,
177}
178
179#[derive(Debug, Clone, Deserialize, Serialize)]
180struct SafetyRating {
181 #[serde(default)]
182 category: String,
183 #[serde(default)]
184 probability: VertexProbabilityDefault,
185}
186
187#[derive(Debug, Clone, Deserialize, Serialize, Default)]
188#[serde(transparent)]
189struct VertexProbabilityDefault(VertexProbability);
190
191impl serde::Serialize for VertexProbability {
192 fn serialize<S: serde::Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
193 let s = match self {
194 VertexProbability::Negligible => "NEGLIGIBLE",
195 VertexProbability::Low => "LOW",
196 VertexProbability::Medium => "MEDIUM",
197 VertexProbability::High => "HIGH",
198 VertexProbability::Unknown => "UNKNOWN",
199 };
200 ser.serialize_str(s)
201 }
202}
203
204#[derive(Debug, Clone, Deserialize)]
205struct PromptFeedback {
206 #[serde(default, rename = "blockReason")]
207 block_reason: Option<String>,
208 #[serde(default, rename = "safetyRatings")]
209 safety_ratings: Vec<SafetyRating>,
210}
211
212pub struct VertexSafetyGuard {
214 cfg: VertexSafetyConfig,
215 http: Client,
216}
217
218impl VertexSafetyGuard {
219 pub fn new(cfg: VertexSafetyConfig) -> Result<Self, ExternalGuardError> {
221 let http = Client::builder()
222 .timeout(cfg.timeout)
223 .build()
224 .map_err(|e| ExternalGuardError::Permanent(format!("reqwest build: {e}")))?;
225 Ok(Self { cfg, http })
226 }
227
228 pub fn with_client(cfg: VertexSafetyConfig, http: Client) -> Self {
230 Self { cfg, http }
231 }
232
233 pub fn evidence_from_decision(
235 &self,
236 verdict: Verdict,
237 details: Option<&VertexDecisionDetails>,
238 ) -> GuardEvidence {
239 GuardEvidence {
240 guard_name: self.name().to_string(),
241 verdict: matches!(verdict, Verdict::Allow),
242 details: details.and_then(|d| d.as_details_string()),
243 }
244 }
245}
246
247#[derive(Debug, Clone, Serialize)]
249pub struct VertexDecisionDetails {
250 pub block_reason: Option<String>,
252 pub threshold: String,
254 pub safety_ratings: Vec<VertexRatingBreakdown>,
256}
257
258#[derive(Debug, Clone, Serialize)]
259pub struct VertexRatingBreakdown {
260 pub category: String,
261 pub probability: String,
262}
263
264impl VertexDecisionDetails {
265 fn as_details_string(&self) -> Option<String> {
266 serde_json::to_string(self).ok()
267 }
268}
269
270#[async_trait]
271impl ExternalGuard for VertexSafetyGuard {
272 fn name(&self) -> &str {
273 GUARD_NAME
274 }
275
276 fn cache_key(&self, ctx: &GuardCallContext) -> Option<String> {
277 let mut hasher = Sha256::new();
278 hasher.update(self.cfg.project.as_bytes());
279 hasher.update(b":");
280 hasher.update(self.cfg.model.as_bytes());
281 hasher.update(b":");
282 hasher.update(ctx.tool_name.as_bytes());
283 hasher.update(b":");
284 hasher.update(ctx.arguments_json.as_bytes());
285 let digest = hasher.finalize();
286 let mut hex = String::with_capacity(digest.len() * 2);
287 for b in digest {
288 hex.push_str(&format!("{b:02x}"));
289 }
290 Some(format!("vertex:{hex}"))
291 }
292
293 async fn eval(&self, ctx: &GuardCallContext) -> Result<Verdict, ExternalGuardError> {
294 let url = self.cfg.generate_url();
295 super::endpoint_security::validate_external_guard_url("vertex-safety endpoint", &url)?;
296
297 let body = GenerateRequest {
298 contents: vec![GenerateContent {
299 role: "user",
300 parts: vec![GeneratePart {
301 text: &ctx.arguments_json,
302 }],
303 }],
304 };
305
306 let mut headers = HeaderMap::new();
307 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
308 let auth = format!("Bearer {}", self.cfg.api_key.as_str());
309 headers.insert(
310 AUTHORIZATION,
311 HeaderValue::from_str(&auth)
312 .map_err(|e| ExternalGuardError::Permanent(format!("invalid api key: {e}")))?,
313 );
314
315 let resp = self
316 .http
317 .post(&url)
318 .headers(headers)
319 .json(&body)
320 .send()
321 .await
322 .map_err(classify_reqwest_error)?;
323
324 let status = resp.status();
325 let text = resp
326 .text()
327 .await
328 .map_err(|e| ExternalGuardError::Transient(format!("read body: {e}")))?;
329
330 if !status.is_success() {
331 return Err(classify_status_error("vertex-safety", status, &text));
332 }
333
334 let parsed: GenerateResponse = serde_json::from_str(&text)
335 .map_err(|e| ExternalGuardError::Transient(format!("parse vertex response: {e}")))?;
336
337 if let Some(pf) = parsed.prompt_feedback.as_ref() {
338 if pf.block_reason.is_some() {
339 tracing::info!(
340 guard = GUARD_NAME,
341 block_reason = ?pf.block_reason,
342 "vertex safety promptFeedback blocked"
343 );
344 return Ok(Verdict::Deny);
345 }
346 }
347
348 let threshold_rank = self.cfg.probability_threshold.rank();
349 let mut max_rank = 0_u8;
350 let candidate_ratings = parsed
351 .candidates
352 .iter()
353 .flat_map(|c| c.safety_ratings.iter());
354 let pf_ratings = parsed
355 .prompt_feedback
356 .as_ref()
357 .map(|p| p.safety_ratings.as_slice())
358 .unwrap_or(&[])
359 .iter();
360 for rating in candidate_ratings.chain(pf_ratings) {
361 let rank = rating.probability.0.rank();
362 if rank > max_rank {
363 max_rank = rank;
364 }
365 }
366
367 tracing::info!(
368 guard = GUARD_NAME,
369 max_rank,
370 threshold_rank,
371 "vertex safety response"
372 );
373
374 Ok(if max_rank >= threshold_rank {
375 Verdict::Deny
376 } else {
377 Verdict::Allow
378 })
379 }
380}