Skip to main content

highflame_shield/
client.rs

1//! [`ShieldClient`] — async HTTP client for highflame-shield.
2
3use std::{sync::Arc, time::Duration};
4
5use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, AUTHORIZATION};
6use serde_json::json;
7use tokio::sync::Mutex;
8
9use crate::{
10    error::ShieldError,
11    stream::{parse_sse_response, ShieldStreamEvent},
12    types::{
13        HealthResponse, ListDetectorsResponse, ShieldRequest, ShieldResponse, ToolContext,
14        TokenResponse,
15    },
16};
17
18// ── Constants ─────────────────────────────────────────────────────────────────
19
20const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
21const DEFAULT_MAX_RETRIES: u32 = 2;
22/// Refresh the token this many seconds before it actually expires.
23const TOKEN_REFRESH_BUFFER: Duration = Duration::from_secs(60);
24/// HTTP status codes that trigger a retry.
25const RETRY_STATUS_CODES: &[u16] = &[429, 500, 502, 503, 504];
26
27const SAAS_BASE_URL: &str = "https://shield.api.highflame.ai";
28const SAAS_TOKEN_URL: &str = "https://studio.api.highflame.ai/api/cli-auth/token";
29
30// ── Token cache ───────────────────────────────────────────────────────────────
31
32#[derive(Debug, Clone)]
33struct CachedToken {
34    access_token: String,
35    /// Uses `tokio::time::Instant` so `tokio::time::pause()` / `advance()` work
36    /// in tests.
37    expires_at: tokio::time::Instant,
38    account_id: String,
39    project_id: String,
40    #[allow(dead_code)]
41    gateway_id: String,
42}
43
44// ── Options ───────────────────────────────────────────────────────────────────
45
46/// Configuration for [`ShieldClient`].
47///
48/// ```no_run
49/// use highflame_shield::ShieldClientOptions;
50/// use std::time::Duration;
51///
52/// let opts = ShieldClientOptions::new("hf_sk_my_key")
53///     .base_url("https://shield.self-hosted.example.com")
54///     .timeout(Duration::from_secs(10))
55///     .max_retries(1);
56/// ```
57#[derive(Debug, Clone)]
58pub struct ShieldClientOptions {
59    pub(crate) api_key: String,
60    pub(crate) base_url: Option<String>,
61    pub(crate) token_url: Option<String>,
62    pub(crate) timeout: Option<Duration>,
63    pub(crate) max_retries: Option<u32>,
64    pub(crate) account_id: Option<String>,
65    pub(crate) project_id: Option<String>,
66}
67
68impl ShieldClientOptions {
69    /// Create options from just an API key; all other fields take defaults.
70    pub fn new(api_key: impl Into<String>) -> Self {
71        Self {
72            api_key: api_key.into(),
73            base_url: None,
74            token_url: None,
75            timeout: None,
76            max_retries: None,
77            account_id: None,
78            project_id: None,
79        }
80    }
81
82    /// Override the guard service base URL (default: Highflame SaaS endpoint).
83    pub fn base_url(mut self, url: impl Into<String>) -> Self {
84        self.base_url = Some(url.into());
85        self
86    }
87
88    /// Override the token exchange URL (default: Highflame SaaS endpoint).
89    /// Only used when `api_key` starts with `hf_sk`.
90    pub fn token_url(mut self, url: impl Into<String>) -> Self {
91        self.token_url = Some(url.into());
92        self
93    }
94
95    /// Request timeout (default: 30 s).
96    pub fn timeout(mut self, t: Duration) -> Self {
97        self.timeout = Some(t);
98        self
99    }
100
101    /// Number of retries on 429 / 5xx (default: 2).
102    pub fn max_retries(mut self, n: u32) -> Self {
103        self.max_retries = Some(n);
104        self
105    }
106
107    /// Override `X-Account-ID` header; takes precedence over the token exchange
108    /// value.
109    pub fn account_id(mut self, id: impl Into<String>) -> Self {
110        self.account_id = Some(id.into());
111        self
112    }
113
114    /// Override `X-Project-ID` header; takes precedence over the token exchange
115    /// value.
116    pub fn project_id(mut self, id: impl Into<String>) -> Self {
117        self.project_id = Some(id.into());
118        self
119    }
120}
121
122// ── Inner shared state ────────────────────────────────────────────────────────
123
124struct Inner {
125    base_url: String,
126    api_key: String,
127    /// `Some` only when the service-key exchange flow is active.
128    token_url: Option<String>,
129    timeout: Duration,
130    max_retries: u32,
131    override_account_id: Option<String>,
132    override_project_id: Option<String>,
133    /// Pre-built headers for the direct-JWT path (no token exchange).
134    static_headers: HeaderMap,
135    /// Shared, connection-pooled HTTP client.
136    http: reqwest::Client,
137    /// Token cache for the service-key flow.
138    ///
139    /// `tokio::sync::Mutex` is held across the `await` in `exchange_token` so
140    /// that concurrent callers block behind the mutex rather than all issuing
141    /// their own exchange request.  Once the first caller refreshes the token
142    /// and releases the lock, all queued callers find a fresh token on the
143    /// fast-path check and return immediately.
144    token_cache: Mutex<Option<CachedToken>>,
145}
146
147// ── Client ────────────────────────────────────────────────────────────────────
148
149/// Async HTTP client for highflame-shield.
150///
151/// Cheap to clone — all clones share the underlying connection pool and token
152/// cache via [`Arc`].
153///
154/// # Service key flow
155///
156/// Keys starting with `hf_sk` are automatically exchanged for a short-lived
157/// RS256 JWT via `POST <token_url>`.  The JWT is cached and auto-refreshed 60 s
158/// before expiry.  Concurrent callers share a single [`tokio::sync::Mutex`];
159/// only one exchange is ever in-flight at a time.
160///
161/// # Direct JWT flow
162///
163/// Any other `api_key` value is sent directly as `Authorization: Bearer <key>`,
164/// useful when you already hold a short-lived token.
165#[derive(Clone)]
166pub struct ShieldClient {
167    inner: Arc<Inner>,
168}
169
170impl ShieldClient {
171    /// Create a new [`ShieldClient`].
172    pub fn new(options: ShieldClientOptions) -> Self {
173        let base_url = options
174            .base_url
175            .unwrap_or_else(|| SAAS_BASE_URL.to_string());
176        let base_url = base_url.trim_end_matches('/').to_string();
177
178        let token_url = if options.api_key.starts_with("hf_sk") {
179            Some(
180                options
181                    .token_url
182                    .unwrap_or_else(|| SAAS_TOKEN_URL.to_string()),
183            )
184        } else {
185            None
186        };
187
188        let timeout = options.timeout.unwrap_or(DEFAULT_TIMEOUT);
189        let max_retries = options.max_retries.unwrap_or(DEFAULT_MAX_RETRIES);
190
191        // Pre-build headers for the direct-JWT (non-exchange) path.
192        let mut static_headers = HeaderMap::new();
193        static_headers.insert(
194            AUTHORIZATION,
195            HeaderValue::from_str(&format!("Bearer {}", options.api_key))
196                .expect("api_key contains invalid header characters"),
197        );
198        static_headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
199        if let Some(ref id) = options.account_id {
200            static_headers.insert(
201                reqwest::header::HeaderName::from_static("x-account-id"),
202                HeaderValue::from_str(id).expect("account_id contains invalid header characters"),
203            );
204        }
205        if let Some(ref id) = options.project_id {
206            static_headers.insert(
207                reqwest::header::HeaderName::from_static("x-project-id"),
208                HeaderValue::from_str(id).expect("project_id contains invalid header characters"),
209            );
210        }
211
212        let http = reqwest::Client::builder()
213            .build()
214            .expect("failed to build reqwest client");
215
216        Self {
217            inner: Arc::new(Inner {
218                base_url,
219                api_key: options.api_key,
220                token_url,
221                timeout,
222                max_retries,
223                override_account_id: options.account_id,
224                override_project_id: options.project_id,
225                static_headers,
226                http,
227                token_cache: Mutex::new(None),
228            }),
229        }
230    }
231
232    // ── Public accessors ──────────────────────────────────────────────────────
233
234    /// Account ID from the last token exchange, or the constructor override.
235    /// Empty when using the direct-JWT flow or before the first exchange.
236    pub async fn account_id(&self) -> String {
237        if let Some(ref id) = self.inner.override_account_id {
238            return id.clone();
239        }
240        self.inner
241            .token_cache
242            .lock()
243            .await
244            .as_ref()
245            .map(|t| t.account_id.clone())
246            .unwrap_or_default()
247    }
248
249    /// Project ID from the last token exchange, or the constructor override.
250    /// Empty when using the direct-JWT flow or before the first exchange.
251    pub async fn project_id(&self) -> String {
252        if let Some(ref id) = self.inner.override_project_id {
253            return id.clone();
254        }
255        self.inner
256            .token_cache
257            .lock()
258            .await
259            .as_ref()
260            .map(|t| t.project_id.clone())
261            .unwrap_or_default()
262    }
263
264    // ── Token exchange ────────────────────────────────────────────────────────
265
266    fn build_token_headers(&self, token: &CachedToken) -> Result<HeaderMap, ShieldError> {
267        let mut headers = HeaderMap::new();
268        headers.insert(
269            AUTHORIZATION,
270            HeaderValue::from_str(&format!("Bearer {}", token.access_token))
271                .map_err(|e| ShieldError::Connection(e.to_string()))?,
272        );
273        headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
274
275        let account_id = self
276            .inner
277            .override_account_id
278            .as_deref()
279            .or_else(|| (!token.account_id.is_empty()).then_some(token.account_id.as_str()));
280        let project_id = self
281            .inner
282            .override_project_id
283            .as_deref()
284            .or_else(|| (!token.project_id.is_empty()).then_some(token.project_id.as_str()));
285
286        if let Some(id) = account_id {
287            headers.insert(
288                reqwest::header::HeaderName::from_static("x-account-id"),
289                HeaderValue::from_str(id).map_err(|e| ShieldError::Connection(e.to_string()))?,
290            );
291        }
292        if let Some(id) = project_id {
293            headers.insert(
294                reqwest::header::HeaderName::from_static("x-project-id"),
295                HeaderValue::from_str(id).map_err(|e| ShieldError::Connection(e.to_string()))?,
296            );
297        }
298        Ok(headers)
299    }
300
301    async fn exchange_token(&self) -> Result<CachedToken, ShieldError> {
302        let url = self
303            .inner
304            .token_url
305            .as_ref()
306            .expect("exchange_token called without a token_url");
307
308        let resp = self
309            .inner
310            .http
311            .post(url)
312            .json(&json!({ "grant_type": "api_key", "api_key": self.inner.api_key }))
313            .timeout(self.inner.timeout)
314            .send()
315            .await
316            .map_err(|e| ShieldError::Connection(e.to_string()))?;
317
318        if !resp.status().is_success() {
319            return Err(parse_api_error(resp).await);
320        }
321
322        let tok: TokenResponse = resp.json().await?;
323        let ttl = Duration::from_secs(tok.expires_in).saturating_sub(TOKEN_REFRESH_BUFFER);
324        Ok(CachedToken {
325            access_token: tok.access_token,
326            expires_at: tokio::time::Instant::now() + ttl,
327            account_id: tok.account_id,
328            project_id: tok.project_id,
329            gateway_id: tok.gateway_id,
330        })
331    }
332
333    /// Return auth headers, exchanging or refreshing the JWT as needed.
334    ///
335    /// The `tokio::sync::Mutex` on the token cache is held across the `await`
336    /// in [`exchange_token`] so that:
337    /// - Only one exchange is ever in-flight at a time.
338    /// - Concurrent callers that see an expired token block behind the mutex;
339    ///   once the first completes, the rest find a fresh token on the fast-path
340    ///   check and return immediately without re-exchanging.
341    ///
342    /// Exposed publicly so callers can forward Shield credentials to their own
343    /// HTTP clients (e.g. a reverse proxy).
344    pub async fn get_auth_headers(&self) -> Result<HeaderMap, ShieldError> {
345        if self.inner.token_url.is_none() {
346            return Ok(self.inner.static_headers.clone());
347        }
348
349        let mut cache = self.inner.token_cache.lock().await;
350
351        // Fast path: token is still valid.
352        if let Some(ref token) = *cache {
353            if tokio::time::Instant::now() < token.expires_at {
354                return self.build_token_headers(token);
355            }
356        }
357
358        // Token absent or expired — exchange while holding the lock.
359        let new_token = self.exchange_token().await?;
360        let headers = self.build_token_headers(&new_token)?;
361        *cache = Some(new_token);
362        Ok(headers)
363    }
364
365    // ── Request helpers ───────────────────────────────────────────────────────
366
367    async fn send_with_retry(
368        &self,
369        method: reqwest::Method,
370        path: &str,
371        json_body: Option<&serde_json::Value>,
372    ) -> Result<reqwest::Response, ShieldError> {
373        let url = format!("{}{}", self.inner.base_url, path);
374        let mut last_err =
375            ShieldError::Connection("request failed after all retries".to_string());
376
377        for attempt in 0..=self.inner.max_retries {
378            if attempt > 0 {
379                let delay = Duration::from_millis(1_000 * 2u64.pow(attempt - 1));
380                tokio::time::sleep(delay).await;
381            }
382
383            let headers = self.get_auth_headers().await?;
384            let mut req = self
385                .inner
386                .http
387                .request(method.clone(), &url)
388                .headers(headers)
389                .timeout(self.inner.timeout);
390
391            if let Some(body) = json_body {
392                req = req.json(body);
393            }
394
395            let resp = match req.send().await {
396                Ok(r) => r,
397                Err(e) if e.is_timeout() => {
398                    last_err = ShieldError::Connection(format!("request timed out: {e}"));
399                    continue;
400                }
401                Err(e) => return Err(ShieldError::Connection(e.to_string())),
402            };
403
404            if !RETRY_STATUS_CODES.contains(&resp.status().as_u16()) {
405                return Ok(resp);
406            }
407
408            last_err = parse_api_error(resp).await;
409        }
410
411        Err(last_err)
412    }
413
414    // ── Public API ────────────────────────────────────────────────────────────
415
416    /// Evaluate content against guard policies.
417    ///
418    /// `POST /v1/guard` — full detect + Cedar evaluate.
419    pub async fn guard(&self, request: &ShieldRequest) -> Result<ShieldResponse, ShieldError> {
420        let body = serde_json::to_value(request)?;
421        let resp = self
422            .send_with_retry(reqwest::Method::POST, "/v1/guard", Some(&body))
423            .await?;
424        if !resp.status().is_success() {
425            return Err(parse_api_error(resp).await);
426        }
427        Ok(resp.json().await?)
428    }
429
430    /// Shorthand: evaluate a user prompt.
431    ///
432    /// Equivalent to [`guard`] with `content_type: "prompt"` and
433    /// `action: "process_prompt"`.
434    pub async fn guard_prompt(
435        &self,
436        content: &str,
437        mode: Option<&str>,
438        session_id: Option<&str>,
439    ) -> Result<ShieldResponse, ShieldError> {
440        self.guard(&ShieldRequest {
441            content: content.to_string(),
442            content_type: "prompt".to_string(),
443            action: "process_prompt".to_string(),
444            mode: mode.map(str::to_string),
445            session_id: session_id.map(str::to_string),
446            ..Default::default()
447        })
448        .await
449    }
450
451    /// Shorthand: evaluate a tool call.
452    ///
453    /// Equivalent to [`guard`] with `content_type: "tool_call"` and
454    /// `action: "call_tool"`.
455    pub async fn guard_tool_call(
456        &self,
457        tool_name: &str,
458        arguments: Option<std::collections::HashMap<String, serde_json::Value>>,
459        mode: Option<&str>,
460        session_id: Option<&str>,
461    ) -> Result<ShieldResponse, ShieldError> {
462        self.guard(&ShieldRequest {
463            content: format!("Tool call: {tool_name}"),
464            content_type: "tool_call".to_string(),
465            action: "call_tool".to_string(),
466            mode: mode.map(str::to_string),
467            session_id: session_id.map(str::to_string),
468            tool: Some(ToolContext {
469                name: tool_name.to_string(),
470                arguments,
471                ..Default::default()
472            }),
473            ..Default::default()
474        })
475        .await
476    }
477
478    /// Evaluate content with SSE streaming.
479    ///
480    /// `POST /v1/guard/stream` — returns a `Stream` of [`ShieldStreamEvent`]s.
481    /// The first `Err` in the stream terminates it.
482    pub async fn stream(
483        &self,
484        request: &ShieldRequest,
485    ) -> Result<
486        impl futures_util::Stream<Item = Result<ShieldStreamEvent, ShieldError>>,
487        ShieldError,
488    > {
489        let url = format!("{}/v1/guard/stream", self.inner.base_url);
490        let mut headers = self.get_auth_headers().await?;
491        headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream"));
492
493        let resp = self
494            .inner
495            .http
496            .post(&url)
497            .headers(headers)
498            .json(request)
499            .timeout(self.inner.timeout)
500            .send()
501            .await
502            .map_err(|e| ShieldError::Connection(e.to_string()))?;
503
504        if !resp.status().is_success() {
505            return Err(parse_api_error(resp).await);
506        }
507
508        Ok(parse_sse_response(resp))
509    }
510
511    /// Get detailed service health.
512    ///
513    /// `GET /v1/health` — returns detector statuses and evaluator state.
514    pub async fn health(&self) -> Result<HealthResponse, ShieldError> {
515        let resp = self
516            .send_with_retry(reqwest::Method::GET, "/v1/health", None)
517            .await?;
518        if !resp.status().is_success() {
519            return Err(parse_api_error(resp).await);
520        }
521        Ok(resp.json().await?)
522    }
523
524    /// List available detectors.
525    ///
526    /// `GET /v1/detectors` — returns detector names, tiers, and health.
527    pub async fn list_detectors(&self) -> Result<ListDetectorsResponse, ShieldError> {
528        let resp = self
529            .send_with_retry(reqwest::Method::GET, "/v1/detectors", None)
530            .await?;
531        if !resp.status().is_success() {
532            return Err(parse_api_error(resp).await);
533        }
534        Ok(resp.json().await?)
535    }
536}
537
538// ── Helpers ───────────────────────────────────────────────────────────────────
539
540/// Parse an RFC 9457 Problem Details body from a non-2xx response.
541async fn parse_api_error(resp: reqwest::Response) -> ShieldError {
542    let status = resp.status().as_u16();
543    match resp.json::<serde_json::Value>().await {
544        Ok(body) => ShieldError::Api {
545            status,
546            title: body["title"].as_str().unwrap_or("Error").to_string(),
547            detail: body["detail"].as_str().unwrap_or("").to_string(),
548        },
549        Err(_) => ShieldError::Api {
550            status,
551            title: "Error".to_string(),
552            detail: String::new(),
553        },
554    }
555}