1use crate::proxy::{LlmMessage, LlmProvider, LlmRequest, LlmResponse, LlmRole, LlmUsage};
6use anyhow::{Context, Result};
7use async_trait::async_trait;
8use reqwest::Client;
9use serde::{Deserialize, Serialize};
10
11pub struct OpenAiProvider {
12 client: Client,
13 api_key: String,
14 base_url: String,
15}
16
17impl OpenAiProvider {
18 pub fn new(api_key: String) -> Self {
19 Self {
20 client: Client::new(),
21 api_key,
22 base_url: "https://api.openai.com/v1".to_string(),
23 }
24 }
25}
26
27impl Default for OpenAiProvider {
28 fn default() -> Self {
29 let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
30 Self::new(api_key)
31 }
32}
33
34#[async_trait]
35impl LlmProvider for OpenAiProvider {
36 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
37 let url = format!("{}/chat/completions", self.base_url);
38
39 let openai_request = OpenAiChatRequest {
40 model: request.model.clone(),
41 messages: request.messages.into_iter().map(Into::into).collect(),
42 temperature: request.temperature,
43 max_tokens: request.max_tokens,
44 stream: request.stream,
45 };
46
47 let response = self
48 .client
49 .post(&url)
50 .header("Authorization", format!("Bearer {}", self.api_key))
51 .json(&openai_request)
52 .send()
53 .await
54 .context("Failed to send request to OpenAI")?;
55
56 if !response.status().is_success() {
57 let error_text = response.text().await?;
58 return Err(anyhow::anyhow!("OpenAI API error: {}", error_text));
59 }
60
61 let openai_response: OpenAiChatResponse = response.json().await?;
62
63 let content = openai_response
64 .choices
65 .first()
66 .map(|c| c.message.content.clone())
67 .unwrap_or_default();
68
69 Ok(LlmResponse {
70 content,
71 model: openai_response.model,
72 usage: openai_response.usage.map(Into::into),
73 })
74 }
75
76 fn name(&self) -> &'static str {
77 "OpenAI"
78 }
79}
80
81#[derive(Debug, Serialize)]
82struct OpenAiChatRequest {
83 model: String,
84 messages: Vec<OpenAiMessage>,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 temperature: Option<f32>,
87 #[serde(skip_serializing_if = "Option::is_none")]
88 max_tokens: Option<usize>,
89 stream: bool,
90}
91
92#[derive(Debug, Serialize, Deserialize)]
93struct OpenAiMessage {
94 role: String,
95 content: String,
96}
97
98impl From<LlmMessage> for OpenAiMessage {
99 fn from(msg: LlmMessage) -> Self {
100 Self {
101 role: match msg.role {
102 LlmRole::System => "system".to_string(),
103 LlmRole::User => "user".to_string(),
104 LlmRole::Assistant => "assistant".to_string(),
105 },
106 content: msg.content,
107 }
108 }
109}
110
111#[derive(Debug, Deserialize)]
112struct OpenAiChatResponse {
113 model: String,
114 choices: Vec<OpenAiChoice>,
115 usage: Option<OpenAiUsage>,
116}
117
118#[derive(Debug, Deserialize)]
119struct OpenAiChoice {
120 message: OpenAiMessage,
121}
122
123#[derive(Debug, Deserialize)]
124struct OpenAiUsage {
125 prompt_tokens: usize,
126 completion_tokens: usize,
127 total_tokens: usize,
128}
129
130impl From<OpenAiUsage> for LlmUsage {
131 fn from(usage: OpenAiUsage) -> Self {
132 Self {
133 prompt_tokens: usage.prompt_tokens,
134 completion_tokens: usage.completion_tokens,
135 total_tokens: usage.total_tokens,
136 }
137 }
138}