1use crate::privacy::SecretString;
2use async_trait::async_trait;
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5use serde_json::{Value, json};
6use std::fmt;
7use std::str::FromStr;
8use std::time::Duration;
9use thiserror::Error;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
12pub enum ProviderKind {
13 Gemini,
14}
15
16impl FromStr for ProviderKind {
17 type Err = ProviderError;
18 fn from_str(value: &str) -> Result<Self, Self::Err> {
19 if value.eq_ignore_ascii_case("gemini") {
20 return Ok(Self::Gemini);
21 }
22 Err(ProviderError::InvalidConfig(
23 "unsupported provider".to_string(),
24 ))
25 }
26}
27
28#[derive(Clone)]
29pub struct ProviderConfig {
30 pub kind: ProviderKind,
31 pub base_url: String,
32 pub model: String,
33 pub api_key: SecretString,
34 pub timeout: Duration,
35 pub allow_insecure_test_base_url: bool,
36}
37
38impl fmt::Debug for ProviderConfig {
39 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
40 formatter
41 .debug_struct("ProviderConfig")
42 .field("kind", &self.kind)
43 .field("base_url", &self.base_url)
44 .field("model", &self.model)
45 .field("api_key", &"[REDACTED]")
46 .field("timeout", &self.timeout)
47 .finish()
48 }
49}
50
51impl ProviderConfig {
52 pub fn parse(self) -> Result<Self, ProviderError> {
53 if self.model.trim().is_empty() {
54 return Err(ProviderError::InvalidConfig(
55 "model is required".to_string(),
56 ));
57 }
58 if self.api_key.expose().trim().is_empty() {
59 return Err(ProviderError::InvalidConfig(
60 "api key is required".to_string(),
61 ));
62 }
63 if !self.allow_insecure_test_base_url && !self.base_url.starts_with("https://") {
64 return Err(ProviderError::InvalidConfig(
65 "base url must be HTTPS".to_string(),
66 ));
67 }
68 let timeout = if self.timeout.is_zero() {
69 Duration::from_secs(20)
70 } else {
71 self.timeout
72 };
73 Ok(Self { timeout, ..self })
74 }
75}
76
77#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
78pub struct SafetyPolicy {
79 pub block_medium_and_above: bool,
80}
81impl Default for SafetyPolicy {
82 fn default() -> Self {
83 Self {
84 block_medium_and_above: true,
85 }
86 }
87}
88
89#[derive(Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
90pub struct AdjudicationRequest {
91 pub prompt: String,
92 pub schema: Option<Value>,
93 pub temperature: Option<f32>,
94 pub max_output_tokens: Option<u32>,
95 pub safety_policy: SafetyPolicy,
96}
97
98impl fmt::Debug for AdjudicationRequest {
99 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
100 formatter
101 .debug_struct("AdjudicationRequest")
102 .field("prompt", &"[REDACTED]")
103 .field("schema", &self.schema.as_ref().map(|_| "[PRESENT]"))
104 .field("temperature", &self.temperature)
105 .field("max_output_tokens", &self.max_output_tokens)
106 .field("safety_policy", &self.safety_policy)
107 .finish()
108 }
109}
110
111#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
112pub struct TokenUsage {
113 pub input_tokens: Option<u32>,
114 pub output_tokens: Option<u32>,
115}
116#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
117pub struct ProviderMeta {
118 pub provider: ProviderKind,
119 pub model: String,
120 pub token_usage: TokenUsage,
121}
122#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
123pub struct AdjudicationResponse {
124 pub json: Value,
125 pub meta: ProviderMeta,
126}
127
128#[derive(Debug, Error)]
129pub enum ProviderError {
130 #[error("invalid provider config: {0}")]
131 InvalidConfig(String),
132 #[error("provider request failed")]
133 Transport,
134 #[error("provider returned non-success status {status}")]
135 HttpStatus { status: u16 },
136 #[error("provider blocked request: {0}")]
137 Blocked(String),
138 #[error("provider response did not contain JSON text")]
139 MissingJsonText,
140 #[error("provider JSON parse failed")]
141 JsonParse,
142 #[error("provider response failed schema validation")]
143 SchemaValidation,
144}
145
146#[async_trait]
147pub trait AdjudicatorClient: Send + Sync {
148 async fn adjudicate(
149 &self,
150 request: AdjudicationRequest,
151 ) -> Result<AdjudicationResponse, ProviderError>;
152}
153
154#[derive(Clone)]
155pub struct GeminiClient {
156 config: ProviderConfig,
157 http: reqwest::Client,
158}
159
160impl GeminiClient {
161 pub fn new(config: ProviderConfig) -> Result<Self, ProviderError> {
162 let config = config.parse()?;
163 install_ring_crypto_provider();
164 let root_certificates = mozilla_root_certificates()?;
165 let http = reqwest::Client::builder()
166 .timeout(config.timeout)
167 .user_agent("chromaframe-sdk/0.1")
168 .http1_only()
169 .tls_certs_only(root_certificates)
170 .build()
171 .map_err(|_| ProviderError::Transport)?;
172 Ok(Self { config, http })
173 }
174
175 #[must_use]
176 pub fn request_body(request: &AdjudicationRequest) -> Value {
177 let mut generation_config = serde_json::Map::new();
178 generation_config.insert("responseMimeType".to_string(), json!("application/json"));
179 if let Some(schema) = &request.schema {
180 generation_config.insert("responseJsonSchema".to_string(), schema.clone());
181 }
182 if let Some(temperature) = request.temperature {
183 generation_config.insert("temperature".to_string(), json!(temperature));
184 }
185 if let Some(max_tokens) = request.max_output_tokens {
186 generation_config.insert("maxOutputTokens".to_string(), json!(max_tokens));
187 }
188 json!({
189 "contents": [{ "parts": [{ "text": request.prompt }] }],
190 "generationConfig": generation_config,
191 "safetySettings": default_safety_settings(request.safety_policy.block_medium_and_above),
192 })
193 }
194
195 fn endpoint(&self) -> String {
196 format!(
197 "{}/models/{}:generateContent",
198 self.config.base_url.trim_end_matches('/'),
199 self.config.model
200 )
201 }
202}
203
204#[async_trait]
205impl AdjudicatorClient for GeminiClient {
206 async fn adjudicate(
207 &self,
208 request: AdjudicationRequest,
209 ) -> Result<AdjudicationResponse, ProviderError> {
210 let schema = request.schema.clone();
211 let response = self
212 .http
213 .post(self.endpoint())
214 .header("x-goog-api-key", self.config.api_key.expose())
215 .json(&Self::request_body(&request))
216 .send()
217 .await
218 .map_err(|_| ProviderError::Transport)?;
219 if !response.status().is_success() {
220 return Err(ProviderError::HttpStatus {
221 status: response.status().as_u16(),
222 });
223 }
224 let value: Value = response
225 .json()
226 .await
227 .map_err(|_| ProviderError::JsonParse)?;
228 parse_gemini_response(value, schema, self.config.model.clone())
229 }
230}
231
232pub fn parse_gemini_response(
233 value: Value,
234 schema: Option<Value>,
235 model: String,
236) -> Result<AdjudicationResponse, ProviderError> {
237 if value
238 .get("promptFeedback")
239 .and_then(|feedback| feedback.get("blockReason"))
240 .is_some()
241 {
242 return Err(ProviderError::Blocked("prompt_feedback".to_string()));
243 }
244 if value
245 .pointer("/candidates/0/finishReason")
246 .and_then(Value::as_str)
247 .is_some_and(|reason| reason == "SAFETY")
248 {
249 return Err(ProviderError::Blocked("candidate_safety".to_string()));
250 }
251 let text = value
252 .pointer("/candidates/0/content/parts")
253 .and_then(Value::as_array)
254 .and_then(|parts| {
255 parts
256 .iter()
257 .filter_map(|part| part.get("text").and_then(Value::as_str))
258 .next()
259 })
260 .ok_or(ProviderError::MissingJsonText)?;
261 let parsed: Value = serde_json::from_str(text).map_err(|_| ProviderError::JsonParse)?;
262 if let Some(schema) = schema {
263 let validator =
264 jsonschema::validator_for(&schema).map_err(|_| ProviderError::SchemaValidation)?;
265 if !validator.is_valid(&parsed) {
266 return Err(ProviderError::SchemaValidation);
267 }
268 }
269 Ok(AdjudicationResponse {
270 json: parsed,
271 meta: ProviderMeta {
272 provider: ProviderKind::Gemini,
273 model,
274 token_usage: TokenUsage {
275 input_tokens: None,
276 output_tokens: None,
277 },
278 },
279 })
280}
281
282fn install_ring_crypto_provider() {
283 if rustls::crypto::CryptoProvider::get_default().is_some() {
284 return;
285 }
286
287 let _ = rustls::crypto::ring::default_provider().install_default();
288}
289
290fn mozilla_root_certificates() -> Result<Vec<reqwest::Certificate>, ProviderError> {
291 webpki_root_certs::TLS_SERVER_ROOT_CERTS
292 .iter()
293 .map(|cert| reqwest::Certificate::from_der(cert.as_ref()))
294 .collect::<Result<Vec<_>, _>>()
295 .map_err(|_| ProviderError::Transport)
296}
297
298fn default_safety_settings(block: bool) -> Value {
299 let threshold = if block {
300 "BLOCK_MEDIUM_AND_ABOVE"
301 } else {
302 "BLOCK_ONLY_HIGH"
303 };
304 json!([
305 {"category":"HARM_CATEGORY_HARASSMENT","threshold":threshold},
306 {"category":"HARM_CATEGORY_HATE_SPEECH","threshold":threshold},
307 {"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":threshold},
308 {"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":threshold}
309 ])
310}