1use super::config::ProviderConfig;
2use crate::api::{
3 ChatApi, ChatCompletionChunk, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission,
4};
5use crate::metrics::{Metrics, NoopMetrics};
6use crate::transport::{DynHttpTransportRef, HttpTransport};
7use crate::types::{
8 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
9};
10use futures::stream::{Stream, StreamExt};
11use std::env;
12use std::sync::Arc;
13pub struct GenericAdapter {
15 transport: DynHttpTransportRef,
16 config: ProviderConfig,
17 api_key: Option<String>,
18 metrics: Arc<dyn Metrics>,
19}
20
21impl GenericAdapter {
22 pub fn new(config: ProviderConfig) -> Result<Self, AiLibError> {
23 config.validate()?;
25
26 let api_key = env::var(&config.api_key_env).ok();
30
31 Ok(Self {
32 transport: HttpTransport::new_without_proxy().boxed(),
33 config,
34 api_key,
35 metrics: Arc::new(NoopMetrics::new()),
36 })
37 }
38
39 pub fn new_with_api_key(
41 config: ProviderConfig,
42 api_key_override: Option<String>,
43 ) -> Result<Self, AiLibError> {
44 config.validate()?;
45 let api_key = api_key_override.or_else(|| env::var(&config.api_key_env).ok());
46 Ok(Self {
47 transport: HttpTransport::new_without_proxy().boxed(),
48 config,
49 api_key,
50 metrics: Arc::new(NoopMetrics::new()),
51 })
52 }
53
54 pub fn with_transport(
56 config: ProviderConfig,
57 transport: HttpTransport,
58 ) -> Result<Self, AiLibError> {
59 config.validate()?;
61
62 let api_key = env::var(&config.api_key_env).ok();
63
64 Ok(Self {
65 transport: transport.boxed(),
66 config,
67 api_key,
68 metrics: Arc::new(NoopMetrics::new()),
69 })
70 }
71
72 pub fn with_transport_api_key(
74 config: ProviderConfig,
75 transport: HttpTransport,
76 api_key_override: Option<String>,
77 ) -> Result<Self, AiLibError> {
78 config.validate()?;
79 let api_key = api_key_override.or_else(|| env::var(&config.api_key_env).ok());
80 Ok(Self {
81 transport: transport.boxed(),
82 config,
83 api_key,
84 metrics: Arc::new(NoopMetrics::new()),
85 })
86 }
87
88 pub fn with_transport_ref(
90 config: ProviderConfig,
91 transport: DynHttpTransportRef,
92 ) -> Result<Self, AiLibError> {
93 config.validate()?;
95
96 let api_key = env::var(&config.api_key_env).ok();
97 Ok(Self {
98 transport,
99 config,
100 api_key,
101 metrics: Arc::new(NoopMetrics::new()),
102 })
103 }
104
105 pub fn with_transport_ref_api_key(
107 config: ProviderConfig,
108 transport: DynHttpTransportRef,
109 api_key_override: Option<String>,
110 ) -> Result<Self, AiLibError> {
111 config.validate()?;
112 let api_key = api_key_override.or_else(|| env::var(&config.api_key_env).ok());
113 Ok(Self {
114 transport,
115 config,
116 api_key,
117 metrics: Arc::new(NoopMetrics::new()),
118 })
119 }
120
121 pub fn with_transport_ref_and_metrics(
123 config: ProviderConfig,
124 transport: DynHttpTransportRef,
125 metrics: Arc<dyn Metrics>,
126 ) -> Result<Self, AiLibError> {
127 config.validate()?;
129
130 let api_key = env::var(&config.api_key_env).ok();
131 Ok(Self {
132 transport,
133 config,
134 api_key,
135 metrics,
136 })
137 }
138
139 pub fn with_metrics(
141 config: ProviderConfig,
142 metrics: Arc<dyn Metrics>,
143 ) -> Result<Self, AiLibError> {
144 config.validate()?;
146
147 let api_key = env::var(&config.api_key_env).ok();
148 Ok(Self {
149 transport: HttpTransport::new().boxed(),
150 config,
151 api_key,
152 metrics,
153 })
154 }
155
156 async fn convert_request(
158 &self,
159 request: &ChatCompletionRequest,
160 ) -> Result<serde_json::Value, AiLibError> {
161 let default_role = "user".to_string();
162
163 let mut messages: Vec<serde_json::Value> = Vec::with_capacity(request.messages.len());
165 for msg in request.messages.iter() {
166 let role_key = format!("{:?}", msg.role);
167 let mapped_role = self
168 .config
169 .field_mapping
170 .role_mapping
171 .get(&role_key)
172 .unwrap_or(&default_role)
173 .clone();
174
175 let content_val = match &msg.content {
177 crate::types::common::Content::Image {
178 url,
179 mime: _mime,
180 name,
181 } => {
182 if url.is_some() {
183 crate::provider::utils::content_to_provider_value(&msg.content)
184 } else if let Some(n) = name {
185 if let Some(upload_ep) = &self.config.upload_endpoint {
186 let upload_url = format!(
187 "{}{}",
188 self.config.base_url.trim_end_matches('/'),
189 upload_ep
190 );
191 let should_upload = match self.config.upload_size_limit {
193 Some(limit) => match std::fs::metadata(n) {
194 Ok(meta) => meta.len() > limit,
195 Err(_) => true, },
197 None => true, };
199
200 if should_upload {
201 match crate::provider::utils::upload_file_with_transport(
203 Some(self.transport.clone()),
204 &upload_url,
205 n,
206 "file",
207 )
208 .await
209 {
210 Ok(remote_url) => {
211 if remote_url.starts_with("http://")
212 || remote_url.starts_with("https://")
213 || remote_url.starts_with("data:")
214 {
215 serde_json::json!({"image": {"url": remote_url}})
216 } else {
217 serde_json::json!({"image": {"file_id": remote_url}})
218 }
219 }
220 Err(_) => crate::provider::utils::content_to_provider_value(
221 &msg.content,
222 ),
223 }
224 } else {
225 crate::provider::utils::content_to_provider_value(&msg.content)
227 }
228 } else {
229 crate::provider::utils::content_to_provider_value(&msg.content)
230 }
231 } else {
232 crate::provider::utils::content_to_provider_value(&msg.content)
233 }
234 }
235 _ => crate::provider::utils::content_to_provider_value(&msg.content),
236 };
237
238 messages.push(serde_json::json!({"role": mapped_role, "content": content_val}));
239 }
240
241 let mut provider_request = serde_json::json!({
243 "model": request.model,
244 "messages": messages
245 });
246
247 if let Some(temp) = request.temperature {
249 provider_request["temperature"] =
250 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
251 }
252 if let Some(max_tokens) = request.max_tokens {
253 provider_request["max_tokens"] =
254 serde_json::Value::Number(serde_json::Number::from(max_tokens));
255 }
256 if let Some(top_p) = request.top_p {
257 provider_request["top_p"] =
258 serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
259 }
260 if let Some(freq_penalty) = request.frequency_penalty {
261 provider_request["frequency_penalty"] = serde_json::Value::Number(
262 serde_json::Number::from_f64(freq_penalty.into()).unwrap(),
263 );
264 }
265 if let Some(presence_penalty) = request.presence_penalty {
266 provider_request["presence_penalty"] = serde_json::Value::Number(
267 serde_json::Number::from_f64(presence_penalty.into()).unwrap(),
268 );
269 }
270
271 Ok(provider_request)
272 }
273
274 fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
276 let mut i = 0;
277 while i < buffer.len().saturating_sub(1) {
278 if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
279 return Some(i + 2);
280 }
281 if i < buffer.len().saturating_sub(3)
282 && buffer[i] == b'\r'
283 && buffer[i + 1] == b'\n'
284 && buffer[i + 2] == b'\r'
285 && buffer[i + 3] == b'\n'
286 {
287 return Some(i + 4);
288 }
289 i += 1;
290 }
291 None
292 }
293
294 fn parse_sse_event(
296 event_text: &str,
297 ) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
298 for line in event_text.lines() {
299 let line = line.trim();
300 if let Some(stripped) = line.strip_prefix("data: ") {
301 let data = stripped;
302 if data == "[DONE]" {
303 return Some(Ok(None));
304 }
305 return Some(Self::parse_chunk_data(data));
306 }
307 }
308 None
309 }
310
311 fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
313 match serde_json::from_str::<serde_json::Value>(data) {
314 Ok(json) => {
315 let choices = json["choices"]
316 .as_array()
317 .map(|arr| {
318 arr.iter()
319 .enumerate()
320 .map(|(index, choice)| {
321 let delta = &choice["delta"];
322 ChoiceDelta {
323 index: index as u32,
324 delta: MessageDelta {
325 role: delta["role"].as_str().map(|r| match r {
326 "assistant" => Role::Assistant,
327 "user" => Role::User,
328 "system" => Role::System,
329 _ => Role::Assistant,
330 }),
331 content: delta["content"].as_str().map(str::to_string),
332 },
333 finish_reason: choice["finish_reason"]
334 .as_str()
335 .map(str::to_string),
336 }
337 })
338 .collect()
339 })
340 .unwrap_or_default();
341
342 Ok(Some(ChatCompletionChunk {
343 id: json["id"].as_str().unwrap_or_default().to_string(),
344 object: json["object"]
345 .as_str()
346 .unwrap_or("chat.completion.chunk")
347 .to_string(),
348 created: json["created"].as_u64().unwrap_or(0),
349 model: json["model"].as_str().unwrap_or_default().to_string(),
350 choices,
351 }))
352 }
353 Err(e) => Err(AiLibError::ProviderError(format!(
354 "JSON parse error: {}",
355 e
356 ))),
357 }
358 }
359
360 fn parse_response(
362 &self,
363 response: serde_json::Value,
364 ) -> Result<ChatCompletionResponse, AiLibError> {
365 let choices = response["choices"]
366 .as_array()
367 .ok_or_else(|| {
368 AiLibError::ProviderError("Invalid response format: choices not found".to_string())
369 })?
370 .iter()
371 .enumerate()
372 .map(|(index, choice)| {
373 let message = choice["message"].as_object().ok_or_else(|| {
374 AiLibError::ProviderError("Invalid choice format".to_string())
375 })?;
376
377 let role = match message["role"].as_str().unwrap_or("user") {
378 "system" => Role::System,
379 "assistant" => Role::Assistant,
380 _ => Role::User,
381 };
382
383 let content = message["content"].as_str().unwrap_or("").to_string();
384
385 let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
387 if let Some(fc_val) = message.get("function_call") {
388 if let Ok(mut fc) = serde_json::from_value::<
390 crate::types::function_call::FunctionCall,
391 >(fc_val.clone())
392 {
393 if let Some(arg_val) = &fc.arguments {
395 if arg_val.is_string() {
396 if let Some(s) = arg_val.as_str() {
397 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s)
398 {
399 fc.arguments = Some(parsed);
400 }
401 }
402 }
403 }
404 function_call = Some(fc);
405 } else {
406 let name = fc_val
408 .get("name")
409 .and_then(|v| v.as_str())
410 .map(|s| s.to_string());
411 if let Some(name) = name {
412 let args = fc_val.get("arguments").and_then(|a| {
413 if a.is_string() {
414 serde_json::from_str::<serde_json::Value>(a.as_str().unwrap())
415 .ok()
416 } else {
417 Some(a.clone())
418 }
419 });
420
421 function_call = Some(crate::types::function_call::FunctionCall {
422 name,
423 arguments: args,
424 });
425 }
426 }
427 }
428
429 Ok(Choice {
430 index: index as u32,
431 message: Message {
432 role,
433 content: crate::types::common::Content::Text(content),
434 function_call,
435 },
436 finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
437 })
438 })
439 .collect::<Result<Vec<_>, AiLibError>>()?;
440
441 let usage = response["usage"].as_object().ok_or_else(|| {
442 AiLibError::ProviderError("Invalid response format: usage not found".to_string())
443 })?;
444
445 let usage = Usage {
446 prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
447 completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
448 total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
449 };
450
451 Ok(ChatCompletionResponse {
452 id: response["id"].as_str().unwrap_or("").to_string(),
453 object: response["object"].as_str().unwrap_or("").to_string(),
454 created: response["created"].as_u64().unwrap_or(0),
455 model: response["model"].as_str().unwrap_or("").to_string(),
456 choices,
457 usage,
458 })
459 }
460}
461#[async_trait::async_trait]
462impl ChatApi for GenericAdapter {
463 async fn chat_completion(
464 &self,
465 request: ChatCompletionRequest,
466 ) -> Result<ChatCompletionResponse, AiLibError> {
467 self.metrics.incr_counter("generic.requests", 1).await;
469 let timer = self
470 .metrics
471 .start_timer("generic.request_duration_ms")
472 .await;
473
474 let provider_request = self.convert_request(&request).await?;
475 let url = self.config.chat_url();
476
477 let mut headers = self.config.headers.clone();
478
479 if let Some(key) = &self.api_key {
481 if self.config.base_url.contains("anthropic.com") {
482 headers.insert("x-api-key".to_string(), key.clone());
483 } else {
484 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
485 }
486 }
487
488 let response = self
489 .transport
490 .post_json(&url, Some(headers.clone()), provider_request.clone())
491 .await?;
492
493 if let Some(t) = timer {
495 t.stop();
496 }
497
498 self.parse_response(response)
499 }
500
501 async fn chat_completion_stream(
502 &self,
503 request: ChatCompletionRequest,
504 ) -> Result<
505 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
506 AiLibError,
507 > {
508 let mut stream_request = self.convert_request(&request).await?;
509 stream_request["stream"] = serde_json::Value::Bool(true);
510
511 let url = self.config.chat_url();
512
513 let mut client_builder = reqwest::Client::builder();
515 if let Ok(proxy_url) = std::env::var("AI_PROXY_URL") {
516 if let Ok(proxy) = reqwest::Proxy::all(&proxy_url) {
517 client_builder = client_builder.proxy(proxy);
518 }
519 }
520 let client = client_builder
521 .build()
522 .map_err(|e| AiLibError::ProviderError(format!("Client error: {}", e)))?;
523
524 let mut headers = self.config.headers.clone();
525 headers.insert("Accept".to_string(), "text/event-stream".to_string());
526
527 if let Some(key) = &self.api_key {
529 if self.config.base_url.contains("anthropic.com") {
530 headers.insert("x-api-key".to_string(), key.clone());
531 } else {
532 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
533 }
534 }
535
536 let response = client.post(&url).json(&stream_request);
537
538 let mut req = response;
539 for (key, value) in headers {
540 req = req.header(key, value);
541 }
542
543 let response = req
544 .send()
545 .await
546 .map_err(|e| AiLibError::ProviderError(format!("Stream request failed: {}", e)))?;
547
548 if !response.status().is_success() {
549 let error_text = response.text().await.unwrap_or_default();
550 return Err(AiLibError::ProviderError(format!(
551 "Stream error: {}",
552 error_text
553 )));
554 }
555
556 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
557
558 tokio::spawn(async move {
559 let mut buffer = Vec::new();
560 let mut stream = response.bytes_stream();
561
562 while let Some(result) = stream.next().await {
563 match result {
564 Ok(bytes) => {
565 buffer.extend_from_slice(&bytes);
566
567 while let Some(event_end) = Self::find_event_boundary(&buffer) {
568 let event_bytes = buffer.drain(..event_end).collect::<Vec<_>>();
569
570 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
571 if let Some(chunk) = Self::parse_sse_event(event_text) {
572 match chunk {
573 Ok(Some(c)) => {
574 if tx.send(Ok(c)).is_err() {
575 return;
576 }
577 }
578 Ok(None) => return,
579 Err(e) => {
580 let _ = tx.send(Err(e));
581 return;
582 }
583 }
584 }
585 }
586 }
587 }
588 Err(e) => {
589 let _ = tx.send(Err(AiLibError::ProviderError(format!(
590 "Stream error: {}",
591 e
592 ))));
593 break;
594 }
595 }
596 }
597 });
598
599 let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
600 Ok(Box::new(Box::pin(stream)))
601 }
602
603 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
604 if let Some(models_endpoint) = &self.config.models_endpoint {
605 let url = format!("{}{}", self.config.base_url, models_endpoint);
606 let mut headers = self.config.headers.clone();
607
608 if let Some(key) = &self.api_key {
610 if self.config.base_url.contains("anthropic.com") {
611 headers.insert("x-api-key".to_string(), key.clone());
612 } else {
613 headers.insert("Authorization".to_string(), format!("Bearer {}", key));
614 }
615 }
616
617 let response: serde_json::Value = self.transport.get_json(&url, Some(headers)).await?;
618
619 Ok(response["data"]
620 .as_array()
621 .unwrap_or(&vec![])
622 .iter()
623 .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
624 .collect())
625 } else {
626 Err(AiLibError::ProviderError(
627 "Models endpoint not configured".to_string(),
628 ))
629 }
630 }
631
632 async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
633 Ok(ModelInfo {
634 id: model_id.to_string(),
635 object: "model".to_string(),
636 created: 0,
637 owned_by: "generic".to_string(),
638 permission: vec![ModelPermission {
639 id: "default".to_string(),
640 object: "model_permission".to_string(),
641 created: 0,
642 allow_create_engine: false,
643 allow_sampling: true,
644 allow_logprobs: false,
645 allow_search_indices: false,
646 allow_view: true,
647 allow_fine_tuning: false,
648 organization: "*".to_string(),
649 group: None,
650 is_blocking: false,
651 }],
652 })
653 }
654}