1use crate::api::{ChatApi, ChatCompletionChunk, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission};
2use crate::types::{ChatCompletionRequest, ChatCompletionResponse, AiLibError, Message, Role, Choice, Usage};
3use crate::transport::{HttpClient, HttpTransport};
4use super::config::ProviderConfig;
5use std::env;
6use futures::stream::{Stream, StreamExt};
7
8pub struct GenericAdapter {
12 transport: HttpTransport,
13 config: ProviderConfig,
14 api_key: String,
15}
16
17impl GenericAdapter {
18 pub fn new(config: ProviderConfig) -> Result<Self, AiLibError> {
19 let api_key = env::var(&config.api_key_env)
20 .map_err(|_| AiLibError::AuthenticationError(
21 format!("{} environment variable not set", config.api_key_env)
22 ))?;
23
24 Ok(Self {
25 transport: HttpTransport::new(),
26 config,
27 api_key,
28 })
29 }
30
31 pub fn with_transport(config: ProviderConfig, transport: HttpTransport) -> Result<Self, AiLibError> {
33 let api_key = env::var(&config.api_key_env)
34 .map_err(|_| AiLibError::AuthenticationError(
35 format!("{} environment variable not set", config.api_key_env)
36 ))?;
37
38 Ok(Self {
39 transport,
40 config,
41 api_key,
42 })
43 }
44
45 fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
47 let default_role = "user".to_string();
48
49 let messages: Vec<serde_json::Value> = request.messages.iter().map(|msg| {
51 let role_key = format!("{:?}", msg.role);
52 let mapped_role = self.config.field_mapping.role_mapping
53 .get(&role_key)
54 .unwrap_or(&default_role);
55 serde_json::json!({
56 "role": mapped_role,
57 "content": msg.content
58 })
59 }).collect();
60
61 let mut provider_request = serde_json::json!({
63 "model": request.model,
64 "messages": messages
65 });
66
67 if let Some(temp) = request.temperature {
69 provider_request["temperature"] = serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
70 }
71 if let Some(max_tokens) = request.max_tokens {
72 provider_request["max_tokens"] = serde_json::Value::Number(serde_json::Number::from(max_tokens));
73 }
74 if let Some(top_p) = request.top_p {
75 provider_request["top_p"] = serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
76 }
77 if let Some(freq_penalty) = request.frequency_penalty {
78 provider_request["frequency_penalty"] = serde_json::Value::Number(serde_json::Number::from_f64(freq_penalty.into()).unwrap());
79 }
80 if let Some(presence_penalty) = request.presence_penalty {
81 provider_request["presence_penalty"] = serde_json::Value::Number(serde_json::Number::from_f64(presence_penalty.into()).unwrap());
82 }
83
84 provider_request
85 }
86
87 fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
89 let mut i = 0;
90 while i < buffer.len().saturating_sub(1) {
91 if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
92 return Some(i + 2);
93 }
94 if i < buffer.len().saturating_sub(3)
95 && buffer[i] == b'\r' && buffer[i + 1] == b'\n'
96 && buffer[i + 2] == b'\r' && buffer[i + 3] == b'\n' {
97 return Some(i + 4);
98 }
99 i += 1;
100 }
101 None
102 }
103
104 fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
106 for line in event_text.lines() {
107 let line = line.trim();
108 if line.starts_with("data: ") {
109 let data = &line[6..];
110 if data == "[DONE]" {
111 return Some(Ok(None));
112 }
113 return Some(Self::parse_chunk_data(data));
114 }
115 }
116 None
117 }
118
119 fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
121 match serde_json::from_str::<serde_json::Value>(data) {
122 Ok(json) => {
123 let choices = json["choices"].as_array()
124 .map(|arr| {
125 arr.iter()
126 .enumerate()
127 .map(|(index, choice)| {
128 let delta = &choice["delta"];
129 ChoiceDelta {
130 index: index as u32,
131 delta: MessageDelta {
132 role: delta["role"].as_str().map(|r| match r {
133 "assistant" => Role::Assistant,
134 "user" => Role::User,
135 "system" => Role::System,
136 _ => Role::Assistant,
137 }),
138 content: delta["content"].as_str().map(str::to_string),
139 },
140 finish_reason: choice["finish_reason"].as_str().map(str::to_string),
141 }
142 })
143 .collect()
144 })
145 .unwrap_or_default();
146
147 Ok(Some(ChatCompletionChunk {
148 id: json["id"].as_str().unwrap_or_default().to_string(),
149 object: json["object"].as_str().unwrap_or("chat.completion.chunk").to_string(),
150 created: json["created"].as_u64().unwrap_or(0),
151 model: json["model"].as_str().unwrap_or_default().to_string(),
152 choices,
153 }))
154 }
155 Err(e) => Err(AiLibError::ProviderError(format!("JSON parse error: {}", e)))
156 }
157 }
158
159 fn parse_response(&self, response: serde_json::Value) -> Result<ChatCompletionResponse, AiLibError> {
161 let choices = response["choices"]
162 .as_array()
163 .ok_or_else(|| AiLibError::ProviderError("Invalid response format: choices not found".to_string()))?
164 .iter()
165 .enumerate()
166 .map(|(index, choice)| {
167 let message = choice["message"].as_object()
168 .ok_or_else(|| AiLibError::ProviderError("Invalid choice format".to_string()))?;
169
170 let role = match message["role"].as_str().unwrap_or("user") {
171 "system" => Role::System,
172 "assistant" => Role::Assistant,
173 _ => Role::User,
174 };
175
176 let content = message["content"].as_str()
177 .unwrap_or("")
178 .to_string();
179
180 Ok(Choice {
181 index: index as u32,
182 message: Message { role, content },
183 finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
184 })
185 })
186 .collect::<Result<Vec<_>, AiLibError>>()?;
187
188 let usage = response["usage"].as_object()
189 .ok_or_else(|| AiLibError::ProviderError("Invalid response format: usage not found".to_string()))?;
190
191 let usage = Usage {
192 prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
193 completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
194 total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
195 };
196
197 Ok(ChatCompletionResponse {
198 id: response["id"].as_str().unwrap_or("").to_string(),
199 object: response["object"].as_str().unwrap_or("").to_string(),
200 created: response["created"].as_u64().unwrap_or(0),
201 model: response["model"].as_str().unwrap_or("").to_string(),
202 choices,
203 usage,
204 })
205 }
206}
207
208#[async_trait::async_trait]
209impl ChatApi for GenericAdapter {
210 async fn chat_completion(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, AiLibError> {
211 let provider_request = self.convert_request(&request);
212 let url = format!("{}{}", self.config.base_url, self.config.chat_endpoint);
213
214 let mut headers = self.config.headers.clone();
215
216 if self.config.base_url.contains("anthropic.com") {
218 headers.insert("x-api-key".to_string(), self.api_key.clone());
219 } else {
220 headers.insert("Authorization".to_string(), format!("Bearer {}", self.api_key));
221 }
222
223 let response: serde_json::Value = self.transport
224 .post(&url, Some(headers), &provider_request)
225 .await?;
226
227 self.parse_response(response)
228 }
229
230 async fn chat_completion_stream(&self, request: ChatCompletionRequest) -> Result<Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>, AiLibError> {
231 let mut stream_request = self.convert_request(&request);
232 stream_request["stream"] = serde_json::Value::Bool(true);
233
234 let url = format!("{}{}", self.config.base_url, self.config.chat_endpoint);
235
236 let mut client_builder = reqwest::Client::builder();
238 if let Ok(proxy_url) = std::env::var("AI_PROXY_URL") {
239 if let Ok(proxy) = reqwest::Proxy::all(&proxy_url) {
240 client_builder = client_builder.proxy(proxy);
241 }
242 }
243 let client = client_builder.build()
244 .map_err(|e| AiLibError::ProviderError(format!("Client error: {}", e)))?;
245
246 let mut headers = self.config.headers.clone();
247 headers.insert("Accept".to_string(), "text/event-stream".to_string());
248
249 if self.config.base_url.contains("anthropic.com") {
251 headers.insert("x-api-key".to_string(), self.api_key.clone());
252 } else {
253 headers.insert("Authorization".to_string(), format!("Bearer {}", self.api_key));
254 }
255
256 let response = client
257 .post(&url)
258 .json(&stream_request);
259
260 let mut req = response;
261 for (key, value) in headers {
262 req = req.header(key, value);
263 }
264
265 let response = req.send().await
266 .map_err(|e| AiLibError::ProviderError(format!("Stream request failed: {}", e)))?;
267
268 if !response.status().is_success() {
269 let error_text = response.text().await.unwrap_or_default();
270 return Err(AiLibError::ProviderError(format!("Stream error: {}", error_text)));
271 }
272
273 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
274
275 tokio::spawn(async move {
276 let mut buffer = Vec::new();
277 let mut stream = response.bytes_stream();
278
279 while let Some(result) = stream.next().await {
280 match result {
281 Ok(bytes) => {
282 buffer.extend_from_slice(&bytes);
283
284 while let Some(event_end) = Self::find_event_boundary(&buffer) {
285 let event_bytes = buffer.drain(..event_end).collect::<Vec<_>>();
286
287 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
288 if let Some(chunk) = Self::parse_sse_event(event_text) {
289 match chunk {
290 Ok(Some(c)) => {
291 if tx.send(Ok(c)).is_err() {
292 return;
293 }
294 }
295 Ok(None) => return,
296 Err(e) => {
297 let _ = tx.send(Err(e));
298 return;
299 }
300 }
301 }
302 }
303 }
304 }
305 Err(e) => {
306 let _ = tx.send(Err(AiLibError::ProviderError(format!("Stream error: {}", e))));
307 break;
308 }
309 }
310 }
311 });
312
313 let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
314 Ok(Box::new(Box::pin(stream)))
315 }
316
317
318
319 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
320 if let Some(models_endpoint) = &self.config.models_endpoint {
321 let url = format!("{}{}", self.config.base_url, models_endpoint);
322 let mut headers = self.config.headers.clone();
323
324 if self.config.base_url.contains("anthropic.com") {
326 headers.insert("x-api-key".to_string(), self.api_key.clone());
327 } else {
328 headers.insert("Authorization".to_string(), format!("Bearer {}", self.api_key));
329 }
330
331 let response: serde_json::Value = self.transport
332 .get(&url, Some(headers))
333 .await?;
334
335 Ok(response["data"].as_array()
336 .unwrap_or(&vec![])
337 .iter()
338 .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
339 .collect())
340 } else {
341 Err(AiLibError::ProviderError("Models endpoint not configured".to_string()))
342 }
343 }
344
345 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
346 Ok(ModelInfo {
347 id: model_id.to_string(),
348 object: "model".to_string(),
349 created: 0,
350 owned_by: "generic".to_string(),
351 permission: vec![ModelPermission {
352 id: "default".to_string(),
353 object: "model_permission".to_string(),
354 created: 0,
355 allow_create_engine: false,
356 allow_sampling: true,
357 allow_logprobs: false,
358 allow_search_indices: false,
359 allow_view: true,
360 allow_fine_tuning: false,
361 organization: "*".to_string(),
362 group: None,
363 is_blocking: false,
364 }],
365 })
366 }
367}