use crate::core::providers::unified_provider::ProviderError;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GuardrailConfig {
pub guardrail_id: String,
pub guardrail_version: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub trace: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GuardrailApplyRequest {
pub content: Vec<GuardrailContent>,
pub source: GuardrailSource,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GuardrailContent {
pub text: GuardrailText,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct GuardrailText {
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub qualifiers: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "UPPERCASE")]
pub enum GuardrailSource {
Input,
Output,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GuardrailResponse {
pub usage: GuardrailUsage,
pub outputs: Vec<GuardrailOutput>,
pub action: GuardrailAction,
#[serde(skip_serializing_if = "Option::is_none")]
pub assessments: Option<Vec<GuardrailAssessment>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub guardrail_coverage: Option<GuardrailCoverage>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GuardrailUsage {
pub topic_policy_units: u32,
pub content_policy_units: u32,
pub word_policy_units: u32,
pub sensitive_information_policy_units: u32,
pub sensitive_information_policy_free_units: u32,
pub contextual_grounding_policy_units: u32,
}
#[derive(Debug, Deserialize)]
pub struct GuardrailOutput {
pub text: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "UPPERCASE")]
pub enum GuardrailAction {
None,
Guardrail,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GuardrailAssessment {
pub topic_policy: Option<TopicPolicyAssessment>,
pub content_policy: Option<ContentPolicyAssessment>,
pub word_policy: Option<WordPolicyAssessment>,
pub sensitive_information_policy: Option<SensitiveInformationPolicyAssessment>,
}
#[derive(Debug, Deserialize)]
pub struct TopicPolicyAssessment {
pub topics: Vec<Topic>,
}
#[derive(Debug, Deserialize)]
pub struct Topic {
pub name: String,
#[serde(rename = "type")]
pub topic_type: String,
pub action: String,
}
#[derive(Debug, Deserialize)]
pub struct ContentPolicyAssessment {
pub filters: Vec<ContentFilter>,
}
#[derive(Debug, Deserialize)]
pub struct ContentFilter {
#[serde(rename = "type")]
pub filter_type: String,
pub confidence: String,
pub action: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct WordPolicyAssessment {
pub custom_words: Vec<String>,
pub managed_word_lists: Vec<String>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SensitiveInformationPolicyAssessment {
pub pii_entities: Vec<PiiEntity>,
pub regexes: Vec<RegexMatch>,
}
#[derive(Debug, Deserialize)]
pub struct PiiEntity {
#[serde(rename = "type")]
pub entity_type: String,
pub action: String,
}
#[derive(Debug, Deserialize)]
pub struct RegexMatch {
pub name: String,
pub action: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GuardrailCoverage {
pub text_characters: TextCharactersCoverage,
}
#[derive(Debug, Deserialize)]
pub struct TextCharactersCoverage {
pub guarded: u32,
pub total: u32,
}
pub struct GuardrailClient<'a> {
client: &'a crate::core::providers::bedrock::client::BedrockClient,
}
impl<'a> GuardrailClient<'a> {
pub fn new(client: &'a crate::core::providers::bedrock::client::BedrockClient) -> Self {
Self { client }
}
pub async fn apply(
&self,
guardrail_id: &str,
guardrail_version: &str,
content: &str,
source: GuardrailSource,
) -> Result<GuardrailResponse, ProviderError> {
let request = GuardrailApplyRequest {
content: vec![GuardrailContent {
text: GuardrailText {
text: content.to_string(),
qualifiers: None,
},
}],
source,
};
let url = format!(
"guardrail/{}/version/{}/apply",
guardrail_id, guardrail_version
);
let response = self
.client
.send_request("", &url, &serde_json::to_value(request)?)
.await?;
let guardrail_response: GuardrailResponse = response
.json()
.await
.map_err(|e| ProviderError::response_parsing("bedrock", e.to_string()))?;
Ok(guardrail_response)
}
pub async fn check_content(
&self,
guardrail_id: &str,
guardrail_version: &str,
content: &str,
source: GuardrailSource,
) -> Result<bool, ProviderError> {
let response = self
.apply(guardrail_id, guardrail_version, content, source)
.await?;
Ok(matches!(response.action, GuardrailAction::None))
}
}