1use async_trait::async_trait;
4use futures::{Stream, StreamExt};
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use std::pin::Pin;
8
9use crate::traits::{
10 ChunkType, CompletionRequest, CompletionResponse, ContentBlock, MessageContent, Provider,
11 ProviderError, Role, StopReason, StreamingChunk,
12};
13use openclaw_core::secrets::ApiKey;
14use openclaw_core::types::TokenUsage;
15
16const DEFAULT_BASE_URL: &str = "https://api.openai.com";
17
18pub struct OpenAIProvider {
20 client: Client,
21 api_key: ApiKey,
22 base_url: String,
23 org_id: Option<String>,
24}
25
26impl OpenAIProvider {
27 #[must_use]
29 pub fn new(api_key: ApiKey) -> Self {
30 Self {
31 client: Client::new(),
32 api_key,
33 base_url: DEFAULT_BASE_URL.to_string(),
34 org_id: None,
35 }
36 }
37
38 #[must_use]
40 pub fn with_base_url(api_key: ApiKey, base_url: impl Into<String>) -> Self {
41 Self {
42 client: Client::new(),
43 api_key,
44 base_url: base_url.into(),
45 org_id: None,
46 }
47 }
48
49 #[must_use]
51 pub fn with_org_id(mut self, org_id: impl Into<String>) -> Self {
52 self.org_id = Some(org_id.into());
53 self
54 }
55
56 fn to_openai_request(&self, request: &CompletionRequest) -> OpenAIRequest {
58 let mut messages: Vec<OpenAIMessage> = Vec::new();
59
60 if let Some(system) = &request.system {
62 messages.push(OpenAIMessage {
63 role: "system".to_string(),
64 content: Some(OpenAIContent::Text(system.clone())),
65 tool_calls: None,
66 tool_call_id: None,
67 });
68 }
69
70 for msg in &request.messages {
72 let openai_msg = match msg.role {
73 Role::System => OpenAIMessage {
74 role: "system".to_string(),
75 content: Some(content_to_openai(&msg.content)),
76 tool_calls: None,
77 tool_call_id: None,
78 },
79 Role::User => OpenAIMessage {
80 role: "user".to_string(),
81 content: Some(content_to_openai(&msg.content)),
82 tool_calls: None,
83 tool_call_id: None,
84 },
85 Role::Assistant => {
86 let (content, tool_calls) = extract_tool_calls(&msg.content);
87 OpenAIMessage {
88 role: "assistant".to_string(),
89 content,
90 tool_calls,
91 tool_call_id: None,
92 }
93 }
94 Role::Tool => {
95 if let MessageContent::Blocks(blocks) = &msg.content {
96 for block in blocks {
97 if let ContentBlock::ToolResult {
98 tool_use_id,
99 content,
100 ..
101 } = block
102 {
103 messages.push(OpenAIMessage {
104 role: "tool".to_string(),
105 content: Some(OpenAIContent::Text(content.clone())),
106 tool_calls: None,
107 tool_call_id: Some(tool_use_id.clone()),
108 });
109 }
110 }
111 }
112 continue;
113 }
114 };
115 messages.push(openai_msg);
116 }
117
118 let tools = request.tools.as_ref().map(|tools| {
119 tools
120 .iter()
121 .map(|t| OpenAITool {
122 tool_type: "function".to_string(),
123 function: OpenAIFunction {
124 name: t.name.clone(),
125 description: t.description.clone(),
126 parameters: t.input_schema.clone(),
127 },
128 })
129 .collect()
130 });
131
132 OpenAIRequest {
133 model: request.model.clone(),
134 messages,
135 max_tokens: Some(request.max_tokens),
136 temperature: Some(request.temperature),
137 stop: request.stop.clone(),
138 tools,
139 stream: Some(false),
140 }
141 }
142}
143
144fn content_to_openai(content: &MessageContent) -> OpenAIContent {
145 match content {
146 MessageContent::Text(text) => OpenAIContent::Text(text.clone()),
147 MessageContent::Blocks(blocks) => {
148 let parts: Vec<OpenAIContentPart> = blocks
149 .iter()
150 .filter_map(|b| match b {
151 ContentBlock::Text { text } => {
152 Some(OpenAIContentPart::Text { text: text.clone() })
153 }
154 ContentBlock::Image { source } => Some(OpenAIContentPart::ImageUrl {
155 image_url: OpenAIImageUrl {
156 url: format!("data:{};base64,{}", source.media_type, source.data),
157 },
158 }),
159 _ => None,
160 })
161 .collect();
162 OpenAIContent::Parts(parts)
163 }
164 }
165}
166
167fn extract_tool_calls(
168 content: &MessageContent,
169) -> (Option<OpenAIContent>, Option<Vec<OpenAIToolCall>>) {
170 match content {
171 MessageContent::Text(text) => (Some(OpenAIContent::Text(text.clone())), None),
172 MessageContent::Blocks(blocks) => {
173 let mut text_parts = Vec::new();
174 let mut tool_calls = Vec::new();
175
176 for block in blocks {
177 match block {
178 ContentBlock::Text { text } => text_parts.push(text.clone()),
179 ContentBlock::ToolUse { id, name, input } => {
180 tool_calls.push(OpenAIToolCall {
181 id: id.clone(),
182 call_type: "function".to_string(),
183 function: OpenAIFunctionCall {
184 name: name.clone(),
185 arguments: serde_json::to_string(input).unwrap_or_default(),
186 },
187 });
188 }
189 _ => {}
190 }
191 }
192
193 let content = if text_parts.is_empty() {
194 None
195 } else {
196 Some(OpenAIContent::Text(text_parts.join("\n")))
197 };
198
199 let tool_calls = if tool_calls.is_empty() {
200 None
201 } else {
202 Some(tool_calls)
203 };
204
205 (content, tool_calls)
206 }
207 }
208}
209
210#[async_trait]
211impl Provider for OpenAIProvider {
212 fn name(&self) -> &'static str {
213 "openai"
214 }
215
216 async fn list_models(&self) -> Result<Vec<String>, ProviderError> {
217 let url = format!("{}/v1/models", self.base_url);
218
219 let mut req = self
220 .client
221 .get(&url)
222 .header("Authorization", format!("Bearer {}", self.api_key.expose()));
223
224 if let Some(org) = &self.org_id {
225 req = req.header("OpenAI-Organization", org);
226 }
227
228 let response = req.send().await?;
229
230 if !response.status().is_success() {
231 let status = response.status().as_u16();
232 let message = response.text().await.unwrap_or_default();
233 return Err(ProviderError::Api { status, message });
234 }
235
236 let result: OpenAIModelsResponse = response.json().await?;
237 Ok(result.data.into_iter().map(|m| m.id).collect())
238 }
239
240 async fn complete(
241 &self,
242 request: CompletionRequest,
243 ) -> Result<CompletionResponse, ProviderError> {
244 let url = format!("{}/v1/chat/completions", self.base_url);
245 let openai_request = self.to_openai_request(&request);
246
247 let mut req = self
248 .client
249 .post(&url)
250 .header("Authorization", format!("Bearer {}", self.api_key.expose()))
251 .header("Content-Type", "application/json");
252
253 if let Some(org) = &self.org_id {
254 req = req.header("OpenAI-Organization", org);
255 }
256
257 let response = req.json(&openai_request).send().await?;
258
259 if !response.status().is_success() {
260 let status = response.status().as_u16();
261
262 if status == 429 {
263 let retry_after = response
264 .headers()
265 .get("retry-after")
266 .and_then(|v| v.to_str().ok())
267 .and_then(|v| v.parse().ok())
268 .unwrap_or(60);
269 return Err(ProviderError::RateLimited {
270 retry_after_secs: retry_after,
271 });
272 }
273
274 let message = response.text().await.unwrap_or_default();
275 return Err(ProviderError::Api { status, message });
276 }
277
278 let result: OpenAIResponse = response.json().await?;
279 Ok(result.into())
280 }
281
282 async fn complete_stream(
283 &self,
284 request: CompletionRequest,
285 ) -> Result<
286 Pin<Box<dyn Stream<Item = Result<StreamingChunk, ProviderError>> + Send>>,
287 ProviderError,
288 > {
289 let url = format!("{}/v1/chat/completions", self.base_url);
290 let mut openai_request = self.to_openai_request(&request);
291 openai_request.stream = Some(true);
292
293 let mut req = self
294 .client
295 .post(&url)
296 .header("Authorization", format!("Bearer {}", self.api_key.expose()))
297 .header("Content-Type", "application/json");
298
299 if let Some(org) = &self.org_id {
300 req = req.header("OpenAI-Organization", org);
301 }
302
303 let response = req.json(&openai_request).send().await?;
304
305 if !response.status().is_success() {
306 let status = response.status().as_u16();
307 let message = response.text().await.unwrap_or_default();
308 return Err(ProviderError::Api { status, message });
309 }
310
311 let stream = response.bytes_stream().map(move |result| match result {
312 Ok(bytes) => {
313 let text = String::from_utf8_lossy(&bytes);
314 parse_sse_event(&text)
315 }
316 Err(e) => Err(ProviderError::Network(e)),
317 });
318
319 Ok(Box::pin(stream))
320 }
321}
322
323fn parse_sse_event(text: &str) -> Result<StreamingChunk, ProviderError> {
324 for line in text.lines() {
325 if let Some(data) = line.strip_prefix("data: ") {
326 if data == "[DONE]" {
327 return Ok(StreamingChunk {
328 chunk_type: ChunkType::MessageStop,
329 delta: None,
330 index: None,
331 });
332 }
333
334 if let Ok(event) = serde_json::from_str::<OpenAIStreamEvent>(data) {
335 if let Some(choice) = event.choices.first() {
336 return Ok(StreamingChunk {
337 chunk_type: if choice.finish_reason.is_some() {
338 ChunkType::MessageStop
339 } else {
340 ChunkType::ContentBlockDelta
341 },
342 delta: choice.delta.content.clone(),
343 index: Some(choice.index),
344 });
345 }
346 }
347 }
348 }
349
350 Ok(StreamingChunk {
351 chunk_type: ChunkType::ContentBlockDelta,
352 delta: None,
353 index: None,
354 })
355}
356
357#[derive(Debug, Serialize)]
360struct OpenAIRequest {
361 model: String,
362 messages: Vec<OpenAIMessage>,
363 #[serde(skip_serializing_if = "Option::is_none")]
364 max_tokens: Option<u32>,
365 #[serde(skip_serializing_if = "Option::is_none")]
366 temperature: Option<f32>,
367 #[serde(skip_serializing_if = "Option::is_none")]
368 stop: Option<Vec<String>>,
369 #[serde(skip_serializing_if = "Option::is_none")]
370 tools: Option<Vec<OpenAITool>>,
371 #[serde(skip_serializing_if = "Option::is_none")]
372 stream: Option<bool>,
373}
374
375#[derive(Debug, Serialize)]
376struct OpenAIMessage {
377 role: String,
378 #[serde(skip_serializing_if = "Option::is_none")]
379 content: Option<OpenAIContent>,
380 #[serde(skip_serializing_if = "Option::is_none")]
381 tool_calls: Option<Vec<OpenAIToolCall>>,
382 #[serde(skip_serializing_if = "Option::is_none")]
383 tool_call_id: Option<String>,
384}
385
386#[derive(Debug, Serialize)]
387#[serde(untagged)]
388enum OpenAIContent {
389 Text(String),
390 Parts(Vec<OpenAIContentPart>),
391}
392
393#[derive(Debug, Serialize)]
394#[serde(tag = "type", rename_all = "snake_case")]
395enum OpenAIContentPart {
396 Text { text: String },
397 ImageUrl { image_url: OpenAIImageUrl },
398}
399
400#[derive(Debug, Serialize)]
401struct OpenAIImageUrl {
402 url: String,
403}
404
405#[derive(Debug, Serialize, Deserialize)]
406struct OpenAIToolCall {
407 id: String,
408 #[serde(rename = "type")]
409 call_type: String,
410 function: OpenAIFunctionCall,
411}
412
413#[derive(Debug, Serialize, Deserialize)]
414struct OpenAIFunctionCall {
415 name: String,
416 arguments: String,
417}
418
419#[derive(Debug, Serialize)]
420struct OpenAITool {
421 #[serde(rename = "type")]
422 tool_type: String,
423 function: OpenAIFunction,
424}
425
426#[derive(Debug, Serialize)]
427struct OpenAIFunction {
428 name: String,
429 description: String,
430 parameters: serde_json::Value,
431}
432
433#[derive(Debug, Deserialize)]
434struct OpenAIModelsResponse {
435 data: Vec<OpenAIModel>,
436}
437
438#[derive(Debug, Deserialize)]
439struct OpenAIModel {
440 id: String,
441}
442
443#[derive(Debug, Deserialize)]
444struct OpenAIResponse {
445 id: String,
446 model: String,
447 choices: Vec<OpenAIChoice>,
448 usage: OpenAIUsage,
449}
450
451#[derive(Debug, Deserialize)]
452struct OpenAIChoice {
453 message: OpenAIResponseMessage,
454 finish_reason: Option<String>,
455}
456
457#[derive(Debug, Deserialize)]
458struct OpenAIResponseMessage {
459 content: Option<String>,
460 tool_calls: Option<Vec<OpenAIToolCall>>,
461}
462
463#[derive(Debug, Deserialize)]
464struct OpenAIUsage {
465 prompt_tokens: u64,
466 completion_tokens: u64,
467}
468
469#[derive(Debug, Deserialize)]
470struct OpenAIStreamEvent {
471 choices: Vec<OpenAIStreamChoice>,
472}
473
474#[derive(Debug, Deserialize)]
475struct OpenAIStreamChoice {
476 index: usize,
477 delta: OpenAIStreamDelta,
478 finish_reason: Option<String>,
479}
480
481#[derive(Debug, Deserialize)]
482struct OpenAIStreamDelta {
483 content: Option<String>,
484}
485
486impl From<OpenAIResponse> for CompletionResponse {
487 fn from(resp: OpenAIResponse) -> Self {
488 let choice = resp.choices.into_iter().next();
489 let (content, stop_reason) = match choice {
490 Some(c) => {
491 let mut blocks = Vec::new();
492
493 if let Some(text) = c.message.content {
494 blocks.push(ContentBlock::Text { text });
495 }
496
497 if let Some(tool_calls) = c.message.tool_calls {
498 for tc in tool_calls {
499 let input: serde_json::Value =
500 serde_json::from_str(&tc.function.arguments).unwrap_or_default();
501 blocks.push(ContentBlock::ToolUse {
502 id: tc.id,
503 name: tc.function.name,
504 input,
505 });
506 }
507 }
508
509 let stop = c.finish_reason.and_then(|r| match r.as_str() {
510 "stop" => Some(StopReason::EndTurn),
511 "length" => Some(StopReason::MaxTokens),
512 "tool_calls" => Some(StopReason::ToolUse),
513 _ => None,
514 });
515
516 (blocks, stop)
517 }
518 None => (vec![], None),
519 };
520
521 Self {
522 id: resp.id,
523 model: resp.model,
524 content,
525 stop_reason,
526 usage: TokenUsage {
527 input_tokens: resp.usage.prompt_tokens,
528 output_tokens: resp.usage.completion_tokens,
529 cache_read_tokens: None,
530 cache_write_tokens: None,
531 },
532 }
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539 use crate::Message;
540
541 #[test]
542 fn test_provider_name() {
543 let provider = OpenAIProvider::new(ApiKey::new("test".to_string()));
544 assert_eq!(provider.name(), "openai");
545 }
546
547 #[test]
548 fn test_request_conversion() {
549 let provider = OpenAIProvider::new(ApiKey::new("test".to_string()));
550 let request = CompletionRequest {
551 model: "gpt-4o".to_string(),
552 messages: vec![Message {
553 role: Role::User,
554 content: MessageContent::Text("Hello".to_string()),
555 }],
556 system: Some("You are helpful".to_string()),
557 max_tokens: 1024,
558 temperature: 0.7,
559 stop: None,
560 tools: None,
561 };
562
563 let openai_req = provider.to_openai_request(&request);
564 assert_eq!(openai_req.model, "gpt-4o");
565 assert_eq!(openai_req.messages.len(), 2); }
567}