Skip to main content

rs_guard/llm/
mod.rs

1//! LLM provider abstraction and shared types.
2//!
3//! Defines the [`LlmProvider`] async trait for dispatching chat completion
4//! requests to supported LLM backends, along with shared request/response types
5//! and a common HTTP helper for provider implementations.
6
7use crate::error::RsGuardError;
8use async_trait::async_trait;
9use reqwest::header::{self, HeaderMap, HeaderValue};
10use serde::{Deserialize, Serialize};
11
12/// HTTP request timeout for LLM API calls.
13const LLM_REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
14
15pub mod deepseek;
16pub mod factory;
17pub mod kimi;
18pub mod openai;
19pub mod openrouter;
20pub mod providers;
21pub mod qwen;
22
23/// A single message in a chat conversation.
24#[derive(Debug, Clone, Serialize)]
25pub struct ChatMessage {
26    /// The role of the message sender (e.g. `"system"`, `"user"`).
27    pub role: String,
28    /// The message content.
29    pub content: String,
30}
31
32/// Request body for a chat completion API call.
33#[derive(Debug, Serialize)]
34pub struct ChatRequest {
35    /// Model identifier to use for completion.
36    pub model: String,
37    /// Conversation messages.
38    pub messages: Vec<ChatMessage>,
39    /// Sampling temperature (0.0 to 2.0).
40    pub temperature: f32,
41    /// Maximum tokens in the response (provider-agnostic).
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub max_tokens: Option<u32>,
44}
45
46/// A single choice in a chat completion response.
47#[derive(Debug, Deserialize)]
48pub struct ChatChoice {
49    /// The message content of this choice.
50    pub message: ChatMessageResponse,
51}
52
53/// Message content within a chat completion response choice.
54#[derive(Debug, Deserialize)]
55pub struct ChatMessageResponse {
56    /// The generated text content.
57    pub content: String,
58    /// Optional reasoning content (e.g. Kimi/Moonshot AI chain-of-thought).
59    #[serde(default)]
60    pub reasoning_content: Option<String>,
61}
62
63/// Parsed response from a chat completion API call.
64#[derive(Debug, Deserialize)]
65pub struct ChatResponse {
66    /// List of completion choices returned by the model.
67    pub choices: Vec<ChatChoice>,
68}
69
70/// Async trait for LLM provider implementations.
71///
72/// All providers must implement this trait to participate in the rs-guard
73/// pipeline. Implementations are expected to handle HTTP communication,
74/// authentication, and response parsing.
75#[async_trait]
76pub trait LlmProvider: Send + Sync + std::fmt::Debug {
77    /// Returns the provider's display name (e.g. `"deepseek"`).
78    fn name(&self) -> &'static str;
79
80    /// Sends a chat completion request to the provider.
81    ///
82    /// # Arguments
83    ///
84    /// * `system_prompt` — The system instruction for the model.
85    /// * `user_message` — The user message (typically the diff content).
86    /// * `temperature` — Sampling temperature.
87    async fn chat_completion(
88        &self,
89        system_prompt: &str,
90        user_message: &str,
91        temperature: f32,
92    ) -> Result<String, RsGuardError>;
93}
94
95/// Dynamic-dispatch handle for an LLM provider.
96///
97/// Uses a trait object so the factory can return heterogeneous providers
98/// without enum match arms at every call site.
99pub type Provider = Box<dyn LlmProvider>;
100
101/// Provider-specific configuration overrides from `.reviewer.toml`.
102///
103/// These are resolved by [`crate::config::Config`] and passed to the
104/// provider factory to customise base URLs, model, attribution headers, etc.
105#[derive(Debug, Clone, Default)]
106pub struct ProviderConfig {
107    /// Custom API base URL override.
108    pub base_url: Option<String>,
109    /// HTTP referer for attribution (OpenRouter only).
110    pub http_referer: Option<String>,
111    /// Maximum tokens for LLM completions.
112    pub max_tokens: Option<u32>,
113    /// Model identifier to use (overrides provider default).
114    pub model: String,
115}
116
117/// Sends a chat completion HTTP request and parses the response.
118///
119/// Shared implementation used by all provider modules to avoid duplication
120/// in HTTP error handling, response deserialization, and content extraction.
121///
122/// # Arguments
123///
124/// * `client` — Pre-configured reqwest client with auth headers.
125/// * `url` — Full endpoint URL.
126/// * `request` — Serializable request body.
127/// * `provider_name` — Provider name for error reporting.
128///
129/// # Errors
130///
131/// Returns [`RsGuardError::LlmApi`] on network errors, non-success HTTP
132/// status codes, or response parsing failures.
133pub(crate) async fn send_chat_request<B: Serialize + Send>(
134    client: &reqwest::Client,
135    url: &str,
136    request: &B,
137    provider_name: &str,
138) -> Result<String, RsGuardError> {
139    log::debug!(
140        "[{}] POST {} (effective params logged at debug level)",
141        provider_name,
142        url
143    );
144
145    let response = client.post(url).json(request).send().await.map_err(|e| {
146        let status = e.status().map(|s| s.as_u16()).unwrap_or(0);
147        LlmError {
148            provider: provider_name.to_string(),
149            status,
150            message: e.to_string(),
151        }
152    })?;
153
154    let status = response.status();
155
156    // Log sanitized response headers at debug level for observability.
157    // Only safe, non-sensitive headers are logged.
158    if log::log_enabled!(log::Level::Debug) {
159        let headers = response.headers();
160        let safe_headers: Vec<String> = headers
161            .iter()
162            .filter_map(|(name, value)| {
163                let name_str = name.as_str();
164                // Skip potentially sensitive headers
165                if name_str == "authorization"
166                    || name_str == "set-cookie"
167                    || name_str.contains("token")
168                    || name_str.contains("key")
169                {
170                    return None;
171                }
172                let val = value.to_str().unwrap_or("<binary>");
173                // Truncate long values (use char-aware truncation to avoid panics on multi-byte UTF-8)
174                let val_display = if val.len() > 80 {
175                    let truncated: String = val.chars().take(80).collect();
176                    format!("{}...", truncated)
177                } else {
178                    val.to_string()
179                };
180                Some(format!("{}: {}", name_str, val_display))
181            })
182            .collect();
183        log::debug!(
184            "[{}] Response status: {} — headers: [{}]",
185            provider_name,
186            status.as_u16(),
187            safe_headers.join(", ")
188        );
189    }
190
191    if !status.is_success() {
192        let body = response.text().await.unwrap_or_default();
193        return Err(LlmError {
194            provider: provider_name.to_string(),
195            status: status.as_u16(),
196            message: body,
197        }
198        .into());
199    }
200
201    let chat_response: ChatResponse = response.json().await.map_err(|e| LlmError {
202        provider: provider_name.to_string(),
203        status: 0,
204        message: format!("Failed to parse response: {}", e),
205    })?;
206
207    let choice = chat_response
208        .choices
209        .into_iter()
210        .next()
211        .ok_or_else(|| LlmError {
212            provider: provider_name.to_string(),
213            status: 0,
214            message: "Empty response from LLM".to_string(),
215        })?;
216
217    if let Some(ref reasoning) = choice.message.reasoning_content {
218        log::debug!(
219            "[{}] reasoning_content present ({} chars, content not logged)",
220            provider_name,
221            reasoning.len()
222        );
223    }
224
225    Ok(choice.message.content)
226}
227
228/// Provider-specific error information.
229#[derive(Debug, Clone)]
230pub struct LlmError {
231    /// Name of the provider that produced the error.
232    pub provider: String,
233    /// HTTP status code, or 0 for non-HTTP failures.
234    pub status: u16,
235    /// Human-readable error description.
236    pub message: String,
237}
238
239impl From<LlmError> for RsGuardError {
240    fn from(err: LlmError) -> Self {
241        RsGuardError::LlmApi {
242            provider: err.provider,
243            status: err.status,
244            message: err.message,
245        }
246    }
247}
248
249/// Creates a system + user message pair for a chat completion request.
250///
251/// Shared helper to avoid duplicating message construction across providers.
252pub(crate) fn chat_messages(system_prompt: &str, user_message: &str) -> Vec<ChatMessage> {
253    vec![
254        ChatMessage {
255            role: "system".to_string(),
256            content: system_prompt.to_string(),
257        },
258        ChatMessage {
259            role: "user".to_string(),
260            content: user_message.to_string(),
261        },
262    ]
263}
264
265/// Builds a [`reqwest::Client`] with standard LLM provider headers.
266///
267/// Sets `Authorization: Bearer {api_key}`, `Content-Type: application/json`,
268/// and any additional headers. Uses [`LLM_REQUEST_TIMEOUT`].
269///
270/// # Arguments
271///
272/// * `provider_name` — Provider name for error messages.
273/// * `api_key` — API key for Bearer authentication.
274/// * `extra_headers` — Additional headers to include (e.g. `HTTP-Referer`).
275///
276/// # Errors
277///
278/// Returns [`RsGuardError::Config`] if the API key or extra header values
279/// contain invalid HTTP header characters.
280pub(crate) fn build_llm_client(
281    provider_name: &str,
282    api_key: &str,
283    extra_headers: &[(&str, &str)],
284) -> Result<reqwest::Client, RsGuardError> {
285    let mut headers = HeaderMap::new();
286    headers.insert(
287        header::AUTHORIZATION,
288        HeaderValue::from_str(&format!("Bearer {}", api_key)).map_err(|e| {
289            RsGuardError::Config(format!("Invalid {} API key format: {}", provider_name, e))
290        })?,
291    );
292    headers.insert(
293        header::CONTENT_TYPE,
294        HeaderValue::from_static("application/json"),
295    );
296    for &(name, value) in extra_headers {
297        let h_name = header::HeaderName::from_bytes(name.as_bytes()).map_err(|e| {
298            RsGuardError::Config(format!(
299                "Invalid header name '{}' for {}: {}",
300                name, provider_name, e
301            ))
302        })?;
303        headers.insert(
304            h_name,
305            HeaderValue::from_str(value).map_err(|e| {
306                RsGuardError::Config(format!(
307                    "Invalid header '{}' value for {}: {}",
308                    name, provider_name, e
309                ))
310            })?,
311        );
312    }
313
314    reqwest::Client::builder()
315        .default_headers(headers)
316        .timeout(LLM_REQUEST_TIMEOUT)
317        .build()
318        .map_err(|e| RsGuardError::Config(format!("Failed to build HTTP client: {}", e)))
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_build_llm_client_rejects_invalid_api_key() {
327        let result = build_llm_client("deepseek", "key\x00with\x01control", &[]);
328        assert!(result.is_err());
329        let err = result.unwrap_err().to_string();
330        assert!(
331            err.contains("Invalid deepseek API key format"),
332            "Expected API key format error, got: {}",
333            err
334        );
335    }
336
337    #[test]
338    fn test_build_llm_client_rejects_invalid_extra_header_name() {
339        let result = build_llm_client("testprov", "valid-key", &[("inv@lid header name", "value")]);
340        assert!(result.is_err());
341        let err = result.unwrap_err().to_string();
342        assert!(
343            err.contains("Invalid header name"),
344            "Expected header name error, got: {}",
345            err
346        );
347    }
348
349    #[test]
350    fn test_build_llm_client_rejects_invalid_extra_header_value() {
351        let result = build_llm_client("testprov", "valid-key", &[("X-Custom", "val\x00ue")]);
352        assert!(result.is_err());
353        let err = result.unwrap_err().to_string();
354        assert!(
355            err.contains("Invalid header"),
356            "Expected header value error, got: {}",
357            err
358        );
359    }
360
361    #[test]
362    fn test_build_llm_client_succeeds_with_valid_inputs() {
363        let result = build_llm_client("deepseek", "valid-key-123", &[]);
364        assert!(result.is_ok());
365    }
366
367    #[test]
368    fn test_build_llm_client_succeeds_with_extra_headers() {
369        let result = build_llm_client(
370            "openrouter",
371            "valid-key",
372            &[("HTTP-Referer", "https://example.com"), ("X-Title", "test")],
373        );
374        assert!(result.is_ok());
375    }
376
377    #[test]
378    fn test_chat_messages_ordering() {
379        let messages = chat_messages("system prompt", "user diff");
380        assert_eq!(messages.len(), 2);
381        assert_eq!(messages[0].role, "system");
382        assert_eq!(messages[0].content, "system prompt");
383        assert_eq!(messages[1].role, "user");
384        assert_eq!(messages[1].content, "user diff");
385    }
386
387    #[tokio::test]
388    async fn test_send_chat_request_empty_choices() {
389        use wiremock::matchers::method;
390        use wiremock::{Mock, MockServer, ResponseTemplate};
391
392        let mock_server = MockServer::start().await;
393        Mock::given(method("POST"))
394            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
395                "choices": []
396            })))
397            .mount(&mock_server)
398            .await;
399
400        let client = build_llm_client("testprov", "key", &[]).unwrap();
401        let request = ChatRequest {
402            model: "test-model".to_string(),
403            messages: chat_messages("system", "user"),
404            temperature: 0.1,
405            max_tokens: None,
406        };
407        let result = send_chat_request(
408            &client,
409            &format!("{}/chat/completions", mock_server.uri()),
410            &request,
411            "testprov",
412        )
413        .await;
414
415        assert!(result.is_err());
416        let err = result.unwrap_err().to_string();
417        assert!(
418            err.contains("Empty response from LLM"),
419            "Expected empty choices error, got: {}",
420            err
421        );
422    }
423
424    #[tokio::test]
425    async fn test_send_chat_request_malformed_json() {
426        use wiremock::matchers::method;
427        use wiremock::{Mock, MockServer, ResponseTemplate};
428
429        let mock_server = MockServer::start().await;
430        Mock::given(method("POST"))
431            .respond_with(ResponseTemplate::new(200).set_body_string("this is not json"))
432            .mount(&mock_server)
433            .await;
434
435        let client = build_llm_client("testprov", "key", &[]).unwrap();
436        let request = ChatRequest {
437            model: "test-model".to_string(),
438            messages: chat_messages("system", "user"),
439            temperature: 0.1,
440            max_tokens: None,
441        };
442        let result = send_chat_request(
443            &client,
444            &format!("{}/chat/completions", mock_server.uri()),
445            &request,
446            "testprov",
447        )
448        .await;
449
450        assert!(result.is_err());
451        let err = result.unwrap_err().to_string();
452        assert!(
453            err.contains("Failed to parse response"),
454            "Expected parse error, got: {}",
455            err
456        );
457    }
458
459    #[tokio::test]
460    async fn test_send_chat_request_http_error() {
461        use wiremock::matchers::method;
462        use wiremock::{Mock, MockServer, ResponseTemplate};
463
464        let mock_server = MockServer::start().await;
465        Mock::given(method("POST"))
466            .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error"))
467            .mount(&mock_server)
468            .await;
469
470        let client = build_llm_client("testprov", "key", &[]).unwrap();
471        let request = ChatRequest {
472            model: "test-model".to_string(),
473            messages: chat_messages("system", "user"),
474            temperature: 0.1,
475            max_tokens: None,
476        };
477        let result = send_chat_request(
478            &client,
479            &format!("{}/chat/completions", mock_server.uri()),
480            &request,
481            "testprov",
482        )
483        .await;
484
485        assert!(result.is_err());
486        let err = result.unwrap_err().to_string();
487        assert!(err.contains("500"), "Expected 500 error, got: {}", err);
488    }
489
490    #[tokio::test]
491    async fn test_send_chat_request_reasoning_content_ignored() {
492        use wiremock::matchers::method;
493        use wiremock::{Mock, MockServer, ResponseTemplate};
494
495        let mock_server = MockServer::start().await;
496        Mock::given(method("POST"))
497            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
498                "choices": [{
499                    "message": {
500                        "content": "Review text",
501                        "reasoning_content": "Internal reasoning that should not appear in output"
502                    }
503                }]
504            })))
505            .mount(&mock_server)
506            .await;
507
508        let client = build_llm_client("testprov", "key", &[]).unwrap();
509        let request = ChatRequest {
510            model: "test-model".to_string(),
511            messages: chat_messages("system", "user"),
512            temperature: 0.1,
513            max_tokens: None,
514        };
515        let result = send_chat_request(
516            &client,
517            &format!("{}/chat/completions", mock_server.uri()),
518            &request,
519            "testprov",
520        )
521        .await;
522
523        assert!(result.is_ok());
524        let content = result.unwrap();
525        assert_eq!(content, "Review text");
526        assert!(!content.contains("Internal reasoning"));
527    }
528}