Skip to main content

chromaframe_sdk/
provider.rs

1use crate::privacy::SecretString;
2use async_trait::async_trait;
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5use serde_json::{Value, json};
6use std::fmt;
7use std::str::FromStr;
8use std::time::Duration;
9use thiserror::Error;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
12pub enum ProviderKind {
13    Gemini,
14}
15
16impl FromStr for ProviderKind {
17    type Err = ProviderError;
18    fn from_str(value: &str) -> Result<Self, Self::Err> {
19        if value.eq_ignore_ascii_case("gemini") {
20            return Ok(Self::Gemini);
21        }
22        Err(ProviderError::InvalidConfig(
23            "unsupported provider".to_string(),
24        ))
25    }
26}
27
28#[derive(Clone)]
29pub struct ProviderConfig {
30    pub kind: ProviderKind,
31    pub base_url: String,
32    pub model: String,
33    pub api_key: SecretString,
34    pub timeout: Duration,
35    pub allow_insecure_test_base_url: bool,
36}
37
38impl fmt::Debug for ProviderConfig {
39    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
40        formatter
41            .debug_struct("ProviderConfig")
42            .field("kind", &self.kind)
43            .field("base_url", &self.base_url)
44            .field("model", &self.model)
45            .field("api_key", &"[REDACTED]")
46            .field("timeout", &self.timeout)
47            .finish()
48    }
49}
50
51impl ProviderConfig {
52    pub fn parse(self) -> Result<Self, ProviderError> {
53        if self.model.trim().is_empty() {
54            return Err(ProviderError::InvalidConfig(
55                "model is required".to_string(),
56            ));
57        }
58        if self.api_key.expose().trim().is_empty() {
59            return Err(ProviderError::InvalidConfig(
60                "api key is required".to_string(),
61            ));
62        }
63        if !self.allow_insecure_test_base_url && !self.base_url.starts_with("https://") {
64            return Err(ProviderError::InvalidConfig(
65                "base url must be HTTPS".to_string(),
66            ));
67        }
68        let timeout = if self.timeout.is_zero() {
69            Duration::from_secs(20)
70        } else {
71            self.timeout
72        };
73        Ok(Self { timeout, ..self })
74    }
75}
76
77#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
78pub struct SafetyPolicy {
79    pub block_medium_and_above: bool,
80}
81impl Default for SafetyPolicy {
82    fn default() -> Self {
83        Self {
84            block_medium_and_above: true,
85        }
86    }
87}
88
89#[derive(Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
90pub struct AdjudicationRequest {
91    pub prompt: String,
92    pub schema: Option<Value>,
93    pub temperature: Option<f32>,
94    pub max_output_tokens: Option<u32>,
95    pub safety_policy: SafetyPolicy,
96}
97
98impl fmt::Debug for AdjudicationRequest {
99    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
100        formatter
101            .debug_struct("AdjudicationRequest")
102            .field("prompt", &"[REDACTED]")
103            .field("schema", &self.schema.as_ref().map(|_| "[PRESENT]"))
104            .field("temperature", &self.temperature)
105            .field("max_output_tokens", &self.max_output_tokens)
106            .field("safety_policy", &self.safety_policy)
107            .finish()
108    }
109}
110
111#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
112pub struct TokenUsage {
113    pub input_tokens: Option<u32>,
114    pub output_tokens: Option<u32>,
115}
116#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
117pub struct ProviderMeta {
118    pub provider: ProviderKind,
119    pub model: String,
120    pub token_usage: TokenUsage,
121}
122#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
123pub struct AdjudicationResponse {
124    pub json: Value,
125    pub meta: ProviderMeta,
126}
127
128#[derive(Debug, Error)]
129pub enum ProviderError {
130    #[error("invalid provider config: {0}")]
131    InvalidConfig(String),
132    #[error("provider request failed")]
133    Transport,
134    #[error("provider returned non-success status {status}")]
135    HttpStatus { status: u16 },
136    #[error("provider blocked request: {0}")]
137    Blocked(String),
138    #[error("provider response did not contain JSON text")]
139    MissingJsonText,
140    #[error("provider JSON parse failed")]
141    JsonParse,
142    #[error("provider response failed schema validation")]
143    SchemaValidation,
144}
145
146#[async_trait]
147pub trait AdjudicatorClient: Send + Sync {
148    async fn adjudicate(
149        &self,
150        request: AdjudicationRequest,
151    ) -> Result<AdjudicationResponse, ProviderError>;
152}
153
154#[derive(Clone)]
155pub struct GeminiClient {
156    config: ProviderConfig,
157    http: reqwest::Client,
158}
159
160impl GeminiClient {
161    pub fn new(config: ProviderConfig) -> Result<Self, ProviderError> {
162        let config = config.parse()?;
163        install_ring_crypto_provider();
164        let root_certificates = mozilla_root_certificates()?;
165        let http = reqwest::Client::builder()
166            .timeout(config.timeout)
167            .user_agent("chromaframe-sdk/0.1")
168            .http1_only()
169            .tls_certs_only(root_certificates)
170            .build()
171            .map_err(|_| ProviderError::Transport)?;
172        Ok(Self { config, http })
173    }
174
175    #[must_use]
176    pub fn request_body(request: &AdjudicationRequest) -> Value {
177        let mut generation_config = serde_json::Map::new();
178        generation_config.insert("responseMimeType".to_string(), json!("application/json"));
179        if let Some(schema) = &request.schema {
180            generation_config.insert("responseJsonSchema".to_string(), schema.clone());
181        }
182        if let Some(temperature) = request.temperature {
183            generation_config.insert("temperature".to_string(), json!(temperature));
184        }
185        if let Some(max_tokens) = request.max_output_tokens {
186            generation_config.insert("maxOutputTokens".to_string(), json!(max_tokens));
187        }
188        json!({
189            "contents": [{ "parts": [{ "text": request.prompt }] }],
190            "generationConfig": generation_config,
191            "safetySettings": default_safety_settings(request.safety_policy.block_medium_and_above),
192        })
193    }
194
195    fn endpoint(&self) -> String {
196        format!(
197            "{}/models/{}:generateContent",
198            self.config.base_url.trim_end_matches('/'),
199            self.config.model
200        )
201    }
202}
203
204#[async_trait]
205impl AdjudicatorClient for GeminiClient {
206    async fn adjudicate(
207        &self,
208        request: AdjudicationRequest,
209    ) -> Result<AdjudicationResponse, ProviderError> {
210        let schema = request.schema.clone();
211        let response = self
212            .http
213            .post(self.endpoint())
214            .header("x-goog-api-key", self.config.api_key.expose())
215            .json(&Self::request_body(&request))
216            .send()
217            .await
218            .map_err(|_| ProviderError::Transport)?;
219        if !response.status().is_success() {
220            return Err(ProviderError::HttpStatus {
221                status: response.status().as_u16(),
222            });
223        }
224        let value: Value = response
225            .json()
226            .await
227            .map_err(|_| ProviderError::JsonParse)?;
228        parse_gemini_response(value, schema, self.config.model.clone())
229    }
230}
231
232pub fn parse_gemini_response(
233    value: Value,
234    schema: Option<Value>,
235    model: String,
236) -> Result<AdjudicationResponse, ProviderError> {
237    if value
238        .get("promptFeedback")
239        .and_then(|feedback| feedback.get("blockReason"))
240        .is_some()
241    {
242        return Err(ProviderError::Blocked("prompt_feedback".to_string()));
243    }
244    if value
245        .pointer("/candidates/0/finishReason")
246        .and_then(Value::as_str)
247        .is_some_and(|reason| reason == "SAFETY")
248    {
249        return Err(ProviderError::Blocked("candidate_safety".to_string()));
250    }
251    let text = value
252        .pointer("/candidates/0/content/parts")
253        .and_then(Value::as_array)
254        .and_then(|parts| {
255            parts
256                .iter()
257                .filter_map(|part| part.get("text").and_then(Value::as_str))
258                .next()
259        })
260        .ok_or(ProviderError::MissingJsonText)?;
261    let parsed: Value = serde_json::from_str(text).map_err(|_| ProviderError::JsonParse)?;
262    if let Some(schema) = schema {
263        let validator =
264            jsonschema::validator_for(&schema).map_err(|_| ProviderError::SchemaValidation)?;
265        if !validator.is_valid(&parsed) {
266            return Err(ProviderError::SchemaValidation);
267        }
268    }
269    Ok(AdjudicationResponse {
270        json: parsed,
271        meta: ProviderMeta {
272            provider: ProviderKind::Gemini,
273            model,
274            token_usage: TokenUsage {
275                input_tokens: None,
276                output_tokens: None,
277            },
278        },
279    })
280}
281
282fn install_ring_crypto_provider() {
283    if rustls::crypto::CryptoProvider::get_default().is_some() {
284        return;
285    }
286
287    let _ = rustls::crypto::ring::default_provider().install_default();
288}
289
290fn mozilla_root_certificates() -> Result<Vec<reqwest::Certificate>, ProviderError> {
291    webpki_root_certs::TLS_SERVER_ROOT_CERTS
292        .iter()
293        .map(|cert| reqwest::Certificate::from_der(cert.as_ref()))
294        .collect::<Result<Vec<_>, _>>()
295        .map_err(|_| ProviderError::Transport)
296}
297
298fn default_safety_settings(block: bool) -> Value {
299    let threshold = if block {
300        "BLOCK_MEDIUM_AND_ABOVE"
301    } else {
302        "BLOCK_ONLY_HIGH"
303    };
304    json!([
305        {"category":"HARM_CATEGORY_HARASSMENT","threshold":threshold},
306        {"category":"HARM_CATEGORY_HATE_SPEECH","threshold":threshold},
307        {"category":"HARM_CATEGORY_SEXUALLY_EXPLICIT","threshold":threshold},
308        {"category":"HARM_CATEGORY_DANGEROUS_CONTENT","threshold":threshold}
309    ])
310}