1use std::{sync::Arc, time::Duration};
4
5use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, AUTHORIZATION};
6use serde_json::json;
7use tokio::sync::Mutex;
8
9use crate::{
10 error::ShieldError,
11 stream::{parse_sse_response, ShieldStreamEvent},
12 types::{
13 HealthResponse, ListDetectorsResponse, ShieldRequest, ShieldResponse, ToolContext,
14 TokenResponse,
15 },
16};
17
18const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
21const DEFAULT_MAX_RETRIES: u32 = 2;
22const TOKEN_REFRESH_BUFFER: Duration = Duration::from_secs(60);
24const RETRY_STATUS_CODES: &[u16] = &[429, 500, 502, 503, 504];
26
27const SAAS_BASE_URL: &str = "https://shield.api.highflame.ai";
28const SAAS_TOKEN_URL: &str = "https://studio.api.highflame.ai/api/cli-auth/token";
29
30#[derive(Debug, Clone)]
33struct CachedToken {
34 access_token: String,
35 expires_at: tokio::time::Instant,
38 account_id: String,
39 project_id: String,
40 #[allow(dead_code)]
41 gateway_id: String,
42}
43
44#[derive(Debug, Clone)]
58pub struct ShieldClientOptions {
59 pub(crate) api_key: String,
60 pub(crate) base_url: Option<String>,
61 pub(crate) token_url: Option<String>,
62 pub(crate) timeout: Option<Duration>,
63 pub(crate) max_retries: Option<u32>,
64 pub(crate) account_id: Option<String>,
65 pub(crate) project_id: Option<String>,
66}
67
68impl ShieldClientOptions {
69 pub fn new(api_key: impl Into<String>) -> Self {
71 Self {
72 api_key: api_key.into(),
73 base_url: None,
74 token_url: None,
75 timeout: None,
76 max_retries: None,
77 account_id: None,
78 project_id: None,
79 }
80 }
81
82 pub fn base_url(mut self, url: impl Into<String>) -> Self {
84 self.base_url = Some(url.into());
85 self
86 }
87
88 pub fn token_url(mut self, url: impl Into<String>) -> Self {
91 self.token_url = Some(url.into());
92 self
93 }
94
95 pub fn timeout(mut self, t: Duration) -> Self {
97 self.timeout = Some(t);
98 self
99 }
100
101 pub fn max_retries(mut self, n: u32) -> Self {
103 self.max_retries = Some(n);
104 self
105 }
106
107 pub fn account_id(mut self, id: impl Into<String>) -> Self {
110 self.account_id = Some(id.into());
111 self
112 }
113
114 pub fn project_id(mut self, id: impl Into<String>) -> Self {
117 self.project_id = Some(id.into());
118 self
119 }
120}
121
122struct Inner {
125 base_url: String,
126 api_key: String,
127 token_url: Option<String>,
129 timeout: Duration,
130 max_retries: u32,
131 override_account_id: Option<String>,
132 override_project_id: Option<String>,
133 static_headers: HeaderMap,
135 http: reqwest::Client,
137 token_cache: Mutex<Option<CachedToken>>,
145}
146
147#[derive(Clone)]
166pub struct ShieldClient {
167 inner: Arc<Inner>,
168}
169
170impl ShieldClient {
171 pub fn new(options: ShieldClientOptions) -> Self {
173 let base_url = options
174 .base_url
175 .unwrap_or_else(|| SAAS_BASE_URL.to_string());
176 let base_url = base_url.trim_end_matches('/').to_string();
177
178 let token_url = if options.api_key.starts_with("hf_sk") {
179 Some(
180 options
181 .token_url
182 .unwrap_or_else(|| SAAS_TOKEN_URL.to_string()),
183 )
184 } else {
185 None
186 };
187
188 let timeout = options.timeout.unwrap_or(DEFAULT_TIMEOUT);
189 let max_retries = options.max_retries.unwrap_or(DEFAULT_MAX_RETRIES);
190
191 let mut static_headers = HeaderMap::new();
193 static_headers.insert(
194 AUTHORIZATION,
195 HeaderValue::from_str(&format!("Bearer {}", options.api_key))
196 .expect("api_key contains invalid header characters"),
197 );
198 static_headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
199 if let Some(ref id) = options.account_id {
200 static_headers.insert(
201 reqwest::header::HeaderName::from_static("x-account-id"),
202 HeaderValue::from_str(id).expect("account_id contains invalid header characters"),
203 );
204 }
205 if let Some(ref id) = options.project_id {
206 static_headers.insert(
207 reqwest::header::HeaderName::from_static("x-project-id"),
208 HeaderValue::from_str(id).expect("project_id contains invalid header characters"),
209 );
210 }
211
212 let http = reqwest::Client::builder()
213 .build()
214 .expect("failed to build reqwest client");
215
216 Self {
217 inner: Arc::new(Inner {
218 base_url,
219 api_key: options.api_key,
220 token_url,
221 timeout,
222 max_retries,
223 override_account_id: options.account_id,
224 override_project_id: options.project_id,
225 static_headers,
226 http,
227 token_cache: Mutex::new(None),
228 }),
229 }
230 }
231
232 pub async fn account_id(&self) -> String {
237 if let Some(ref id) = self.inner.override_account_id {
238 return id.clone();
239 }
240 self.inner
241 .token_cache
242 .lock()
243 .await
244 .as_ref()
245 .map(|t| t.account_id.clone())
246 .unwrap_or_default()
247 }
248
249 pub async fn project_id(&self) -> String {
252 if let Some(ref id) = self.inner.override_project_id {
253 return id.clone();
254 }
255 self.inner
256 .token_cache
257 .lock()
258 .await
259 .as_ref()
260 .map(|t| t.project_id.clone())
261 .unwrap_or_default()
262 }
263
264 fn build_token_headers(&self, token: &CachedToken) -> Result<HeaderMap, ShieldError> {
267 let mut headers = HeaderMap::new();
268 headers.insert(
269 AUTHORIZATION,
270 HeaderValue::from_str(&format!("Bearer {}", token.access_token))
271 .map_err(|e| ShieldError::Connection(e.to_string()))?,
272 );
273 headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
274
275 let account_id = self
276 .inner
277 .override_account_id
278 .as_deref()
279 .or_else(|| (!token.account_id.is_empty()).then_some(token.account_id.as_str()));
280 let project_id = self
281 .inner
282 .override_project_id
283 .as_deref()
284 .or_else(|| (!token.project_id.is_empty()).then_some(token.project_id.as_str()));
285
286 if let Some(id) = account_id {
287 headers.insert(
288 reqwest::header::HeaderName::from_static("x-account-id"),
289 HeaderValue::from_str(id).map_err(|e| ShieldError::Connection(e.to_string()))?,
290 );
291 }
292 if let Some(id) = project_id {
293 headers.insert(
294 reqwest::header::HeaderName::from_static("x-project-id"),
295 HeaderValue::from_str(id).map_err(|e| ShieldError::Connection(e.to_string()))?,
296 );
297 }
298 Ok(headers)
299 }
300
301 async fn exchange_token(&self) -> Result<CachedToken, ShieldError> {
302 let url = self
303 .inner
304 .token_url
305 .as_ref()
306 .expect("exchange_token called without a token_url");
307
308 let resp = self
309 .inner
310 .http
311 .post(url)
312 .json(&json!({ "grant_type": "api_key", "api_key": self.inner.api_key }))
313 .timeout(self.inner.timeout)
314 .send()
315 .await
316 .map_err(|e| ShieldError::Connection(e.to_string()))?;
317
318 if !resp.status().is_success() {
319 return Err(parse_api_error(resp).await);
320 }
321
322 let tok: TokenResponse = resp.json().await?;
323 let ttl = Duration::from_secs(tok.expires_in).saturating_sub(TOKEN_REFRESH_BUFFER);
324 Ok(CachedToken {
325 access_token: tok.access_token,
326 expires_at: tokio::time::Instant::now() + ttl,
327 account_id: tok.account_id,
328 project_id: tok.project_id,
329 gateway_id: tok.gateway_id,
330 })
331 }
332
333 pub async fn get_auth_headers(&self) -> Result<HeaderMap, ShieldError> {
345 if self.inner.token_url.is_none() {
346 return Ok(self.inner.static_headers.clone());
347 }
348
349 let mut cache = self.inner.token_cache.lock().await;
350
351 if let Some(ref token) = *cache {
353 if tokio::time::Instant::now() < token.expires_at {
354 return self.build_token_headers(token);
355 }
356 }
357
358 let new_token = self.exchange_token().await?;
360 let headers = self.build_token_headers(&new_token)?;
361 *cache = Some(new_token);
362 Ok(headers)
363 }
364
365 async fn send_with_retry(
368 &self,
369 method: reqwest::Method,
370 path: &str,
371 json_body: Option<&serde_json::Value>,
372 ) -> Result<reqwest::Response, ShieldError> {
373 let url = format!("{}{}", self.inner.base_url, path);
374 let mut last_err =
375 ShieldError::Connection("request failed after all retries".to_string());
376
377 for attempt in 0..=self.inner.max_retries {
378 if attempt > 0 {
379 let delay = Duration::from_millis(1_000 * 2u64.pow(attempt - 1));
380 tokio::time::sleep(delay).await;
381 }
382
383 let headers = self.get_auth_headers().await?;
384 let mut req = self
385 .inner
386 .http
387 .request(method.clone(), &url)
388 .headers(headers)
389 .timeout(self.inner.timeout);
390
391 if let Some(body) = json_body {
392 req = req.json(body);
393 }
394
395 let resp = match req.send().await {
396 Ok(r) => r,
397 Err(e) if e.is_timeout() => {
398 last_err = ShieldError::Connection(format!("request timed out: {e}"));
399 continue;
400 }
401 Err(e) => return Err(ShieldError::Connection(e.to_string())),
402 };
403
404 if !RETRY_STATUS_CODES.contains(&resp.status().as_u16()) {
405 return Ok(resp);
406 }
407
408 last_err = parse_api_error(resp).await;
409 }
410
411 Err(last_err)
412 }
413
414 pub async fn guard(&self, request: &ShieldRequest) -> Result<ShieldResponse, ShieldError> {
420 let body = serde_json::to_value(request)?;
421 let resp = self
422 .send_with_retry(reqwest::Method::POST, "/v1/guard", Some(&body))
423 .await?;
424 if !resp.status().is_success() {
425 return Err(parse_api_error(resp).await);
426 }
427 Ok(resp.json().await?)
428 }
429
430 pub async fn guard_prompt(
435 &self,
436 content: &str,
437 mode: Option<&str>,
438 session_id: Option<&str>,
439 ) -> Result<ShieldResponse, ShieldError> {
440 self.guard(&ShieldRequest {
441 content: content.to_string(),
442 content_type: "prompt".to_string(),
443 action: "process_prompt".to_string(),
444 mode: mode.map(str::to_string),
445 session_id: session_id.map(str::to_string),
446 ..Default::default()
447 })
448 .await
449 }
450
451 pub async fn guard_tool_call(
456 &self,
457 tool_name: &str,
458 arguments: Option<std::collections::HashMap<String, serde_json::Value>>,
459 mode: Option<&str>,
460 session_id: Option<&str>,
461 ) -> Result<ShieldResponse, ShieldError> {
462 self.guard(&ShieldRequest {
463 content: format!("Tool call: {tool_name}"),
464 content_type: "tool_call".to_string(),
465 action: "call_tool".to_string(),
466 mode: mode.map(str::to_string),
467 session_id: session_id.map(str::to_string),
468 tool: Some(ToolContext {
469 name: tool_name.to_string(),
470 arguments,
471 ..Default::default()
472 }),
473 ..Default::default()
474 })
475 .await
476 }
477
478 pub async fn stream(
483 &self,
484 request: &ShieldRequest,
485 ) -> Result<
486 impl futures_util::Stream<Item = Result<ShieldStreamEvent, ShieldError>>,
487 ShieldError,
488 > {
489 let url = format!("{}/v1/guard/stream", self.inner.base_url);
490 let mut headers = self.get_auth_headers().await?;
491 headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream"));
492
493 let resp = self
494 .inner
495 .http
496 .post(&url)
497 .headers(headers)
498 .json(request)
499 .timeout(self.inner.timeout)
500 .send()
501 .await
502 .map_err(|e| ShieldError::Connection(e.to_string()))?;
503
504 if !resp.status().is_success() {
505 return Err(parse_api_error(resp).await);
506 }
507
508 Ok(parse_sse_response(resp))
509 }
510
511 pub async fn health(&self) -> Result<HealthResponse, ShieldError> {
515 let resp = self
516 .send_with_retry(reqwest::Method::GET, "/v1/health", None)
517 .await?;
518 if !resp.status().is_success() {
519 return Err(parse_api_error(resp).await);
520 }
521 Ok(resp.json().await?)
522 }
523
524 pub async fn list_detectors(&self) -> Result<ListDetectorsResponse, ShieldError> {
528 let resp = self
529 .send_with_retry(reqwest::Method::GET, "/v1/detectors", None)
530 .await?;
531 if !resp.status().is_success() {
532 return Err(parse_api_error(resp).await);
533 }
534 Ok(resp.json().await?)
535 }
536}
537
538async fn parse_api_error(resp: reqwest::Response) -> ShieldError {
542 let status = resp.status().as_u16();
543 match resp.json::<serde_json::Value>().await {
544 Ok(body) => ShieldError::Api {
545 status,
546 title: body["title"].as_str().unwrap_or("Error").to_string(),
547 detail: body["detail"].as_str().unwrap_or("").to_string(),
548 },
549 Err(_) => ShieldError::Api {
550 status,
551 title: "Error".to_string(),
552 detail: String::new(),
553 },
554 }
555}