Skip to main content

aster/security/
classification_client.rs

1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::time::Duration;
5use url::Url;
6
7/// Request format following HuggingFace Inference Text Classification API specification
8#[derive(Debug, Serialize)]
9struct ClassificationRequest {
10    inputs: String,
11    #[serde(skip_serializing_if = "Option::is_none")]
12    parameters: Option<serde_json::Value>,
13}
14
15#[derive(Debug, Deserialize, Clone)]
16struct ClassificationLabel {
17    label: String,
18    score: f32,
19}
20
21type ClassificationResponse = Vec<Vec<ClassificationLabel>>;
22
23#[derive(Debug, Deserialize, Clone)]
24pub struct ModelEndpointInfo {
25    pub endpoint: String,
26    #[serde(flatten)]
27    pub extra_params: HashMap<String, serde_json::Value>,
28}
29
30#[derive(Debug, Deserialize, Clone)]
31pub struct ModelMappingConfig {
32    #[serde(flatten)]
33    pub models: HashMap<String, ModelEndpointInfo>,
34}
35
36#[derive(Debug)]
37pub struct ClassificationClient {
38    endpoint_url: String,
39    client: reqwest::Client,
40    auth_token: Option<String>,
41    extra_params: Option<HashMap<String, serde_json::Value>>,
42}
43
44impl ClassificationClient {
45    pub fn new(
46        endpoint_url: String,
47        timeout_ms: Option<u64>,
48        auth_token: Option<String>,
49        extra_params: Option<HashMap<String, serde_json::Value>>,
50    ) -> Result<Self> {
51        let timeout = Duration::from_millis(timeout_ms.unwrap_or(5000));
52
53        let client = reqwest::Client::builder()
54            .timeout(timeout)
55            .build()
56            .context("Failed to create HTTP client")?;
57
58        Ok(Self {
59            endpoint_url,
60            client,
61            auth_token,
62            extra_params,
63        })
64    }
65
66    pub fn from_model_name(model_name: &str, timeout_ms: Option<u64>) -> Result<Self> {
67        let mapping_json = std::env::var("SECURITY_ML_MODEL_MAPPING")
68            .context("SECURITY_ML_MODEL_MAPPING environment variable not set")?;
69
70        let mapping = serde_json::from_str::<ModelMappingConfig>(&mapping_json)
71            .context("Failed to parse SECURITY_ML_MODEL_MAPPING JSON")?;
72
73        let model_info = mapping.models.get(model_name).context(format!(
74            "Model '{}' not found in SECURITY_ML_MODEL_MAPPING",
75            model_name
76        ))?;
77
78        tracing::info!(
79            model_name = %model_name,
80            endpoint = %model_info.endpoint,
81            extra_params = ?model_info.extra_params,
82            "Creating classification client from model mapping"
83        );
84
85        Self::new(
86            model_info.endpoint.clone(),
87            timeout_ms,
88            None,
89            Some(model_info.extra_params.clone()),
90        )
91    }
92
93    pub fn from_endpoint(
94        endpoint_url: String,
95        timeout_ms: Option<u64>,
96        auth_token: Option<String>,
97    ) -> Result<Self> {
98        let endpoint_url = endpoint_url.trim().to_string();
99
100        Url::parse(&endpoint_url)
101            .context("Invalid endpoint URL format. Must be a valid HTTP/HTTPS URL")?;
102
103        let auth_token = auth_token
104            .map(|t| t.trim().to_string())
105            .filter(|t| !t.is_empty());
106
107        tracing::info!(
108            endpoint = %endpoint_url,
109            has_token = auth_token.is_some(),
110            "Creating classification client from endpoint"
111        );
112
113        Self::new(endpoint_url, timeout_ms, auth_token, None)
114    }
115
116    pub async fn classify(&self, text: &str) -> Result<f32> {
117        tracing::debug!(
118            endpoint = %self.endpoint_url,
119            text_length = text.len(),
120            "Sending classification request"
121        );
122
123        let parameters = self
124            .extra_params
125            .as_ref()
126            .map(serde_json::to_value)
127            .transpose()?;
128
129        let request = ClassificationRequest {
130            inputs: text.to_string(),
131            parameters,
132        };
133
134        let mut request_builder = self.client.post(&self.endpoint_url).json(&request);
135
136        if let Some(token) = &self.auth_token {
137            request_builder = request_builder.header("Authorization", format!("Bearer {}", token));
138        }
139
140        let response = request_builder
141            .send()
142            .await
143            .context("Failed to send classification request")?;
144
145        let status = response.status();
146        let response = if !status.is_success() {
147            let error_body = response.text().await.unwrap_or_default();
148            return Err(anyhow::anyhow!(
149                "Classification API returned error status {}: {}",
150                status,
151                error_body
152            ));
153        } else {
154            response
155        };
156
157        let classification_response: ClassificationResponse = response
158            .json()
159            .await
160            .context("Failed to parse classification response")?;
161
162        let batch_result = classification_response
163            .first()
164            .context("Classification API returned empty response")?;
165
166        let sum: f32 = batch_result.iter().map(|l| l.score).sum();
167        let is_probabilities = batch_result
168            .iter()
169            .all(|label| label.score >= 0.0 && label.score <= 1.0)
170            && (sum - 1.0).abs() < 0.1;
171
172        let normalized_results: Vec<ClassificationLabel> = if is_probabilities {
173            batch_result.to_vec()
174        } else {
175            self.apply_softmax(batch_result)?
176        };
177
178        let top_label = normalized_results
179            .iter()
180            .max_by(|a, b| {
181                a.score
182                    .partial_cmp(&b.score)
183                    .unwrap_or(std::cmp::Ordering::Equal)
184            })
185            .context("Classification API returned no labels")?;
186
187        let injection_score = match top_label.label.as_str() {
188            "INJECTION" | "LABEL_1" => top_label.score,
189            "SAFE" | "LABEL_0" => 1.0 - top_label.score,
190            _ => {
191                tracing::warn!(
192                    label = %top_label.label,
193                    score = %top_label.score,
194                    "Unknown classification label, defaulting to safe"
195                );
196                0.0
197            }
198        };
199
200        tracing::info!(
201            injection_score = %injection_score,
202            top_label = %top_label.label,
203            top_score = %top_label.score,
204            normalized = !is_probabilities,
205            "Classification complete"
206        );
207
208        Ok(injection_score)
209    }
210
211    fn apply_softmax(&self, labels: &[ClassificationLabel]) -> Result<Vec<ClassificationLabel>> {
212        if labels.is_empty() {
213            return Ok(Vec::new());
214        }
215
216        let max_score = labels
217            .iter()
218            .map(|l| l.score)
219            .fold(f32::NEG_INFINITY, f32::max);
220
221        let exp_scores: Vec<f32> = labels.iter().map(|l| (l.score - max_score).exp()).collect();
222
223        let sum_exp: f32 = exp_scores.iter().sum();
224
225        if sum_exp == 0.0 || !sum_exp.is_finite() {
226            anyhow::bail!("Softmax normalization failed: invalid sum");
227        }
228
229        let normalized: Vec<ClassificationLabel> = labels
230            .iter()
231            .zip(exp_scores.iter())
232            .map(|(label, &exp_score)| ClassificationLabel {
233                label: label.label.clone(),
234                score: exp_score / sum_exp,
235            })
236            .collect();
237
238        Ok(normalized)
239    }
240}