ai_lib/provider/
generic.rs

1use super::config::ProviderConfig;
2use crate::api::{
3    ChatApi, ChatCompletionChunk, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission,
4};
5use crate::metrics::{Metrics, NoopMetrics};
6use crate::transport::{DynHttpTransportRef, HttpTransport};
7use crate::types::{
8    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
9};
10use futures::stream::{Stream, StreamExt};
11use std::env;
12use std::sync::Arc;
13/// Configuration-driven generic adapter for OpenAI-compatible APIs
14pub struct GenericAdapter {
15    transport: DynHttpTransportRef,
16    config: ProviderConfig,
17    api_key: Option<String>,
18    metrics: Arc<dyn Metrics>,
19}
20
21impl GenericAdapter {
22    pub fn new(config: ProviderConfig) -> Result<Self, AiLibError> {
23        // Validate configuration
24        config.validate()?;
25
26        // For generic/config-driven providers we treat the API key as optional.
27        // Some deployments (e.g. local Ollama) don't require a key. If the env var
28        // is missing we continue with None and callers will simply omit auth headers.
29        let api_key = env::var(&config.api_key_env).ok();
30
31        Ok(Self {
32            transport: HttpTransport::new_without_proxy().boxed(),
33            config,
34            api_key,
35            metrics: Arc::new(NoopMetrics::new()),
36        })
37    }
38
39    /// Create adapter with an explicit API key override (takes precedence over env var).
40    pub fn new_with_api_key(
41        config: ProviderConfig,
42        api_key_override: Option<String>,
43    ) -> Result<Self, AiLibError> {
44        config.validate()?;
45        let api_key = api_key_override.or_else(|| env::var(&config.api_key_env).ok());
46        Ok(Self {
47            transport: HttpTransport::new_without_proxy().boxed(),
48            config,
49            api_key,
50            metrics: Arc::new(NoopMetrics::new()),
51        })
52    }
53
54    /// Create adapter with custom transport layer (for testing)
55    pub fn with_transport(
56        config: ProviderConfig,
57        transport: HttpTransport,
58    ) -> Result<Self, AiLibError> {
59        // Validate configuration
60        config.validate()?;
61
62        let api_key = env::var(&config.api_key_env).ok();
63
64        Ok(Self {
65            transport: transport.boxed(),
66            config,
67            api_key,
68            metrics: Arc::new(NoopMetrics::new()),
69        })
70    }
71
72    /// Custom transport + API key override.
73    pub fn with_transport_api_key(
74        config: ProviderConfig,
75        transport: HttpTransport,
76        api_key_override: Option<String>,
77    ) -> Result<Self, AiLibError> {
78        config.validate()?;
79        let api_key = api_key_override.or_else(|| env::var(&config.api_key_env).ok());
80        Ok(Self {
81            transport: transport.boxed(),
82            config,
83            api_key,
84            metrics: Arc::new(NoopMetrics::new()),
85        })
86    }
87
88    /// Accept an object-safe transport reference directly
89    pub fn with_transport_ref(
90        config: ProviderConfig,
91        transport: DynHttpTransportRef,
92    ) -> Result<Self, AiLibError> {
93        // Validate configuration
94        config.validate()?;
95
96        let api_key = env::var(&config.api_key_env).ok();
97        Ok(Self {
98            transport,
99            config,
100            api_key,
101            metrics: Arc::new(NoopMetrics::new()),
102        })
103    }
104
105    /// Object-safe transport + API key override.
106    pub fn with_transport_ref_api_key(
107        config: ProviderConfig,
108        transport: DynHttpTransportRef,
109        api_key_override: Option<String>,
110    ) -> Result<Self, AiLibError> {
111        config.validate()?;
112        let api_key = api_key_override.or_else(|| env::var(&config.api_key_env).ok());
113        Ok(Self {
114            transport,
115            config,
116            api_key,
117            metrics: Arc::new(NoopMetrics::new()),
118        })
119    }
120
121    /// Create adapter with custom transport and an injected metrics implementation
122    pub fn with_transport_ref_and_metrics(
123        config: ProviderConfig,
124        transport: DynHttpTransportRef,
125        metrics: Arc<dyn Metrics>,
126    ) -> Result<Self, AiLibError> {
127        // Validate configuration
128        config.validate()?;
129
130        let api_key = env::var(&config.api_key_env).ok();
131        Ok(Self {
132            transport,
133            config,
134            api_key,
135            metrics,
136        })
137    }
138
139    /// Create adapter with injected metrics (uses default HttpTransport)
140    pub fn with_metrics(
141        config: ProviderConfig,
142        metrics: Arc<dyn Metrics>,
143    ) -> Result<Self, AiLibError> {
144        // Validate configuration
145        config.validate()?;
146
147        let api_key = env::var(&config.api_key_env).ok();
148        Ok(Self {
149            transport: HttpTransport::new().boxed(),
150            config,
151            api_key,
152            metrics,
153        })
154    }
155
156    /// Convert generic request to provider-specific format (async: may upload local files)
157    async fn convert_request(
158        &self,
159        request: &ChatCompletionRequest,
160    ) -> Result<serde_json::Value, AiLibError> {
161        let default_role = "user".to_string();
162
163        // Build messages array; may perform uploads for local files
164        let mut messages: Vec<serde_json::Value> = Vec::with_capacity(request.messages.len());
165        for msg in request.messages.iter() {
166            let role_key = format!("{:?}", msg.role);
167            let mapped_role = self
168                .config
169                .field_mapping
170                .role_mapping
171                .get(&role_key)
172                .unwrap_or(&default_role)
173                .clone();
174
175            // Handle multimodal: if image has no url but has a name and upload endpoint configured, upload it
176            let content_val = match &msg.content {
177                crate::types::common::Content::Image {
178                    url,
179                    mime: _mime,
180                    name,
181                } => {
182                    if url.is_some() {
183                        crate::provider::utils::content_to_provider_value(&msg.content)
184                    } else if let Some(n) = name {
185                        if let Some(upload_ep) = &self.config.upload_endpoint {
186                            let upload_url = format!(
187                                "{}{}",
188                                self.config.base_url.trim_end_matches('/'),
189                                upload_ep
190                            );
191                            // Decide whether to upload or inline based on configured size limit.
192                            let should_upload = match self.config.upload_size_limit {
193                                Some(limit) => match std::fs::metadata(n) {
194                                    Ok(meta) => meta.len() > limit,
195                                    Err(_) => true, // if we can't stat the file, attempt upload
196                                },
197                                None => true, // default: upload if no limit configured (preserve prior behavior)
198                            };
199
200                            if should_upload {
201                                // Use the injected transport when available so tests can mock uploads.
202                                match crate::provider::utils::upload_file_with_transport(
203                                    Some(self.transport.clone()),
204                                    &upload_url,
205                                    n,
206                                    "file",
207                                )
208                                .await
209                                {
210                                    Ok(remote_url) => {
211                                        if remote_url.starts_with("http://")
212                                            || remote_url.starts_with("https://")
213                                            || remote_url.starts_with("data:")
214                                        {
215                                            serde_json::json!({"image": {"url": remote_url}})
216                                        } else {
217                                            serde_json::json!({"image": {"file_id": remote_url}})
218                                        }
219                                    }
220                                    Err(_) => crate::provider::utils::content_to_provider_value(
221                                        &msg.content,
222                                    ),
223                                }
224                            } else {
225                                // Inline small files as data URLs
226                                crate::provider::utils::content_to_provider_value(&msg.content)
227                            }
228                        } else {
229                            crate::provider::utils::content_to_provider_value(&msg.content)
230                        }
231                    } else {
232                        crate::provider::utils::content_to_provider_value(&msg.content)
233                    }
234                }
235                _ => crate::provider::utils::content_to_provider_value(&msg.content),
236            };
237
238            messages.push(serde_json::json!({"role": mapped_role, "content": content_val}));
239        }
240
241        // Use string literals as JSON keys
242        let mut provider_request = serde_json::json!({
243            "model": request.model,
244            "messages": messages
245        });
246
247        // Add optional parameters
248        if let Some(temp) = request.temperature {
249            provider_request["temperature"] =
250                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
251        }
252        if let Some(max_tokens) = request.max_tokens {
253            provider_request["max_tokens"] =
254                serde_json::Value::Number(serde_json::Number::from(max_tokens));
255        }
256        if let Some(top_p) = request.top_p {
257            provider_request["top_p"] =
258                serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
259        }
260        if let Some(freq_penalty) = request.frequency_penalty {
261            provider_request["frequency_penalty"] = serde_json::Value::Number(
262                serde_json::Number::from_f64(freq_penalty.into()).unwrap(),
263            );
264        }
265        if let Some(presence_penalty) = request.presence_penalty {
266            provider_request["presence_penalty"] = serde_json::Value::Number(
267                serde_json::Number::from_f64(presence_penalty.into()).unwrap(),
268            );
269        }
270
271        Ok(provider_request)
272    }
273
274    /// Find event boundary
275    fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
276        let mut i = 0;
277        while i < buffer.len().saturating_sub(1) {
278            if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
279                return Some(i + 2);
280            }
281            if i < buffer.len().saturating_sub(3)
282                && buffer[i] == b'\r'
283                && buffer[i + 1] == b'\n'
284                && buffer[i + 2] == b'\r'
285                && buffer[i + 3] == b'\n'
286            {
287                return Some(i + 4);
288            }
289            i += 1;
290        }
291        None
292    }
293
294    /// Parse SSE event
295    fn parse_sse_event(
296        event_text: &str,
297    ) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
298        for line in event_text.lines() {
299            let line = line.trim();
300            if let Some(stripped) = line.strip_prefix("data: ") {
301                let data = stripped;
302                if data == "[DONE]" {
303                    return Some(Ok(None));
304                }
305                return Some(Self::parse_chunk_data(data));
306            }
307        }
308        None
309    }
310
311    /// Parse chunk data
312    fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
313        match serde_json::from_str::<serde_json::Value>(data) {
314            Ok(json) => {
315                let choices = json["choices"]
316                    .as_array()
317                    .map(|arr| {
318                        arr.iter()
319                            .enumerate()
320                            .map(|(index, choice)| {
321                                let delta = &choice["delta"];
322                                ChoiceDelta {
323                                    index: index as u32,
324                                    delta: MessageDelta {
325                                        role: delta["role"].as_str().map(|r| match r {
326                                            "assistant" => Role::Assistant,
327                                            "user" => Role::User,
328                                            "system" => Role::System,
329                                            _ => Role::Assistant,
330                                        }),
331                                        content: delta["content"].as_str().map(str::to_string),
332                                    },
333                                    finish_reason: choice["finish_reason"]
334                                        .as_str()
335                                        .map(str::to_string),
336                                }
337                            })
338                            .collect()
339                    })
340                    .unwrap_or_default();
341
342                Ok(Some(ChatCompletionChunk {
343                    id: json["id"].as_str().unwrap_or_default().to_string(),
344                    object: json["object"]
345                        .as_str()
346                        .unwrap_or("chat.completion.chunk")
347                        .to_string(),
348                    created: json["created"].as_u64().unwrap_or(0),
349                    model: json["model"].as_str().unwrap_or_default().to_string(),
350                    choices,
351                }))
352            }
353            Err(e) => Err(AiLibError::ProviderError(format!(
354                "JSON parse error: {}",
355                e
356            ))),
357        }
358    }
359
360    /// Parse response
361    fn parse_response(
362        &self,
363        response: serde_json::Value,
364    ) -> Result<ChatCompletionResponse, AiLibError> {
365        let choices = response["choices"]
366            .as_array()
367            .ok_or_else(|| {
368                AiLibError::ProviderError("Invalid response format: choices not found".to_string())
369            })?
370            .iter()
371            .enumerate()
372            .map(|(index, choice)| {
373                let message = choice["message"].as_object().ok_or_else(|| {
374                    AiLibError::ProviderError("Invalid choice format".to_string())
375                })?;
376
377                let role = match message["role"].as_str().unwrap_or("user") {
378                    "system" => Role::System,
379                    "assistant" => Role::Assistant,
380                    _ => Role::User,
381                };
382
383                let content = message["content"].as_str().unwrap_or("").to_string();
384
385                // try to parse a function_call if present
386                let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
387                if let Some(fc_val) = message.get("function_call") {
388                    // attempt full deserialization
389                    if let Ok(mut fc) = serde_json::from_value::<
390                        crate::types::function_call::FunctionCall,
391                    >(fc_val.clone())
392                    {
393                        // If the provider deserialized arguments as a JSON string, try to parse it into structured JSON.
394                        if let Some(arg_val) = &fc.arguments {
395                            if arg_val.is_string() {
396                                if let Some(s) = arg_val.as_str() {
397                                    if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s)
398                                    {
399                                        fc.arguments = Some(parsed);
400                                    }
401                                }
402                            }
403                        }
404                        function_call = Some(fc);
405                    } else {
406                        // fallback: try to extract name + arguments (arguments may be a string)
407                        let name = fc_val
408                            .get("name")
409                            .and_then(|v| v.as_str())
410                            .map(|s| s.to_string());
411                        if let Some(name) = name {
412                            let args = fc_val.get("arguments").and_then(|a| {
413                                if a.is_string() {
414                                    serde_json::from_str::<serde_json::Value>(a.as_str().unwrap())
415                                        .ok()
416                                } else {
417                                    Some(a.clone())
418                                }
419                            });
420
421                            function_call = Some(crate::types::function_call::FunctionCall {
422                                name,
423                                arguments: args,
424                            });
425                        }
426                    }
427                }
428
429                Ok(Choice {
430                    index: index as u32,
431                    message: Message {
432                        role,
433                        content: crate::types::common::Content::Text(content),
434                        function_call,
435                    },
436                    finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
437                })
438            })
439            .collect::<Result<Vec<_>, AiLibError>>()?;
440
441        let usage = response["usage"].as_object().ok_or_else(|| {
442            AiLibError::ProviderError("Invalid response format: usage not found".to_string())
443        })?;
444
445        let usage = Usage {
446            prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
447            completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
448            total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
449        };
450
451        Ok(ChatCompletionResponse {
452            id: response["id"].as_str().unwrap_or("").to_string(),
453            object: response["object"].as_str().unwrap_or("").to_string(),
454            created: response["created"].as_u64().unwrap_or(0),
455            model: response["model"].as_str().unwrap_or("").to_string(),
456            choices,
457            usage,
458        })
459    }
460}
461#[async_trait::async_trait]
462impl ChatApi for GenericAdapter {
463    async fn chat_completion(
464        &self,
465        request: ChatCompletionRequest,
466    ) -> Result<ChatCompletionResponse, AiLibError> {
467        // metrics: count requests (standardized key) and start timer
468        self.metrics.incr_counter("generic.requests", 1).await;
469        let timer = self
470            .metrics
471            .start_timer("generic.request_duration_ms")
472            .await;
473
474        let provider_request = self.convert_request(&request).await?;
475        let url = self.config.chat_url();
476
477        let mut headers = self.config.headers.clone();
478
479        // Set different authentication methods based on provider when an API key is present
480        if let Some(key) = &self.api_key {
481            if self.config.base_url.contains("anthropic.com") {
482                headers.insert("x-api-key".to_string(), key.clone());
483            } else {
484                headers.insert("Authorization".to_string(), format!("Bearer {}", key));
485            }
486        }
487
488        let response = self
489            .transport
490            .post_json(&url, Some(headers.clone()), provider_request.clone())
491            .await?;
492
493        // Stop timer and record success
494        if let Some(t) = timer {
495            t.stop();
496        }
497
498        self.parse_response(response)
499    }
500
501    async fn chat_completion_stream(
502        &self,
503        request: ChatCompletionRequest,
504    ) -> Result<
505        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
506        AiLibError,
507    > {
508        let mut stream_request = self.convert_request(&request).await?;
509        stream_request["stream"] = serde_json::Value::Bool(true);
510
511        let url = self.config.chat_url();
512
513        // Create HTTP client
514        let mut client_builder = reqwest::Client::builder();
515        if let Ok(proxy_url) = std::env::var("AI_PROXY_URL") {
516            if let Ok(proxy) = reqwest::Proxy::all(&proxy_url) {
517                client_builder = client_builder.proxy(proxy);
518            }
519        }
520        let client = client_builder
521            .build()
522            .map_err(|e| AiLibError::ProviderError(format!("Client error: {}", e)))?;
523
524        let mut headers = self.config.headers.clone();
525        headers.insert("Accept".to_string(), "text/event-stream".to_string());
526
527        // Set different authentication methods based on provider when an API key is present
528        if let Some(key) = &self.api_key {
529            if self.config.base_url.contains("anthropic.com") {
530                headers.insert("x-api-key".to_string(), key.clone());
531            } else {
532                headers.insert("Authorization".to_string(), format!("Bearer {}", key));
533            }
534        }
535
536        let response = client.post(&url).json(&stream_request);
537
538        let mut req = response;
539        for (key, value) in headers {
540            req = req.header(key, value);
541        }
542
543        let response = req
544            .send()
545            .await
546            .map_err(|e| AiLibError::ProviderError(format!("Stream request failed: {}", e)))?;
547
548        if !response.status().is_success() {
549            let error_text = response.text().await.unwrap_or_default();
550            return Err(AiLibError::ProviderError(format!(
551                "Stream error: {}",
552                error_text
553            )));
554        }
555
556        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
557
558        tokio::spawn(async move {
559            let mut buffer = Vec::new();
560            let mut stream = response.bytes_stream();
561
562            while let Some(result) = stream.next().await {
563                match result {
564                    Ok(bytes) => {
565                        buffer.extend_from_slice(&bytes);
566
567                        while let Some(event_end) = Self::find_event_boundary(&buffer) {
568                            let event_bytes = buffer.drain(..event_end).collect::<Vec<_>>();
569
570                            if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
571                                if let Some(chunk) = Self::parse_sse_event(event_text) {
572                                    match chunk {
573                                        Ok(Some(c)) => {
574                                            if tx.send(Ok(c)).is_err() {
575                                                return;
576                                            }
577                                        }
578                                        Ok(None) => return,
579                                        Err(e) => {
580                                            let _ = tx.send(Err(e));
581                                            return;
582                                        }
583                                    }
584                                }
585                            }
586                        }
587                    }
588                    Err(e) => {
589                        let _ = tx.send(Err(AiLibError::ProviderError(format!(
590                            "Stream error: {}",
591                            e
592                        ))));
593                        break;
594                    }
595                }
596            }
597        });
598
599        let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
600        Ok(Box::new(Box::pin(stream)))
601    }
602
603    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
604        if let Some(models_endpoint) = &self.config.models_endpoint {
605            let url = format!("{}{}", self.config.base_url, models_endpoint);
606            let mut headers = self.config.headers.clone();
607
608            // Set different authentication methods based on provider when an API key is present
609            if let Some(key) = &self.api_key {
610                if self.config.base_url.contains("anthropic.com") {
611                    headers.insert("x-api-key".to_string(), key.clone());
612                } else {
613                    headers.insert("Authorization".to_string(), format!("Bearer {}", key));
614                }
615            }
616
617            let response: serde_json::Value = self.transport.get_json(&url, Some(headers)).await?;
618
619            Ok(response["data"]
620                .as_array()
621                .unwrap_or(&vec![])
622                .iter()
623                .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
624                .collect())
625        } else {
626            Err(AiLibError::ProviderError(
627                "Models endpoint not configured".to_string(),
628            ))
629        }
630    }
631
632    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
633        Ok(ModelInfo {
634            id: model_id.to_string(),
635            object: "model".to_string(),
636            created: 0,
637            owned_by: "generic".to_string(),
638            permission: vec![ModelPermission {
639                id: "default".to_string(),
640                object: "model_permission".to_string(),
641                created: 0,
642                allow_create_engine: false,
643                allow_sampling: true,
644                allow_logprobs: false,
645                allow_search_indices: false,
646                allow_view: true,
647                allow_fine_tuning: false,
648                organization: "*".to_string(),
649                group: None,
650                is_blocking: false,
651            }],
652        })
653    }
654}