ai_lib/provider/
generic.rs

1use super::config::ProviderConfig;
2use crate::api::{
3    ChatApi, ChatCompletionChunk, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission,
4};
5use crate::metrics::{Metrics, NoopMetrics};
6use crate::transport::{DynHttpTransportRef, HttpTransport};
7use crate::types::{
8    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
9};
10use futures::stream::{Stream, StreamExt};
11use std::env;
12use std::sync::Arc;
13/// Configuration-driven generic adapter for OpenAI-compatible APIs
14pub struct GenericAdapter {
15    transport: DynHttpTransportRef,
16    config: ProviderConfig,
17    api_key: Option<String>,
18    metrics: Arc<dyn Metrics>,
19}
20
21impl GenericAdapter {
22    pub fn new(config: ProviderConfig) -> Result<Self, AiLibError> {
23        // 验证配置
24        config.validate()?;
25        
26        // For generic/config-driven providers we treat the API key as optional.
27        // Some deployments (e.g. local Ollama) don't require a key. If the env var
28        // is missing we continue with None and callers will simply omit auth headers.
29        let api_key = env::var(&config.api_key_env).ok();
30
31        Ok(Self {
32            transport: HttpTransport::new().boxed(),
33            config,
34            api_key,
35            metrics: Arc::new(NoopMetrics::new()),
36        })
37    }
38
39    /// Create adapter with custom transport layer (for testing)
40    pub fn with_transport(
41        config: ProviderConfig,
42        transport: HttpTransport,
43    ) -> Result<Self, AiLibError> {
44        // 验证配置
45        config.validate()?;
46        
47        let api_key = env::var(&config.api_key_env).ok();
48
49        Ok(Self {
50            transport: transport.boxed(),
51            config,
52            api_key,
53            metrics: Arc::new(NoopMetrics::new()),
54        })
55    }
56
57    /// Accept an object-safe transport reference directly
58    pub fn with_transport_ref(
59        config: ProviderConfig,
60        transport: DynHttpTransportRef,
61    ) -> Result<Self, AiLibError> {
62        // 验证配置
63        config.validate()?;
64        
65        let api_key = env::var(&config.api_key_env).ok();
66        Ok(Self {
67            transport,
68            config,
69            api_key,
70            metrics: Arc::new(NoopMetrics::new()),
71        })
72    }
73
74    /// Create adapter with custom transport and an injected metrics implementation
75    pub fn with_transport_ref_and_metrics(
76        config: ProviderConfig,
77        transport: DynHttpTransportRef,
78        metrics: Arc<dyn Metrics>,
79    ) -> Result<Self, AiLibError> {
80        // 验证配置
81        config.validate()?;
82        
83        let api_key = env::var(&config.api_key_env).ok();
84        Ok(Self {
85            transport,
86            config,
87            api_key,
88            metrics,
89        })
90    }
91
92    /// Create adapter with injected metrics (uses default HttpTransport)
93    pub fn with_metrics(
94        config: ProviderConfig,
95        metrics: Arc<dyn Metrics>,
96    ) -> Result<Self, AiLibError> {
97        // 验证配置
98        config.validate()?;
99        
100        let api_key = env::var(&config.api_key_env).ok();
101        Ok(Self {
102            transport: HttpTransport::new().boxed(),
103            config,
104            api_key,
105            metrics,
106        })
107    }
108
109    /// Convert generic request to provider-specific format (async: may upload local files)
110    async fn convert_request(
111        &self,
112        request: &ChatCompletionRequest,
113    ) -> Result<serde_json::Value, AiLibError> {
114        let default_role = "user".to_string();
115
116        // Build messages array; may perform uploads for local files
117        let mut messages: Vec<serde_json::Value> = Vec::with_capacity(request.messages.len());
118        for msg in request.messages.iter() {
119            let role_key = format!("{:?}", msg.role);
120            let mapped_role = self
121                .config
122                .field_mapping
123                .role_mapping
124                .get(&role_key)
125                .unwrap_or(&default_role)
126                .clone();
127
128            // Handle multimodal: if image has no url but has a name and upload endpoint configured, upload it
129            let content_val = match &msg.content {
130                crate::types::common::Content::Image {
131                    url,
132                    mime: _mime,
133                    name,
134                } => {
135                    if url.is_some() {
136                        crate::provider::utils::content_to_provider_value(&msg.content)
137                    } else if let Some(n) = name {
138                        if let Some(upload_ep) = &self.config.upload_endpoint {
139                            let upload_url = format!(
140                                "{}{}",
141                                self.config.base_url.trim_end_matches('/'),
142                                upload_ep
143                            );
144                            // Decide whether to upload or inline based on configured size limit.
145                            let should_upload = match self.config.upload_size_limit {
146                                Some(limit) => match std::fs::metadata(n) {
147                                    Ok(meta) => meta.len() > limit,
148                                    Err(_) => true, // if we can't stat the file, attempt upload
149                                },
150                                None => true, // default: upload if no limit configured (preserve prior behavior)
151                            };
152
153                            if should_upload {
154                                // Use the injected transport when available so tests can mock uploads.
155                                match crate::provider::utils::upload_file_with_transport(
156                                    Some(self.transport.clone()),
157                                    &upload_url,
158                                    n,
159                                    "file",
160                                )
161                                .await
162                                {
163                                    Ok(remote_url) => {
164                                        if remote_url.starts_with("http://")
165                                            || remote_url.starts_with("https://")
166                                            || remote_url.starts_with("data:")
167                                        {
168                                            serde_json::json!({"image": {"url": remote_url}})
169                                        } else {
170                                            serde_json::json!({"image": {"file_id": remote_url}})
171                                        }
172                                    }
173                                    Err(_) => crate::provider::utils::content_to_provider_value(
174                                        &msg.content,
175                                    ),
176                                }
177                            } else {
178                                // Inline small files as data URLs
179                                crate::provider::utils::content_to_provider_value(&msg.content)
180                            }
181                        } else {
182                            crate::provider::utils::content_to_provider_value(&msg.content)
183                        }
184                    } else {
185                        crate::provider::utils::content_to_provider_value(&msg.content)
186                    }
187                }
188                _ => crate::provider::utils::content_to_provider_value(&msg.content),
189            };
190
191            messages.push(serde_json::json!({"role": mapped_role, "content": content_val}));
192        }
193
194        // Use string literals as JSON keys
195        let mut provider_request = serde_json::json!({
196            "model": request.model,
197            "messages": messages
198        });
199
200        // Add optional parameters
201        if let Some(temp) = request.temperature {
202            provider_request["temperature"] =
203                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
204        }
205        if let Some(max_tokens) = request.max_tokens {
206            provider_request["max_tokens"] =
207                serde_json::Value::Number(serde_json::Number::from(max_tokens));
208        }
209        if let Some(top_p) = request.top_p {
210            provider_request["top_p"] =
211                serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
212        }
213        if let Some(freq_penalty) = request.frequency_penalty {
214            provider_request["frequency_penalty"] = serde_json::Value::Number(
215                serde_json::Number::from_f64(freq_penalty.into()).unwrap(),
216            );
217        }
218        if let Some(presence_penalty) = request.presence_penalty {
219            provider_request["presence_penalty"] = serde_json::Value::Number(
220                serde_json::Number::from_f64(presence_penalty.into()).unwrap(),
221            );
222        }
223
224        Ok(provider_request)
225    }
226
227    /// Find event boundary
228    fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
229        let mut i = 0;
230        while i < buffer.len().saturating_sub(1) {
231            if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
232                return Some(i + 2);
233            }
234            if i < buffer.len().saturating_sub(3)
235                && buffer[i] == b'\r'
236                && buffer[i + 1] == b'\n'
237                && buffer[i + 2] == b'\r'
238                && buffer[i + 3] == b'\n'
239            {
240                return Some(i + 4);
241            }
242            i += 1;
243        }
244        None
245    }
246
247    /// Parse SSE event
248    fn parse_sse_event(
249        event_text: &str,
250    ) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
251        for line in event_text.lines() {
252            let line = line.trim();
253            if let Some(stripped) = line.strip_prefix("data: ") {
254                let data = stripped;
255                if data == "[DONE]" {
256                    return Some(Ok(None));
257                }
258                return Some(Self::parse_chunk_data(data));
259            }
260        }
261        None
262    }
263
264    /// Parse chunk data
265    fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
266        match serde_json::from_str::<serde_json::Value>(data) {
267            Ok(json) => {
268                let choices = json["choices"]
269                    .as_array()
270                    .map(|arr| {
271                        arr.iter()
272                            .enumerate()
273                            .map(|(index, choice)| {
274                                let delta = &choice["delta"];
275                                ChoiceDelta {
276                                    index: index as u32,
277                                    delta: MessageDelta {
278                                        role: delta["role"].as_str().map(|r| match r {
279                                            "assistant" => Role::Assistant,
280                                            "user" => Role::User,
281                                            "system" => Role::System,
282                                            _ => Role::Assistant,
283                                        }),
284                                        content: delta["content"].as_str().map(str::to_string),
285                                    },
286                                    finish_reason: choice["finish_reason"]
287                                        .as_str()
288                                        .map(str::to_string),
289                                }
290                            })
291                            .collect()
292                    })
293                    .unwrap_or_default();
294
295                Ok(Some(ChatCompletionChunk {
296                    id: json["id"].as_str().unwrap_or_default().to_string(),
297                    object: json["object"]
298                        .as_str()
299                        .unwrap_or("chat.completion.chunk")
300                        .to_string(),
301                    created: json["created"].as_u64().unwrap_or(0),
302                    model: json["model"].as_str().unwrap_or_default().to_string(),
303                    choices,
304                }))
305            }
306            Err(e) => Err(AiLibError::ProviderError(format!(
307                "JSON parse error: {}",
308                e
309            ))),
310        }
311    }
312
313    /// Parse response
314    fn parse_response(
315        &self,
316        response: serde_json::Value,
317    ) -> Result<ChatCompletionResponse, AiLibError> {
318        let choices = response["choices"]
319            .as_array()
320            .ok_or_else(|| {
321                AiLibError::ProviderError("Invalid response format: choices not found".to_string())
322            })?
323            .iter()
324            .enumerate()
325            .map(|(index, choice)| {
326                let message = choice["message"].as_object().ok_or_else(|| {
327                    AiLibError::ProviderError("Invalid choice format".to_string())
328                })?;
329
330                let role = match message["role"].as_str().unwrap_or("user") {
331                    "system" => Role::System,
332                    "assistant" => Role::Assistant,
333                    _ => Role::User,
334                };
335
336                let content = message["content"].as_str().unwrap_or("").to_string();
337
338                // try to parse a function_call if present
339                let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
340                if let Some(fc_val) = message.get("function_call") {
341                    // attempt full deserialization
342                    if let Ok(mut fc) = serde_json::from_value::<
343                        crate::types::function_call::FunctionCall,
344                    >(fc_val.clone())
345                    {
346                        // If the provider deserialized arguments as a JSON string, try to parse it into structured JSON.
347                        if let Some(arg_val) = &fc.arguments {
348                            if arg_val.is_string() {
349                                if let Some(s) = arg_val.as_str() {
350                                    if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s)
351                                    {
352                                        fc.arguments = Some(parsed);
353                                    }
354                                }
355                            }
356                        }
357                        function_call = Some(fc);
358                    } else {
359                        // fallback: try to extract name + arguments (arguments may be a string)
360                        let name = fc_val
361                            .get("name")
362                            .and_then(|v| v.as_str())
363                            .map(|s| s.to_string());
364                        if let Some(name) = name {
365                            let args = fc_val.get("arguments").and_then(|a| {
366                                if a.is_string() {
367                                    serde_json::from_str::<serde_json::Value>(a.as_str().unwrap())
368                                        .ok()
369                                } else {
370                                    Some(a.clone())
371                                }
372                            });
373
374                            function_call = Some(crate::types::function_call::FunctionCall {
375                                name,
376                                arguments: args,
377                            });
378                        }
379                    }
380                }
381
382                Ok(Choice {
383                    index: index as u32,
384                    message: Message {
385                        role,
386                        content: crate::types::common::Content::Text(content),
387                        function_call,
388                    },
389                    finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
390                })
391            })
392            .collect::<Result<Vec<_>, AiLibError>>()?;
393
394        let usage = response["usage"].as_object().ok_or_else(|| {
395            AiLibError::ProviderError("Invalid response format: usage not found".to_string())
396        })?;
397
398        let usage = Usage {
399            prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
400            completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
401            total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
402        };
403
404        Ok(ChatCompletionResponse {
405            id: response["id"].as_str().unwrap_or("").to_string(),
406            object: response["object"].as_str().unwrap_or("").to_string(),
407            created: response["created"].as_u64().unwrap_or(0),
408            model: response["model"].as_str().unwrap_or("").to_string(),
409            choices,
410            usage,
411        })
412    }
413}
414#[async_trait::async_trait]
415impl ChatApi for GenericAdapter {
416    async fn chat_completion(
417        &self,
418        request: ChatCompletionRequest,
419    ) -> Result<ChatCompletionResponse, AiLibError> {
420        // metrics: count requests (standardized key) and start timer
421        self.metrics.incr_counter("generic.requests", 1).await;
422        let timer = self
423            .metrics
424            .start_timer("generic.request_duration_ms")
425            .await;
426
427        let provider_request = self.convert_request(&request).await?;
428        let url = self.config.chat_url();
429
430        let mut headers = self.config.headers.clone();
431
432        // Set different authentication methods based on provider when an API key is present
433        if let Some(key) = &self.api_key {
434            if self.config.base_url.contains("anthropic.com") {
435                headers.insert("x-api-key".to_string(), key.clone());
436            } else {
437                headers.insert("Authorization".to_string(), format!("Bearer {}", key));
438            }
439        }
440
441        let response = match self
442            .transport
443            .post_json(&url, Some(headers), provider_request)
444            .await
445        {
446            Ok(v) => {
447                if let Some(t) = timer {
448                    t.stop();
449                    // record that we stopped the timer in case test-inspection needs a metric
450                    let _ = self.metrics.incr_counter("generic.request_timer_recorded", 1).await;
451                }
452                v
453            }
454            Err(e) => {
455                if let Some(t) = timer {
456                    t.stop();
457                    let _ = self.metrics.incr_counter("generic.request_timer_recorded", 1).await;
458                }
459                return Err(e);
460            }
461        };
462
463        self.parse_response(response)
464    }
465
466    async fn chat_completion_stream(
467        &self,
468        request: ChatCompletionRequest,
469    ) -> Result<
470        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
471        AiLibError,
472    > {
473        let mut stream_request = self.convert_request(&request).await?;
474        stream_request["stream"] = serde_json::Value::Bool(true);
475
476        let url = self.config.chat_url();
477
478        // Create HTTP client
479        let mut client_builder = reqwest::Client::builder();
480        if let Ok(proxy_url) = std::env::var("AI_PROXY_URL") {
481            if let Ok(proxy) = reqwest::Proxy::all(&proxy_url) {
482                client_builder = client_builder.proxy(proxy);
483            }
484        }
485        let client = client_builder
486            .build()
487            .map_err(|e| AiLibError::ProviderError(format!("Client error: {}", e)))?;
488
489        let mut headers = self.config.headers.clone();
490        headers.insert("Accept".to_string(), "text/event-stream".to_string());
491
492        // Set different authentication methods based on provider when an API key is present
493        if let Some(key) = &self.api_key {
494            if self.config.base_url.contains("anthropic.com") {
495                headers.insert("x-api-key".to_string(), key.clone());
496            } else {
497                headers.insert("Authorization".to_string(), format!("Bearer {}", key));
498            }
499        }
500
501        let response = client.post(&url).json(&stream_request);
502
503        let mut req = response;
504        for (key, value) in headers {
505            req = req.header(key, value);
506        }
507
508        let response = req
509            .send()
510            .await
511            .map_err(|e| AiLibError::ProviderError(format!("Stream request failed: {}", e)))?;
512
513        if !response.status().is_success() {
514            let error_text = response.text().await.unwrap_or_default();
515            return Err(AiLibError::ProviderError(format!(
516                "Stream error: {}",
517                error_text
518            )));
519        }
520
521        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
522
523        tokio::spawn(async move {
524            let mut buffer = Vec::new();
525            let mut stream = response.bytes_stream();
526
527            while let Some(result) = stream.next().await {
528                match result {
529                    Ok(bytes) => {
530                        buffer.extend_from_slice(&bytes);
531
532                        while let Some(event_end) = Self::find_event_boundary(&buffer) {
533                            let event_bytes = buffer.drain(..event_end).collect::<Vec<_>>();
534
535                            if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
536                                if let Some(chunk) = Self::parse_sse_event(event_text) {
537                                    match chunk {
538                                        Ok(Some(c)) => {
539                                            if tx.send(Ok(c)).is_err() {
540                                                return;
541                                            }
542                                        }
543                                        Ok(None) => return,
544                                        Err(e) => {
545                                            let _ = tx.send(Err(e));
546                                            return;
547                                        }
548                                    }
549                                }
550                            }
551                        }
552                    }
553                    Err(e) => {
554                        let _ = tx.send(Err(AiLibError::ProviderError(format!(
555                            "Stream error: {}",
556                            e
557                        ))));
558                        break;
559                    }
560                }
561            }
562        });
563
564        let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
565        Ok(Box::new(Box::pin(stream)))
566    }
567
568    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
569        if let Some(models_endpoint) = &self.config.models_endpoint {
570            let url = format!("{}{}", self.config.base_url, models_endpoint);
571            let mut headers = self.config.headers.clone();
572
573            // Set different authentication methods based on provider when an API key is present
574            if let Some(key) = &self.api_key {
575                if self.config.base_url.contains("anthropic.com") {
576                    headers.insert("x-api-key".to_string(), key.clone());
577                } else {
578                    headers.insert("Authorization".to_string(), format!("Bearer {}", key));
579                }
580            }
581
582            let response: serde_json::Value = self.transport.get_json(&url, Some(headers)).await?;
583
584            Ok(response["data"]
585                .as_array()
586                .unwrap_or(&vec![])
587                .iter()
588                .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
589                .collect())
590        } else {
591            Err(AiLibError::ProviderError(
592                "Models endpoint not configured".to_string(),
593            ))
594        }
595    }
596
597    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
598        Ok(ModelInfo {
599            id: model_id.to_string(),
600            object: "model".to_string(),
601            created: 0,
602            owned_by: "generic".to_string(),
603            permission: vec![ModelPermission {
604                id: "default".to_string(),
605                object: "model_permission".to_string(),
606                created: 0,
607                allow_create_engine: false,
608                allow_sampling: true,
609                allow_logprobs: false,
610                allow_search_indices: false,
611                allow_view: true,
612                allow_fine_tuning: false,
613                organization: "*".to_string(),
614                group: None,
615                is_blocking: false,
616            }],
617        })
618    }
619}