agent_runtime/llm/provider/
openai.rs1use async_trait::async_trait;
2use reqwest::Client as HttpClient;
3use serde::{Deserialize, Serialize};
4use tokio::sync::mpsc;
5
6use super::super::{ChatClient, ChatRequest, ChatResponse, LlmError, LlmResult};
7
8const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
9
10pub struct OpenAIClient {
12 api_key: String,
13 model: String,
14 http_client: HttpClient,
15}
16
17impl OpenAIClient {
18 pub fn new(api_key: impl Into<String>) -> Self {
20 Self::with_model(api_key, "gpt-4")
21 }
22
23 pub fn with_model(api_key: impl Into<String>, model: impl Into<String>) -> Self {
25 Self {
26 api_key: api_key.into(),
27 model: model.into(),
28 http_client: HttpClient::new(),
29 }
30 }
31
32 pub fn model(&self) -> &str {
34 &self.model
35 }
36
37 pub fn provider(&self) -> &str {
39 "openai"
40 }
41}
42
43#[async_trait]
44impl ChatClient for OpenAIClient {
45 async fn chat(&self, request: ChatRequest) -> LlmResult<ChatResponse> {
46 let openai_request = OpenAIChatRequest {
48 model: self.model.clone(),
49 messages: request.messages,
50 temperature: request.temperature,
51 max_tokens: request.max_tokens,
52 top_p: request.top_p,
53 tools: request.tools,
54 };
55
56 let response = self
58 .http_client
59 .post(OPENAI_API_URL)
60 .header("Authorization", format!("Bearer {}", self.api_key))
61 .header("Content-Type", "application/json")
62 .json(&openai_request)
63 .send()
64 .await
65 .map_err(|e| LlmError::NetworkError(e.to_string()))?;
66
67 let status = response.status();
69 if !status.is_success() {
70 let error_text = response.text().await.unwrap_or_default();
71 return Err(match status.as_u16() {
72 401 => LlmError::AuthenticationFailed(error_text),
73 429 => LlmError::RateLimitExceeded,
74 _ => LlmError::ApiError(format!("Status {}: {}", status, error_text)),
75 });
76 }
77
78 let openai_response: OpenAIChatResponse = response
80 .json()
81 .await
82 .map_err(|e| LlmError::ParseError(e.to_string()))?;
83
84 let choice = openai_response
86 .choices
87 .first()
88 .ok_or_else(|| LlmError::ParseError("No choices in response".to_string()))?;
89
90 let tool_calls = choice.message.tool_calls.as_ref().map(|calls| {
92 calls
93 .iter()
94 .map(|tc| super::super::types::ToolCall {
95 id: tc.id.clone(),
96 r#type: tc.r#type.clone(),
97 function: super::super::types::FunctionCall {
98 name: tc.function.name.clone(),
99 arguments: tc.function.arguments.clone(),
100 },
101 })
102 .collect()
103 });
104
105 Ok(ChatResponse {
106 content: choice.message.content.clone(),
107 model: openai_response.model,
108 usage: openai_response.usage.map(|u| super::super::types::Usage {
109 prompt_tokens: u.prompt_tokens,
110 completion_tokens: u.completion_tokens,
111 total_tokens: u.total_tokens,
112 }),
113 finish_reason: choice.finish_reason.clone(),
114 tool_calls,
115 })
116 }
117
118 async fn chat_stream(
119 &self,
120 _request: ChatRequest,
121 _tx: mpsc::Sender<String>,
122 ) -> LlmResult<ChatResponse> {
123 Err(LlmError::ApiError(
126 "Streaming not yet implemented for OpenAI - use LlamaClient".to_string(),
127 ))
128 }
129}
130
131#[derive(Debug, Serialize)]
134struct OpenAIChatRequest {
135 model: String,
136 messages: Vec<super::super::types::ChatMessage>,
137
138 #[serde(skip_serializing_if = "Option::is_none")]
139 temperature: Option<f32>,
140
141 #[serde(skip_serializing_if = "Option::is_none")]
142 max_tokens: Option<u32>,
143
144 #[serde(skip_serializing_if = "Option::is_none")]
145 top_p: Option<f32>,
146
147 #[serde(skip_serializing_if = "Option::is_none")]
148 tools: Option<Vec<serde_json::Value>>,
149}
150
151#[derive(Debug, Deserialize)]
152struct OpenAIChatResponse {
153 model: String,
154 choices: Vec<Choice>,
155 usage: Option<UsageInfo>,
156}
157
158#[derive(Debug, Deserialize)]
159struct Choice {
160 message: Message,
161 finish_reason: Option<String>,
162}
163
164#[derive(Debug, Deserialize)]
165struct Message {
166 #[serde(default)]
167 content: String,
168
169 #[serde(skip_serializing_if = "Option::is_none")]
170 tool_calls: Option<Vec<OpenAIToolCall>>,
171}
172
173#[derive(Debug, Deserialize)]
174struct OpenAIToolCall {
175 id: String,
176 r#type: String,
177 function: OpenAIFunction,
178}
179
180#[derive(Debug, Deserialize)]
181struct OpenAIFunction {
182 name: String,
183 arguments: String,
184}
185
186#[derive(Debug, Deserialize)]
187struct UsageInfo {
188 prompt_tokens: u32,
189 completion_tokens: u32,
190 total_tokens: u32,
191}