Skip to main content

pi/providers/
mod.rs

1//! Provider implementations.
2//!
3//! This module contains concrete implementations of the Provider trait
4//! for various LLM APIs.
5
6use crate::error::{Error, Result};
7use crate::extensions::{ExtensionManager, ExtensionRuntimeHandle};
8use crate::http::client::{Client, RequestBuilder};
9use crate::model::{
10    AssistantMessage, AssistantMessageEvent, ContentBlock, StopReason, TextContent, Usage,
11};
12use crate::models::ModelEntry;
13use crate::provider::{Context, Provider, StreamEvent, StreamOptions};
14use crate::provider_metadata::{
15    PROVIDER_METADATA, canonical_provider_id, provider_routing_defaults,
16};
17use crate::vcr::{VCR_ENV_MODE, VcrRecorder};
18use async_trait::async_trait;
19use chrono::Utc;
20use futures::stream;
21use futures::stream::Stream;
22use serde_json::Value;
23use std::collections::HashMap;
24use std::env;
25use std::pin::Pin;
26use std::sync::Arc;
27use url::Url;
28
29pub mod anthropic;
30pub mod azure;
31pub mod bedrock;
32pub mod cohere;
33pub mod copilot;
34pub mod gemini;
35pub mod gitlab;
36pub mod openai;
37pub mod openai_responses;
38pub mod vertex;
39
40pub(super) fn first_non_empty_header_value_case_insensitive(
41    headers: &HashMap<String, String>,
42    names: &[&str],
43) -> Option<String> {
44    headers.iter().find_map(|(key, value)| {
45        names
46            .iter()
47            .any(|name| key.eq_ignore_ascii_case(name))
48            .then_some(value.trim())
49            .filter(|value| !value.is_empty())
50            .map(ToString::to_string)
51    })
52}
53
54pub(super) fn apply_headers_ignoring_blank_auth_overrides<'a>(
55    mut request: RequestBuilder<'a>,
56    headers: &HashMap<String, String>,
57    auth_names: &[&str],
58) -> RequestBuilder<'a> {
59    for (key, value) in headers {
60        let is_blank_auth_override =
61            auth_names.iter().any(|name| key.eq_ignore_ascii_case(name)) && value.trim().is_empty();
62        if is_blank_auth_override {
63            continue;
64        }
65        request = request.header(key, value);
66    }
67    request
68}
69
70fn vcr_client_if_enabled() -> Result<Option<Client>> {
71    if env::var(VCR_ENV_MODE).is_err() {
72        return Ok(None);
73    }
74
75    let test_name = env::var("PI_VCR_TEST_NAME").unwrap_or_else(|_| "pi_runtime".to_string());
76    let recorder = VcrRecorder::new(&test_name)?;
77    Ok(Some(Client::new().with_vcr(recorder)))
78}
79
80struct ExtensionStreamSimpleProvider {
81    model: crate::provider::Model,
82    runtime: ExtensionRuntimeHandle,
83}
84
85struct ExtensionStreamSimpleState {
86    runtime: ExtensionRuntimeHandle,
87    stream_id: Option<String>,
88    model_id: String,
89    provider: String,
90    api: String,
91    accumulated_text: String,
92    last_message: Option<AssistantMessage>,
93    /// Whether `StreamEvent::Start` + `TextStart` have been emitted for string-chunk mode.
94    string_chunk_started: bool,
95    /// Buffered events to drain before polling the next JS chunk.
96    pending_events: std::collections::VecDeque<StreamEvent>,
97}
98
99impl Drop for ExtensionStreamSimpleState {
100    fn drop(&mut self) {
101        if let Some(stream_id) = self.stream_id.take() {
102            self.runtime
103                .provider_stream_simple_cancel_best_effort(stream_id);
104        }
105    }
106}
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109enum ProviderRouteKind {
110    NativeAnthropic,
111    NativeOpenAICompletions,
112    NativeOpenAIResponses,
113    NativeOpenAICodexResponses,
114    NativeCohere,
115    NativeGoogle,
116    NativeGoogleGeminiCli,
117    NativeGoogleVertex,
118    NativeBedrock,
119    NativeAzure,
120    NativeCopilot,
121    NativeGitlab,
122    ApiAnthropicMessages,
123    ApiOpenAICompletions,
124    ApiOpenAIResponses,
125    ApiOpenAICodexResponses,
126    ApiCohereChat,
127    ApiGoogleGenerativeAi,
128    ApiGoogleGeminiCli,
129}
130
131impl ProviderRouteKind {
132    const fn as_str(self) -> &'static str {
133        match self {
134            Self::NativeAnthropic => "native:anthropic",
135            Self::NativeOpenAICompletions => "native:openai-completions",
136            Self::NativeOpenAIResponses => "native:openai-responses",
137            Self::NativeOpenAICodexResponses => "native:openai-codex-responses",
138            Self::NativeCohere => "native:cohere",
139            Self::NativeGoogle => "native:google",
140            Self::NativeGoogleGeminiCli => "native:google-gemini-cli",
141            Self::NativeGoogleVertex => "native:google-vertex",
142            Self::NativeBedrock => "native:amazon-bedrock",
143            Self::NativeAzure => "native:azure-openai",
144            Self::NativeCopilot => "native:github-copilot",
145            Self::NativeGitlab => "native:gitlab",
146            Self::ApiAnthropicMessages => "api:anthropic-messages",
147            Self::ApiOpenAICompletions => "api:openai-completions",
148            Self::ApiOpenAIResponses => "api:openai-responses",
149            Self::ApiOpenAICodexResponses => "api:openai-codex-responses",
150            Self::ApiCohereChat => "api:cohere-chat",
151            Self::ApiGoogleGenerativeAi => "api:google-generative-ai",
152            Self::ApiGoogleGeminiCli => "api:google-gemini-cli",
153        }
154    }
155}
156
157fn resolve_provider_route(entry: &ModelEntry) -> Result<(ProviderRouteKind, String, String)> {
158    let canonical_provider =
159        canonical_provider_id(&entry.model.provider).unwrap_or(entry.model.provider.as_str());
160    let schema_api = provider_routing_defaults(&entry.model.provider).map(|defaults| defaults.api);
161    let effective_api = if entry.model.api.is_empty() {
162        schema_api.unwrap_or_default().to_string()
163    } else {
164        entry.model.api.clone()
165    };
166
167    let route = match canonical_provider {
168        "anthropic" => ProviderRouteKind::NativeAnthropic,
169        "openai" => {
170            if effective_api == "openai-completions" {
171                ProviderRouteKind::NativeOpenAICompletions
172            } else {
173                ProviderRouteKind::NativeOpenAIResponses
174            }
175        }
176        "openai-codex" => ProviderRouteKind::NativeOpenAICodexResponses,
177        "cohere" => ProviderRouteKind::NativeCohere,
178        "google" => ProviderRouteKind::NativeGoogle,
179        "google-gemini-cli" | "google-antigravity" => ProviderRouteKind::NativeGoogleGeminiCli,
180        "google-vertex" | "vertexai" => ProviderRouteKind::NativeGoogleVertex,
181        "amazon-bedrock" | "bedrock" => ProviderRouteKind::NativeBedrock,
182        "azure-openai" | "azure" | "azure-cognitive-services" | "azure-openai-responses" => {
183            ProviderRouteKind::NativeAzure
184        }
185        "github-copilot" | "copilot" => ProviderRouteKind::NativeCopilot,
186        "gitlab" | "gitlab-duo" => ProviderRouteKind::NativeGitlab,
187        _ => match effective_api.as_str() {
188            "anthropic-messages" => ProviderRouteKind::ApiAnthropicMessages,
189            "openai-completions" => ProviderRouteKind::ApiOpenAICompletions,
190            "openai-responses" => ProviderRouteKind::ApiOpenAIResponses,
191            "openai-codex-responses" => ProviderRouteKind::ApiOpenAICodexResponses,
192            "cohere-chat" => ProviderRouteKind::ApiCohereChat,
193            "google-generative-ai" => ProviderRouteKind::ApiGoogleGenerativeAi,
194            "google-gemini-cli" => ProviderRouteKind::ApiGoogleGeminiCli,
195            "google-vertex" => ProviderRouteKind::NativeGoogleVertex,
196            "bedrock-converse-stream" => ProviderRouteKind::NativeBedrock,
197            "azure-openai-responses" => ProviderRouteKind::NativeAzure,
198            _ => {
199                let suggestions = suggest_similar_providers(&entry.model.provider);
200                let msg = if suggestions.is_empty() {
201                    format!("Provider not implemented (api: {effective_api})")
202                } else {
203                    format!(
204                        "Provider not implemented (api: {effective_api}). Did you mean: {}?",
205                        suggestions.join(", ")
206                    )
207                };
208                return Err(Error::provider(&entry.model.provider, msg));
209            }
210        },
211    };
212
213    Ok((route, canonical_provider.to_string(), effective_api))
214}
215
216/// Levenshtein edit distance between two byte slices. Uses a single-row
217/// buffer so memory is O(min(a,b)).
218fn edit_distance(a: &[u8], b: &[u8]) -> usize {
219    let (short, long) = if a.len() <= b.len() { (a, b) } else { (b, a) };
220    let mut row: Vec<usize> = (0..=short.len()).collect();
221    for (i, &lb) in long.iter().enumerate() {
222        let mut prev = i;
223        row[0] = i + 1;
224        for (j, &sb) in short.iter().enumerate() {
225            let cost = usize::from(lb != sb);
226            let val = (row[j + 1] + 1).min(row[j] + 1).min(prev + cost);
227            prev = row[j + 1];
228            row[j + 1] = val;
229        }
230    }
231    row[short.len()]
232}
233
234/// Maximum edit distance allowed for a fuzzy suggestion, scaled by the
235/// length of the input so very short inputs don't produce false positives.
236const fn max_edit_distance(input_len: usize) -> usize {
237    match input_len {
238        0..=2 => 0,
239        3..=5 => 1,
240        6..=9 => 2,
241        _ => 3,
242    }
243}
244
245/// Suggest provider names similar to `input` by checking prefix matching,
246/// substring containment, and Levenshtein edit distance against all
247/// canonical IDs and aliases.
248fn suggest_similar_providers(input: &str) -> Vec<String> {
249    let needle = input.to_lowercase();
250    let needle_bytes = needle.as_bytes();
251    let threshold = max_edit_distance(needle.len());
252    let mut matches: Vec<(usize, String)> = Vec::new();
253
254    for meta in PROVIDER_METADATA {
255        let names: Vec<&str> = std::iter::once(meta.canonical_id)
256            .chain(meta.aliases.iter().copied())
257            .collect();
258        let mut matched = false;
259        for name in &names {
260            let haystack = name.to_lowercase();
261            // Tier 0: exact prefix match (highest quality)
262            if haystack.starts_with(&needle) || needle.starts_with(&haystack) {
263                matches.push((0, meta.canonical_id.to_string()));
264                matched = true;
265                break;
266            }
267            // Tier 1: substring containment
268            if haystack.contains(&needle) || needle.contains(&haystack) {
269                matches.push((1, meta.canonical_id.to_string()));
270                matched = true;
271                break;
272            }
273        }
274        if matched {
275            continue;
276        }
277        // Tier 2: edit distance (typo correction)
278        if threshold > 0 {
279            let mut best_dist = usize::MAX;
280            for name in &names {
281                let haystack = name.to_lowercase();
282                let dist = edit_distance(needle_bytes, haystack.as_bytes());
283                best_dist = best_dist.min(dist);
284            }
285            if best_dist <= threshold {
286                // Encode distance in the sort key so closer matches rank higher
287                matches.push((
288                    2_usize.wrapping_add(best_dist),
289                    meta.canonical_id.to_string(),
290                ));
291            }
292        }
293    }
294
295    matches.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
296    matches.dedup_by(|a, b| a.1 == b.1);
297    matches.truncate(3);
298    matches.into_iter().map(|(_, name)| name).collect()
299}
300
301const AZURE_OPENAI_RESOURCE_ENV: &str = "AZURE_OPENAI_RESOURCE";
302const AZURE_OPENAI_DEPLOYMENT_ENV: &str = "AZURE_OPENAI_DEPLOYMENT";
303const AZURE_OPENAI_API_VERSION_ENV: &str = "AZURE_OPENAI_API_VERSION";
304
305#[derive(Debug, Clone, PartialEq, Eq)]
306struct AzureProviderRuntime {
307    resource: String,
308    deployment: String,
309    api_version: String,
310    endpoint_url: String,
311}
312
313fn trim_non_empty(value: Option<String>) -> Option<String> {
314    value
315        .map(|v| v.trim().to_string())
316        .filter(|v| !v.is_empty())
317}
318
319fn parse_azure_resource_from_host(host: &str) -> Option<String> {
320    host.strip_suffix(".openai.azure.com")
321        .or_else(|| host.strip_suffix(".cognitiveservices.azure.com"))
322        .map(str::trim)
323        .filter(|value| !value.is_empty())
324        .map(ToString::to_string)
325}
326
327fn parse_azure_base_url_details(
328    base_url: &str,
329) -> Result<(String, Option<String>, Option<String>)> {
330    let url = Url::parse(base_url)
331        .map_err(|err| Error::config(format!("Invalid Azure base_url '{base_url}': {err}")))?;
332    let host = url.host_str().map(ToString::to_string).ok_or_else(|| {
333        Error::config(format!(
334            "Azure base_url is missing host information: '{base_url}'"
335        ))
336    })?;
337
338    let mut deployment = None;
339    if let Some(segments) = url.path_segments() {
340        let mut iter = segments;
341        while let Some(segment) = iter.next() {
342            if segment == "deployments" {
343                deployment = iter
344                    .next()
345                    .map(str::trim)
346                    .filter(|value| !value.is_empty())
347                    .map(ToString::to_string);
348                break;
349            }
350        }
351    }
352
353    let api_version = url
354        .query_pairs()
355        .find(|(key, _)| key == "api-version")
356        .map(|(_, value)| value.into_owned())
357        .filter(|value| !value.trim().is_empty());
358
359    Ok((host, deployment, api_version))
360}
361
362fn resolve_azure_provider_runtime(entry: &ModelEntry) -> Result<AzureProviderRuntime> {
363    resolve_azure_provider_runtime_with_env(entry, |name| env::var(name).ok())
364}
365
366fn resolve_azure_provider_runtime_with_env<F>(
367    entry: &ModelEntry,
368    mut env_lookup: F,
369) -> Result<AzureProviderRuntime>
370where
371    F: FnMut(&str) -> Option<String>,
372{
373    let base_url = entry.model.base_url.trim();
374    if base_url.is_empty() {
375        return Err(Error::config(format!(
376            "Missing Azure base_url for provider '{}'; expected https://<resource>.openai.azure.com or https://<resource>.cognitiveservices.azure.com",
377            entry.model.provider
378        )));
379    }
380
381    let (host, base_deployment, base_api_version) = parse_azure_base_url_details(base_url)?;
382    let host_resource = parse_azure_resource_from_host(&host);
383    let env_resource = trim_non_empty(env_lookup(AZURE_OPENAI_RESOURCE_ENV));
384    let resource = env_resource.or(host_resource).ok_or_else(|| {
385        Error::config(format!(
386            "Unable to resolve Azure resource for provider '{}'; set {AZURE_OPENAI_RESOURCE_ENV} or use an Azure host in base_url ('{base_url}')",
387            entry.model.provider
388        ))
389    })?;
390
391    let env_deployment = trim_non_empty(env_lookup(AZURE_OPENAI_DEPLOYMENT_ENV));
392    let model_deployment = {
393        let model_id = entry.model.id.trim();
394        (!model_id.is_empty()).then(|| model_id.to_string())
395    };
396    let deployment = env_deployment
397        .or(base_deployment)
398        .or(model_deployment)
399        .ok_or_else(|| {
400            Error::config(format!(
401                "Unable to resolve Azure deployment for provider '{}'; set {AZURE_OPENAI_DEPLOYMENT_ENV}, provide a non-empty model id, or include '/deployments/<name>' in base_url ('{base_url}')",
402                entry.model.provider
403            ))
404        })?;
405
406    let api_version = trim_non_empty(env_lookup(AZURE_OPENAI_API_VERSION_ENV))
407        .or(base_api_version)
408        .unwrap_or_else(azure::azure_api_version);
409
410    let endpoint_host = if parse_azure_resource_from_host(&host).is_some() {
411        host
412    } else {
413        format!("{resource}.openai.azure.com")
414    };
415    let endpoint_url = format!(
416        "https://{endpoint_host}/openai/deployments/{deployment}/chat/completions?api-version={api_version}"
417    );
418
419    Ok(AzureProviderRuntime {
420        resource,
421        deployment,
422        api_version,
423        endpoint_url,
424    })
425}
426
427fn resolve_copilot_token(entry: &ModelEntry) -> Result<String> {
428    resolve_copilot_token_with_env(entry, |name| env::var(name).ok())
429}
430
431fn resolve_copilot_token_with_env<F>(entry: &ModelEntry, mut env_lookup: F) -> Result<String>
432where
433    F: FnMut(&str) -> Option<String>,
434{
435    let inline = entry
436        .api_key
437        .as_deref()
438        .map(str::trim)
439        .filter(|value| !value.is_empty())
440        .map(ToString::to_string);
441    let from_env = || {
442        env_lookup("GITHUB_COPILOT_API_KEY")
443            .or_else(|| env_lookup("GITHUB_TOKEN"))
444            .map(|value| value.trim().to_string())
445            .filter(|value| !value.is_empty())
446    };
447
448    inline.or_else(from_env).ok_or_else(|| {
449        Error::auth(
450            "GitHub Copilot requires login credentials or GITHUB_COPILOT_API_KEY/GITHUB_TOKEN",
451        )
452    })
453}
454
455impl ExtensionStreamSimpleProvider {
456    const NEXT_TIMEOUT_MS: u64 = 600_000;
457
458    const fn new(model: crate::provider::Model, runtime: ExtensionRuntimeHandle) -> Self {
459        Self { model, runtime }
460    }
461
462    fn build_js_model(model: &crate::provider::Model) -> Value {
463        serde_json::json!({
464            "id": &model.id,
465            "name": &model.name,
466            "api": &model.api,
467            "provider": &model.provider,
468            "baseUrl": &model.base_url,
469            "reasoning": model.reasoning,
470            "input": &model.input,
471            "cost": &model.cost,
472            "contextWindow": model.context_window,
473            "maxTokens": model.max_tokens,
474            "headers": &model.headers,
475        })
476    }
477
478    fn build_js_context(context: &Context<'_>) -> Value {
479        let mut map = serde_json::Map::new();
480        if let Some(system_prompt) = &context.system_prompt {
481            map.insert(
482                "systemPrompt".to_string(),
483                Value::String(system_prompt.to_string()),
484            );
485        }
486        map.insert(
487            "messages".to_string(),
488            serde_json::to_value(&context.messages).unwrap_or(Value::Array(Vec::new())),
489        );
490        if !context.tools.is_empty() {
491            let tools = context
492                .tools
493                .iter()
494                .map(|tool| {
495                    serde_json::json!({
496                        "name": tool.name,
497                        "description": tool.description,
498                        "parameters": tool.parameters,
499                    })
500                })
501                .collect::<Vec<_>>();
502            map.insert("tools".to_string(), Value::Array(tools));
503        }
504        Value::Object(map)
505    }
506
507    fn build_js_options(options: &StreamOptions) -> Value {
508        let mut map = serde_json::Map::new();
509        if let Some(temp) = options.temperature {
510            map.insert("temperature".to_string(), serde_json::json!(temp));
511        }
512        if let Some(max_tokens) = options.max_tokens {
513            map.insert("maxTokens".to_string(), serde_json::json!(max_tokens));
514        }
515        if let Some(api_key) = &options.api_key {
516            map.insert("apiKey".to_string(), Value::String(api_key.clone()));
517        }
518        if let Some(session_id) = &options.session_id {
519            map.insert("sessionId".to_string(), Value::String(session_id.clone()));
520        }
521        if !options.headers.is_empty() {
522            map.insert(
523                "headers".to_string(),
524                serde_json::to_value(&options.headers)
525                    .unwrap_or_else(|_| Value::Object(serde_json::Map::new())),
526            );
527        }
528        let cache_retention = match options.cache_retention {
529            crate::provider::CacheRetention::None => "none",
530            crate::provider::CacheRetention::Short => "short",
531            crate::provider::CacheRetention::Long => "long",
532        };
533        map.insert(
534            "cacheRetention".to_string(),
535            Value::String(cache_retention.to_string()),
536        );
537        if let Some(level) = options.thinking_level {
538            if level != crate::model::ThinkingLevel::Off {
539                map.insert("reasoning".to_string(), Value::String(level.to_string()));
540            }
541        }
542        if let Some(budgets) = &options.thinking_budgets {
543            map.insert(
544                "thinkingBudgets".to_string(),
545                serde_json::json!({
546                    "minimal": budgets.minimal,
547                    "low": budgets.low,
548                    "medium": budgets.medium,
549                    "high": budgets.high,
550                    "xhigh": budgets.xhigh,
551                }),
552            );
553        }
554        Value::Object(map)
555    }
556
557    fn assistant_event_to_stream_event(event: AssistantMessageEvent) -> StreamEvent {
558        match event {
559            AssistantMessageEvent::Start { partial } => StreamEvent::Start {
560                partial: partial.as_ref().clone(),
561            },
562            AssistantMessageEvent::TextStart { content_index, .. } => {
563                StreamEvent::TextStart { content_index }
564            }
565            AssistantMessageEvent::TextDelta {
566                content_index,
567                delta,
568                ..
569            } => StreamEvent::TextDelta {
570                content_index,
571                delta,
572            },
573            AssistantMessageEvent::TextEnd {
574                content_index,
575                content,
576                ..
577            } => StreamEvent::TextEnd {
578                content_index,
579                content,
580            },
581            AssistantMessageEvent::ThinkingStart { content_index, .. } => {
582                StreamEvent::ThinkingStart { content_index }
583            }
584            AssistantMessageEvent::ThinkingDelta {
585                content_index,
586                delta,
587                ..
588            } => StreamEvent::ThinkingDelta {
589                content_index,
590                delta,
591            },
592            AssistantMessageEvent::ThinkingEnd {
593                content_index,
594                content,
595                ..
596            } => StreamEvent::ThinkingEnd {
597                content_index,
598                content,
599            },
600            AssistantMessageEvent::ToolCallStart { content_index, .. } => {
601                StreamEvent::ToolCallStart { content_index }
602            }
603            AssistantMessageEvent::ToolCallDelta {
604                content_index,
605                delta,
606                ..
607            } => StreamEvent::ToolCallDelta {
608                content_index,
609                delta,
610            },
611            AssistantMessageEvent::ToolCallEnd {
612                content_index,
613                tool_call,
614                ..
615            } => StreamEvent::ToolCallEnd {
616                content_index,
617                tool_call,
618            },
619            AssistantMessageEvent::Done { reason, message } => StreamEvent::Done {
620                reason,
621                message: message.as_ref().clone(),
622            },
623            AssistantMessageEvent::Error { reason, error } => StreamEvent::Error {
624                reason,
625                error: error.as_ref().clone(),
626            },
627        }
628    }
629
630    fn make_partial(model_id: &str, provider: &str, api: &str, text: &str) -> AssistantMessage {
631        AssistantMessage {
632            model: model_id.to_string(),
633            api: api.to_string(),
634            provider: provider.to_string(),
635            content: vec![ContentBlock::Text(TextContent {
636                text: text.to_string(),
637                text_signature: None,
638            })],
639            stop_reason: StopReason::default(),
640            usage: Usage::default(),
641            error_message: None,
642            timestamp: Utc::now().timestamp_millis(),
643        }
644    }
645}
646
647#[allow(clippy::too_many_lines)]
648#[async_trait]
649impl Provider for ExtensionStreamSimpleProvider {
650    #[allow(clippy::misnamed_getters)]
651    fn name(&self) -> &str {
652        &self.model.provider
653    }
654
655    fn api(&self) -> &str {
656        &self.model.api
657    }
658
659    fn model_id(&self) -> &str {
660        &self.model.id
661    }
662
663    async fn stream(
664        &self,
665        context: &Context<'_>,
666        options: &StreamOptions,
667    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
668        let model = Self::build_js_model(&self.model);
669        let ctx = Self::build_js_context(context);
670        let opts = Self::build_js_options(options);
671
672        let stream_id = self
673            .runtime
674            .provider_stream_simple_start(
675                self.model.provider.clone(),
676                model,
677                ctx,
678                opts,
679                Self::NEXT_TIMEOUT_MS,
680            )
681            .await?;
682
683        let state = ExtensionStreamSimpleState {
684            runtime: self.runtime.clone(),
685            stream_id: Some(stream_id),
686            model_id: self.model.id.clone(),
687            provider: self.model.provider.clone(),
688            api: self.model.api.clone(),
689            accumulated_text: String::new(),
690            last_message: None,
691            string_chunk_started: false,
692            pending_events: std::collections::VecDeque::new(),
693        };
694
695        let stream = stream::unfold(state, |mut state| async move {
696            // Drain any buffered events before polling JS.
697            if let Some(event) = state.pending_events.pop_front() {
698                return Some((Ok(event), state));
699            }
700
701            let stream_id = state.stream_id.clone()?;
702            let stream_id_for_cancel = stream_id.clone();
703
704            match state
705                .runtime
706                .provider_stream_simple_next(stream_id, Self::NEXT_TIMEOUT_MS)
707                .await
708            {
709                Ok(Some(value)) => {
710                    if let Some(chunk) = value.as_str() {
711                        let chunk = chunk.to_string();
712                        state.accumulated_text.push_str(&chunk);
713                        // Update last_message in-place: mutate existing text
714                        // content instead of rebuilding the entire
715                        // AssistantMessage (avoids 3 String + Vec allocs per
716                        // chunk).
717                        match &mut state.last_message {
718                            Some(msg) => {
719                                if let Some(ContentBlock::Text(t)) = msg.content.first_mut() {
720                                    t.text.clone_from(&state.accumulated_text);
721                                }
722                            }
723                            None => {
724                                state.last_message = Some(Self::make_partial(
725                                    &state.model_id,
726                                    &state.provider,
727                                    &state.api,
728                                    &state.accumulated_text,
729                                ));
730                            }
731                        }
732
733                        // Emit Start + TextStart before first string-chunk TextDelta.
734                        if !state.string_chunk_started {
735                            state.string_chunk_started = true;
736                            state
737                                .pending_events
738                                .push_back(StreamEvent::TextStart { content_index: 0 });
739                            state.pending_events.push_back(StreamEvent::TextDelta {
740                                content_index: 0,
741                                delta: chunk,
742                            });
743                            // Raw string mode still streams deltas chunk-by-chunk, so the
744                            // synthetic Start event must begin empty. Otherwise the agent
745                            // seeds the partial with the first chunk and then appends that
746                            // same first delta again.
747                            return Some((
748                                Ok(StreamEvent::Start {
749                                    partial: Self::make_partial(
750                                        &state.model_id,
751                                        &state.provider,
752                                        &state.api,
753                                        "",
754                                    ),
755                                }),
756                                state,
757                            ));
758                        }
759                        return Some((
760                            Ok(StreamEvent::TextDelta {
761                                content_index: 0,
762                                delta: chunk,
763                            }),
764                            state,
765                        ));
766                    }
767
768                    let event: AssistantMessageEvent = match serde_json::from_value(value) {
769                        Ok(event) => event,
770                        Err(err) => {
771                            state
772                                .runtime
773                                .provider_stream_simple_cancel_best_effort(stream_id_for_cancel);
774                            state.stream_id = None;
775                            return Some((
776                                Err(Error::extension(format!(
777                                    "streamSimple yielded invalid event: {err}"
778                                ))),
779                                state,
780                            ));
781                        }
782                    };
783
784                    match &event {
785                        AssistantMessageEvent::Start { partial }
786                        | AssistantMessageEvent::TextStart { partial, .. }
787                        | AssistantMessageEvent::TextDelta { partial, .. }
788                        | AssistantMessageEvent::TextEnd { partial, .. }
789                        | AssistantMessageEvent::ThinkingStart { partial, .. }
790                        | AssistantMessageEvent::ThinkingDelta { partial, .. }
791                        | AssistantMessageEvent::ThinkingEnd { partial, .. }
792                        | AssistantMessageEvent::ToolCallStart { partial, .. }
793                        | AssistantMessageEvent::ToolCallDelta { partial, .. }
794                        | AssistantMessageEvent::ToolCallEnd { partial, .. } => {
795                            state.last_message = Some(partial.as_ref().clone());
796                        }
797                        AssistantMessageEvent::Done { message, .. } => {
798                            state.last_message = Some(message.as_ref().clone());
799                        }
800                        AssistantMessageEvent::Error { error, .. } => {
801                            state.last_message = Some(error.as_ref().clone());
802                        }
803                    }
804
805                    let stream_event = Self::assistant_event_to_stream_event(event);
806                    if matches!(
807                        stream_event,
808                        StreamEvent::Done { .. } | StreamEvent::Error { .. }
809                    ) {
810                        state
811                            .runtime
812                            .provider_stream_simple_cancel_best_effort(stream_id_for_cancel);
813                        state.stream_id = None;
814                    }
815                    Some((Ok(stream_event), state))
816                }
817                Ok(None) => {
818                    // Stream ended — emit TextEnd (if string chunks were used) then Done.
819                    state.stream_id = None;
820                    let message = state.last_message.clone().unwrap_or_else(|| {
821                        Self::make_partial(
822                            &state.model_id,
823                            &state.provider,
824                            &state.api,
825                            &state.accumulated_text,
826                        )
827                    });
828
829                    if state.string_chunk_started {
830                        // Emit TextEnd before Done.
831                        state.pending_events.push_back(StreamEvent::Done {
832                            reason: StopReason::Stop,
833                            message,
834                        });
835                        Some((
836                            Ok(StreamEvent::TextEnd {
837                                content_index: 0,
838                                content: state.accumulated_text.clone(),
839                            }),
840                            state,
841                        ))
842                    } else {
843                        Some((
844                            Ok(StreamEvent::Done {
845                                reason: StopReason::Stop,
846                                message,
847                            }),
848                            state,
849                        ))
850                    }
851                }
852                Err(err) => {
853                    state
854                        .runtime
855                        .provider_stream_simple_cancel_best_effort(stream_id_for_cancel);
856                    state.stream_id = None;
857                    Some((Err(err), state))
858                }
859            }
860        });
861
862        Ok(Box::pin(stream))
863    }
864}
865
866#[allow(clippy::too_many_lines)]
867pub fn create_provider(
868    entry: &ModelEntry,
869    extensions: Option<&ExtensionManager>,
870) -> Result<Arc<dyn Provider>> {
871    if let Some(manager) = extensions {
872        if manager.provider_has_stream_simple(&entry.model.provider) {
873            let runtime = manager.runtime().ok_or_else(|| {
874                Error::provider(
875                    &entry.model.provider,
876                    "Extension runtime not configured for streamSimple provider",
877                )
878            })?;
879            return Ok(Arc::new(ExtensionStreamSimpleProvider::new(
880                entry.model.clone(),
881                runtime,
882            )));
883        }
884    }
885
886    let vcr_client = vcr_client_if_enabled()?;
887    let client = vcr_client.unwrap_or_else(Client::new);
888    let (route, canonical_provider, effective_api) = resolve_provider_route(entry)?;
889    tracing::debug!(
890        event = "pi.provider.factory.select",
891        provider = %entry.model.provider,
892        canonical_provider = %canonical_provider,
893        api = %effective_api,
894        base_url = %entry.model.base_url,
895        route = %route.as_str(),
896        "Selecting provider implementation"
897    );
898
899    match route {
900        ProviderRouteKind::NativeAnthropic | ProviderRouteKind::ApiAnthropicMessages => {
901            Ok(Arc::new(
902                anthropic::AnthropicProvider::new(entry.model.id.clone())
903                    .with_provider_name(entry.model.provider.clone())
904                    .with_base_url(normalize_anthropic_base(&entry.model.base_url))
905                    .with_compat(entry.compat.clone())
906                    .with_client(client),
907            ))
908        }
909        ProviderRouteKind::NativeOpenAICompletions | ProviderRouteKind::ApiOpenAICompletions => {
910            Ok(Arc::new(
911                openai::OpenAIProvider::new(entry.model.id.clone())
912                    .with_provider_name(entry.model.provider.clone())
913                    .with_base_url(normalize_openai_base(&entry.model.base_url))
914                    .with_compat(entry.compat.clone())
915                    .with_client(client),
916            ))
917        }
918        ProviderRouteKind::NativeOpenAIResponses | ProviderRouteKind::ApiOpenAIResponses => {
919            Ok(Arc::new(
920                openai_responses::OpenAIResponsesProvider::new(entry.model.id.clone())
921                    .with_provider_name(entry.model.provider.clone())
922                    .with_base_url(normalize_openai_responses_base(&entry.model.base_url))
923                    .with_compat(entry.compat.clone())
924                    .with_client(client),
925            ))
926        }
927        ProviderRouteKind::NativeOpenAICodexResponses
928        | ProviderRouteKind::ApiOpenAICodexResponses => Ok(Arc::new(
929            openai_responses::OpenAIResponsesProvider::new(entry.model.id.clone())
930                .with_provider_name(entry.model.provider.clone())
931                .with_api_name("openai-codex-responses")
932                .with_codex_mode(true)
933                .with_base_url(normalize_openai_codex_responses_base(&entry.model.base_url))
934                .with_compat(entry.compat.clone())
935                .with_client(client),
936        )),
937        ProviderRouteKind::NativeCohere | ProviderRouteKind::ApiCohereChat => Ok(Arc::new(
938            cohere::CohereProvider::new(entry.model.id.clone())
939                .with_provider_name(entry.model.provider.clone())
940                .with_base_url(normalize_cohere_base(&entry.model.base_url))
941                .with_compat(entry.compat.clone())
942                .with_client(client),
943        )),
944        ProviderRouteKind::NativeGoogle | ProviderRouteKind::ApiGoogleGenerativeAi => Ok(Arc::new(
945            gemini::GeminiProvider::new(entry.model.id.clone())
946                .with_provider_name(entry.model.provider.clone())
947                .with_api_name("google-generative-ai")
948                .with_base_url(entry.model.base_url.clone())
949                .with_compat(entry.compat.clone())
950                .with_client(client),
951        )),
952        ProviderRouteKind::NativeGoogleGeminiCli | ProviderRouteKind::ApiGoogleGeminiCli => {
953            Ok(Arc::new(
954                gemini::GeminiProvider::new(entry.model.id.clone())
955                    .with_provider_name(entry.model.provider.clone())
956                    .with_api_name("google-gemini-cli")
957                    .with_google_cli_mode(true)
958                    .with_base_url(entry.model.base_url.clone())
959                    .with_compat(entry.compat.clone())
960                    .with_client(client),
961            ))
962        }
963        ProviderRouteKind::NativeGoogleVertex => {
964            let runtime = vertex::resolve_vertex_provider_runtime(entry)?;
965            Ok(Arc::new(
966                vertex::VertexProvider::new(runtime.model)
967                    .with_project(runtime.project)
968                    .with_location(runtime.location)
969                    .with_publisher(runtime.publisher)
970                    .with_compat(entry.compat.clone())
971                    .with_client(client),
972            ))
973        }
974        ProviderRouteKind::NativeBedrock => Ok(Arc::new(
975            bedrock::BedrockProvider::new(&entry.model.id)
976                .with_provider_name(&entry.model.provider)
977                .with_base_url(&entry.model.base_url)
978                .with_compat(entry.compat.clone())
979                .with_client(client),
980        )),
981        ProviderRouteKind::NativeAzure => {
982            let runtime = resolve_azure_provider_runtime(entry)?;
983            Ok(Arc::new(
984                azure::AzureOpenAIProvider::new(runtime.resource, runtime.deployment)
985                    .with_provider_name(&entry.model.provider)
986                    .with_api_version(runtime.api_version)
987                    .with_endpoint_url(runtime.endpoint_url)
988                    .with_compat(entry.compat.clone())
989                    .with_client(client),
990            ))
991        }
992        ProviderRouteKind::NativeCopilot => {
993            let github_token = resolve_copilot_token(entry)?;
994            let mut provider = copilot::CopilotProvider::new(&entry.model.id, github_token)
995                .with_provider_name(&entry.model.provider)
996                .with_compat(entry.compat.clone())
997                .with_client(client);
998            if !entry.model.base_url.is_empty() {
999                provider = provider.with_github_api_base(&entry.model.base_url);
1000            }
1001            Ok(Arc::new(provider))
1002        }
1003        ProviderRouteKind::NativeGitlab => Ok(Arc::new(
1004            gitlab::GitLabProvider::new(&entry.model.id)
1005                .with_provider_name(&entry.model.provider)
1006                .with_base_url(&entry.model.base_url)
1007                .with_compat(entry.compat.clone())
1008                .with_client(client),
1009        )),
1010    }
1011}
1012
1013pub fn normalize_anthropic_base(base_url: &str) -> String {
1014    let trimmed = base_url.trim();
1015    if trimmed.is_empty() {
1016        return "https://api.anthropic.com/v1/messages".to_string();
1017    }
1018
1019    let mut base_for_fallback = trimmed.trim_end_matches('/').to_string();
1020
1021    if let Ok(url) = Url::parse(trimmed) {
1022        if url.cannot_be_a_base() {
1023            base_for_fallback = url.as_str().trim_end_matches('/').to_string();
1024        } else {
1025            if trimmed_url_path(&url).ends_with("/v1/messages") {
1026                return canonicalize_url_path(&url);
1027            }
1028            return append_url_path(&url, "v1/messages");
1029        }
1030    }
1031
1032    let base_url = base_for_fallback;
1033    if base_url.ends_with("/v1/messages") {
1034        return base_url;
1035    }
1036    format!("{base_url}/v1/messages")
1037}
1038
1039fn trimmed_url_path(url: &Url) -> &str {
1040    match url.path().trim_end_matches('/') {
1041        "" => "/",
1042        trimmed => trimmed,
1043    }
1044}
1045
1046fn canonicalize_url_path(url: &Url) -> String {
1047    let mut canonical = url.clone();
1048    canonical.set_path(trimmed_url_path(url));
1049    canonical.to_string()
1050}
1051
1052fn replace_url_path(url: &Url, path: &str) -> String {
1053    let mut updated = url.clone();
1054    updated.set_path(path);
1055    updated.to_string()
1056}
1057
1058fn append_url_path(url: &Url, suffix: &str) -> String {
1059    let base_path = trimmed_url_path(url);
1060    let path = if base_path == "/" {
1061        format!("/{suffix}")
1062    } else {
1063        format!("{base_path}/{suffix}")
1064    };
1065    replace_url_path(url, &path)
1066}
1067
1068fn strip_url_path_suffix(url: &Url, suffix: &str) -> Option<Url> {
1069    let base_path = trimmed_url_path(url);
1070    let prefix = base_path.strip_suffix(suffix)?;
1071    let mut stripped = url.clone();
1072    stripped.set_path(if prefix.is_empty() { "/" } else { prefix });
1073    Some(stripped)
1074}
1075
1076fn is_official_https_origin(url: &Url, host: &str, default_port: u16) -> bool {
1077    url.scheme().eq_ignore_ascii_case("https")
1078        && url
1079            .host_str()
1080            .is_some_and(|candidate| candidate.eq_ignore_ascii_case(host))
1081        && url.port_or_known_default() == Some(default_port)
1082        && trimmed_url_path(url) == "/"
1083}
1084
1085pub fn normalize_openai_base(base_url: &str) -> String {
1086    let trimmed = base_url.trim();
1087    if trimmed.is_empty() {
1088        return "https://api.openai.com/v1/chat/completions".to_string();
1089    }
1090
1091    let mut base_for_fallback = trimmed.trim_end_matches('/').to_string();
1092
1093    if let Ok(url) = Url::parse(trimmed) {
1094        if url.cannot_be_a_base() {
1095            base_for_fallback = url.as_str().trim_end_matches('/').to_string();
1096        } else {
1097            if trimmed_url_path(&url).ends_with("/chat/completions") {
1098                return canonicalize_url_path(&url);
1099            }
1100            let url = strip_url_path_suffix(&url, "/responses").unwrap_or(url);
1101            if is_official_https_origin(&url, "api.openai.com", 443) {
1102                return replace_url_path(&url, "/v1/chat/completions");
1103            }
1104            return append_url_path(&url, "chat/completions");
1105        }
1106    }
1107
1108    let base_url = base_for_fallback;
1109    if base_url.ends_with("/chat/completions") {
1110        return base_url;
1111    }
1112    let base_url = base_url
1113        .strip_suffix("/responses")
1114        .unwrap_or(base_url.as_str());
1115    format!("{base_url}/chat/completions")
1116}
1117
1118pub fn normalize_openai_responses_base(base_url: &str) -> String {
1119    let trimmed = base_url.trim();
1120    if trimmed.is_empty() {
1121        return "https://api.openai.com/v1/responses".to_string();
1122    }
1123
1124    let mut base_for_fallback = trimmed.trim_end_matches('/').to_string();
1125
1126    if let Ok(url) = Url::parse(trimmed) {
1127        if url.cannot_be_a_base() {
1128            base_for_fallback = url.as_str().trim_end_matches('/').to_string();
1129        } else {
1130            if trimmed_url_path(&url).ends_with("/responses") {
1131                return canonicalize_url_path(&url);
1132            }
1133            let url = strip_url_path_suffix(&url, "/chat/completions").unwrap_or(url);
1134            if is_official_https_origin(&url, "api.openai.com", 443) {
1135                return replace_url_path(&url, "/v1/responses");
1136            }
1137            return append_url_path(&url, "responses");
1138        }
1139    }
1140
1141    let base_url = base_for_fallback;
1142    if base_url.ends_with("/responses") {
1143        return base_url;
1144    }
1145    let base_url = base_url
1146        .strip_suffix("/chat/completions")
1147        .unwrap_or(base_url.as_str());
1148    format!("{base_url}/responses")
1149}
1150
1151pub fn normalize_openai_codex_responses_base(base_url: &str) -> String {
1152    let trimmed = base_url.trim();
1153    if trimmed.is_empty() {
1154        return openai_responses::CODEX_RESPONSES_API_URL.to_string();
1155    }
1156
1157    let mut base_for_fallback = trimmed.trim_end_matches('/').to_string();
1158
1159    if let Ok(url) = Url::parse(trimmed) {
1160        if url.cannot_be_a_base() {
1161            base_for_fallback = url.as_str().trim_end_matches('/').to_string();
1162        } else {
1163            let path = trimmed_url_path(&url);
1164            if path.ends_with("/backend-api/codex/responses") || path.ends_with("/responses") {
1165                return canonicalize_url_path(&url);
1166            }
1167            if path.ends_with("/backend-api") {
1168                return append_url_path(&url, "codex/responses");
1169            }
1170            return append_url_path(&url, "backend-api/codex/responses");
1171        }
1172    }
1173
1174    let base = base_for_fallback;
1175    if base.ends_with("/backend-api/codex/responses") {
1176        return base;
1177    }
1178    // Some registries (including legacy Pi) store the ChatGPT base as
1179    // `https://chatgpt.com/backend-api`. In that case we only want to append
1180    // `/codex/responses`, not `/backend-api/codex/responses` again.
1181    if base.ends_with("/backend-api") {
1182        return format!("{base}/codex/responses");
1183    }
1184    if base.ends_with("/responses") {
1185        return base;
1186    }
1187    format!("{base}/backend-api/codex/responses")
1188}
1189
1190pub fn normalize_cohere_base(base_url: &str) -> String {
1191    let trimmed = base_url.trim();
1192    if trimmed.is_empty() {
1193        return "https://api.cohere.com/v2/chat".to_string();
1194    }
1195
1196    let mut base_for_fallback = trimmed.trim_end_matches('/').to_string();
1197
1198    if let Ok(url) = Url::parse(trimmed) {
1199        if url.cannot_be_a_base() {
1200            base_for_fallback = url.as_str().trim_end_matches('/').to_string();
1201        } else {
1202            if trimmed_url_path(&url).ends_with("/chat") {
1203                return canonicalize_url_path(&url);
1204            }
1205            if is_official_https_origin(&url, "api.cohere.com", 443) {
1206                return replace_url_path(&url, "/v2/chat");
1207            }
1208            return append_url_path(&url, "chat");
1209        }
1210    }
1211
1212    let base_url = base_for_fallback;
1213    if base_url.ends_with("/chat") {
1214        return base_url;
1215    }
1216    format!("{base_url}/chat")
1217}
1218
1219#[cfg(test)]
1220mod tests {
1221    use super::*;
1222    use crate::extensions::{ExtensionManager, JsExtensionLoadSpec, JsExtensionRuntimeHandle};
1223    use crate::extensions_js::PiJsRuntimeConfig;
1224    use crate::model::{ContentBlock, Message, UserContent, UserMessage};
1225    use crate::tools::ToolRegistry;
1226    use asupersync::runtime::RuntimeBuilder;
1227    use asupersync::time::{sleep, wall_now};
1228    use futures::StreamExt;
1229    use std::sync::Arc;
1230    use std::time::Duration;
1231    use tempfile::tempdir;
1232
1233    const STREAM_SIMPLE_EXTENSION: &str = r#"
1234export default function init(pi) {
1235  pi.registerProvider("stream-provider", {
1236    baseUrl: "https://api.example.test",
1237    apiKey: "EXAMPLE_KEY",
1238    api: "custom-api",
1239    models: [
1240      { id: "stream-model", name: "Stream Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1241    ],
1242    streamSimple: async function* (model, context, options) {
1243      if (!model || !model.baseUrl || !model.maxTokens || !model.contextWindow) {
1244        throw new Error("bad model shape");
1245      }
1246      if (!context || !Array.isArray(context.messages)) {
1247        throw new Error("bad context shape");
1248      }
1249      if (!options || !options.signal) {
1250        throw new Error("missing abort signal");
1251      }
1252
1253      const partial = {
1254        role: "assistant",
1255        content: [{ type: "text", text: "" }],
1256        api: model.api,
1257        provider: model.provider,
1258        model: model.id,
1259        usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1260        stopReason: "stop",
1261        timestamp: 0
1262      };
1263
1264      yield { type: "start", partial };
1265      yield { type: "text_start", contentIndex: 0, partial };
1266      partial.content[0].text += "hi";
1267      yield { type: "text_delta", contentIndex: 0, delta: "hi", partial };
1268      yield { type: "done", reason: "stop", message: partial };
1269    }
1270  });
1271}
1272"#;
1273
1274    const STREAM_SIMPLE_CANCEL_EXTENSION: &str = r#"
1275export default function init(pi) {
1276  pi.registerProvider("cancel-provider", {
1277    baseUrl: "https://api.example.test",
1278    apiKey: "EXAMPLE_KEY",
1279    api: "custom-api",
1280    models: [
1281      { id: "cancel-model", name: "Cancel Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1282    ],
1283    streamSimple: async function* (model, context, options) {
1284      const partial = {
1285        role: "assistant",
1286        content: [{ type: "text", text: "" }],
1287        api: model.api,
1288        provider: model.provider,
1289        model: model.id,
1290        usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1291        stopReason: "stop",
1292        timestamp: 0
1293      };
1294
1295      try {
1296        yield { type: "start", partial };
1297        await new Promise((resolve) => {
1298          if (options && options.signal && options.signal.aborted) return resolve();
1299          if (options && options.signal && typeof options.signal.addEventListener === "function") {
1300            options.signal.addEventListener("abort", () => resolve());
1301          }
1302        });
1303      } finally {
1304        await pi.tool("write", { path: "cancelled.txt", content: "ok" });
1305      }
1306    }
1307  });
1308}
1309"#;
1310
1311    async fn load_extension(
1312        source: &str,
1313        allow_write: bool,
1314    ) -> (tempfile::TempDir, ExtensionManager) {
1315        let dir = tempdir().expect("tempdir");
1316        let entry_path = dir.path().join("ext.mjs");
1317        std::fs::write(&entry_path, source).expect("write extension");
1318
1319        let manager = ExtensionManager::new();
1320        let tools = if allow_write {
1321            Arc::new(ToolRegistry::new(&["write"], dir.path(), None))
1322        } else {
1323            Arc::new(ToolRegistry::new(&[], dir.path(), None))
1324        };
1325
1326        let js_runtime = JsExtensionRuntimeHandle::start(
1327            PiJsRuntimeConfig {
1328                cwd: dir.path().display().to_string(),
1329                ..Default::default()
1330            },
1331            Arc::clone(&tools),
1332            manager.clone(),
1333        )
1334        .await
1335        .expect("start js runtime");
1336        manager.set_js_runtime(js_runtime);
1337
1338        let spec = JsExtensionLoadSpec::from_entry_path(&entry_path).expect("load spec");
1339        manager
1340            .load_js_extensions(vec![spec])
1341            .await
1342            .expect("load extension");
1343
1344        (dir, manager)
1345    }
1346
1347    fn basic_context() -> Context<'static> {
1348        Context {
1349            system_prompt: Some("system".to_string().into()),
1350            messages: vec![Message::User(UserMessage {
1351                content: UserContent::Text("hello".to_string()),
1352                timestamp: 0,
1353            })]
1354            .into(),
1355            tools: Vec::new().into(),
1356        }
1357    }
1358
1359    fn basic_options() -> StreamOptions {
1360        StreamOptions {
1361            api_key: Some("sk-test".to_string()),
1362            ..Default::default()
1363        }
1364    }
1365
1366    #[test]
1367    fn extension_stream_simple_provider_emits_assistant_events() {
1368        let runtime = RuntimeBuilder::current_thread()
1369            .build()
1370            .expect("runtime build");
1371
1372        runtime.block_on(async move {
1373            let (_dir, manager) = load_extension(STREAM_SIMPLE_EXTENSION, false).await;
1374            let entries = manager.extension_model_entries();
1375            assert_eq!(entries.len(), 1);
1376            let entry = entries
1377                .iter()
1378                .find(|e| e.model.provider == "stream-provider")
1379                .expect("stream-provider entry");
1380
1381            let provider = create_provider(entry, Some(&manager)).expect("create provider");
1382            assert_eq!(provider.name(), "stream-provider");
1383
1384            let ctx = basic_context();
1385            let opts = basic_options();
1386            let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1387
1388            let mut saw_start = false;
1389            let mut saw_text_delta = false;
1390            while let Some(item) = stream.next().await {
1391                let event = item.expect("stream event");
1392                match event {
1393                    StreamEvent::Start { .. } => {
1394                        saw_start = true;
1395                    }
1396                    StreamEvent::TextDelta { delta, .. } => {
1397                        assert_eq!(delta, "hi");
1398                        saw_text_delta = true;
1399                    }
1400                    StreamEvent::Done { reason, message } => {
1401                        assert_eq!(reason, StopReason::Stop);
1402                        let text = match &message.content[0] {
1403                            ContentBlock::Text(text) => text,
1404                            other => unreachable!("expected text content block, got {other:?}"),
1405                        };
1406                        assert_eq!(text.text, "hi");
1407                        break;
1408                    }
1409                    _ => {}
1410                }
1411            }
1412
1413            assert!(saw_start, "expected a Start event");
1414            assert!(saw_text_delta, "expected a TextDelta event");
1415        });
1416    }
1417
1418    #[test]
1419    fn extension_stream_simple_provider_drop_cancels_js_stream() {
1420        let runtime = RuntimeBuilder::current_thread()
1421            .build()
1422            .expect("runtime build");
1423
1424        runtime.block_on(async move {
1425            let (dir, manager) = load_extension(STREAM_SIMPLE_CANCEL_EXTENSION, true).await;
1426            let entries = manager.extension_model_entries();
1427            assert_eq!(entries.len(), 1);
1428            let entry = entries
1429                .iter()
1430                .find(|e| e.model.provider == "cancel-provider")
1431                .expect("cancel-provider entry");
1432
1433            let provider = create_provider(entry, Some(&manager)).expect("create provider");
1434            let ctx = basic_context();
1435            let opts = basic_options();
1436            let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1437
1438            let first = stream.next().await.expect("first event");
1439            let _ = first.expect("first event ok");
1440            drop(stream);
1441
1442            let out_path = dir.path().join("cancelled.txt");
1443            for _ in 0..200 {
1444                if out_path.exists() {
1445                    let contents = std::fs::read_to_string(&out_path).expect("read cancelled.txt");
1446                    assert_eq!(contents, "ok");
1447                    return;
1448                }
1449                sleep(wall_now(), Duration::from_millis(5)).await;
1450            }
1451
1452            assert!(
1453                out_path.exists(),
1454                "expected cancelled.txt to be created after stream drop/cancel"
1455            );
1456        });
1457    }
1458
1459    // ========================================================================
1460    // Additional tests for bd-izzp
1461    // ========================================================================
1462
1463    const STREAM_SIMPLE_MULTI_CHUNK: &str = r#"
1464export default function init(pi) {
1465  pi.registerProvider("multi-chunk-provider", {
1466    baseUrl: "https://api.example.test",
1467    apiKey: "EXAMPLE_KEY",
1468    api: "custom-api",
1469    models: [
1470      { id: "multi-model", name: "Multi Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1471    ],
1472    streamSimple: async function* (model, context, options) {
1473      const partial = {
1474        role: "assistant",
1475        content: [{ type: "text", text: "" }],
1476        api: model.api,
1477        provider: model.provider,
1478        model: model.id,
1479        usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1480        stopReason: "stop",
1481        timestamp: 0
1482      };
1483
1484      yield { type: "start", partial };
1485      yield { type: "text_start", contentIndex: 0, partial };
1486
1487      const chunks = ["Hello", ", ", "world", "!"];
1488      for (const chunk of chunks) {
1489        partial.content[0].text += chunk;
1490        yield { type: "text_delta", contentIndex: 0, delta: chunk, partial };
1491      }
1492
1493      yield { type: "text_end", contentIndex: 0, content: partial.content[0].text, partial };
1494      yield { type: "done", reason: "stop", message: partial };
1495    }
1496  });
1497}
1498"#;
1499
1500    const STREAM_SIMPLE_ERROR: &str = r#"
1501export default function init(pi) {
1502  pi.registerProvider("error-provider", {
1503    baseUrl: "https://api.example.test",
1504    apiKey: "EXAMPLE_KEY",
1505    api: "custom-api",
1506    models: [
1507      { id: "error-model", name: "Error Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1508    ],
1509    streamSimple: async function* (model, context, options) {
1510      const partial = {
1511        role: "assistant",
1512        content: [{ type: "text", text: "" }],
1513        api: model.api,
1514        provider: model.provider,
1515        model: model.id,
1516        usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1517        stopReason: "stop",
1518        timestamp: 0
1519      };
1520
1521      yield { type: "start", partial };
1522      throw new Error("simulated JS error during streaming");
1523    }
1524  });
1525}
1526"#;
1527
1528    const STREAM_SIMPLE_UNICODE: &str = r#"
1529export default function init(pi) {
1530  pi.registerProvider("unicode-provider", {
1531    baseUrl: "https://api.example.test",
1532    apiKey: "EXAMPLE_KEY",
1533    api: "custom-api",
1534    models: [
1535      { id: "unicode-model", name: "Unicode Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1536    ],
1537    streamSimple: async function* (model, context, options) {
1538      const partial = {
1539        role: "assistant",
1540        content: [{ type: "text", text: "" }],
1541        api: model.api,
1542        provider: model.provider,
1543        model: model.id,
1544        usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1545        stopReason: "stop",
1546        timestamp: 0
1547      };
1548
1549      yield { type: "start", partial };
1550      yield { type: "text_start", contentIndex: 0, partial };
1551      partial.content[0].text = "日本語テスト 🦀";
1552      yield { type: "text_delta", contentIndex: 0, delta: "日本語テスト 🦀", partial };
1553      yield { type: "done", reason: "stop", message: partial };
1554    }
1555  });
1556}
1557"#;
1558
1559    #[test]
1560    fn extension_stream_simple_multiple_chunks_in_order() {
1561        let runtime = RuntimeBuilder::current_thread()
1562            .build()
1563            .expect("runtime build");
1564
1565        runtime.block_on(async move {
1566            let (_dir, manager) = load_extension(STREAM_SIMPLE_MULTI_CHUNK, false).await;
1567            let entries = manager.extension_model_entries();
1568            let entry = entries
1569                .iter()
1570                .find(|e| e.model.provider == "multi-chunk-provider")
1571                .expect("multi-chunk-provider entry");
1572
1573            let provider = create_provider(entry, Some(&manager)).expect("create provider");
1574            let ctx = basic_context();
1575            let opts = basic_options();
1576            let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1577
1578            let mut deltas = Vec::new();
1579            let mut final_text = String::new();
1580            while let Some(item) = stream.next().await {
1581                let event = item.expect("stream event");
1582                match event {
1583                    StreamEvent::TextDelta { delta, .. } => {
1584                        deltas.push(delta);
1585                    }
1586                    StreamEvent::Done { message, .. } => {
1587                        let text = match &message.content[0] {
1588                            ContentBlock::Text(text) => text,
1589                            other => unreachable!("expected text content block, got {other:?}"),
1590                        };
1591                        final_text = text.text.clone();
1592                        break;
1593                    }
1594                    _ => {}
1595                }
1596            }
1597
1598            assert_eq!(deltas, vec!["Hello", ", ", "world", "!"]);
1599            assert_eq!(final_text, "Hello, world!");
1600        });
1601    }
1602
1603    #[test]
1604    fn extension_stream_simple_js_error_propagates() {
1605        let runtime = RuntimeBuilder::current_thread()
1606            .build()
1607            .expect("runtime build");
1608
1609        runtime.block_on(async move {
1610            let (_dir, manager) = load_extension(STREAM_SIMPLE_ERROR, false).await;
1611            let entries = manager.extension_model_entries();
1612            let entry = entries
1613                .iter()
1614                .find(|e| e.model.provider == "error-provider")
1615                .expect("error-provider entry");
1616
1617            let provider = create_provider(entry, Some(&manager)).expect("create provider");
1618            let ctx = basic_context();
1619            let opts = basic_options();
1620            let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1621
1622            let mut saw_start = false;
1623            let mut saw_error = false;
1624            while let Some(item) = stream.next().await {
1625                match item {
1626                    Ok(StreamEvent::Start { .. }) => {
1627                        saw_start = true;
1628                    }
1629                    Err(err) => {
1630                        // JS error should propagate as an extension error.
1631                        let msg = err.to_string();
1632                        assert!(
1633                            msg.contains("simulated JS error") || msg.contains("error"),
1634                            "expected JS error message, got: {msg}"
1635                        );
1636                        saw_error = true;
1637                        break;
1638                    }
1639                    Ok(StreamEvent::Error { .. }) => {
1640                        saw_error = true;
1641                        break;
1642                    }
1643                    _ => {}
1644                }
1645            }
1646
1647            assert!(saw_start, "expected a Start event before error");
1648            assert!(saw_error, "expected JS error to propagate");
1649        });
1650    }
1651
1652    #[test]
1653    fn extension_stream_simple_unicode_content() {
1654        let runtime = RuntimeBuilder::current_thread()
1655            .build()
1656            .expect("runtime build");
1657
1658        runtime.block_on(async move {
1659            let (_dir, manager) = load_extension(STREAM_SIMPLE_UNICODE, false).await;
1660            let entries = manager.extension_model_entries();
1661            let entry = entries
1662                .iter()
1663                .find(|e| e.model.provider == "unicode-provider")
1664                .expect("unicode-provider entry");
1665
1666            let provider = create_provider(entry, Some(&manager)).expect("create provider");
1667            let ctx = basic_context();
1668            let opts = basic_options();
1669            let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1670
1671            let mut saw_unicode = false;
1672            while let Some(item) = stream.next().await {
1673                let event = item.expect("stream event");
1674                match event {
1675                    StreamEvent::TextDelta { delta, .. } => {
1676                        assert_eq!(delta, "日本語テスト 🦀");
1677                        saw_unicode = true;
1678                    }
1679                    StreamEvent::Done { .. } => break,
1680                    _ => {}
1681                }
1682            }
1683
1684            assert!(saw_unicode, "expected unicode text delta");
1685        });
1686    }
1687
1688    #[test]
1689    fn extension_stream_simple_provider_name_and_model() {
1690        let runtime = RuntimeBuilder::current_thread()
1691            .build()
1692            .expect("runtime build");
1693
1694        runtime.block_on(async move {
1695            let (_dir, manager) = load_extension(STREAM_SIMPLE_EXTENSION, false).await;
1696            let entries = manager.extension_model_entries();
1697            let entry = entries
1698                .iter()
1699                .find(|e| e.model.provider == "stream-provider")
1700                .expect("stream-provider entry");
1701
1702            let provider = create_provider(entry, Some(&manager)).expect("create provider");
1703            assert_eq!(provider.name(), "stream-provider");
1704            assert_eq!(provider.model_id(), "stream-model");
1705            assert_eq!(provider.api(), "custom-api");
1706        });
1707    }
1708
1709    #[test]
1710    fn create_provider_returns_extension_provider_for_stream_simple() {
1711        let runtime = RuntimeBuilder::current_thread()
1712            .build()
1713            .expect("runtime build");
1714
1715        runtime.block_on(async move {
1716            let (_dir, manager) = load_extension(STREAM_SIMPLE_EXTENSION, false).await;
1717            let entries = manager.extension_model_entries();
1718            let entry = entries
1719                .iter()
1720                .find(|e| e.model.provider == "stream-provider")
1721                .expect("stream-provider entry");
1722
1723            // With extensions, should create ExtensionStreamSimpleProvider.
1724            let provider = create_provider(entry, Some(&manager));
1725            assert!(provider.is_ok());
1726
1727            // Without extensions, should fail (unknown provider).
1728            let provider_no_ext = create_provider(entry, None);
1729            assert!(provider_no_ext.is_err());
1730        });
1731    }
1732
1733    // ========================================================================
1734    // bd-g1nx: Provider factory + URL normalization tests
1735    // ========================================================================
1736
1737    use crate::models::ModelEntry;
1738    use crate::provider::{InputType, Model, ModelCost};
1739    use std::collections::HashMap;
1740
1741    fn model_entry(provider: &str, api: &str, model_id: &str, base_url: &str) -> ModelEntry {
1742        ModelEntry {
1743            model: Model {
1744                id: model_id.to_string(),
1745                name: model_id.to_string(),
1746                api: api.to_string(),
1747                provider: provider.to_string(),
1748                base_url: base_url.to_string(),
1749                reasoning: false,
1750                input: vec![InputType::Text],
1751                cost: ModelCost {
1752                    input: 3.0,
1753                    output: 15.0,
1754                    cache_read: 0.3,
1755                    cache_write: 3.75,
1756                },
1757                context_window: 200_000,
1758                max_tokens: 8192,
1759                headers: HashMap::new(),
1760            },
1761            api_key: Some("sk-test-key".to_string()),
1762            headers: HashMap::new(),
1763            auth_header: true,
1764            compat: None,
1765            oauth_config: None,
1766        }
1767    }
1768
1769    #[test]
1770    fn resolve_provider_route_uses_metadata_for_alias_provider() {
1771        let entry = model_entry(
1772            "kimi",
1773            "openai-completions",
1774            "kimi-k2-instruct",
1775            "https://api.moonshot.ai/v1",
1776        );
1777        let (route, canonical_provider, effective_api) =
1778            resolve_provider_route(&entry).expect("resolve alias route");
1779        assert_eq!(route, ProviderRouteKind::ApiOpenAICompletions);
1780        assert_eq!(canonical_provider, "moonshotai");
1781        assert_eq!(effective_api, "openai-completions");
1782    }
1783
1784    #[test]
1785    fn resolve_provider_route_openai_unknown_api_defaults_to_native_responses() {
1786        let entry = model_entry("openai", "openai", "gpt-4o", "https://api.openai.com/v1");
1787        let (route, canonical_provider, effective_api) =
1788            resolve_provider_route(&entry).expect("resolve openai route");
1789        assert_eq!(route, ProviderRouteKind::NativeOpenAIResponses);
1790        assert_eq!(canonical_provider, "openai");
1791        assert_eq!(effective_api, "openai");
1792    }
1793
1794    #[test]
1795    fn resolve_provider_route_cloudflare_workers_defaults_to_openai_completions() {
1796        let entry = model_entry(
1797            "cloudflare-workers-ai",
1798            "",
1799            "@cf/meta/llama-3.1-8b-instruct",
1800            "https://api.cloudflare.com/client/v4/accounts/test-account/ai/v1",
1801        );
1802        let (route, canonical_provider, effective_api) =
1803            resolve_provider_route(&entry).expect("resolve cloudflare workers route");
1804        assert_eq!(route, ProviderRouteKind::ApiOpenAICompletions);
1805        assert_eq!(canonical_provider, "cloudflare-workers-ai");
1806        assert_eq!(effective_api, "openai-completions");
1807    }
1808
1809    #[test]
1810    fn resolve_provider_route_cloudflare_gateway_defaults_to_openai_completions() {
1811        let entry = model_entry(
1812            "cloudflare-ai-gateway",
1813            "",
1814            "gpt-4o-mini",
1815            "https://gateway.ai.cloudflare.com/v1/account-id/gateway-id/openai",
1816        );
1817        let (route, canonical_provider, effective_api) =
1818            resolve_provider_route(&entry).expect("resolve cloudflare gateway route");
1819        assert_eq!(route, ProviderRouteKind::ApiOpenAICompletions);
1820        assert_eq!(canonical_provider, "cloudflare-ai-gateway");
1821        assert_eq!(effective_api, "openai-completions");
1822    }
1823
1824    #[test]
1825    fn resolve_provider_route_uses_native_azure_route_for_cognitive_alias() {
1826        let entry = model_entry(
1827            "azure-cognitive-services",
1828            "openai-completions",
1829            "gpt-4o-mini",
1830            "https://myresource.cognitiveservices.azure.com",
1831        );
1832        let (route, canonical_provider, effective_api) =
1833            resolve_provider_route(&entry).expect("resolve azure cognitive route");
1834        assert_eq!(route, ProviderRouteKind::NativeAzure);
1835        assert_eq!(canonical_provider, "azure-openai");
1836        assert_eq!(effective_api, "openai-completions");
1837    }
1838
1839    #[test]
1840    fn resolve_provider_route_uses_native_azure_route_for_legacy_provider_alias() {
1841        let entry = model_entry(
1842            "azure-openai-responses",
1843            "azure-openai-responses",
1844            "gpt-4o-mini",
1845            "https://myresource.openai.azure.com",
1846        );
1847        let (route, canonical_provider, effective_api) =
1848            resolve_provider_route(&entry).expect("resolve azure legacy alias route");
1849        assert_eq!(route, ProviderRouteKind::NativeAzure);
1850        assert_eq!(canonical_provider, "azure-openai");
1851        assert_eq!(effective_api, "azure-openai-responses");
1852    }
1853
1854    #[test]
1855    fn resolve_provider_route_accepts_azure_legacy_api_for_custom_provider_id() {
1856        let entry = model_entry(
1857            "my-azure",
1858            "azure-openai-responses",
1859            "gpt-4o-mini",
1860            "https://example.invalid",
1861        );
1862        let (route, canonical_provider, effective_api) =
1863            resolve_provider_route(&entry).expect("resolve azure legacy api fallback");
1864        assert_eq!(route, ProviderRouteKind::NativeAzure);
1865        assert_eq!(canonical_provider, "my-azure");
1866        assert_eq!(effective_api, "azure-openai-responses");
1867    }
1868
1869    #[test]
1870    fn resolve_copilot_token_prefers_inline_model_api_key() {
1871        let mut entry = model_entry("github-copilot", "", "gpt-4o", "");
1872        entry.api_key = Some("inline-copilot-token".to_string());
1873
1874        let token = resolve_copilot_token_with_env(&entry, |_| None)
1875            .expect("inline token should be accepted");
1876        assert_eq!(token, "inline-copilot-token");
1877    }
1878
1879    #[test]
1880    fn resolve_copilot_token_falls_back_to_env() {
1881        let mut entry = model_entry("github-copilot", "", "gpt-4o", "");
1882        entry.api_key = None;
1883
1884        let token = resolve_copilot_token_with_env(&entry, |name| match name {
1885            "GITHUB_COPILOT_API_KEY" => Some("env-copilot-token".to_string()),
1886            _ => None,
1887        })
1888        .expect("env token should be accepted");
1889        assert_eq!(token, "env-copilot-token");
1890    }
1891
1892    #[test]
1893    fn resolve_copilot_token_errors_when_missing_everywhere() {
1894        let mut entry = model_entry("github-copilot", "", "gpt-4o", "");
1895        entry.api_key = None;
1896
1897        let err = resolve_copilot_token_with_env(&entry, |_| None).expect_err("expected error");
1898        assert!(
1899            err.to_string().contains("GitHub Copilot requires"),
1900            "unexpected error: {err}"
1901        );
1902    }
1903
1904    #[test]
1905    fn suggest_similar_providers_finds_prefix_match() {
1906        let suggestions = suggest_similar_providers("deep");
1907        assert!(
1908            suggestions.contains(&"deepinfra".to_string())
1909                || suggestions.contains(&"deepseek".to_string()),
1910            "expected deepinfra or deepseek in suggestions: {suggestions:?}"
1911        );
1912    }
1913
1914    #[test]
1915    fn suggest_similar_providers_finds_substring_match() {
1916        let suggestions = suggest_similar_providers("flow");
1917        assert!(
1918            suggestions.contains(&"siliconflow".to_string()),
1919            "expected siliconflow in suggestions: {suggestions:?}"
1920        );
1921    }
1922
1923    #[test]
1924    fn suggest_similar_providers_returns_empty_for_gibberish() {
1925        let suggestions = suggest_similar_providers("xyzzzabc123");
1926        assert!(
1927            suggestions.is_empty(),
1928            "expected no suggestions for gibberish: {suggestions:?}"
1929        );
1930    }
1931
1932    #[test]
1933    fn suggest_similar_providers_caps_at_three() {
1934        let suggestions = suggest_similar_providers("a");
1935        assert!(
1936            suggestions.len() <= 3,
1937            "expected at most 3 suggestions: {suggestions:?}"
1938        );
1939    }
1940
1941    #[test]
1942    fn edit_distance_basic_cases() {
1943        assert_eq!(edit_distance(b"", b""), 0);
1944        assert_eq!(edit_distance(b"abc", b"abc"), 0);
1945        assert_eq!(edit_distance(b"abc", b"ab"), 1);
1946        assert_eq!(edit_distance(b"abc", b"axc"), 1);
1947        assert_eq!(edit_distance(b"abc", b"abcd"), 1);
1948        assert_eq!(edit_distance(b"kitten", b"sitting"), 3);
1949        assert_eq!(edit_distance(b"", b"hello"), 5);
1950    }
1951
1952    #[test]
1953    fn suggest_similar_providers_finds_typo_with_edit_distance() {
1954        // "anthropick" is edit distance 1 from "anthropic"
1955        let suggestions = suggest_similar_providers("anthropick");
1956        assert!(
1957            suggestions.contains(&"anthropic".to_string()),
1958            "expected anthropic for typo 'anthropick': {suggestions:?}"
1959        );
1960    }
1961
1962    #[test]
1963    fn suggest_similar_providers_finds_typo_missing_char() {
1964        // "openai" with missing letter: "opnai" → edit distance 1
1965        let suggestions = suggest_similar_providers("opnai");
1966        assert!(
1967            suggestions.contains(&"openai".to_string()),
1968            "expected openai for typo 'opnai': {suggestions:?}"
1969        );
1970    }
1971
1972    #[test]
1973    fn suggest_similar_providers_finds_transposed_chars() {
1974        // "gogle" → "google" edit distance 1 (missing 'o')
1975        let suggestions = suggest_similar_providers("gogle");
1976        assert!(
1977            suggestions.contains(&"google".to_string()),
1978            "expected google for typo 'gogle': {suggestions:?}"
1979        );
1980    }
1981
1982    #[test]
1983    fn suggest_similar_providers_no_false_positives_for_short_input() {
1984        // Very short input should not match via edit distance (threshold=0)
1985        let suggestions = suggest_similar_providers("xy");
1986        assert!(
1987            suggestions.is_empty(),
1988            "expected no suggestions for 'xy': {suggestions:?}"
1989        );
1990    }
1991
1992    #[test]
1993    fn resolve_azure_provider_runtime_supports_openai_host() {
1994        let entry = model_entry(
1995            "azure-openai",
1996            "openai-completions",
1997            "gpt-4o",
1998            "https://myresource.openai.azure.com",
1999        );
2000        let runtime =
2001            resolve_azure_provider_runtime_with_env(&entry, |_| None).expect("resolve runtime");
2002        assert_eq!(runtime.resource, "myresource");
2003        assert_eq!(runtime.deployment, "gpt-4o");
2004        assert_eq!(runtime.api_version, "2024-12-01-preview");
2005        assert_eq!(
2006            runtime.endpoint_url,
2007            "https://myresource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-12-01-preview"
2008        );
2009    }
2010
2011    #[test]
2012    fn resolve_azure_provider_runtime_supports_cognitive_services_host() {
2013        let entry = model_entry(
2014            "azure-cognitive-services",
2015            "openai-completions",
2016            "gpt-4o-mini",
2017            "https://myresource.cognitiveservices.azure.com/openai/deployments/custom/chat/completions?api-version=2024-10-21",
2018        );
2019        let runtime =
2020            resolve_azure_provider_runtime_with_env(&entry, |_| None).expect("resolve runtime");
2021        assert_eq!(runtime.resource, "myresource");
2022        assert_eq!(runtime.deployment, "custom");
2023        assert_eq!(runtime.api_version, "2024-10-21");
2024        assert_eq!(
2025            runtime.endpoint_url,
2026            "https://myresource.cognitiveservices.azure.com/openai/deployments/custom/chat/completions?api-version=2024-10-21"
2027        );
2028    }
2029
2030    #[test]
2031    fn resolve_azure_provider_runtime_prefers_base_url_deployment_over_model_id() {
2032        let entry = model_entry(
2033            "azure-openai",
2034            "openai-completions",
2035            "model-fallback",
2036            "https://myresource.openai.azure.com/openai/deployments/base-deploy/chat/completions?api-version=2024-10-21",
2037        );
2038        let runtime =
2039            resolve_azure_provider_runtime_with_env(&entry, |_| None).expect("resolve runtime");
2040        assert_eq!(runtime.resource, "myresource");
2041        assert_eq!(runtime.deployment, "base-deploy");
2042        assert_eq!(runtime.api_version, "2024-10-21");
2043        assert_eq!(
2044            runtime.endpoint_url,
2045            "https://myresource.openai.azure.com/openai/deployments/base-deploy/chat/completions?api-version=2024-10-21"
2046        );
2047    }
2048
2049    #[test]
2050    fn resolve_azure_provider_runtime_env_deployment_overrides_base_url_and_model_id() {
2051        let entry = model_entry(
2052            "azure-openai",
2053            "openai-completions",
2054            "model-fallback",
2055            "https://myresource.openai.azure.com/openai/deployments/base-deploy/chat/completions?api-version=2024-10-21",
2056        );
2057        let runtime = resolve_azure_provider_runtime_with_env(&entry, |name| match name {
2058            AZURE_OPENAI_DEPLOYMENT_ENV => Some("env-deploy".to_string()),
2059            _ => None,
2060        })
2061        .expect("resolve runtime");
2062        assert_eq!(runtime.resource, "myresource");
2063        assert_eq!(runtime.deployment, "env-deploy");
2064        assert_eq!(runtime.api_version, "2024-10-21");
2065        assert_eq!(
2066            runtime.endpoint_url,
2067            "https://myresource.openai.azure.com/openai/deployments/env-deploy/chat/completions?api-version=2024-10-21"
2068        );
2069    }
2070
2071    // ── create_provider: built-in provider selection ─────────────────
2072
2073    #[test]
2074    fn create_provider_anthropic_by_name() {
2075        let entry = model_entry(
2076            "anthropic",
2077            "anthropic-messages",
2078            "claude-sonnet-4-5",
2079            "https://api.anthropic.com",
2080        );
2081        let provider = create_provider(&entry, None).expect("anthropic provider");
2082        assert_eq!(provider.name(), "anthropic");
2083        assert_eq!(provider.model_id(), "claude-sonnet-4-5");
2084        assert_eq!(provider.api(), "anthropic-messages");
2085    }
2086
2087    #[test]
2088    fn create_provider_openai_completions_by_name() {
2089        let entry = model_entry(
2090            "openai",
2091            "openai-completions",
2092            "gpt-4o",
2093            "https://api.openai.com/v1",
2094        );
2095        let provider = create_provider(&entry, None).expect("openai completions provider");
2096        assert_eq!(provider.name(), "openai");
2097        assert_eq!(provider.model_id(), "gpt-4o");
2098    }
2099
2100    #[test]
2101    fn create_provider_openai_responses_by_name() {
2102        let entry = model_entry(
2103            "openai",
2104            "openai-responses",
2105            "gpt-4o",
2106            "https://api.openai.com/v1",
2107        );
2108        let provider = create_provider(&entry, None).expect("openai responses provider");
2109        assert_eq!(provider.name(), "openai");
2110        assert_eq!(provider.model_id(), "gpt-4o");
2111    }
2112
2113    #[test]
2114    fn create_provider_openai_defaults_to_responses() {
2115        // When api is not "openai-completions", OpenAI defaults to Responses API
2116        let entry = model_entry("openai", "openai", "gpt-4o", "https://api.openai.com/v1");
2117        let provider = create_provider(&entry, None).expect("openai default responses provider");
2118        assert_eq!(provider.name(), "openai");
2119    }
2120
2121    #[test]
2122    fn create_provider_google_by_name() {
2123        let entry = model_entry(
2124            "google",
2125            "google-generative-ai",
2126            "gemini-2.0-flash",
2127            "https://generativelanguage.googleapis.com",
2128        );
2129        let provider = create_provider(&entry, None).expect("google provider");
2130        assert_eq!(provider.name(), "google");
2131        assert_eq!(provider.model_id(), "gemini-2.0-flash");
2132    }
2133
2134    #[test]
2135    fn create_provider_cohere_by_name() {
2136        let entry = model_entry(
2137            "cohere",
2138            "cohere-chat",
2139            "command-r-plus",
2140            "https://api.cohere.com/v2",
2141        );
2142        let provider = create_provider(&entry, None).expect("cohere provider");
2143        assert_eq!(provider.name(), "cohere");
2144        assert_eq!(provider.model_id(), "command-r-plus");
2145    }
2146
2147    #[test]
2148    fn create_provider_azure_openai_by_name() {
2149        let entry = model_entry(
2150            "azure-openai",
2151            "openai-completions",
2152            "gpt-4o",
2153            "https://myresource.openai.azure.com",
2154        );
2155        let provider = create_provider(&entry, None).expect("azure provider");
2156        assert_eq!(provider.name(), "azure-openai");
2157        assert_eq!(provider.api(), "azure-openai");
2158        assert!(!provider.model_id().is_empty());
2159    }
2160
2161    #[test]
2162    fn create_provider_azure_cognitive_services_alias_by_name() {
2163        let entry = model_entry(
2164            "azure-cognitive-services",
2165            "openai-completions",
2166            "gpt-4o-mini",
2167            "https://myresource.cognitiveservices.azure.com",
2168        );
2169        let provider = create_provider(&entry, None).expect("azure cognitive provider");
2170        assert_eq!(provider.name(), "azure-cognitive-services");
2171        assert_eq!(provider.api(), "azure-openai");
2172        assert!(!provider.model_id().is_empty());
2173    }
2174
2175    #[test]
2176    fn create_provider_cloudflare_workers_ai_by_name() {
2177        let entry = model_entry(
2178            "cloudflare-workers-ai",
2179            "",
2180            "@cf/meta/llama-3.1-8b-instruct",
2181            "https://api.cloudflare.com/client/v4/accounts/test-account/ai/v1",
2182        );
2183        let provider = create_provider(&entry, None).expect("cloudflare workers provider");
2184        assert_eq!(provider.name(), "cloudflare-workers-ai");
2185        assert_eq!(provider.api(), "openai-completions");
2186        assert_eq!(provider.model_id(), "@cf/meta/llama-3.1-8b-instruct");
2187    }
2188
2189    #[test]
2190    fn create_provider_cloudflare_ai_gateway_by_name() {
2191        let entry = model_entry(
2192            "cloudflare-ai-gateway",
2193            "",
2194            "gpt-4o-mini",
2195            "https://gateway.ai.cloudflare.com/v1/account-id/gateway-id/openai",
2196        );
2197        let provider = create_provider(&entry, None).expect("cloudflare gateway provider");
2198        assert_eq!(provider.name(), "cloudflare-ai-gateway");
2199        assert_eq!(provider.api(), "openai-completions");
2200        assert_eq!(provider.model_id(), "gpt-4o-mini");
2201    }
2202
2203    // ── create_provider: API fallback path ──────────────────────────
2204
2205    #[test]
2206    fn create_provider_falls_back_to_api_anthropic_messages() {
2207        let entry = model_entry(
2208            "custom-anthropic",
2209            "anthropic-messages",
2210            "my-model",
2211            "https://custom.api.com",
2212        );
2213        let provider = create_provider(&entry, None).expect("fallback anthropic provider");
2214        // Anthropic fallback uses the standard anthropic provider
2215        assert_eq!(provider.model_id(), "my-model");
2216    }
2217
2218    #[test]
2219    fn create_provider_falls_back_to_api_openai_completions() {
2220        let entry = model_entry(
2221            "my-openai-compat",
2222            "openai-completions",
2223            "local-model",
2224            "http://localhost:8080/v1",
2225        );
2226        let provider = create_provider(&entry, None).expect("fallback openai completions");
2227        assert_eq!(provider.model_id(), "local-model");
2228    }
2229
2230    #[test]
2231    fn create_provider_falls_back_to_api_openai_responses() {
2232        let entry = model_entry(
2233            "my-openai-compat",
2234            "openai-responses",
2235            "local-model",
2236            "http://localhost:8080/v1",
2237        );
2238        let provider = create_provider(&entry, None).expect("fallback openai responses");
2239        assert_eq!(provider.model_id(), "local-model");
2240    }
2241
2242    #[test]
2243    fn create_provider_falls_back_to_api_cohere_chat() {
2244        let entry = model_entry(
2245            "custom-cohere",
2246            "cohere-chat",
2247            "custom-r",
2248            "https://custom-cohere.api.com/v2",
2249        );
2250        let provider = create_provider(&entry, None).expect("fallback cohere provider");
2251        assert_eq!(provider.model_id(), "custom-r");
2252    }
2253
2254    #[test]
2255    fn create_provider_falls_back_to_api_google() {
2256        let entry = model_entry(
2257            "custom-google",
2258            "google-generative-ai",
2259            "custom-gemini",
2260            "https://custom.google.com",
2261        );
2262        let provider = create_provider(&entry, None).expect("fallback google provider");
2263        assert_eq!(provider.model_id(), "custom-gemini");
2264    }
2265
2266    #[test]
2267    fn resolve_provider_route_copilot_routes_correctly() {
2268        let entry = model_entry("github-copilot", "", "gpt-4o", "");
2269        let (route, canonical, _api) = resolve_provider_route(&entry).expect("copilot route");
2270        assert_eq!(route, ProviderRouteKind::NativeCopilot);
2271        assert_eq!(canonical, "github-copilot");
2272    }
2273
2274    #[test]
2275    fn resolve_provider_route_copilot_alias_routes_correctly() {
2276        let entry = model_entry("copilot", "", "gpt-4o", "");
2277        let (route, canonical, _api) = resolve_provider_route(&entry).expect("copilot alias route");
2278        assert_eq!(route, ProviderRouteKind::NativeCopilot);
2279        assert_eq!(canonical, "github-copilot");
2280    }
2281
2282    #[test]
2283    fn create_provider_unknown_provider_and_api_returns_error() {
2284        let entry = model_entry(
2285            "totally-unknown",
2286            "unknown-api",
2287            "some-model",
2288            "https://example.com",
2289        );
2290        let Err(err) = create_provider(&entry, None) else {
2291            panic!();
2292        };
2293        let msg = err.to_string();
2294        assert!(
2295            msg.contains("not implemented"),
2296            "expected 'not implemented' message, got: {msg}"
2297        );
2298    }
2299
2300    // ── normalize_anthropic_base ───────────────────────────────────
2301
2302    #[test]
2303    fn normalize_anthropic_base_appends_v1_messages() {
2304        assert_eq!(
2305            normalize_anthropic_base("https://api.anthropic.com"),
2306            "https://api.anthropic.com/v1/messages"
2307        );
2308    }
2309
2310    #[test]
2311    fn normalize_anthropic_base_keeps_existing_v1_messages() {
2312        assert_eq!(
2313            normalize_anthropic_base("https://api.anthropic.com/v1/messages"),
2314            "https://api.anthropic.com/v1/messages"
2315        );
2316    }
2317
2318    #[test]
2319    fn normalize_anthropic_base_strips_trailing_slash() {
2320        assert_eq!(
2321            normalize_anthropic_base("https://api.anthropic.com/"),
2322            "https://api.anthropic.com/v1/messages"
2323        );
2324    }
2325
2326    #[test]
2327    fn normalize_anthropic_base_empty_uses_default() {
2328        assert_eq!(
2329            normalize_anthropic_base("   "),
2330            "https://api.anthropic.com/v1/messages"
2331        );
2332    }
2333
2334    #[test]
2335    fn normalize_anthropic_base_preserves_query_and_fragment() {
2336        assert_eq!(
2337            normalize_anthropic_base("https://api.anthropic.com/?via=proxy#frag"),
2338            "https://api.anthropic.com/v1/messages?via=proxy#frag"
2339        );
2340    }
2341
2342    #[test]
2343    fn normalize_anthropic_base_handles_opaque_url_fallback() {
2344        assert_eq!(
2345            normalize_anthropic_base("data:text/plain,hello"),
2346            "data:text/plain,hello/v1/messages"
2347        );
2348    }
2349
2350    // ── normalize_openai_base ───────────────────────────────────────
2351
2352    #[test]
2353    fn normalize_openai_base_appends_chat_completions_to_v1() {
2354        assert_eq!(
2355            normalize_openai_base("https://api.openai.com/v1"),
2356            "https://api.openai.com/v1/chat/completions"
2357        );
2358    }
2359
2360    #[test]
2361    fn normalize_openai_base_keeps_existing_chat_completions() {
2362        assert_eq!(
2363            normalize_openai_base("https://api.openai.com/v1/chat/completions"),
2364            "https://api.openai.com/v1/chat/completions"
2365        );
2366    }
2367
2368    #[test]
2369    fn normalize_openai_base_strips_trailing_slash() {
2370        assert_eq!(
2371            normalize_openai_base("https://api.openai.com/v1/"),
2372            "https://api.openai.com/v1/chat/completions"
2373        );
2374    }
2375
2376    #[test]
2377    fn normalize_openai_base_strips_responses_suffix() {
2378        assert_eq!(
2379            normalize_openai_base("https://api.openai.com/v1/responses"),
2380            "https://api.openai.com/v1/chat/completions"
2381        );
2382    }
2383
2384    #[test]
2385    fn normalize_openai_base_official_bare_url_gets_v1_chat_completions() {
2386        assert_eq!(
2387            normalize_openai_base("https://api.openai.com"),
2388            "https://api.openai.com/v1/chat/completions"
2389        );
2390    }
2391
2392    #[test]
2393    fn normalize_openai_base_official_default_port_gets_v1_chat_completions() {
2394        assert_eq!(
2395            normalize_openai_base("https://api.openai.com:443"),
2396            "https://api.openai.com/v1/chat/completions"
2397        );
2398    }
2399
2400    #[test]
2401    fn normalize_openai_base_strips_non_v1_official_responses_suffix() {
2402        assert_eq!(
2403            normalize_openai_base("https://api.openai.com/responses"),
2404            "https://api.openai.com/v1/chat/completions"
2405        );
2406    }
2407
2408    #[test]
2409    fn normalize_openai_base_custom_bare_url_gets_chat_completions() {
2410        assert_eq!(
2411            normalize_openai_base("https://my-llm-proxy.com"),
2412            "https://my-llm-proxy.com/chat/completions"
2413        );
2414    }
2415
2416    #[test]
2417    fn normalize_openai_base_preserves_query_and_fragment_on_official_origin() {
2418        assert_eq!(
2419            normalize_openai_base("https://api.openai.com:443/?via=proxy#frag"),
2420            "https://api.openai.com/v1/chat/completions?via=proxy#frag"
2421        );
2422    }
2423
2424    #[test]
2425    fn normalize_openai_base_empty_uses_default() {
2426        assert_eq!(
2427            normalize_openai_base(""),
2428            "https://api.openai.com/v1/chat/completions"
2429        );
2430    }
2431
2432    #[test]
2433    fn normalize_openai_base_handles_opaque_url_fallback() {
2434        assert_eq!(
2435            normalize_openai_base("data:text/plain,hello"),
2436            "data:text/plain,hello/chat/completions"
2437        );
2438    }
2439
2440    // ── normalize_openai_responses_base ─────────────────────────────
2441
2442    #[test]
2443    fn normalize_responses_appends_responses_to_v1() {
2444        assert_eq!(
2445            normalize_openai_responses_base("https://api.openai.com/v1"),
2446            "https://api.openai.com/v1/responses"
2447        );
2448    }
2449
2450    #[test]
2451    fn normalize_responses_keeps_existing_responses() {
2452        assert_eq!(
2453            normalize_openai_responses_base("https://api.openai.com/v1/responses"),
2454            "https://api.openai.com/v1/responses"
2455        );
2456    }
2457
2458    #[test]
2459    fn normalize_responses_strips_trailing_slash() {
2460        assert_eq!(
2461            normalize_openai_responses_base("https://api.openai.com/v1/"),
2462            "https://api.openai.com/v1/responses"
2463        );
2464    }
2465
2466    #[test]
2467    fn normalize_responses_strips_chat_completions_suffix() {
2468        assert_eq!(
2469            normalize_openai_responses_base("https://api.openai.com/v1/chat/completions"),
2470            "https://api.openai.com/v1/responses"
2471        );
2472    }
2473
2474    #[test]
2475    fn normalize_responses_official_bare_url_gets_v1_responses() {
2476        assert_eq!(
2477            normalize_openai_responses_base("https://api.openai.com"),
2478            "https://api.openai.com/v1/responses"
2479        );
2480    }
2481
2482    #[test]
2483    fn normalize_responses_official_default_port_gets_v1_responses() {
2484        assert_eq!(
2485            normalize_openai_responses_base("https://api.openai.com:443"),
2486            "https://api.openai.com/v1/responses"
2487        );
2488    }
2489
2490    #[test]
2491    fn normalize_responses_strips_non_v1_official_chat_completions_suffix() {
2492        assert_eq!(
2493            normalize_openai_responses_base("https://api.openai.com/chat/completions"),
2494            "https://api.openai.com/v1/responses"
2495        );
2496    }
2497
2498    #[test]
2499    fn normalize_responses_custom_bare_url_gets_responses() {
2500        assert_eq!(
2501            normalize_openai_responses_base("https://my-llm-proxy.com"),
2502            "https://my-llm-proxy.com/responses"
2503        );
2504    }
2505
2506    #[test]
2507    fn normalize_responses_preserves_query_and_fragment() {
2508        assert_eq!(
2509            normalize_openai_responses_base("https://my-llm-proxy.com/api?via=proxy#frag"),
2510            "https://my-llm-proxy.com/api/responses?via=proxy#frag"
2511        );
2512    }
2513
2514    #[test]
2515    fn normalize_responses_preserves_query_and_fragment_on_official_origin() {
2516        assert_eq!(
2517            normalize_openai_responses_base("https://api.openai.com:443/?via=proxy#frag"),
2518            "https://api.openai.com/v1/responses?via=proxy#frag"
2519        );
2520    }
2521
2522    #[test]
2523    fn normalize_responses_base_empty_uses_default() {
2524        assert_eq!(
2525            normalize_openai_responses_base("  "),
2526            "https://api.openai.com/v1/responses"
2527        );
2528    }
2529
2530    #[test]
2531    fn normalize_responses_base_handles_opaque_url_fallback() {
2532        assert_eq!(
2533            normalize_openai_responses_base("data:text/plain,hello"),
2534            "data:text/plain,hello/responses"
2535        );
2536    }
2537
2538    // ── normalize_openai_codex_responses_base ──────────────────────
2539
2540    #[test]
2541    fn normalize_codex_responses_base_empty_uses_default() {
2542        assert_eq!(
2543            normalize_openai_codex_responses_base(""),
2544            openai_responses::CODEX_RESPONSES_API_URL
2545        );
2546    }
2547
2548    #[test]
2549    fn normalize_codex_responses_base_keeps_existing_suffix() {
2550        assert_eq!(
2551            normalize_openai_codex_responses_base(
2552                "https://chatgpt.com/backend-api/codex/responses"
2553            ),
2554            "https://chatgpt.com/backend-api/codex/responses"
2555        );
2556    }
2557
2558    #[test]
2559    fn normalize_codex_responses_base_appends_suffix_from_backend_api() {
2560        assert_eq!(
2561            normalize_openai_codex_responses_base("https://chatgpt.com/backend-api"),
2562            "https://chatgpt.com/backend-api/codex/responses"
2563        );
2564    }
2565
2566    #[test]
2567    fn normalize_codex_responses_base_preserves_query_and_fragment() {
2568        assert_eq!(
2569            normalize_openai_codex_responses_base("https://chatgpt.com/backend-api?via=proxy#frag"),
2570            "https://chatgpt.com/backend-api/codex/responses?via=proxy#frag"
2571        );
2572    }
2573
2574    #[test]
2575    fn normalize_codex_responses_base_handles_opaque_url_fallback() {
2576        assert_eq!(
2577            normalize_openai_codex_responses_base("data:text/plain,hello"),
2578            "data:text/plain,hello/backend-api/codex/responses"
2579        );
2580    }
2581
2582    // ── normalize_cohere_base ───────────────────────────────────────
2583
2584    #[test]
2585    fn normalize_cohere_appends_chat_to_v2() {
2586        assert_eq!(
2587            normalize_cohere_base("https://api.cohere.com/v2"),
2588            "https://api.cohere.com/v2/chat"
2589        );
2590    }
2591
2592    #[test]
2593    fn normalize_cohere_keeps_existing_chat() {
2594        assert_eq!(
2595            normalize_cohere_base("https://api.cohere.com/v2/chat"),
2596            "https://api.cohere.com/v2/chat"
2597        );
2598    }
2599
2600    #[test]
2601    fn normalize_cohere_strips_trailing_slash() {
2602        assert_eq!(
2603            normalize_cohere_base("https://api.cohere.com/v2/"),
2604            "https://api.cohere.com/v2/chat"
2605        );
2606    }
2607
2608    #[test]
2609    fn normalize_cohere_official_bare_url_gets_v2_chat() {
2610        assert_eq!(
2611            normalize_cohere_base("https://api.cohere.com"),
2612            "https://api.cohere.com/v2/chat"
2613        );
2614    }
2615
2616    #[test]
2617    fn normalize_cohere_official_default_port_gets_v2_chat() {
2618        assert_eq!(
2619            normalize_cohere_base("https://api.cohere.com:443"),
2620            "https://api.cohere.com/v2/chat"
2621        );
2622    }
2623
2624    #[test]
2625    fn normalize_cohere_custom_bare_url_gets_chat() {
2626        assert_eq!(
2627            normalize_cohere_base("https://custom-cohere.example.com"),
2628            "https://custom-cohere.example.com/chat"
2629        );
2630    }
2631
2632    #[test]
2633    fn normalize_cohere_preserves_query_and_fragment() {
2634        assert_eq!(
2635            normalize_cohere_base("https://custom-cohere.example.com/v2?tenant=test#frag"),
2636            "https://custom-cohere.example.com/v2/chat?tenant=test#frag"
2637        );
2638    }
2639
2640    #[test]
2641    fn normalize_cohere_preserves_query_and_fragment_on_official_origin() {
2642        assert_eq!(
2643            normalize_cohere_base("https://api.cohere.com:443/?tenant=test#frag"),
2644            "https://api.cohere.com/v2/chat?tenant=test#frag"
2645        );
2646    }
2647
2648    #[test]
2649    fn normalize_cohere_base_empty_uses_default() {
2650        assert_eq!(normalize_cohere_base(""), "https://api.cohere.com/v2/chat");
2651    }
2652
2653    #[test]
2654    fn normalize_cohere_base_handles_opaque_url_fallback() {
2655        assert_eq!(
2656            normalize_cohere_base("data:text/plain,hello"),
2657            "data:text/plain,hello/chat"
2658        );
2659    }
2660
2661    mod proptests {
2662        use super::*;
2663        use proptest::prelude::*;
2664
2665        proptest! {
2666            #[test]
2667            fn normalize_anthropic_base_is_idempotent_and_targets_v1_messages(
2668                base in "[A-Za-z0-9:/._-]{1,96}"
2669            ) {
2670                let normalized = normalize_anthropic_base(&base);
2671                prop_assert!(normalized.ends_with("/v1/messages"));
2672                prop_assert_eq!(normalize_anthropic_base(&normalized), normalized);
2673            }
2674
2675            #[test]
2676            fn normalize_openai_base_is_idempotent_and_targets_chat_completions(
2677                base in "[A-Za-z0-9:/._-]{1,96}"
2678            ) {
2679                let normalized = normalize_openai_base(&base);
2680                prop_assert!(normalized.ends_with("/chat/completions"));
2681                prop_assert_eq!(normalize_openai_base(&normalized), normalized);
2682            }
2683
2684            #[test]
2685            fn normalize_openai_responses_base_is_idempotent_and_targets_responses(
2686                base in "[A-Za-z0-9:/._-]{1,96}"
2687            ) {
2688                let normalized = normalize_openai_responses_base(&base);
2689                prop_assert!(normalized.ends_with("/responses"));
2690                prop_assert_eq!(normalize_openai_responses_base(&normalized), normalized);
2691            }
2692
2693            #[test]
2694            fn normalize_cohere_base_is_idempotent_and_targets_chat(
2695                base in "[A-Za-z0-9:/._-]{1,96}"
2696            ) {
2697                let normalized = normalize_cohere_base(&base);
2698                prop_assert!(normalized.ends_with("/chat"));
2699                prop_assert_eq!(normalize_cohere_base(&normalized), normalized);
2700            }
2701
2702            #[test]
2703            fn normalize_openai_base_rewrites_responses_suffix(
2704                host in "[a-z0-9-]{1,32}",
2705                trailing_slashes in 0usize..4
2706            ) {
2707                let base = format!(
2708                    "https://{host}.example/v1/responses{}",
2709                    "/".repeat(trailing_slashes)
2710                );
2711                prop_assert_eq!(
2712                    normalize_openai_base(&base),
2713                    format!("https://{host}.example/v1/chat/completions")
2714                );
2715            }
2716
2717            #[test]
2718            fn normalize_openai_responses_base_rewrites_chat_completions_suffix(
2719                host in "[a-z0-9-]{1,32}",
2720                trailing_slashes in 0usize..4
2721            ) {
2722                let base = format!(
2723                    "https://{host}.example/v1/chat/completions{}",
2724                    "/".repeat(trailing_slashes)
2725                );
2726                prop_assert_eq!(
2727                    normalize_openai_responses_base(&base),
2728                    format!("https://{host}.example/v1/responses")
2729                );
2730            }
2731        }
2732    }
2733
2734    // ── bd-3uqg.2.4: Compat override propagation ─────────────────────
2735
2736    use crate::models::CompatConfig;
2737
2738    fn compat_with_custom_headers() -> CompatConfig {
2739        let mut custom = HashMap::new();
2740        custom.insert("X-Custom-Header".to_string(), "test-value".to_string());
2741        custom.insert("X-Provider-Tag".to_string(), "override".to_string());
2742        CompatConfig {
2743            custom_headers: Some(custom),
2744            ..Default::default()
2745        }
2746    }
2747
2748    fn model_entry_with_compat(
2749        provider: &str,
2750        api: &str,
2751        model_id: &str,
2752        base_url: &str,
2753        compat: CompatConfig,
2754    ) -> ModelEntry {
2755        let mut entry = model_entry(provider, api, model_id, base_url);
2756        entry.compat = Some(compat);
2757        entry
2758    }
2759
2760    #[test]
2761    fn create_provider_anthropic_accepts_compat_config() {
2762        let entry = model_entry_with_compat(
2763            "anthropic",
2764            "anthropic-messages",
2765            "claude-sonnet-4-5",
2766            "https://api.anthropic.com",
2767            compat_with_custom_headers(),
2768        );
2769        let provider = create_provider(&entry, None).expect("anthropic with compat");
2770        assert_eq!(provider.name(), "anthropic");
2771    }
2772
2773    #[test]
2774    fn create_provider_openai_completions_accepts_compat_config() {
2775        let entry = model_entry_with_compat(
2776            "openai",
2777            "openai-completions",
2778            "gpt-4o",
2779            "https://api.openai.com/v1",
2780            CompatConfig {
2781                max_tokens_field: Some("max_completion_tokens".to_string()),
2782                system_role_name: Some("developer".to_string()),
2783                supports_tools: Some(false),
2784                ..Default::default()
2785            },
2786        );
2787        let provider = create_provider(&entry, None).expect("openai completions with compat");
2788        assert_eq!(provider.name(), "openai");
2789    }
2790
2791    #[test]
2792    fn create_provider_openai_responses_accepts_compat_config() {
2793        let entry = model_entry_with_compat(
2794            "openai",
2795            "openai-responses",
2796            "gpt-4o",
2797            "https://api.openai.com/v1",
2798            compat_with_custom_headers(),
2799        );
2800        let provider = create_provider(&entry, None).expect("openai responses with compat");
2801        assert_eq!(provider.name(), "openai");
2802    }
2803
2804    #[test]
2805    fn create_provider_cohere_accepts_compat_config() {
2806        let entry = model_entry_with_compat(
2807            "cohere",
2808            "cohere-chat",
2809            "command-r-plus",
2810            "https://api.cohere.com/v2",
2811            compat_with_custom_headers(),
2812        );
2813        let provider = create_provider(&entry, None).expect("cohere with compat");
2814        assert_eq!(provider.name(), "cohere");
2815    }
2816
2817    #[test]
2818    fn create_provider_google_accepts_compat_config() {
2819        let entry = model_entry_with_compat(
2820            "google",
2821            "google-generative-ai",
2822            "gemini-2.0-flash",
2823            "https://generativelanguage.googleapis.com",
2824            compat_with_custom_headers(),
2825        );
2826        let provider = create_provider(&entry, None).expect("google with compat");
2827        assert_eq!(provider.name(), "google");
2828    }
2829
2830    #[test]
2831    fn create_provider_fallback_api_routes_accept_compat_config() {
2832        // Custom provider using anthropic-messages API fallback
2833        let entry = model_entry_with_compat(
2834            "custom-anthropic",
2835            "anthropic-messages",
2836            "my-model",
2837            "https://custom.api.com",
2838            compat_with_custom_headers(),
2839        );
2840        let provider = create_provider(&entry, None).expect("fallback anthropic with compat");
2841        assert_eq!(provider.model_id(), "my-model");
2842
2843        // Custom provider using openai-completions API fallback
2844        let entry = model_entry_with_compat(
2845            "my-groq-clone",
2846            "openai-completions",
2847            "llama-3.1",
2848            "http://localhost:8080/v1",
2849            compat_with_custom_headers(),
2850        );
2851        let provider = create_provider(&entry, None).expect("fallback openai with compat");
2852        assert_eq!(provider.model_id(), "llama-3.1");
2853
2854        // Custom provider using cohere-chat API fallback
2855        let entry = model_entry_with_compat(
2856            "custom-cohere",
2857            "cohere-chat",
2858            "custom-r",
2859            "https://custom-cohere.api.com/v2",
2860            compat_with_custom_headers(),
2861        );
2862        let provider = create_provider(&entry, None).expect("fallback cohere with compat");
2863        assert_eq!(provider.model_id(), "custom-r");
2864
2865        // Custom provider using google-generative-ai API fallback
2866        let entry = model_entry_with_compat(
2867            "custom-google",
2868            "google-generative-ai",
2869            "custom-gemini",
2870            "https://custom.google.com",
2871            compat_with_custom_headers(),
2872        );
2873        let provider = create_provider(&entry, None).expect("fallback google with compat");
2874        assert_eq!(provider.model_id(), "custom-gemini");
2875    }
2876
2877    // ── bd-3uqg.3.1: Google Vertex AI provider routing ──────────────
2878
2879    #[test]
2880    fn resolve_provider_route_google_vertex_routes_to_native() {
2881        let entry = model_entry(
2882            "google-vertex",
2883            "google-vertex",
2884            "gemini-2.0-flash",
2885            "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2886        );
2887        let (route, canonical_provider, effective_api) =
2888            resolve_provider_route(&entry).expect("resolve google-vertex route");
2889        assert_eq!(route, ProviderRouteKind::NativeGoogleVertex);
2890        assert_eq!(canonical_provider, "google-vertex");
2891        assert_eq!(effective_api, "google-vertex");
2892    }
2893
2894    #[test]
2895    fn resolve_provider_route_vertexai_alias_routes_to_native() {
2896        let entry = model_entry(
2897            "vertexai",
2898            "google-vertex",
2899            "gemini-2.0-flash",
2900            "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2901        );
2902        let (route, canonical_provider, effective_api) =
2903            resolve_provider_route(&entry).expect("resolve vertexai alias route");
2904        assert_eq!(route, ProviderRouteKind::NativeGoogleVertex);
2905        assert_eq!(canonical_provider, "google-vertex");
2906        assert_eq!(effective_api, "google-vertex");
2907    }
2908
2909    #[test]
2910    fn resolve_provider_route_google_vertex_api_fallback() {
2911        // Unknown provider but google-vertex API should still route correctly
2912        let entry = model_entry(
2913            "custom-vertex",
2914            "google-vertex",
2915            "gemini-2.0-flash",
2916            "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2917        );
2918        let (route, _canonical_provider, effective_api) =
2919            resolve_provider_route(&entry).expect("resolve google-vertex fallback");
2920        assert_eq!(route, ProviderRouteKind::NativeGoogleVertex);
2921        assert_eq!(effective_api, "google-vertex");
2922    }
2923
2924    #[test]
2925    fn create_provider_google_vertex_from_full_url() {
2926        let entry = model_entry(
2927            "google-vertex",
2928            "google-vertex",
2929            "gemini-2.0-flash",
2930            "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2931        );
2932        let provider = create_provider(&entry, None).expect("google-vertex from full URL");
2933        assert_eq!(provider.name(), "google-vertex");
2934        assert_eq!(provider.api(), "google-vertex");
2935        assert_eq!(provider.model_id(), "gemini-2.0-flash");
2936    }
2937
2938    #[test]
2939    fn create_provider_google_vertex_anthropic_publisher() {
2940        let entry = model_entry(
2941            "google-vertex",
2942            "google-vertex",
2943            "claude-sonnet-4-5",
2944            "https://us-east5-aiplatform.googleapis.com/v1/projects/my-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5",
2945        );
2946        let provider =
2947            create_provider(&entry, None).expect("google-vertex with anthropic publisher");
2948        assert_eq!(provider.name(), "google-vertex");
2949        assert_eq!(provider.model_id(), "claude-sonnet-4-5");
2950    }
2951
2952    #[test]
2953    fn create_provider_google_vertex_accepts_compat_config() {
2954        let entry = model_entry_with_compat(
2955            "google-vertex",
2956            "google-vertex",
2957            "gemini-2.0-flash",
2958            "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2959            compat_with_custom_headers(),
2960        );
2961        let provider = create_provider(&entry, None).expect("google-vertex with compat");
2962        assert_eq!(provider.name(), "google-vertex");
2963    }
2964
2965    #[test]
2966    fn create_provider_compat_none_accepted_by_all_routes() {
2967        // Verify None compat doesn't break anything (regression guard)
2968        let routes = [
2969            (
2970                "anthropic",
2971                "anthropic-messages",
2972                "https://api.anthropic.com",
2973            ),
2974            ("openai", "openai-completions", "https://api.openai.com/v1"),
2975            ("openai", "openai-responses", "https://api.openai.com/v1"),
2976            ("cohere", "cohere-chat", "https://api.cohere.com/v2"),
2977            (
2978                "google",
2979                "google-generative-ai",
2980                "https://generativelanguage.googleapis.com",
2981            ),
2982            (
2983                "google-vertex",
2984                "google-vertex",
2985                "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/test-model",
2986            ),
2987        ];
2988        for (provider, api, base_url) in routes {
2989            let entry = model_entry(provider, api, "test-model", base_url);
2990            assert!(
2991                entry.compat.is_none(),
2992                "expected None compat for {provider}"
2993            );
2994            let result = create_provider(&entry, None);
2995            assert!(
2996                result.is_ok(),
2997                "create_provider failed for {provider} with None compat: {:?}",
2998                result.err()
2999            );
3000        }
3001    }
3002}