Skip to main content

pi/providers/
mod.rs

1//! Provider implementations.
2//!
3//! This module contains concrete implementations of the Provider trait
4//! for various LLM APIs.
5
6use crate::error::{Error, Result};
7use crate::extensions::{ExtensionManager, ExtensionRuntimeHandle};
8use crate::http::client::Client;
9use crate::model::{
10    AssistantMessage, AssistantMessageEvent, ContentBlock, StopReason, TextContent, Usage,
11};
12use crate::models::ModelEntry;
13use crate::provider::{Context, Provider, StreamEvent, StreamOptions};
14use crate::provider_metadata::{
15    PROVIDER_METADATA, canonical_provider_id, provider_routing_defaults,
16};
17use crate::vcr::{VCR_ENV_MODE, VcrRecorder};
18use async_trait::async_trait;
19use chrono::Utc;
20use futures::stream;
21use futures::stream::Stream;
22use serde_json::Value;
23use std::env;
24use std::pin::Pin;
25use std::sync::Arc;
26use url::Url;
27
28pub mod anthropic;
29pub mod azure;
30pub mod bedrock;
31pub mod cohere;
32pub mod copilot;
33pub mod gemini;
34pub mod gitlab;
35pub mod openai;
36pub mod openai_responses;
37pub mod vertex;
38
39fn vcr_client_if_enabled() -> Result<Option<Client>> {
40    if env::var(VCR_ENV_MODE).is_err() {
41        return Ok(None);
42    }
43
44    let test_name = env::var("PI_VCR_TEST_NAME").unwrap_or_else(|_| "pi_runtime".to_string());
45    let recorder = VcrRecorder::new(&test_name)?;
46    Ok(Some(Client::new().with_vcr(recorder)))
47}
48
49struct ExtensionStreamSimpleProvider {
50    model: crate::provider::Model,
51    runtime: ExtensionRuntimeHandle,
52}
53
54struct ExtensionStreamSimpleState {
55    runtime: ExtensionRuntimeHandle,
56    stream_id: Option<String>,
57    model_id: String,
58    provider: String,
59    api: String,
60    accumulated_text: String,
61    last_message: Option<AssistantMessage>,
62    /// Whether `StreamEvent::Start` + `TextStart` have been emitted for string-chunk mode.
63    string_chunk_started: bool,
64    /// Buffered events to drain before polling the next JS chunk.
65    pending_events: std::collections::VecDeque<StreamEvent>,
66}
67
68impl Drop for ExtensionStreamSimpleState {
69    fn drop(&mut self) {
70        if let Some(stream_id) = self.stream_id.take() {
71            self.runtime
72                .provider_stream_simple_cancel_best_effort(stream_id);
73        }
74    }
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78enum ProviderRouteKind {
79    NativeAnthropic,
80    NativeOpenAICompletions,
81    NativeOpenAIResponses,
82    NativeOpenAICodexResponses,
83    NativeCohere,
84    NativeGoogle,
85    NativeGoogleGeminiCli,
86    NativeGoogleVertex,
87    NativeBedrock,
88    NativeAzure,
89    NativeCopilot,
90    NativeGitlab,
91    ApiAnthropicMessages,
92    ApiOpenAICompletions,
93    ApiOpenAIResponses,
94    ApiOpenAICodexResponses,
95    ApiCohereChat,
96    ApiGoogleGenerativeAi,
97    ApiGoogleGeminiCli,
98}
99
100impl ProviderRouteKind {
101    const fn as_str(self) -> &'static str {
102        match self {
103            Self::NativeAnthropic => "native:anthropic",
104            Self::NativeOpenAICompletions => "native:openai-completions",
105            Self::NativeOpenAIResponses => "native:openai-responses",
106            Self::NativeOpenAICodexResponses => "native:openai-codex-responses",
107            Self::NativeCohere => "native:cohere",
108            Self::NativeGoogle => "native:google",
109            Self::NativeGoogleGeminiCli => "native:google-gemini-cli",
110            Self::NativeGoogleVertex => "native:google-vertex",
111            Self::NativeBedrock => "native:amazon-bedrock",
112            Self::NativeAzure => "native:azure-openai",
113            Self::NativeCopilot => "native:github-copilot",
114            Self::NativeGitlab => "native:gitlab",
115            Self::ApiAnthropicMessages => "api:anthropic-messages",
116            Self::ApiOpenAICompletions => "api:openai-completions",
117            Self::ApiOpenAIResponses => "api:openai-responses",
118            Self::ApiOpenAICodexResponses => "api:openai-codex-responses",
119            Self::ApiCohereChat => "api:cohere-chat",
120            Self::ApiGoogleGenerativeAi => "api:google-generative-ai",
121            Self::ApiGoogleGeminiCli => "api:google-gemini-cli",
122        }
123    }
124}
125
126fn resolve_provider_route(entry: &ModelEntry) -> Result<(ProviderRouteKind, String, String)> {
127    let canonical_provider =
128        canonical_provider_id(&entry.model.provider).unwrap_or(entry.model.provider.as_str());
129    let schema_api = provider_routing_defaults(&entry.model.provider).map(|defaults| defaults.api);
130    let effective_api = if entry.model.api.is_empty() {
131        schema_api.unwrap_or_default().to_string()
132    } else {
133        entry.model.api.clone()
134    };
135
136    let route = match canonical_provider {
137        "anthropic" => ProviderRouteKind::NativeAnthropic,
138        "openai" => {
139            if effective_api == "openai-completions" {
140                ProviderRouteKind::NativeOpenAICompletions
141            } else {
142                ProviderRouteKind::NativeOpenAIResponses
143            }
144        }
145        "openai-codex" => ProviderRouteKind::NativeOpenAICodexResponses,
146        "cohere" => ProviderRouteKind::NativeCohere,
147        "google" => ProviderRouteKind::NativeGoogle,
148        "google-gemini-cli" | "google-antigravity" => ProviderRouteKind::NativeGoogleGeminiCli,
149        "google-vertex" | "vertexai" => ProviderRouteKind::NativeGoogleVertex,
150        "amazon-bedrock" | "bedrock" => ProviderRouteKind::NativeBedrock,
151        "azure-openai" | "azure" | "azure-cognitive-services" | "azure-openai-responses" => {
152            ProviderRouteKind::NativeAzure
153        }
154        "github-copilot" | "copilot" => ProviderRouteKind::NativeCopilot,
155        "gitlab" | "gitlab-duo" => ProviderRouteKind::NativeGitlab,
156        _ => match effective_api.as_str() {
157            "anthropic-messages" => ProviderRouteKind::ApiAnthropicMessages,
158            "openai-completions" => ProviderRouteKind::ApiOpenAICompletions,
159            "openai-responses" => ProviderRouteKind::ApiOpenAIResponses,
160            "openai-codex-responses" => ProviderRouteKind::ApiOpenAICodexResponses,
161            "cohere-chat" => ProviderRouteKind::ApiCohereChat,
162            "google-generative-ai" => ProviderRouteKind::ApiGoogleGenerativeAi,
163            "google-gemini-cli" => ProviderRouteKind::ApiGoogleGeminiCli,
164            "google-vertex" => ProviderRouteKind::NativeGoogleVertex,
165            "bedrock-converse-stream" => ProviderRouteKind::NativeBedrock,
166            "azure-openai-responses" => ProviderRouteKind::NativeAzure,
167            _ => {
168                let suggestions = suggest_similar_providers(&entry.model.provider);
169                let msg = if suggestions.is_empty() {
170                    format!("Provider not implemented (api: {effective_api})")
171                } else {
172                    format!(
173                        "Provider not implemented (api: {effective_api}). Did you mean: {}?",
174                        suggestions.join(", ")
175                    )
176                };
177                return Err(Error::provider(&entry.model.provider, msg));
178            }
179        },
180    };
181
182    Ok((route, canonical_provider.to_string(), effective_api))
183}
184
185/// Levenshtein edit distance between two byte slices. Uses a single-row
186/// buffer so memory is O(min(a,b)).
187fn edit_distance(a: &[u8], b: &[u8]) -> usize {
188    let (short, long) = if a.len() <= b.len() { (a, b) } else { (b, a) };
189    let mut row: Vec<usize> = (0..=short.len()).collect();
190    for (i, &lb) in long.iter().enumerate() {
191        let mut prev = i;
192        row[0] = i + 1;
193        for (j, &sb) in short.iter().enumerate() {
194            let cost = usize::from(lb != sb);
195            let val = (row[j + 1] + 1).min(row[j] + 1).min(prev + cost);
196            prev = row[j + 1];
197            row[j + 1] = val;
198        }
199    }
200    row[short.len()]
201}
202
203/// Maximum edit distance allowed for a fuzzy suggestion, scaled by the
204/// length of the input so very short inputs don't produce false positives.
205const fn max_edit_distance(input_len: usize) -> usize {
206    match input_len {
207        0..=2 => 0,
208        3..=5 => 1,
209        6..=9 => 2,
210        _ => 3,
211    }
212}
213
214/// Suggest provider names similar to `input` by checking prefix matching,
215/// substring containment, and Levenshtein edit distance against all
216/// canonical IDs and aliases.
217fn suggest_similar_providers(input: &str) -> Vec<String> {
218    let needle = input.to_lowercase();
219    let needle_bytes = needle.as_bytes();
220    let threshold = max_edit_distance(needle.len());
221    let mut matches: Vec<(usize, String)> = Vec::new();
222
223    for meta in PROVIDER_METADATA {
224        let names: Vec<&str> = std::iter::once(meta.canonical_id)
225            .chain(meta.aliases.iter().copied())
226            .collect();
227        let mut matched = false;
228        for name in &names {
229            let haystack = name.to_lowercase();
230            // Tier 0: exact prefix match (highest quality)
231            if haystack.starts_with(&needle) || needle.starts_with(&haystack) {
232                matches.push((0, meta.canonical_id.to_string()));
233                matched = true;
234                break;
235            }
236            // Tier 1: substring containment
237            if haystack.contains(&needle) || needle.contains(&haystack) {
238                matches.push((1, meta.canonical_id.to_string()));
239                matched = true;
240                break;
241            }
242        }
243        if matched {
244            continue;
245        }
246        // Tier 2: edit distance (typo correction)
247        if threshold > 0 {
248            let mut best_dist = usize::MAX;
249            for name in &names {
250                let haystack = name.to_lowercase();
251                let dist = edit_distance(needle_bytes, haystack.as_bytes());
252                best_dist = best_dist.min(dist);
253            }
254            if best_dist <= threshold {
255                // Encode distance in the sort key so closer matches rank higher
256                matches.push((
257                    2_usize.wrapping_add(best_dist),
258                    meta.canonical_id.to_string(),
259                ));
260            }
261        }
262    }
263
264    matches.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
265    matches.dedup_by(|a, b| a.1 == b.1);
266    matches.truncate(3);
267    matches.into_iter().map(|(_, name)| name).collect()
268}
269
270const AZURE_OPENAI_RESOURCE_ENV: &str = "AZURE_OPENAI_RESOURCE";
271const AZURE_OPENAI_DEPLOYMENT_ENV: &str = "AZURE_OPENAI_DEPLOYMENT";
272const AZURE_OPENAI_API_VERSION_ENV: &str = "AZURE_OPENAI_API_VERSION";
273
274#[derive(Debug, Clone, PartialEq, Eq)]
275struct AzureProviderRuntime {
276    resource: String,
277    deployment: String,
278    api_version: String,
279    endpoint_url: String,
280}
281
282fn trim_non_empty(value: Option<String>) -> Option<String> {
283    value
284        .map(|v| v.trim().to_string())
285        .filter(|v| !v.is_empty())
286}
287
288fn parse_azure_resource_from_host(host: &str) -> Option<String> {
289    host.strip_suffix(".openai.azure.com")
290        .or_else(|| host.strip_suffix(".cognitiveservices.azure.com"))
291        .map(str::trim)
292        .filter(|value| !value.is_empty())
293        .map(ToString::to_string)
294}
295
296fn parse_azure_base_url_details(
297    base_url: &str,
298) -> Result<(String, Option<String>, Option<String>)> {
299    let url = Url::parse(base_url)
300        .map_err(|err| Error::config(format!("Invalid Azure base_url '{base_url}': {err}")))?;
301    let host = url.host_str().map(ToString::to_string).ok_or_else(|| {
302        Error::config(format!(
303            "Azure base_url is missing host information: '{base_url}'"
304        ))
305    })?;
306
307    let mut deployment = None;
308    if let Some(segments) = url.path_segments() {
309        let mut iter = segments;
310        while let Some(segment) = iter.next() {
311            if segment == "deployments" {
312                deployment = iter
313                    .next()
314                    .map(str::trim)
315                    .filter(|value| !value.is_empty())
316                    .map(ToString::to_string);
317                break;
318            }
319        }
320    }
321
322    let api_version = url
323        .query_pairs()
324        .find(|(key, _)| key == "api-version")
325        .map(|(_, value)| value.into_owned())
326        .filter(|value| !value.trim().is_empty());
327
328    Ok((host, deployment, api_version))
329}
330
331fn resolve_azure_provider_runtime(entry: &ModelEntry) -> Result<AzureProviderRuntime> {
332    resolve_azure_provider_runtime_with_env(entry, |name| env::var(name).ok())
333}
334
335fn resolve_azure_provider_runtime_with_env<F>(
336    entry: &ModelEntry,
337    mut env_lookup: F,
338) -> Result<AzureProviderRuntime>
339where
340    F: FnMut(&str) -> Option<String>,
341{
342    let base_url = entry.model.base_url.trim();
343    if base_url.is_empty() {
344        return Err(Error::config(format!(
345            "Missing Azure base_url for provider '{}'; expected https://<resource>.openai.azure.com or https://<resource>.cognitiveservices.azure.com",
346            entry.model.provider
347        )));
348    }
349
350    let (host, base_deployment, base_api_version) = parse_azure_base_url_details(base_url)?;
351    let host_resource = parse_azure_resource_from_host(&host);
352    let env_resource = trim_non_empty(env_lookup(AZURE_OPENAI_RESOURCE_ENV));
353    let resource = env_resource.or(host_resource).ok_or_else(|| {
354        Error::config(format!(
355            "Unable to resolve Azure resource for provider '{}'; set {AZURE_OPENAI_RESOURCE_ENV} or use an Azure host in base_url ('{base_url}')",
356            entry.model.provider
357        ))
358    })?;
359
360    let env_deployment = trim_non_empty(env_lookup(AZURE_OPENAI_DEPLOYMENT_ENV));
361    let model_deployment = {
362        let model_id = entry.model.id.trim();
363        (!model_id.is_empty()).then(|| model_id.to_string())
364    };
365    let deployment = env_deployment
366        .or(model_deployment)
367        .or(base_deployment)
368        .ok_or_else(|| {
369            Error::config(format!(
370                "Unable to resolve Azure deployment for provider '{}'; set {AZURE_OPENAI_DEPLOYMENT_ENV}, provide a non-empty model id, or include '/deployments/<name>' in base_url ('{base_url}')",
371                entry.model.provider
372            ))
373        })?;
374
375    let api_version = trim_non_empty(env_lookup(AZURE_OPENAI_API_VERSION_ENV))
376        .or(base_api_version)
377        .unwrap_or_else(|| azure::DEFAULT_API_VERSION.to_string());
378
379    let endpoint_host = if parse_azure_resource_from_host(&host).is_some() {
380        host
381    } else {
382        format!("{resource}.openai.azure.com")
383    };
384    let endpoint_url = format!(
385        "https://{endpoint_host}/openai/deployments/{deployment}/chat/completions?api-version={api_version}"
386    );
387
388    Ok(AzureProviderRuntime {
389        resource,
390        deployment,
391        api_version,
392        endpoint_url,
393    })
394}
395
396fn resolve_copilot_token(entry: &ModelEntry) -> Result<String> {
397    resolve_copilot_token_with_env(entry, |name| env::var(name).ok())
398}
399
400fn resolve_copilot_token_with_env<F>(entry: &ModelEntry, mut env_lookup: F) -> Result<String>
401where
402    F: FnMut(&str) -> Option<String>,
403{
404    let inline = entry
405        .api_key
406        .as_deref()
407        .map(str::trim)
408        .filter(|value| !value.is_empty())
409        .map(ToString::to_string);
410    let from_env = || {
411        env_lookup("GITHUB_COPILOT_API_KEY")
412            .or_else(|| env_lookup("GITHUB_TOKEN"))
413            .map(|value| value.trim().to_string())
414            .filter(|value| !value.is_empty())
415    };
416
417    inline.or_else(from_env).ok_or_else(|| {
418        Error::auth(
419            "GitHub Copilot requires login credentials or GITHUB_COPILOT_API_KEY/GITHUB_TOKEN",
420        )
421    })
422}
423
424impl ExtensionStreamSimpleProvider {
425    const NEXT_TIMEOUT_MS: u64 = 600_000;
426
427    const fn new(model: crate::provider::Model, runtime: ExtensionRuntimeHandle) -> Self {
428        Self { model, runtime }
429    }
430
431    fn build_js_model(model: &crate::provider::Model) -> Value {
432        serde_json::json!({
433            "id": &model.id,
434            "name": &model.name,
435            "api": &model.api,
436            "provider": &model.provider,
437            "baseUrl": &model.base_url,
438            "reasoning": model.reasoning,
439            "input": &model.input,
440            "cost": &model.cost,
441            "contextWindow": model.context_window,
442            "maxTokens": model.max_tokens,
443            "headers": &model.headers,
444        })
445    }
446
447    fn build_js_context(context: &Context<'_>) -> Value {
448        let mut map = serde_json::Map::new();
449        if let Some(system_prompt) = &context.system_prompt {
450            map.insert(
451                "systemPrompt".to_string(),
452                Value::String(system_prompt.to_string()),
453            );
454        }
455        map.insert(
456            "messages".to_string(),
457            serde_json::to_value(&context.messages).unwrap_or(Value::Array(Vec::new())),
458        );
459        if !context.tools.is_empty() {
460            let tools = context
461                .tools
462                .iter()
463                .map(|tool| {
464                    serde_json::json!({
465                        "name": tool.name,
466                        "description": tool.description,
467                        "parameters": tool.parameters,
468                    })
469                })
470                .collect::<Vec<_>>();
471            map.insert("tools".to_string(), Value::Array(tools));
472        }
473        Value::Object(map)
474    }
475
476    fn build_js_options(options: &StreamOptions) -> Value {
477        let mut map = serde_json::Map::new();
478        if let Some(temp) = options.temperature {
479            map.insert("temperature".to_string(), serde_json::json!(temp));
480        }
481        if let Some(max_tokens) = options.max_tokens {
482            map.insert("maxTokens".to_string(), serde_json::json!(max_tokens));
483        }
484        if let Some(api_key) = &options.api_key {
485            map.insert("apiKey".to_string(), Value::String(api_key.clone()));
486        }
487        if let Some(session_id) = &options.session_id {
488            map.insert("sessionId".to_string(), Value::String(session_id.clone()));
489        }
490        if !options.headers.is_empty() {
491            map.insert(
492                "headers".to_string(),
493                serde_json::to_value(&options.headers)
494                    .unwrap_or_else(|_| Value::Object(serde_json::Map::new())),
495            );
496        }
497        let cache_retention = match options.cache_retention {
498            crate::provider::CacheRetention::None => "none",
499            crate::provider::CacheRetention::Short => "short",
500            crate::provider::CacheRetention::Long => "long",
501        };
502        map.insert(
503            "cacheRetention".to_string(),
504            Value::String(cache_retention.to_string()),
505        );
506        if let Some(level) = options.thinking_level {
507            if level != crate::model::ThinkingLevel::Off {
508                map.insert("reasoning".to_string(), Value::String(level.to_string()));
509            }
510        }
511        if let Some(budgets) = &options.thinking_budgets {
512            map.insert(
513                "thinkingBudgets".to_string(),
514                serde_json::json!({
515                    "minimal": budgets.minimal,
516                    "low": budgets.low,
517                    "medium": budgets.medium,
518                    "high": budgets.high,
519                    "xhigh": budgets.xhigh,
520                }),
521            );
522        }
523        Value::Object(map)
524    }
525
526    fn assistant_event_to_stream_event(event: AssistantMessageEvent) -> StreamEvent {
527        match event {
528            AssistantMessageEvent::Start { partial } => StreamEvent::Start {
529                partial: partial.as_ref().clone(),
530            },
531            AssistantMessageEvent::TextStart { content_index, .. } => {
532                StreamEvent::TextStart { content_index }
533            }
534            AssistantMessageEvent::TextDelta {
535                content_index,
536                delta,
537                ..
538            } => StreamEvent::TextDelta {
539                content_index,
540                delta,
541            },
542            AssistantMessageEvent::TextEnd {
543                content_index,
544                content,
545                ..
546            } => StreamEvent::TextEnd {
547                content_index,
548                content,
549            },
550            AssistantMessageEvent::ThinkingStart { content_index, .. } => {
551                StreamEvent::ThinkingStart { content_index }
552            }
553            AssistantMessageEvent::ThinkingDelta {
554                content_index,
555                delta,
556                ..
557            } => StreamEvent::ThinkingDelta {
558                content_index,
559                delta,
560            },
561            AssistantMessageEvent::ThinkingEnd {
562                content_index,
563                content,
564                ..
565            } => StreamEvent::ThinkingEnd {
566                content_index,
567                content,
568            },
569            AssistantMessageEvent::ToolCallStart { content_index, .. } => {
570                StreamEvent::ToolCallStart { content_index }
571            }
572            AssistantMessageEvent::ToolCallDelta {
573                content_index,
574                delta,
575                ..
576            } => StreamEvent::ToolCallDelta {
577                content_index,
578                delta,
579            },
580            AssistantMessageEvent::ToolCallEnd {
581                content_index,
582                tool_call,
583                ..
584            } => StreamEvent::ToolCallEnd {
585                content_index,
586                tool_call,
587            },
588            AssistantMessageEvent::Done { reason, message } => StreamEvent::Done {
589                reason,
590                message: message.as_ref().clone(),
591            },
592            AssistantMessageEvent::Error { reason, error } => StreamEvent::Error {
593                reason,
594                error: error.as_ref().clone(),
595            },
596        }
597    }
598
599    fn make_partial(model_id: &str, provider: &str, api: &str, text: &str) -> AssistantMessage {
600        AssistantMessage {
601            model: model_id.to_string(),
602            api: api.to_string(),
603            provider: provider.to_string(),
604            content: vec![ContentBlock::Text(TextContent {
605                text: text.to_string(),
606                text_signature: None,
607            })],
608            stop_reason: StopReason::default(),
609            usage: Usage::default(),
610            error_message: None,
611            timestamp: Utc::now().timestamp_millis(),
612        }
613    }
614}
615
616#[allow(clippy::too_many_lines)]
617#[async_trait]
618impl Provider for ExtensionStreamSimpleProvider {
619    #[allow(clippy::misnamed_getters)]
620    fn name(&self) -> &str {
621        &self.model.provider
622    }
623
624    fn api(&self) -> &str {
625        &self.model.api
626    }
627
628    fn model_id(&self) -> &str {
629        &self.model.id
630    }
631
632    async fn stream(
633        &self,
634        context: &Context<'_>,
635        options: &StreamOptions,
636    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
637        let model = Self::build_js_model(&self.model);
638        let ctx = Self::build_js_context(context);
639        let opts = Self::build_js_options(options);
640
641        let stream_id = self
642            .runtime
643            .provider_stream_simple_start(
644                self.model.provider.clone(),
645                model,
646                ctx,
647                opts,
648                Self::NEXT_TIMEOUT_MS,
649            )
650            .await?;
651
652        let state = ExtensionStreamSimpleState {
653            runtime: self.runtime.clone(),
654            stream_id: Some(stream_id),
655            model_id: self.model.id.clone(),
656            provider: self.model.provider.clone(),
657            api: self.model.api.clone(),
658            accumulated_text: String::new(),
659            last_message: None,
660            string_chunk_started: false,
661            pending_events: std::collections::VecDeque::new(),
662        };
663
664        let stream = stream::unfold(state, |mut state| async move {
665            // Drain any buffered events before polling JS.
666            if let Some(event) = state.pending_events.pop_front() {
667                return Some((Ok(event), state));
668            }
669
670            let stream_id = state.stream_id.clone()?;
671            let stream_id_for_cancel = stream_id.clone();
672
673            match state
674                .runtime
675                .provider_stream_simple_next(stream_id, Self::NEXT_TIMEOUT_MS)
676                .await
677            {
678                Ok(Some(value)) => {
679                    if let Some(chunk) = value.as_str() {
680                        let chunk = chunk.to_string();
681                        state.accumulated_text.push_str(&chunk);
682                        // Update last_message in-place: mutate existing text
683                        // content instead of rebuilding the entire
684                        // AssistantMessage (avoids 3 String + Vec allocs per
685                        // chunk).
686                        match &mut state.last_message {
687                            Some(msg) => {
688                                if let Some(ContentBlock::Text(t)) = msg.content.first_mut() {
689                                    t.text.clone_from(&state.accumulated_text);
690                                }
691                            }
692                            None => {
693                                state.last_message = Some(Self::make_partial(
694                                    &state.model_id,
695                                    &state.provider,
696                                    &state.api,
697                                    &state.accumulated_text,
698                                ));
699                            }
700                        }
701
702                        // Emit Start + TextStart before first string-chunk TextDelta.
703                        if !state.string_chunk_started {
704                            state.string_chunk_started = true;
705                            state
706                                .pending_events
707                                .push_back(StreamEvent::TextStart { content_index: 0 });
708                            state.pending_events.push_back(StreamEvent::TextDelta {
709                                content_index: 0,
710                                delta: chunk,
711                            });
712                            return Some((
713                                Ok(StreamEvent::Start {
714                                    partial: state.last_message.clone().unwrap_or_else(|| {
715                                        Self::make_partial(
716                                            &state.model_id,
717                                            &state.provider,
718                                            &state.api,
719                                            &state.accumulated_text,
720                                        )
721                                    }),
722                                }),
723                                state,
724                            ));
725                        }
726                        return Some((
727                            Ok(StreamEvent::TextDelta {
728                                content_index: 0,
729                                delta: chunk,
730                            }),
731                            state,
732                        ));
733                    }
734
735                    let event: AssistantMessageEvent = match serde_json::from_value(value) {
736                        Ok(event) => event,
737                        Err(err) => {
738                            state
739                                .runtime
740                                .provider_stream_simple_cancel_best_effort(stream_id_for_cancel);
741                            state.stream_id = None;
742                            return Some((
743                                Err(Error::extension(format!(
744                                    "streamSimple yielded invalid event: {err}"
745                                ))),
746                                state,
747                            ));
748                        }
749                    };
750
751                    match &event {
752                        AssistantMessageEvent::Start { partial }
753                        | AssistantMessageEvent::TextStart { partial, .. }
754                        | AssistantMessageEvent::TextDelta { partial, .. }
755                        | AssistantMessageEvent::TextEnd { partial, .. }
756                        | AssistantMessageEvent::ThinkingStart { partial, .. }
757                        | AssistantMessageEvent::ThinkingDelta { partial, .. }
758                        | AssistantMessageEvent::ThinkingEnd { partial, .. }
759                        | AssistantMessageEvent::ToolCallStart { partial, .. }
760                        | AssistantMessageEvent::ToolCallDelta { partial, .. }
761                        | AssistantMessageEvent::ToolCallEnd { partial, .. } => {
762                            state.last_message = Some(partial.as_ref().clone());
763                        }
764                        AssistantMessageEvent::Done { message, .. } => {
765                            state.last_message = Some(message.as_ref().clone());
766                        }
767                        AssistantMessageEvent::Error { error, .. } => {
768                            state.last_message = Some(error.as_ref().clone());
769                        }
770                    }
771
772                    let stream_event = Self::assistant_event_to_stream_event(event);
773                    if matches!(
774                        stream_event,
775                        StreamEvent::Done { .. } | StreamEvent::Error { .. }
776                    ) {
777                        state
778                            .runtime
779                            .provider_stream_simple_cancel_best_effort(stream_id_for_cancel);
780                        state.stream_id = None;
781                    }
782                    Some((Ok(stream_event), state))
783                }
784                Ok(None) => {
785                    // Stream ended — emit TextEnd (if string chunks were used) then Done.
786                    state.stream_id = None;
787                    let message = state.last_message.clone().unwrap_or_else(|| {
788                        Self::make_partial(
789                            &state.model_id,
790                            &state.provider,
791                            &state.api,
792                            &state.accumulated_text,
793                        )
794                    });
795
796                    if state.string_chunk_started {
797                        // Emit TextEnd before Done.
798                        state.pending_events.push_back(StreamEvent::Done {
799                            reason: StopReason::Stop,
800                            message,
801                        });
802                        Some((
803                            Ok(StreamEvent::TextEnd {
804                                content_index: 0,
805                                content: state.accumulated_text.clone(),
806                            }),
807                            state,
808                        ))
809                    } else {
810                        Some((
811                            Ok(StreamEvent::Done {
812                                reason: StopReason::Stop,
813                                message,
814                            }),
815                            state,
816                        ))
817                    }
818                }
819                Err(err) => {
820                    state
821                        .runtime
822                        .provider_stream_simple_cancel_best_effort(stream_id_for_cancel);
823                    state.stream_id = None;
824                    Some((Err(err), state))
825                }
826            }
827        });
828
829        Ok(Box::pin(stream))
830    }
831}
832
833#[allow(clippy::too_many_lines)]
834pub fn create_provider(
835    entry: &ModelEntry,
836    extensions: Option<&ExtensionManager>,
837) -> Result<Arc<dyn Provider>> {
838    if let Some(manager) = extensions {
839        if manager.provider_has_stream_simple(&entry.model.provider) {
840            let runtime = manager.runtime().ok_or_else(|| {
841                Error::provider(
842                    &entry.model.provider,
843                    "Extension runtime not configured for streamSimple provider",
844                )
845            })?;
846            return Ok(Arc::new(ExtensionStreamSimpleProvider::new(
847                entry.model.clone(),
848                runtime,
849            )));
850        }
851    }
852
853    let vcr_client = vcr_client_if_enabled()?;
854    let client = vcr_client.unwrap_or_else(Client::new);
855    let (route, canonical_provider, effective_api) = resolve_provider_route(entry)?;
856    tracing::debug!(
857        event = "pi.provider.factory.select",
858        provider = %entry.model.provider,
859        canonical_provider = %canonical_provider,
860        api = %effective_api,
861        base_url = %entry.model.base_url,
862        route = %route.as_str(),
863        "Selecting provider implementation"
864    );
865
866    match route {
867        ProviderRouteKind::NativeAnthropic | ProviderRouteKind::ApiAnthropicMessages => {
868            Ok(Arc::new(
869                anthropic::AnthropicProvider::new(entry.model.id.clone())
870                    .with_provider_name(entry.model.provider.clone())
871                    .with_base_url(normalize_anthropic_base(&entry.model.base_url))
872                    .with_compat(entry.compat.clone())
873                    .with_client(client),
874            ))
875        }
876        ProviderRouteKind::NativeOpenAICompletions | ProviderRouteKind::ApiOpenAICompletions => {
877            Ok(Arc::new(
878                openai::OpenAIProvider::new(entry.model.id.clone())
879                    .with_provider_name(entry.model.provider.clone())
880                    .with_base_url(normalize_openai_base(&entry.model.base_url))
881                    .with_compat(entry.compat.clone())
882                    .with_client(client),
883            ))
884        }
885        ProviderRouteKind::NativeOpenAIResponses | ProviderRouteKind::ApiOpenAIResponses => {
886            Ok(Arc::new(
887                openai_responses::OpenAIResponsesProvider::new(entry.model.id.clone())
888                    .with_provider_name(entry.model.provider.clone())
889                    .with_base_url(normalize_openai_responses_base(&entry.model.base_url))
890                    .with_compat(entry.compat.clone())
891                    .with_client(client),
892            ))
893        }
894        ProviderRouteKind::NativeOpenAICodexResponses
895        | ProviderRouteKind::ApiOpenAICodexResponses => Ok(Arc::new(
896            openai_responses::OpenAIResponsesProvider::new(entry.model.id.clone())
897                .with_provider_name(entry.model.provider.clone())
898                .with_api_name("openai-codex-responses")
899                .with_codex_mode(true)
900                .with_base_url(normalize_openai_codex_responses_base(&entry.model.base_url))
901                .with_compat(entry.compat.clone())
902                .with_client(client),
903        )),
904        ProviderRouteKind::NativeCohere | ProviderRouteKind::ApiCohereChat => Ok(Arc::new(
905            cohere::CohereProvider::new(entry.model.id.clone())
906                .with_provider_name(entry.model.provider.clone())
907                .with_base_url(normalize_cohere_base(&entry.model.base_url))
908                .with_compat(entry.compat.clone())
909                .with_client(client),
910        )),
911        ProviderRouteKind::NativeGoogle | ProviderRouteKind::ApiGoogleGenerativeAi => Ok(Arc::new(
912            gemini::GeminiProvider::new(entry.model.id.clone())
913                .with_provider_name(entry.model.provider.clone())
914                .with_api_name("google-generative-ai")
915                .with_base_url(entry.model.base_url.clone())
916                .with_compat(entry.compat.clone())
917                .with_client(client),
918        )),
919        ProviderRouteKind::NativeGoogleGeminiCli | ProviderRouteKind::ApiGoogleGeminiCli => {
920            Ok(Arc::new(
921                gemini::GeminiProvider::new(entry.model.id.clone())
922                    .with_provider_name(entry.model.provider.clone())
923                    .with_api_name("google-gemini-cli")
924                    .with_google_cli_mode(true)
925                    .with_base_url(entry.model.base_url.clone())
926                    .with_compat(entry.compat.clone())
927                    .with_client(client),
928            ))
929        }
930        ProviderRouteKind::NativeGoogleVertex => {
931            let runtime = vertex::resolve_vertex_provider_runtime(entry)?;
932            Ok(Arc::new(
933                vertex::VertexProvider::new(runtime.model)
934                    .with_project(runtime.project)
935                    .with_location(runtime.location)
936                    .with_publisher(runtime.publisher)
937                    .with_compat(entry.compat.clone())
938                    .with_client(client),
939            ))
940        }
941        ProviderRouteKind::NativeBedrock => Ok(Arc::new(
942            bedrock::BedrockProvider::new(&entry.model.id)
943                .with_provider_name(&entry.model.provider)
944                .with_base_url(&entry.model.base_url)
945                .with_compat(entry.compat.clone())
946                .with_client(client),
947        )),
948        ProviderRouteKind::NativeAzure => {
949            let runtime = resolve_azure_provider_runtime(entry)?;
950            Ok(Arc::new(
951                azure::AzureOpenAIProvider::new(runtime.resource, runtime.deployment)
952                    .with_api_version(runtime.api_version)
953                    .with_endpoint_url(runtime.endpoint_url)
954                    .with_compat(entry.compat.clone())
955                    .with_client(client),
956            ))
957        }
958        ProviderRouteKind::NativeCopilot => {
959            let github_token = resolve_copilot_token(entry)?;
960            let mut provider = copilot::CopilotProvider::new(&entry.model.id, github_token)
961                .with_provider_name(&entry.model.provider)
962                .with_compat(entry.compat.clone())
963                .with_client(client);
964            if !entry.model.base_url.is_empty() {
965                provider = provider.with_github_api_base(&entry.model.base_url);
966            }
967            Ok(Arc::new(provider))
968        }
969        ProviderRouteKind::NativeGitlab => Ok(Arc::new(
970            gitlab::GitLabProvider::new(&entry.model.id)
971                .with_provider_name(&entry.model.provider)
972                .with_base_url(&entry.model.base_url)
973                .with_compat(entry.compat.clone())
974                .with_client(client),
975        )),
976    }
977}
978
979pub fn normalize_anthropic_base(base_url: &str) -> String {
980    let trimmed = base_url.trim();
981    if trimmed.is_empty() {
982        return "https://api.anthropic.com/v1/messages".to_string();
983    }
984    let base_url = trimmed.trim_end_matches('/');
985    if base_url.ends_with("/v1/messages") {
986        return base_url.to_string();
987    }
988    format!("{base_url}/v1/messages")
989}
990
991pub fn normalize_openai_base(base_url: &str) -> String {
992    let trimmed = base_url.trim();
993    if trimmed.is_empty() {
994        return "https://api.openai.com/v1/chat/completions".to_string();
995    }
996    let base_url = trimmed.trim_end_matches('/');
997    if base_url.ends_with("/chat/completions") {
998        return base_url.to_string();
999    }
1000    let base_url = base_url.strip_suffix("/responses").unwrap_or(base_url);
1001    if base_url.ends_with("/v1") {
1002        return format!("{base_url}/chat/completions");
1003    }
1004    format!("{base_url}/chat/completions")
1005}
1006
1007pub fn normalize_openai_responses_base(base_url: &str) -> String {
1008    let trimmed = base_url.trim();
1009    if trimmed.is_empty() {
1010        return "https://api.openai.com/v1/responses".to_string();
1011    }
1012    let base_url = trimmed.trim_end_matches('/');
1013    if base_url.ends_with("/responses") {
1014        return base_url.to_string();
1015    }
1016    let base_url = base_url
1017        .strip_suffix("/chat/completions")
1018        .unwrap_or(base_url);
1019    if base_url.ends_with("/v1") {
1020        return format!("{base_url}/responses");
1021    }
1022    format!("{base_url}/responses")
1023}
1024
1025pub fn normalize_openai_codex_responses_base(base_url: &str) -> String {
1026    let trimmed = base_url.trim();
1027    if trimmed.is_empty() {
1028        return openai_responses::CODEX_RESPONSES_API_URL.to_string();
1029    }
1030    let base = trimmed.trim_end_matches('/');
1031    if base.ends_with("/backend-api/codex/responses") {
1032        return base.to_string();
1033    }
1034    // Some registries (including legacy Pi) store the ChatGPT base as
1035    // `https://chatgpt.com/backend-api`. In that case we only want to append
1036    // `/codex/responses`, not `/backend-api/codex/responses` again.
1037    if base.ends_with("/backend-api") {
1038        return format!("{base}/codex/responses");
1039    }
1040    if base.ends_with("/responses") {
1041        return base.to_string();
1042    }
1043    format!("{base}/backend-api/codex/responses")
1044}
1045
1046pub fn normalize_cohere_base(base_url: &str) -> String {
1047    let trimmed = base_url.trim();
1048    if trimmed.is_empty() {
1049        return "https://api.cohere.com/v2/chat".to_string();
1050    }
1051    let base_url = trimmed.trim_end_matches('/');
1052    if base_url.ends_with("/chat") {
1053        return base_url.to_string();
1054    }
1055    if base_url.ends_with("/v2") {
1056        return format!("{base_url}/chat");
1057    }
1058    format!("{base_url}/chat")
1059}
1060
1061#[cfg(test)]
1062mod tests {
1063    use super::*;
1064    use crate::extensions::{ExtensionManager, JsExtensionLoadSpec, JsExtensionRuntimeHandle};
1065    use crate::extensions_js::PiJsRuntimeConfig;
1066    use crate::model::{ContentBlock, Message, UserContent, UserMessage};
1067    use crate::tools::ToolRegistry;
1068    use asupersync::runtime::RuntimeBuilder;
1069    use asupersync::time::{sleep, wall_now};
1070    use futures::StreamExt;
1071    use std::sync::Arc;
1072    use std::time::Duration;
1073    use tempfile::tempdir;
1074
1075    const STREAM_SIMPLE_EXTENSION: &str = r#"
1076export default function init(pi) {
1077  pi.registerProvider("stream-provider", {
1078    baseUrl: "https://api.example.test",
1079    apiKey: "EXAMPLE_KEY",
1080    api: "custom-api",
1081    models: [
1082      { id: "stream-model", name: "Stream Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1083    ],
1084    streamSimple: async function* (model, context, options) {
1085      if (!model || !model.baseUrl || !model.maxTokens || !model.contextWindow) {
1086        throw new Error("bad model shape");
1087      }
1088      if (!context || !Array.isArray(context.messages)) {
1089        throw new Error("bad context shape");
1090      }
1091      if (!options || !options.signal) {
1092        throw new Error("missing abort signal");
1093      }
1094
1095      const partial = {
1096        role: "assistant",
1097        content: [{ type: "text", text: "" }],
1098        api: model.api,
1099        provider: model.provider,
1100        model: model.id,
1101        usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1102        stopReason: "stop",
1103        timestamp: 0
1104      };
1105
1106      yield { type: "start", partial };
1107      yield { type: "text_start", contentIndex: 0, partial };
1108      partial.content[0].text += "hi";
1109      yield { type: "text_delta", contentIndex: 0, delta: "hi", partial };
1110      yield { type: "done", reason: "stop", message: partial };
1111    }
1112  });
1113}
1114"#;
1115
1116    const STREAM_SIMPLE_CANCEL_EXTENSION: &str = r#"
1117export default function init(pi) {
1118  pi.registerProvider("cancel-provider", {
1119    baseUrl: "https://api.example.test",
1120    apiKey: "EXAMPLE_KEY",
1121    api: "custom-api",
1122    models: [
1123      { id: "cancel-model", name: "Cancel Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1124    ],
1125    streamSimple: async function* (model, context, options) {
1126      const partial = {
1127        role: "assistant",
1128        content: [{ type: "text", text: "" }],
1129        api: model.api,
1130        provider: model.provider,
1131        model: model.id,
1132        usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1133        stopReason: "stop",
1134        timestamp: 0
1135      };
1136
1137      try {
1138        yield { type: "start", partial };
1139        await new Promise((resolve) => {
1140          if (options && options.signal && options.signal.aborted) return resolve();
1141          if (options && options.signal && typeof options.signal.addEventListener === "function") {
1142            options.signal.addEventListener("abort", () => resolve());
1143          }
1144        });
1145      } finally {
1146        await pi.tool("write", { path: "cancelled.txt", content: "ok" });
1147      }
1148    }
1149  });
1150}
1151"#;
1152
1153    async fn load_extension(
1154        source: &str,
1155        allow_write: bool,
1156    ) -> (tempfile::TempDir, ExtensionManager) {
1157        let dir = tempdir().expect("tempdir");
1158        let entry_path = dir.path().join("ext.mjs");
1159        std::fs::write(&entry_path, source).expect("write extension");
1160
1161        let manager = ExtensionManager::new();
1162        let tools = if allow_write {
1163            Arc::new(ToolRegistry::new(&["write"], dir.path(), None))
1164        } else {
1165            Arc::new(ToolRegistry::new(&[], dir.path(), None))
1166        };
1167
1168        let js_runtime = JsExtensionRuntimeHandle::start(
1169            PiJsRuntimeConfig {
1170                cwd: dir.path().display().to_string(),
1171                ..Default::default()
1172            },
1173            Arc::clone(&tools),
1174            manager.clone(),
1175        )
1176        .await
1177        .expect("start js runtime");
1178        manager.set_js_runtime(js_runtime);
1179
1180        let spec = JsExtensionLoadSpec::from_entry_path(&entry_path).expect("load spec");
1181        manager
1182            .load_js_extensions(vec![spec])
1183            .await
1184            .expect("load extension");
1185
1186        (dir, manager)
1187    }
1188
1189    fn basic_context() -> Context<'static> {
1190        Context {
1191            system_prompt: Some("system".to_string().into()),
1192            messages: vec![Message::User(UserMessage {
1193                content: UserContent::Text("hello".to_string()),
1194                timestamp: 0,
1195            })]
1196            .into(),
1197            tools: Vec::new().into(),
1198        }
1199    }
1200
1201    fn basic_options() -> StreamOptions {
1202        StreamOptions {
1203            api_key: Some("sk-test".to_string()),
1204            ..Default::default()
1205        }
1206    }
1207
1208    #[test]
1209    fn extension_stream_simple_provider_emits_assistant_events() {
1210        let runtime = RuntimeBuilder::current_thread()
1211            .build()
1212            .expect("runtime build");
1213
1214        runtime.block_on(async move {
1215            let (_dir, manager) = load_extension(STREAM_SIMPLE_EXTENSION, false).await;
1216            let entries = manager.extension_model_entries();
1217            assert_eq!(entries.len(), 1);
1218            let entry = entries
1219                .iter()
1220                .find(|e| e.model.provider == "stream-provider")
1221                .expect("stream-provider entry");
1222
1223            let provider = create_provider(entry, Some(&manager)).expect("create provider");
1224            assert_eq!(provider.name(), "stream-provider");
1225
1226            let ctx = basic_context();
1227            let opts = basic_options();
1228            let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1229
1230            let mut saw_start = false;
1231            let mut saw_text_delta = false;
1232            while let Some(item) = stream.next().await {
1233                let event = item.expect("stream event");
1234                match event {
1235                    StreamEvent::Start { .. } => {
1236                        saw_start = true;
1237                    }
1238                    StreamEvent::TextDelta { delta, .. } => {
1239                        assert_eq!(delta, "hi");
1240                        saw_text_delta = true;
1241                    }
1242                    StreamEvent::Done { reason, message } => {
1243                        assert_eq!(reason, StopReason::Stop);
1244                        let text = match &message.content[0] {
1245                            ContentBlock::Text(text) => text,
1246                            other => unreachable!("expected text content block, got {other:?}"),
1247                        };
1248                        assert_eq!(text.text, "hi");
1249                        break;
1250                    }
1251                    _ => {}
1252                }
1253            }
1254
1255            assert!(saw_start, "expected a Start event");
1256            assert!(saw_text_delta, "expected a TextDelta event");
1257        });
1258    }
1259
1260    #[test]
1261    fn extension_stream_simple_provider_drop_cancels_js_stream() {
1262        let runtime = RuntimeBuilder::current_thread()
1263            .build()
1264            .expect("runtime build");
1265
1266        runtime.block_on(async move {
1267            let (dir, manager) = load_extension(STREAM_SIMPLE_CANCEL_EXTENSION, true).await;
1268            let entries = manager.extension_model_entries();
1269            assert_eq!(entries.len(), 1);
1270            let entry = entries
1271                .iter()
1272                .find(|e| e.model.provider == "cancel-provider")
1273                .expect("cancel-provider entry");
1274
1275            let provider = create_provider(entry, Some(&manager)).expect("create provider");
1276            let ctx = basic_context();
1277            let opts = basic_options();
1278            let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1279
1280            let first = stream.next().await.expect("first event");
1281            let _ = first.expect("first event ok");
1282            drop(stream);
1283
1284            let out_path = dir.path().join("cancelled.txt");
1285            for _ in 0..200 {
1286                if out_path.exists() {
1287                    let contents = std::fs::read_to_string(&out_path).expect("read cancelled.txt");
1288                    assert_eq!(contents, "ok");
1289                    return;
1290                }
1291                sleep(wall_now(), Duration::from_millis(5)).await;
1292            }
1293
1294            assert!(
1295                out_path.exists(),
1296                "expected cancelled.txt to be created after stream drop/cancel"
1297            );
1298        });
1299    }
1300
1301    // ========================================================================
1302    // Additional tests for bd-izzp
1303    // ========================================================================
1304
1305    const STREAM_SIMPLE_MULTI_CHUNK: &str = r#"
1306export default function init(pi) {
1307  pi.registerProvider("multi-chunk-provider", {
1308    baseUrl: "https://api.example.test",
1309    apiKey: "EXAMPLE_KEY",
1310    api: "custom-api",
1311    models: [
1312      { id: "multi-model", name: "Multi Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1313    ],
1314    streamSimple: async function* (model, context, options) {
1315      const partial = {
1316        role: "assistant",
1317        content: [{ type: "text", text: "" }],
1318        api: model.api,
1319        provider: model.provider,
1320        model: model.id,
1321        usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1322        stopReason: "stop",
1323        timestamp: 0
1324      };
1325
1326      yield { type: "start", partial };
1327      yield { type: "text_start", contentIndex: 0, partial };
1328
1329      const chunks = ["Hello", ", ", "world", "!"];
1330      for (const chunk of chunks) {
1331        partial.content[0].text += chunk;
1332        yield { type: "text_delta", contentIndex: 0, delta: chunk, partial };
1333      }
1334
1335      yield { type: "text_end", contentIndex: 0, content: partial.content[0].text, partial };
1336      yield { type: "done", reason: "stop", message: partial };
1337    }
1338  });
1339}
1340"#;
1341
1342    const STREAM_SIMPLE_ERROR: &str = r#"
1343export default function init(pi) {
1344  pi.registerProvider("error-provider", {
1345    baseUrl: "https://api.example.test",
1346    apiKey: "EXAMPLE_KEY",
1347    api: "custom-api",
1348    models: [
1349      { id: "error-model", name: "Error Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1350    ],
1351    streamSimple: async function* (model, context, options) {
1352      const partial = {
1353        role: "assistant",
1354        content: [{ type: "text", text: "" }],
1355        api: model.api,
1356        provider: model.provider,
1357        model: model.id,
1358        usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1359        stopReason: "stop",
1360        timestamp: 0
1361      };
1362
1363      yield { type: "start", partial };
1364      throw new Error("simulated JS error during streaming");
1365    }
1366  });
1367}
1368"#;
1369
1370    const STREAM_SIMPLE_UNICODE: &str = r#"
1371export default function init(pi) {
1372  pi.registerProvider("unicode-provider", {
1373    baseUrl: "https://api.example.test",
1374    apiKey: "EXAMPLE_KEY",
1375    api: "custom-api",
1376    models: [
1377      { id: "unicode-model", name: "Unicode Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1378    ],
1379    streamSimple: async function* (model, context, options) {
1380      const partial = {
1381        role: "assistant",
1382        content: [{ type: "text", text: "" }],
1383        api: model.api,
1384        provider: model.provider,
1385        model: model.id,
1386        usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1387        stopReason: "stop",
1388        timestamp: 0
1389      };
1390
1391      yield { type: "start", partial };
1392      yield { type: "text_start", contentIndex: 0, partial };
1393      partial.content[0].text = "日本語テスト 🦀";
1394      yield { type: "text_delta", contentIndex: 0, delta: "日本語テスト 🦀", partial };
1395      yield { type: "done", reason: "stop", message: partial };
1396    }
1397  });
1398}
1399"#;
1400
1401    #[test]
1402    fn extension_stream_simple_multiple_chunks_in_order() {
1403        let runtime = RuntimeBuilder::current_thread()
1404            .build()
1405            .expect("runtime build");
1406
1407        runtime.block_on(async move {
1408            let (_dir, manager) = load_extension(STREAM_SIMPLE_MULTI_CHUNK, false).await;
1409            let entries = manager.extension_model_entries();
1410            let entry = entries
1411                .iter()
1412                .find(|e| e.model.provider == "multi-chunk-provider")
1413                .expect("multi-chunk-provider entry");
1414
1415            let provider = create_provider(entry, Some(&manager)).expect("create provider");
1416            let ctx = basic_context();
1417            let opts = basic_options();
1418            let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1419
1420            let mut deltas = Vec::new();
1421            let mut final_text = String::new();
1422            while let Some(item) = stream.next().await {
1423                let event = item.expect("stream event");
1424                match event {
1425                    StreamEvent::TextDelta { delta, .. } => {
1426                        deltas.push(delta);
1427                    }
1428                    StreamEvent::Done { message, .. } => {
1429                        let text = match &message.content[0] {
1430                            ContentBlock::Text(text) => text,
1431                            other => unreachable!("expected text content block, got {other:?}"),
1432                        };
1433                        final_text = text.text.clone();
1434                        break;
1435                    }
1436                    _ => {}
1437                }
1438            }
1439
1440            assert_eq!(deltas, vec!["Hello", ", ", "world", "!"]);
1441            assert_eq!(final_text, "Hello, world!");
1442        });
1443    }
1444
1445    #[test]
1446    fn extension_stream_simple_js_error_propagates() {
1447        let runtime = RuntimeBuilder::current_thread()
1448            .build()
1449            .expect("runtime build");
1450
1451        runtime.block_on(async move {
1452            let (_dir, manager) = load_extension(STREAM_SIMPLE_ERROR, false).await;
1453            let entries = manager.extension_model_entries();
1454            let entry = entries
1455                .iter()
1456                .find(|e| e.model.provider == "error-provider")
1457                .expect("error-provider entry");
1458
1459            let provider = create_provider(entry, Some(&manager)).expect("create provider");
1460            let ctx = basic_context();
1461            let opts = basic_options();
1462            let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1463
1464            let mut saw_start = false;
1465            let mut saw_error = false;
1466            while let Some(item) = stream.next().await {
1467                match item {
1468                    Ok(StreamEvent::Start { .. }) => {
1469                        saw_start = true;
1470                    }
1471                    Err(err) => {
1472                        // JS error should propagate as an extension error.
1473                        let msg = err.to_string();
1474                        assert!(
1475                            msg.contains("simulated JS error") || msg.contains("error"),
1476                            "expected JS error message, got: {msg}"
1477                        );
1478                        saw_error = true;
1479                        break;
1480                    }
1481                    Ok(StreamEvent::Error { .. }) => {
1482                        saw_error = true;
1483                        break;
1484                    }
1485                    _ => {}
1486                }
1487            }
1488
1489            assert!(saw_start, "expected a Start event before error");
1490            assert!(saw_error, "expected JS error to propagate");
1491        });
1492    }
1493
1494    #[test]
1495    fn extension_stream_simple_unicode_content() {
1496        let runtime = RuntimeBuilder::current_thread()
1497            .build()
1498            .expect("runtime build");
1499
1500        runtime.block_on(async move {
1501            let (_dir, manager) = load_extension(STREAM_SIMPLE_UNICODE, false).await;
1502            let entries = manager.extension_model_entries();
1503            let entry = entries
1504                .iter()
1505                .find(|e| e.model.provider == "unicode-provider")
1506                .expect("unicode-provider entry");
1507
1508            let provider = create_provider(entry, Some(&manager)).expect("create provider");
1509            let ctx = basic_context();
1510            let opts = basic_options();
1511            let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1512
1513            let mut saw_unicode = false;
1514            while let Some(item) = stream.next().await {
1515                let event = item.expect("stream event");
1516                match event {
1517                    StreamEvent::TextDelta { delta, .. } => {
1518                        assert_eq!(delta, "日本語テスト 🦀");
1519                        saw_unicode = true;
1520                    }
1521                    StreamEvent::Done { .. } => break,
1522                    _ => {}
1523                }
1524            }
1525
1526            assert!(saw_unicode, "expected unicode text delta");
1527        });
1528    }
1529
1530    #[test]
1531    fn extension_stream_simple_provider_name_and_model() {
1532        let runtime = RuntimeBuilder::current_thread()
1533            .build()
1534            .expect("runtime build");
1535
1536        runtime.block_on(async move {
1537            let (_dir, manager) = load_extension(STREAM_SIMPLE_EXTENSION, false).await;
1538            let entries = manager.extension_model_entries();
1539            let entry = entries
1540                .iter()
1541                .find(|e| e.model.provider == "stream-provider")
1542                .expect("stream-provider entry");
1543
1544            let provider = create_provider(entry, Some(&manager)).expect("create provider");
1545            assert_eq!(provider.name(), "stream-provider");
1546            assert_eq!(provider.model_id(), "stream-model");
1547            assert_eq!(provider.api(), "custom-api");
1548        });
1549    }
1550
1551    #[test]
1552    fn create_provider_returns_extension_provider_for_stream_simple() {
1553        let runtime = RuntimeBuilder::current_thread()
1554            .build()
1555            .expect("runtime build");
1556
1557        runtime.block_on(async move {
1558            let (_dir, manager) = load_extension(STREAM_SIMPLE_EXTENSION, false).await;
1559            let entries = manager.extension_model_entries();
1560            let entry = entries
1561                .iter()
1562                .find(|e| e.model.provider == "stream-provider")
1563                .expect("stream-provider entry");
1564
1565            // With extensions, should create ExtensionStreamSimpleProvider.
1566            let provider = create_provider(entry, Some(&manager));
1567            assert!(provider.is_ok());
1568
1569            // Without extensions, should fail (unknown provider).
1570            let provider_no_ext = create_provider(entry, None);
1571            assert!(provider_no_ext.is_err());
1572        });
1573    }
1574
1575    // ========================================================================
1576    // bd-g1nx: Provider factory + URL normalization tests
1577    // ========================================================================
1578
1579    use crate::models::ModelEntry;
1580    use crate::provider::{InputType, Model, ModelCost};
1581    use std::collections::HashMap;
1582
1583    fn model_entry(provider: &str, api: &str, model_id: &str, base_url: &str) -> ModelEntry {
1584        ModelEntry {
1585            model: Model {
1586                id: model_id.to_string(),
1587                name: model_id.to_string(),
1588                api: api.to_string(),
1589                provider: provider.to_string(),
1590                base_url: base_url.to_string(),
1591                reasoning: false,
1592                input: vec![InputType::Text],
1593                cost: ModelCost {
1594                    input: 3.0,
1595                    output: 15.0,
1596                    cache_read: 0.3,
1597                    cache_write: 3.75,
1598                },
1599                context_window: 200_000,
1600                max_tokens: 8192,
1601                headers: HashMap::new(),
1602            },
1603            api_key: Some("sk-test-key".to_string()),
1604            headers: HashMap::new(),
1605            auth_header: true,
1606            compat: None,
1607            oauth_config: None,
1608        }
1609    }
1610
1611    #[test]
1612    fn resolve_provider_route_uses_metadata_for_alias_provider() {
1613        let entry = model_entry(
1614            "kimi",
1615            "openai-completions",
1616            "kimi-k2-instruct",
1617            "https://api.moonshot.ai/v1",
1618        );
1619        let (route, canonical_provider, effective_api) =
1620            resolve_provider_route(&entry).expect("resolve alias route");
1621        assert_eq!(route, ProviderRouteKind::ApiOpenAICompletions);
1622        assert_eq!(canonical_provider, "moonshotai");
1623        assert_eq!(effective_api, "openai-completions");
1624    }
1625
1626    #[test]
1627    fn resolve_provider_route_openai_unknown_api_defaults_to_native_responses() {
1628        let entry = model_entry("openai", "openai", "gpt-4o", "https://api.openai.com/v1");
1629        let (route, canonical_provider, effective_api) =
1630            resolve_provider_route(&entry).expect("resolve openai route");
1631        assert_eq!(route, ProviderRouteKind::NativeOpenAIResponses);
1632        assert_eq!(canonical_provider, "openai");
1633        assert_eq!(effective_api, "openai");
1634    }
1635
1636    #[test]
1637    fn resolve_provider_route_cloudflare_workers_defaults_to_openai_completions() {
1638        let entry = model_entry(
1639            "cloudflare-workers-ai",
1640            "",
1641            "@cf/meta/llama-3.1-8b-instruct",
1642            "https://api.cloudflare.com/client/v4/accounts/test-account/ai/v1",
1643        );
1644        let (route, canonical_provider, effective_api) =
1645            resolve_provider_route(&entry).expect("resolve cloudflare workers route");
1646        assert_eq!(route, ProviderRouteKind::ApiOpenAICompletions);
1647        assert_eq!(canonical_provider, "cloudflare-workers-ai");
1648        assert_eq!(effective_api, "openai-completions");
1649    }
1650
1651    #[test]
1652    fn resolve_provider_route_cloudflare_gateway_defaults_to_openai_completions() {
1653        let entry = model_entry(
1654            "cloudflare-ai-gateway",
1655            "",
1656            "gpt-4o-mini",
1657            "https://gateway.ai.cloudflare.com/v1/account-id/gateway-id/openai",
1658        );
1659        let (route, canonical_provider, effective_api) =
1660            resolve_provider_route(&entry).expect("resolve cloudflare gateway route");
1661        assert_eq!(route, ProviderRouteKind::ApiOpenAICompletions);
1662        assert_eq!(canonical_provider, "cloudflare-ai-gateway");
1663        assert_eq!(effective_api, "openai-completions");
1664    }
1665
1666    #[test]
1667    fn resolve_provider_route_uses_native_azure_route_for_cognitive_alias() {
1668        let entry = model_entry(
1669            "azure-cognitive-services",
1670            "openai-completions",
1671            "gpt-4o-mini",
1672            "https://myresource.cognitiveservices.azure.com",
1673        );
1674        let (route, canonical_provider, effective_api) =
1675            resolve_provider_route(&entry).expect("resolve azure cognitive route");
1676        assert_eq!(route, ProviderRouteKind::NativeAzure);
1677        assert_eq!(canonical_provider, "azure-openai");
1678        assert_eq!(effective_api, "openai-completions");
1679    }
1680
1681    #[test]
1682    fn resolve_provider_route_uses_native_azure_route_for_legacy_provider_alias() {
1683        let entry = model_entry(
1684            "azure-openai-responses",
1685            "azure-openai-responses",
1686            "gpt-4o-mini",
1687            "https://myresource.openai.azure.com",
1688        );
1689        let (route, canonical_provider, effective_api) =
1690            resolve_provider_route(&entry).expect("resolve azure legacy alias route");
1691        assert_eq!(route, ProviderRouteKind::NativeAzure);
1692        assert_eq!(canonical_provider, "azure-openai");
1693        assert_eq!(effective_api, "azure-openai-responses");
1694    }
1695
1696    #[test]
1697    fn resolve_provider_route_accepts_azure_legacy_api_for_custom_provider_id() {
1698        let entry = model_entry(
1699            "my-azure",
1700            "azure-openai-responses",
1701            "gpt-4o-mini",
1702            "https://example.invalid",
1703        );
1704        let (route, canonical_provider, effective_api) =
1705            resolve_provider_route(&entry).expect("resolve azure legacy api fallback");
1706        assert_eq!(route, ProviderRouteKind::NativeAzure);
1707        assert_eq!(canonical_provider, "my-azure");
1708        assert_eq!(effective_api, "azure-openai-responses");
1709    }
1710
1711    #[test]
1712    fn resolve_copilot_token_prefers_inline_model_api_key() {
1713        let mut entry = model_entry("github-copilot", "", "gpt-4o", "");
1714        entry.api_key = Some("inline-copilot-token".to_string());
1715
1716        let token = resolve_copilot_token_with_env(&entry, |_| None)
1717            .expect("inline token should be accepted");
1718        assert_eq!(token, "inline-copilot-token");
1719    }
1720
1721    #[test]
1722    fn resolve_copilot_token_falls_back_to_env() {
1723        let mut entry = model_entry("github-copilot", "", "gpt-4o", "");
1724        entry.api_key = None;
1725
1726        let token = resolve_copilot_token_with_env(&entry, |name| match name {
1727            "GITHUB_COPILOT_API_KEY" => Some("env-copilot-token".to_string()),
1728            _ => None,
1729        })
1730        .expect("env token should be accepted");
1731        assert_eq!(token, "env-copilot-token");
1732    }
1733
1734    #[test]
1735    fn resolve_copilot_token_errors_when_missing_everywhere() {
1736        let mut entry = model_entry("github-copilot", "", "gpt-4o", "");
1737        entry.api_key = None;
1738
1739        let err = resolve_copilot_token_with_env(&entry, |_| None).expect_err("expected error");
1740        assert!(
1741            err.to_string().contains("GitHub Copilot requires"),
1742            "unexpected error: {err}"
1743        );
1744    }
1745
1746    #[test]
1747    fn suggest_similar_providers_finds_prefix_match() {
1748        let suggestions = suggest_similar_providers("deep");
1749        assert!(
1750            suggestions.contains(&"deepinfra".to_string())
1751                || suggestions.contains(&"deepseek".to_string()),
1752            "expected deepinfra or deepseek in suggestions: {suggestions:?}"
1753        );
1754    }
1755
1756    #[test]
1757    fn suggest_similar_providers_finds_substring_match() {
1758        let suggestions = suggest_similar_providers("flow");
1759        assert!(
1760            suggestions.contains(&"siliconflow".to_string()),
1761            "expected siliconflow in suggestions: {suggestions:?}"
1762        );
1763    }
1764
1765    #[test]
1766    fn suggest_similar_providers_returns_empty_for_gibberish() {
1767        let suggestions = suggest_similar_providers("xyzzzabc123");
1768        assert!(
1769            suggestions.is_empty(),
1770            "expected no suggestions for gibberish: {suggestions:?}"
1771        );
1772    }
1773
1774    #[test]
1775    fn suggest_similar_providers_caps_at_three() {
1776        let suggestions = suggest_similar_providers("a");
1777        assert!(
1778            suggestions.len() <= 3,
1779            "expected at most 3 suggestions: {suggestions:?}"
1780        );
1781    }
1782
1783    #[test]
1784    fn edit_distance_basic_cases() {
1785        assert_eq!(edit_distance(b"", b""), 0);
1786        assert_eq!(edit_distance(b"abc", b"abc"), 0);
1787        assert_eq!(edit_distance(b"abc", b"ab"), 1);
1788        assert_eq!(edit_distance(b"abc", b"axc"), 1);
1789        assert_eq!(edit_distance(b"abc", b"abcd"), 1);
1790        assert_eq!(edit_distance(b"kitten", b"sitting"), 3);
1791        assert_eq!(edit_distance(b"", b"hello"), 5);
1792    }
1793
1794    #[test]
1795    fn suggest_similar_providers_finds_typo_with_edit_distance() {
1796        // "anthropick" is edit distance 1 from "anthropic"
1797        let suggestions = suggest_similar_providers("anthropick");
1798        assert!(
1799            suggestions.contains(&"anthropic".to_string()),
1800            "expected anthropic for typo 'anthropick': {suggestions:?}"
1801        );
1802    }
1803
1804    #[test]
1805    fn suggest_similar_providers_finds_typo_missing_char() {
1806        // "openai" with missing letter: "opnai" → edit distance 1
1807        let suggestions = suggest_similar_providers("opnai");
1808        assert!(
1809            suggestions.contains(&"openai".to_string()),
1810            "expected openai for typo 'opnai': {suggestions:?}"
1811        );
1812    }
1813
1814    #[test]
1815    fn suggest_similar_providers_finds_transposed_chars() {
1816        // "gogle" → "google" edit distance 1 (missing 'o')
1817        let suggestions = suggest_similar_providers("gogle");
1818        assert!(
1819            suggestions.contains(&"google".to_string()),
1820            "expected google for typo 'gogle': {suggestions:?}"
1821        );
1822    }
1823
1824    #[test]
1825    fn suggest_similar_providers_no_false_positives_for_short_input() {
1826        // Very short input should not match via edit distance (threshold=0)
1827        let suggestions = suggest_similar_providers("xy");
1828        assert!(
1829            suggestions.is_empty(),
1830            "expected no suggestions for 'xy': {suggestions:?}"
1831        );
1832    }
1833
1834    #[test]
1835    fn resolve_azure_provider_runtime_supports_openai_host() {
1836        let entry = model_entry(
1837            "azure-openai",
1838            "openai-completions",
1839            "gpt-4o",
1840            "https://myresource.openai.azure.com",
1841        );
1842        let runtime =
1843            resolve_azure_provider_runtime_with_env(&entry, |_| None).expect("resolve runtime");
1844        assert_eq!(runtime.resource, "myresource");
1845        assert_eq!(runtime.deployment, "gpt-4o");
1846        assert_eq!(runtime.api_version, "2024-02-15-preview");
1847        assert_eq!(
1848            runtime.endpoint_url,
1849            "https://myresource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-02-15-preview"
1850        );
1851    }
1852
1853    #[test]
1854    fn resolve_azure_provider_runtime_supports_cognitive_services_host() {
1855        let entry = model_entry(
1856            "azure-cognitive-services",
1857            "openai-completions",
1858            "gpt-4o-mini",
1859            "https://myresource.cognitiveservices.azure.com/openai/deployments/custom/chat/completions?api-version=2024-10-21",
1860        );
1861        let runtime =
1862            resolve_azure_provider_runtime_with_env(&entry, |_| None).expect("resolve runtime");
1863        assert_eq!(runtime.resource, "myresource");
1864        assert_eq!(runtime.deployment, "gpt-4o-mini");
1865        assert_eq!(runtime.api_version, "2024-10-21");
1866        assert_eq!(
1867            runtime.endpoint_url,
1868            "https://myresource.cognitiveservices.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-10-21"
1869        );
1870    }
1871
1872    // ── create_provider: built-in provider selection ─────────────────
1873
1874    #[test]
1875    fn create_provider_anthropic_by_name() {
1876        let entry = model_entry(
1877            "anthropic",
1878            "anthropic-messages",
1879            "claude-sonnet-4-5",
1880            "https://api.anthropic.com",
1881        );
1882        let provider = create_provider(&entry, None).expect("anthropic provider");
1883        assert_eq!(provider.name(), "anthropic");
1884        assert_eq!(provider.model_id(), "claude-sonnet-4-5");
1885        assert_eq!(provider.api(), "anthropic-messages");
1886    }
1887
1888    #[test]
1889    fn create_provider_openai_completions_by_name() {
1890        let entry = model_entry(
1891            "openai",
1892            "openai-completions",
1893            "gpt-4o",
1894            "https://api.openai.com/v1",
1895        );
1896        let provider = create_provider(&entry, None).expect("openai completions provider");
1897        assert_eq!(provider.name(), "openai");
1898        assert_eq!(provider.model_id(), "gpt-4o");
1899    }
1900
1901    #[test]
1902    fn create_provider_openai_responses_by_name() {
1903        let entry = model_entry(
1904            "openai",
1905            "openai-responses",
1906            "gpt-4o",
1907            "https://api.openai.com/v1",
1908        );
1909        let provider = create_provider(&entry, None).expect("openai responses provider");
1910        assert_eq!(provider.name(), "openai");
1911        assert_eq!(provider.model_id(), "gpt-4o");
1912    }
1913
1914    #[test]
1915    fn create_provider_openai_defaults_to_responses() {
1916        // When api is not "openai-completions", OpenAI defaults to Responses API
1917        let entry = model_entry("openai", "openai", "gpt-4o", "https://api.openai.com/v1");
1918        let provider = create_provider(&entry, None).expect("openai default responses provider");
1919        assert_eq!(provider.name(), "openai");
1920    }
1921
1922    #[test]
1923    fn create_provider_google_by_name() {
1924        let entry = model_entry(
1925            "google",
1926            "google-generative-ai",
1927            "gemini-2.0-flash",
1928            "https://generativelanguage.googleapis.com",
1929        );
1930        let provider = create_provider(&entry, None).expect("google provider");
1931        assert_eq!(provider.name(), "google");
1932        assert_eq!(provider.model_id(), "gemini-2.0-flash");
1933    }
1934
1935    #[test]
1936    fn create_provider_cohere_by_name() {
1937        let entry = model_entry(
1938            "cohere",
1939            "cohere-chat",
1940            "command-r-plus",
1941            "https://api.cohere.com/v2",
1942        );
1943        let provider = create_provider(&entry, None).expect("cohere provider");
1944        assert_eq!(provider.name(), "cohere");
1945        assert_eq!(provider.model_id(), "command-r-plus");
1946    }
1947
1948    #[test]
1949    fn create_provider_azure_openai_by_name() {
1950        let entry = model_entry(
1951            "azure-openai",
1952            "openai-completions",
1953            "gpt-4o",
1954            "https://myresource.openai.azure.com",
1955        );
1956        let provider = create_provider(&entry, None).expect("azure provider");
1957        assert_eq!(provider.name(), "azure");
1958        assert_eq!(provider.api(), "azure-openai");
1959        assert!(!provider.model_id().is_empty());
1960    }
1961
1962    #[test]
1963    fn create_provider_azure_cognitive_services_alias_by_name() {
1964        let entry = model_entry(
1965            "azure-cognitive-services",
1966            "openai-completions",
1967            "gpt-4o-mini",
1968            "https://myresource.cognitiveservices.azure.com",
1969        );
1970        let provider = create_provider(&entry, None).expect("azure cognitive provider");
1971        assert_eq!(provider.name(), "azure");
1972        assert_eq!(provider.api(), "azure-openai");
1973        assert!(!provider.model_id().is_empty());
1974    }
1975
1976    #[test]
1977    fn create_provider_cloudflare_workers_ai_by_name() {
1978        let entry = model_entry(
1979            "cloudflare-workers-ai",
1980            "",
1981            "@cf/meta/llama-3.1-8b-instruct",
1982            "https://api.cloudflare.com/client/v4/accounts/test-account/ai/v1",
1983        );
1984        let provider = create_provider(&entry, None).expect("cloudflare workers provider");
1985        assert_eq!(provider.name(), "cloudflare-workers-ai");
1986        assert_eq!(provider.api(), "openai-completions");
1987        assert_eq!(provider.model_id(), "@cf/meta/llama-3.1-8b-instruct");
1988    }
1989
1990    #[test]
1991    fn create_provider_cloudflare_ai_gateway_by_name() {
1992        let entry = model_entry(
1993            "cloudflare-ai-gateway",
1994            "",
1995            "gpt-4o-mini",
1996            "https://gateway.ai.cloudflare.com/v1/account-id/gateway-id/openai",
1997        );
1998        let provider = create_provider(&entry, None).expect("cloudflare gateway provider");
1999        assert_eq!(provider.name(), "cloudflare-ai-gateway");
2000        assert_eq!(provider.api(), "openai-completions");
2001        assert_eq!(provider.model_id(), "gpt-4o-mini");
2002    }
2003
2004    // ── create_provider: API fallback path ──────────────────────────
2005
2006    #[test]
2007    fn create_provider_falls_back_to_api_anthropic_messages() {
2008        let entry = model_entry(
2009            "custom-anthropic",
2010            "anthropic-messages",
2011            "my-model",
2012            "https://custom.api.com",
2013        );
2014        let provider = create_provider(&entry, None).expect("fallback anthropic provider");
2015        // Anthropic fallback uses the standard anthropic provider
2016        assert_eq!(provider.model_id(), "my-model");
2017    }
2018
2019    #[test]
2020    fn create_provider_falls_back_to_api_openai_completions() {
2021        let entry = model_entry(
2022            "my-openai-compat",
2023            "openai-completions",
2024            "local-model",
2025            "http://localhost:8080/v1",
2026        );
2027        let provider = create_provider(&entry, None).expect("fallback openai completions");
2028        assert_eq!(provider.model_id(), "local-model");
2029    }
2030
2031    #[test]
2032    fn create_provider_falls_back_to_api_openai_responses() {
2033        let entry = model_entry(
2034            "my-openai-compat",
2035            "openai-responses",
2036            "local-model",
2037            "http://localhost:8080/v1",
2038        );
2039        let provider = create_provider(&entry, None).expect("fallback openai responses");
2040        assert_eq!(provider.model_id(), "local-model");
2041    }
2042
2043    #[test]
2044    fn create_provider_falls_back_to_api_cohere_chat() {
2045        let entry = model_entry(
2046            "custom-cohere",
2047            "cohere-chat",
2048            "custom-r",
2049            "https://custom-cohere.api.com/v2",
2050        );
2051        let provider = create_provider(&entry, None).expect("fallback cohere provider");
2052        assert_eq!(provider.model_id(), "custom-r");
2053    }
2054
2055    #[test]
2056    fn create_provider_falls_back_to_api_google() {
2057        let entry = model_entry(
2058            "custom-google",
2059            "google-generative-ai",
2060            "custom-gemini",
2061            "https://custom.google.com",
2062        );
2063        let provider = create_provider(&entry, None).expect("fallback google provider");
2064        assert_eq!(provider.model_id(), "custom-gemini");
2065    }
2066
2067    #[test]
2068    fn resolve_provider_route_copilot_routes_correctly() {
2069        let entry = model_entry("github-copilot", "", "gpt-4o", "");
2070        let (route, canonical, _api) = resolve_provider_route(&entry).expect("copilot route");
2071        assert_eq!(route, ProviderRouteKind::NativeCopilot);
2072        assert_eq!(canonical, "github-copilot");
2073    }
2074
2075    #[test]
2076    fn resolve_provider_route_copilot_alias_routes_correctly() {
2077        let entry = model_entry("copilot", "", "gpt-4o", "");
2078        let (route, canonical, _api) = resolve_provider_route(&entry).expect("copilot alias route");
2079        assert_eq!(route, ProviderRouteKind::NativeCopilot);
2080        assert_eq!(canonical, "github-copilot");
2081    }
2082
2083    #[test]
2084    fn create_provider_unknown_provider_and_api_returns_error() {
2085        let entry = model_entry(
2086            "totally-unknown",
2087            "unknown-api",
2088            "some-model",
2089            "https://example.com",
2090        );
2091        let Err(err) = create_provider(&entry, None) else {
2092            panic!("unknown should fail");
2093        };
2094        let msg = err.to_string();
2095        assert!(
2096            msg.contains("not implemented"),
2097            "expected 'not implemented' message, got: {msg}"
2098        );
2099    }
2100
2101    // ── normalize_anthropic_base ───────────────────────────────────
2102
2103    #[test]
2104    fn normalize_anthropic_base_appends_v1_messages() {
2105        assert_eq!(
2106            normalize_anthropic_base("https://api.anthropic.com"),
2107            "https://api.anthropic.com/v1/messages"
2108        );
2109    }
2110
2111    #[test]
2112    fn normalize_anthropic_base_keeps_existing_v1_messages() {
2113        assert_eq!(
2114            normalize_anthropic_base("https://api.anthropic.com/v1/messages"),
2115            "https://api.anthropic.com/v1/messages"
2116        );
2117    }
2118
2119    #[test]
2120    fn normalize_anthropic_base_strips_trailing_slash() {
2121        assert_eq!(
2122            normalize_anthropic_base("https://api.anthropic.com/"),
2123            "https://api.anthropic.com/v1/messages"
2124        );
2125    }
2126
2127    #[test]
2128    fn normalize_anthropic_base_empty_uses_default() {
2129        assert_eq!(
2130            normalize_anthropic_base("   "),
2131            "https://api.anthropic.com/v1/messages"
2132        );
2133    }
2134
2135    // ── normalize_openai_base ───────────────────────────────────────
2136
2137    #[test]
2138    fn normalize_openai_base_appends_chat_completions_to_v1() {
2139        assert_eq!(
2140            normalize_openai_base("https://api.openai.com/v1"),
2141            "https://api.openai.com/v1/chat/completions"
2142        );
2143    }
2144
2145    #[test]
2146    fn normalize_openai_base_keeps_existing_chat_completions() {
2147        assert_eq!(
2148            normalize_openai_base("https://api.openai.com/v1/chat/completions"),
2149            "https://api.openai.com/v1/chat/completions"
2150        );
2151    }
2152
2153    #[test]
2154    fn normalize_openai_base_strips_trailing_slash() {
2155        assert_eq!(
2156            normalize_openai_base("https://api.openai.com/v1/"),
2157            "https://api.openai.com/v1/chat/completions"
2158        );
2159    }
2160
2161    #[test]
2162    fn normalize_openai_base_strips_responses_suffix() {
2163        assert_eq!(
2164            normalize_openai_base("https://api.openai.com/v1/responses"),
2165            "https://api.openai.com/v1/chat/completions"
2166        );
2167    }
2168
2169    #[test]
2170    fn normalize_openai_base_bare_url_gets_chat_completions() {
2171        assert_eq!(
2172            normalize_openai_base("https://my-llm-proxy.com"),
2173            "https://my-llm-proxy.com/chat/completions"
2174        );
2175    }
2176
2177    #[test]
2178    fn normalize_openai_base_empty_uses_default() {
2179        assert_eq!(
2180            normalize_openai_base(""),
2181            "https://api.openai.com/v1/chat/completions"
2182        );
2183    }
2184
2185    // ── normalize_openai_responses_base ─────────────────────────────
2186
2187    #[test]
2188    fn normalize_responses_appends_responses_to_v1() {
2189        assert_eq!(
2190            normalize_openai_responses_base("https://api.openai.com/v1"),
2191            "https://api.openai.com/v1/responses"
2192        );
2193    }
2194
2195    #[test]
2196    fn normalize_responses_keeps_existing_responses() {
2197        assert_eq!(
2198            normalize_openai_responses_base("https://api.openai.com/v1/responses"),
2199            "https://api.openai.com/v1/responses"
2200        );
2201    }
2202
2203    #[test]
2204    fn normalize_responses_strips_trailing_slash() {
2205        assert_eq!(
2206            normalize_openai_responses_base("https://api.openai.com/v1/"),
2207            "https://api.openai.com/v1/responses"
2208        );
2209    }
2210
2211    #[test]
2212    fn normalize_responses_strips_chat_completions_suffix() {
2213        assert_eq!(
2214            normalize_openai_responses_base("https://api.openai.com/v1/chat/completions"),
2215            "https://api.openai.com/v1/responses"
2216        );
2217    }
2218
2219    #[test]
2220    fn normalize_responses_bare_url_gets_responses() {
2221        assert_eq!(
2222            normalize_openai_responses_base("https://my-llm-proxy.com"),
2223            "https://my-llm-proxy.com/responses"
2224        );
2225    }
2226
2227    #[test]
2228    fn normalize_responses_base_empty_uses_default() {
2229        assert_eq!(
2230            normalize_openai_responses_base("  "),
2231            "https://api.openai.com/v1/responses"
2232        );
2233    }
2234
2235    // ── normalize_cohere_base ───────────────────────────────────────
2236
2237    #[test]
2238    fn normalize_cohere_appends_chat_to_v2() {
2239        assert_eq!(
2240            normalize_cohere_base("https://api.cohere.com/v2"),
2241            "https://api.cohere.com/v2/chat"
2242        );
2243    }
2244
2245    #[test]
2246    fn normalize_cohere_keeps_existing_chat() {
2247        assert_eq!(
2248            normalize_cohere_base("https://api.cohere.com/v2/chat"),
2249            "https://api.cohere.com/v2/chat"
2250        );
2251    }
2252
2253    #[test]
2254    fn normalize_cohere_strips_trailing_slash() {
2255        assert_eq!(
2256            normalize_cohere_base("https://api.cohere.com/v2/"),
2257            "https://api.cohere.com/v2/chat"
2258        );
2259    }
2260
2261    #[test]
2262    fn normalize_cohere_bare_url_gets_chat() {
2263        assert_eq!(
2264            normalize_cohere_base("https://custom-cohere.example.com"),
2265            "https://custom-cohere.example.com/chat"
2266        );
2267    }
2268
2269    #[test]
2270    fn normalize_cohere_base_empty_uses_default() {
2271        assert_eq!(normalize_cohere_base(""), "https://api.cohere.com/v2/chat");
2272    }
2273
2274    mod proptests {
2275        use super::*;
2276        use proptest::prelude::*;
2277
2278        proptest! {
2279            #[test]
2280            fn normalize_anthropic_base_is_idempotent_and_targets_v1_messages(
2281                base in "[A-Za-z0-9:/._-]{1,96}"
2282            ) {
2283                let normalized = normalize_anthropic_base(&base);
2284                prop_assert!(normalized.ends_with("/v1/messages"));
2285                prop_assert_eq!(normalize_anthropic_base(&normalized), normalized);
2286            }
2287
2288            #[test]
2289            fn normalize_openai_base_is_idempotent_and_targets_chat_completions(
2290                base in "[A-Za-z0-9:/._-]{1,96}"
2291            ) {
2292                let normalized = normalize_openai_base(&base);
2293                prop_assert!(normalized.ends_with("/chat/completions"));
2294                prop_assert_eq!(normalize_openai_base(&normalized), normalized);
2295            }
2296
2297            #[test]
2298            fn normalize_openai_responses_base_is_idempotent_and_targets_responses(
2299                base in "[A-Za-z0-9:/._-]{1,96}"
2300            ) {
2301                let normalized = normalize_openai_responses_base(&base);
2302                prop_assert!(normalized.ends_with("/responses"));
2303                prop_assert_eq!(normalize_openai_responses_base(&normalized), normalized);
2304            }
2305
2306            #[test]
2307            fn normalize_cohere_base_is_idempotent_and_targets_chat(
2308                base in "[A-Za-z0-9:/._-]{1,96}"
2309            ) {
2310                let normalized = normalize_cohere_base(&base);
2311                prop_assert!(normalized.ends_with("/chat"));
2312                prop_assert_eq!(normalize_cohere_base(&normalized), normalized);
2313            }
2314
2315            #[test]
2316            fn normalize_openai_base_rewrites_responses_suffix(
2317                host in "[a-z0-9-]{1,32}",
2318                trailing_slashes in 0usize..4
2319            ) {
2320                let base = format!(
2321                    "https://{host}.example/v1/responses{}",
2322                    "/".repeat(trailing_slashes)
2323                );
2324                prop_assert_eq!(
2325                    normalize_openai_base(&base),
2326                    format!("https://{host}.example/v1/chat/completions")
2327                );
2328            }
2329
2330            #[test]
2331            fn normalize_openai_responses_base_rewrites_chat_completions_suffix(
2332                host in "[a-z0-9-]{1,32}",
2333                trailing_slashes in 0usize..4
2334            ) {
2335                let base = format!(
2336                    "https://{host}.example/v1/chat/completions{}",
2337                    "/".repeat(trailing_slashes)
2338                );
2339                prop_assert_eq!(
2340                    normalize_openai_responses_base(&base),
2341                    format!("https://{host}.example/v1/responses")
2342                );
2343            }
2344        }
2345    }
2346
2347    // ── bd-3uqg.2.4: Compat override propagation ─────────────────────
2348
2349    use crate::models::CompatConfig;
2350
2351    fn compat_with_custom_headers() -> CompatConfig {
2352        let mut custom = HashMap::new();
2353        custom.insert("X-Custom-Header".to_string(), "test-value".to_string());
2354        custom.insert("X-Provider-Tag".to_string(), "override".to_string());
2355        CompatConfig {
2356            custom_headers: Some(custom),
2357            ..Default::default()
2358        }
2359    }
2360
2361    fn model_entry_with_compat(
2362        provider: &str,
2363        api: &str,
2364        model_id: &str,
2365        base_url: &str,
2366        compat: CompatConfig,
2367    ) -> ModelEntry {
2368        let mut entry = model_entry(provider, api, model_id, base_url);
2369        entry.compat = Some(compat);
2370        entry
2371    }
2372
2373    #[test]
2374    fn create_provider_anthropic_accepts_compat_config() {
2375        let entry = model_entry_with_compat(
2376            "anthropic",
2377            "anthropic-messages",
2378            "claude-sonnet-4-5",
2379            "https://api.anthropic.com",
2380            compat_with_custom_headers(),
2381        );
2382        let provider = create_provider(&entry, None).expect("anthropic with compat");
2383        assert_eq!(provider.name(), "anthropic");
2384    }
2385
2386    #[test]
2387    fn create_provider_openai_completions_accepts_compat_config() {
2388        let entry = model_entry_with_compat(
2389            "openai",
2390            "openai-completions",
2391            "gpt-4o",
2392            "https://api.openai.com/v1",
2393            CompatConfig {
2394                max_tokens_field: Some("max_completion_tokens".to_string()),
2395                system_role_name: Some("developer".to_string()),
2396                supports_tools: Some(false),
2397                ..Default::default()
2398            },
2399        );
2400        let provider = create_provider(&entry, None).expect("openai completions with compat");
2401        assert_eq!(provider.name(), "openai");
2402    }
2403
2404    #[test]
2405    fn create_provider_openai_responses_accepts_compat_config() {
2406        let entry = model_entry_with_compat(
2407            "openai",
2408            "openai-responses",
2409            "gpt-4o",
2410            "https://api.openai.com/v1",
2411            compat_with_custom_headers(),
2412        );
2413        let provider = create_provider(&entry, None).expect("openai responses with compat");
2414        assert_eq!(provider.name(), "openai");
2415    }
2416
2417    #[test]
2418    fn create_provider_cohere_accepts_compat_config() {
2419        let entry = model_entry_with_compat(
2420            "cohere",
2421            "cohere-chat",
2422            "command-r-plus",
2423            "https://api.cohere.com/v2",
2424            compat_with_custom_headers(),
2425        );
2426        let provider = create_provider(&entry, None).expect("cohere with compat");
2427        assert_eq!(provider.name(), "cohere");
2428    }
2429
2430    #[test]
2431    fn create_provider_google_accepts_compat_config() {
2432        let entry = model_entry_with_compat(
2433            "google",
2434            "google-generative-ai",
2435            "gemini-2.0-flash",
2436            "https://generativelanguage.googleapis.com",
2437            compat_with_custom_headers(),
2438        );
2439        let provider = create_provider(&entry, None).expect("google with compat");
2440        assert_eq!(provider.name(), "google");
2441    }
2442
2443    #[test]
2444    fn create_provider_fallback_api_routes_accept_compat_config() {
2445        // Custom provider using anthropic-messages API fallback
2446        let entry = model_entry_with_compat(
2447            "custom-anthropic",
2448            "anthropic-messages",
2449            "my-model",
2450            "https://custom.api.com",
2451            compat_with_custom_headers(),
2452        );
2453        let provider = create_provider(&entry, None).expect("fallback anthropic with compat");
2454        assert_eq!(provider.model_id(), "my-model");
2455
2456        // Custom provider using openai-completions API fallback
2457        let entry = model_entry_with_compat(
2458            "my-groq-clone",
2459            "openai-completions",
2460            "llama-3.1",
2461            "http://localhost:8080/v1",
2462            compat_with_custom_headers(),
2463        );
2464        let provider = create_provider(&entry, None).expect("fallback openai with compat");
2465        assert_eq!(provider.model_id(), "llama-3.1");
2466
2467        // Custom provider using cohere-chat API fallback
2468        let entry = model_entry_with_compat(
2469            "custom-cohere",
2470            "cohere-chat",
2471            "custom-r",
2472            "https://custom-cohere.api.com/v2",
2473            compat_with_custom_headers(),
2474        );
2475        let provider = create_provider(&entry, None).expect("fallback cohere with compat");
2476        assert_eq!(provider.model_id(), "custom-r");
2477
2478        // Custom provider using google-generative-ai API fallback
2479        let entry = model_entry_with_compat(
2480            "custom-google",
2481            "google-generative-ai",
2482            "custom-gemini",
2483            "https://custom.google.com",
2484            compat_with_custom_headers(),
2485        );
2486        let provider = create_provider(&entry, None).expect("fallback google with compat");
2487        assert_eq!(provider.model_id(), "custom-gemini");
2488    }
2489
2490    // ── bd-3uqg.3.1: Google Vertex AI provider routing ──────────────
2491
2492    #[test]
2493    fn resolve_provider_route_google_vertex_routes_to_native() {
2494        let entry = model_entry(
2495            "google-vertex",
2496            "google-vertex",
2497            "gemini-2.0-flash",
2498            "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2499        );
2500        let (route, canonical_provider, effective_api) =
2501            resolve_provider_route(&entry).expect("resolve google-vertex route");
2502        assert_eq!(route, ProviderRouteKind::NativeGoogleVertex);
2503        assert_eq!(canonical_provider, "google-vertex");
2504        assert_eq!(effective_api, "google-vertex");
2505    }
2506
2507    #[test]
2508    fn resolve_provider_route_vertexai_alias_routes_to_native() {
2509        let entry = model_entry(
2510            "vertexai",
2511            "google-vertex",
2512            "gemini-2.0-flash",
2513            "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2514        );
2515        let (route, canonical_provider, effective_api) =
2516            resolve_provider_route(&entry).expect("resolve vertexai alias route");
2517        assert_eq!(route, ProviderRouteKind::NativeGoogleVertex);
2518        assert_eq!(canonical_provider, "google-vertex");
2519        assert_eq!(effective_api, "google-vertex");
2520    }
2521
2522    #[test]
2523    fn resolve_provider_route_google_vertex_api_fallback() {
2524        // Unknown provider but google-vertex API should still route correctly
2525        let entry = model_entry(
2526            "custom-vertex",
2527            "google-vertex",
2528            "gemini-2.0-flash",
2529            "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2530        );
2531        let (route, _canonical_provider, effective_api) =
2532            resolve_provider_route(&entry).expect("resolve google-vertex fallback");
2533        assert_eq!(route, ProviderRouteKind::NativeGoogleVertex);
2534        assert_eq!(effective_api, "google-vertex");
2535    }
2536
2537    #[test]
2538    fn create_provider_google_vertex_from_full_url() {
2539        let entry = model_entry(
2540            "google-vertex",
2541            "google-vertex",
2542            "gemini-2.0-flash",
2543            "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2544        );
2545        let provider = create_provider(&entry, None).expect("google-vertex from full URL");
2546        assert_eq!(provider.name(), "google-vertex");
2547        assert_eq!(provider.api(), "google-vertex");
2548        assert_eq!(provider.model_id(), "gemini-2.0-flash");
2549    }
2550
2551    #[test]
2552    fn create_provider_google_vertex_anthropic_publisher() {
2553        let entry = model_entry(
2554            "google-vertex",
2555            "google-vertex",
2556            "claude-sonnet-4-5",
2557            "https://us-east5-aiplatform.googleapis.com/v1/projects/my-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5",
2558        );
2559        let provider =
2560            create_provider(&entry, None).expect("google-vertex with anthropic publisher");
2561        assert_eq!(provider.name(), "google-vertex");
2562        assert_eq!(provider.model_id(), "claude-sonnet-4-5");
2563    }
2564
2565    #[test]
2566    fn create_provider_google_vertex_accepts_compat_config() {
2567        let entry = model_entry_with_compat(
2568            "google-vertex",
2569            "google-vertex",
2570            "gemini-2.0-flash",
2571            "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2572            compat_with_custom_headers(),
2573        );
2574        let provider = create_provider(&entry, None).expect("google-vertex with compat");
2575        assert_eq!(provider.name(), "google-vertex");
2576    }
2577
2578    #[test]
2579    fn create_provider_compat_none_accepted_by_all_routes() {
2580        // Verify None compat doesn't break anything (regression guard)
2581        let routes = [
2582            (
2583                "anthropic",
2584                "anthropic-messages",
2585                "https://api.anthropic.com",
2586            ),
2587            ("openai", "openai-completions", "https://api.openai.com/v1"),
2588            ("openai", "openai-responses", "https://api.openai.com/v1"),
2589            ("cohere", "cohere-chat", "https://api.cohere.com/v2"),
2590            (
2591                "google",
2592                "google-generative-ai",
2593                "https://generativelanguage.googleapis.com",
2594            ),
2595            (
2596                "google-vertex",
2597                "google-vertex",
2598                "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/test-model",
2599            ),
2600        ];
2601        for (provider, api, base_url) in routes {
2602            let entry = model_entry(provider, api, "test-model", base_url);
2603            assert!(
2604                entry.compat.is_none(),
2605                "expected None compat for {provider}"
2606            );
2607            let result = create_provider(&entry, None);
2608            assert!(
2609                result.is_ok(),
2610                "create_provider failed for {provider} with None compat: {:?}",
2611                result.err()
2612            );
2613        }
2614    }
2615}