ai_lib/provider/
generic.rs

1use super::config::ProviderConfig;
2use crate::api::{
3    ChatApi, ChatCompletionChunk, ChoiceDelta, MessageDelta, ModelInfo, ModelPermission,
4};
5use crate::metrics::{Metrics, NoopMetrics};
6use crate::transport::{DynHttpTransportRef, HttpTransport};
7use crate::types::{
8    AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage, UsageStatus,
9};
10use futures::stream::{Stream, StreamExt};
11use std::env;
12use std::sync::Arc;
13/// Configuration-driven generic adapter for OpenAI-compatible APIs
14pub struct GenericAdapter {
15    transport: DynHttpTransportRef,
16    config: ProviderConfig,
17    api_key: Option<String>,
18    metrics: Arc<dyn Metrics>,
19}
20
21#[cfg(all(test, not(feature = "unified_sse")))]
22mod legacy_sse_tests {
23    use super::*;
24
25    #[test]
26    fn legacy_event_sequence_non_ascii() {
27        let event1 = "data: {\"id\":\"1\",\"object\":\"chat.completion.chunk\",\"created\":0,\"model\":\"m\",\"choices\":[{\"delta\":{\"role\":\"assistant\",\"content\":\"你好,\"}}]}\n\n";
28        let event2 = "data: {\"id\":\"2\",\"object\":\"chat.completion.chunk\",\"created\":0,\"model\":\"m\",\"choices\":[{\"delta\":{\"content\":\"世界!\"}}]}\n\n";
29        let mut buffer = [event1.as_bytes(), event2.as_bytes()].concat();
30        let mut out: Vec<String> = Vec::new();
31        while let Some(boundary) = GenericAdapter::find_event_boundary(&buffer) {
32            let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
33            if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
34                if let Some(parsed) = GenericAdapter::parse_sse_event(event_text) {
35                    let chunk = parsed.expect("ok").expect("chunk");
36                    if let Some(c) = &chunk.choices[0].delta.content {
37                        out.push(c.clone());
38                    }
39                }
40            }
41        }
42        assert_eq!(out, vec!["你好,".to_string(), "世界!".to_string()]);
43    }
44}
45
46impl GenericAdapter {
47    pub fn new(config: ProviderConfig) -> Result<Self, AiLibError> {
48        // Validate configuration
49        config.validate()?;
50
51        // For generic/config-driven providers we treat the API key as optional.
52        // Some deployments (e.g. local Ollama) don't require a key. If the env var
53        // is missing we continue with None and callers will simply omit auth headers.
54        let api_key = env::var(&config.api_key_env).ok();
55
56        Ok(Self {
57            transport: HttpTransport::new_without_proxy().boxed(),
58            config,
59            api_key,
60            metrics: Arc::new(NoopMetrics::new()),
61        })
62    }
63
64    /// Create adapter with an explicit API key override (takes precedence over env var).
65    pub fn new_with_api_key(
66        config: ProviderConfig,
67        api_key_override: Option<String>,
68    ) -> Result<Self, AiLibError> {
69        config.validate()?;
70        let api_key = api_key_override.or_else(|| env::var(&config.api_key_env).ok());
71        Ok(Self {
72            transport: HttpTransport::new_without_proxy().boxed(),
73            config,
74            api_key,
75            metrics: Arc::new(NoopMetrics::new()),
76        })
77    }
78
79    /// Create adapter with custom transport layer (for testing)
80    pub fn with_transport(
81        config: ProviderConfig,
82        transport: HttpTransport,
83    ) -> Result<Self, AiLibError> {
84        // Validate configuration
85        config.validate()?;
86
87        let api_key = env::var(&config.api_key_env).ok();
88
89        Ok(Self {
90            transport: transport.boxed(),
91            config,
92            api_key,
93            metrics: Arc::new(NoopMetrics::new()),
94        })
95    }
96
97    /// Custom transport + API key override.
98    pub fn with_transport_api_key(
99        config: ProviderConfig,
100        transport: HttpTransport,
101        api_key_override: Option<String>,
102    ) -> Result<Self, AiLibError> {
103        config.validate()?;
104        let api_key = api_key_override.or_else(|| env::var(&config.api_key_env).ok());
105        Ok(Self {
106            transport: transport.boxed(),
107            config,
108            api_key,
109            metrics: Arc::new(NoopMetrics::new()),
110        })
111    }
112
113    /// Accept an object-safe transport reference directly
114    pub fn with_transport_ref(
115        config: ProviderConfig,
116        transport: DynHttpTransportRef,
117    ) -> Result<Self, AiLibError> {
118        // Validate configuration
119        config.validate()?;
120
121        let api_key = env::var(&config.api_key_env).ok();
122        Ok(Self {
123            transport,
124            config,
125            api_key,
126            metrics: Arc::new(NoopMetrics::new()),
127        })
128    }
129
130    /// Object-safe transport + API key override.
131    pub fn with_transport_ref_api_key(
132        config: ProviderConfig,
133        transport: DynHttpTransportRef,
134        api_key_override: Option<String>,
135    ) -> Result<Self, AiLibError> {
136        config.validate()?;
137        let api_key = api_key_override.or_else(|| env::var(&config.api_key_env).ok());
138        Ok(Self {
139            transport,
140            config,
141            api_key,
142            metrics: Arc::new(NoopMetrics::new()),
143        })
144    }
145
146    /// Create adapter with custom transport and an injected metrics implementation
147    pub fn with_transport_ref_and_metrics(
148        config: ProviderConfig,
149        transport: DynHttpTransportRef,
150        metrics: Arc<dyn Metrics>,
151    ) -> Result<Self, AiLibError> {
152        // Validate configuration
153        config.validate()?;
154
155        let api_key = env::var(&config.api_key_env).ok();
156        Ok(Self {
157            transport,
158            config,
159            api_key,
160            metrics,
161        })
162    }
163
164    /// Create adapter with injected metrics (uses default HttpTransport)
165    pub fn with_metrics(
166        config: ProviderConfig,
167        metrics: Arc<dyn Metrics>,
168    ) -> Result<Self, AiLibError> {
169        // Validate configuration
170        config.validate()?;
171
172        let api_key = env::var(&config.api_key_env).ok();
173        Ok(Self {
174            transport: HttpTransport::new().boxed(),
175            config,
176            api_key,
177            metrics,
178        })
179    }
180
181    /// Convert generic request to provider-specific format (async: may upload local files)
182    async fn convert_request(
183        &self,
184        request: &ChatCompletionRequest,
185    ) -> Result<serde_json::Value, AiLibError> {
186        let default_role = "user".to_string();
187
188        // Build messages array; may perform uploads for local files
189        let mut messages: Vec<serde_json::Value> = Vec::with_capacity(request.messages.len());
190        for msg in request.messages.iter() {
191            let role_key = format!("{:?}", msg.role);
192            let mapped_role = self
193                .config
194                .field_mapping
195                .role_mapping
196                .get(&role_key)
197                .unwrap_or(&default_role)
198                .clone();
199
200            // Handle multimodal: if image has no url but has a name and upload endpoint configured, upload it
201            let content_val = match &msg.content {
202                crate::types::common::Content::Image {
203                    url,
204                    mime: _mime,
205                    name,
206                } => {
207                    if url.is_some() {
208                        crate::provider::utils::content_to_provider_value(&msg.content)
209                    } else if let Some(n) = name {
210                        if let Some(upload_ep) = &self.config.upload_endpoint {
211                            let upload_url = format!(
212                                "{}{}",
213                                self.config.base_url.trim_end_matches('/'),
214                                upload_ep
215                            );
216                            // Decide whether to upload or inline based on configured size limit.
217                            let should_upload = match self.config.upload_size_limit {
218                                Some(limit) => match std::fs::metadata(n) {
219                                    Ok(meta) => meta.len() > limit,
220                                    Err(_) => true, // if we can't stat the file, attempt upload
221                                },
222                                None => true, // default: upload if no limit configured (preserve prior behavior)
223                            };
224
225                            if should_upload {
226                                // Use the injected transport when available so tests can mock uploads.
227                                match crate::provider::utils::upload_file_with_transport(
228                                    Some(self.transport.clone()),
229                                    &upload_url,
230                                    n,
231                                    "file",
232                                )
233                                .await
234                                {
235                                    Ok(remote_url) => {
236                                        if remote_url.starts_with("http://")
237                                            || remote_url.starts_with("https://")
238                                            || remote_url.starts_with("data:")
239                                        {
240                                            serde_json::json!({"image": {"url": remote_url}})
241                                        } else {
242                                            serde_json::json!({"image": {"file_id": remote_url}})
243                                        }
244                                    }
245                                    Err(_) => crate::provider::utils::content_to_provider_value(
246                                        &msg.content,
247                                    ),
248                                }
249                            } else {
250                                // Inline small files as data URLs
251                                crate::provider::utils::content_to_provider_value(&msg.content)
252                            }
253                        } else {
254                            crate::provider::utils::content_to_provider_value(&msg.content)
255                        }
256                    } else {
257                        crate::provider::utils::content_to_provider_value(&msg.content)
258                    }
259                }
260                _ => crate::provider::utils::content_to_provider_value(&msg.content),
261            };
262
263            messages.push(serde_json::json!({"role": mapped_role, "content": content_val}));
264        }
265
266        // Use string literals as JSON keys
267        let mut provider_request = serde_json::json!({
268            "model": request.model,
269            "messages": messages
270        });
271
272        // Add optional parameters
273        if let Some(temp) = request.temperature {
274            provider_request["temperature"] =
275                serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
276        }
277        if let Some(max_tokens) = request.max_tokens {
278            provider_request["max_tokens"] =
279                serde_json::Value::Number(serde_json::Number::from(max_tokens));
280        }
281        if let Some(top_p) = request.top_p {
282            provider_request["top_p"] =
283                serde_json::Value::Number(serde_json::Number::from_f64(top_p.into()).unwrap());
284        }
285        if let Some(freq_penalty) = request.frequency_penalty {
286            provider_request["frequency_penalty"] = serde_json::Value::Number(
287                serde_json::Number::from_f64(freq_penalty.into()).unwrap(),
288            );
289        }
290        if let Some(presence_penalty) = request.presence_penalty {
291            provider_request["presence_penalty"] = serde_json::Value::Number(
292                serde_json::Number::from_f64(presence_penalty.into()).unwrap(),
293            );
294        }
295
296        // Function calling (OpenAI-compatible). Many config-driven providers accept this schema.
297        if let Some(funcs) = &request.functions {
298            let mapped: Vec<serde_json::Value> = funcs
299                .iter()
300                .map(|t| {
301                    serde_json::json!({
302                        "name": t.name,
303                        "description": t.description,
304                        "parameters": t.parameters.clone().unwrap_or(serde_json::json!({}))
305                    })
306                })
307                .collect();
308            provider_request["functions"] = serde_json::Value::Array(mapped);
309        }
310
311        if let Some(policy) = &request.function_call {
312            match policy {
313                crate::types::FunctionCallPolicy::Auto(name) => {
314                    if name == "auto" {
315                        provider_request["function_call"] =
316                            serde_json::Value::String("auto".to_string());
317                    } else {
318                        provider_request["function_call"] = serde_json::json!({"name": name});
319                    }
320                }
321                crate::types::FunctionCallPolicy::None => {
322                    provider_request["function_call"] =
323                        serde_json::Value::String("none".to_string());
324                }
325            }
326        }
327
328        Ok(provider_request)
329    }
330
331    /// Find event boundary
332    /// Deprecated path (legacy SSE helper). Prefer `crate::sse::parser` when `unified_sse` is enabled.
333    #[cfg(not(feature = "unified_sse"))]
334    fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
335        let mut i = 0;
336        while i < buffer.len().saturating_sub(1) {
337            if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
338                return Some(i + 2);
339            }
340            if i < buffer.len().saturating_sub(3)
341                && buffer[i] == b'\r'
342                && buffer[i + 1] == b'\n'
343                && buffer[i + 2] == b'\r'
344                && buffer[i + 3] == b'\n'
345            {
346                return Some(i + 4);
347            }
348            i += 1;
349        }
350        None
351    }
352
353    /// Parse SSE event
354    /// Deprecated path (legacy SSE helper). Prefer `crate::sse::parser` when `unified_sse` is enabled.
355    #[cfg(not(feature = "unified_sse"))]
356    fn parse_sse_event(
357        event_text: &str,
358    ) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
359        for line in event_text.lines() {
360            let line = line.trim();
361            if let Some(stripped) = line.strip_prefix("data: ") {
362                let data = stripped;
363                if data == "[DONE]" {
364                    return Some(Ok(None));
365                }
366                return Some(Self::parse_chunk_data(data));
367            }
368        }
369        None
370    }
371
372    /// Parse chunk data
373    /// Deprecated path (legacy SSE helper). Prefer `crate::sse::parser` when `unified_sse` is enabled.
374    #[cfg(not(feature = "unified_sse"))]
375    fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
376        match serde_json::from_str::<serde_json::Value>(data) {
377            Ok(json) => {
378                let choices = json["choices"]
379                    .as_array()
380                    .map(|arr| {
381                        arr.iter()
382                            .enumerate()
383                            .map(|(index, choice)| {
384                                let delta = &choice["delta"];
385                                ChoiceDelta {
386                                    index: index as u32,
387                                    delta: MessageDelta {
388                                        role: delta["role"].as_str().map(|r| match r {
389                                            "assistant" => Role::Assistant,
390                                            "user" => Role::User,
391                                            "system" => Role::System,
392                                            _ => Role::Assistant,
393                                        }),
394                                        content: delta["content"].as_str().map(str::to_string),
395                                    },
396                                    finish_reason: choice["finish_reason"]
397                                        .as_str()
398                                        .map(str::to_string),
399                                }
400                            })
401                            .collect()
402                    })
403                    .unwrap_or_default();
404
405                Ok(Some(ChatCompletionChunk {
406                    id: json["id"].as_str().unwrap_or_default().to_string(),
407                    object: json["object"]
408                        .as_str()
409                        .unwrap_or("chat.completion.chunk")
410                        .to_string(),
411                    created: json["created"].as_u64().unwrap_or(0),
412                    model: json["model"].as_str().unwrap_or_default().to_string(),
413                    choices,
414                }))
415            }
416            Err(e) => Err(AiLibError::ProviderError(format!(
417                "JSON parse error: {}",
418                e
419            ))),
420        }
421    }
422
423    // legacy test moved to separate file under tests/ when needed
424
425    fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
426        let mut chunks = Vec::new();
427        let mut start = 0;
428        let bytes = text.as_bytes();
429        while start < bytes.len() {
430            let end = std::cmp::min(start + max_len, bytes.len());
431            let mut cut = end;
432            if end < bytes.len() {
433                if let Some(pos) = text[start..end].rfind(' ') {
434                    cut = start + pos;
435                }
436            }
437            if cut == start {
438                cut = end;
439            }
440            chunks.push(String::from_utf8_lossy(&bytes[start..cut]).to_string());
441            start = cut;
442            if start < bytes.len() && bytes[start] == b' ' {
443                start += 1;
444            }
445        }
446        chunks
447    }
448
449    /// Parse response
450    fn parse_response(
451        &self,
452        response: serde_json::Value,
453    ) -> Result<ChatCompletionResponse, AiLibError> {
454        let choices = response["choices"]
455            .as_array()
456            .ok_or_else(|| {
457                AiLibError::ProviderError("Invalid response format: choices not found".to_string())
458            })?
459            .iter()
460            .enumerate()
461            .map(|(index, choice)| {
462                let message = choice["message"].as_object().ok_or_else(|| {
463                    AiLibError::ProviderError("Invalid choice format".to_string())
464                })?;
465
466                let role = match message["role"].as_str().unwrap_or("user") {
467                    "system" => Role::System,
468                    "assistant" => Role::Assistant,
469                    _ => Role::User,
470                };
471
472                let content = message["content"].as_str().unwrap_or("").to_string();
473
474                // try to parse a function_call if present
475                let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
476                if let Some(fc_val) = message.get("function_call") {
477                    // attempt full deserialization
478                    if let Ok(mut fc) = serde_json::from_value::<
479                        crate::types::function_call::FunctionCall,
480                    >(fc_val.clone())
481                    {
482                        // If the provider deserialized arguments as a JSON string, try to parse it into structured JSON.
483                        if let Some(arg_val) = &fc.arguments {
484                            if arg_val.is_string() {
485                                if let Some(s) = arg_val.as_str() {
486                                    if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s)
487                                    {
488                                        fc.arguments = Some(parsed);
489                                    }
490                                }
491                            }
492                        }
493                        function_call = Some(fc);
494                    } else {
495                        // fallback: try to extract name + arguments (arguments may be a string)
496                        let name = fc_val
497                            .get("name")
498                            .and_then(|v| v.as_str())
499                            .map(|s| s.to_string());
500                        if let Some(name) = name {
501                            let args = fc_val.get("arguments").and_then(|a| {
502                                if a.is_string() {
503                                    serde_json::from_str::<serde_json::Value>(a.as_str().unwrap())
504                                        .ok()
505                                } else {
506                                    Some(a.clone())
507                                }
508                            });
509
510                            function_call = Some(crate::types::function_call::FunctionCall {
511                                name,
512                                arguments: args,
513                            });
514                        }
515                    }
516                } else if let Some(tool_calls) = message.get("tool_calls").and_then(|v| v.as_array()) {
517                    // OpenAI tool_calls format: [{"type":"function","function":{"name":...,"arguments":...}}]
518                    if let Some(first) = tool_calls.first() {
519                        if let Some(func) = first.get("function") {
520                            if let Some(name) = func.get("name").and_then(|v| v.as_str()) {
521                                let mut args_opt = func.get("arguments").cloned();
522                                // If arguments is a string, attempt to parse JSON
523                                if let Some(args_val) = &args_opt {
524                                    if args_val.is_string() {
525                                        if let Some(s) = args_val.as_str() {
526                                            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s) {
527                                                args_opt = Some(parsed);
528                                            }
529                                        }
530                                    }
531                                }
532                                function_call = Some(crate::types::function_call::FunctionCall {
533                                    name: name.to_string(),
534                                    arguments: args_opt,
535                                });
536                            }
537                        }
538                    }
539                }
540
541                Ok(Choice {
542                    index: index as u32,
543                    message: Message {
544                        role,
545                        content: crate::types::common::Content::Text(content),
546                        function_call,
547                    },
548                    finish_reason: choice["finish_reason"].as_str().map(|s| s.to_string()),
549                })
550            })
551            .collect::<Result<Vec<_>, AiLibError>>()?;
552
553        let usage = response["usage"].as_object().ok_or_else(|| {
554            AiLibError::ProviderError("Invalid response format: usage not found".to_string())
555        })?;
556
557        let usage = Usage {
558            prompt_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
559            completion_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
560            total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
561        };
562
563        Ok(ChatCompletionResponse {
564            id: response["id"].as_str().unwrap_or("").to_string(),
565            object: response["object"].as_str().unwrap_or("").to_string(),
566            created: response["created"].as_u64().unwrap_or(0),
567            model: response["model"].as_str().unwrap_or("").to_string(),
568            choices,
569            usage,
570            usage_status: UsageStatus::Finalized, // Generic adapter assumes OpenAI-compatible format
571        })
572    }
573}
574#[async_trait::async_trait]
575impl ChatApi for GenericAdapter {
576    async fn chat_completion(
577        &self,
578        request: ChatCompletionRequest,
579    ) -> Result<ChatCompletionResponse, AiLibError> {
580        // metrics: standardized keys
581        let provider_key = "generic";
582        self.metrics
583            .incr_counter(&crate::metrics::keys::requests(provider_key), 1)
584            .await;
585        let timer = self
586            .metrics
587            .start_timer(&crate::metrics::keys::request_duration_ms(provider_key))
588            .await;
589
590        // Build request body & headers
591        let url = self.config.chat_url();
592        let provider_request = self.convert_request(&request).await?;
593        let mut headers = self.config.headers.clone();
594        if let Some(key) = &self.api_key {
595            if self.config.base_url.contains("anthropic.com") {
596                headers.insert("x-api-key".to_string(), key.clone());
597                // Anthropic requires version header per API docs
598                headers.insert("anthropic-version".to_string(), "2023-06-01".to_string());
599            } else {
600                headers.insert("Authorization".to_string(), format!("Bearer {}", key));
601            }
602        }
603
604        // Use transport to allow mocking in tests
605        let response_json = self
606            .transport
607            .post_json(&url, Some(headers), provider_request)
608            .await?;
609
610        // Stop timer
611        if let Some(t) = timer {
612            t.stop();
613        }
614
615        let parsed = self.parse_response(response_json)?;
616
617        // optional cost metrics
618        #[cfg(feature = "cost_metrics")]
619        {
620            let usd = crate::metrics::cost::estimate_usd(
621                parsed.usage.prompt_tokens,
622                parsed.usage.completion_tokens,
623            );
624            crate::metrics::cost::record_cost(
625                self.metrics.as_ref(),
626                provider_key,
627                &parsed.model,
628                usd,
629            )
630            .await;
631        }
632
633        Ok(parsed)
634    }
635
636    async fn chat_completion_stream(
637        &self,
638        request: ChatCompletionRequest,
639    ) -> Result<
640        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
641        AiLibError,
642    > {
643        let mut stream_request = self.convert_request(&request).await?;
644        stream_request["stream"] = serde_json::Value::Bool(true);
645        let url = self.config.chat_url();
646
647        let mut headers = self.config.headers.clone();
648        headers.insert("Accept".to_string(), "text/event-stream".to_string());
649        if let Some(key) = &self.api_key {
650            if self.config.base_url.contains("anthropic.com") {
651                headers.insert("x-api-key".to_string(), key.clone());
652                headers.insert("anthropic-version".to_string(), "2023-06-01".to_string());
653            } else {
654                headers.insert("Authorization".to_string(), format!("Bearer {}", key));
655            }
656        }
657
658        let byte_stream_res = self
659            .transport
660            .post_stream(&url, Some(headers), stream_request)
661            .await;
662
663        match byte_stream_res {
664            Ok(mut byte_stream) => {
665                let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
666                tokio::spawn(async move {
667                    let mut buffer = Vec::new();
668                    while let Some(result) = byte_stream.next().await {
669                        match result {
670                            Ok(bytes) => {
671                                buffer.extend_from_slice(&bytes);
672                                #[cfg(feature = "unified_sse")]
673                                {
674                                    while let Some(event_end) =
675                                        crate::sse::parser::find_event_boundary(&buffer)
676                                    {
677                                        let event_bytes =
678                                            buffer.drain(..event_end).collect::<Vec<_>>();
679                                        if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
680                                            if let Some(chunk) =
681                                                crate::sse::parser::parse_sse_event(event_text)
682                                            {
683                                                match chunk {
684                                                    Ok(Some(c)) => {
685                                                        if tx.send(Ok(c)).is_err() {
686                                                            return;
687                                                        }
688                                                    }
689                                                    Ok(None) => return,
690                                                    Err(e) => {
691                                                        let _ = tx.send(Err(e));
692                                                        return;
693                                                    }
694                                                }
695                                            }
696                                        }
697                                    }
698                                }
699                                #[cfg(not(feature = "unified_sse"))]
700                                {
701                                    while let Some(event_end) = Self::find_event_boundary(&buffer) {
702                                        let event_bytes =
703                                            buffer.drain(..event_end).collect::<Vec<_>>();
704                                        if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
705                                            if let Some(chunk) = Self::parse_sse_event(event_text) {
706                                                match chunk {
707                                                    Ok(Some(c)) => {
708                                                        if tx.send(Ok(c)).is_err() {
709                                                            return;
710                                                        }
711                                                    }
712                                                    Ok(None) => return,
713                                                    Err(e) => {
714                                                        let _ = tx.send(Err(e));
715                                                        return;
716                                                    }
717                                                }
718                                            }
719                                        }
720                                    }
721                                }
722                            }
723                            Err(e) => {
724                                let _ = tx.send(Err(AiLibError::ProviderError(format!(
725                                    "Stream error: {}",
726                                    e
727                                ))));
728                                break;
729                            }
730                        }
731                    }
732                });
733                let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
734                Ok(Box::new(Box::pin(stream)))
735            }
736            Err(_) => {
737                // Fallback to non-streaming + simulated chunks
738                let finished = self.chat_completion(request).await?;
739                let text = finished
740                    .choices
741                    .first()
742                    .map(|c| c.message.content.as_text())
743                    .unwrap_or_default();
744                let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
745                tokio::spawn(async move {
746                    let chunks = Self::split_text_into_chunks(&text, 80);
747                    for chunk in chunks {
748                        let delta = ChoiceDelta {
749                            index: 0,
750                            delta: MessageDelta {
751                                role: Some(Role::Assistant),
752                                content: Some(chunk.clone()),
753                            },
754                            finish_reason: None,
755                        };
756                        let chunk_obj = ChatCompletionChunk {
757                            id: "simulated".to_string(),
758                            object: "chat.completion.chunk".to_string(),
759                            created: 0,
760                            model: finished.model.clone(),
761                            choices: vec![delta],
762                        };
763                        if tx.send(Ok(chunk_obj)).is_err() {
764                            return;
765                        }
766                        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
767                    }
768                });
769                let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
770                Ok(Box::new(Box::pin(stream)))
771            }
772        }
773    }
774
775    async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
776        if let Some(models_endpoint) = &self.config.models_endpoint {
777            let url = format!("{}{}", self.config.base_url, models_endpoint);
778            let mut headers = self.config.headers.clone();
779            if let Some(key) = &self.api_key {
780                if self.config.base_url.contains("anthropic.com") {
781                    headers.insert("x-api-key".to_string(), key.clone());
782                } else {
783                    headers.insert("Authorization".to_string(), format!("Bearer {}", key));
784                }
785            }
786            let response = self.transport.get_json(&url, Some(headers)).await?;
787            Ok(response["data"]
788                .as_array()
789                .unwrap_or(&vec![])
790                .iter()
791                .filter_map(|model| model["id"].as_str().map(|s| s.to_string()))
792                .collect())
793        } else {
794            Err(AiLibError::ProviderError(
795                "Models endpoint not configured".to_string(),
796            ))
797        }
798    }
799
800    async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo, AiLibError> {
801        Ok(ModelInfo {
802            id: model_id.to_string(),
803            object: "model".to_string(),
804            created: 0,
805            owned_by: "generic".to_string(),
806            permission: vec![ModelPermission {
807                id: "default".to_string(),
808                object: "model_permission".to_string(),
809                created: 0,
810                allow_create_engine: false,
811                allow_sampling: true,
812                allow_logprobs: false,
813                allow_search_indices: false,
814                allow_view: true,
815                allow_fine_tuning: false,
816                organization: "*".to_string(),
817                group: None,
818                is_blocking: false,
819            }],
820        })
821    }
822}