aster/security/
classification_client.rs1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::time::Duration;
5use url::Url;
6
7#[derive(Debug, Serialize)]
9struct ClassificationRequest {
10 inputs: String,
11 #[serde(skip_serializing_if = "Option::is_none")]
12 parameters: Option<serde_json::Value>,
13}
14
15#[derive(Debug, Deserialize, Clone)]
16struct ClassificationLabel {
17 label: String,
18 score: f32,
19}
20
21type ClassificationResponse = Vec<Vec<ClassificationLabel>>;
22
23#[derive(Debug, Deserialize, Clone)]
24pub struct ModelEndpointInfo {
25 pub endpoint: String,
26 #[serde(flatten)]
27 pub extra_params: HashMap<String, serde_json::Value>,
28}
29
30#[derive(Debug, Deserialize, Clone)]
31pub struct ModelMappingConfig {
32 #[serde(flatten)]
33 pub models: HashMap<String, ModelEndpointInfo>,
34}
35
36#[derive(Debug)]
37pub struct ClassificationClient {
38 endpoint_url: String,
39 client: reqwest::Client,
40 auth_token: Option<String>,
41 extra_params: Option<HashMap<String, serde_json::Value>>,
42}
43
44impl ClassificationClient {
45 pub fn new(
46 endpoint_url: String,
47 timeout_ms: Option<u64>,
48 auth_token: Option<String>,
49 extra_params: Option<HashMap<String, serde_json::Value>>,
50 ) -> Result<Self> {
51 let timeout = Duration::from_millis(timeout_ms.unwrap_or(5000));
52
53 let client = reqwest::Client::builder()
54 .timeout(timeout)
55 .build()
56 .context("Failed to create HTTP client")?;
57
58 Ok(Self {
59 endpoint_url,
60 client,
61 auth_token,
62 extra_params,
63 })
64 }
65
66 pub fn from_model_name(model_name: &str, timeout_ms: Option<u64>) -> Result<Self> {
67 let mapping_json = std::env::var("SECURITY_ML_MODEL_MAPPING")
68 .context("SECURITY_ML_MODEL_MAPPING environment variable not set")?;
69
70 let mapping = serde_json::from_str::<ModelMappingConfig>(&mapping_json)
71 .context("Failed to parse SECURITY_ML_MODEL_MAPPING JSON")?;
72
73 let model_info = mapping.models.get(model_name).context(format!(
74 "Model '{}' not found in SECURITY_ML_MODEL_MAPPING",
75 model_name
76 ))?;
77
78 tracing::info!(
79 model_name = %model_name,
80 endpoint = %model_info.endpoint,
81 extra_params = ?model_info.extra_params,
82 "Creating classification client from model mapping"
83 );
84
85 Self::new(
86 model_info.endpoint.clone(),
87 timeout_ms,
88 None,
89 Some(model_info.extra_params.clone()),
90 )
91 }
92
93 pub fn from_endpoint(
94 endpoint_url: String,
95 timeout_ms: Option<u64>,
96 auth_token: Option<String>,
97 ) -> Result<Self> {
98 let endpoint_url = endpoint_url.trim().to_string();
99
100 Url::parse(&endpoint_url)
101 .context("Invalid endpoint URL format. Must be a valid HTTP/HTTPS URL")?;
102
103 let auth_token = auth_token
104 .map(|t| t.trim().to_string())
105 .filter(|t| !t.is_empty());
106
107 tracing::info!(
108 endpoint = %endpoint_url,
109 has_token = auth_token.is_some(),
110 "Creating classification client from endpoint"
111 );
112
113 Self::new(endpoint_url, timeout_ms, auth_token, None)
114 }
115
116 pub async fn classify(&self, text: &str) -> Result<f32> {
117 tracing::debug!(
118 endpoint = %self.endpoint_url,
119 text_length = text.len(),
120 "Sending classification request"
121 );
122
123 let parameters = self
124 .extra_params
125 .as_ref()
126 .map(serde_json::to_value)
127 .transpose()?;
128
129 let request = ClassificationRequest {
130 inputs: text.to_string(),
131 parameters,
132 };
133
134 let mut request_builder = self.client.post(&self.endpoint_url).json(&request);
135
136 if let Some(token) = &self.auth_token {
137 request_builder = request_builder.header("Authorization", format!("Bearer {}", token));
138 }
139
140 let response = request_builder
141 .send()
142 .await
143 .context("Failed to send classification request")?;
144
145 let status = response.status();
146 let response = if !status.is_success() {
147 let error_body = response.text().await.unwrap_or_default();
148 return Err(anyhow::anyhow!(
149 "Classification API returned error status {}: {}",
150 status,
151 error_body
152 ));
153 } else {
154 response
155 };
156
157 let classification_response: ClassificationResponse = response
158 .json()
159 .await
160 .context("Failed to parse classification response")?;
161
162 let batch_result = classification_response
163 .first()
164 .context("Classification API returned empty response")?;
165
166 let sum: f32 = batch_result.iter().map(|l| l.score).sum();
167 let is_probabilities = batch_result
168 .iter()
169 .all(|label| label.score >= 0.0 && label.score <= 1.0)
170 && (sum - 1.0).abs() < 0.1;
171
172 let normalized_results: Vec<ClassificationLabel> = if is_probabilities {
173 batch_result.to_vec()
174 } else {
175 self.apply_softmax(batch_result)?
176 };
177
178 let top_label = normalized_results
179 .iter()
180 .max_by(|a, b| {
181 a.score
182 .partial_cmp(&b.score)
183 .unwrap_or(std::cmp::Ordering::Equal)
184 })
185 .context("Classification API returned no labels")?;
186
187 let injection_score = match top_label.label.as_str() {
188 "INJECTION" | "LABEL_1" => top_label.score,
189 "SAFE" | "LABEL_0" => 1.0 - top_label.score,
190 _ => {
191 tracing::warn!(
192 label = %top_label.label,
193 score = %top_label.score,
194 "Unknown classification label, defaulting to safe"
195 );
196 0.0
197 }
198 };
199
200 tracing::info!(
201 injection_score = %injection_score,
202 top_label = %top_label.label,
203 top_score = %top_label.score,
204 normalized = !is_probabilities,
205 "Classification complete"
206 );
207
208 Ok(injection_score)
209 }
210
211 fn apply_softmax(&self, labels: &[ClassificationLabel]) -> Result<Vec<ClassificationLabel>> {
212 if labels.is_empty() {
213 return Ok(Vec::new());
214 }
215
216 let max_score = labels
217 .iter()
218 .map(|l| l.score)
219 .fold(f32::NEG_INFINITY, f32::max);
220
221 let exp_scores: Vec<f32> = labels.iter().map(|l| (l.score - max_score).exp()).collect();
222
223 let sum_exp: f32 = exp_scores.iter().sum();
224
225 if sum_exp == 0.0 || !sum_exp.is_finite() {
226 anyhow::bail!("Softmax normalization failed: invalid sum");
227 }
228
229 let normalized: Vec<ClassificationLabel> = labels
230 .iter()
231 .zip(exp_scores.iter())
232 .map(|(label, &exp_score)| ClassificationLabel {
233 label: label.label.clone(),
234 score: exp_score / sum_exp,
235 })
236 .collect();
237
238 Ok(normalized)
239 }
240}