Skip to main content

openclaw_providers/
openai.rs

1//! `OpenAI` API provider.
2
3use async_trait::async_trait;
4use futures::{Stream, StreamExt};
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use std::pin::Pin;
8
9use crate::traits::{
10    ChunkType, CompletionRequest, CompletionResponse, ContentBlock, MessageContent, Provider,
11    ProviderError, Role, StopReason, StreamingChunk,
12};
13use openclaw_core::secrets::ApiKey;
14use openclaw_core::types::TokenUsage;
15
16const DEFAULT_BASE_URL: &str = "https://api.openai.com";
17
18/// `OpenAI` API provider.
19pub struct OpenAIProvider {
20    client: Client,
21    api_key: ApiKey,
22    base_url: String,
23    org_id: Option<String>,
24}
25
26impl OpenAIProvider {
27    /// Create a new `OpenAI` provider.
28    #[must_use]
29    pub fn new(api_key: ApiKey) -> Self {
30        Self {
31            client: Client::new(),
32            api_key,
33            base_url: DEFAULT_BASE_URL.to_string(),
34            org_id: None,
35        }
36    }
37
38    /// Create with custom base URL (for Azure or compatible APIs).
39    #[must_use]
40    pub fn with_base_url(api_key: ApiKey, base_url: impl Into<String>) -> Self {
41        Self {
42            client: Client::new(),
43            api_key,
44            base_url: base_url.into(),
45            org_id: None,
46        }
47    }
48
49    /// Set organization ID.
50    #[must_use]
51    pub fn with_org_id(mut self, org_id: impl Into<String>) -> Self {
52        self.org_id = Some(org_id.into());
53        self
54    }
55
56    /// Convert our request format to `OpenAI`'s API format.
57    fn to_openai_request(&self, request: &CompletionRequest) -> OpenAIRequest {
58        let mut messages: Vec<OpenAIMessage> = Vec::new();
59
60        // Add system message if present
61        if let Some(system) = &request.system {
62            messages.push(OpenAIMessage {
63                role: "system".to_string(),
64                content: Some(OpenAIContent::Text(system.clone())),
65                tool_calls: None,
66                tool_call_id: None,
67            });
68        }
69
70        // Convert messages
71        for msg in &request.messages {
72            let openai_msg = match msg.role {
73                Role::System => OpenAIMessage {
74                    role: "system".to_string(),
75                    content: Some(content_to_openai(&msg.content)),
76                    tool_calls: None,
77                    tool_call_id: None,
78                },
79                Role::User => OpenAIMessage {
80                    role: "user".to_string(),
81                    content: Some(content_to_openai(&msg.content)),
82                    tool_calls: None,
83                    tool_call_id: None,
84                },
85                Role::Assistant => {
86                    let (content, tool_calls) = extract_tool_calls(&msg.content);
87                    OpenAIMessage {
88                        role: "assistant".to_string(),
89                        content,
90                        tool_calls,
91                        tool_call_id: None,
92                    }
93                }
94                Role::Tool => {
95                    if let MessageContent::Blocks(blocks) = &msg.content {
96                        for block in blocks {
97                            if let ContentBlock::ToolResult {
98                                tool_use_id,
99                                content,
100                                ..
101                            } = block
102                            {
103                                messages.push(OpenAIMessage {
104                                    role: "tool".to_string(),
105                                    content: Some(OpenAIContent::Text(content.clone())),
106                                    tool_calls: None,
107                                    tool_call_id: Some(tool_use_id.clone()),
108                                });
109                            }
110                        }
111                    }
112                    continue;
113                }
114            };
115            messages.push(openai_msg);
116        }
117
118        let tools = request.tools.as_ref().map(|tools| {
119            tools
120                .iter()
121                .map(|t| OpenAITool {
122                    tool_type: "function".to_string(),
123                    function: OpenAIFunction {
124                        name: t.name.clone(),
125                        description: t.description.clone(),
126                        parameters: t.input_schema.clone(),
127                    },
128                })
129                .collect()
130        });
131
132        OpenAIRequest {
133            model: request.model.clone(),
134            messages,
135            max_tokens: Some(request.max_tokens),
136            temperature: Some(request.temperature),
137            stop: request.stop.clone(),
138            tools,
139            stream: Some(false),
140        }
141    }
142}
143
144fn content_to_openai(content: &MessageContent) -> OpenAIContent {
145    match content {
146        MessageContent::Text(text) => OpenAIContent::Text(text.clone()),
147        MessageContent::Blocks(blocks) => {
148            let parts: Vec<OpenAIContentPart> = blocks
149                .iter()
150                .filter_map(|b| match b {
151                    ContentBlock::Text { text } => {
152                        Some(OpenAIContentPart::Text { text: text.clone() })
153                    }
154                    ContentBlock::Image { source } => Some(OpenAIContentPart::ImageUrl {
155                        image_url: OpenAIImageUrl {
156                            url: format!("data:{};base64,{}", source.media_type, source.data),
157                        },
158                    }),
159                    _ => None,
160                })
161                .collect();
162            OpenAIContent::Parts(parts)
163        }
164    }
165}
166
167fn extract_tool_calls(
168    content: &MessageContent,
169) -> (Option<OpenAIContent>, Option<Vec<OpenAIToolCall>>) {
170    match content {
171        MessageContent::Text(text) => (Some(OpenAIContent::Text(text.clone())), None),
172        MessageContent::Blocks(blocks) => {
173            let mut text_parts = Vec::new();
174            let mut tool_calls = Vec::new();
175
176            for block in blocks {
177                match block {
178                    ContentBlock::Text { text } => text_parts.push(text.clone()),
179                    ContentBlock::ToolUse { id, name, input } => {
180                        tool_calls.push(OpenAIToolCall {
181                            id: id.clone(),
182                            call_type: "function".to_string(),
183                            function: OpenAIFunctionCall {
184                                name: name.clone(),
185                                arguments: serde_json::to_string(input).unwrap_or_default(),
186                            },
187                        });
188                    }
189                    _ => {}
190                }
191            }
192
193            let content = if text_parts.is_empty() {
194                None
195            } else {
196                Some(OpenAIContent::Text(text_parts.join("\n")))
197            };
198
199            let tool_calls = if tool_calls.is_empty() {
200                None
201            } else {
202                Some(tool_calls)
203            };
204
205            (content, tool_calls)
206        }
207    }
208}
209
210#[async_trait]
211impl Provider for OpenAIProvider {
212    fn name(&self) -> &'static str {
213        "openai"
214    }
215
216    async fn list_models(&self) -> Result<Vec<String>, ProviderError> {
217        let url = format!("{}/v1/models", self.base_url);
218
219        let mut req = self
220            .client
221            .get(&url)
222            .header("Authorization", format!("Bearer {}", self.api_key.expose()));
223
224        if let Some(org) = &self.org_id {
225            req = req.header("OpenAI-Organization", org);
226        }
227
228        let response = req.send().await?;
229
230        if !response.status().is_success() {
231            let status = response.status().as_u16();
232            let message = response.text().await.unwrap_or_default();
233            return Err(ProviderError::Api { status, message });
234        }
235
236        let result: OpenAIModelsResponse = response.json().await?;
237        Ok(result.data.into_iter().map(|m| m.id).collect())
238    }
239
240    async fn complete(
241        &self,
242        request: CompletionRequest,
243    ) -> Result<CompletionResponse, ProviderError> {
244        let url = format!("{}/v1/chat/completions", self.base_url);
245        let openai_request = self.to_openai_request(&request);
246
247        let mut req = self
248            .client
249            .post(&url)
250            .header("Authorization", format!("Bearer {}", self.api_key.expose()))
251            .header("Content-Type", "application/json");
252
253        if let Some(org) = &self.org_id {
254            req = req.header("OpenAI-Organization", org);
255        }
256
257        let response = req.json(&openai_request).send().await?;
258
259        if !response.status().is_success() {
260            let status = response.status().as_u16();
261
262            if status == 429 {
263                let retry_after = response
264                    .headers()
265                    .get("retry-after")
266                    .and_then(|v| v.to_str().ok())
267                    .and_then(|v| v.parse().ok())
268                    .unwrap_or(60);
269                return Err(ProviderError::RateLimited {
270                    retry_after_secs: retry_after,
271                });
272            }
273
274            let message = response.text().await.unwrap_or_default();
275            return Err(ProviderError::Api { status, message });
276        }
277
278        let result: OpenAIResponse = response.json().await?;
279        Ok(result.into())
280    }
281
282    async fn complete_stream(
283        &self,
284        request: CompletionRequest,
285    ) -> Result<
286        Pin<Box<dyn Stream<Item = Result<StreamingChunk, ProviderError>> + Send>>,
287        ProviderError,
288    > {
289        let url = format!("{}/v1/chat/completions", self.base_url);
290        let mut openai_request = self.to_openai_request(&request);
291        openai_request.stream = Some(true);
292
293        let mut req = self
294            .client
295            .post(&url)
296            .header("Authorization", format!("Bearer {}", self.api_key.expose()))
297            .header("Content-Type", "application/json");
298
299        if let Some(org) = &self.org_id {
300            req = req.header("OpenAI-Organization", org);
301        }
302
303        let response = req.json(&openai_request).send().await?;
304
305        if !response.status().is_success() {
306            let status = response.status().as_u16();
307            let message = response.text().await.unwrap_or_default();
308            return Err(ProviderError::Api { status, message });
309        }
310
311        let stream = response.bytes_stream().map(move |result| match result {
312            Ok(bytes) => {
313                let text = String::from_utf8_lossy(&bytes);
314                parse_sse_event(&text)
315            }
316            Err(e) => Err(ProviderError::Network(e)),
317        });
318
319        Ok(Box::pin(stream))
320    }
321}
322
323fn parse_sse_event(text: &str) -> Result<StreamingChunk, ProviderError> {
324    for line in text.lines() {
325        if let Some(data) = line.strip_prefix("data: ") {
326            if data == "[DONE]" {
327                return Ok(StreamingChunk {
328                    chunk_type: ChunkType::MessageStop,
329                    delta: None,
330                    index: None,
331                });
332            }
333
334            if let Ok(event) = serde_json::from_str::<OpenAIStreamEvent>(data) {
335                if let Some(choice) = event.choices.first() {
336                    return Ok(StreamingChunk {
337                        chunk_type: if choice.finish_reason.is_some() {
338                            ChunkType::MessageStop
339                        } else {
340                            ChunkType::ContentBlockDelta
341                        },
342                        delta: choice.delta.content.clone(),
343                        index: Some(choice.index),
344                    });
345                }
346            }
347        }
348    }
349
350    Ok(StreamingChunk {
351        chunk_type: ChunkType::ContentBlockDelta,
352        delta: None,
353        index: None,
354    })
355}
356
357// OpenAI API types
358
359#[derive(Debug, Serialize)]
360struct OpenAIRequest {
361    model: String,
362    messages: Vec<OpenAIMessage>,
363    #[serde(skip_serializing_if = "Option::is_none")]
364    max_tokens: Option<u32>,
365    #[serde(skip_serializing_if = "Option::is_none")]
366    temperature: Option<f32>,
367    #[serde(skip_serializing_if = "Option::is_none")]
368    stop: Option<Vec<String>>,
369    #[serde(skip_serializing_if = "Option::is_none")]
370    tools: Option<Vec<OpenAITool>>,
371    #[serde(skip_serializing_if = "Option::is_none")]
372    stream: Option<bool>,
373}
374
375#[derive(Debug, Serialize)]
376struct OpenAIMessage {
377    role: String,
378    #[serde(skip_serializing_if = "Option::is_none")]
379    content: Option<OpenAIContent>,
380    #[serde(skip_serializing_if = "Option::is_none")]
381    tool_calls: Option<Vec<OpenAIToolCall>>,
382    #[serde(skip_serializing_if = "Option::is_none")]
383    tool_call_id: Option<String>,
384}
385
386#[derive(Debug, Serialize)]
387#[serde(untagged)]
388enum OpenAIContent {
389    Text(String),
390    Parts(Vec<OpenAIContentPart>),
391}
392
393#[derive(Debug, Serialize)]
394#[serde(tag = "type", rename_all = "snake_case")]
395enum OpenAIContentPart {
396    Text { text: String },
397    ImageUrl { image_url: OpenAIImageUrl },
398}
399
400#[derive(Debug, Serialize)]
401struct OpenAIImageUrl {
402    url: String,
403}
404
405#[derive(Debug, Serialize, Deserialize)]
406struct OpenAIToolCall {
407    id: String,
408    #[serde(rename = "type")]
409    call_type: String,
410    function: OpenAIFunctionCall,
411}
412
413#[derive(Debug, Serialize, Deserialize)]
414struct OpenAIFunctionCall {
415    name: String,
416    arguments: String,
417}
418
419#[derive(Debug, Serialize)]
420struct OpenAITool {
421    #[serde(rename = "type")]
422    tool_type: String,
423    function: OpenAIFunction,
424}
425
426#[derive(Debug, Serialize)]
427struct OpenAIFunction {
428    name: String,
429    description: String,
430    parameters: serde_json::Value,
431}
432
433#[derive(Debug, Deserialize)]
434struct OpenAIModelsResponse {
435    data: Vec<OpenAIModel>,
436}
437
438#[derive(Debug, Deserialize)]
439struct OpenAIModel {
440    id: String,
441}
442
443#[derive(Debug, Deserialize)]
444struct OpenAIResponse {
445    id: String,
446    model: String,
447    choices: Vec<OpenAIChoice>,
448    usage: OpenAIUsage,
449}
450
451#[derive(Debug, Deserialize)]
452struct OpenAIChoice {
453    message: OpenAIResponseMessage,
454    finish_reason: Option<String>,
455}
456
457#[derive(Debug, Deserialize)]
458struct OpenAIResponseMessage {
459    content: Option<String>,
460    tool_calls: Option<Vec<OpenAIToolCall>>,
461}
462
463#[derive(Debug, Deserialize)]
464struct OpenAIUsage {
465    prompt_tokens: u64,
466    completion_tokens: u64,
467}
468
469#[derive(Debug, Deserialize)]
470struct OpenAIStreamEvent {
471    choices: Vec<OpenAIStreamChoice>,
472}
473
474#[derive(Debug, Deserialize)]
475struct OpenAIStreamChoice {
476    index: usize,
477    delta: OpenAIStreamDelta,
478    finish_reason: Option<String>,
479}
480
481#[derive(Debug, Deserialize)]
482struct OpenAIStreamDelta {
483    content: Option<String>,
484}
485
486impl From<OpenAIResponse> for CompletionResponse {
487    fn from(resp: OpenAIResponse) -> Self {
488        let choice = resp.choices.into_iter().next();
489        let (content, stop_reason) = match choice {
490            Some(c) => {
491                let mut blocks = Vec::new();
492
493                if let Some(text) = c.message.content {
494                    blocks.push(ContentBlock::Text { text });
495                }
496
497                if let Some(tool_calls) = c.message.tool_calls {
498                    for tc in tool_calls {
499                        let input: serde_json::Value =
500                            serde_json::from_str(&tc.function.arguments).unwrap_or_default();
501                        blocks.push(ContentBlock::ToolUse {
502                            id: tc.id,
503                            name: tc.function.name,
504                            input,
505                        });
506                    }
507                }
508
509                let stop = c.finish_reason.and_then(|r| match r.as_str() {
510                    "stop" => Some(StopReason::EndTurn),
511                    "length" => Some(StopReason::MaxTokens),
512                    "tool_calls" => Some(StopReason::ToolUse),
513                    _ => None,
514                });
515
516                (blocks, stop)
517            }
518            None => (vec![], None),
519        };
520
521        Self {
522            id: resp.id,
523            model: resp.model,
524            content,
525            stop_reason,
526            usage: TokenUsage {
527                input_tokens: resp.usage.prompt_tokens,
528                output_tokens: resp.usage.completion_tokens,
529                cache_read_tokens: None,
530                cache_write_tokens: None,
531            },
532        }
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539    use crate::Message;
540
541    #[test]
542    fn test_provider_name() {
543        let provider = OpenAIProvider::new(ApiKey::new("test".to_string()));
544        assert_eq!(provider.name(), "openai");
545    }
546
547    #[test]
548    fn test_request_conversion() {
549        let provider = OpenAIProvider::new(ApiKey::new("test".to_string()));
550        let request = CompletionRequest {
551            model: "gpt-4o".to_string(),
552            messages: vec![Message {
553                role: Role::User,
554                content: MessageContent::Text("Hello".to_string()),
555            }],
556            system: Some("You are helpful".to_string()),
557            max_tokens: 1024,
558            temperature: 0.7,
559            stop: None,
560            tools: None,
561        };
562
563        let openai_req = provider.to_openai_request(&request);
564        assert_eq!(openai_req.model, "gpt-4o");
565        assert_eq!(openai_req.messages.len(), 2); // system + user
566    }
567}