1use crate::error::{Error, Result};
7use crate::extensions::{ExtensionManager, ExtensionRuntimeHandle};
8use crate::http::client::{Client, RequestBuilder};
9use crate::model::{
10 AssistantMessage, AssistantMessageEvent, ContentBlock, StopReason, TextContent, Usage,
11};
12use crate::models::ModelEntry;
13use crate::provider::{Context, Provider, StreamEvent, StreamOptions};
14use crate::provider_metadata::{
15 PROVIDER_METADATA, canonical_provider_id, provider_routing_defaults,
16};
17use crate::vcr::{VCR_ENV_MODE, VcrRecorder};
18use async_trait::async_trait;
19use chrono::Utc;
20use futures::stream;
21use futures::stream::Stream;
22use serde_json::Value;
23use std::collections::HashMap;
24use std::env;
25use std::pin::Pin;
26use std::sync::Arc;
27use url::Url;
28
29pub mod anthropic;
30pub mod azure;
31pub mod bedrock;
32pub mod cohere;
33pub mod copilot;
34pub mod gemini;
35pub mod gitlab;
36pub mod openai;
37pub mod openai_responses;
38pub mod vertex;
39
40pub(super) fn first_non_empty_header_value_case_insensitive(
41 headers: &HashMap<String, String>,
42 names: &[&str],
43) -> Option<String> {
44 headers.iter().find_map(|(key, value)| {
45 names
46 .iter()
47 .any(|name| key.eq_ignore_ascii_case(name))
48 .then_some(value.trim())
49 .filter(|value| !value.is_empty())
50 .map(ToString::to_string)
51 })
52}
53
54pub(super) fn apply_headers_ignoring_blank_auth_overrides<'a>(
55 mut request: RequestBuilder<'a>,
56 headers: &HashMap<String, String>,
57 auth_names: &[&str],
58) -> RequestBuilder<'a> {
59 for (key, value) in headers {
60 let is_blank_auth_override =
61 auth_names.iter().any(|name| key.eq_ignore_ascii_case(name)) && value.trim().is_empty();
62 if is_blank_auth_override {
63 continue;
64 }
65 request = request.header(key, value);
66 }
67 request
68}
69
70fn vcr_client_if_enabled() -> Result<Option<Client>> {
71 if env::var(VCR_ENV_MODE).is_err() {
72 return Ok(None);
73 }
74
75 let test_name = env::var("PI_VCR_TEST_NAME").unwrap_or_else(|_| "pi_runtime".to_string());
76 let recorder = VcrRecorder::new(&test_name)?;
77 Ok(Some(Client::new().with_vcr(recorder)))
78}
79
80struct ExtensionStreamSimpleProvider {
81 model: crate::provider::Model,
82 runtime: ExtensionRuntimeHandle,
83}
84
85struct ExtensionStreamSimpleState {
86 runtime: ExtensionRuntimeHandle,
87 stream_id: Option<String>,
88 model_id: String,
89 provider: String,
90 api: String,
91 accumulated_text: String,
92 last_message: Option<AssistantMessage>,
93 string_chunk_started: bool,
95 pending_events: std::collections::VecDeque<StreamEvent>,
97}
98
99impl Drop for ExtensionStreamSimpleState {
100 fn drop(&mut self) {
101 if let Some(stream_id) = self.stream_id.take() {
102 self.runtime
103 .provider_stream_simple_cancel_best_effort(stream_id);
104 }
105 }
106}
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109enum ProviderRouteKind {
110 NativeAnthropic,
111 NativeOpenAICompletions,
112 NativeOpenAIResponses,
113 NativeOpenAICodexResponses,
114 NativeCohere,
115 NativeGoogle,
116 NativeGoogleGeminiCli,
117 NativeGoogleVertex,
118 NativeBedrock,
119 NativeAzure,
120 NativeCopilot,
121 NativeGitlab,
122 ApiAnthropicMessages,
123 ApiOpenAICompletions,
124 ApiOpenAIResponses,
125 ApiOpenAICodexResponses,
126 ApiCohereChat,
127 ApiGoogleGenerativeAi,
128 ApiGoogleGeminiCli,
129}
130
131impl ProviderRouteKind {
132 const fn as_str(self) -> &'static str {
133 match self {
134 Self::NativeAnthropic => "native:anthropic",
135 Self::NativeOpenAICompletions => "native:openai-completions",
136 Self::NativeOpenAIResponses => "native:openai-responses",
137 Self::NativeOpenAICodexResponses => "native:openai-codex-responses",
138 Self::NativeCohere => "native:cohere",
139 Self::NativeGoogle => "native:google",
140 Self::NativeGoogleGeminiCli => "native:google-gemini-cli",
141 Self::NativeGoogleVertex => "native:google-vertex",
142 Self::NativeBedrock => "native:amazon-bedrock",
143 Self::NativeAzure => "native:azure-openai",
144 Self::NativeCopilot => "native:github-copilot",
145 Self::NativeGitlab => "native:gitlab",
146 Self::ApiAnthropicMessages => "api:anthropic-messages",
147 Self::ApiOpenAICompletions => "api:openai-completions",
148 Self::ApiOpenAIResponses => "api:openai-responses",
149 Self::ApiOpenAICodexResponses => "api:openai-codex-responses",
150 Self::ApiCohereChat => "api:cohere-chat",
151 Self::ApiGoogleGenerativeAi => "api:google-generative-ai",
152 Self::ApiGoogleGeminiCli => "api:google-gemini-cli",
153 }
154 }
155}
156
157fn resolve_provider_route(entry: &ModelEntry) -> Result<(ProviderRouteKind, String, String)> {
158 let canonical_provider =
159 canonical_provider_id(&entry.model.provider).unwrap_or(entry.model.provider.as_str());
160 let schema_api = provider_routing_defaults(&entry.model.provider).map(|defaults| defaults.api);
161 let effective_api = if entry.model.api.is_empty() {
162 schema_api.unwrap_or_default().to_string()
163 } else {
164 entry.model.api.clone()
165 };
166
167 let route = match canonical_provider {
168 "anthropic" => ProviderRouteKind::NativeAnthropic,
169 "openai" => {
170 if effective_api == "openai-completions" {
171 ProviderRouteKind::NativeOpenAICompletions
172 } else {
173 ProviderRouteKind::NativeOpenAIResponses
174 }
175 }
176 "openai-codex" => ProviderRouteKind::NativeOpenAICodexResponses,
177 "cohere" => ProviderRouteKind::NativeCohere,
178 "google" => ProviderRouteKind::NativeGoogle,
179 "google-gemini-cli" | "google-antigravity" => ProviderRouteKind::NativeGoogleGeminiCli,
180 "google-vertex" | "vertexai" => ProviderRouteKind::NativeGoogleVertex,
181 "amazon-bedrock" | "bedrock" => ProviderRouteKind::NativeBedrock,
182 "azure-openai" | "azure" | "azure-cognitive-services" | "azure-openai-responses" => {
183 ProviderRouteKind::NativeAzure
184 }
185 "github-copilot" | "copilot" => ProviderRouteKind::NativeCopilot,
186 "gitlab" | "gitlab-duo" => ProviderRouteKind::NativeGitlab,
187 _ => match effective_api.as_str() {
188 "anthropic-messages" => ProviderRouteKind::ApiAnthropicMessages,
189 "openai-completions" => ProviderRouteKind::ApiOpenAICompletions,
190 "openai-responses" => ProviderRouteKind::ApiOpenAIResponses,
191 "openai-codex-responses" => ProviderRouteKind::ApiOpenAICodexResponses,
192 "cohere-chat" => ProviderRouteKind::ApiCohereChat,
193 "google-generative-ai" => ProviderRouteKind::ApiGoogleGenerativeAi,
194 "google-gemini-cli" => ProviderRouteKind::ApiGoogleGeminiCli,
195 "google-vertex" => ProviderRouteKind::NativeGoogleVertex,
196 "bedrock-converse-stream" => ProviderRouteKind::NativeBedrock,
197 "azure-openai-responses" => ProviderRouteKind::NativeAzure,
198 _ => {
199 let suggestions = suggest_similar_providers(&entry.model.provider);
200 let msg = if suggestions.is_empty() {
201 format!("Provider not implemented (api: {effective_api})")
202 } else {
203 format!(
204 "Provider not implemented (api: {effective_api}). Did you mean: {}?",
205 suggestions.join(", ")
206 )
207 };
208 return Err(Error::provider(&entry.model.provider, msg));
209 }
210 },
211 };
212
213 Ok((route, canonical_provider.to_string(), effective_api))
214}
215
216fn edit_distance(a: &[u8], b: &[u8]) -> usize {
219 let (short, long) = if a.len() <= b.len() { (a, b) } else { (b, a) };
220 let mut row: Vec<usize> = (0..=short.len()).collect();
221 for (i, &lb) in long.iter().enumerate() {
222 let mut prev = i;
223 row[0] = i + 1;
224 for (j, &sb) in short.iter().enumerate() {
225 let cost = usize::from(lb != sb);
226 let val = (row[j + 1] + 1).min(row[j] + 1).min(prev + cost);
227 prev = row[j + 1];
228 row[j + 1] = val;
229 }
230 }
231 row[short.len()]
232}
233
234const fn max_edit_distance(input_len: usize) -> usize {
237 match input_len {
238 0..=2 => 0,
239 3..=5 => 1,
240 6..=9 => 2,
241 _ => 3,
242 }
243}
244
245fn suggest_similar_providers(input: &str) -> Vec<String> {
249 let needle = input.to_lowercase();
250 let needle_bytes = needle.as_bytes();
251 let threshold = max_edit_distance(needle.len());
252 let mut matches: Vec<(usize, String)> = Vec::new();
253
254 for meta in PROVIDER_METADATA {
255 let names: Vec<&str> = std::iter::once(meta.canonical_id)
256 .chain(meta.aliases.iter().copied())
257 .collect();
258 let mut matched = false;
259 for name in &names {
260 let haystack = name.to_lowercase();
261 if haystack.starts_with(&needle) || needle.starts_with(&haystack) {
263 matches.push((0, meta.canonical_id.to_string()));
264 matched = true;
265 break;
266 }
267 if haystack.contains(&needle) || needle.contains(&haystack) {
269 matches.push((1, meta.canonical_id.to_string()));
270 matched = true;
271 break;
272 }
273 }
274 if matched {
275 continue;
276 }
277 if threshold > 0 {
279 let mut best_dist = usize::MAX;
280 for name in &names {
281 let haystack = name.to_lowercase();
282 let dist = edit_distance(needle_bytes, haystack.as_bytes());
283 best_dist = best_dist.min(dist);
284 }
285 if best_dist <= threshold {
286 matches.push((
288 2_usize.wrapping_add(best_dist),
289 meta.canonical_id.to_string(),
290 ));
291 }
292 }
293 }
294
295 matches.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
296 matches.dedup_by(|a, b| a.1 == b.1);
297 matches.truncate(3);
298 matches.into_iter().map(|(_, name)| name).collect()
299}
300
301const AZURE_OPENAI_RESOURCE_ENV: &str = "AZURE_OPENAI_RESOURCE";
302const AZURE_OPENAI_DEPLOYMENT_ENV: &str = "AZURE_OPENAI_DEPLOYMENT";
303const AZURE_OPENAI_API_VERSION_ENV: &str = "AZURE_OPENAI_API_VERSION";
304
305#[derive(Debug, Clone, PartialEq, Eq)]
306struct AzureProviderRuntime {
307 resource: String,
308 deployment: String,
309 api_version: String,
310 endpoint_url: String,
311}
312
313fn trim_non_empty(value: Option<String>) -> Option<String> {
314 value
315 .map(|v| v.trim().to_string())
316 .filter(|v| !v.is_empty())
317}
318
319fn parse_azure_resource_from_host(host: &str) -> Option<String> {
320 host.strip_suffix(".openai.azure.com")
321 .or_else(|| host.strip_suffix(".cognitiveservices.azure.com"))
322 .map(str::trim)
323 .filter(|value| !value.is_empty())
324 .map(ToString::to_string)
325}
326
327fn parse_azure_base_url_details(
328 base_url: &str,
329) -> Result<(String, Option<String>, Option<String>)> {
330 let url = Url::parse(base_url)
331 .map_err(|err| Error::config(format!("Invalid Azure base_url '{base_url}': {err}")))?;
332 let host = url.host_str().map(ToString::to_string).ok_or_else(|| {
333 Error::config(format!(
334 "Azure base_url is missing host information: '{base_url}'"
335 ))
336 })?;
337
338 let mut deployment = None;
339 if let Some(segments) = url.path_segments() {
340 let mut iter = segments;
341 while let Some(segment) = iter.next() {
342 if segment == "deployments" {
343 deployment = iter
344 .next()
345 .map(str::trim)
346 .filter(|value| !value.is_empty())
347 .map(ToString::to_string);
348 break;
349 }
350 }
351 }
352
353 let api_version = url
354 .query_pairs()
355 .find(|(key, _)| key == "api-version")
356 .map(|(_, value)| value.into_owned())
357 .filter(|value| !value.trim().is_empty());
358
359 Ok((host, deployment, api_version))
360}
361
362fn resolve_azure_provider_runtime(entry: &ModelEntry) -> Result<AzureProviderRuntime> {
363 resolve_azure_provider_runtime_with_env(entry, |name| env::var(name).ok())
364}
365
366fn resolve_azure_provider_runtime_with_env<F>(
367 entry: &ModelEntry,
368 mut env_lookup: F,
369) -> Result<AzureProviderRuntime>
370where
371 F: FnMut(&str) -> Option<String>,
372{
373 let base_url = entry.model.base_url.trim();
374 if base_url.is_empty() {
375 return Err(Error::config(format!(
376 "Missing Azure base_url for provider '{}'; expected https://<resource>.openai.azure.com or https://<resource>.cognitiveservices.azure.com",
377 entry.model.provider
378 )));
379 }
380
381 let (host, base_deployment, base_api_version) = parse_azure_base_url_details(base_url)?;
382 let host_resource = parse_azure_resource_from_host(&host);
383 let env_resource = trim_non_empty(env_lookup(AZURE_OPENAI_RESOURCE_ENV));
384 let resource = env_resource.or(host_resource).ok_or_else(|| {
385 Error::config(format!(
386 "Unable to resolve Azure resource for provider '{}'; set {AZURE_OPENAI_RESOURCE_ENV} or use an Azure host in base_url ('{base_url}')",
387 entry.model.provider
388 ))
389 })?;
390
391 let env_deployment = trim_non_empty(env_lookup(AZURE_OPENAI_DEPLOYMENT_ENV));
392 let model_deployment = {
393 let model_id = entry.model.id.trim();
394 (!model_id.is_empty()).then(|| model_id.to_string())
395 };
396 let deployment = env_deployment
397 .or(base_deployment)
398 .or(model_deployment)
399 .ok_or_else(|| {
400 Error::config(format!(
401 "Unable to resolve Azure deployment for provider '{}'; set {AZURE_OPENAI_DEPLOYMENT_ENV}, provide a non-empty model id, or include '/deployments/<name>' in base_url ('{base_url}')",
402 entry.model.provider
403 ))
404 })?;
405
406 let api_version = trim_non_empty(env_lookup(AZURE_OPENAI_API_VERSION_ENV))
407 .or(base_api_version)
408 .unwrap_or_else(azure::azure_api_version);
409
410 let endpoint_host = if parse_azure_resource_from_host(&host).is_some() {
411 host
412 } else {
413 format!("{resource}.openai.azure.com")
414 };
415 let endpoint_url = format!(
416 "https://{endpoint_host}/openai/deployments/{deployment}/chat/completions?api-version={api_version}"
417 );
418
419 Ok(AzureProviderRuntime {
420 resource,
421 deployment,
422 api_version,
423 endpoint_url,
424 })
425}
426
427fn resolve_copilot_token(entry: &ModelEntry) -> Result<String> {
428 resolve_copilot_token_with_env(entry, |name| env::var(name).ok())
429}
430
431fn resolve_copilot_token_with_env<F>(entry: &ModelEntry, mut env_lookup: F) -> Result<String>
432where
433 F: FnMut(&str) -> Option<String>,
434{
435 let inline = entry
436 .api_key
437 .as_deref()
438 .map(str::trim)
439 .filter(|value| !value.is_empty())
440 .map(ToString::to_string);
441 let from_env = || {
442 env_lookup("GITHUB_COPILOT_API_KEY")
443 .or_else(|| env_lookup("GITHUB_TOKEN"))
444 .map(|value| value.trim().to_string())
445 .filter(|value| !value.is_empty())
446 };
447
448 inline.or_else(from_env).ok_or_else(|| {
449 Error::auth(
450 "GitHub Copilot requires login credentials or GITHUB_COPILOT_API_KEY/GITHUB_TOKEN",
451 )
452 })
453}
454
455impl ExtensionStreamSimpleProvider {
456 const NEXT_TIMEOUT_MS: u64 = 600_000;
457
458 const fn new(model: crate::provider::Model, runtime: ExtensionRuntimeHandle) -> Self {
459 Self { model, runtime }
460 }
461
462 fn build_js_model(model: &crate::provider::Model) -> Value {
463 serde_json::json!({
464 "id": &model.id,
465 "name": &model.name,
466 "api": &model.api,
467 "provider": &model.provider,
468 "baseUrl": &model.base_url,
469 "reasoning": model.reasoning,
470 "input": &model.input,
471 "cost": &model.cost,
472 "contextWindow": model.context_window,
473 "maxTokens": model.max_tokens,
474 "headers": &model.headers,
475 })
476 }
477
478 fn build_js_context(context: &Context<'_>) -> Value {
479 let mut map = serde_json::Map::new();
480 if let Some(system_prompt) = &context.system_prompt {
481 map.insert(
482 "systemPrompt".to_string(),
483 Value::String(system_prompt.to_string()),
484 );
485 }
486 map.insert(
487 "messages".to_string(),
488 serde_json::to_value(&context.messages).unwrap_or(Value::Array(Vec::new())),
489 );
490 if !context.tools.is_empty() {
491 let tools = context
492 .tools
493 .iter()
494 .map(|tool| {
495 serde_json::json!({
496 "name": tool.name,
497 "description": tool.description,
498 "parameters": tool.parameters,
499 })
500 })
501 .collect::<Vec<_>>();
502 map.insert("tools".to_string(), Value::Array(tools));
503 }
504 Value::Object(map)
505 }
506
507 fn build_js_options(options: &StreamOptions) -> Value {
508 let mut map = serde_json::Map::new();
509 if let Some(temp) = options.temperature {
510 map.insert("temperature".to_string(), serde_json::json!(temp));
511 }
512 if let Some(max_tokens) = options.max_tokens {
513 map.insert("maxTokens".to_string(), serde_json::json!(max_tokens));
514 }
515 if let Some(api_key) = &options.api_key {
516 map.insert("apiKey".to_string(), Value::String(api_key.clone()));
517 }
518 if let Some(session_id) = &options.session_id {
519 map.insert("sessionId".to_string(), Value::String(session_id.clone()));
520 }
521 if !options.headers.is_empty() {
522 map.insert(
523 "headers".to_string(),
524 serde_json::to_value(&options.headers)
525 .unwrap_or_else(|_| Value::Object(serde_json::Map::new())),
526 );
527 }
528 let cache_retention = match options.cache_retention {
529 crate::provider::CacheRetention::None => "none",
530 crate::provider::CacheRetention::Short => "short",
531 crate::provider::CacheRetention::Long => "long",
532 };
533 map.insert(
534 "cacheRetention".to_string(),
535 Value::String(cache_retention.to_string()),
536 );
537 if let Some(level) = options.thinking_level {
538 if level != crate::model::ThinkingLevel::Off {
539 map.insert("reasoning".to_string(), Value::String(level.to_string()));
540 }
541 }
542 if let Some(budgets) = &options.thinking_budgets {
543 map.insert(
544 "thinkingBudgets".to_string(),
545 serde_json::json!({
546 "minimal": budgets.minimal,
547 "low": budgets.low,
548 "medium": budgets.medium,
549 "high": budgets.high,
550 "xhigh": budgets.xhigh,
551 }),
552 );
553 }
554 Value::Object(map)
555 }
556
557 fn assistant_event_to_stream_event(event: AssistantMessageEvent) -> StreamEvent {
558 match event {
559 AssistantMessageEvent::Start { partial } => StreamEvent::Start {
560 partial: partial.as_ref().clone(),
561 },
562 AssistantMessageEvent::TextStart { content_index, .. } => {
563 StreamEvent::TextStart { content_index }
564 }
565 AssistantMessageEvent::TextDelta {
566 content_index,
567 delta,
568 ..
569 } => StreamEvent::TextDelta {
570 content_index,
571 delta,
572 },
573 AssistantMessageEvent::TextEnd {
574 content_index,
575 content,
576 ..
577 } => StreamEvent::TextEnd {
578 content_index,
579 content,
580 },
581 AssistantMessageEvent::ThinkingStart { content_index, .. } => {
582 StreamEvent::ThinkingStart { content_index }
583 }
584 AssistantMessageEvent::ThinkingDelta {
585 content_index,
586 delta,
587 ..
588 } => StreamEvent::ThinkingDelta {
589 content_index,
590 delta,
591 },
592 AssistantMessageEvent::ThinkingEnd {
593 content_index,
594 content,
595 ..
596 } => StreamEvent::ThinkingEnd {
597 content_index,
598 content,
599 },
600 AssistantMessageEvent::ToolCallStart { content_index, .. } => {
601 StreamEvent::ToolCallStart { content_index }
602 }
603 AssistantMessageEvent::ToolCallDelta {
604 content_index,
605 delta,
606 ..
607 } => StreamEvent::ToolCallDelta {
608 content_index,
609 delta,
610 },
611 AssistantMessageEvent::ToolCallEnd {
612 content_index,
613 tool_call,
614 ..
615 } => StreamEvent::ToolCallEnd {
616 content_index,
617 tool_call,
618 },
619 AssistantMessageEvent::Done { reason, message } => StreamEvent::Done {
620 reason,
621 message: message.as_ref().clone(),
622 },
623 AssistantMessageEvent::Error { reason, error } => StreamEvent::Error {
624 reason,
625 error: error.as_ref().clone(),
626 },
627 }
628 }
629
630 fn make_partial(model_id: &str, provider: &str, api: &str, text: &str) -> AssistantMessage {
631 AssistantMessage {
632 model: model_id.to_string(),
633 api: api.to_string(),
634 provider: provider.to_string(),
635 content: vec![ContentBlock::Text(TextContent {
636 text: text.to_string(),
637 text_signature: None,
638 })],
639 stop_reason: StopReason::default(),
640 usage: Usage::default(),
641 error_message: None,
642 timestamp: Utc::now().timestamp_millis(),
643 }
644 }
645}
646
647#[allow(clippy::too_many_lines)]
648#[async_trait]
649impl Provider for ExtensionStreamSimpleProvider {
650 #[allow(clippy::misnamed_getters)]
651 fn name(&self) -> &str {
652 &self.model.provider
653 }
654
655 fn api(&self) -> &str {
656 &self.model.api
657 }
658
659 fn model_id(&self) -> &str {
660 &self.model.id
661 }
662
663 async fn stream(
664 &self,
665 context: &Context<'_>,
666 options: &StreamOptions,
667 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
668 let model = Self::build_js_model(&self.model);
669 let ctx = Self::build_js_context(context);
670 let opts = Self::build_js_options(options);
671
672 let stream_id = self
673 .runtime
674 .provider_stream_simple_start(
675 self.model.provider.clone(),
676 model,
677 ctx,
678 opts,
679 Self::NEXT_TIMEOUT_MS,
680 )
681 .await?;
682
683 let state = ExtensionStreamSimpleState {
684 runtime: self.runtime.clone(),
685 stream_id: Some(stream_id),
686 model_id: self.model.id.clone(),
687 provider: self.model.provider.clone(),
688 api: self.model.api.clone(),
689 accumulated_text: String::new(),
690 last_message: None,
691 string_chunk_started: false,
692 pending_events: std::collections::VecDeque::new(),
693 };
694
695 let stream = stream::unfold(state, |mut state| async move {
696 if let Some(event) = state.pending_events.pop_front() {
698 return Some((Ok(event), state));
699 }
700
701 let stream_id = state.stream_id.clone()?;
702 let stream_id_for_cancel = stream_id.clone();
703
704 match state
705 .runtime
706 .provider_stream_simple_next(stream_id, Self::NEXT_TIMEOUT_MS)
707 .await
708 {
709 Ok(Some(value)) => {
710 if let Some(chunk) = value.as_str() {
711 let chunk = chunk.to_string();
712 state.accumulated_text.push_str(&chunk);
713 match &mut state.last_message {
718 Some(msg) => {
719 if let Some(ContentBlock::Text(t)) = msg.content.first_mut() {
720 t.text.clone_from(&state.accumulated_text);
721 }
722 }
723 None => {
724 state.last_message = Some(Self::make_partial(
725 &state.model_id,
726 &state.provider,
727 &state.api,
728 &state.accumulated_text,
729 ));
730 }
731 }
732
733 if !state.string_chunk_started {
735 state.string_chunk_started = true;
736 state
737 .pending_events
738 .push_back(StreamEvent::TextStart { content_index: 0 });
739 state.pending_events.push_back(StreamEvent::TextDelta {
740 content_index: 0,
741 delta: chunk,
742 });
743 return Some((
748 Ok(StreamEvent::Start {
749 partial: Self::make_partial(
750 &state.model_id,
751 &state.provider,
752 &state.api,
753 "",
754 ),
755 }),
756 state,
757 ));
758 }
759 return Some((
760 Ok(StreamEvent::TextDelta {
761 content_index: 0,
762 delta: chunk,
763 }),
764 state,
765 ));
766 }
767
768 let event: AssistantMessageEvent = match serde_json::from_value(value) {
769 Ok(event) => event,
770 Err(err) => {
771 state
772 .runtime
773 .provider_stream_simple_cancel_best_effort(stream_id_for_cancel);
774 state.stream_id = None;
775 return Some((
776 Err(Error::extension(format!(
777 "streamSimple yielded invalid event: {err}"
778 ))),
779 state,
780 ));
781 }
782 };
783
784 match &event {
785 AssistantMessageEvent::Start { partial }
786 | AssistantMessageEvent::TextStart { partial, .. }
787 | AssistantMessageEvent::TextDelta { partial, .. }
788 | AssistantMessageEvent::TextEnd { partial, .. }
789 | AssistantMessageEvent::ThinkingStart { partial, .. }
790 | AssistantMessageEvent::ThinkingDelta { partial, .. }
791 | AssistantMessageEvent::ThinkingEnd { partial, .. }
792 | AssistantMessageEvent::ToolCallStart { partial, .. }
793 | AssistantMessageEvent::ToolCallDelta { partial, .. }
794 | AssistantMessageEvent::ToolCallEnd { partial, .. } => {
795 state.last_message = Some(partial.as_ref().clone());
796 }
797 AssistantMessageEvent::Done { message, .. } => {
798 state.last_message = Some(message.as_ref().clone());
799 }
800 AssistantMessageEvent::Error { error, .. } => {
801 state.last_message = Some(error.as_ref().clone());
802 }
803 }
804
805 let stream_event = Self::assistant_event_to_stream_event(event);
806 if matches!(
807 stream_event,
808 StreamEvent::Done { .. } | StreamEvent::Error { .. }
809 ) {
810 state
811 .runtime
812 .provider_stream_simple_cancel_best_effort(stream_id_for_cancel);
813 state.stream_id = None;
814 }
815 Some((Ok(stream_event), state))
816 }
817 Ok(None) => {
818 state.stream_id = None;
820 let message = state.last_message.clone().unwrap_or_else(|| {
821 Self::make_partial(
822 &state.model_id,
823 &state.provider,
824 &state.api,
825 &state.accumulated_text,
826 )
827 });
828
829 if state.string_chunk_started {
830 state.pending_events.push_back(StreamEvent::Done {
832 reason: StopReason::Stop,
833 message,
834 });
835 Some((
836 Ok(StreamEvent::TextEnd {
837 content_index: 0,
838 content: state.accumulated_text.clone(),
839 }),
840 state,
841 ))
842 } else {
843 Some((
844 Ok(StreamEvent::Done {
845 reason: StopReason::Stop,
846 message,
847 }),
848 state,
849 ))
850 }
851 }
852 Err(err) => {
853 state
854 .runtime
855 .provider_stream_simple_cancel_best_effort(stream_id_for_cancel);
856 state.stream_id = None;
857 Some((Err(err), state))
858 }
859 }
860 });
861
862 Ok(Box::pin(stream))
863 }
864}
865
866#[allow(clippy::too_many_lines)]
867pub fn create_provider(
868 entry: &ModelEntry,
869 extensions: Option<&ExtensionManager>,
870) -> Result<Arc<dyn Provider>> {
871 if let Some(manager) = extensions {
872 if manager.provider_has_stream_simple(&entry.model.provider) {
873 let runtime = manager.runtime().ok_or_else(|| {
874 Error::provider(
875 &entry.model.provider,
876 "Extension runtime not configured for streamSimple provider",
877 )
878 })?;
879 return Ok(Arc::new(ExtensionStreamSimpleProvider::new(
880 entry.model.clone(),
881 runtime,
882 )));
883 }
884 }
885
886 let vcr_client = vcr_client_if_enabled()?;
887 let client = vcr_client.unwrap_or_else(Client::new);
888 let (route, canonical_provider, effective_api) = resolve_provider_route(entry)?;
889 tracing::debug!(
890 event = "pi.provider.factory.select",
891 provider = %entry.model.provider,
892 canonical_provider = %canonical_provider,
893 api = %effective_api,
894 base_url = %entry.model.base_url,
895 route = %route.as_str(),
896 "Selecting provider implementation"
897 );
898
899 match route {
900 ProviderRouteKind::NativeAnthropic | ProviderRouteKind::ApiAnthropicMessages => {
901 Ok(Arc::new(
902 anthropic::AnthropicProvider::new(entry.model.id.clone())
903 .with_provider_name(entry.model.provider.clone())
904 .with_base_url(normalize_anthropic_base(&entry.model.base_url))
905 .with_compat(entry.compat.clone())
906 .with_client(client),
907 ))
908 }
909 ProviderRouteKind::NativeOpenAICompletions | ProviderRouteKind::ApiOpenAICompletions => {
910 Ok(Arc::new(
911 openai::OpenAIProvider::new(entry.model.id.clone())
912 .with_provider_name(entry.model.provider.clone())
913 .with_base_url(normalize_openai_base(&entry.model.base_url))
914 .with_compat(entry.compat.clone())
915 .with_client(client),
916 ))
917 }
918 ProviderRouteKind::NativeOpenAIResponses | ProviderRouteKind::ApiOpenAIResponses => {
919 Ok(Arc::new(
920 openai_responses::OpenAIResponsesProvider::new(entry.model.id.clone())
921 .with_provider_name(entry.model.provider.clone())
922 .with_base_url(normalize_openai_responses_base(&entry.model.base_url))
923 .with_compat(entry.compat.clone())
924 .with_client(client),
925 ))
926 }
927 ProviderRouteKind::NativeOpenAICodexResponses
928 | ProviderRouteKind::ApiOpenAICodexResponses => Ok(Arc::new(
929 openai_responses::OpenAIResponsesProvider::new(entry.model.id.clone())
930 .with_provider_name(entry.model.provider.clone())
931 .with_api_name("openai-codex-responses")
932 .with_codex_mode(true)
933 .with_base_url(normalize_openai_codex_responses_base(&entry.model.base_url))
934 .with_compat(entry.compat.clone())
935 .with_client(client),
936 )),
937 ProviderRouteKind::NativeCohere | ProviderRouteKind::ApiCohereChat => Ok(Arc::new(
938 cohere::CohereProvider::new(entry.model.id.clone())
939 .with_provider_name(entry.model.provider.clone())
940 .with_base_url(normalize_cohere_base(&entry.model.base_url))
941 .with_compat(entry.compat.clone())
942 .with_client(client),
943 )),
944 ProviderRouteKind::NativeGoogle | ProviderRouteKind::ApiGoogleGenerativeAi => Ok(Arc::new(
945 gemini::GeminiProvider::new(entry.model.id.clone())
946 .with_provider_name(entry.model.provider.clone())
947 .with_api_name("google-generative-ai")
948 .with_base_url(entry.model.base_url.clone())
949 .with_compat(entry.compat.clone())
950 .with_client(client),
951 )),
952 ProviderRouteKind::NativeGoogleGeminiCli | ProviderRouteKind::ApiGoogleGeminiCli => {
953 Ok(Arc::new(
954 gemini::GeminiProvider::new(entry.model.id.clone())
955 .with_provider_name(entry.model.provider.clone())
956 .with_api_name("google-gemini-cli")
957 .with_google_cli_mode(true)
958 .with_base_url(entry.model.base_url.clone())
959 .with_compat(entry.compat.clone())
960 .with_client(client),
961 ))
962 }
963 ProviderRouteKind::NativeGoogleVertex => {
964 let runtime = vertex::resolve_vertex_provider_runtime(entry)?;
965 Ok(Arc::new(
966 vertex::VertexProvider::new(runtime.model)
967 .with_project(runtime.project)
968 .with_location(runtime.location)
969 .with_publisher(runtime.publisher)
970 .with_compat(entry.compat.clone())
971 .with_client(client),
972 ))
973 }
974 ProviderRouteKind::NativeBedrock => Ok(Arc::new(
975 bedrock::BedrockProvider::new(&entry.model.id)
976 .with_provider_name(&entry.model.provider)
977 .with_base_url(&entry.model.base_url)
978 .with_compat(entry.compat.clone())
979 .with_client(client),
980 )),
981 ProviderRouteKind::NativeAzure => {
982 let runtime = resolve_azure_provider_runtime(entry)?;
983 Ok(Arc::new(
984 azure::AzureOpenAIProvider::new(runtime.resource, runtime.deployment)
985 .with_provider_name(&entry.model.provider)
986 .with_api_version(runtime.api_version)
987 .with_endpoint_url(runtime.endpoint_url)
988 .with_compat(entry.compat.clone())
989 .with_client(client),
990 ))
991 }
992 ProviderRouteKind::NativeCopilot => {
993 let github_token = resolve_copilot_token(entry)?;
994 let mut provider = copilot::CopilotProvider::new(&entry.model.id, github_token)
995 .with_provider_name(&entry.model.provider)
996 .with_compat(entry.compat.clone())
997 .with_client(client);
998 if !entry.model.base_url.is_empty() {
999 provider = provider.with_github_api_base(&entry.model.base_url);
1000 }
1001 Ok(Arc::new(provider))
1002 }
1003 ProviderRouteKind::NativeGitlab => Ok(Arc::new(
1004 gitlab::GitLabProvider::new(&entry.model.id)
1005 .with_provider_name(&entry.model.provider)
1006 .with_base_url(&entry.model.base_url)
1007 .with_compat(entry.compat.clone())
1008 .with_client(client),
1009 )),
1010 }
1011}
1012
1013pub fn normalize_anthropic_base(base_url: &str) -> String {
1014 let trimmed = base_url.trim();
1015 if trimmed.is_empty() {
1016 return "https://api.anthropic.com/v1/messages".to_string();
1017 }
1018
1019 let mut base_for_fallback = trimmed.trim_end_matches('/').to_string();
1020
1021 if let Ok(url) = Url::parse(trimmed) {
1022 if url.cannot_be_a_base() {
1023 base_for_fallback = url.as_str().trim_end_matches('/').to_string();
1024 } else {
1025 if trimmed_url_path(&url).ends_with("/v1/messages") {
1026 return canonicalize_url_path(&url);
1027 }
1028 return append_url_path(&url, "v1/messages");
1029 }
1030 }
1031
1032 let base_url = base_for_fallback;
1033 if base_url.ends_with("/v1/messages") {
1034 return base_url;
1035 }
1036 format!("{base_url}/v1/messages")
1037}
1038
1039fn trimmed_url_path(url: &Url) -> &str {
1040 match url.path().trim_end_matches('/') {
1041 "" => "/",
1042 trimmed => trimmed,
1043 }
1044}
1045
1046fn canonicalize_url_path(url: &Url) -> String {
1047 let mut canonical = url.clone();
1048 canonical.set_path(trimmed_url_path(url));
1049 canonical.to_string()
1050}
1051
1052fn replace_url_path(url: &Url, path: &str) -> String {
1053 let mut updated = url.clone();
1054 updated.set_path(path);
1055 updated.to_string()
1056}
1057
1058fn append_url_path(url: &Url, suffix: &str) -> String {
1059 let base_path = trimmed_url_path(url);
1060 let path = if base_path == "/" {
1061 format!("/{suffix}")
1062 } else {
1063 format!("{base_path}/{suffix}")
1064 };
1065 replace_url_path(url, &path)
1066}
1067
1068fn strip_url_path_suffix(url: &Url, suffix: &str) -> Option<Url> {
1069 let base_path = trimmed_url_path(url);
1070 let prefix = base_path.strip_suffix(suffix)?;
1071 let mut stripped = url.clone();
1072 stripped.set_path(if prefix.is_empty() { "/" } else { prefix });
1073 Some(stripped)
1074}
1075
1076fn is_official_https_origin(url: &Url, host: &str, default_port: u16) -> bool {
1077 url.scheme().eq_ignore_ascii_case("https")
1078 && url
1079 .host_str()
1080 .is_some_and(|candidate| candidate.eq_ignore_ascii_case(host))
1081 && url.port_or_known_default() == Some(default_port)
1082 && trimmed_url_path(url) == "/"
1083}
1084
1085pub fn normalize_openai_base(base_url: &str) -> String {
1086 let trimmed = base_url.trim();
1087 if trimmed.is_empty() {
1088 return "https://api.openai.com/v1/chat/completions".to_string();
1089 }
1090
1091 let mut base_for_fallback = trimmed.trim_end_matches('/').to_string();
1092
1093 if let Ok(url) = Url::parse(trimmed) {
1094 if url.cannot_be_a_base() {
1095 base_for_fallback = url.as_str().trim_end_matches('/').to_string();
1096 } else {
1097 if trimmed_url_path(&url).ends_with("/chat/completions") {
1098 return canonicalize_url_path(&url);
1099 }
1100 let url = strip_url_path_suffix(&url, "/responses").unwrap_or(url);
1101 if is_official_https_origin(&url, "api.openai.com", 443) {
1102 return replace_url_path(&url, "/v1/chat/completions");
1103 }
1104 return append_url_path(&url, "chat/completions");
1105 }
1106 }
1107
1108 let base_url = base_for_fallback;
1109 if base_url.ends_with("/chat/completions") {
1110 return base_url;
1111 }
1112 let base_url = base_url
1113 .strip_suffix("/responses")
1114 .unwrap_or(base_url.as_str());
1115 format!("{base_url}/chat/completions")
1116}
1117
1118pub fn normalize_openai_responses_base(base_url: &str) -> String {
1119 let trimmed = base_url.trim();
1120 if trimmed.is_empty() {
1121 return "https://api.openai.com/v1/responses".to_string();
1122 }
1123
1124 let mut base_for_fallback = trimmed.trim_end_matches('/').to_string();
1125
1126 if let Ok(url) = Url::parse(trimmed) {
1127 if url.cannot_be_a_base() {
1128 base_for_fallback = url.as_str().trim_end_matches('/').to_string();
1129 } else {
1130 if trimmed_url_path(&url).ends_with("/responses") {
1131 return canonicalize_url_path(&url);
1132 }
1133 let url = strip_url_path_suffix(&url, "/chat/completions").unwrap_or(url);
1134 if is_official_https_origin(&url, "api.openai.com", 443) {
1135 return replace_url_path(&url, "/v1/responses");
1136 }
1137 return append_url_path(&url, "responses");
1138 }
1139 }
1140
1141 let base_url = base_for_fallback;
1142 if base_url.ends_with("/responses") {
1143 return base_url;
1144 }
1145 let base_url = base_url
1146 .strip_suffix("/chat/completions")
1147 .unwrap_or(base_url.as_str());
1148 format!("{base_url}/responses")
1149}
1150
1151pub fn normalize_openai_codex_responses_base(base_url: &str) -> String {
1152 let trimmed = base_url.trim();
1153 if trimmed.is_empty() {
1154 return openai_responses::CODEX_RESPONSES_API_URL.to_string();
1155 }
1156
1157 let mut base_for_fallback = trimmed.trim_end_matches('/').to_string();
1158
1159 if let Ok(url) = Url::parse(trimmed) {
1160 if url.cannot_be_a_base() {
1161 base_for_fallback = url.as_str().trim_end_matches('/').to_string();
1162 } else {
1163 let path = trimmed_url_path(&url);
1164 if path.ends_with("/backend-api/codex/responses") || path.ends_with("/responses") {
1165 return canonicalize_url_path(&url);
1166 }
1167 if path.ends_with("/backend-api") {
1168 return append_url_path(&url, "codex/responses");
1169 }
1170 return append_url_path(&url, "backend-api/codex/responses");
1171 }
1172 }
1173
1174 let base = base_for_fallback;
1175 if base.ends_with("/backend-api/codex/responses") {
1176 return base;
1177 }
1178 if base.ends_with("/backend-api") {
1182 return format!("{base}/codex/responses");
1183 }
1184 if base.ends_with("/responses") {
1185 return base;
1186 }
1187 format!("{base}/backend-api/codex/responses")
1188}
1189
1190pub fn normalize_cohere_base(base_url: &str) -> String {
1191 let trimmed = base_url.trim();
1192 if trimmed.is_empty() {
1193 return "https://api.cohere.com/v2/chat".to_string();
1194 }
1195
1196 let mut base_for_fallback = trimmed.trim_end_matches('/').to_string();
1197
1198 if let Ok(url) = Url::parse(trimmed) {
1199 if url.cannot_be_a_base() {
1200 base_for_fallback = url.as_str().trim_end_matches('/').to_string();
1201 } else {
1202 if trimmed_url_path(&url).ends_with("/chat") {
1203 return canonicalize_url_path(&url);
1204 }
1205 if is_official_https_origin(&url, "api.cohere.com", 443) {
1206 return replace_url_path(&url, "/v2/chat");
1207 }
1208 return append_url_path(&url, "chat");
1209 }
1210 }
1211
1212 let base_url = base_for_fallback;
1213 if base_url.ends_with("/chat") {
1214 return base_url;
1215 }
1216 format!("{base_url}/chat")
1217}
1218
1219#[cfg(test)]
1220mod tests {
1221 use super::*;
1222 use crate::extensions::{ExtensionManager, JsExtensionLoadSpec, JsExtensionRuntimeHandle};
1223 use crate::extensions_js::PiJsRuntimeConfig;
1224 use crate::model::{ContentBlock, Message, UserContent, UserMessage};
1225 use crate::tools::ToolRegistry;
1226 use asupersync::runtime::RuntimeBuilder;
1227 use asupersync::time::{sleep, wall_now};
1228 use futures::StreamExt;
1229 use std::sync::Arc;
1230 use std::time::Duration;
1231 use tempfile::tempdir;
1232
1233 const STREAM_SIMPLE_EXTENSION: &str = r#"
1234export default function init(pi) {
1235 pi.registerProvider("stream-provider", {
1236 baseUrl: "https://api.example.test",
1237 apiKey: "EXAMPLE_KEY",
1238 api: "custom-api",
1239 models: [
1240 { id: "stream-model", name: "Stream Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1241 ],
1242 streamSimple: async function* (model, context, options) {
1243 if (!model || !model.baseUrl || !model.maxTokens || !model.contextWindow) {
1244 throw new Error("bad model shape");
1245 }
1246 if (!context || !Array.isArray(context.messages)) {
1247 throw new Error("bad context shape");
1248 }
1249 if (!options || !options.signal) {
1250 throw new Error("missing abort signal");
1251 }
1252
1253 const partial = {
1254 role: "assistant",
1255 content: [{ type: "text", text: "" }],
1256 api: model.api,
1257 provider: model.provider,
1258 model: model.id,
1259 usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1260 stopReason: "stop",
1261 timestamp: 0
1262 };
1263
1264 yield { type: "start", partial };
1265 yield { type: "text_start", contentIndex: 0, partial };
1266 partial.content[0].text += "hi";
1267 yield { type: "text_delta", contentIndex: 0, delta: "hi", partial };
1268 yield { type: "done", reason: "stop", message: partial };
1269 }
1270 });
1271}
1272"#;
1273
1274 const STREAM_SIMPLE_CANCEL_EXTENSION: &str = r#"
1275export default function init(pi) {
1276 pi.registerProvider("cancel-provider", {
1277 baseUrl: "https://api.example.test",
1278 apiKey: "EXAMPLE_KEY",
1279 api: "custom-api",
1280 models: [
1281 { id: "cancel-model", name: "Cancel Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1282 ],
1283 streamSimple: async function* (model, context, options) {
1284 const partial = {
1285 role: "assistant",
1286 content: [{ type: "text", text: "" }],
1287 api: model.api,
1288 provider: model.provider,
1289 model: model.id,
1290 usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1291 stopReason: "stop",
1292 timestamp: 0
1293 };
1294
1295 try {
1296 yield { type: "start", partial };
1297 await new Promise((resolve) => {
1298 if (options && options.signal && options.signal.aborted) return resolve();
1299 if (options && options.signal && typeof options.signal.addEventListener === "function") {
1300 options.signal.addEventListener("abort", () => resolve());
1301 }
1302 });
1303 } finally {
1304 await pi.tool("write", { path: "cancelled.txt", content: "ok" });
1305 }
1306 }
1307 });
1308}
1309"#;
1310
1311 async fn load_extension(
1312 source: &str,
1313 allow_write: bool,
1314 ) -> (tempfile::TempDir, ExtensionManager) {
1315 let dir = tempdir().expect("tempdir");
1316 let entry_path = dir.path().join("ext.mjs");
1317 std::fs::write(&entry_path, source).expect("write extension");
1318
1319 let manager = ExtensionManager::new();
1320 let tools = if allow_write {
1321 Arc::new(ToolRegistry::new(&["write"], dir.path(), None))
1322 } else {
1323 Arc::new(ToolRegistry::new(&[], dir.path(), None))
1324 };
1325
1326 let js_runtime = JsExtensionRuntimeHandle::start(
1327 PiJsRuntimeConfig {
1328 cwd: dir.path().display().to_string(),
1329 ..Default::default()
1330 },
1331 Arc::clone(&tools),
1332 manager.clone(),
1333 )
1334 .await
1335 .expect("start js runtime");
1336 manager.set_js_runtime(js_runtime);
1337
1338 let spec = JsExtensionLoadSpec::from_entry_path(&entry_path).expect("load spec");
1339 manager
1340 .load_js_extensions(vec![spec])
1341 .await
1342 .expect("load extension");
1343
1344 (dir, manager)
1345 }
1346
1347 fn basic_context() -> Context<'static> {
1348 Context {
1349 system_prompt: Some("system".to_string().into()),
1350 messages: vec![Message::User(UserMessage {
1351 content: UserContent::Text("hello".to_string()),
1352 timestamp: 0,
1353 })]
1354 .into(),
1355 tools: Vec::new().into(),
1356 }
1357 }
1358
1359 fn basic_options() -> StreamOptions {
1360 StreamOptions {
1361 api_key: Some("sk-test".to_string()),
1362 ..Default::default()
1363 }
1364 }
1365
1366 #[test]
1367 fn extension_stream_simple_provider_emits_assistant_events() {
1368 let runtime = RuntimeBuilder::current_thread()
1369 .build()
1370 .expect("runtime build");
1371
1372 runtime.block_on(async move {
1373 let (_dir, manager) = load_extension(STREAM_SIMPLE_EXTENSION, false).await;
1374 let entries = manager.extension_model_entries();
1375 assert_eq!(entries.len(), 1);
1376 let entry = entries
1377 .iter()
1378 .find(|e| e.model.provider == "stream-provider")
1379 .expect("stream-provider entry");
1380
1381 let provider = create_provider(entry, Some(&manager)).expect("create provider");
1382 assert_eq!(provider.name(), "stream-provider");
1383
1384 let ctx = basic_context();
1385 let opts = basic_options();
1386 let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1387
1388 let mut saw_start = false;
1389 let mut saw_text_delta = false;
1390 while let Some(item) = stream.next().await {
1391 let event = item.expect("stream event");
1392 match event {
1393 StreamEvent::Start { .. } => {
1394 saw_start = true;
1395 }
1396 StreamEvent::TextDelta { delta, .. } => {
1397 assert_eq!(delta, "hi");
1398 saw_text_delta = true;
1399 }
1400 StreamEvent::Done { reason, message } => {
1401 assert_eq!(reason, StopReason::Stop);
1402 let text = match &message.content[0] {
1403 ContentBlock::Text(text) => text,
1404 other => unreachable!("expected text content block, got {other:?}"),
1405 };
1406 assert_eq!(text.text, "hi");
1407 break;
1408 }
1409 _ => {}
1410 }
1411 }
1412
1413 assert!(saw_start, "expected a Start event");
1414 assert!(saw_text_delta, "expected a TextDelta event");
1415 });
1416 }
1417
1418 #[test]
1419 fn extension_stream_simple_provider_drop_cancels_js_stream() {
1420 let runtime = RuntimeBuilder::current_thread()
1421 .build()
1422 .expect("runtime build");
1423
1424 runtime.block_on(async move {
1425 let (dir, manager) = load_extension(STREAM_SIMPLE_CANCEL_EXTENSION, true).await;
1426 let entries = manager.extension_model_entries();
1427 assert_eq!(entries.len(), 1);
1428 let entry = entries
1429 .iter()
1430 .find(|e| e.model.provider == "cancel-provider")
1431 .expect("cancel-provider entry");
1432
1433 let provider = create_provider(entry, Some(&manager)).expect("create provider");
1434 let ctx = basic_context();
1435 let opts = basic_options();
1436 let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1437
1438 let first = stream.next().await.expect("first event");
1439 let _ = first.expect("first event ok");
1440 drop(stream);
1441
1442 let out_path = dir.path().join("cancelled.txt");
1443 for _ in 0..200 {
1444 if out_path.exists() {
1445 let contents = std::fs::read_to_string(&out_path).expect("read cancelled.txt");
1446 assert_eq!(contents, "ok");
1447 return;
1448 }
1449 sleep(wall_now(), Duration::from_millis(5)).await;
1450 }
1451
1452 assert!(
1453 out_path.exists(),
1454 "expected cancelled.txt to be created after stream drop/cancel"
1455 );
1456 });
1457 }
1458
1459 const STREAM_SIMPLE_MULTI_CHUNK: &str = r#"
1464export default function init(pi) {
1465 pi.registerProvider("multi-chunk-provider", {
1466 baseUrl: "https://api.example.test",
1467 apiKey: "EXAMPLE_KEY",
1468 api: "custom-api",
1469 models: [
1470 { id: "multi-model", name: "Multi Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1471 ],
1472 streamSimple: async function* (model, context, options) {
1473 const partial = {
1474 role: "assistant",
1475 content: [{ type: "text", text: "" }],
1476 api: model.api,
1477 provider: model.provider,
1478 model: model.id,
1479 usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1480 stopReason: "stop",
1481 timestamp: 0
1482 };
1483
1484 yield { type: "start", partial };
1485 yield { type: "text_start", contentIndex: 0, partial };
1486
1487 const chunks = ["Hello", ", ", "world", "!"];
1488 for (const chunk of chunks) {
1489 partial.content[0].text += chunk;
1490 yield { type: "text_delta", contentIndex: 0, delta: chunk, partial };
1491 }
1492
1493 yield { type: "text_end", contentIndex: 0, content: partial.content[0].text, partial };
1494 yield { type: "done", reason: "stop", message: partial };
1495 }
1496 });
1497}
1498"#;
1499
1500 const STREAM_SIMPLE_ERROR: &str = r#"
1501export default function init(pi) {
1502 pi.registerProvider("error-provider", {
1503 baseUrl: "https://api.example.test",
1504 apiKey: "EXAMPLE_KEY",
1505 api: "custom-api",
1506 models: [
1507 { id: "error-model", name: "Error Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1508 ],
1509 streamSimple: async function* (model, context, options) {
1510 const partial = {
1511 role: "assistant",
1512 content: [{ type: "text", text: "" }],
1513 api: model.api,
1514 provider: model.provider,
1515 model: model.id,
1516 usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1517 stopReason: "stop",
1518 timestamp: 0
1519 };
1520
1521 yield { type: "start", partial };
1522 throw new Error("simulated JS error during streaming");
1523 }
1524 });
1525}
1526"#;
1527
1528 const STREAM_SIMPLE_UNICODE: &str = r#"
1529export default function init(pi) {
1530 pi.registerProvider("unicode-provider", {
1531 baseUrl: "https://api.example.test",
1532 apiKey: "EXAMPLE_KEY",
1533 api: "custom-api",
1534 models: [
1535 { id: "unicode-model", name: "Unicode Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1536 ],
1537 streamSimple: async function* (model, context, options) {
1538 const partial = {
1539 role: "assistant",
1540 content: [{ type: "text", text: "" }],
1541 api: model.api,
1542 provider: model.provider,
1543 model: model.id,
1544 usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1545 stopReason: "stop",
1546 timestamp: 0
1547 };
1548
1549 yield { type: "start", partial };
1550 yield { type: "text_start", contentIndex: 0, partial };
1551 partial.content[0].text = "日本語テスト 🦀";
1552 yield { type: "text_delta", contentIndex: 0, delta: "日本語テスト 🦀", partial };
1553 yield { type: "done", reason: "stop", message: partial };
1554 }
1555 });
1556}
1557"#;
1558
1559 #[test]
1560 fn extension_stream_simple_multiple_chunks_in_order() {
1561 let runtime = RuntimeBuilder::current_thread()
1562 .build()
1563 .expect("runtime build");
1564
1565 runtime.block_on(async move {
1566 let (_dir, manager) = load_extension(STREAM_SIMPLE_MULTI_CHUNK, false).await;
1567 let entries = manager.extension_model_entries();
1568 let entry = entries
1569 .iter()
1570 .find(|e| e.model.provider == "multi-chunk-provider")
1571 .expect("multi-chunk-provider entry");
1572
1573 let provider = create_provider(entry, Some(&manager)).expect("create provider");
1574 let ctx = basic_context();
1575 let opts = basic_options();
1576 let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1577
1578 let mut deltas = Vec::new();
1579 let mut final_text = String::new();
1580 while let Some(item) = stream.next().await {
1581 let event = item.expect("stream event");
1582 match event {
1583 StreamEvent::TextDelta { delta, .. } => {
1584 deltas.push(delta);
1585 }
1586 StreamEvent::Done { message, .. } => {
1587 let text = match &message.content[0] {
1588 ContentBlock::Text(text) => text,
1589 other => unreachable!("expected text content block, got {other:?}"),
1590 };
1591 final_text = text.text.clone();
1592 break;
1593 }
1594 _ => {}
1595 }
1596 }
1597
1598 assert_eq!(deltas, vec!["Hello", ", ", "world", "!"]);
1599 assert_eq!(final_text, "Hello, world!");
1600 });
1601 }
1602
1603 #[test]
1604 fn extension_stream_simple_js_error_propagates() {
1605 let runtime = RuntimeBuilder::current_thread()
1606 .build()
1607 .expect("runtime build");
1608
1609 runtime.block_on(async move {
1610 let (_dir, manager) = load_extension(STREAM_SIMPLE_ERROR, false).await;
1611 let entries = manager.extension_model_entries();
1612 let entry = entries
1613 .iter()
1614 .find(|e| e.model.provider == "error-provider")
1615 .expect("error-provider entry");
1616
1617 let provider = create_provider(entry, Some(&manager)).expect("create provider");
1618 let ctx = basic_context();
1619 let opts = basic_options();
1620 let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1621
1622 let mut saw_start = false;
1623 let mut saw_error = false;
1624 while let Some(item) = stream.next().await {
1625 match item {
1626 Ok(StreamEvent::Start { .. }) => {
1627 saw_start = true;
1628 }
1629 Err(err) => {
1630 let msg = err.to_string();
1632 assert!(
1633 msg.contains("simulated JS error") || msg.contains("error"),
1634 "expected JS error message, got: {msg}"
1635 );
1636 saw_error = true;
1637 break;
1638 }
1639 Ok(StreamEvent::Error { .. }) => {
1640 saw_error = true;
1641 break;
1642 }
1643 _ => {}
1644 }
1645 }
1646
1647 assert!(saw_start, "expected a Start event before error");
1648 assert!(saw_error, "expected JS error to propagate");
1649 });
1650 }
1651
1652 #[test]
1653 fn extension_stream_simple_unicode_content() {
1654 let runtime = RuntimeBuilder::current_thread()
1655 .build()
1656 .expect("runtime build");
1657
1658 runtime.block_on(async move {
1659 let (_dir, manager) = load_extension(STREAM_SIMPLE_UNICODE, false).await;
1660 let entries = manager.extension_model_entries();
1661 let entry = entries
1662 .iter()
1663 .find(|e| e.model.provider == "unicode-provider")
1664 .expect("unicode-provider entry");
1665
1666 let provider = create_provider(entry, Some(&manager)).expect("create provider");
1667 let ctx = basic_context();
1668 let opts = basic_options();
1669 let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1670
1671 let mut saw_unicode = false;
1672 while let Some(item) = stream.next().await {
1673 let event = item.expect("stream event");
1674 match event {
1675 StreamEvent::TextDelta { delta, .. } => {
1676 assert_eq!(delta, "日本語テスト 🦀");
1677 saw_unicode = true;
1678 }
1679 StreamEvent::Done { .. } => break,
1680 _ => {}
1681 }
1682 }
1683
1684 assert!(saw_unicode, "expected unicode text delta");
1685 });
1686 }
1687
1688 #[test]
1689 fn extension_stream_simple_provider_name_and_model() {
1690 let runtime = RuntimeBuilder::current_thread()
1691 .build()
1692 .expect("runtime build");
1693
1694 runtime.block_on(async move {
1695 let (_dir, manager) = load_extension(STREAM_SIMPLE_EXTENSION, false).await;
1696 let entries = manager.extension_model_entries();
1697 let entry = entries
1698 .iter()
1699 .find(|e| e.model.provider == "stream-provider")
1700 .expect("stream-provider entry");
1701
1702 let provider = create_provider(entry, Some(&manager)).expect("create provider");
1703 assert_eq!(provider.name(), "stream-provider");
1704 assert_eq!(provider.model_id(), "stream-model");
1705 assert_eq!(provider.api(), "custom-api");
1706 });
1707 }
1708
1709 #[test]
1710 fn create_provider_returns_extension_provider_for_stream_simple() {
1711 let runtime = RuntimeBuilder::current_thread()
1712 .build()
1713 .expect("runtime build");
1714
1715 runtime.block_on(async move {
1716 let (_dir, manager) = load_extension(STREAM_SIMPLE_EXTENSION, false).await;
1717 let entries = manager.extension_model_entries();
1718 let entry = entries
1719 .iter()
1720 .find(|e| e.model.provider == "stream-provider")
1721 .expect("stream-provider entry");
1722
1723 let provider = create_provider(entry, Some(&manager));
1725 assert!(provider.is_ok());
1726
1727 let provider_no_ext = create_provider(entry, None);
1729 assert!(provider_no_ext.is_err());
1730 });
1731 }
1732
1733 use crate::models::ModelEntry;
1738 use crate::provider::{InputType, Model, ModelCost};
1739 use std::collections::HashMap;
1740
1741 fn model_entry(provider: &str, api: &str, model_id: &str, base_url: &str) -> ModelEntry {
1742 ModelEntry {
1743 model: Model {
1744 id: model_id.to_string(),
1745 name: model_id.to_string(),
1746 api: api.to_string(),
1747 provider: provider.to_string(),
1748 base_url: base_url.to_string(),
1749 reasoning: false,
1750 input: vec![InputType::Text],
1751 cost: ModelCost {
1752 input: 3.0,
1753 output: 15.0,
1754 cache_read: 0.3,
1755 cache_write: 3.75,
1756 },
1757 context_window: 200_000,
1758 max_tokens: 8192,
1759 headers: HashMap::new(),
1760 },
1761 api_key: Some("sk-test-key".to_string()),
1762 headers: HashMap::new(),
1763 auth_header: true,
1764 compat: None,
1765 oauth_config: None,
1766 }
1767 }
1768
1769 #[test]
1770 fn resolve_provider_route_uses_metadata_for_alias_provider() {
1771 let entry = model_entry(
1772 "kimi",
1773 "openai-completions",
1774 "kimi-k2-instruct",
1775 "https://api.moonshot.ai/v1",
1776 );
1777 let (route, canonical_provider, effective_api) =
1778 resolve_provider_route(&entry).expect("resolve alias route");
1779 assert_eq!(route, ProviderRouteKind::ApiOpenAICompletions);
1780 assert_eq!(canonical_provider, "moonshotai");
1781 assert_eq!(effective_api, "openai-completions");
1782 }
1783
1784 #[test]
1785 fn resolve_provider_route_openai_unknown_api_defaults_to_native_responses() {
1786 let entry = model_entry("openai", "openai", "gpt-4o", "https://api.openai.com/v1");
1787 let (route, canonical_provider, effective_api) =
1788 resolve_provider_route(&entry).expect("resolve openai route");
1789 assert_eq!(route, ProviderRouteKind::NativeOpenAIResponses);
1790 assert_eq!(canonical_provider, "openai");
1791 assert_eq!(effective_api, "openai");
1792 }
1793
1794 #[test]
1795 fn resolve_provider_route_cloudflare_workers_defaults_to_openai_completions() {
1796 let entry = model_entry(
1797 "cloudflare-workers-ai",
1798 "",
1799 "@cf/meta/llama-3.1-8b-instruct",
1800 "https://api.cloudflare.com/client/v4/accounts/test-account/ai/v1",
1801 );
1802 let (route, canonical_provider, effective_api) =
1803 resolve_provider_route(&entry).expect("resolve cloudflare workers route");
1804 assert_eq!(route, ProviderRouteKind::ApiOpenAICompletions);
1805 assert_eq!(canonical_provider, "cloudflare-workers-ai");
1806 assert_eq!(effective_api, "openai-completions");
1807 }
1808
1809 #[test]
1810 fn resolve_provider_route_cloudflare_gateway_defaults_to_openai_completions() {
1811 let entry = model_entry(
1812 "cloudflare-ai-gateway",
1813 "",
1814 "gpt-4o-mini",
1815 "https://gateway.ai.cloudflare.com/v1/account-id/gateway-id/openai",
1816 );
1817 let (route, canonical_provider, effective_api) =
1818 resolve_provider_route(&entry).expect("resolve cloudflare gateway route");
1819 assert_eq!(route, ProviderRouteKind::ApiOpenAICompletions);
1820 assert_eq!(canonical_provider, "cloudflare-ai-gateway");
1821 assert_eq!(effective_api, "openai-completions");
1822 }
1823
1824 #[test]
1825 fn resolve_provider_route_uses_native_azure_route_for_cognitive_alias() {
1826 let entry = model_entry(
1827 "azure-cognitive-services",
1828 "openai-completions",
1829 "gpt-4o-mini",
1830 "https://myresource.cognitiveservices.azure.com",
1831 );
1832 let (route, canonical_provider, effective_api) =
1833 resolve_provider_route(&entry).expect("resolve azure cognitive route");
1834 assert_eq!(route, ProviderRouteKind::NativeAzure);
1835 assert_eq!(canonical_provider, "azure-openai");
1836 assert_eq!(effective_api, "openai-completions");
1837 }
1838
1839 #[test]
1840 fn resolve_provider_route_uses_native_azure_route_for_legacy_provider_alias() {
1841 let entry = model_entry(
1842 "azure-openai-responses",
1843 "azure-openai-responses",
1844 "gpt-4o-mini",
1845 "https://myresource.openai.azure.com",
1846 );
1847 let (route, canonical_provider, effective_api) =
1848 resolve_provider_route(&entry).expect("resolve azure legacy alias route");
1849 assert_eq!(route, ProviderRouteKind::NativeAzure);
1850 assert_eq!(canonical_provider, "azure-openai");
1851 assert_eq!(effective_api, "azure-openai-responses");
1852 }
1853
1854 #[test]
1855 fn resolve_provider_route_accepts_azure_legacy_api_for_custom_provider_id() {
1856 let entry = model_entry(
1857 "my-azure",
1858 "azure-openai-responses",
1859 "gpt-4o-mini",
1860 "https://example.invalid",
1861 );
1862 let (route, canonical_provider, effective_api) =
1863 resolve_provider_route(&entry).expect("resolve azure legacy api fallback");
1864 assert_eq!(route, ProviderRouteKind::NativeAzure);
1865 assert_eq!(canonical_provider, "my-azure");
1866 assert_eq!(effective_api, "azure-openai-responses");
1867 }
1868
1869 #[test]
1870 fn resolve_copilot_token_prefers_inline_model_api_key() {
1871 let mut entry = model_entry("github-copilot", "", "gpt-4o", "");
1872 entry.api_key = Some("inline-copilot-token".to_string());
1873
1874 let token = resolve_copilot_token_with_env(&entry, |_| None)
1875 .expect("inline token should be accepted");
1876 assert_eq!(token, "inline-copilot-token");
1877 }
1878
1879 #[test]
1880 fn resolve_copilot_token_falls_back_to_env() {
1881 let mut entry = model_entry("github-copilot", "", "gpt-4o", "");
1882 entry.api_key = None;
1883
1884 let token = resolve_copilot_token_with_env(&entry, |name| match name {
1885 "GITHUB_COPILOT_API_KEY" => Some("env-copilot-token".to_string()),
1886 _ => None,
1887 })
1888 .expect("env token should be accepted");
1889 assert_eq!(token, "env-copilot-token");
1890 }
1891
1892 #[test]
1893 fn resolve_copilot_token_errors_when_missing_everywhere() {
1894 let mut entry = model_entry("github-copilot", "", "gpt-4o", "");
1895 entry.api_key = None;
1896
1897 let err = resolve_copilot_token_with_env(&entry, |_| None).expect_err("expected error");
1898 assert!(
1899 err.to_string().contains("GitHub Copilot requires"),
1900 "unexpected error: {err}"
1901 );
1902 }
1903
1904 #[test]
1905 fn suggest_similar_providers_finds_prefix_match() {
1906 let suggestions = suggest_similar_providers("deep");
1907 assert!(
1908 suggestions.contains(&"deepinfra".to_string())
1909 || suggestions.contains(&"deepseek".to_string()),
1910 "expected deepinfra or deepseek in suggestions: {suggestions:?}"
1911 );
1912 }
1913
1914 #[test]
1915 fn suggest_similar_providers_finds_substring_match() {
1916 let suggestions = suggest_similar_providers("flow");
1917 assert!(
1918 suggestions.contains(&"siliconflow".to_string()),
1919 "expected siliconflow in suggestions: {suggestions:?}"
1920 );
1921 }
1922
1923 #[test]
1924 fn suggest_similar_providers_returns_empty_for_gibberish() {
1925 let suggestions = suggest_similar_providers("xyzzzabc123");
1926 assert!(
1927 suggestions.is_empty(),
1928 "expected no suggestions for gibberish: {suggestions:?}"
1929 );
1930 }
1931
1932 #[test]
1933 fn suggest_similar_providers_caps_at_three() {
1934 let suggestions = suggest_similar_providers("a");
1935 assert!(
1936 suggestions.len() <= 3,
1937 "expected at most 3 suggestions: {suggestions:?}"
1938 );
1939 }
1940
1941 #[test]
1942 fn edit_distance_basic_cases() {
1943 assert_eq!(edit_distance(b"", b""), 0);
1944 assert_eq!(edit_distance(b"abc", b"abc"), 0);
1945 assert_eq!(edit_distance(b"abc", b"ab"), 1);
1946 assert_eq!(edit_distance(b"abc", b"axc"), 1);
1947 assert_eq!(edit_distance(b"abc", b"abcd"), 1);
1948 assert_eq!(edit_distance(b"kitten", b"sitting"), 3);
1949 assert_eq!(edit_distance(b"", b"hello"), 5);
1950 }
1951
1952 #[test]
1953 fn suggest_similar_providers_finds_typo_with_edit_distance() {
1954 let suggestions = suggest_similar_providers("anthropick");
1956 assert!(
1957 suggestions.contains(&"anthropic".to_string()),
1958 "expected anthropic for typo 'anthropick': {suggestions:?}"
1959 );
1960 }
1961
1962 #[test]
1963 fn suggest_similar_providers_finds_typo_missing_char() {
1964 let suggestions = suggest_similar_providers("opnai");
1966 assert!(
1967 suggestions.contains(&"openai".to_string()),
1968 "expected openai for typo 'opnai': {suggestions:?}"
1969 );
1970 }
1971
1972 #[test]
1973 fn suggest_similar_providers_finds_transposed_chars() {
1974 let suggestions = suggest_similar_providers("gogle");
1976 assert!(
1977 suggestions.contains(&"google".to_string()),
1978 "expected google for typo 'gogle': {suggestions:?}"
1979 );
1980 }
1981
1982 #[test]
1983 fn suggest_similar_providers_no_false_positives_for_short_input() {
1984 let suggestions = suggest_similar_providers("xy");
1986 assert!(
1987 suggestions.is_empty(),
1988 "expected no suggestions for 'xy': {suggestions:?}"
1989 );
1990 }
1991
1992 #[test]
1993 fn resolve_azure_provider_runtime_supports_openai_host() {
1994 let entry = model_entry(
1995 "azure-openai",
1996 "openai-completions",
1997 "gpt-4o",
1998 "https://myresource.openai.azure.com",
1999 );
2000 let runtime =
2001 resolve_azure_provider_runtime_with_env(&entry, |_| None).expect("resolve runtime");
2002 assert_eq!(runtime.resource, "myresource");
2003 assert_eq!(runtime.deployment, "gpt-4o");
2004 assert_eq!(runtime.api_version, "2024-12-01-preview");
2005 assert_eq!(
2006 runtime.endpoint_url,
2007 "https://myresource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-12-01-preview"
2008 );
2009 }
2010
2011 #[test]
2012 fn resolve_azure_provider_runtime_supports_cognitive_services_host() {
2013 let entry = model_entry(
2014 "azure-cognitive-services",
2015 "openai-completions",
2016 "gpt-4o-mini",
2017 "https://myresource.cognitiveservices.azure.com/openai/deployments/custom/chat/completions?api-version=2024-10-21",
2018 );
2019 let runtime =
2020 resolve_azure_provider_runtime_with_env(&entry, |_| None).expect("resolve runtime");
2021 assert_eq!(runtime.resource, "myresource");
2022 assert_eq!(runtime.deployment, "custom");
2023 assert_eq!(runtime.api_version, "2024-10-21");
2024 assert_eq!(
2025 runtime.endpoint_url,
2026 "https://myresource.cognitiveservices.azure.com/openai/deployments/custom/chat/completions?api-version=2024-10-21"
2027 );
2028 }
2029
2030 #[test]
2031 fn resolve_azure_provider_runtime_prefers_base_url_deployment_over_model_id() {
2032 let entry = model_entry(
2033 "azure-openai",
2034 "openai-completions",
2035 "model-fallback",
2036 "https://myresource.openai.azure.com/openai/deployments/base-deploy/chat/completions?api-version=2024-10-21",
2037 );
2038 let runtime =
2039 resolve_azure_provider_runtime_with_env(&entry, |_| None).expect("resolve runtime");
2040 assert_eq!(runtime.resource, "myresource");
2041 assert_eq!(runtime.deployment, "base-deploy");
2042 assert_eq!(runtime.api_version, "2024-10-21");
2043 assert_eq!(
2044 runtime.endpoint_url,
2045 "https://myresource.openai.azure.com/openai/deployments/base-deploy/chat/completions?api-version=2024-10-21"
2046 );
2047 }
2048
2049 #[test]
2050 fn resolve_azure_provider_runtime_env_deployment_overrides_base_url_and_model_id() {
2051 let entry = model_entry(
2052 "azure-openai",
2053 "openai-completions",
2054 "model-fallback",
2055 "https://myresource.openai.azure.com/openai/deployments/base-deploy/chat/completions?api-version=2024-10-21",
2056 );
2057 let runtime = resolve_azure_provider_runtime_with_env(&entry, |name| match name {
2058 AZURE_OPENAI_DEPLOYMENT_ENV => Some("env-deploy".to_string()),
2059 _ => None,
2060 })
2061 .expect("resolve runtime");
2062 assert_eq!(runtime.resource, "myresource");
2063 assert_eq!(runtime.deployment, "env-deploy");
2064 assert_eq!(runtime.api_version, "2024-10-21");
2065 assert_eq!(
2066 runtime.endpoint_url,
2067 "https://myresource.openai.azure.com/openai/deployments/env-deploy/chat/completions?api-version=2024-10-21"
2068 );
2069 }
2070
2071 #[test]
2074 fn create_provider_anthropic_by_name() {
2075 let entry = model_entry(
2076 "anthropic",
2077 "anthropic-messages",
2078 "claude-sonnet-4-5",
2079 "https://api.anthropic.com",
2080 );
2081 let provider = create_provider(&entry, None).expect("anthropic provider");
2082 assert_eq!(provider.name(), "anthropic");
2083 assert_eq!(provider.model_id(), "claude-sonnet-4-5");
2084 assert_eq!(provider.api(), "anthropic-messages");
2085 }
2086
2087 #[test]
2088 fn create_provider_openai_completions_by_name() {
2089 let entry = model_entry(
2090 "openai",
2091 "openai-completions",
2092 "gpt-4o",
2093 "https://api.openai.com/v1",
2094 );
2095 let provider = create_provider(&entry, None).expect("openai completions provider");
2096 assert_eq!(provider.name(), "openai");
2097 assert_eq!(provider.model_id(), "gpt-4o");
2098 }
2099
2100 #[test]
2101 fn create_provider_openai_responses_by_name() {
2102 let entry = model_entry(
2103 "openai",
2104 "openai-responses",
2105 "gpt-4o",
2106 "https://api.openai.com/v1",
2107 );
2108 let provider = create_provider(&entry, None).expect("openai responses provider");
2109 assert_eq!(provider.name(), "openai");
2110 assert_eq!(provider.model_id(), "gpt-4o");
2111 }
2112
2113 #[test]
2114 fn create_provider_openai_defaults_to_responses() {
2115 let entry = model_entry("openai", "openai", "gpt-4o", "https://api.openai.com/v1");
2117 let provider = create_provider(&entry, None).expect("openai default responses provider");
2118 assert_eq!(provider.name(), "openai");
2119 }
2120
2121 #[test]
2122 fn create_provider_google_by_name() {
2123 let entry = model_entry(
2124 "google",
2125 "google-generative-ai",
2126 "gemini-2.0-flash",
2127 "https://generativelanguage.googleapis.com",
2128 );
2129 let provider = create_provider(&entry, None).expect("google provider");
2130 assert_eq!(provider.name(), "google");
2131 assert_eq!(provider.model_id(), "gemini-2.0-flash");
2132 }
2133
2134 #[test]
2135 fn create_provider_cohere_by_name() {
2136 let entry = model_entry(
2137 "cohere",
2138 "cohere-chat",
2139 "command-r-plus",
2140 "https://api.cohere.com/v2",
2141 );
2142 let provider = create_provider(&entry, None).expect("cohere provider");
2143 assert_eq!(provider.name(), "cohere");
2144 assert_eq!(provider.model_id(), "command-r-plus");
2145 }
2146
2147 #[test]
2148 fn create_provider_azure_openai_by_name() {
2149 let entry = model_entry(
2150 "azure-openai",
2151 "openai-completions",
2152 "gpt-4o",
2153 "https://myresource.openai.azure.com",
2154 );
2155 let provider = create_provider(&entry, None).expect("azure provider");
2156 assert_eq!(provider.name(), "azure-openai");
2157 assert_eq!(provider.api(), "azure-openai");
2158 assert!(!provider.model_id().is_empty());
2159 }
2160
2161 #[test]
2162 fn create_provider_azure_cognitive_services_alias_by_name() {
2163 let entry = model_entry(
2164 "azure-cognitive-services",
2165 "openai-completions",
2166 "gpt-4o-mini",
2167 "https://myresource.cognitiveservices.azure.com",
2168 );
2169 let provider = create_provider(&entry, None).expect("azure cognitive provider");
2170 assert_eq!(provider.name(), "azure-cognitive-services");
2171 assert_eq!(provider.api(), "azure-openai");
2172 assert!(!provider.model_id().is_empty());
2173 }
2174
2175 #[test]
2176 fn create_provider_cloudflare_workers_ai_by_name() {
2177 let entry = model_entry(
2178 "cloudflare-workers-ai",
2179 "",
2180 "@cf/meta/llama-3.1-8b-instruct",
2181 "https://api.cloudflare.com/client/v4/accounts/test-account/ai/v1",
2182 );
2183 let provider = create_provider(&entry, None).expect("cloudflare workers provider");
2184 assert_eq!(provider.name(), "cloudflare-workers-ai");
2185 assert_eq!(provider.api(), "openai-completions");
2186 assert_eq!(provider.model_id(), "@cf/meta/llama-3.1-8b-instruct");
2187 }
2188
2189 #[test]
2190 fn create_provider_cloudflare_ai_gateway_by_name() {
2191 let entry = model_entry(
2192 "cloudflare-ai-gateway",
2193 "",
2194 "gpt-4o-mini",
2195 "https://gateway.ai.cloudflare.com/v1/account-id/gateway-id/openai",
2196 );
2197 let provider = create_provider(&entry, None).expect("cloudflare gateway provider");
2198 assert_eq!(provider.name(), "cloudflare-ai-gateway");
2199 assert_eq!(provider.api(), "openai-completions");
2200 assert_eq!(provider.model_id(), "gpt-4o-mini");
2201 }
2202
2203 #[test]
2206 fn create_provider_falls_back_to_api_anthropic_messages() {
2207 let entry = model_entry(
2208 "custom-anthropic",
2209 "anthropic-messages",
2210 "my-model",
2211 "https://custom.api.com",
2212 );
2213 let provider = create_provider(&entry, None).expect("fallback anthropic provider");
2214 assert_eq!(provider.model_id(), "my-model");
2216 }
2217
2218 #[test]
2219 fn create_provider_falls_back_to_api_openai_completions() {
2220 let entry = model_entry(
2221 "my-openai-compat",
2222 "openai-completions",
2223 "local-model",
2224 "http://localhost:8080/v1",
2225 );
2226 let provider = create_provider(&entry, None).expect("fallback openai completions");
2227 assert_eq!(provider.model_id(), "local-model");
2228 }
2229
2230 #[test]
2231 fn create_provider_falls_back_to_api_openai_responses() {
2232 let entry = model_entry(
2233 "my-openai-compat",
2234 "openai-responses",
2235 "local-model",
2236 "http://localhost:8080/v1",
2237 );
2238 let provider = create_provider(&entry, None).expect("fallback openai responses");
2239 assert_eq!(provider.model_id(), "local-model");
2240 }
2241
2242 #[test]
2243 fn create_provider_falls_back_to_api_cohere_chat() {
2244 let entry = model_entry(
2245 "custom-cohere",
2246 "cohere-chat",
2247 "custom-r",
2248 "https://custom-cohere.api.com/v2",
2249 );
2250 let provider = create_provider(&entry, None).expect("fallback cohere provider");
2251 assert_eq!(provider.model_id(), "custom-r");
2252 }
2253
2254 #[test]
2255 fn create_provider_falls_back_to_api_google() {
2256 let entry = model_entry(
2257 "custom-google",
2258 "google-generative-ai",
2259 "custom-gemini",
2260 "https://custom.google.com",
2261 );
2262 let provider = create_provider(&entry, None).expect("fallback google provider");
2263 assert_eq!(provider.model_id(), "custom-gemini");
2264 }
2265
2266 #[test]
2267 fn resolve_provider_route_copilot_routes_correctly() {
2268 let entry = model_entry("github-copilot", "", "gpt-4o", "");
2269 let (route, canonical, _api) = resolve_provider_route(&entry).expect("copilot route");
2270 assert_eq!(route, ProviderRouteKind::NativeCopilot);
2271 assert_eq!(canonical, "github-copilot");
2272 }
2273
2274 #[test]
2275 fn resolve_provider_route_copilot_alias_routes_correctly() {
2276 let entry = model_entry("copilot", "", "gpt-4o", "");
2277 let (route, canonical, _api) = resolve_provider_route(&entry).expect("copilot alias route");
2278 assert_eq!(route, ProviderRouteKind::NativeCopilot);
2279 assert_eq!(canonical, "github-copilot");
2280 }
2281
2282 #[test]
2283 fn create_provider_unknown_provider_and_api_returns_error() {
2284 let entry = model_entry(
2285 "totally-unknown",
2286 "unknown-api",
2287 "some-model",
2288 "https://example.com",
2289 );
2290 let Err(err) = create_provider(&entry, None) else {
2291 panic!();
2292 };
2293 let msg = err.to_string();
2294 assert!(
2295 msg.contains("not implemented"),
2296 "expected 'not implemented' message, got: {msg}"
2297 );
2298 }
2299
2300 #[test]
2303 fn normalize_anthropic_base_appends_v1_messages() {
2304 assert_eq!(
2305 normalize_anthropic_base("https://api.anthropic.com"),
2306 "https://api.anthropic.com/v1/messages"
2307 );
2308 }
2309
2310 #[test]
2311 fn normalize_anthropic_base_keeps_existing_v1_messages() {
2312 assert_eq!(
2313 normalize_anthropic_base("https://api.anthropic.com/v1/messages"),
2314 "https://api.anthropic.com/v1/messages"
2315 );
2316 }
2317
2318 #[test]
2319 fn normalize_anthropic_base_strips_trailing_slash() {
2320 assert_eq!(
2321 normalize_anthropic_base("https://api.anthropic.com/"),
2322 "https://api.anthropic.com/v1/messages"
2323 );
2324 }
2325
2326 #[test]
2327 fn normalize_anthropic_base_empty_uses_default() {
2328 assert_eq!(
2329 normalize_anthropic_base(" "),
2330 "https://api.anthropic.com/v1/messages"
2331 );
2332 }
2333
2334 #[test]
2335 fn normalize_anthropic_base_preserves_query_and_fragment() {
2336 assert_eq!(
2337 normalize_anthropic_base("https://api.anthropic.com/?via=proxy#frag"),
2338 "https://api.anthropic.com/v1/messages?via=proxy#frag"
2339 );
2340 }
2341
2342 #[test]
2343 fn normalize_anthropic_base_handles_opaque_url_fallback() {
2344 assert_eq!(
2345 normalize_anthropic_base("data:text/plain,hello"),
2346 "data:text/plain,hello/v1/messages"
2347 );
2348 }
2349
2350 #[test]
2353 fn normalize_openai_base_appends_chat_completions_to_v1() {
2354 assert_eq!(
2355 normalize_openai_base("https://api.openai.com/v1"),
2356 "https://api.openai.com/v1/chat/completions"
2357 );
2358 }
2359
2360 #[test]
2361 fn normalize_openai_base_keeps_existing_chat_completions() {
2362 assert_eq!(
2363 normalize_openai_base("https://api.openai.com/v1/chat/completions"),
2364 "https://api.openai.com/v1/chat/completions"
2365 );
2366 }
2367
2368 #[test]
2369 fn normalize_openai_base_strips_trailing_slash() {
2370 assert_eq!(
2371 normalize_openai_base("https://api.openai.com/v1/"),
2372 "https://api.openai.com/v1/chat/completions"
2373 );
2374 }
2375
2376 #[test]
2377 fn normalize_openai_base_strips_responses_suffix() {
2378 assert_eq!(
2379 normalize_openai_base("https://api.openai.com/v1/responses"),
2380 "https://api.openai.com/v1/chat/completions"
2381 );
2382 }
2383
2384 #[test]
2385 fn normalize_openai_base_official_bare_url_gets_v1_chat_completions() {
2386 assert_eq!(
2387 normalize_openai_base("https://api.openai.com"),
2388 "https://api.openai.com/v1/chat/completions"
2389 );
2390 }
2391
2392 #[test]
2393 fn normalize_openai_base_official_default_port_gets_v1_chat_completions() {
2394 assert_eq!(
2395 normalize_openai_base("https://api.openai.com:443"),
2396 "https://api.openai.com/v1/chat/completions"
2397 );
2398 }
2399
2400 #[test]
2401 fn normalize_openai_base_strips_non_v1_official_responses_suffix() {
2402 assert_eq!(
2403 normalize_openai_base("https://api.openai.com/responses"),
2404 "https://api.openai.com/v1/chat/completions"
2405 );
2406 }
2407
2408 #[test]
2409 fn normalize_openai_base_custom_bare_url_gets_chat_completions() {
2410 assert_eq!(
2411 normalize_openai_base("https://my-llm-proxy.com"),
2412 "https://my-llm-proxy.com/chat/completions"
2413 );
2414 }
2415
2416 #[test]
2417 fn normalize_openai_base_preserves_query_and_fragment_on_official_origin() {
2418 assert_eq!(
2419 normalize_openai_base("https://api.openai.com:443/?via=proxy#frag"),
2420 "https://api.openai.com/v1/chat/completions?via=proxy#frag"
2421 );
2422 }
2423
2424 #[test]
2425 fn normalize_openai_base_empty_uses_default() {
2426 assert_eq!(
2427 normalize_openai_base(""),
2428 "https://api.openai.com/v1/chat/completions"
2429 );
2430 }
2431
2432 #[test]
2433 fn normalize_openai_base_handles_opaque_url_fallback() {
2434 assert_eq!(
2435 normalize_openai_base("data:text/plain,hello"),
2436 "data:text/plain,hello/chat/completions"
2437 );
2438 }
2439
2440 #[test]
2443 fn normalize_responses_appends_responses_to_v1() {
2444 assert_eq!(
2445 normalize_openai_responses_base("https://api.openai.com/v1"),
2446 "https://api.openai.com/v1/responses"
2447 );
2448 }
2449
2450 #[test]
2451 fn normalize_responses_keeps_existing_responses() {
2452 assert_eq!(
2453 normalize_openai_responses_base("https://api.openai.com/v1/responses"),
2454 "https://api.openai.com/v1/responses"
2455 );
2456 }
2457
2458 #[test]
2459 fn normalize_responses_strips_trailing_slash() {
2460 assert_eq!(
2461 normalize_openai_responses_base("https://api.openai.com/v1/"),
2462 "https://api.openai.com/v1/responses"
2463 );
2464 }
2465
2466 #[test]
2467 fn normalize_responses_strips_chat_completions_suffix() {
2468 assert_eq!(
2469 normalize_openai_responses_base("https://api.openai.com/v1/chat/completions"),
2470 "https://api.openai.com/v1/responses"
2471 );
2472 }
2473
2474 #[test]
2475 fn normalize_responses_official_bare_url_gets_v1_responses() {
2476 assert_eq!(
2477 normalize_openai_responses_base("https://api.openai.com"),
2478 "https://api.openai.com/v1/responses"
2479 );
2480 }
2481
2482 #[test]
2483 fn normalize_responses_official_default_port_gets_v1_responses() {
2484 assert_eq!(
2485 normalize_openai_responses_base("https://api.openai.com:443"),
2486 "https://api.openai.com/v1/responses"
2487 );
2488 }
2489
2490 #[test]
2491 fn normalize_responses_strips_non_v1_official_chat_completions_suffix() {
2492 assert_eq!(
2493 normalize_openai_responses_base("https://api.openai.com/chat/completions"),
2494 "https://api.openai.com/v1/responses"
2495 );
2496 }
2497
2498 #[test]
2499 fn normalize_responses_custom_bare_url_gets_responses() {
2500 assert_eq!(
2501 normalize_openai_responses_base("https://my-llm-proxy.com"),
2502 "https://my-llm-proxy.com/responses"
2503 );
2504 }
2505
2506 #[test]
2507 fn normalize_responses_preserves_query_and_fragment() {
2508 assert_eq!(
2509 normalize_openai_responses_base("https://my-llm-proxy.com/api?via=proxy#frag"),
2510 "https://my-llm-proxy.com/api/responses?via=proxy#frag"
2511 );
2512 }
2513
2514 #[test]
2515 fn normalize_responses_preserves_query_and_fragment_on_official_origin() {
2516 assert_eq!(
2517 normalize_openai_responses_base("https://api.openai.com:443/?via=proxy#frag"),
2518 "https://api.openai.com/v1/responses?via=proxy#frag"
2519 );
2520 }
2521
2522 #[test]
2523 fn normalize_responses_base_empty_uses_default() {
2524 assert_eq!(
2525 normalize_openai_responses_base(" "),
2526 "https://api.openai.com/v1/responses"
2527 );
2528 }
2529
2530 #[test]
2531 fn normalize_responses_base_handles_opaque_url_fallback() {
2532 assert_eq!(
2533 normalize_openai_responses_base("data:text/plain,hello"),
2534 "data:text/plain,hello/responses"
2535 );
2536 }
2537
2538 #[test]
2541 fn normalize_codex_responses_base_empty_uses_default() {
2542 assert_eq!(
2543 normalize_openai_codex_responses_base(""),
2544 openai_responses::CODEX_RESPONSES_API_URL
2545 );
2546 }
2547
2548 #[test]
2549 fn normalize_codex_responses_base_keeps_existing_suffix() {
2550 assert_eq!(
2551 normalize_openai_codex_responses_base(
2552 "https://chatgpt.com/backend-api/codex/responses"
2553 ),
2554 "https://chatgpt.com/backend-api/codex/responses"
2555 );
2556 }
2557
2558 #[test]
2559 fn normalize_codex_responses_base_appends_suffix_from_backend_api() {
2560 assert_eq!(
2561 normalize_openai_codex_responses_base("https://chatgpt.com/backend-api"),
2562 "https://chatgpt.com/backend-api/codex/responses"
2563 );
2564 }
2565
2566 #[test]
2567 fn normalize_codex_responses_base_preserves_query_and_fragment() {
2568 assert_eq!(
2569 normalize_openai_codex_responses_base("https://chatgpt.com/backend-api?via=proxy#frag"),
2570 "https://chatgpt.com/backend-api/codex/responses?via=proxy#frag"
2571 );
2572 }
2573
2574 #[test]
2575 fn normalize_codex_responses_base_handles_opaque_url_fallback() {
2576 assert_eq!(
2577 normalize_openai_codex_responses_base("data:text/plain,hello"),
2578 "data:text/plain,hello/backend-api/codex/responses"
2579 );
2580 }
2581
2582 #[test]
2585 fn normalize_cohere_appends_chat_to_v2() {
2586 assert_eq!(
2587 normalize_cohere_base("https://api.cohere.com/v2"),
2588 "https://api.cohere.com/v2/chat"
2589 );
2590 }
2591
2592 #[test]
2593 fn normalize_cohere_keeps_existing_chat() {
2594 assert_eq!(
2595 normalize_cohere_base("https://api.cohere.com/v2/chat"),
2596 "https://api.cohere.com/v2/chat"
2597 );
2598 }
2599
2600 #[test]
2601 fn normalize_cohere_strips_trailing_slash() {
2602 assert_eq!(
2603 normalize_cohere_base("https://api.cohere.com/v2/"),
2604 "https://api.cohere.com/v2/chat"
2605 );
2606 }
2607
2608 #[test]
2609 fn normalize_cohere_official_bare_url_gets_v2_chat() {
2610 assert_eq!(
2611 normalize_cohere_base("https://api.cohere.com"),
2612 "https://api.cohere.com/v2/chat"
2613 );
2614 }
2615
2616 #[test]
2617 fn normalize_cohere_official_default_port_gets_v2_chat() {
2618 assert_eq!(
2619 normalize_cohere_base("https://api.cohere.com:443"),
2620 "https://api.cohere.com/v2/chat"
2621 );
2622 }
2623
2624 #[test]
2625 fn normalize_cohere_custom_bare_url_gets_chat() {
2626 assert_eq!(
2627 normalize_cohere_base("https://custom-cohere.example.com"),
2628 "https://custom-cohere.example.com/chat"
2629 );
2630 }
2631
2632 #[test]
2633 fn normalize_cohere_preserves_query_and_fragment() {
2634 assert_eq!(
2635 normalize_cohere_base("https://custom-cohere.example.com/v2?tenant=test#frag"),
2636 "https://custom-cohere.example.com/v2/chat?tenant=test#frag"
2637 );
2638 }
2639
2640 #[test]
2641 fn normalize_cohere_preserves_query_and_fragment_on_official_origin() {
2642 assert_eq!(
2643 normalize_cohere_base("https://api.cohere.com:443/?tenant=test#frag"),
2644 "https://api.cohere.com/v2/chat?tenant=test#frag"
2645 );
2646 }
2647
2648 #[test]
2649 fn normalize_cohere_base_empty_uses_default() {
2650 assert_eq!(normalize_cohere_base(""), "https://api.cohere.com/v2/chat");
2651 }
2652
2653 #[test]
2654 fn normalize_cohere_base_handles_opaque_url_fallback() {
2655 assert_eq!(
2656 normalize_cohere_base("data:text/plain,hello"),
2657 "data:text/plain,hello/chat"
2658 );
2659 }
2660
2661 mod proptests {
2662 use super::*;
2663 use proptest::prelude::*;
2664
2665 proptest! {
2666 #[test]
2667 fn normalize_anthropic_base_is_idempotent_and_targets_v1_messages(
2668 base in "[A-Za-z0-9:/._-]{1,96}"
2669 ) {
2670 let normalized = normalize_anthropic_base(&base);
2671 prop_assert!(normalized.ends_with("/v1/messages"));
2672 prop_assert_eq!(normalize_anthropic_base(&normalized), normalized);
2673 }
2674
2675 #[test]
2676 fn normalize_openai_base_is_idempotent_and_targets_chat_completions(
2677 base in "[A-Za-z0-9:/._-]{1,96}"
2678 ) {
2679 let normalized = normalize_openai_base(&base);
2680 prop_assert!(normalized.ends_with("/chat/completions"));
2681 prop_assert_eq!(normalize_openai_base(&normalized), normalized);
2682 }
2683
2684 #[test]
2685 fn normalize_openai_responses_base_is_idempotent_and_targets_responses(
2686 base in "[A-Za-z0-9:/._-]{1,96}"
2687 ) {
2688 let normalized = normalize_openai_responses_base(&base);
2689 prop_assert!(normalized.ends_with("/responses"));
2690 prop_assert_eq!(normalize_openai_responses_base(&normalized), normalized);
2691 }
2692
2693 #[test]
2694 fn normalize_cohere_base_is_idempotent_and_targets_chat(
2695 base in "[A-Za-z0-9:/._-]{1,96}"
2696 ) {
2697 let normalized = normalize_cohere_base(&base);
2698 prop_assert!(normalized.ends_with("/chat"));
2699 prop_assert_eq!(normalize_cohere_base(&normalized), normalized);
2700 }
2701
2702 #[test]
2703 fn normalize_openai_base_rewrites_responses_suffix(
2704 host in "[a-z0-9-]{1,32}",
2705 trailing_slashes in 0usize..4
2706 ) {
2707 let base = format!(
2708 "https://{host}.example/v1/responses{}",
2709 "/".repeat(trailing_slashes)
2710 );
2711 prop_assert_eq!(
2712 normalize_openai_base(&base),
2713 format!("https://{host}.example/v1/chat/completions")
2714 );
2715 }
2716
2717 #[test]
2718 fn normalize_openai_responses_base_rewrites_chat_completions_suffix(
2719 host in "[a-z0-9-]{1,32}",
2720 trailing_slashes in 0usize..4
2721 ) {
2722 let base = format!(
2723 "https://{host}.example/v1/chat/completions{}",
2724 "/".repeat(trailing_slashes)
2725 );
2726 prop_assert_eq!(
2727 normalize_openai_responses_base(&base),
2728 format!("https://{host}.example/v1/responses")
2729 );
2730 }
2731 }
2732 }
2733
2734 use crate::models::CompatConfig;
2737
2738 fn compat_with_custom_headers() -> CompatConfig {
2739 let mut custom = HashMap::new();
2740 custom.insert("X-Custom-Header".to_string(), "test-value".to_string());
2741 custom.insert("X-Provider-Tag".to_string(), "override".to_string());
2742 CompatConfig {
2743 custom_headers: Some(custom),
2744 ..Default::default()
2745 }
2746 }
2747
2748 fn model_entry_with_compat(
2749 provider: &str,
2750 api: &str,
2751 model_id: &str,
2752 base_url: &str,
2753 compat: CompatConfig,
2754 ) -> ModelEntry {
2755 let mut entry = model_entry(provider, api, model_id, base_url);
2756 entry.compat = Some(compat);
2757 entry
2758 }
2759
2760 #[test]
2761 fn create_provider_anthropic_accepts_compat_config() {
2762 let entry = model_entry_with_compat(
2763 "anthropic",
2764 "anthropic-messages",
2765 "claude-sonnet-4-5",
2766 "https://api.anthropic.com",
2767 compat_with_custom_headers(),
2768 );
2769 let provider = create_provider(&entry, None).expect("anthropic with compat");
2770 assert_eq!(provider.name(), "anthropic");
2771 }
2772
2773 #[test]
2774 fn create_provider_openai_completions_accepts_compat_config() {
2775 let entry = model_entry_with_compat(
2776 "openai",
2777 "openai-completions",
2778 "gpt-4o",
2779 "https://api.openai.com/v1",
2780 CompatConfig {
2781 max_tokens_field: Some("max_completion_tokens".to_string()),
2782 system_role_name: Some("developer".to_string()),
2783 supports_tools: Some(false),
2784 ..Default::default()
2785 },
2786 );
2787 let provider = create_provider(&entry, None).expect("openai completions with compat");
2788 assert_eq!(provider.name(), "openai");
2789 }
2790
2791 #[test]
2792 fn create_provider_openai_responses_accepts_compat_config() {
2793 let entry = model_entry_with_compat(
2794 "openai",
2795 "openai-responses",
2796 "gpt-4o",
2797 "https://api.openai.com/v1",
2798 compat_with_custom_headers(),
2799 );
2800 let provider = create_provider(&entry, None).expect("openai responses with compat");
2801 assert_eq!(provider.name(), "openai");
2802 }
2803
2804 #[test]
2805 fn create_provider_cohere_accepts_compat_config() {
2806 let entry = model_entry_with_compat(
2807 "cohere",
2808 "cohere-chat",
2809 "command-r-plus",
2810 "https://api.cohere.com/v2",
2811 compat_with_custom_headers(),
2812 );
2813 let provider = create_provider(&entry, None).expect("cohere with compat");
2814 assert_eq!(provider.name(), "cohere");
2815 }
2816
2817 #[test]
2818 fn create_provider_google_accepts_compat_config() {
2819 let entry = model_entry_with_compat(
2820 "google",
2821 "google-generative-ai",
2822 "gemini-2.0-flash",
2823 "https://generativelanguage.googleapis.com",
2824 compat_with_custom_headers(),
2825 );
2826 let provider = create_provider(&entry, None).expect("google with compat");
2827 assert_eq!(provider.name(), "google");
2828 }
2829
2830 #[test]
2831 fn create_provider_fallback_api_routes_accept_compat_config() {
2832 let entry = model_entry_with_compat(
2834 "custom-anthropic",
2835 "anthropic-messages",
2836 "my-model",
2837 "https://custom.api.com",
2838 compat_with_custom_headers(),
2839 );
2840 let provider = create_provider(&entry, None).expect("fallback anthropic with compat");
2841 assert_eq!(provider.model_id(), "my-model");
2842
2843 let entry = model_entry_with_compat(
2845 "my-groq-clone",
2846 "openai-completions",
2847 "llama-3.1",
2848 "http://localhost:8080/v1",
2849 compat_with_custom_headers(),
2850 );
2851 let provider = create_provider(&entry, None).expect("fallback openai with compat");
2852 assert_eq!(provider.model_id(), "llama-3.1");
2853
2854 let entry = model_entry_with_compat(
2856 "custom-cohere",
2857 "cohere-chat",
2858 "custom-r",
2859 "https://custom-cohere.api.com/v2",
2860 compat_with_custom_headers(),
2861 );
2862 let provider = create_provider(&entry, None).expect("fallback cohere with compat");
2863 assert_eq!(provider.model_id(), "custom-r");
2864
2865 let entry = model_entry_with_compat(
2867 "custom-google",
2868 "google-generative-ai",
2869 "custom-gemini",
2870 "https://custom.google.com",
2871 compat_with_custom_headers(),
2872 );
2873 let provider = create_provider(&entry, None).expect("fallback google with compat");
2874 assert_eq!(provider.model_id(), "custom-gemini");
2875 }
2876
2877 #[test]
2880 fn resolve_provider_route_google_vertex_routes_to_native() {
2881 let entry = model_entry(
2882 "google-vertex",
2883 "google-vertex",
2884 "gemini-2.0-flash",
2885 "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2886 );
2887 let (route, canonical_provider, effective_api) =
2888 resolve_provider_route(&entry).expect("resolve google-vertex route");
2889 assert_eq!(route, ProviderRouteKind::NativeGoogleVertex);
2890 assert_eq!(canonical_provider, "google-vertex");
2891 assert_eq!(effective_api, "google-vertex");
2892 }
2893
2894 #[test]
2895 fn resolve_provider_route_vertexai_alias_routes_to_native() {
2896 let entry = model_entry(
2897 "vertexai",
2898 "google-vertex",
2899 "gemini-2.0-flash",
2900 "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2901 );
2902 let (route, canonical_provider, effective_api) =
2903 resolve_provider_route(&entry).expect("resolve vertexai alias route");
2904 assert_eq!(route, ProviderRouteKind::NativeGoogleVertex);
2905 assert_eq!(canonical_provider, "google-vertex");
2906 assert_eq!(effective_api, "google-vertex");
2907 }
2908
2909 #[test]
2910 fn resolve_provider_route_google_vertex_api_fallback() {
2911 let entry = model_entry(
2913 "custom-vertex",
2914 "google-vertex",
2915 "gemini-2.0-flash",
2916 "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2917 );
2918 let (route, _canonical_provider, effective_api) =
2919 resolve_provider_route(&entry).expect("resolve google-vertex fallback");
2920 assert_eq!(route, ProviderRouteKind::NativeGoogleVertex);
2921 assert_eq!(effective_api, "google-vertex");
2922 }
2923
2924 #[test]
2925 fn create_provider_google_vertex_from_full_url() {
2926 let entry = model_entry(
2927 "google-vertex",
2928 "google-vertex",
2929 "gemini-2.0-flash",
2930 "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2931 );
2932 let provider = create_provider(&entry, None).expect("google-vertex from full URL");
2933 assert_eq!(provider.name(), "google-vertex");
2934 assert_eq!(provider.api(), "google-vertex");
2935 assert_eq!(provider.model_id(), "gemini-2.0-flash");
2936 }
2937
2938 #[test]
2939 fn create_provider_google_vertex_anthropic_publisher() {
2940 let entry = model_entry(
2941 "google-vertex",
2942 "google-vertex",
2943 "claude-sonnet-4-5",
2944 "https://us-east5-aiplatform.googleapis.com/v1/projects/my-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5",
2945 );
2946 let provider =
2947 create_provider(&entry, None).expect("google-vertex with anthropic publisher");
2948 assert_eq!(provider.name(), "google-vertex");
2949 assert_eq!(provider.model_id(), "claude-sonnet-4-5");
2950 }
2951
2952 #[test]
2953 fn create_provider_google_vertex_accepts_compat_config() {
2954 let entry = model_entry_with_compat(
2955 "google-vertex",
2956 "google-vertex",
2957 "gemini-2.0-flash",
2958 "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2959 compat_with_custom_headers(),
2960 );
2961 let provider = create_provider(&entry, None).expect("google-vertex with compat");
2962 assert_eq!(provider.name(), "google-vertex");
2963 }
2964
2965 #[test]
2966 fn create_provider_compat_none_accepted_by_all_routes() {
2967 let routes = [
2969 (
2970 "anthropic",
2971 "anthropic-messages",
2972 "https://api.anthropic.com",
2973 ),
2974 ("openai", "openai-completions", "https://api.openai.com/v1"),
2975 ("openai", "openai-responses", "https://api.openai.com/v1"),
2976 ("cohere", "cohere-chat", "https://api.cohere.com/v2"),
2977 (
2978 "google",
2979 "google-generative-ai",
2980 "https://generativelanguage.googleapis.com",
2981 ),
2982 (
2983 "google-vertex",
2984 "google-vertex",
2985 "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/test-model",
2986 ),
2987 ];
2988 for (provider, api, base_url) in routes {
2989 let entry = model_entry(provider, api, "test-model", base_url);
2990 assert!(
2991 entry.compat.is_none(),
2992 "expected None compat for {provider}"
2993 );
2994 let result = create_provider(&entry, None);
2995 assert!(
2996 result.is_ok(),
2997 "create_provider failed for {provider} with None compat: {:?}",
2998 result.err()
2999 );
3000 }
3001 }
3002}