Skip to main content

devsper_providers/
openai.rs

1use devsper_core::{LlmProvider, LlmRequest, LlmResponse, LlmRole, StopReason};
2use anyhow::{anyhow, Result};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7
8/// OpenAI-compatible provider. Also used for ZAI (same API).
9pub struct OpenAiProvider {
10    client: Client,
11    api_key: String,
12    base_url: String,
13    name: String,
14}
15
16impl OpenAiProvider {
17    pub fn new(api_key: impl Into<String>) -> Self {
18        Self {
19            client: Client::new(),
20            api_key: api_key.into(),
21            base_url: "https://api.openai.com".to_string(),
22            name: "openai".to_string(),
23        }
24    }
25
26    /// ZAI uses OpenAI-compatible API at a different endpoint.
27    pub fn zai(api_key: impl Into<String>) -> Self {
28        Self {
29            client: Client::new(),
30            api_key: api_key.into(),
31            base_url: "https://api.zai.ai".to_string(),
32            name: "zai".to_string(),
33        }
34    }
35
36    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
37        self.base_url = url.into();
38        self
39    }
40}
41
42#[derive(Serialize)]
43struct OaiRequest<'a> {
44    model: &'a str,
45    messages: Vec<OaiMessage<'a>>,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    max_tokens: Option<u32>,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    temperature: Option<f32>,
50}
51
52#[derive(Serialize)]
53struct OaiMessage<'a> {
54    role: &'a str,
55    content: &'a str,
56}
57
58#[derive(Deserialize)]
59struct OaiResponse {
60    choices: Vec<OaiChoice>,
61    usage: OaiUsage,
62    model: String,
63}
64
65#[derive(Deserialize)]
66struct OaiChoice {
67    message: OaiChoiceMessage,
68    finish_reason: Option<String>,
69}
70
71#[derive(Deserialize)]
72struct OaiChoiceMessage {
73    content: Option<String>,
74}
75
76#[derive(Deserialize)]
77struct OaiUsage {
78    prompt_tokens: u32,
79    completion_tokens: u32,
80}
81
82fn role_str(role: &LlmRole) -> &'static str {
83    match role {
84        LlmRole::System => "system",
85        LlmRole::User | LlmRole::Tool => "user",
86        LlmRole::Assistant => "assistant",
87    }
88}
89
90#[async_trait]
91impl LlmProvider for OpenAiProvider {
92    async fn generate(&self, req: LlmRequest) -> Result<LlmResponse> {
93        use tracing::Instrument;
94
95        let span = tracing::info_span!(
96            "gen_ai.chat",
97            "gen_ai.system" = self.name(),
98            "gen_ai.operation.name" = "chat",
99            "gen_ai.request.model" = req.model.as_str(),
100            "gen_ai.request.max_tokens" = req.max_tokens,
101            "gen_ai.response.model" = tracing::field::Empty,
102            "gen_ai.usage.input_tokens" = tracing::field::Empty,
103            "gen_ai.usage.output_tokens" = tracing::field::Empty,
104        );
105
106        let messages: Vec<OaiMessage> = req
107            .messages
108            .iter()
109            .map(|m| OaiMessage {
110                role: role_str(&m.role),
111                content: &m.content,
112            })
113            .collect();
114
115        let body = OaiRequest {
116            model: &req.model,
117            messages,
118            max_tokens: req.max_tokens,
119            temperature: req.temperature,
120        };
121
122        debug!(model = %req.model, provider = %self.name, "OpenAI-compatible request");
123
124        let result = async {
125            let resp = self
126                .client
127                .post(format!("{}/v1/chat/completions", self.base_url))
128                .header("Authorization", format!("Bearer {}", self.api_key))
129                .header("Content-Type", "application/json")
130                .json(&body)
131                .send()
132                .await?;
133
134            if !resp.status().is_success() {
135                let status = resp.status();
136                let text = resp.text().await.unwrap_or_default();
137                return Err(anyhow!("{} API error {status}: {text}", self.name));
138            }
139
140            let data: OaiResponse = resp.json().await?;
141            let choice = data
142                .choices
143                .into_iter()
144                .next()
145                .ok_or_else(|| anyhow!("No choices in response"))?;
146
147            let stop_reason = match choice.finish_reason.as_deref() {
148                Some("tool_calls") => StopReason::ToolUse,
149                Some("length") => StopReason::MaxTokens,
150                Some("stop") | None => StopReason::EndTurn,
151                _ => StopReason::EndTurn,
152            };
153
154            Ok(LlmResponse {
155                content: choice.message.content.unwrap_or_default(),
156                tool_calls: vec![],
157                input_tokens: data.usage.prompt_tokens,
158                output_tokens: data.usage.completion_tokens,
159                model: data.model,
160                stop_reason,
161            })
162        }
163        .instrument(span.clone())
164        .await;
165
166        if let Ok(ref resp) = result {
167            span.record("gen_ai.response.model", resp.model.as_str());
168            span.record("gen_ai.usage.input_tokens", resp.input_tokens);
169            span.record("gen_ai.usage.output_tokens", resp.output_tokens);
170        }
171        result
172    }
173
174    fn name(&self) -> &str {
175        &self.name
176    }
177
178    fn supports_model(&self, model: &str) -> bool {
179        match self.name.as_str() {
180            "zai" => model.starts_with("zai:") || model.starts_with("glm-"),
181            _ => {
182                model.starts_with("gpt-")
183                    || model.starts_with("o1")
184                    || model.starts_with("o3")
185            }
186        }
187    }
188}