chromaframe-sdk 0.1.1

Deterministic, privacy-preserving color measurement and ranking SDK
Documentation
use crate::privacy::SecretString;
use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::fmt;
use std::str::FromStr;
use std::time::Duration;
use thiserror::Error;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
pub enum ProviderKind {
    Gemini,
}

impl FromStr for ProviderKind {
    type Err = ProviderError;
    fn from_str(value: &str) -> Result<Self, Self::Err> {
        if value.eq_ignore_ascii_case("gemini") {
            return Ok(Self::Gemini);
        }
        Err(ProviderError::InvalidConfig(
            "unsupported provider".to_string(),
        ))
    }
}

#[derive(Clone)]
pub struct ProviderConfig {
    pub kind: ProviderKind,
    pub base_url: String,
    pub model: String,
    pub api_key: SecretString,
    pub timeout: Duration,
    pub allow_insecure_test_base_url: bool,
}

impl fmt::Debug for ProviderConfig {
    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
        formatter
            .debug_struct("ProviderConfig")
            .field("kind", &self.kind)
            .field("base_url", &self.base_url)
            .field("model", &self.model)
            .field("api_key", &"[REDACTED]")
            .field("timeout", &self.timeout)
            .finish()
    }
}

impl ProviderConfig {
    pub fn parse(self) -> Result<Self, ProviderError> {
        if self.model.trim().is_empty() {
            return Err(ProviderError::InvalidConfig(
                "model is required".to_string(),
            ));
        }
        if self.api_key.expose().trim().is_empty() {
            return Err(ProviderError::InvalidConfig(
                "api key is required".to_string(),
            ));
        }
        if !self.allow_insecure_test_base_url && !self.base_url.starts_with("https://") {
            return Err(ProviderError::InvalidConfig(
                "base url must be HTTPS".to_string(),
            ));
        }
        let timeout = if self.timeout.is_zero() {
            Duration::from_secs(20)
        } else {
            self.timeout
        };
        Ok(Self { timeout, ..self })
    }
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct SafetyPolicy {
    pub block_medium_and_above: bool,
}
impl Default for SafetyPolicy {
    fn default() -> Self {
        Self {
            block_medium_and_above: true,
        }
    }
}

#[derive(Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AdjudicationRequest {
    pub prompt: String,
    pub schema: Option<Value>,
    pub temperature: Option<f32>,
    pub max_output_tokens: Option<u32>,
    pub safety_policy: SafetyPolicy,
}

