agent_runtime/llm/provider/
openai.rs1use async_trait::async_trait;
2use reqwest::Client as HttpClient;
3use serde::{Deserialize, Serialize};
4
5use super::super::{ChatClient, ChatRequest, ChatResponse, LlmError, LlmResult, TextStream};
6
7const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
8
9pub struct OpenAIClient {
11 api_key: String,
12 model: String,
13 http_client: HttpClient,
14}
15
16impl OpenAIClient {
17 pub fn new(api_key: impl Into<String>) -> Self {
19 Self::with_model(api_key, "gpt-4")
20 }
21
22 pub fn with_model(api_key: impl Into<String>, model: impl Into<String>) -> Self {
24 Self {
25 api_key: api_key.into(),
26 model: model.into(),
27 http_client: HttpClient::new(),
28 }
29 }
30}
31
32#[async_trait]
33impl ChatClient for OpenAIClient {
34 async fn chat(&self, request: ChatRequest) -> LlmResult<ChatResponse> {
35 let openai_request = OpenAIChatRequest {
37 model: self.model.clone(),
38 messages: request.messages,
39 temperature: request.temperature,
40 max_tokens: request.max_tokens,
41 top_p: request.top_p,
42 };
43
44 let response = self
46 .http_client
47 .post(OPENAI_API_URL)
48 .header("Authorization", format!("Bearer {}", self.api_key))
49 .header("Content-Type", "application/json")
50 .json(&openai_request)
51 .send()
52 .await
53 .map_err(|e| LlmError::NetworkError(e.to_string()))?;
54
55 let status = response.status();
57 if !status.is_success() {
58 let error_text = response.text().await.unwrap_or_default();
59 return Err(match status.as_u16() {
60 401 => LlmError::AuthenticationFailed(error_text),
61 429 => LlmError::RateLimitExceeded,
62 _ => LlmError::ApiError(format!("Status {}: {}", status, error_text)),
63 });
64 }
65
66 let openai_response: OpenAIChatResponse = response
68 .json()
69 .await
70 .map_err(|e| LlmError::ParseError(e.to_string()))?;
71
72 let choice = openai_response
74 .choices
75 .first()
76 .ok_or_else(|| LlmError::ParseError("No choices in response".to_string()))?;
77
78 Ok(ChatResponse {
79 content: choice.message.content.clone(),
80 model: openai_response.model,
81 usage: openai_response.usage.map(|u| super::super::types::Usage {
82 prompt_tokens: u.prompt_tokens,
83 completion_tokens: u.completion_tokens,
84 total_tokens: u.total_tokens,
85 }),
86 finish_reason: choice.finish_reason.clone(),
87 })
88 }
89
90 async fn chat_stream(&self, _request: ChatRequest) -> LlmResult<TextStream> {
91 Err(LlmError::ApiError(
94 "Streaming not yet implemented for OpenAI - use LlamaClient".to_string(),
95 ))
96 }
97
98 fn model(&self) -> &str {
99 &self.model
100 }
101
102 fn provider(&self) -> &str {
103 "openai"
104 }
105}
106
107#[derive(Debug, Serialize)]
110struct OpenAIChatRequest {
111 model: String,
112 messages: Vec<super::super::types::ChatMessage>,
113
114 #[serde(skip_serializing_if = "Option::is_none")]
115 temperature: Option<f32>,
116
117 #[serde(skip_serializing_if = "Option::is_none")]
118 max_tokens: Option<u32>,
119
120 #[serde(skip_serializing_if = "Option::is_none")]
121 top_p: Option<f32>,
122}
123
124#[derive(Debug, Deserialize)]
125struct OpenAIChatResponse {
126 model: String,
127 choices: Vec<Choice>,
128 usage: Option<UsageInfo>,
129}
130
131#[derive(Debug, Deserialize)]
132struct Choice {
133 message: Message,
134 finish_reason: Option<String>,
135}
136
137#[derive(Debug, Deserialize)]
138struct Message {
139 content: String,
140}
141
142#[derive(Debug, Deserialize)]
143struct UsageInfo {
144 prompt_tokens: u32,
145 completion_tokens: u32,
146 total_tokens: u32,
147}