Skip to main content

deepseek/
reqwest_client.rs

1use crate::client::{DeepSeekClient, HttpClient, DEFAULT_BASE_URL};
2use crate::error::DeepSeekError;
3use crate::types::{ChatRequest, ChatResponse, ReasonerOutput};
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use tracing::{debug, warn};
7
8/// Native reqwest-based transport.
9#[derive(Clone)]
10pub struct ReqwestClient {
11    client: reqwest::Client,
12}
13
14impl ReqwestClient {
15    pub fn new() -> Self {
16        Self {
17            client: reqwest::Client::new(),
18        }
19    }
20
21    pub fn with_client(client: reqwest::Client) -> Self {
22        Self { client }
23    }
24}
25
26impl Default for ReqwestClient {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32#[async_trait]
33impl HttpClient for ReqwestClient {
34    async fn post_json(
35        &self,
36        url: &str,
37        bearer_token: &str,
38        body: &ChatRequest,
39    ) -> crate::error::Result<ChatResponse> {
40        let resp = self
41            .client
42            .post(url)
43            .bearer_auth(bearer_token)
44            .json(body)
45            .send()
46            .await?;
47
48        let status = resp.status();
49        if !status.is_success() {
50            let text = resp.text().await.unwrap_or_default();
51            return Err(DeepSeekError::Api {
52                status: status.as_u16(),
53                body: text,
54            });
55        }
56
57        let chat_resp: ChatResponse = resp.json().await?;
58        Ok(chat_resp)
59    }
60}
61
62const REASONER_MODEL: &str = "deepseek-reasoner";
63
64/// Production constructor — reads `DEEPSEEK_API_KEY` (and optionally
65/// `DEEPSEEK_BASE_URL`) from the environment.
66pub fn client_from_env() -> Result<DeepSeekClient<ReqwestClient>> {
67    let api_key = std::env::var("DEEPSEEK_API_KEY")
68        .context("DEEPSEEK_API_KEY environment variable not set")?;
69    let base_url =
70        std::env::var("DEEPSEEK_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string());
71    Ok(DeepSeekClient::new(ReqwestClient::new(), api_key).with_base_url(base_url))
72}
73
74/// Thin wrapper that builds a ChatRequest for deepseek-reasoner and calls `client.chat()`.
75pub async fn reason(
76    client: &DeepSeekClient<ReqwestClient>,
77    system: &str,
78    user: &str,
79) -> Result<ReasonerOutput> {
80    debug!("deepseek-reasoner call: system={:.80}…", system);
81
82    let request = ChatRequest {
83        model: REASONER_MODEL.to_string(),
84        messages: vec![
85            crate::types::system_msg(system),
86            crate::types::user_msg(user),
87        ],
88        tools: None,
89        tool_choice: None,
90        temperature: Some(0.6),
91        max_tokens: Some(8192),
92        stream: Some(false),
93        reasoning_effort: Some("high".to_string()),
94        thinking: Some(serde_json::json!({"type": "enabled"})),
95    };
96
97    let resp = client
98        .chat(&request)
99        .await
100        .map_err(|e| anyhow::anyhow!("{e}"))?;
101
102    let choice = resp
103        .choices
104        .into_iter()
105        .next()
106        .context("No choices in DeepSeek response")?;
107
108    Ok(ReasonerOutput {
109        reasoning: choice.message.reasoning_content.unwrap_or_default(),
110        content: choice.message.content.as_str().to_string(),
111    })
112}
113
114/// Returns `true` for errors worth retrying (5xx, network). 4xx errors are not retried.
115fn is_retryable(err: &anyhow::Error) -> bool {
116    let msg = err.to_string();
117    // DeepSeekError::Api includes "API error (STATUS): …"
118    if let Some(rest) = msg.strip_prefix("API error (") {
119        if let Some(code_str) = rest.split(')').next() {
120            if let Ok(code) = code_str.parse::<u16>() {
121                return code >= 500;
122            }
123        }
124    }
125    // Network / reqwest errors are retryable
126    msg.contains("HTTP error:") || msg.contains("connection") || msg.contains("timed out")
127}
128
129/// Wrapper around [`reason`] with exponential backoff: 3 attempts, 1s → 2s → 4s.
130/// Only retries on 5xx / network errors, not 4xx.
131pub async fn reason_with_retry(
132    client: &DeepSeekClient<ReqwestClient>,
133    system: &str,
134    user: &str,
135) -> Result<ReasonerOutput> {
136    let delays = [1, 2, 4]; // seconds
137    let mut last_err = None;
138    for (attempt, &delay_secs) in std::iter::once(&0).chain(delays.iter()).enumerate() {
139        if attempt > 0 {
140            warn!("Retry attempt {attempt}/3 after {delay_secs}s backoff");
141            tokio::time::sleep(std::time::Duration::from_secs(delay_secs)).await;
142        }
143        match reason(client, system, user).await {
144            Ok(output) => return Ok(output),
145            Err(e) => {
146                if !is_retryable(&e) || attempt == 3 {
147                    return Err(e);
148                }
149                warn!("Retryable error: {e}");
150                last_err = Some(e);
151            }
152        }
153    }
154    Err(last_err.unwrap())
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[test]
162    fn test_client_from_env_missing_key() {
163        let original = std::env::var("DEEPSEEK_API_KEY").ok();
164        std::env::remove_var("DEEPSEEK_API_KEY");
165        let result = client_from_env();
166        let err = result
167            .err()
168            .expect("should error when DEEPSEEK_API_KEY is unset");
169        assert!(
170            err.to_string().contains("DEEPSEEK_API_KEY"),
171            "error should mention the missing env var"
172        );
173        if let Some(val) = original {
174            std::env::set_var("DEEPSEEK_API_KEY", val);
175        }
176    }
177
178    #[test]
179    fn test_is_retryable_5xx() {
180        let err = anyhow::anyhow!("API error (500): Internal Server Error");
181        assert!(is_retryable(&err));
182    }
183
184    #[test]
185    fn test_is_not_retryable_4xx() {
186        let err = anyhow::anyhow!("API error (400): Bad Request");
187        assert!(!is_retryable(&err));
188    }
189
190    #[test]
191    fn test_is_retryable_network() {
192        let err = anyhow::anyhow!("HTTP error: connection refused");
193        assert!(is_retryable(&err));
194    }
195}