ai_lib/provider/
openai.rs1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::types::{ChatCompletionRequest, ChatCompletionResponse, AiLibError, Message, Role, Choice, Usage};
3use crate::transport::{HttpClient, HttpTransport};
4use std::env;
5use std::collections::HashMap;
6use futures::stream::{self, Stream};
7
8pub struct OpenAiAdapter {
12 transport: HttpTransport,
13 api_key: String,
14 base_url: String,
15}
16
17impl OpenAiAdapter {
18 pub fn new() -> Result<Self, AiLibError> {
19 let api_key = env::var("OPENAI_API_KEY")
20 .map_err(|_| AiLibError::AuthenticationError(
21 "OPENAI_API_KEY environment variable not set".to_string()
22 ))?;
23
24 Ok(Self {
25 transport: HttpTransport::new(),
26 api_key,
27 base_url: "https://api.openai.com/v1".to_string(),
28 })
29 }
30
31 fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
32 let mut openai_request = serde_json::json!({
33 "model": request.model,
34 "messages": request.messages.iter().map(|msg| {
35 serde_json::json!({
36 "role": match msg.role {
37 Role::System => "system",
38 Role::User => "user",
39 Role::Assistant => "assistant",
40 },
41 "content": msg.content
42 })
43 }).collect::<Vec<_>>()
44 });
45
46 if let Some(temp) = request.temperature {
47 openai_request["temperature"] = serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
48 }
49 if let Some(max_tokens) = request.max_tokens {
50 openai_request["max_tokens"] = serde_json::Value::Number(serde_json::Number::from(max_tokens));
51 }
52 if let Some(top_p) = request.top_p {
53 openai_request["top_p"] = serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
54 }
55 if let Some(freq_penalty) = request.frequency_penalty {
56 openai_request["frequency_penalty"] = serde_json::Value::Number(serde_json::Number::from_f64(freq_penalty.into()).unwrap());
57 }
58 if let Some(presence_penalty) = request.presence_penalty {
59 openai_request["presence_penalty"] = serde_json::Value::Number(serde_json::Number::from_f64(presence_penalty.into()).unwrap());
60 }
61
62 openai_request
63 }
64
65 fn parse_response(&self, response: serde_json::Value) -> Result<ChatCompletionResponse, AiLibError> {
66 let choices = response["choices"]
67 .as_array()
68 .ok_or_else(|| AiLibError::ProviderError("Invalid response format: choices not found".to_string()))?
69 .iter()
70 .enumerate()
71 .map(|(index, choice)| {
72 let message = choice["message"].as_object()
73 .ok_or_else(|| AiLibError::ProviderError("Invalid choice format".to_string()))?;
74
75 let role = match message["role"].as_str().unwrap_or("user") {
76 "system" => Role::System,
77 "assistant" => Role::Assistant,
78 _ => Role::User,
79 };
80
81 let content = message["content"].as_str()
82 .unwrap_or("")
83 .to_string();
84
85 Ok(Choice {
86 index: index as u32,
87 message: Message { role, content },
88 finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
89 })
90 })
91 .collect::<Result<Vec<_>, AiLibError>>()?;
92
93 let usage = response["usage"].as_object()
94 .ok_or_else(|| AiLibError::ProviderError("Invalid response format: usage not found".to_string()))?;
95
96 let usage = Usage {
97 prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
98 completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
99 total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
100 };
101
102 Ok(ChatCompletionResponse {
103 id: response["id"].as_str().unwrap_or("").to_string(),
104 object: response["object"].as_str().unwrap_or("").to_string(),
105 created: response["created"].as_u64().unwrap_or(0),
106 model: response["model"].as_str().unwrap_or("").to_string(),
107 choices,
108 usage,
109 })
110 }
111}
112
113#[async_trait::async_trait]
114impl ChatApi for OpenAiAdapter {
115 async fn chat_completion(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, AiLibError> {
116 let openai_request = self.convert_request(&request);
117 let url = format!("{}/chat/completions", self.base_url);
118
119
120
121 let mut headers = HashMap::new();
122 headers.insert("Authorization".to_string(), format!("Bearer {}", self.api_key));
123 headers.insert("Content-Type".to_string(), "application/json".to_string());
124
125 let response: serde_json::Value = self.transport
126 .post(&url, Some(headers), &openai_request)
127 .await?;
128
129 self.parse_response(response)
130 }
131
132 async fn chat_completion_stream(&self, _request: ChatCompletionRequest) -> Result<Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>, AiLibError> {
133 let stream = stream::empty();
134 Ok(Box::new(Box::pin(stream)))
135 }
136
137 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
138 let url = format!("{}/models", self.base_url);
139 let mut headers = HashMap::new();
140 headers.insert("Authorization".to_string(), format!("Bearer {}", self.api_key));
141
142 let response: serde_json::Value = self.transport
143 .get(&url, Some(headers))
144 .await?;
145
146 Ok(response["data"].as_array()
147 .unwrap_or(&vec![])
148 .iter()
149 .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
150 .collect())
151 }
152
153 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
154 Ok(ModelInfo {
155 id: model_id.to_string(),
156 object: "model".to_string(),
157 created: 0,
158 owned_by: "openai".to_string(),
159 permission: vec![ModelPermission {
160 id: "default".to_string(),
161 object: "model_permission".to_string(),
162 created: 0,
163 allow_create_engine: false,
164 allow_sampling: true,
165 allow_logprobs: false,
166 allow_search_indices: false,
167 allow_view: true,
168 allow_fine_tuning: false,
169 organization: "*".to_string(),
170 group: None,
171 is_blocking: false,
172 }],
173 })
174 }
175}