devsper_providers/
openai.rs1use devsper_core::{LlmProvider, LlmRequest, LlmResponse, LlmRole, StopReason};
2use anyhow::{anyhow, Result};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7
8pub struct OpenAiProvider {
10 client: Client,
11 api_key: String,
12 base_url: String,
13 name: String,
14}
15
16impl OpenAiProvider {
17 pub fn new(api_key: impl Into<String>) -> Self {
18 Self {
19 client: Client::new(),
20 api_key: api_key.into(),
21 base_url: "https://api.openai.com".to_string(),
22 name: "openai".to_string(),
23 }
24 }
25
26 pub fn zai(api_key: impl Into<String>) -> Self {
28 Self {
29 client: Client::new(),
30 api_key: api_key.into(),
31 base_url: "https://api.zai.ai".to_string(),
32 name: "zai".to_string(),
33 }
34 }
35
36 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
37 self.base_url = url.into();
38 self
39 }
40}
41
42#[derive(Serialize)]
43struct OaiRequest<'a> {
44 model: &'a str,
45 messages: Vec<OaiMessage<'a>>,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 max_tokens: Option<u32>,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 temperature: Option<f32>,
50}
51
52#[derive(Serialize)]
53struct OaiMessage<'a> {
54 role: &'a str,
55 content: &'a str,
56}
57
58#[derive(Deserialize)]
59struct OaiResponse {
60 choices: Vec<OaiChoice>,
61 usage: OaiUsage,
62 model: String,
63}
64
65#[derive(Deserialize)]
66struct OaiChoice {
67 message: OaiChoiceMessage,
68 finish_reason: Option<String>,
69}
70
71#[derive(Deserialize)]
72struct OaiChoiceMessage {
73 content: Option<String>,
74}
75
76#[derive(Deserialize)]
77struct OaiUsage {
78 prompt_tokens: u32,
79 completion_tokens: u32,
80}
81
82fn role_str(role: &LlmRole) -> &'static str {
83 match role {
84 LlmRole::System => "system",
85 LlmRole::User | LlmRole::Tool => "user",
86 LlmRole::Assistant => "assistant",
87 }
88}
89
90#[async_trait]
91impl LlmProvider for OpenAiProvider {
92 async fn generate(&self, req: LlmRequest) -> Result<LlmResponse> {
93 use tracing::Instrument;
94
95 let span = tracing::info_span!(
96 "gen_ai.chat",
97 "gen_ai.system" = self.name(),
98 "gen_ai.operation.name" = "chat",
99 "gen_ai.request.model" = req.model.as_str(),
100 "gen_ai.request.max_tokens" = req.max_tokens,
101 "gen_ai.response.model" = tracing::field::Empty,
102 "gen_ai.usage.input_tokens" = tracing::field::Empty,
103 "gen_ai.usage.output_tokens" = tracing::field::Empty,
104 );
105
106 let messages: Vec<OaiMessage> = req
107 .messages
108 .iter()
109 .map(|m| OaiMessage {
110 role: role_str(&m.role),
111 content: &m.content,
112 })
113 .collect();
114
115 let body = OaiRequest {
116 model: &req.model,
117 messages,
118 max_tokens: req.max_tokens,
119 temperature: req.temperature,
120 };
121
122 debug!(model = %req.model, provider = %self.name, "OpenAI-compatible request");
123
124 let result = async {
125 let resp = self
126 .client
127 .post(format!("{}/v1/chat/completions", self.base_url))
128 .header("Authorization", format!("Bearer {}", self.api_key))
129 .header("Content-Type", "application/json")
130 .json(&body)
131 .send()
132 .await?;
133
134 if !resp.status().is_success() {
135 let status = resp.status();
136 let text = resp.text().await.unwrap_or_default();
137 return Err(anyhow!("{} API error {status}: {text}", self.name));
138 }
139
140 let data: OaiResponse = resp.json().await?;
141 let choice = data
142 .choices
143 .into_iter()
144 .next()
145 .ok_or_else(|| anyhow!("No choices in response"))?;
146
147 let stop_reason = match choice.finish_reason.as_deref() {
148 Some("tool_calls") => StopReason::ToolUse,
149 Some("length") => StopReason::MaxTokens,
150 Some("stop") | None => StopReason::EndTurn,
151 _ => StopReason::EndTurn,
152 };
153
154 Ok(LlmResponse {
155 content: choice.message.content.unwrap_or_default(),
156 tool_calls: vec![],
157 input_tokens: data.usage.prompt_tokens,
158 output_tokens: data.usage.completion_tokens,
159 model: data.model,
160 stop_reason,
161 })
162 }
163 .instrument(span.clone())
164 .await;
165
166 if let Ok(ref resp) = result {
167 span.record("gen_ai.response.model", resp.model.as_str());
168 span.record("gen_ai.usage.input_tokens", resp.input_tokens);
169 span.record("gen_ai.usage.output_tokens", resp.output_tokens);
170 }
171 result
172 }
173
174 fn name(&self) -> &str {
175 &self.name
176 }
177
178 fn supports_model(&self, model: &str) -> bool {
179 match self.name.as_str() {
180 "zai" => model.starts_with("zai:") || model.starts_with("glm-"),
181 _ => {
182 model.starts_with("gpt-")
183 || model.starts_with("o1")
184 || model.starts_with("o3")
185 }
186 }
187 }
188}