1use crate::api::{ChatApi, ChatCompletionChunk, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission};
2use crate::types::{ChatCompletionRequest, ChatCompletionResponse, AiLibError, Message, Role, Choice, Usage};
3use crate::transport::{HttpTransport, DynHttpTransportRef};
4use super::config::ProviderConfig;
5use std::env;
6use futures::stream::{Stream, StreamExt};
7
8pub struct GenericAdapter {
12 transport: DynHttpTransportRef,
13 config: ProviderConfig,
14 api_key: Option<String>,
15}
16
17impl GenericAdapter {
18 pub fn new(config: ProviderConfig) -> Result<Self, AiLibError> {
19 let api_key = env::var(&config.api_key_env).ok();
23
24 Ok(Self {
25 transport: HttpTransport::new().boxed(),
26 config,
27 api_key,
28 })
29 }
30
31 pub fn with_transport(config: ProviderConfig, transport: HttpTransport) -> Result<Self, AiLibError> {
33 let api_key = env::var(&config.api_key_env).ok();
34
35 Ok(Self {
36 transport: transport.boxed(),
37 config,
38 api_key,
39 })
40 }
41
42 pub fn with_transport_ref(config: ProviderConfig, transport: DynHttpTransportRef) -> Result<Self, AiLibError> {
44 let api_key = env::var(&config.api_key_env).ok();
45 Ok(Self { transport, config, api_key })
46 }
47
48 fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
50 let default_role = "user".to_string();
51
52 let messages: Vec<serde_json::Value> = request.messages.iter().map(|msg| {
54 let role_key = format!("{:?}", msg.role);
55 let mapped_role = self.config.field_mapping.role_mapping
56 .get(&role_key)
57 .unwrap_or(&default_role);
58 serde_json::json!({
59 "role": mapped_role,
60 "content": msg.content
61 })
62 }).collect();
63
64 let mut provider_request = serde_json::json!({
66 "model": request.model,
67 "messages": messages
68 });
69
70 if let Some(temp) = request.temperature {
72 provider_request["temperature"] = serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
73 }
74 if let Some(max_tokens) = request.max_tokens {
75 provider_request["max_tokens"] = serde_json::Value::Number(serde_json::Number::from(max_tokens));
76 }
77 if let Some(top_p) = request.top_p {
78 provider_request["top_p"] = serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
79 }
80 if let Some(freq_penalty) = request.frequency_penalty {
81 provider_request["frequency_penalty"] = serde_json::Value::Number(serde_json::Number::from_f64(freq_penalty.into()).unwrap());
82 }
83 if let Some(presence_penalty) = request.presence_penalty {
84 provider_request["presence_penalty"] = serde_json::Value::Number(serde_json::Number::from_f64(presence_penalty.into()).unwrap());
85 }
86
87 provider_request
88 }
89
90 fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
92 let mut i = 0;
93 while i < buffer.len().saturating_sub(1) {
94 if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
95 return Some(i + 2);
96 }
97 if i < buffer.len().saturating_sub(3)
98 && buffer[i] == b'\r' && buffer[i + 1] == b'\n'
99 && buffer[i + 2] == b'\r' && buffer[i + 3] == b'\n' {
100 return Some(i + 4);
101 }
102 i += 1;
103 }
104 None
105 }
106
107 fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
109 for line in event_text.lines() {
110 let line = line.trim();
111 if line.starts_with("data: ") {
112 let data = &line[6..];
113 if data == "[DONE]" {
114 return Some(Ok(None));
115 }
116 return Some(Self::parse_chunk_data(data));
117 }
118 }
119 None
120 }
121
122 fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
124 match serde_json::from_str::<serde_json::Value>(data) {
125 Ok(json) => {
126 let choices = json["choices"].as_array()
127 .map(|arr| {
128 arr.iter()
129 .enumerate()
130 .map(|(index, choice)| {
131 let delta = &choice["delta"];
132 ChoiceDelta {
133 index: index as u32,
134 delta: MessageDelta {
135 role: delta["role"].as_str().map(|r| match r {
136 "assistant" => Role::Assistant,
137 "user" => Role::User,
138 "system" => Role::System,
139 _ => Role::Assistant,
140 }),
141 content: delta["content"].as_str().map(str::to_string),
142 },
143 finish_reason: choice["finish_reason"].as_str().map(str::to_string),
144 }
145 })
146 .collect()
147 })
148 .unwrap_or_default();
149
150 Ok(Some(ChatCompletionChunk {
151 id: json["id"].as_str().unwrap_or_default().to_string(),
152 object: json["object"].as_str().unwrap_or("chat.completion.chunk").to_string(),
153 created: json["created"].as_u64().unwrap_or(0),
154 model: json["model"].as_str().unwrap_or_default().to_string(),
155 choices,
156 }))
157 }
158 Err(e) => Err(AiLibError::ProviderError(format!("JSON parse error: {}", e)))
159 }
160 }
161
162 fn parse_response(&self, response: serde_json::Value) -> Result<ChatCompletionResponse, AiLibError> {
164 let choices = response["choices"]
165 .as_array()
166 .ok_or_else(|| AiLibError::ProviderError("Invalid response format: choices not found".to_string()))?
167 .iter()
168 .enumerate()
169 .map(|(index, choice)| {
170 let message = choice["message"].as_object()
171 .ok_or_else(|| AiLibError::ProviderError("Invalid choice format".to_string()))?;
172
173 let role = match message["role"].as_str().unwrap_or("user") {
174 "system" => Role::System,
175 "assistant" => Role::Assistant,
176 _ => Role::User,
177 };
178
179 let content = message["content"].as_str()
180 .unwrap_or("")
181 .to_string();
182
183 Ok(Choice {
184 index: index as u32,
185 message: Message { role, content },
186 finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
187 })
188 })
189 .collect::<Result<Vec<_>, AiLibError>>()?;
190
191 let usage = response["usage"].as_object()
192 .ok_or_else(|| AiLibError::ProviderError("Invalid response format: usage not found".to_string()))?;
193
194 let usage = Usage {
195 prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
196 completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
197 total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
198 };
199
200 Ok(ChatCompletionResponse {
201 id: response["id"].as_str().unwrap_or("").to_string(),
202 object: response["object"].as_str().unwrap_or("").to_string(),
203 created: response["created"].as_u64().unwrap_or(0),
204 model: response["model"].as_str().unwrap_or("").to_string(),
205 choices,
206 usage,
207 })
208 }
209}
210
211#[async_trait::async_trait]
212impl ChatApi for GenericAdapter {
213 async fn chat_completion(&self, request: ChatCompletionRequest) -> Result<ChatCompletionResponse, AiLibError> {
214 let provider_request = self.convert_request(&request);
215 let url = format!("{}{}", self.config.base_url, self.config.chat_endpoint);
216
217 let mut headers = self.config.headers.clone();
218
219 if let Some(key) = &self.api_key {
221 if self.config.base_url.contains("anthropic.com") {
222 headers.insert("x-api-key".to_string(), key.clone());
223 } else {
224 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
225 }
226 }
227
228 let response: serde_json::Value = self.transport
229 .post_json(&url, Some(headers), provider_request)
230 .await?;
231
232 self.parse_response(response)
233 }
234
235 async fn chat_completion_stream(&self, request: ChatCompletionRequest) -> Result<Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>, AiLibError> {
236 let mut stream_request = self.convert_request(&request);
237 stream_request["stream"] = serde_json::Value::Bool(true);
238
239 let url = format!("{}{}", self.config.base_url, self.config.chat_endpoint);
240
241 let mut client_builder = reqwest::Client::builder();
243 if let Ok(proxy_url) = std::env::var("AI_PROXY_URL") {
244 if let Ok(proxy) = reqwest::Proxy::all(&proxy_url) {
245 client_builder = client_builder.proxy(proxy);
246 }
247 }
248 let client = client_builder.build()
249 .map_err(|e| AiLibError::ProviderError(format!("Client error: {}", e)))?;
250
251 let mut headers = self.config.headers.clone();
252 headers.insert("Accept".to_string(), "text/event-stream".to_string());
253
254 if let Some(key) = &self.api_key {
256 if self.config.base_url.contains("anthropic.com") {
257 headers.insert("x-api-key".to_string(), key.clone());
258 } else {
259 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
260 }
261 }
262
263 let response = client
264 .post(&url)
265 .json(&stream_request);
266
267 let mut req = response;
268 for (key, value) in headers {
269 req = req.header(key, value);
270 }
271
272 let response = req.send().await
273 .map_err(|e| AiLibError::ProviderError(format!("Stream request failed: {}", e)))?;
274
275 if !response.status().is_success() {
276 let error_text = response.text().await.unwrap_or_default();
277 return Err(AiLibError::ProviderError(format!("Stream error: {}", error_text)));
278 }
279
280 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
281
282 tokio::spawn(async move {
283 let mut buffer = Vec::new();
284 let mut stream = response.bytes_stream();
285
286 while let Some(result) = stream.next().await {
287 match result {
288 Ok(bytes) => {
289 buffer.extend_from_slice(&bytes);
290
291 while let Some(event_end) = Self::find_event_boundary(&buffer) {
292 let event_bytes = buffer.drain(..event_end).collect::<Vec<_>>();
293
294 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
295 if let Some(chunk) = Self::parse_sse_event(event_text) {
296 match chunk {
297 Ok(Some(c)) => {
298 if tx.send(Ok(c)).is_err() {
299 return;
300 }
301 }
302 Ok(None) => return,
303 Err(e) => {
304 let _ = tx.send(Err(e));
305 return;
306 }
307 }
308 }
309 }
310 }
311 }
312 Err(e) => {
313 let _ = tx.send(Err(AiLibError::ProviderError(format!("Stream error: {}", e))));
314 break;
315 }
316 }
317 }
318 });
319
320 let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
321 Ok(Box::new(Box::pin(stream)))
322 }
323
324
325
326 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
327 if let Some(models_endpoint) = &self.config.models_endpoint {
328 let url = format!("{}{}", self.config.base_url, models_endpoint);
329 let mut headers = self.config.headers.clone();
330
331 if let Some(key) = &self.api_key {
333 if self.config.base_url.contains("anthropic.com") {
334 headers.insert("x-api-key".to_string(), key.clone());
335 } else {
336 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
337 }
338 }
339
340 let response: serde_json::Value = self.transport
341 .get_json(&url, Some(headers))
342 .await?;
343
344 Ok(response["data"].as_array()
345 .unwrap_or(&vec![])
346 .iter()
347 .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
348 .collect())
349 } else {
350 Err(AiLibError::ProviderError("Models endpoint not configured".to_string()))
351 }
352 }
353
354 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
355 Ok(ModelInfo {
356 id: model_id.to_string(),
357 object: "model".to_string(),
358 created: 0,
359 owned_by: "generic".to_string(),
360 permission: vec![ModelPermission {
361 id: "default".to_string(),
362 object: "model_permission".to_string(),
363 created: 0,
364 allow_create_engine: false,
365 allow_sampling: true,
366 allow_logprobs: false,
367 allow_search_indices: false,
368 allow_view: true,
369 allow_fine_tuning: false,
370 organization: "*".to_string(),
371 group: None,
372 is_blocking: false,
373 }],
374 })
375 }
376}