Skip to main content

deepseek/
reqwest_client.rs

1use async_trait::async_trait;
2use anyhow::{Context, Result};
3use tracing::{debug, warn};
4use crate::client::{HttpClient, DeepSeekClient, DEFAULT_BASE_URL};
5use crate::error::DeepSeekError;
6use crate::types::{ChatRequest, ChatResponse, ReasonerOutput};
7
8/// Native reqwest-based transport.
9#[derive(Clone)]
10pub struct ReqwestClient {
11    client: reqwest::Client,
12}
13
14impl ReqwestClient {
15    pub fn new() -> Self {
16        Self { client: reqwest::Client::new() }
17    }
18
19    pub fn with_client(client: reqwest::Client) -> Self {
20        Self { client }
21    }
22}
23
24impl Default for ReqwestClient {
25    fn default() -> Self { Self::new() }
26}
27
28#[async_trait]
29impl HttpClient for ReqwestClient {
30    async fn post_json(&self, url: &str, bearer_token: &str, body: &ChatRequest) -> crate::error::Result<ChatResponse> {
31        let resp = self
32            .client
33            .post(url)
34            .bearer_auth(bearer_token)
35            .json(body)
36            .send()
37            .await?;
38
39        let status = resp.status();
40        if !status.is_success() {
41            let text = resp.text().await.unwrap_or_default();
42            return Err(DeepSeekError::Api { status: status.as_u16(), body: text });
43        }
44
45        let chat_resp: ChatResponse = resp.json().await?;
46        Ok(chat_resp)
47    }
48}
49
50const REASONER_MODEL: &str = "deepseek-reasoner";
51
52/// Production constructor — reads `DEEPSEEK_API_KEY` (and optionally
53/// `DEEPSEEK_BASE_URL`) from the environment.
54pub fn client_from_env() -> Result<DeepSeekClient<ReqwestClient>> {
55    let api_key = std::env::var("DEEPSEEK_API_KEY")
56        .context("DEEPSEEK_API_KEY environment variable not set")?;
57    let base_url = std::env::var("DEEPSEEK_BASE_URL")
58        .unwrap_or_else(|_| DEFAULT_BASE_URL.to_string());
59    Ok(DeepSeekClient::new(ReqwestClient::new(), api_key).with_base_url(base_url))
60}
61
62/// Thin wrapper that builds a ChatRequest for deepseek-reasoner and calls `client.chat()`.
63pub async fn reason(
64    client: &DeepSeekClient<ReqwestClient>,
65    system: &str,
66    user: &str,
67) -> Result<ReasonerOutput> {
68    debug!("deepseek-reasoner call: system={:.80}…", system);
69
70    let request = ChatRequest {
71        model: REASONER_MODEL.to_string(),
72        messages: vec![
73            crate::types::system_msg(system),
74            crate::types::user_msg(user),
75        ],
76        tools: None,
77        tool_choice: None,
78        temperature: Some(0.6),
79        max_tokens: Some(8192),
80        stream: Some(false),
81        reasoning_effort: Some("high".to_string()),
82        thinking: Some(serde_json::json!({"type": "enabled"})),
83    };
84
85    let resp = client
86        .chat(&request)
87        .await
88        .map_err(|e| anyhow::anyhow!("{e}"))?;
89
90    let choice = resp
91        .choices
92        .into_iter()
93        .next()
94        .context("No choices in DeepSeek response")?;
95
96    Ok(ReasonerOutput {
97        reasoning: choice.message.reasoning_content.unwrap_or_default(),
98        content: choice.message.content.as_str().to_string(),
99    })
100}
101
102/// Returns `true` for errors worth retrying (5xx, network). 4xx errors are not retried.
103fn is_retryable(err: &anyhow::Error) -> bool {
104    let msg = err.to_string();
105    // DeepSeekError::Api includes "API error (STATUS): …"
106    if let Some(rest) = msg.strip_prefix("API error (") {
107        if let Some(code_str) = rest.split(')').next() {
108            if let Ok(code) = code_str.parse::<u16>() {
109                return code >= 500;
110            }
111        }
112    }
113    // Network / reqwest errors are retryable
114    msg.contains("HTTP error:") || msg.contains("connection") || msg.contains("timed out")
115}
116
117/// Wrapper around [`reason`] with exponential backoff: 3 attempts, 1s → 2s → 4s.
118/// Only retries on 5xx / network errors, not 4xx.
119pub async fn reason_with_retry(
120    client: &DeepSeekClient<ReqwestClient>,
121    system: &str,
122    user: &str,
123) -> Result<ReasonerOutput> {
124    let delays = [1, 2, 4]; // seconds
125    let mut last_err = None;
126    for (attempt, &delay_secs) in std::iter::once(&0).chain(delays.iter()).enumerate() {
127        if attempt > 0 {
128            warn!("Retry attempt {attempt}/3 after {delay_secs}s backoff");
129            tokio::time::sleep(std::time::Duration::from_secs(delay_secs)).await;
130        }
131        match reason(client, system, user).await {
132            Ok(output) => return Ok(output),
133            Err(e) => {
134                if !is_retryable(&e) || attempt == 3 {
135                    return Err(e);
136                }
137                warn!("Retryable error: {e}");
138                last_err = Some(e);
139            }
140        }
141    }
142    Err(last_err.unwrap())
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn test_client_from_env_missing_key() {
151        let original = std::env::var("DEEPSEEK_API_KEY").ok();
152        std::env::remove_var("DEEPSEEK_API_KEY");
153        let result = client_from_env();
154        let err = result.err().expect("should error when DEEPSEEK_API_KEY is unset");
155        assert!(
156            err.to_string().contains("DEEPSEEK_API_KEY"),
157            "error should mention the missing env var"
158        );
159        if let Some(val) = original {
160            std::env::set_var("DEEPSEEK_API_KEY", val);
161        }
162    }
163
164    #[test]
165    fn test_is_retryable_5xx() {
166        let err = anyhow::anyhow!("API error (500): Internal Server Error");
167        assert!(is_retryable(&err));
168    }
169
170    #[test]
171    fn test_is_not_retryable_4xx() {
172        let err = anyhow::anyhow!("API error (400): Bad Request");
173        assert!(!is_retryable(&err));
174    }
175
176    #[test]
177    fn test_is_retryable_network() {
178        let err = anyhow::anyhow!("HTTP error: connection refused");
179        assert!(is_retryable(&err));
180    }
181}