use crate::privacy::SecretString;
use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::fmt;
use std::str::FromStr;
use std::time::Duration;
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
pub enum ProviderKind {
Gemini,
}
impl FromStr for ProviderKind {
type Err = ProviderError;
fn from_str(value: &str) -> Result<Self, Self::Err> {
if value.eq_ignore_ascii_case("gemini") {
return Ok(Self::Gemini);
}
Err(ProviderError::InvalidConfig(
"unsupported provider".to_string(),
))
}
}
#[derive(Clone)]
pub struct ProviderConfig {
pub kind: ProviderKind,
pub base_url: String,
pub model: String,
pub api_key: SecretString,
pub timeout: Duration,
pub allow_insecure_test_base_url: bool,
}
impl fmt::Debug for ProviderConfig {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("ProviderConfig")
.field("kind", &self.kind)
.field("base_url", &self.base_url)
.field("model", &self.model)
.field("api_key", &"[REDACTED]")
.field("timeout", &self.timeout)
.finish()
}
}
impl ProviderConfig {
pub fn parse(self) -> Result<Self, ProviderError> {
if self.model.trim().is_empty() {
return Err(ProviderError::InvalidConfig(
"model is required".to_string(),
));
}
if self.api_key.expose().trim().is_empty() {
return Err(ProviderError::InvalidConfig(
"api key is required".to_string(),
));
}
if !self.allow_insecure_test_base_url && !self.base_url.starts_with("https://") {
return Err(ProviderError::InvalidConfig(
"base url must be HTTPS".to_string(),
));
}
let timeout = if self.timeout.is_zero() {
Duration::from_secs(20)
} else {
self.timeout
};
Ok(Self { timeout, ..self })
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct SafetyPolicy {
pub block_medium_and_above: bool,
}
impl Default for SafetyPolicy {
fn default() -> Self {
Self {
block_medium_and_above: true,
}
}
}
#[derive(Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AdjudicationRequest {
pub prompt: String,
pub schema: Option<Value>,
pub temperature: Option<f32>,
pub max_output_tokens: Option<u32>,
pub safety_policy: SafetyPolicy,
}
impl fmt::Debug for AdjudicationRequest {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("AdjudicationRequest")
.field("prompt", &"[REDACTED]")
.field("schema", &self.schema.as_ref().map(|_| "[PRESENT]"))
.field("temperature", &self.temperature)
.field("max_output_tokens", &self.max_output_tokens)
.field("safety_policy", &self.safety_policy)
.finish()
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct TokenUsage {
pub input_tokens: Option<u32>,
pub output_tokens: Option<u32>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct ProviderMeta {
pub provider: ProviderKind,
pub model: String,
pub token_usage: TokenUsage,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AdjudicationResponse {
pub json: Value,
pub meta: ProviderMeta,
}
#[derive(Debug, Error)]
pub enum ProviderError {
#[error("invalid provider config: {0}")]
InvalidConfig(String),
#[error("provider request failed")]
Transport,
#[error("provider returned non-success status {status}")]
HttpStatus { status: u16 },
#[error("provider blocked request: {0}")]
Blocked(String),
#[error("provider response did not contain JSON text")]
MissingJsonText,
#[error("provider JSON parse failed")]
JsonParse,
#[error("provider response failed schema validation")]
SchemaValidation,
}
#[async_trait]
pub trait AdjudicatorClient: Send + Sync {
async fn adjudicate(
&self,
request: AdjudicationRequest,
) -> Result<AdjudicationResponse, ProviderError>;
}
#[derive(Clone)]
pub struct GeminiClient {
config: ProviderConfig,
http: reqwest::Client,
}
impl GeminiClient {
pub fn new(config: ProviderConfig) -> Result<Self, ProviderError> {
let config = config.parse()?;
install_ring_crypto_provider();
let root_certificates = mozilla_root_certificates()?;
let http = reqwest::Client::builder()
.timeout(config.timeout)
.user_agent("chromaframe-sdk/0.1")
.http1_only()
.tls_certs_only(root_certificates)
.build()
.map_err(|_| ProviderError::Transport)?;
Ok(Self { config, http })
}
#[must_use]
pub fn request_body(request: &AdjudicationRequest) -> Value {
let mut generation_config = serde_json::Map::new();
generation_config.insert("responseMimeType".to_string(), json!("application/json"));
if let Some(schema) = &request.schema {
generation_config.insert("responseJsonSchema".to_string(), schema.clone());
}
if let Some(temperature) = request.temperature {
generation_config.insert("temperature".to_string(), json!(temperature));
}
if let Some(max_tokens) = request.max_output_tokens {
generation_config.insert("maxOutputTokens".to_string(), json!(max_tokens));
}
json!({
"contents": [{ "parts": [{ "text": request.prompt }] }],
"generationConfig": generation_config,
"safetySettings": default_safety_settings(request.safety_policy.block_medium_and_above),
})
}
fn endpoint(&self) -> String {
format!(
"{}/models/{}:generateContent",
self.config.base_url.trim_end_matches('/'),
self.config.model
)
}
}
#[async_trait]
impl AdjudicatorClient for GeminiClient {
async fn adjudicate(
&self,
request: AdjudicationRequest,
) -> Result<AdjudicationResponse, ProviderError> {
let schema = request.schema.clone();
let response = self
.http
.post(self.endpoint())
.header("x-goog-api-key", self.config.api_key.expose())
.json(&Self::request_body(&request))
.send()
.await
.map_err(|_| ProviderError::Transport)?;
if !response.status().is_success() {
return Err(ProviderError::HttpStatus {
status: response.status().as_u16(),
});
}
let value: Value = response
.json()
.await
.map_err(|_| ProviderError::JsonParse)?;
parse_gemini_response(value, schema, self.config.model.clone())
}
}
pub fn parse_gemini_response(
value: Value,
schema: Option<Value>,
model: String,
) -> Result<AdjudicationResponse, ProviderError> {
if value
.get("promptFeedback")
.and_then(|feedback| feedback.get("blockReason"))
.is_some()
{
return Err(ProviderError::Blocked("prompt_feedback".to_string()));
}
if value
.pointer("/candidates/0/finishReason")
.and_then(Value::as_str)
.is_some_and(|reason| reason == "SAFETY")
{
return Err(ProviderError::Blocked("candidate_safety".to_string()));
}
let text = value
.pointer("/candidates/0/content/parts")
.and_then(Value::as_array)
.and_then(|parts| {
parts
.iter()
.filter_map(|part| part.get("text").and_then(Value::as_str))
.next()
})
.ok_or(ProviderError::MissingJsonText)?;
let parsed: Value = serde_json::from_str(text).map_err(|_| ProviderError::JsonParse)?;
if let Some(schema) = schema {
let validator =
jsonschema::validator_for(&schema).map_err(|_| ProviderError::SchemaValidation)?;
if !validator.is_valid(&parsed) {
return Err(ProviderError::SchemaValidation);
}
}
Ok(AdjudicationResponse {
json: parsed,
meta: ProviderMeta {
provider: ProviderKind::Gemini,
model,
token_usage: TokenUsage {
input_tokens: None,
output_tokens: None,
},
},
})
}
fn install_ring_crypto_provider() {
if rustls::crypto::CryptoProvider::get_default().is_some() {
return;
}
let _ = rustls::crypto::ring::default_provider().install_default();
}
fn mozilla_root_certificates() -> Result<Vec<reqwest::Certificate>, ProviderError> {
webpki_root_certs::TLS_SERVER_ROOT_CERTS
.iter()
.map(|cert| reqwest::Certificate::from_der(cert.as_ref()))
.collect::<Result<Vec<_>, _>>()
.map_err(|_| ProviderError::Transport)
}
fn default_safety_settings(block: bool) -> Value {
let threshold = if block {
"BLOCK_MEDIUM_AND_ABOVE"
} else {
"BLOCK_ONLY_HIGH"
};
json!([
{"category":"HARM_CATEGORY_HARASSMENT","threshold":threshold},
{"category":"HARM_CATEGORY_HATE_SPEECH","threshold":threshold},
{"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":threshold},
{"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":threshold}
])
}