1use crate::error::RsGuardError;
8use async_trait::async_trait;
9use reqwest::header::{self, HeaderMap, HeaderValue};
10use serde::{Deserialize, Serialize};
11
12const LLM_REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
14
15pub mod deepseek;
16pub mod factory;
17pub mod kimi;
18pub mod openai;
19pub mod openrouter;
20pub mod providers;
21pub mod qwen;
22
23#[derive(Debug, Clone, Serialize)]
25pub struct ChatMessage {
26 pub role: String,
28 pub content: String,
30}
31
32#[derive(Debug, Serialize)]
34pub struct ChatRequest {
35 pub model: String,
37 pub messages: Vec<ChatMessage>,
39 pub temperature: f32,
41 #[serde(skip_serializing_if = "Option::is_none")]
43 pub max_tokens: Option<u32>,
44}
45
46#[derive(Debug, Deserialize)]
48pub struct ChatChoice {
49 pub message: ChatMessageResponse,
51}
52
53#[derive(Debug, Deserialize)]
55pub struct ChatMessageResponse {
56 pub content: String,
58 #[serde(default)]
60 pub reasoning_content: Option<String>,
61}
62
63#[derive(Debug, Deserialize)]
65pub struct ChatResponse {
66 pub choices: Vec<ChatChoice>,
68}
69
70#[async_trait]
76pub trait LlmProvider: Send + Sync + std::fmt::Debug {
77 fn name(&self) -> &'static str;
79
80 async fn chat_completion(
88 &self,
89 system_prompt: &str,
90 user_message: &str,
91 temperature: f32,
92 ) -> Result<String, RsGuardError>;
93}
94
95pub type Provider = Box<dyn LlmProvider>;
100
101#[derive(Debug, Clone, Default)]
106pub struct ProviderConfig {
107 pub base_url: Option<String>,
109 pub http_referer: Option<String>,
111 pub max_tokens: Option<u32>,
113 pub model: String,
115}
116
117pub(crate) async fn send_chat_request<B: Serialize + Send>(
134 client: &reqwest::Client,
135 url: &str,
136 request: &B,
137 provider_name: &str,
138) -> Result<String, RsGuardError> {
139 log::debug!(
140 "[{}] POST {} (effective params logged at debug level)",
141 provider_name,
142 url
143 );
144
145 let response = client.post(url).json(request).send().await.map_err(|e| {
146 let status = e.status().map(|s| s.as_u16()).unwrap_or(0);
147 LlmError {
148 provider: provider_name.to_string(),
149 status,
150 message: e.to_string(),
151 }
152 })?;
153
154 let status = response.status();
155
156 if log::log_enabled!(log::Level::Debug) {
159 let headers = response.headers();
160 let safe_headers: Vec<String> = headers
161 .iter()
162 .filter_map(|(name, value)| {
163 let name_str = name.as_str();
164 if name_str == "authorization"
166 || name_str == "set-cookie"
167 || name_str.contains("token")
168 || name_str.contains("key")
169 {
170 return None;
171 }
172 let val = value.to_str().unwrap_or("<binary>");
173 let val_display = if val.len() > 80 {
175 let truncated: String = val.chars().take(80).collect();
176 format!("{}...", truncated)
177 } else {
178 val.to_string()
179 };
180 Some(format!("{}: {}", name_str, val_display))
181 })
182 .collect();
183 log::debug!(
184 "[{}] Response status: {} — headers: [{}]",
185 provider_name,
186 status.as_u16(),
187 safe_headers.join(", ")
188 );
189 }
190
191 if !status.is_success() {
192 let body = response.text().await.unwrap_or_default();
193 return Err(LlmError {
194 provider: provider_name.to_string(),
195 status: status.as_u16(),
196 message: body,
197 }
198 .into());
199 }
200
201 let chat_response: ChatResponse = response.json().await.map_err(|e| LlmError {
202 provider: provider_name.to_string(),
203 status: 0,
204 message: format!("Failed to parse response: {}", e),
205 })?;
206
207 let choice = chat_response
208 .choices
209 .into_iter()
210 .next()
211 .ok_or_else(|| LlmError {
212 provider: provider_name.to_string(),
213 status: 0,
214 message: "Empty response from LLM".to_string(),
215 })?;
216
217 if let Some(ref reasoning) = choice.message.reasoning_content {
218 log::debug!(
219 "[{}] reasoning_content present ({} chars, content not logged)",
220 provider_name,
221 reasoning.len()
222 );
223 }
224
225 Ok(choice.message.content)
226}
227
228#[derive(Debug, Clone)]
230pub struct LlmError {
231 pub provider: String,
233 pub status: u16,
235 pub message: String,
237}
238
239impl From<LlmError> for RsGuardError {
240 fn from(err: LlmError) -> Self {
241 RsGuardError::LlmApi {
242 provider: err.provider,
243 status: err.status,
244 message: err.message,
245 }
246 }
247}
248
249pub(crate) fn chat_messages(system_prompt: &str, user_message: &str) -> Vec<ChatMessage> {
253 vec![
254 ChatMessage {
255 role: "system".to_string(),
256 content: system_prompt.to_string(),
257 },
258 ChatMessage {
259 role: "user".to_string(),
260 content: user_message.to_string(),
261 },
262 ]
263}
264
265pub(crate) fn build_llm_client(
281 provider_name: &str,
282 api_key: &str,
283 extra_headers: &[(&str, &str)],
284) -> Result<reqwest::Client, RsGuardError> {
285 let mut headers = HeaderMap::new();
286 headers.insert(
287 header::AUTHORIZATION,
288 HeaderValue::from_str(&format!("Bearer {}", api_key)).map_err(|e| {
289 RsGuardError::Config(format!("Invalid {} API key format: {}", provider_name, e))
290 })?,
291 );
292 headers.insert(
293 header::CONTENT_TYPE,
294 HeaderValue::from_static("application/json"),
295 );
296 for &(name, value) in extra_headers {
297 let h_name = header::HeaderName::from_bytes(name.as_bytes()).map_err(|e| {
298 RsGuardError::Config(format!(
299 "Invalid header name '{}' for {}: {}",
300 name, provider_name, e
301 ))
302 })?;
303 headers.insert(
304 h_name,
305 HeaderValue::from_str(value).map_err(|e| {
306 RsGuardError::Config(format!(
307 "Invalid header '{}' value for {}: {}",
308 name, provider_name, e
309 ))
310 })?,
311 );
312 }
313
314 reqwest::Client::builder()
315 .default_headers(headers)
316 .timeout(LLM_REQUEST_TIMEOUT)
317 .build()
318 .map_err(|e| RsGuardError::Config(format!("Failed to build HTTP client: {}", e)))
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_build_llm_client_rejects_invalid_api_key() {
327 let result = build_llm_client("deepseek", "key\x00with\x01control", &[]);
328 assert!(result.is_err());
329 let err = result.unwrap_err().to_string();
330 assert!(
331 err.contains("Invalid deepseek API key format"),
332 "Expected API key format error, got: {}",
333 err
334 );
335 }
336
337 #[test]
338 fn test_build_llm_client_rejects_invalid_extra_header_name() {
339 let result = build_llm_client("testprov", "valid-key", &[("inv@lid header name", "value")]);
340 assert!(result.is_err());
341 let err = result.unwrap_err().to_string();
342 assert!(
343 err.contains("Invalid header name"),
344 "Expected header name error, got: {}",
345 err
346 );
347 }
348
349 #[test]
350 fn test_build_llm_client_rejects_invalid_extra_header_value() {
351 let result = build_llm_client("testprov", "valid-key", &[("X-Custom", "val\x00ue")]);
352 assert!(result.is_err());
353 let err = result.unwrap_err().to_string();
354 assert!(
355 err.contains("Invalid header"),
356 "Expected header value error, got: {}",
357 err
358 );
359 }
360
361 #[test]
362 fn test_build_llm_client_succeeds_with_valid_inputs() {
363 let result = build_llm_client("deepseek", "valid-key-123", &[]);
364 assert!(result.is_ok());
365 }
366
367 #[test]
368 fn test_build_llm_client_succeeds_with_extra_headers() {
369 let result = build_llm_client(
370 "openrouter",
371 "valid-key",
372 &[("HTTP-Referer", "https://example.com"), ("X-Title", "test")],
373 );
374 assert!(result.is_ok());
375 }
376
377 #[test]
378 fn test_chat_messages_ordering() {
379 let messages = chat_messages("system prompt", "user diff");
380 assert_eq!(messages.len(), 2);
381 assert_eq!(messages[0].role, "system");
382 assert_eq!(messages[0].content, "system prompt");
383 assert_eq!(messages[1].role, "user");
384 assert_eq!(messages[1].content, "user diff");
385 }
386
387 #[tokio::test]
388 async fn test_send_chat_request_empty_choices() {
389 use wiremock::matchers::method;
390 use wiremock::{Mock, MockServer, ResponseTemplate};
391
392 let mock_server = MockServer::start().await;
393 Mock::given(method("POST"))
394 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
395 "choices": []
396 })))
397 .mount(&mock_server)
398 .await;
399
400 let client = build_llm_client("testprov", "key", &[]).unwrap();
401 let request = ChatRequest {
402 model: "test-model".to_string(),
403 messages: chat_messages("system", "user"),
404 temperature: 0.1,
405 max_tokens: None,
406 };
407 let result = send_chat_request(
408 &client,
409 &format!("{}/chat/completions", mock_server.uri()),
410 &request,
411 "testprov",
412 )
413 .await;
414
415 assert!(result.is_err());
416 let err = result.unwrap_err().to_string();
417 assert!(
418 err.contains("Empty response from LLM"),
419 "Expected empty choices error, got: {}",
420 err
421 );
422 }
423
424 #[tokio::test]
425 async fn test_send_chat_request_malformed_json() {
426 use wiremock::matchers::method;
427 use wiremock::{Mock, MockServer, ResponseTemplate};
428
429 let mock_server = MockServer::start().await;
430 Mock::given(method("POST"))
431 .respond_with(ResponseTemplate::new(200).set_body_string("this is not json"))
432 .mount(&mock_server)
433 .await;
434
435 let client = build_llm_client("testprov", "key", &[]).unwrap();
436 let request = ChatRequest {
437 model: "test-model".to_string(),
438 messages: chat_messages("system", "user"),
439 temperature: 0.1,
440 max_tokens: None,
441 };
442 let result = send_chat_request(
443 &client,
444 &format!("{}/chat/completions", mock_server.uri()),
445 &request,
446 "testprov",
447 )
448 .await;
449
450 assert!(result.is_err());
451 let err = result.unwrap_err().to_string();
452 assert!(
453 err.contains("Failed to parse response"),
454 "Expected parse error, got: {}",
455 err
456 );
457 }
458
459 #[tokio::test]
460 async fn test_send_chat_request_http_error() {
461 use wiremock::matchers::method;
462 use wiremock::{Mock, MockServer, ResponseTemplate};
463
464 let mock_server = MockServer::start().await;
465 Mock::given(method("POST"))
466 .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error"))
467 .mount(&mock_server)
468 .await;
469
470 let client = build_llm_client("testprov", "key", &[]).unwrap();
471 let request = ChatRequest {
472 model: "test-model".to_string(),
473 messages: chat_messages("system", "user"),
474 temperature: 0.1,
475 max_tokens: None,
476 };
477 let result = send_chat_request(
478 &client,
479 &format!("{}/chat/completions", mock_server.uri()),
480 &request,
481 "testprov",
482 )
483 .await;
484
485 assert!(result.is_err());
486 let err = result.unwrap_err().to_string();
487 assert!(err.contains("500"), "Expected 500 error, got: {}", err);
488 }
489
490 #[tokio::test]
491 async fn test_send_chat_request_reasoning_content_ignored() {
492 use wiremock::matchers::method;
493 use wiremock::{Mock, MockServer, ResponseTemplate};
494
495 let mock_server = MockServer::start().await;
496 Mock::given(method("POST"))
497 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
498 "choices": [{
499 "message": {
500 "content": "Review text",
501 "reasoning_content": "Internal reasoning that should not appear in output"
502 }
503 }]
504 })))
505 .mount(&mock_server)
506 .await;
507
508 let client = build_llm_client("testprov", "key", &[]).unwrap();
509 let request = ChatRequest {
510 model: "test-model".to_string(),
511 messages: chat_messages("system", "user"),
512 temperature: 0.1,
513 max_tokens: None,
514 };
515 let result = send_chat_request(
516 &client,
517 &format!("{}/chat/completions", mock_server.uri()),
518 &request,
519 "testprov",
520 )
521 .await;
522
523 assert!(result.is_ok());
524 let content = result.unwrap();
525 assert_eq!(content, "Review text");
526 assert!(!content.contains("Internal reasoning"));
527 }
528}