use crate::error::RsGuardError;
use async_trait::async_trait;
use reqwest::header::{self, HeaderMap, HeaderValue};
use serde::{Deserialize, Serialize};
const LLM_REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
pub mod deepseek;
pub mod factory;
pub mod kimi;
pub mod openai;
pub mod openrouter;
pub mod providers;
pub mod qwen;
#[derive(Debug, Clone, Serialize)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
#[derive(Debug, Serialize)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
pub temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
}
#[derive(Debug, Deserialize)]
pub struct ChatChoice {
pub message: ChatMessageResponse,
}
#[derive(Debug, Deserialize)]
pub struct ChatMessageResponse {
pub content: String,
#[serde(default)]
pub reasoning_content: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct ChatResponse {
pub choices: Vec<ChatChoice>,
}
#[async_trait]
pub trait LlmProvider: Send + Sync + std::fmt::Debug {
fn name(&self) -> &'static str;
async fn chat_completion(
&self,
system_prompt: &str,
user_message: &str,
temperature: f32,
) -> Result<String, RsGuardError>;
}
pub type Provider = Box<dyn LlmProvider>;
#[derive(Debug, Clone, Default)]
pub struct ProviderConfig {
pub base_url: Option<String>,
pub http_referer: Option<String>,
pub max_tokens: Option<u32>,
pub model: String,
}
pub(crate) async fn send_chat_request<B: Serialize + Send>(
client: &reqwest::Client,
url: &str,
request: &B,
provider_name: &str,
) -> Result<String, RsGuardError> {
log::debug!(
"[{}] POST {} (effective params logged at debug level)",
provider_name,
url
);
let response = client.post(url).json(request).send().await.map_err(|e| {
let status = e.status().map(|s| s.as_u16()).unwrap_or(0);
LlmError {
provider: provider_name.to_string(),
status,
message: e.to_string(),
}
})?;
let status = response.status();
if log::log_enabled!(log::Level::Debug) {
let headers = response.headers();
let safe_headers: Vec<String> = headers
.iter()
.filter_map(|(name, value)| {
let name_str = name.as_str();
if name_str == "authorization"
|| name_str == "set-cookie"
|| name_str.contains("token")
|| name_str.contains("key")
{
return None;
}
let val = value.to_str().unwrap_or("<binary>");
let val_display = if val.len() > 80 {
let truncated: String = val.chars().take(80).collect();
format!("{}...", truncated)
} else {
val.to_string()
};
Some(format!("{}: {}", name_str, val_display))
})
.collect();
log::debug!(
"[{}] Response status: {} — headers: [{}]",
provider_name,
status.as_u16(),
safe_headers.join(", ")
);
}
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(LlmError {
provider: provider_name.to_string(),
status: status.as_u16(),
message: body,
}
.into());
}
let chat_response: ChatResponse = response.json().await.map_err(|e| LlmError {
provider: provider_name.to_string(),
status: 0,
message: format!("Failed to parse response: {}", e),
})?;
let choice = chat_response
.choices
.into_iter()
.next()
.ok_or_else(|| LlmError {
provider: provider_name.to_string(),
status: 0,
message: "Empty response from LLM".to_string(),
})?;
if let Some(ref reasoning) = choice.message.reasoning_content {
log::debug!(
"[{}] reasoning_content present ({} chars, content not logged)",
provider_name,
reasoning.len()
);
}
Ok(choice.message.content)
}
#[derive(Debug, Clone)]
pub struct LlmError {
pub provider: String,
pub status: u16,
pub message: String,
}
impl From<LlmError> for RsGuardError {
fn from(err: LlmError) -> Self {
RsGuardError::LlmApi {
provider: err.provider,
status: err.status,
message: err.message,
}
}
}
pub(crate) fn chat_messages(system_prompt: &str, user_message: &str) -> Vec<ChatMessage> {
vec![
ChatMessage {
role: "system".to_string(),
content: system_prompt.to_string(),
},
ChatMessage {
role: "user".to_string(),
content: user_message.to_string(),
},
]
}
pub(crate) fn build_llm_client(
provider_name: &str,
api_key: &str,
extra_headers: &[(&str, &str)],
) -> Result<reqwest::Client, RsGuardError> {
let mut headers = HeaderMap::new();
headers.insert(
header::AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {}", api_key)).map_err(|e| {
RsGuardError::Config(format!("Invalid {} API key format: {}", provider_name, e))
})?,
);
headers.insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
for &(name, value) in extra_headers {
let h_name = header::HeaderName::from_bytes(name.as_bytes()).map_err(|e| {
RsGuardError::Config(format!(
"Invalid header name '{}' for {}: {}",
name, provider_name, e
))
})?;
headers.insert(
h_name,
HeaderValue::from_str(value).map_err(|e| {
RsGuardError::Config(format!(
"Invalid header '{}' value for {}: {}",
name, provider_name, e
))
})?,
);
}
reqwest::Client::builder()
.default_headers(headers)
.timeout(LLM_REQUEST_TIMEOUT)
.build()
.map_err(|e| RsGuardError::Config(format!("Failed to build HTTP client: {}", e)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_llm_client_rejects_invalid_api_key() {
let result = build_llm_client("deepseek", "key\x00with\x01control", &[]);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Invalid deepseek API key format"),
"Expected API key format error, got: {}",
err
);
}
#[test]
fn test_build_llm_client_rejects_invalid_extra_header_name() {
let result = build_llm_client("testprov", "valid-key", &[("inv@lid header name", "value")]);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Invalid header name"),
"Expected header name error, got: {}",
err
);
}
#[test]
fn test_build_llm_client_rejects_invalid_extra_header_value() {
let result = build_llm_client("testprov", "valid-key", &[("X-Custom", "val\x00ue")]);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Invalid header"),
"Expected header value error, got: {}",
err
);
}
#[test]
fn test_build_llm_client_succeeds_with_valid_inputs() {
let result = build_llm_client("deepseek", "valid-key-123", &[]);
assert!(result.is_ok());
}
#[test]
fn test_build_llm_client_succeeds_with_extra_headers() {
let result = build_llm_client(
"openrouter",
"valid-key",
&[("HTTP-Referer", "https://example.com"), ("X-Title", "test")],
);
assert!(result.is_ok());
}
#[test]
fn test_chat_messages_ordering() {
let messages = chat_messages("system prompt", "user diff");
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].role, "system");
assert_eq!(messages[0].content, "system prompt");
assert_eq!(messages[1].role, "user");
assert_eq!(messages[1].content, "user diff");
}
#[tokio::test]
async fn test_send_chat_request_empty_choices() {
use wiremock::matchers::method;
use wiremock::{Mock, MockServer, ResponseTemplate};
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"choices": []
})))
.mount(&mock_server)
.await;
let client = build_llm_client("testprov", "key", &[]).unwrap();
let request = ChatRequest {
model: "test-model".to_string(),
messages: chat_messages("system", "user"),
temperature: 0.1,
max_tokens: None,
};
let result = send_chat_request(
&client,
&format!("{}/chat/completions", mock_server.uri()),
&request,
"testprov",
)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Empty response from LLM"),
"Expected empty choices error, got: {}",
err
);
}
#[tokio::test]
async fn test_send_chat_request_malformed_json() {
use wiremock::matchers::method;
use wiremock::{Mock, MockServer, ResponseTemplate};
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_string("this is not json"))
.mount(&mock_server)
.await;
let client = build_llm_client("testprov", "key", &[]).unwrap();
let request = ChatRequest {
model: "test-model".to_string(),
messages: chat_messages("system", "user"),
temperature: 0.1,
max_tokens: None,
};
let result = send_chat_request(
&client,
&format!("{}/chat/completions", mock_server.uri()),
&request,
"testprov",
)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Failed to parse response"),
"Expected parse error, got: {}",
err
);
}
#[tokio::test]
async fn test_send_chat_request_http_error() {
use wiremock::matchers::method;
use wiremock::{Mock, MockServer, ResponseTemplate};
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error"))
.mount(&mock_server)
.await;
let client = build_llm_client("testprov", "key", &[]).unwrap();
let request = ChatRequest {
model: "test-model".to_string(),
messages: chat_messages("system", "user"),
temperature: 0.1,
max_tokens: None,
};
let result = send_chat_request(
&client,
&format!("{}/chat/completions", mock_server.uri()),
&request,
"testprov",
)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("500"), "Expected 500 error, got: {}", err);
}
#[tokio::test]
async fn test_send_chat_request_reasoning_content_ignored() {
use wiremock::matchers::method;
use wiremock::{Mock, MockServer, ResponseTemplate};
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"choices": [{
"message": {
"content": "Review text",
"reasoning_content": "Internal reasoning that should not appear in output"
}
}]
})))
.mount(&mock_server)
.await;
let client = build_llm_client("testprov", "key", &[]).unwrap();
let request = ChatRequest {
model: "test-model".to_string(),
messages: chat_messages("system", "user"),
temperature: 0.1,
max_tokens: None,
};
let result = send_chat_request(
&client,
&format!("{}/chat/completions", mock_server.uri()),
&request,
"testprov",
)
.await;
assert!(result.is_ok());
let content = result.unwrap();
assert_eq!(content, "Review text");
assert!(!content.contains("Internal reasoning"));
}
}