devsper_providers/
openai.rs1use devsper_core::{LlmProvider, LlmRequest, LlmResponse, LlmRole, StopReason};
2use anyhow::{anyhow, Result};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7
8pub struct OpenAiProvider {
10 client: Client,
11 api_key: String,
12 base_url: String,
13 name: String,
14}
15
16impl OpenAiProvider {
17 pub fn new(api_key: impl Into<String>) -> Self {
18 Self {
19 client: Client::new(),
20 api_key: api_key.into(),
21 base_url: "https://api.openai.com".to_string(),
22 name: "openai".to_string(),
23 }
24 }
25
26 pub fn zai(api_key: impl Into<String>) -> Self {
28 Self {
29 client: Client::new(),
30 api_key: api_key.into(),
31 base_url: "https://api.zai.ai".to_string(),
32 name: "zai".to_string(),
33 }
34 }
35
36 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
37 self.base_url = url.into();
38 self
39 }
40}
41
42#[derive(Serialize)]
43struct OaiRequest<'a> {
44 model: &'a str,
45 messages: Vec<OaiMessage<'a>>,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 max_tokens: Option<u32>,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 temperature: Option<f32>,
50}
51
52#[derive(Serialize)]
53struct OaiMessage<'a> {
54 role: &'a str,
55 content: &'a str,
56}
57
58#[derive(Deserialize)]
59struct OaiResponse {
60 choices: Vec<OaiChoice>,
61 usage: OaiUsage,
62 model: String,
63}
64
65#[derive(Deserialize)]
66struct OaiChoice {
67 message: OaiChoiceMessage,
68 finish_reason: Option<String>,
69}
70
71#[derive(Deserialize)]
72struct OaiChoiceMessage {
73 content: Option<String>,
74}
75
76#[derive(Deserialize)]
77struct OaiUsage {
78 prompt_tokens: u32,
79 completion_tokens: u32,
80}
81
82fn role_str(role: &LlmRole) -> &'static str {
83 match role {
84 LlmRole::System => "system",
85 LlmRole::User | LlmRole::Tool => "user",
86 LlmRole::Assistant => "assistant",
87 }
88}
89
90#[async_trait]
91impl LlmProvider for OpenAiProvider {
92 async fn generate(&self, req: LlmRequest) -> Result<LlmResponse> {
93 let messages: Vec<OaiMessage> = req
94 .messages
95 .iter()
96 .map(|m| OaiMessage {
97 role: role_str(&m.role),
98 content: &m.content,
99 })
100 .collect();
101
102 let body = OaiRequest {
103 model: &req.model,
104 messages,
105 max_tokens: req.max_tokens,
106 temperature: req.temperature,
107 };
108
109 debug!(model = %req.model, provider = %self.name, "OpenAI-compatible request");
110
111 let resp = self
112 .client
113 .post(format!("{}/v1/chat/completions", self.base_url))
114 .header("Authorization", format!("Bearer {}", self.api_key))
115 .header("Content-Type", "application/json")
116 .json(&body)
117 .send()
118 .await?;
119
120 if !resp.status().is_success() {
121 let status = resp.status();
122 let text = resp.text().await.unwrap_or_default();
123 return Err(anyhow!("{} API error {status}: {text}", self.name));
124 }
125
126 let data: OaiResponse = resp.json().await?;
127 let choice = data
128 .choices
129 .into_iter()
130 .next()
131 .ok_or_else(|| anyhow!("No choices in response"))?;
132
133 let stop_reason = match choice.finish_reason.as_deref() {
134 Some("tool_calls") => StopReason::ToolUse,
135 Some("length") => StopReason::MaxTokens,
136 Some("stop") | None => StopReason::EndTurn,
137 _ => StopReason::EndTurn,
138 };
139
140 Ok(LlmResponse {
141 content: choice.message.content.unwrap_or_default(),
142 tool_calls: vec![],
143 input_tokens: data.usage.prompt_tokens,
144 output_tokens: data.usage.completion_tokens,
145 model: data.model,
146 stop_reason,
147 })
148 }
149
150 fn name(&self) -> &str {
151 &self.name
152 }
153
154 fn supports_model(&self, model: &str) -> bool {
155 match self.name.as_str() {
156 "zai" => model.starts_with("zai:") || model.starts_with("glm-"),
157 _ => {
158 model.starts_with("gpt-")
159 || model.starts_with("o1")
160 || model.starts_with("o3")
161 }
162 }
163 }
164}