Skip to main content

devsper_providers/
openai.rs

1use devsper_core::{LlmProvider, LlmRequest, LlmResponse, LlmRole, StopReason};
2use anyhow::{anyhow, Result};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7
8/// OpenAI-compatible provider. Also used for ZAI (same API).
9pub struct OpenAiProvider {
10    client: Client,
11    api_key: String,
12    base_url: String,
13    name: String,
14}
15
16impl OpenAiProvider {
17    pub fn new(api_key: impl Into<String>) -> Self {
18        Self {
19            client: Client::new(),
20            api_key: api_key.into(),
21            base_url: "https://api.openai.com".to_string(),
22            name: "openai".to_string(),
23        }
24    }
25
26    /// ZAI uses OpenAI-compatible API at a different endpoint.
27    pub fn zai(api_key: impl Into<String>) -> Self {
28        Self {
29            client: Client::new(),
30            api_key: api_key.into(),
31            base_url: "https://api.zai.ai".to_string(),
32            name: "zai".to_string(),
33        }
34    }
35
36    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
37        self.base_url = url.into();
38        self
39    }
40}
41
42#[derive(Serialize)]
43struct OaiRequest<'a> {
44    model: &'a str,
45    messages: Vec<OaiMessage<'a>>,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    max_tokens: Option<u32>,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    temperature: Option<f32>,
50}
51
52#[derive(Serialize)]
53struct OaiMessage<'a> {
54    role: &'a str,
55    content: &'a str,
56}
57
58#[derive(Deserialize)]
59struct OaiResponse {
60    choices: Vec<OaiChoice>,
61    usage: OaiUsage,
62    model: String,
63}
64
65#[derive(Deserialize)]
66struct OaiChoice {
67    message: OaiChoiceMessage,
68    finish_reason: Option<String>,
69}
70
71#[derive(Deserialize)]
72struct OaiChoiceMessage {
73    content: Option<String>,
74}
75
76#[derive(Deserialize)]
77struct OaiUsage {
78    prompt_tokens: u32,
79    completion_tokens: u32,
80}
81
82fn role_str(role: &LlmRole) -> &'static str {
83    match role {
84        LlmRole::System => "system",
85        LlmRole::User | LlmRole::Tool => "user",
86        LlmRole::Assistant => "assistant",
87    }
88}
89
90#[async_trait]
91impl LlmProvider for OpenAiProvider {
92    async fn generate(&self, req: LlmRequest) -> Result<LlmResponse> {
93        let messages: Vec<OaiMessage> = req
94            .messages
95            .iter()
96            .map(|m| OaiMessage {
97                role: role_str(&m.role),
98                content: &m.content,
99            })
100            .collect();
101
102        let body = OaiRequest {
103            model: &req.model,
104            messages,
105            max_tokens: req.max_tokens,
106            temperature: req.temperature,
107        };
108
109        debug!(model = %req.model, provider = %self.name, "OpenAI-compatible request");
110
111        let resp = self
112            .client
113            .post(format!("{}/v1/chat/completions", self.base_url))
114            .header("Authorization", format!("Bearer {}", self.api_key))
115            .header("Content-Type", "application/json")
116            .json(&body)
117            .send()
118            .await?;
119
120        if !resp.status().is_success() {
121            let status = resp.status();
122            let text = resp.text().await.unwrap_or_default();
123            return Err(anyhow!("{} API error {status}: {text}", self.name));
124        }
125
126        let data: OaiResponse = resp.json().await?;
127        let choice = data
128            .choices
129            .into_iter()
130            .next()
131            .ok_or_else(|| anyhow!("No choices in response"))?;
132
133        let stop_reason = match choice.finish_reason.as_deref() {
134            Some("tool_calls") => StopReason::ToolUse,
135            Some("length") => StopReason::MaxTokens,
136            Some("stop") | None => StopReason::EndTurn,
137            _ => StopReason::EndTurn,
138        };
139
140        Ok(LlmResponse {
141            content: choice.message.content.unwrap_or_default(),
142            tool_calls: vec![],
143            input_tokens: data.usage.prompt_tokens,
144            output_tokens: data.usage.completion_tokens,
145            model: data.model,
146            stop_reason,
147        })
148    }
149
150    fn name(&self) -> &str {
151        &self.name
152    }
153
154    fn supports_model(&self, model: &str) -> bool {
155        match self.name.as_str() {
156            "zai" => model.starts_with("zai:") || model.starts_with("glm-"),
157            _ => {
158                model.starts_with("gpt-")
159                    || model.starts_with("o1")
160                    || model.starts_with("o3")
161            }
162        }
163    }
164}