impl fmt::Debug for AdjudicationRequest {
    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
        formatter
            .debug_struct("AdjudicationRequest")
            .field("prompt", &"[REDACTED]")
            .field("schema", &self.schema.as_ref().map(|_| "[PRESENT]"))
            .field("temperature", &self.temperature)
            .field("max_output_tokens", &self.max_output_tokens)
            .field("safety_policy", &self.safety_policy)
            .finish()
    }
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct TokenUsage {
    pub input_tokens: Option<u32>,
    pub output_tokens: Option<u32>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct ProviderMeta {
    pub provider: ProviderKind,
    pub model: String,
    pub token_usage: TokenUsage,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AdjudicationResponse {
    pub json: Value,
    pub meta: ProviderMeta,
}

#[derive(Debug, Error)]
pub enum ProviderError {
    #[error("invalid provider config: {0}")]
    InvalidConfig(String),
    #[error("provider request failed")]
    Transport,
    #[error("provider returned non-success status {status}")]
    HttpStatus { status: u16 },
    #[error("provider blocked request: {0}")]
    Blocked(String),
    #[error("provider response did not contain JSON text")]
    MissingJsonText,
    #[error("provider JSON parse failed")]
    JsonParse,
    #[error("provider response failed schema validation")]
    SchemaValidation,
}

#[async_trait]
pub trait AdjudicatorClient: Send + Sync {
    async fn adjudicate(
        &self,
        request: AdjudicationRequest,
    ) -> Result<AdjudicationResponse, ProviderError>;
}

#[derive(Clone)]
pub struct GeminiClient {
    config: ProviderConfig,
    http: reqwest::Client,
}

impl GeminiClient {
    pub fn new(config: ProviderConfig) -> Result<Self, ProviderError> {
        let config = config.parse()?;
        install_ring_crypto_provider();
        let root_certificates = mozilla_root_certificates()?;
        let http = reqwest::Client::builder()
            .timeout(config.timeout)
            .user_agent("chromaframe-sdk/0.1")
            .http1_only()
            .tls_certs_only(root_certificates)
            .build()
            .map_err(|_| ProviderError::Transport)?;
        Ok(Self { config, http })
    }

    #[must_use]
    pub fn request_body(request: &AdjudicationRequest) -> Value {
        let mut generation_config = serde_json::Map::new();
        generation_config.insert("responseMimeType".to_string(), json!("application/json"));
        if let Some(schema) = &request.schema {
            generation_config.insert("responseJsonSchema".to_string(), schema.clone());
        }
        if let Some(temperature) = request.temperature {
            generation_config.insert("temperature".to_string(), json!(temperature));
        }
        if let Some(max_tokens) = request.max_output_tokens {
            generation_config.insert("maxOutputTokens".to_string(), json!(max_tokens));
        }
        json!({
            "contents": [{ "parts": [{ "text": request.prompt }] }],
            "generationConfig": generation_config,
            "safetySettings": default_safety_settings(request.safety_policy.block_medium_and_above),
        })
    }

    fn endpoint(&self) -> String {
        format!(
            "{}/models/{}:generateContent",
            self.config.base_url.trim_end_matches('/'),
            self.config.model
        )
    }
}

#[async_trait]
impl AdjudicatorClient for GeminiClient {
    async fn adjudicate(
        &self,
        request: AdjudicationRequest,
    ) -> Result<AdjudicationResponse, ProviderError> {
        let schema = request.schema.clone();
        let response = self
            .http
            .post(self.endpoint())
            .header("x-goog-api-key", self.config.api_key.expose())
            .json(&Self::request_body(&request))
            .send()
            .await
            .map_err(|_| ProviderError::Transport)?;
        if !response.status().is_success() {
            return Err(ProviderError::HttpStatus {
                status: response.status().as_u16(),
            });
        }
        let value: Value = response
            .json()
            .await
            .map_err(|_| ProviderError::JsonParse)?;
        parse_gemini_response(value, schema, self.config.model.clone())
    }
}

pub fn parse_gemini_response(
    value: Value,
    schema: Option<Value>,
    model: String,
) -> Result<AdjudicationResponse, ProviderError> {
    if value
        .get("promptFeedback")
        .and_then(|feedback| feedback.get("blockReason"))
        .is_some()
    {
        return Err(ProviderError::Blocked("prompt_feedback".to_string()));
    }
    if value
        .pointer("/candidates/0/finishReason")
        .and_then(Value::as_str)
        .is_some_and(|reason| reason == "SAFETY")
    {
        return Err(ProviderError::Blocked("candidate_safety".to_string()));
    }
    let text = value
        .pointer("/candidates/0/content/parts")
        .and_then(Value::as_array)
        .and_then(|parts| {
            parts
                .iter()
                .filter_map(|part| part.get("text").and_then(Value::as_str))
                .next()
        })
        .ok_or(ProviderError::MissingJsonText)?;
    let parsed: Value = serde_json::from_str(text).map_err(|_| ProviderError::JsonParse)?;
    if let Some(schema) = schema {
        let validator =
            jsonschema::validator_for(&schema).map_err(|_| ProviderError::SchemaValidation)?;
        if !validator.is_valid(&parsed) {
            return Err(ProviderError::SchemaValidation);
        }
    }
    Ok(AdjudicationResponse {
        json: parsed,
        meta: ProviderMeta {
            provider: ProviderKind::Gemini,
            model,
            token_usage: TokenUsage {
                input_tokens: None,
                output_tokens: None,
            },
        },
    })
}

fn install_ring_crypto_provider() {
    if rustls::crypto::CryptoProvider::get_default().is_some() {
        return;
    }

    let _ = rustls::crypto::ring::default_provider().install_default();
}

fn mozilla_root_certificates() -> Result<Vec<reqwest::Certificate>, ProviderError> {
    webpki_root_certs::TLS_SERVER_ROOT_CERTS
        .iter()
        .map(|cert| reqwest::Certificate::from_der(cert.as_ref()))
        .collect::<Result<Vec<_>, _>>()
        .map_err(|_| ProviderError::Transport)
}

fn default_safety_settings(block: bool) -> Value {
    let threshold = if block {
        "BLOCK_MEDIUM_AND_ABOVE"
    } else {
        "BLOCK_ONLY_HIGH"
    };
    json!([
        {"category":"HARM_CATEGORY_HARASSMENT","threshold":threshold},
        {"category":"HARM_CATEGORY_HATE_SPEECH","threshold":threshold},
        {"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":threshold},
        {"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":threshold}
    ])
}