1use super::config::ProviderConfig;
2use crate::api::{
3 ChatApi, ChatCompletionChunk, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission,
4};
5use crate::metrics::{Metrics, NoopMetrics};
6use crate::transport::{DynHttpTransportRef, HttpTransport};
7use crate::types::{
8 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
9};
10use futures::stream::{Stream, StreamExt};
11use std::env;
12use std::sync::Arc;
13pub struct GenericAdapter {
15 transport: DynHttpTransportRef,
16 config: ProviderConfig,
17 api_key: Option<String>,
18 metrics: Arc<dyn Metrics>,
19}
20
21impl GenericAdapter {
22 pub fn new(config: ProviderConfig) -> Result<Self, AiLibError> {
23 config.validate()?;
25
26 let api_key = env::var(&config.api_key_env).ok();
30
31 Ok(Self {
32 transport: HttpTransport::new().boxed(),
33 config,
34 api_key,
35 metrics: Arc::new(NoopMetrics::new()),
36 })
37 }
38
39 pub fn with_transport(
41 config: ProviderConfig,
42 transport: HttpTransport,
43 ) -> Result<Self, AiLibError> {
44 config.validate()?;
46
47 let api_key = env::var(&config.api_key_env).ok();
48
49 Ok(Self {
50 transport: transport.boxed(),
51 config,
52 api_key,
53 metrics: Arc::new(NoopMetrics::new()),
54 })
55 }
56
57 pub fn with_transport_ref(
59 config: ProviderConfig,
60 transport: DynHttpTransportRef,
61 ) -> Result<Self, AiLibError> {
62 config.validate()?;
64
65 let api_key = env::var(&config.api_key_env).ok();
66 Ok(Self {
67 transport,
68 config,
69 api_key,
70 metrics: Arc::new(NoopMetrics::new()),
71 })
72 }
73
74 pub fn with_transport_ref_and_metrics(
76 config: ProviderConfig,
77 transport: DynHttpTransportRef,
78 metrics: Arc<dyn Metrics>,
79 ) -> Result<Self, AiLibError> {
80 config.validate()?;
82
83 let api_key = env::var(&config.api_key_env).ok();
84 Ok(Self {
85 transport,
86 config,
87 api_key,
88 metrics,
89 })
90 }
91
92 pub fn with_metrics(
94 config: ProviderConfig,
95 metrics: Arc<dyn Metrics>,
96 ) -> Result<Self, AiLibError> {
97 config.validate()?;
99
100 let api_key = env::var(&config.api_key_env).ok();
101 Ok(Self {
102 transport: HttpTransport::new().boxed(),
103 config,
104 api_key,
105 metrics,
106 })
107 }
108
109 async fn convert_request(
111 &self,
112 request: &ChatCompletionRequest,
113 ) -> Result<serde_json::Value, AiLibError> {
114 let default_role = "user".to_string();
115
116 let mut messages: Vec<serde_json::Value> = Vec::with_capacity(request.messages.len());
118 for msg in request.messages.iter() {
119 let role_key = format!("{:?}", msg.role);
120 let mapped_role = self
121 .config
122 .field_mapping
123 .role_mapping
124 .get(&role_key)
125 .unwrap_or(&default_role)
126 .clone();
127
128 let content_val = match &msg.content {
130 crate::types::common::Content::Image {
131 url,
132 mime: _mime,
133 name,
134 } => {
135 if url.is_some() {
136 crate::provider::utils::content_to_provider_value(&msg.content)
137 } else if let Some(n) = name {
138 if let Some(upload_ep) = &self.config.upload_endpoint {
139 let upload_url = format!(
140 "{}{}",
141 self.config.base_url.trim_end_matches('/'),
142 upload_ep
143 );
144 let should_upload = match self.config.upload_size_limit {
146 Some(limit) => match std::fs::metadata(n) {
147 Ok(meta) => meta.len() > limit,
148 Err(_) => true, },
150 None => true, };
152
153 if should_upload {
154 match crate::provider::utils::upload_file_with_transport(
156 Some(self.transport.clone()),
157 &upload_url,
158 n,
159 "file",
160 )
161 .await
162 {
163 Ok(remote_url) => {
164 if remote_url.starts_with("http://")
165 || remote_url.starts_with("https://")
166 || remote_url.starts_with("data:")
167 {
168 serde_json::json!({"image": {"url": remote_url}})
169 } else {
170 serde_json::json!({"image": {"file_id": remote_url}})
171 }
172 }
173 Err(_) => crate::provider::utils::content_to_provider_value(
174 &msg.content,
175 ),
176 }
177 } else {
178 crate::provider::utils::content_to_provider_value(&msg.content)
180 }
181 } else {
182 crate::provider::utils::content_to_provider_value(&msg.content)
183 }
184 } else {
185 crate::provider::utils::content_to_provider_value(&msg.content)
186 }
187 }
188 _ => crate::provider::utils::content_to_provider_value(&msg.content),
189 };
190
191 messages.push(serde_json::json!({"role": mapped_role, "content": content_val}));
192 }
193
194 let mut provider_request = serde_json::json!({
196 "model": request.model,
197 "messages": messages
198 });
199
200 if let Some(temp) = request.temperature {
202 provider_request["temperature"] =
203 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
204 }
205 if let Some(max_tokens) = request.max_tokens {
206 provider_request["max_tokens"] =
207 serde_json::Value::Number(serde_json::Number::from(max_tokens));
208 }
209 if let Some(top_p) = request.top_p {
210 provider_request["top_p"] =
211 serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
212 }
213 if let Some(freq_penalty) = request.frequency_penalty {
214 provider_request["frequency_penalty"] = serde_json::Value::Number(
215 serde_json::Number::from_f64(freq_penalty.into()).unwrap(),
216 );
217 }
218 if let Some(presence_penalty) = request.presence_penalty {
219 provider_request["presence_penalty"] = serde_json::Value::Number(
220 serde_json::Number::from_f64(presence_penalty.into()).unwrap(),
221 );
222 }
223
224 Ok(provider_request)
225 }
226
227 fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
229 let mut i = 0;
230 while i < buffer.len().saturating_sub(1) {
231 if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
232 return Some(i + 2);
233 }
234 if i < buffer.len().saturating_sub(3)
235 && buffer[i] == b'\r'
236 && buffer[i + 1] == b'\n'
237 && buffer[i + 2] == b'\r'
238 && buffer[i + 3] == b'\n'
239 {
240 return Some(i + 4);
241 }
242 i += 1;
243 }
244 None
245 }
246
247 fn parse_sse_event(
249 event_text: &str,
250 ) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
251 for line in event_text.lines() {
252 let line = line.trim();
253 if let Some(stripped) = line.strip_prefix("data: ") {
254 let data = stripped;
255 if data == "[DONE]" {
256 return Some(Ok(None));
257 }
258 return Some(Self::parse_chunk_data(data));
259 }
260 }
261 None
262 }
263
264 fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
266 match serde_json::from_str::<serde_json::Value>(data) {
267 Ok(json) => {
268 let choices = json["choices"]
269 .as_array()
270 .map(|arr| {
271 arr.iter()
272 .enumerate()
273 .map(|(index, choice)| {
274 let delta = &choice["delta"];
275 ChoiceDelta {
276 index: index as u32,
277 delta: MessageDelta {
278 role: delta["role"].as_str().map(|r| match r {
279 "assistant" => Role::Assistant,
280 "user" => Role::User,
281 "system" => Role::System,
282 _ => Role::Assistant,
283 }),
284 content: delta["content"].as_str().map(str::to_string),
285 },
286 finish_reason: choice["finish_reason"]
287 .as_str()
288 .map(str::to_string),
289 }
290 })
291 .collect()
292 })
293 .unwrap_or_default();
294
295 Ok(Some(ChatCompletionChunk {
296 id: json["id"].as_str().unwrap_or_default().to_string(),
297 object: json["object"]
298 .as_str()
299 .unwrap_or("chat.completion.chunk")
300 .to_string(),
301 created: json["created"].as_u64().unwrap_or(0),
302 model: json["model"].as_str().unwrap_or_default().to_string(),
303 choices,
304 }))
305 }
306 Err(e) => Err(AiLibError::ProviderError(format!(
307 "JSON parse error: {}",
308 e
309 ))),
310 }
311 }
312
313 fn parse_response(
315 &self,
316 response: serde_json::Value,
317 ) -> Result<ChatCompletionResponse, AiLibError> {
318 let choices = response["choices"]
319 .as_array()
320 .ok_or_else(|| {
321 AiLibError::ProviderError("Invalid response format: choices not found".to_string())
322 })?
323 .iter()
324 .enumerate()
325 .map(|(index, choice)| {
326 let message = choice["message"].as_object().ok_or_else(|| {
327 AiLibError::ProviderError("Invalid choice format".to_string())
328 })?;
329
330 let role = match message["role"].as_str().unwrap_or("user") {
331 "system" => Role::System,
332 "assistant" => Role::Assistant,
333 _ => Role::User,
334 };
335
336 let content = message["content"].as_str().unwrap_or("").to_string();
337
338 let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
340 if let Some(fc_val) = message.get("function_call") {
341 if let Ok(mut fc) = serde_json::from_value::<
343 crate::types::function_call::FunctionCall,
344 >(fc_val.clone())
345 {
346 if let Some(arg_val) = &fc.arguments {
348 if arg_val.is_string() {
349 if let Some(s) = arg_val.as_str() {
350 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s)
351 {
352 fc.arguments = Some(parsed);
353 }
354 }
355 }
356 }
357 function_call = Some(fc);
358 } else {
359 let name = fc_val
361 .get("name")
362 .and_then(|v| v.as_str())
363 .map(|s| s.to_string());
364 if let Some(name) = name {
365 let args = fc_val.get("arguments").and_then(|a| {
366 if a.is_string() {
367 serde_json::from_str::<serde_json::Value>(a.as_str().unwrap())
368 .ok()
369 } else {
370 Some(a.clone())
371 }
372 });
373
374 function_call = Some(crate::types::function_call::FunctionCall {
375 name,
376 arguments: args,
377 });
378 }
379 }
380 }
381
382 Ok(Choice {
383 index: index as u32,
384 message: Message {
385 role,
386 content: crate::types::common::Content::Text(content),
387 function_call,
388 },
389 finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
390 })
391 })
392 .collect::<Result<Vec<_>, AiLibError>>()?;
393
394 let usage = response["usage"].as_object().ok_or_else(|| {
395 AiLibError::ProviderError("Invalid response format: usage not found".to_string())
396 })?;
397
398 let usage = Usage {
399 prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
400 completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
401 total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
402 };
403
404 Ok(ChatCompletionResponse {
405 id: response["id"].as_str().unwrap_or("").to_string(),
406 object: response["object"].as_str().unwrap_or("").to_string(),
407 created: response["created"].as_u64().unwrap_or(0),
408 model: response["model"].as_str().unwrap_or("").to_string(),
409 choices,
410 usage,
411 })
412 }
413}
414#[async_trait::async_trait]
415impl ChatApi for GenericAdapter {
416 async fn chat_completion(
417 &self,
418 request: ChatCompletionRequest,
419 ) -> Result<ChatCompletionResponse, AiLibError> {
420 self.metrics.incr_counter("generic.requests", 1).await;
422 let timer = self
423 .metrics
424 .start_timer("generic.request_duration_ms")
425 .await;
426
427 let provider_request = self.convert_request(&request).await?;
428 let url = self.config.chat_url();
429
430 let mut headers = self.config.headers.clone();
431
432 if let Some(key) = &self.api_key {
434 if self.config.base_url.contains("anthropic.com") {
435 headers.insert("x-api-key".to_string(), key.clone());
436 } else {
437 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
438 }
439 }
440
441 let response = match self
442 .transport
443 .post_json(&url, Some(headers), provider_request)
444 .await
445 {
446 Ok(v) => {
447 if let Some(t) = timer {
448 t.stop();
449 let _ = self.metrics.incr_counter("generic.request_timer_recorded", 1).await;
451 }
452 v
453 }
454 Err(e) => {
455 if let Some(t) = timer {
456 t.stop();
457 let _ = self.metrics.incr_counter("generic.request_timer_recorded", 1).await;
458 }
459 return Err(e);
460 }
461 };
462
463 self.parse_response(response)
464 }
465
466 async fn chat_completion_stream(
467 &self,
468 request: ChatCompletionRequest,
469 ) -> Result<
470 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
471 AiLibError,
472 > {
473 let mut stream_request = self.convert_request(&request).await?;
474 stream_request["stream"] = serde_json::Value::Bool(true);
475
476 let url = self.config.chat_url();
477
478 let mut client_builder = reqwest::Client::builder();
480 if let Ok(proxy_url) = std::env::var("AI_PROXY_URL") {
481 if let Ok(proxy) = reqwest::Proxy::all(&proxy_url) {
482 client_builder = client_builder.proxy(proxy);
483 }
484 }
485 let client = client_builder
486 .build()
487 .map_err(|e| AiLibError::ProviderError(format!("Client error: {}", e)))?;
488
489 let mut headers = self.config.headers.clone();
490 headers.insert("Accept".to_string(), "text/event-stream".to_string());
491
492 if let Some(key) = &self.api_key {
494 if self.config.base_url.contains("anthropic.com") {
495 headers.insert("x-api-key".to_string(), key.clone());
496 } else {
497 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
498 }
499 }
500
501 let response = client.post(&url).json(&stream_request);
502
503 let mut req = response;
504 for (key, value) in headers {
505 req = req.header(key, value);
506 }
507
508 let response = req
509 .send()
510 .await
511 .map_err(|e| AiLibError::ProviderError(format!("Stream request failed: {}", e)))?;
512
513 if !response.status().is_success() {
514 let error_text = response.text().await.unwrap_or_default();
515 return Err(AiLibError::ProviderError(format!(
516 "Stream error: {}",
517 error_text
518 )));
519 }
520
521 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
522
523 tokio::spawn(async move {
524 let mut buffer = Vec::new();
525 let mut stream = response.bytes_stream();
526
527 while let Some(result) = stream.next().await {
528 match result {
529 Ok(bytes) => {
530 buffer.extend_from_slice(&bytes);
531
532 while let Some(event_end) = Self::find_event_boundary(&buffer) {
533 let event_bytes = buffer.drain(..event_end).collect::<Vec<_>>();
534
535 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
536 if let Some(chunk) = Self::parse_sse_event(event_text) {
537 match chunk {
538 Ok(Some(c)) => {
539 if tx.send(Ok(c)).is_err() {
540 return;
541 }
542 }
543 Ok(None) => return,
544 Err(e) => {
545 let _ = tx.send(Err(e));
546 return;
547 }
548 }
549 }
550 }
551 }
552 }
553 Err(e) => {
554 let _ = tx.send(Err(AiLibError::ProviderError(format!(
555 "Stream error: {}",
556 e
557 ))));
558 break;
559 }
560 }
561 }
562 });
563
564 let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
565 Ok(Box::new(Box::pin(stream)))
566 }
567
568 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
569 if let Some(models_endpoint) = &self.config.models_endpoint {
570 let url = format!("{}{}", self.config.base_url, models_endpoint);
571 let mut headers = self.config.headers.clone();
572
573 if let Some(key) = &self.api_key {
575 if self.config.base_url.contains("anthropic.com") {
576 headers.insert("x-api-key".to_string(), key.clone());
577 } else {
578 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
579 }
580 }
581
582 let response: serde_json::Value = self.transport.get_json(&url, Some(headers)).await?;
583
584 Ok(response["data"]
585 .as_array()
586 .unwrap_or(&vec![])
587 .iter()
588 .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
589 .collect())
590 } else {
591 Err(AiLibError::ProviderError(
592 "Models endpoint not configured".to_string(),
593 ))
594 }
595 }
596
597 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
598 Ok(ModelInfo {
599 id: model_id.to_string(),
600 object: "model".to_string(),
601 created: 0,
602 owned_by: "generic".to_string(),
603 permission: vec![ModelPermission {
604 id: "default".to_string(),
605 object: "model_permission".to_string(),
606 created: 0,
607 allow_create_engine: false,
608 allow_sampling: true,
609 allow_logprobs: false,
610 allow_search_indices: false,
611 allow_view: true,
612 allow_fine_tuning: false,
613 organization: "*".to_string(),
614 group: None,
615 is_blocking: false,
616 }],
617 })
618 }
619}