ai_lib/provider/
openai.rs

1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::types::{ChatCompletionRequest, ChatCompletionResponse, AiLibError, Message, Role, Choice, Usage};
3use crate::transport::{HttpClient, HttpTransport};
4use std::env;
5use std::collections::HashMap;
6use futures::stream::{self, Stream};
7
8/// OpenAI适配器,支持GPT系列模型
9/// 
10/// OpenAI adapter supporting GPT series models
11pub struct OpenAiAdapter {
12    transport: HttpTransport,
13    api_key: String,
14    base_url: String,
15}
16
17impl OpenAiAdapter {
18    pub fn new() -> Result<Self, AiLibError> {
19        let api_key = env::var("OPENAI_API_KEY")
20            .map_err(|_| AiLibError::AuthenticationError(
21                "OPENAI_API_KEY environment variable not set".to_string()
22            ))?;
23        
24        Ok(Self {
25            transport: HttpTransport::new(),
26            api_key,
27            base_url: "https://api.openai.com/v1".to_string(),
28        })
29    }
30
31    fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
32        let mut openai_request = serde_json::json!({
33            "model": request.model,
34            "messages": request.messages.iter().map(|msg| {
35                serde_json::json!({
36                    "role": match msg.role {
37                        Role::System => "system",
38                        Role::User => "user",
39                        Role::Assistant => "assistant",
40                    },
41                    "content": msg.content
42                })
43            }).collect::<Vec<_>>()
44        });
45
46        if let Some(temp) = request.temperature {
47            openai_request["temperature"] = serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
48        }
49        if let Some(max_tokens) = request.max_tokens {
50            openai_request["max_tokens"] = serde_json::Value::Number(serde_json::Number::from(max_tokens));
51        }
52        if let Some(top_p) = request.top_p {
53            openai_request["top_p"] = serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
54        }
55        if let Some(freq_penalty) = request.frequency_penalty {
56            openai_request["frequency_penalty"] = serde_json::Value::Number(serde_json::Number::from_f64(freq_penalty.into()).unwrap());
57        }
58        if let Some(presence_penalty) = request.presence_penalty {
59            openai_request["presence_penalty"] = serde_json::Value::Number(serde_json::Number::from_f64(presence_penalty.into()).unwrap());
60        }
61
62        openai_request
63    }
64
65    fn parse_response(&self, response: serde_json::Value) -> Result<ChatCompletionResponse, AiLibError> {
66        let choices = response["choices"]
67            .as_array()
68            .ok_or_else(|| AiLibError::ProviderError("Invalid response format: choices not found".to_string()))?
69            .iter()
70            .enumerate()
71            .map(|(index, choice)| {
72                let message = choice["message"].as_object()
73                    .ok_or_else(|| AiLibError::ProviderError("Invalid choice format".to_string()))?;
74                
75                let role = match message["role"].as_str().unwrap_or("user") {
76                    "system" => Role::System,
77                    "assistant" => Role::Assistant,
78                    _ => Role::User,
79                };
80                
81                let content = message["content"].as_str()
82                    .unwrap_or("")
83                    .to_string();
84                
85                Ok(Choice {
86                    index: index as u32,
87                    message: Message { role, content },
88                    finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
89                })
90            })
91            .collect::<Result<Vec<_>, AiLibError>>()?;
92        
93        let usage = response["usage"].as_object()
94            .ok_or_else(|| AiLibError::ProviderError("Invalid response format: usage not found".to_string()))?;
95        
96        let usage = Usage {
97            prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
98            completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
99            total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
100        };
101        
102        Ok(ChatCompletionResponse {
103            id: response["id"].as_str().unwrap_or("").to_string(),
104            object: response["object"].as_str().unwrap_or("").to_string(),
105            created: response["created"].as_u64().unwrap_or(0),
106            model: response["model"].as_str().unwrap_or("").to_string(),
107            choices,
108            usage,
109        })
110    }
111}
112
113#[async_trait::async_trait]
114impl ChatApi for OpenAiAdapter {
115    async fn chat_completion(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, AiLibError> {
116        let openai_request = self.convert_request(&request);
117        let url = format!("{}/chat/completions", self.base_url);
118        
119
120        
121        let mut headers = HashMap::new();
122        headers.insert("Authorization".to_string(), format!("Bearer {}", self.api_key));
123        headers.insert("Content-Type".to_string(), "application/json".to_string());
124        
125        let response: serde_json::Value = self.transport
126            .post(&url, Some(headers), &openai_request)
127            .await?;
128        
129        self.parse_response(response)
130    }
131
132    async fn chat_completion_stream(&self, _request: ChatCompletionRequest) -> Result<Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>, AiLibError> {
133        let stream = stream::empty();
134        Ok(Box::new(Box::pin(stream)))
135    }
136
137    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
138        let url = format!("{}/models", self.base_url);
139        let mut headers = HashMap::new();
140        headers.insert("Authorization".to_string(), format!("Bearer {}", self.api_key));
141        
142        let response: serde_json::Value = self.transport
143            .get(&url, Some(headers))
144            .await?;
145        
146        Ok(response["data"].as_array()
147            .unwrap_or(&vec![])
148            .iter()
149            .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
150            .collect())
151    }
152
153    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
154        Ok(ModelInfo {
155            id: model_id.to_string(),
156            object: "model".to_string(),
157            created: 0,
158            owned_by: "openai".to_string(),
159            permission: vec![ModelPermission {
160                id: "default".to_string(),
161                object: "model_permission".to_string(),
162                created: 0,
163                allow_create_engine: false,
164                allow_sampling: true,
165                allow_logprobs: false,
166                allow_search_indices: false,
167                allow_view: true,
168                allow_fine_tuning: false,
169                organization: "*".to_string(),
170                group: None,
171                is_blocking: false,
172            }],
173        })
174    }
175}