ai_lib/provider/
generic.rs

1use super::config::ProviderConfig;
2use crate::api::{
3    ChatApi, ChatCompletionChunk, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission,
4};
5use crate::metrics::{Metrics, NoopMetrics};
6use crate::transport::{DynHttpTransportRef, HttpTransport};
7use crate::types::{
8    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
9};
10use futures::stream::{Stream, StreamExt};
11use std::env;
12use std::sync::Arc;
13/// Configuration-driven generic adapter for OpenAI-compatible APIs
14pub struct GenericAdapter {
15    transport: DynHttpTransportRef,
16    config: ProviderConfig,
17    api_key: Option<String>,
18    metrics: Arc<dyn Metrics>,
19}
20
21impl GenericAdapter {
22    pub fn new(config: ProviderConfig) -> Result<Self, AiLibError> {
23        // Validate configuration
24        config.validate()?;
25
26        // For generic/config-driven providers we treat the API key as optional.
27        // Some deployments (e.g. local Ollama) don't require a key. If the env var
28        // is missing we continue with None and callers will simply omit auth headers.
29        let api_key = env::var(&config.api_key_env).ok();
30
31        Ok(Self {
32            transport: HttpTransport::new_without_proxy().boxed(),
33            config,
34            api_key,
35            metrics: Arc::new(NoopMetrics::new()),
36        })
37    }
38
39    /// Create adapter with custom transport layer (for testing)
40    pub fn with_transport(
41        config: ProviderConfig,
42        transport: HttpTransport,
43    ) -> Result<Self, AiLibError> {
44        // Validate configuration
45        config.validate()?;
46
47        let api_key = env::var(&config.api_key_env).ok();
48
49        Ok(Self {
50            transport: transport.boxed(),
51            config,
52            api_key,
53            metrics: Arc::new(NoopMetrics::new()),
54        })
55    }
56
57    /// Accept an object-safe transport reference directly
58    pub fn with_transport_ref(
59        config: ProviderConfig,
60        transport: DynHttpTransportRef,
61    ) -> Result<Self, AiLibError> {
62        // Validate configuration
63        config.validate()?;
64
65        let api_key = env::var(&config.api_key_env).ok();
66        Ok(Self {
67            transport,
68            config,
69            api_key,
70            metrics: Arc::new(NoopMetrics::new()),
71        })
72    }
73
74    /// Create adapter with custom transport and an injected metrics implementation
75    pub fn with_transport_ref_and_metrics(
76        config: ProviderConfig,
77        transport: DynHttpTransportRef,
78        metrics: Arc<dyn Metrics>,
79    ) -> Result<Self, AiLibError> {
80        // Validate configuration
81        config.validate()?;
82
83        let api_key = env::var(&config.api_key_env).ok();
84        Ok(Self {
85            transport,
86            config,
87            api_key,
88            metrics,
89        })
90    }
91
92    /// Create adapter with injected metrics (uses default HttpTransport)
93    pub fn with_metrics(
94        config: ProviderConfig,
95        metrics: Arc<dyn Metrics>,
96    ) -> Result<Self, AiLibError> {
97        // Validate configuration
98        config.validate()?;
99
100        let api_key = env::var(&config.api_key_env).ok();
101        Ok(Self {
102            transport: HttpTransport::new().boxed(),
103            config,
104            api_key,
105            metrics,
106        })
107    }
108
109    /// Convert generic request to provider-specific format (async: may upload local files)
110    async fn convert_request(
111        &self,
112        request: &ChatCompletionRequest,
113    ) -> Result<serde_json::Value, AiLibError> {
114        let default_role = "user".to_string();
115
116        // Build messages array; may perform uploads for local files
117        let mut messages: Vec<serde_json::Value> = Vec::with_capacity(request.messages.len());
118        for msg in request.messages.iter() {
119            let role_key = format!("{:?}", msg.role);
120            let mapped_role = self
121                .config
122                .field_mapping
123                .role_mapping
124                .get(&role_key)
125                .unwrap_or(&default_role)
126                .clone();
127
128            // Handle multimodal: if image has no url but has a name and upload endpoint configured, upload it
129            let content_val = match &msg.content {
130                crate::types::common::Content::Image {
131                    url,
132                    mime: _mime,
133                    name,
134                } => {
135                    if url.is_some() {
136                        crate::provider::utils::content_to_provider_value(&msg.content)
137                    } else if let Some(n) = name {
138                        if let Some(upload_ep) = &self.config.upload_endpoint {
139                            let upload_url = format!(
140                                "{}{}",
141                                self.config.base_url.trim_end_matches('/'),
142                                upload_ep
143                            );
144                            // Decide whether to upload or inline based on configured size limit.
145                            let should_upload = match self.config.upload_size_limit {
146                                Some(limit) => match std::fs::metadata(n) {
147                                    Ok(meta) => meta.len() > limit,
148                                    Err(_) => true, // if we can't stat the file, attempt upload
149                                },
150                                None => true, // default: upload if no limit configured (preserve prior behavior)
151                            };
152
153                            if should_upload {
154                                // Use the injected transport when available so tests can mock uploads.
155                                match crate::provider::utils::upload_file_with_transport(
156                                    Some(self.transport.clone()),
157                                    &upload_url,
158                                    n,
159                                    "file",
160                                )
161                                .await
162                                {
163                                    Ok(remote_url) => {
164                                        if remote_url.starts_with("http://")
165                                            || remote_url.starts_with("https://")
166                                            || remote_url.starts_with("data:")
167                                        {
168                                            serde_json::json!({"image": {"url": remote_url}})
169                                        } else {
170                                            serde_json::json!({"image": {"file_id": remote_url}})
171                                        }
172                                    }
173                                    Err(_) => crate::provider::utils::content_to_provider_value(
174                                        &msg.content,
175                                    ),
176                                }
177                            } else {
178                                // Inline small files as data URLs
179                                crate::provider::utils::content_to_provider_value(&msg.content)
180                            }
181                        } else {
182                            crate::provider::utils::content_to_provider_value(&msg.content)
183                        }
184                    } else {
185                        crate::provider::utils::content_to_provider_value(&msg.content)
186                    }
187                }
188                _ => crate::provider::utils::content_to_provider_value(&msg.content),
189            };
190
191            messages.push(serde_json::json!({"role": mapped_role, "content": content_val}));
192        }
193
194        // Use string literals as JSON keys
195        let mut provider_request = serde_json::json!({
196            "model": request.model,
197            "messages": messages
198        });
199
200        // Add optional parameters
201        if let Some(temp) = request.temperature {
202            provider_request["temperature"] =
203                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
204        }
205        if let Some(max_tokens) = request.max_tokens {
206            provider_request["max_tokens"] =
207                serde_json::Value::Number(serde_json::Number::from(max_tokens));
208        }
209        if let Some(top_p) = request.top_p {
210            provider_request["top_p"] =
211                serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
212        }
213        if let Some(freq_penalty) = request.frequency_penalty {
214            provider_request["frequency_penalty"] = serde_json::Value::Number(
215                serde_json::Number::from_f64(freq_penalty.into()).unwrap(),
216            );
217        }
218        if let Some(presence_penalty) = request.presence_penalty {
219            provider_request["presence_penalty"] = serde_json::Value::Number(
220                serde_json::Number::from_f64(presence_penalty.into()).unwrap(),
221            );
222        }
223
224        Ok(provider_request)
225    }
226
227    /// Find event boundary
228    fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
229        let mut i = 0;
230        while i < buffer.len().saturating_sub(1) {
231            if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
232                return Some(i + 2);
233            }
234            if i < buffer.len().saturating_sub(3)
235                && buffer[i] == b'\r'
236                && buffer[i + 1] == b'\n'
237                && buffer[i + 2] == b'\r'
238                && buffer[i + 3] == b'\n'
239            {
240                return Some(i + 4);
241            }
242            i += 1;
243        }
244        None
245    }
246
247    /// Parse SSE event
248    fn parse_sse_event(
249        event_text: &str,
250    ) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
251        for line in event_text.lines() {
252            let line = line.trim();
253            if let Some(stripped) = line.strip_prefix("data: ") {
254                let data = stripped;
255                if data == "[DONE]" {
256                    return Some(Ok(None));
257                }
258                return Some(Self::parse_chunk_data(data));
259            }
260        }
261        None
262    }
263
264    /// Parse chunk data
265    fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
266        match serde_json::from_str::<serde_json::Value>(data) {
267            Ok(json) => {
268                let choices = json["choices"]
269                    .as_array()
270                    .map(|arr| {
271                        arr.iter()
272                            .enumerate()
273                            .map(|(index, choice)| {
274                                let delta = &choice["delta"];
275                                ChoiceDelta {
276                                    index: index as u32,
277                                    delta: MessageDelta {
278                                        role: delta["role"].as_str().map(|r| match r {
279                                            "assistant" => Role::Assistant,
280                                            "user" => Role::User,
281                                            "system" => Role::System,
282                                            _ => Role::Assistant,
283                                        }),
284                                        content: delta["content"].as_str().map(str::to_string),
285                                    },
286                                    finish_reason: choice["finish_reason"]
287                                        .as_str()
288                                        .map(str::to_string),
289                                }
290                            })
291                            .collect()
292                    })
293                    .unwrap_or_default();
294
295                Ok(Some(ChatCompletionChunk {
296                    id: json["id"].as_str().unwrap_or_default().to_string(),
297                    object: json["object"]
298                        .as_str()
299                        .unwrap_or("chat.completion.chunk")
300                        .to_string(),
301                    created: json["created"].as_u64().unwrap_or(0),
302                    model: json["model"].as_str().unwrap_or_default().to_string(),
303                    choices,
304                }))
305            }
306            Err(e) => Err(AiLibError::ProviderError(format!(
307                "JSON parse error: {}",
308                e
309            ))),
310        }
311    }
312
313    /// Parse response
314    fn parse_response(
315        &self,
316        response: serde_json::Value,
317    ) -> Result<ChatCompletionResponse, AiLibError> {
318        let choices = response["choices"]
319            .as_array()
320            .ok_or_else(|| {
321                AiLibError::ProviderError("Invalid response format: choices not found".to_string())
322            })?
323            .iter()
324            .enumerate()
325            .map(|(index, choice)| {
326                let message = choice["message"].as_object().ok_or_else(|| {
327                    AiLibError::ProviderError("Invalid choice format".to_string())
328                })?;
329
330                let role = match message["role"].as_str().unwrap_or("user") {
331                    "system" => Role::System,
332                    "assistant" => Role::Assistant,
333                    _ => Role::User,
334                };
335
336                let content = message["content"].as_str().unwrap_or("").to_string();
337
338                // try to parse a function_call if present
339                let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
340                if let Some(fc_val) = message.get("function_call") {
341                    // attempt full deserialization
342                    if let Ok(mut fc) = serde_json::from_value::<
343                        crate::types::function_call::FunctionCall,
344                    >(fc_val.clone())
345                    {
346                        // If the provider deserialized arguments as a JSON string, try to parse it into structured JSON.
347                        if let Some(arg_val) = &fc.arguments {
348                            if arg_val.is_string() {
349                                if let Some(s) = arg_val.as_str() {
350                                    if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s)
351                                    {
352                                        fc.arguments = Some(parsed);
353                                    }
354                                }
355                            }
356                        }
357                        function_call = Some(fc);
358                    } else {
359                        // fallback: try to extract name + arguments (arguments may be a string)
360                        let name = fc_val
361                            .get("name")
362                            .and_then(|v| v.as_str())
363                            .map(|s| s.to_string());
364                        if let Some(name) = name {
365                            let args = fc_val.get("arguments").and_then(|a| {
366                                if a.is_string() {
367                                    serde_json::from_str::<serde_json::Value>(a.as_str().unwrap())
368                                        .ok()
369                                } else {
370                                    Some(a.clone())
371                                }
372                            });
373
374                            function_call = Some(crate::types::function_call::FunctionCall {
375                                name,
376                                arguments: args,
377                            });
378                        }
379                    }
380                }
381
382                Ok(Choice {
383                    index: index as u32,
384                    message: Message {
385                        role,
386                        content: crate::types::common::Content::Text(content),
387                        function_call,
388                    },
389                    finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
390                })
391            })
392            .collect::<Result<Vec<_>, AiLibError>>()?;
393
394        let usage = response["usage"].as_object().ok_or_else(|| {
395            AiLibError::ProviderError("Invalid response format: usage not found".to_string())
396        })?;
397
398        let usage = Usage {
399            prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
400            completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
401            total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
402        };
403
404        Ok(ChatCompletionResponse {
405            id: response["id"].as_str().unwrap_or("").to_string(),
406            object: response["object"].as_str().unwrap_or("").to_string(),
407            created: response["created"].as_u64().unwrap_or(0),
408            model: response["model"].as_str().unwrap_or("").to_string(),
409            choices,
410            usage,
411        })
412    }
413}
414#[async_trait::async_trait]
415impl ChatApi for GenericAdapter {
416    async fn chat_completion(
417        &self,
418        request: ChatCompletionRequest,
419    ) -> Result<ChatCompletionResponse, AiLibError> {
420        // metrics: count requests (standardized key) and start timer
421        self.metrics.incr_counter("generic.requests", 1).await;
422        let timer = self
423            .metrics
424            .start_timer("generic.request_duration_ms")
425            .await;
426
427        let provider_request = self.convert_request(&request).await?;
428        let url = self.config.chat_url();
429
430        let mut headers = self.config.headers.clone();
431
432        // Set different authentication methods based on provider when an API key is present
433        if let Some(key) = &self.api_key {
434            if self.config.base_url.contains("anthropic.com") {
435                headers.insert("x-api-key".to_string(), key.clone());
436            } else {
437                headers.insert("Authorization".to_string(), format!("Bearer {}", key));
438            }
439        }
440
441        let response = self
442            .transport
443            .post_json(&url, Some(headers.clone()), provider_request.clone())
444            .await?;
445
446        // Stop timer and record success
447        if let Some(t) = timer {
448            t.stop();
449        }
450
451        self.parse_response(response)
452    }
453
454    async fn chat_completion_stream(
455        &self,
456        request: ChatCompletionRequest,
457    ) -> Result<
458        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
459        AiLibError,
460    > {
461        let mut stream_request = self.convert_request(&request).await?;
462        stream_request["stream"] = serde_json::Value::Bool(true);
463
464        let url = self.config.chat_url();
465
466        // Create HTTP client
467        let mut client_builder = reqwest::Client::builder();
468        if let Ok(proxy_url) = std::env::var("AI_PROXY_URL") {
469            if let Ok(proxy) = reqwest::Proxy::all(&proxy_url) {
470                client_builder = client_builder.proxy(proxy);
471            }
472        }
473        let client = client_builder
474            .build()
475            .map_err(|e| AiLibError::ProviderError(format!("Client error: {}", e)))?;
476
477        let mut headers = self.config.headers.clone();
478        headers.insert("Accept".to_string(), "text/event-stream".to_string());
479
480        // Set different authentication methods based on provider when an API key is present
481        if let Some(key) = &self.api_key {
482            if self.config.base_url.contains("anthropic.com") {
483                headers.insert("x-api-key".to_string(), key.clone());
484            } else {
485                headers.insert("Authorization".to_string(), format!("Bearer {}", key));
486            }
487        }
488
489        let response = client.post(&url).json(&stream_request);
490
491        let mut req = response;
492        for (key, value) in headers {
493            req = req.header(key, value);
494        }
495
496        let response = req
497            .send()
498            .await
499            .map_err(|e| AiLibError::ProviderError(format!("Stream request failed: {}", e)))?;
500
501        if !response.status().is_success() {
502            let error_text = response.text().await.unwrap_or_default();
503            return Err(AiLibError::ProviderError(format!(
504                "Stream error: {}",
505                error_text
506            )));
507        }
508
509        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
510
511        tokio::spawn(async move {
512            let mut buffer = Vec::new();
513            let mut stream = response.bytes_stream();
514
515            while let Some(result) = stream.next().await {
516                match result {
517                    Ok(bytes) => {
518                        buffer.extend_from_slice(&bytes);
519
520                        while let Some(event_end) = Self::find_event_boundary(&buffer) {
521                            let event_bytes = buffer.drain(..event_end).collect::<Vec<_>>();
522
523                            if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
524                                if let Some(chunk) = Self::parse_sse_event(event_text) {
525                                    match chunk {
526                                        Ok(Some(c)) => {
527                                            if tx.send(Ok(c)).is_err() {
528                                                return;
529                                            }
530                                        }
531                                        Ok(None) => return,
532                                        Err(e) => {
533                                            let _ = tx.send(Err(e));
534                                            return;
535                                        }
536                                    }
537                                }
538                            }
539                        }
540                    }
541                    Err(e) => {
542                        let _ = tx.send(Err(AiLibError::ProviderError(format!(
543                            "Stream error: {}",
544                            e
545                        ))));
546                        break;
547                    }
548                }
549            }
550        });
551
552        let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
553        Ok(Box::new(Box::pin(stream)))
554    }
555
556    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
557        if let Some(models_endpoint) = &self.config.models_endpoint {
558            let url = format!("{}{}", self.config.base_url, models_endpoint);
559            let mut headers = self.config.headers.clone();
560
561            // Set different authentication methods based on provider when an API key is present
562            if let Some(key) = &self.api_key {
563                if self.config.base_url.contains("anthropic.com") {
564                    headers.insert("x-api-key".to_string(), key.clone());
565                } else {
566                    headers.insert("Authorization".to_string(), format!("Bearer {}", key));
567                }
568            }
569
570            let response: serde_json::Value = self.transport.get_json(&url, Some(headers)).await?;
571
572            Ok(response["data"]
573                .as_array()
574                .unwrap_or(&vec![])
575                .iter()
576                .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
577                .collect())
578        } else {
579            Err(AiLibError::ProviderError(
580                "Models endpoint not configured".to_string(),
581            ))
582        }
583    }
584
585    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
586        Ok(ModelInfo {
587            id: model_id.to_string(),
588            object: "model".to_string(),
589            created: 0,
590            owned_by: "generic".to_string(),
591            permission: vec![ModelPermission {
592                id: "default".to_string(),
593                object: "model_permission".to_string(),
594                created: 0,
595                allow_create_engine: false,
596                allow_sampling: true,
597                allow_logprobs: false,
598                allow_search_indices: false,
599                allow_view: true,
600                allow_fine_tuning: false,
601                organization: "*".to_string(),
602                group: None,
603                is_blocking: false,
604            }],
605        })
606    }
607}