Skip to main content

agent_runtime/llm/provider/
openai.rs

1use async_trait::async_trait;
2use reqwest::Client as HttpClient;
3use serde::{Deserialize, Serialize};
4use tokio::sync::mpsc;
5
6use super::super::{ChatClient, ChatRequest, ChatResponse, LlmError, LlmResult};
7
8const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
9
10/// OpenAI chat client
11pub struct OpenAIClient {
12    api_key: String,
13    model: String,
14    http_client: HttpClient,
15}
16
17impl OpenAIClient {
18    /// Create a new OpenAI client
19    pub fn new(api_key: impl Into<String>) -> Self {
20        Self::with_model(api_key, "gpt-4")
21    }
22
23    /// Create a new OpenAI client with specific model
24    pub fn with_model(api_key: impl Into<String>, model: impl Into<String>) -> Self {
25        Self {
26            api_key: api_key.into(),
27            model: model.into(),
28            http_client: HttpClient::new(),
29        }
30    }
31
32    /// Get the model name
33    pub fn model(&self) -> &str {
34        &self.model
35    }
36
37    /// Get the provider name
38    pub fn provider(&self) -> &str {
39        "openai"
40    }
41}
42
43#[async_trait]
44impl ChatClient for OpenAIClient {
45    async fn chat(&self, request: ChatRequest) -> LlmResult<ChatResponse> {
46        // Build OpenAI API request
47        let openai_request = OpenAIChatRequest {
48            model: self.model.clone(),
49            messages: request.messages,
50            temperature: request.temperature,
51            max_tokens: request.max_tokens,
52            top_p: request.top_p,
53            tools: request.tools,
54        };
55
56        // Send request
57        let response = self
58            .http_client
59            .post(OPENAI_API_URL)
60            .header("Authorization", format!("Bearer {}", self.api_key))
61            .header("Content-Type", "application/json")
62            .json(&openai_request)
63            .send()
64            .await
65            .map_err(|e| LlmError::NetworkError(e.to_string()))?;
66
67        // Check status
68        let status = response.status();
69        if !status.is_success() {
70            let error_text = response.text().await.unwrap_or_default();
71            return Err(match status.as_u16() {
72                401 => LlmError::AuthenticationFailed(error_text),
73                429 => LlmError::RateLimitExceeded,
74                _ => LlmError::ApiError(format!("Status {}: {}", status, error_text)),
75            });
76        }
77
78        // Parse response
79        let openai_response: OpenAIChatResponse = response
80            .json()
81            .await
82            .map_err(|e| LlmError::ParseError(e.to_string()))?;
83
84        // Extract first choice
85        let choice = openai_response
86            .choices
87            .first()
88            .ok_or_else(|| LlmError::ParseError("No choices in response".to_string()))?;
89
90        // Convert OpenAI tool_calls to our ToolCall type
91        let tool_calls = choice.message.tool_calls.as_ref().map(|calls| {
92            calls
93                .iter()
94                .map(|tc| super::super::types::ToolCall {
95                    id: tc.id.clone(),
96                    r#type: tc.r#type.clone(),
97                    function: super::super::types::FunctionCall {
98                        name: tc.function.name.clone(),
99                        arguments: tc.function.arguments.clone(),
100                    },
101                })
102                .collect()
103        });
104
105        Ok(ChatResponse {
106            content: choice.message.content.clone(),
107            model: openai_response.model,
108            usage: openai_response.usage.map(|u| super::super::types::Usage {
109                prompt_tokens: u.prompt_tokens,
110                completion_tokens: u.completion_tokens,
111                total_tokens: u.total_tokens,
112            }),
113            finish_reason: choice.finish_reason.clone(),
114            tool_calls,
115        })
116    }
117
118    async fn chat_stream(
119        &self,
120        _request: ChatRequest,
121        _tx: mpsc::Sender<String>,
122    ) -> LlmResult<ChatResponse> {
123        // Simple non-streaming fallback for OpenAI - full implementation would use SSE
124        // For now, return error suggesting to use llama.cpp for streaming
125        Err(LlmError::ApiError(
126            "Streaming not yet implemented for OpenAI - use LlamaClient".to_string(),
127        ))
128    }
129}
130
131// OpenAI-specific request/response types
132
133#[derive(Debug, Serialize)]
134struct OpenAIChatRequest {
135    model: String,
136    messages: Vec<super::super::types::ChatMessage>,
137
138    #[serde(skip_serializing_if = "Option::is_none")]
139    temperature: Option<f32>,
140
141    #[serde(skip_serializing_if = "Option::is_none")]
142    max_tokens: Option<u32>,
143
144    #[serde(skip_serializing_if = "Option::is_none")]
145    top_p: Option<f32>,
146
147    #[serde(skip_serializing_if = "Option::is_none")]
148    tools: Option<Vec<serde_json::Value>>,
149}
150
151#[derive(Debug, Deserialize)]
152struct OpenAIChatResponse {
153    model: String,
154    choices: Vec<Choice>,
155    usage: Option<UsageInfo>,
156}
157
158#[derive(Debug, Deserialize)]
159struct Choice {
160    message: Message,
161    finish_reason: Option<String>,
162}
163
164#[derive(Debug, Deserialize)]
165struct Message {
166    #[serde(default)]
167    content: String,
168
169    #[serde(skip_serializing_if = "Option::is_none")]
170    tool_calls: Option<Vec<OpenAIToolCall>>,
171}
172
173#[derive(Debug, Deserialize)]
174struct OpenAIToolCall {
175    id: String,
176    r#type: String,
177    function: OpenAIFunction,
178}
179
180#[derive(Debug, Deserialize)]
181struct OpenAIFunction {
182    name: String,
183    arguments: String,
184}
185
186#[derive(Debug, Deserialize)]
187struct UsageInfo {
188    prompt_tokens: u32,
189    completion_tokens: u32,
190    total_tokens: u32,
191}