1use crate::error::{Error, Result};
7use crate::extensions::{ExtensionManager, ExtensionRuntimeHandle};
8use crate::http::client::Client;
9use crate::model::{
10 AssistantMessage, AssistantMessageEvent, ContentBlock, StopReason, TextContent, Usage,
11};
12use crate::models::ModelEntry;
13use crate::provider::{Context, Provider, StreamEvent, StreamOptions};
14use crate::provider_metadata::{
15 PROVIDER_METADATA, canonical_provider_id, provider_routing_defaults,
16};
17use crate::vcr::{VCR_ENV_MODE, VcrRecorder};
18use async_trait::async_trait;
19use chrono::Utc;
20use futures::stream;
21use futures::stream::Stream;
22use serde_json::Value;
23use std::env;
24use std::pin::Pin;
25use std::sync::Arc;
26use url::Url;
27
28pub mod anthropic;
29pub mod azure;
30pub mod bedrock;
31pub mod cohere;
32pub mod copilot;
33pub mod gemini;
34pub mod gitlab;
35pub mod openai;
36pub mod openai_responses;
37pub mod vertex;
38
39fn vcr_client_if_enabled() -> Result<Option<Client>> {
40 if env::var(VCR_ENV_MODE).is_err() {
41 return Ok(None);
42 }
43
44 let test_name = env::var("PI_VCR_TEST_NAME").unwrap_or_else(|_| "pi_runtime".to_string());
45 let recorder = VcrRecorder::new(&test_name)?;
46 Ok(Some(Client::new().with_vcr(recorder)))
47}
48
49struct ExtensionStreamSimpleProvider {
50 model: crate::provider::Model,
51 runtime: ExtensionRuntimeHandle,
52}
53
54struct ExtensionStreamSimpleState {
55 runtime: ExtensionRuntimeHandle,
56 stream_id: Option<String>,
57 model_id: String,
58 provider: String,
59 api: String,
60 accumulated_text: String,
61 last_message: Option<AssistantMessage>,
62 string_chunk_started: bool,
64 pending_events: std::collections::VecDeque<StreamEvent>,
66}
67
68impl Drop for ExtensionStreamSimpleState {
69 fn drop(&mut self) {
70 if let Some(stream_id) = self.stream_id.take() {
71 self.runtime
72 .provider_stream_simple_cancel_best_effort(stream_id);
73 }
74 }
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78enum ProviderRouteKind {
79 NativeAnthropic,
80 NativeOpenAICompletions,
81 NativeOpenAIResponses,
82 NativeOpenAICodexResponses,
83 NativeCohere,
84 NativeGoogle,
85 NativeGoogleGeminiCli,
86 NativeGoogleVertex,
87 NativeBedrock,
88 NativeAzure,
89 NativeCopilot,
90 NativeGitlab,
91 ApiAnthropicMessages,
92 ApiOpenAICompletions,
93 ApiOpenAIResponses,
94 ApiOpenAICodexResponses,
95 ApiCohereChat,
96 ApiGoogleGenerativeAi,
97 ApiGoogleGeminiCli,
98}
99
100impl ProviderRouteKind {
101 const fn as_str(self) -> &'static str {
102 match self {
103 Self::NativeAnthropic => "native:anthropic",
104 Self::NativeOpenAICompletions => "native:openai-completions",
105 Self::NativeOpenAIResponses => "native:openai-responses",
106 Self::NativeOpenAICodexResponses => "native:openai-codex-responses",
107 Self::NativeCohere => "native:cohere",
108 Self::NativeGoogle => "native:google",
109 Self::NativeGoogleGeminiCli => "native:google-gemini-cli",
110 Self::NativeGoogleVertex => "native:google-vertex",
111 Self::NativeBedrock => "native:amazon-bedrock",
112 Self::NativeAzure => "native:azure-openai",
113 Self::NativeCopilot => "native:github-copilot",
114 Self::NativeGitlab => "native:gitlab",
115 Self::ApiAnthropicMessages => "api:anthropic-messages",
116 Self::ApiOpenAICompletions => "api:openai-completions",
117 Self::ApiOpenAIResponses => "api:openai-responses",
118 Self::ApiOpenAICodexResponses => "api:openai-codex-responses",
119 Self::ApiCohereChat => "api:cohere-chat",
120 Self::ApiGoogleGenerativeAi => "api:google-generative-ai",
121 Self::ApiGoogleGeminiCli => "api:google-gemini-cli",
122 }
123 }
124}
125
126fn resolve_provider_route(entry: &ModelEntry) -> Result<(ProviderRouteKind, String, String)> {
127 let canonical_provider =
128 canonical_provider_id(&entry.model.provider).unwrap_or(entry.model.provider.as_str());
129 let schema_api = provider_routing_defaults(&entry.model.provider).map(|defaults| defaults.api);
130 let effective_api = if entry.model.api.is_empty() {
131 schema_api.unwrap_or_default().to_string()
132 } else {
133 entry.model.api.clone()
134 };
135
136 let route = match canonical_provider {
137 "anthropic" => ProviderRouteKind::NativeAnthropic,
138 "openai" => {
139 if effective_api == "openai-completions" {
140 ProviderRouteKind::NativeOpenAICompletions
141 } else {
142 ProviderRouteKind::NativeOpenAIResponses
143 }
144 }
145 "openai-codex" => ProviderRouteKind::NativeOpenAICodexResponses,
146 "cohere" => ProviderRouteKind::NativeCohere,
147 "google" => ProviderRouteKind::NativeGoogle,
148 "google-gemini-cli" | "google-antigravity" => ProviderRouteKind::NativeGoogleGeminiCli,
149 "google-vertex" | "vertexai" => ProviderRouteKind::NativeGoogleVertex,
150 "amazon-bedrock" | "bedrock" => ProviderRouteKind::NativeBedrock,
151 "azure-openai" | "azure" | "azure-cognitive-services" | "azure-openai-responses" => {
152 ProviderRouteKind::NativeAzure
153 }
154 "github-copilot" | "copilot" => ProviderRouteKind::NativeCopilot,
155 "gitlab" | "gitlab-duo" => ProviderRouteKind::NativeGitlab,
156 _ => match effective_api.as_str() {
157 "anthropic-messages" => ProviderRouteKind::ApiAnthropicMessages,
158 "openai-completions" => ProviderRouteKind::ApiOpenAICompletions,
159 "openai-responses" => ProviderRouteKind::ApiOpenAIResponses,
160 "openai-codex-responses" => ProviderRouteKind::ApiOpenAICodexResponses,
161 "cohere-chat" => ProviderRouteKind::ApiCohereChat,
162 "google-generative-ai" => ProviderRouteKind::ApiGoogleGenerativeAi,
163 "google-gemini-cli" => ProviderRouteKind::ApiGoogleGeminiCli,
164 "google-vertex" => ProviderRouteKind::NativeGoogleVertex,
165 "bedrock-converse-stream" => ProviderRouteKind::NativeBedrock,
166 "azure-openai-responses" => ProviderRouteKind::NativeAzure,
167 _ => {
168 let suggestions = suggest_similar_providers(&entry.model.provider);
169 let msg = if suggestions.is_empty() {
170 format!("Provider not implemented (api: {effective_api})")
171 } else {
172 format!(
173 "Provider not implemented (api: {effective_api}). Did you mean: {}?",
174 suggestions.join(", ")
175 )
176 };
177 return Err(Error::provider(&entry.model.provider, msg));
178 }
179 },
180 };
181
182 Ok((route, canonical_provider.to_string(), effective_api))
183}
184
185fn edit_distance(a: &[u8], b: &[u8]) -> usize {
188 let (short, long) = if a.len() <= b.len() { (a, b) } else { (b, a) };
189 let mut row: Vec<usize> = (0..=short.len()).collect();
190 for (i, &lb) in long.iter().enumerate() {
191 let mut prev = i;
192 row[0] = i + 1;
193 for (j, &sb) in short.iter().enumerate() {
194 let cost = usize::from(lb != sb);
195 let val = (row[j + 1] + 1).min(row[j] + 1).min(prev + cost);
196 prev = row[j + 1];
197 row[j + 1] = val;
198 }
199 }
200 row[short.len()]
201}
202
203const fn max_edit_distance(input_len: usize) -> usize {
206 match input_len {
207 0..=2 => 0,
208 3..=5 => 1,
209 6..=9 => 2,
210 _ => 3,
211 }
212}
213
214fn suggest_similar_providers(input: &str) -> Vec<String> {
218 let needle = input.to_lowercase();
219 let needle_bytes = needle.as_bytes();
220 let threshold = max_edit_distance(needle.len());
221 let mut matches: Vec<(usize, String)> = Vec::new();
222
223 for meta in PROVIDER_METADATA {
224 let names: Vec<&str> = std::iter::once(meta.canonical_id)
225 .chain(meta.aliases.iter().copied())
226 .collect();
227 let mut matched = false;
228 for name in &names {
229 let haystack = name.to_lowercase();
230 if haystack.starts_with(&needle) || needle.starts_with(&haystack) {
232 matches.push((0, meta.canonical_id.to_string()));
233 matched = true;
234 break;
235 }
236 if haystack.contains(&needle) || needle.contains(&haystack) {
238 matches.push((1, meta.canonical_id.to_string()));
239 matched = true;
240 break;
241 }
242 }
243 if matched {
244 continue;
245 }
246 if threshold > 0 {
248 let mut best_dist = usize::MAX;
249 for name in &names {
250 let haystack = name.to_lowercase();
251 let dist = edit_distance(needle_bytes, haystack.as_bytes());
252 best_dist = best_dist.min(dist);
253 }
254 if best_dist <= threshold {
255 matches.push((
257 2_usize.wrapping_add(best_dist),
258 meta.canonical_id.to_string(),
259 ));
260 }
261 }
262 }
263
264 matches.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
265 matches.dedup_by(|a, b| a.1 == b.1);
266 matches.truncate(3);
267 matches.into_iter().map(|(_, name)| name).collect()
268}
269
270const AZURE_OPENAI_RESOURCE_ENV: &str = "AZURE_OPENAI_RESOURCE";
271const AZURE_OPENAI_DEPLOYMENT_ENV: &str = "AZURE_OPENAI_DEPLOYMENT";
272const AZURE_OPENAI_API_VERSION_ENV: &str = "AZURE_OPENAI_API_VERSION";
273
274#[derive(Debug, Clone, PartialEq, Eq)]
275struct AzureProviderRuntime {
276 resource: String,
277 deployment: String,
278 api_version: String,
279 endpoint_url: String,
280}
281
282fn trim_non_empty(value: Option<String>) -> Option<String> {
283 value
284 .map(|v| v.trim().to_string())
285 .filter(|v| !v.is_empty())
286}
287
288fn parse_azure_resource_from_host(host: &str) -> Option<String> {
289 host.strip_suffix(".openai.azure.com")
290 .or_else(|| host.strip_suffix(".cognitiveservices.azure.com"))
291 .map(str::trim)
292 .filter(|value| !value.is_empty())
293 .map(ToString::to_string)
294}
295
296fn parse_azure_base_url_details(
297 base_url: &str,
298) -> Result<(String, Option<String>, Option<String>)> {
299 let url = Url::parse(base_url)
300 .map_err(|err| Error::config(format!("Invalid Azure base_url '{base_url}': {err}")))?;
301 let host = url.host_str().map(ToString::to_string).ok_or_else(|| {
302 Error::config(format!(
303 "Azure base_url is missing host information: '{base_url}'"
304 ))
305 })?;
306
307 let mut deployment = None;
308 if let Some(segments) = url.path_segments() {
309 let mut iter = segments;
310 while let Some(segment) = iter.next() {
311 if segment == "deployments" {
312 deployment = iter
313 .next()
314 .map(str::trim)
315 .filter(|value| !value.is_empty())
316 .map(ToString::to_string);
317 break;
318 }
319 }
320 }
321
322 let api_version = url
323 .query_pairs()
324 .find(|(key, _)| key == "api-version")
325 .map(|(_, value)| value.into_owned())
326 .filter(|value| !value.trim().is_empty());
327
328 Ok((host, deployment, api_version))
329}
330
331fn resolve_azure_provider_runtime(entry: &ModelEntry) -> Result<AzureProviderRuntime> {
332 resolve_azure_provider_runtime_with_env(entry, |name| env::var(name).ok())
333}
334
335fn resolve_azure_provider_runtime_with_env<F>(
336 entry: &ModelEntry,
337 mut env_lookup: F,
338) -> Result<AzureProviderRuntime>
339where
340 F: FnMut(&str) -> Option<String>,
341{
342 let base_url = entry.model.base_url.trim();
343 if base_url.is_empty() {
344 return Err(Error::config(format!(
345 "Missing Azure base_url for provider '{}'; expected https://<resource>.openai.azure.com or https://<resource>.cognitiveservices.azure.com",
346 entry.model.provider
347 )));
348 }
349
350 let (host, base_deployment, base_api_version) = parse_azure_base_url_details(base_url)?;
351 let host_resource = parse_azure_resource_from_host(&host);
352 let env_resource = trim_non_empty(env_lookup(AZURE_OPENAI_RESOURCE_ENV));
353 let resource = env_resource.or(host_resource).ok_or_else(|| {
354 Error::config(format!(
355 "Unable to resolve Azure resource for provider '{}'; set {AZURE_OPENAI_RESOURCE_ENV} or use an Azure host in base_url ('{base_url}')",
356 entry.model.provider
357 ))
358 })?;
359
360 let env_deployment = trim_non_empty(env_lookup(AZURE_OPENAI_DEPLOYMENT_ENV));
361 let model_deployment = {
362 let model_id = entry.model.id.trim();
363 (!model_id.is_empty()).then(|| model_id.to_string())
364 };
365 let deployment = env_deployment
366 .or(model_deployment)
367 .or(base_deployment)
368 .ok_or_else(|| {
369 Error::config(format!(
370 "Unable to resolve Azure deployment for provider '{}'; set {AZURE_OPENAI_DEPLOYMENT_ENV}, provide a non-empty model id, or include '/deployments/<name>' in base_url ('{base_url}')",
371 entry.model.provider
372 ))
373 })?;
374
375 let api_version = trim_non_empty(env_lookup(AZURE_OPENAI_API_VERSION_ENV))
376 .or(base_api_version)
377 .unwrap_or_else(|| azure::DEFAULT_API_VERSION.to_string());
378
379 let endpoint_host = if parse_azure_resource_from_host(&host).is_some() {
380 host
381 } else {
382 format!("{resource}.openai.azure.com")
383 };
384 let endpoint_url = format!(
385 "https://{endpoint_host}/openai/deployments/{deployment}/chat/completions?api-version={api_version}"
386 );
387
388 Ok(AzureProviderRuntime {
389 resource,
390 deployment,
391 api_version,
392 endpoint_url,
393 })
394}
395
396fn resolve_copilot_token(entry: &ModelEntry) -> Result<String> {
397 resolve_copilot_token_with_env(entry, |name| env::var(name).ok())
398}
399
400fn resolve_copilot_token_with_env<F>(entry: &ModelEntry, mut env_lookup: F) -> Result<String>
401where
402 F: FnMut(&str) -> Option<String>,
403{
404 let inline = entry
405 .api_key
406 .as_deref()
407 .map(str::trim)
408 .filter(|value| !value.is_empty())
409 .map(ToString::to_string);
410 let from_env = || {
411 env_lookup("GITHUB_COPILOT_API_KEY")
412 .or_else(|| env_lookup("GITHUB_TOKEN"))
413 .map(|value| value.trim().to_string())
414 .filter(|value| !value.is_empty())
415 };
416
417 inline.or_else(from_env).ok_or_else(|| {
418 Error::auth(
419 "GitHub Copilot requires login credentials or GITHUB_COPILOT_API_KEY/GITHUB_TOKEN",
420 )
421 })
422}
423
424impl ExtensionStreamSimpleProvider {
425 const NEXT_TIMEOUT_MS: u64 = 600_000;
426
427 const fn new(model: crate::provider::Model, runtime: ExtensionRuntimeHandle) -> Self {
428 Self { model, runtime }
429 }
430
431 fn build_js_model(model: &crate::provider::Model) -> Value {
432 serde_json::json!({
433 "id": &model.id,
434 "name": &model.name,
435 "api": &model.api,
436 "provider": &model.provider,
437 "baseUrl": &model.base_url,
438 "reasoning": model.reasoning,
439 "input": &model.input,
440 "cost": &model.cost,
441 "contextWindow": model.context_window,
442 "maxTokens": model.max_tokens,
443 "headers": &model.headers,
444 })
445 }
446
447 fn build_js_context(context: &Context<'_>) -> Value {
448 let mut map = serde_json::Map::new();
449 if let Some(system_prompt) = &context.system_prompt {
450 map.insert(
451 "systemPrompt".to_string(),
452 Value::String(system_prompt.to_string()),
453 );
454 }
455 map.insert(
456 "messages".to_string(),
457 serde_json::to_value(&context.messages).unwrap_or(Value::Array(Vec::new())),
458 );
459 if !context.tools.is_empty() {
460 let tools = context
461 .tools
462 .iter()
463 .map(|tool| {
464 serde_json::json!({
465 "name": tool.name,
466 "description": tool.description,
467 "parameters": tool.parameters,
468 })
469 })
470 .collect::<Vec<_>>();
471 map.insert("tools".to_string(), Value::Array(tools));
472 }
473 Value::Object(map)
474 }
475
476 fn build_js_options(options: &StreamOptions) -> Value {
477 let mut map = serde_json::Map::new();
478 if let Some(temp) = options.temperature {
479 map.insert("temperature".to_string(), serde_json::json!(temp));
480 }
481 if let Some(max_tokens) = options.max_tokens {
482 map.insert("maxTokens".to_string(), serde_json::json!(max_tokens));
483 }
484 if let Some(api_key) = &options.api_key {
485 map.insert("apiKey".to_string(), Value::String(api_key.clone()));
486 }
487 if let Some(session_id) = &options.session_id {
488 map.insert("sessionId".to_string(), Value::String(session_id.clone()));
489 }
490 if !options.headers.is_empty() {
491 map.insert(
492 "headers".to_string(),
493 serde_json::to_value(&options.headers)
494 .unwrap_or_else(|_| Value::Object(serde_json::Map::new())),
495 );
496 }
497 let cache_retention = match options.cache_retention {
498 crate::provider::CacheRetention::None => "none",
499 crate::provider::CacheRetention::Short => "short",
500 crate::provider::CacheRetention::Long => "long",
501 };
502 map.insert(
503 "cacheRetention".to_string(),
504 Value::String(cache_retention.to_string()),
505 );
506 if let Some(level) = options.thinking_level {
507 if level != crate::model::ThinkingLevel::Off {
508 map.insert("reasoning".to_string(), Value::String(level.to_string()));
509 }
510 }
511 if let Some(budgets) = &options.thinking_budgets {
512 map.insert(
513 "thinkingBudgets".to_string(),
514 serde_json::json!({
515 "minimal": budgets.minimal,
516 "low": budgets.low,
517 "medium": budgets.medium,
518 "high": budgets.high,
519 "xhigh": budgets.xhigh,
520 }),
521 );
522 }
523 Value::Object(map)
524 }
525
526 fn assistant_event_to_stream_event(event: AssistantMessageEvent) -> StreamEvent {
527 match event {
528 AssistantMessageEvent::Start { partial } => StreamEvent::Start {
529 partial: partial.as_ref().clone(),
530 },
531 AssistantMessageEvent::TextStart { content_index, .. } => {
532 StreamEvent::TextStart { content_index }
533 }
534 AssistantMessageEvent::TextDelta {
535 content_index,
536 delta,
537 ..
538 } => StreamEvent::TextDelta {
539 content_index,
540 delta,
541 },
542 AssistantMessageEvent::TextEnd {
543 content_index,
544 content,
545 ..
546 } => StreamEvent::TextEnd {
547 content_index,
548 content,
549 },
550 AssistantMessageEvent::ThinkingStart { content_index, .. } => {
551 StreamEvent::ThinkingStart { content_index }
552 }
553 AssistantMessageEvent::ThinkingDelta {
554 content_index,
555 delta,
556 ..
557 } => StreamEvent::ThinkingDelta {
558 content_index,
559 delta,
560 },
561 AssistantMessageEvent::ThinkingEnd {
562 content_index,
563 content,
564 ..
565 } => StreamEvent::ThinkingEnd {
566 content_index,
567 content,
568 },
569 AssistantMessageEvent::ToolCallStart { content_index, .. } => {
570 StreamEvent::ToolCallStart { content_index }
571 }
572 AssistantMessageEvent::ToolCallDelta {
573 content_index,
574 delta,
575 ..
576 } => StreamEvent::ToolCallDelta {
577 content_index,
578 delta,
579 },
580 AssistantMessageEvent::ToolCallEnd {
581 content_index,
582 tool_call,
583 ..
584 } => StreamEvent::ToolCallEnd {
585 content_index,
586 tool_call,
587 },
588 AssistantMessageEvent::Done { reason, message } => StreamEvent::Done {
589 reason,
590 message: message.as_ref().clone(),
591 },
592 AssistantMessageEvent::Error { reason, error } => StreamEvent::Error {
593 reason,
594 error: error.as_ref().clone(),
595 },
596 }
597 }
598
599 fn make_partial(model_id: &str, provider: &str, api: &str, text: &str) -> AssistantMessage {
600 AssistantMessage {
601 model: model_id.to_string(),
602 api: api.to_string(),
603 provider: provider.to_string(),
604 content: vec![ContentBlock::Text(TextContent {
605 text: text.to_string(),
606 text_signature: None,
607 })],
608 stop_reason: StopReason::default(),
609 usage: Usage::default(),
610 error_message: None,
611 timestamp: Utc::now().timestamp_millis(),
612 }
613 }
614}
615
616#[allow(clippy::too_many_lines)]
617#[async_trait]
618impl Provider for ExtensionStreamSimpleProvider {
619 #[allow(clippy::misnamed_getters)]
620 fn name(&self) -> &str {
621 &self.model.provider
622 }
623
624 fn api(&self) -> &str {
625 &self.model.api
626 }
627
628 fn model_id(&self) -> &str {
629 &self.model.id
630 }
631
632 async fn stream(
633 &self,
634 context: &Context<'_>,
635 options: &StreamOptions,
636 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
637 let model = Self::build_js_model(&self.model);
638 let ctx = Self::build_js_context(context);
639 let opts = Self::build_js_options(options);
640
641 let stream_id = self
642 .runtime
643 .provider_stream_simple_start(
644 self.model.provider.clone(),
645 model,
646 ctx,
647 opts,
648 Self::NEXT_TIMEOUT_MS,
649 )
650 .await?;
651
652 let state = ExtensionStreamSimpleState {
653 runtime: self.runtime.clone(),
654 stream_id: Some(stream_id),
655 model_id: self.model.id.clone(),
656 provider: self.model.provider.clone(),
657 api: self.model.api.clone(),
658 accumulated_text: String::new(),
659 last_message: None,
660 string_chunk_started: false,
661 pending_events: std::collections::VecDeque::new(),
662 };
663
664 let stream = stream::unfold(state, |mut state| async move {
665 if let Some(event) = state.pending_events.pop_front() {
667 return Some((Ok(event), state));
668 }
669
670 let stream_id = state.stream_id.clone()?;
671 let stream_id_for_cancel = stream_id.clone();
672
673 match state
674 .runtime
675 .provider_stream_simple_next(stream_id, Self::NEXT_TIMEOUT_MS)
676 .await
677 {
678 Ok(Some(value)) => {
679 if let Some(chunk) = value.as_str() {
680 let chunk = chunk.to_string();
681 state.accumulated_text.push_str(&chunk);
682 match &mut state.last_message {
687 Some(msg) => {
688 if let Some(ContentBlock::Text(t)) = msg.content.first_mut() {
689 t.text.clone_from(&state.accumulated_text);
690 }
691 }
692 None => {
693 state.last_message = Some(Self::make_partial(
694 &state.model_id,
695 &state.provider,
696 &state.api,
697 &state.accumulated_text,
698 ));
699 }
700 }
701
702 if !state.string_chunk_started {
704 state.string_chunk_started = true;
705 state
706 .pending_events
707 .push_back(StreamEvent::TextStart { content_index: 0 });
708 state.pending_events.push_back(StreamEvent::TextDelta {
709 content_index: 0,
710 delta: chunk,
711 });
712 return Some((
713 Ok(StreamEvent::Start {
714 partial: state.last_message.clone().unwrap_or_else(|| {
715 Self::make_partial(
716 &state.model_id,
717 &state.provider,
718 &state.api,
719 &state.accumulated_text,
720 )
721 }),
722 }),
723 state,
724 ));
725 }
726 return Some((
727 Ok(StreamEvent::TextDelta {
728 content_index: 0,
729 delta: chunk,
730 }),
731 state,
732 ));
733 }
734
735 let event: AssistantMessageEvent = match serde_json::from_value(value) {
736 Ok(event) => event,
737 Err(err) => {
738 state
739 .runtime
740 .provider_stream_simple_cancel_best_effort(stream_id_for_cancel);
741 state.stream_id = None;
742 return Some((
743 Err(Error::extension(format!(
744 "streamSimple yielded invalid event: {err}"
745 ))),
746 state,
747 ));
748 }
749 };
750
751 match &event {
752 AssistantMessageEvent::Start { partial }
753 | AssistantMessageEvent::TextStart { partial, .. }
754 | AssistantMessageEvent::TextDelta { partial, .. }
755 | AssistantMessageEvent::TextEnd { partial, .. }
756 | AssistantMessageEvent::ThinkingStart { partial, .. }
757 | AssistantMessageEvent::ThinkingDelta { partial, .. }
758 | AssistantMessageEvent::ThinkingEnd { partial, .. }
759 | AssistantMessageEvent::ToolCallStart { partial, .. }
760 | AssistantMessageEvent::ToolCallDelta { partial, .. }
761 | AssistantMessageEvent::ToolCallEnd { partial, .. } => {
762 state.last_message = Some(partial.as_ref().clone());
763 }
764 AssistantMessageEvent::Done { message, .. } => {
765 state.last_message = Some(message.as_ref().clone());
766 }
767 AssistantMessageEvent::Error { error, .. } => {
768 state.last_message = Some(error.as_ref().clone());
769 }
770 }
771
772 let stream_event = Self::assistant_event_to_stream_event(event);
773 if matches!(
774 stream_event,
775 StreamEvent::Done { .. } | StreamEvent::Error { .. }
776 ) {
777 state
778 .runtime
779 .provider_stream_simple_cancel_best_effort(stream_id_for_cancel);
780 state.stream_id = None;
781 }
782 Some((Ok(stream_event), state))
783 }
784 Ok(None) => {
785 state.stream_id = None;
787 let message = state.last_message.clone().unwrap_or_else(|| {
788 Self::make_partial(
789 &state.model_id,
790 &state.provider,
791 &state.api,
792 &state.accumulated_text,
793 )
794 });
795
796 if state.string_chunk_started {
797 state.pending_events.push_back(StreamEvent::Done {
799 reason: StopReason::Stop,
800 message,
801 });
802 Some((
803 Ok(StreamEvent::TextEnd {
804 content_index: 0,
805 content: state.accumulated_text.clone(),
806 }),
807 state,
808 ))
809 } else {
810 Some((
811 Ok(StreamEvent::Done {
812 reason: StopReason::Stop,
813 message,
814 }),
815 state,
816 ))
817 }
818 }
819 Err(err) => {
820 state
821 .runtime
822 .provider_stream_simple_cancel_best_effort(stream_id_for_cancel);
823 state.stream_id = None;
824 Some((Err(err), state))
825 }
826 }
827 });
828
829 Ok(Box::pin(stream))
830 }
831}
832
833#[allow(clippy::too_many_lines)]
834pub fn create_provider(
835 entry: &ModelEntry,
836 extensions: Option<&ExtensionManager>,
837) -> Result<Arc<dyn Provider>> {
838 if let Some(manager) = extensions {
839 if manager.provider_has_stream_simple(&entry.model.provider) {
840 let runtime = manager.runtime().ok_or_else(|| {
841 Error::provider(
842 &entry.model.provider,
843 "Extension runtime not configured for streamSimple provider",
844 )
845 })?;
846 return Ok(Arc::new(ExtensionStreamSimpleProvider::new(
847 entry.model.clone(),
848 runtime,
849 )));
850 }
851 }
852
853 let vcr_client = vcr_client_if_enabled()?;
854 let client = vcr_client.unwrap_or_else(Client::new);
855 let (route, canonical_provider, effective_api) = resolve_provider_route(entry)?;
856 tracing::debug!(
857 event = "pi.provider.factory.select",
858 provider = %entry.model.provider,
859 canonical_provider = %canonical_provider,
860 api = %effective_api,
861 base_url = %entry.model.base_url,
862 route = %route.as_str(),
863 "Selecting provider implementation"
864 );
865
866 match route {
867 ProviderRouteKind::NativeAnthropic | ProviderRouteKind::ApiAnthropicMessages => {
868 Ok(Arc::new(
869 anthropic::AnthropicProvider::new(entry.model.id.clone())
870 .with_provider_name(entry.model.provider.clone())
871 .with_base_url(normalize_anthropic_base(&entry.model.base_url))
872 .with_compat(entry.compat.clone())
873 .with_client(client),
874 ))
875 }
876 ProviderRouteKind::NativeOpenAICompletions | ProviderRouteKind::ApiOpenAICompletions => {
877 Ok(Arc::new(
878 openai::OpenAIProvider::new(entry.model.id.clone())
879 .with_provider_name(entry.model.provider.clone())
880 .with_base_url(normalize_openai_base(&entry.model.base_url))
881 .with_compat(entry.compat.clone())
882 .with_client(client),
883 ))
884 }
885 ProviderRouteKind::NativeOpenAIResponses | ProviderRouteKind::ApiOpenAIResponses => {
886 Ok(Arc::new(
887 openai_responses::OpenAIResponsesProvider::new(entry.model.id.clone())
888 .with_provider_name(entry.model.provider.clone())
889 .with_base_url(normalize_openai_responses_base(&entry.model.base_url))
890 .with_compat(entry.compat.clone())
891 .with_client(client),
892 ))
893 }
894 ProviderRouteKind::NativeOpenAICodexResponses
895 | ProviderRouteKind::ApiOpenAICodexResponses => Ok(Arc::new(
896 openai_responses::OpenAIResponsesProvider::new(entry.model.id.clone())
897 .with_provider_name(entry.model.provider.clone())
898 .with_api_name("openai-codex-responses")
899 .with_codex_mode(true)
900 .with_base_url(normalize_openai_codex_responses_base(&entry.model.base_url))
901 .with_compat(entry.compat.clone())
902 .with_client(client),
903 )),
904 ProviderRouteKind::NativeCohere | ProviderRouteKind::ApiCohereChat => Ok(Arc::new(
905 cohere::CohereProvider::new(entry.model.id.clone())
906 .with_provider_name(entry.model.provider.clone())
907 .with_base_url(normalize_cohere_base(&entry.model.base_url))
908 .with_compat(entry.compat.clone())
909 .with_client(client),
910 )),
911 ProviderRouteKind::NativeGoogle | ProviderRouteKind::ApiGoogleGenerativeAi => Ok(Arc::new(
912 gemini::GeminiProvider::new(entry.model.id.clone())
913 .with_provider_name(entry.model.provider.clone())
914 .with_api_name("google-generative-ai")
915 .with_base_url(entry.model.base_url.clone())
916 .with_compat(entry.compat.clone())
917 .with_client(client),
918 )),
919 ProviderRouteKind::NativeGoogleGeminiCli | ProviderRouteKind::ApiGoogleGeminiCli => {
920 Ok(Arc::new(
921 gemini::GeminiProvider::new(entry.model.id.clone())
922 .with_provider_name(entry.model.provider.clone())
923 .with_api_name("google-gemini-cli")
924 .with_google_cli_mode(true)
925 .with_base_url(entry.model.base_url.clone())
926 .with_compat(entry.compat.clone())
927 .with_client(client),
928 ))
929 }
930 ProviderRouteKind::NativeGoogleVertex => {
931 let runtime = vertex::resolve_vertex_provider_runtime(entry)?;
932 Ok(Arc::new(
933 vertex::VertexProvider::new(runtime.model)
934 .with_project(runtime.project)
935 .with_location(runtime.location)
936 .with_publisher(runtime.publisher)
937 .with_compat(entry.compat.clone())
938 .with_client(client),
939 ))
940 }
941 ProviderRouteKind::NativeBedrock => Ok(Arc::new(
942 bedrock::BedrockProvider::new(&entry.model.id)
943 .with_provider_name(&entry.model.provider)
944 .with_base_url(&entry.model.base_url)
945 .with_compat(entry.compat.clone())
946 .with_client(client),
947 )),
948 ProviderRouteKind::NativeAzure => {
949 let runtime = resolve_azure_provider_runtime(entry)?;
950 Ok(Arc::new(
951 azure::AzureOpenAIProvider::new(runtime.resource, runtime.deployment)
952 .with_api_version(runtime.api_version)
953 .with_endpoint_url(runtime.endpoint_url)
954 .with_compat(entry.compat.clone())
955 .with_client(client),
956 ))
957 }
958 ProviderRouteKind::NativeCopilot => {
959 let github_token = resolve_copilot_token(entry)?;
960 let mut provider = copilot::CopilotProvider::new(&entry.model.id, github_token)
961 .with_provider_name(&entry.model.provider)
962 .with_compat(entry.compat.clone())
963 .with_client(client);
964 if !entry.model.base_url.is_empty() {
965 provider = provider.with_github_api_base(&entry.model.base_url);
966 }
967 Ok(Arc::new(provider))
968 }
969 ProviderRouteKind::NativeGitlab => Ok(Arc::new(
970 gitlab::GitLabProvider::new(&entry.model.id)
971 .with_provider_name(&entry.model.provider)
972 .with_base_url(&entry.model.base_url)
973 .with_compat(entry.compat.clone())
974 .with_client(client),
975 )),
976 }
977}
978
979pub fn normalize_anthropic_base(base_url: &str) -> String {
980 let trimmed = base_url.trim();
981 if trimmed.is_empty() {
982 return "https://api.anthropic.com/v1/messages".to_string();
983 }
984 let base_url = trimmed.trim_end_matches('/');
985 if base_url.ends_with("/v1/messages") {
986 return base_url.to_string();
987 }
988 format!("{base_url}/v1/messages")
989}
990
991pub fn normalize_openai_base(base_url: &str) -> String {
992 let trimmed = base_url.trim();
993 if trimmed.is_empty() {
994 return "https://api.openai.com/v1/chat/completions".to_string();
995 }
996 let base_url = trimmed.trim_end_matches('/');
997 if base_url.ends_with("/chat/completions") {
998 return base_url.to_string();
999 }
1000 let base_url = base_url.strip_suffix("/responses").unwrap_or(base_url);
1001 if base_url.ends_with("/v1") {
1002 return format!("{base_url}/chat/completions");
1003 }
1004 format!("{base_url}/chat/completions")
1005}
1006
1007pub fn normalize_openai_responses_base(base_url: &str) -> String {
1008 let trimmed = base_url.trim();
1009 if trimmed.is_empty() {
1010 return "https://api.openai.com/v1/responses".to_string();
1011 }
1012 let base_url = trimmed.trim_end_matches('/');
1013 if base_url.ends_with("/responses") {
1014 return base_url.to_string();
1015 }
1016 let base_url = base_url
1017 .strip_suffix("/chat/completions")
1018 .unwrap_or(base_url);
1019 if base_url.ends_with("/v1") {
1020 return format!("{base_url}/responses");
1021 }
1022 format!("{base_url}/responses")
1023}
1024
1025pub fn normalize_openai_codex_responses_base(base_url: &str) -> String {
1026 let trimmed = base_url.trim();
1027 if trimmed.is_empty() {
1028 return openai_responses::CODEX_RESPONSES_API_URL.to_string();
1029 }
1030 let base = trimmed.trim_end_matches('/');
1031 if base.ends_with("/backend-api/codex/responses") {
1032 return base.to_string();
1033 }
1034 if base.ends_with("/backend-api") {
1038 return format!("{base}/codex/responses");
1039 }
1040 if base.ends_with("/responses") {
1041 return base.to_string();
1042 }
1043 format!("{base}/backend-api/codex/responses")
1044}
1045
1046pub fn normalize_cohere_base(base_url: &str) -> String {
1047 let trimmed = base_url.trim();
1048 if trimmed.is_empty() {
1049 return "https://api.cohere.com/v2/chat".to_string();
1050 }
1051 let base_url = trimmed.trim_end_matches('/');
1052 if base_url.ends_with("/chat") {
1053 return base_url.to_string();
1054 }
1055 if base_url.ends_with("/v2") {
1056 return format!("{base_url}/chat");
1057 }
1058 format!("{base_url}/chat")
1059}
1060
1061#[cfg(test)]
1062mod tests {
1063 use super::*;
1064 use crate::extensions::{ExtensionManager, JsExtensionLoadSpec, JsExtensionRuntimeHandle};
1065 use crate::extensions_js::PiJsRuntimeConfig;
1066 use crate::model::{ContentBlock, Message, UserContent, UserMessage};
1067 use crate::tools::ToolRegistry;
1068 use asupersync::runtime::RuntimeBuilder;
1069 use asupersync::time::{sleep, wall_now};
1070 use futures::StreamExt;
1071 use std::sync::Arc;
1072 use std::time::Duration;
1073 use tempfile::tempdir;
1074
1075 const STREAM_SIMPLE_EXTENSION: &str = r#"
1076export default function init(pi) {
1077 pi.registerProvider("stream-provider", {
1078 baseUrl: "https://api.example.test",
1079 apiKey: "EXAMPLE_KEY",
1080 api: "custom-api",
1081 models: [
1082 { id: "stream-model", name: "Stream Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1083 ],
1084 streamSimple: async function* (model, context, options) {
1085 if (!model || !model.baseUrl || !model.maxTokens || !model.contextWindow) {
1086 throw new Error("bad model shape");
1087 }
1088 if (!context || !Array.isArray(context.messages)) {
1089 throw new Error("bad context shape");
1090 }
1091 if (!options || !options.signal) {
1092 throw new Error("missing abort signal");
1093 }
1094
1095 const partial = {
1096 role: "assistant",
1097 content: [{ type: "text", text: "" }],
1098 api: model.api,
1099 provider: model.provider,
1100 model: model.id,
1101 usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1102 stopReason: "stop",
1103 timestamp: 0
1104 };
1105
1106 yield { type: "start", partial };
1107 yield { type: "text_start", contentIndex: 0, partial };
1108 partial.content[0].text += "hi";
1109 yield { type: "text_delta", contentIndex: 0, delta: "hi", partial };
1110 yield { type: "done", reason: "stop", message: partial };
1111 }
1112 });
1113}
1114"#;
1115
1116 const STREAM_SIMPLE_CANCEL_EXTENSION: &str = r#"
1117export default function init(pi) {
1118 pi.registerProvider("cancel-provider", {
1119 baseUrl: "https://api.example.test",
1120 apiKey: "EXAMPLE_KEY",
1121 api: "custom-api",
1122 models: [
1123 { id: "cancel-model", name: "Cancel Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1124 ],
1125 streamSimple: async function* (model, context, options) {
1126 const partial = {
1127 role: "assistant",
1128 content: [{ type: "text", text: "" }],
1129 api: model.api,
1130 provider: model.provider,
1131 model: model.id,
1132 usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1133 stopReason: "stop",
1134 timestamp: 0
1135 };
1136
1137 try {
1138 yield { type: "start", partial };
1139 await new Promise((resolve) => {
1140 if (options && options.signal && options.signal.aborted) return resolve();
1141 if (options && options.signal && typeof options.signal.addEventListener === "function") {
1142 options.signal.addEventListener("abort", () => resolve());
1143 }
1144 });
1145 } finally {
1146 await pi.tool("write", { path: "cancelled.txt", content: "ok" });
1147 }
1148 }
1149 });
1150}
1151"#;
1152
1153 async fn load_extension(
1154 source: &str,
1155 allow_write: bool,
1156 ) -> (tempfile::TempDir, ExtensionManager) {
1157 let dir = tempdir().expect("tempdir");
1158 let entry_path = dir.path().join("ext.mjs");
1159 std::fs::write(&entry_path, source).expect("write extension");
1160
1161 let manager = ExtensionManager::new();
1162 let tools = if allow_write {
1163 Arc::new(ToolRegistry::new(&["write"], dir.path(), None))
1164 } else {
1165 Arc::new(ToolRegistry::new(&[], dir.path(), None))
1166 };
1167
1168 let js_runtime = JsExtensionRuntimeHandle::start(
1169 PiJsRuntimeConfig {
1170 cwd: dir.path().display().to_string(),
1171 ..Default::default()
1172 },
1173 Arc::clone(&tools),
1174 manager.clone(),
1175 )
1176 .await
1177 .expect("start js runtime");
1178 manager.set_js_runtime(js_runtime);
1179
1180 let spec = JsExtensionLoadSpec::from_entry_path(&entry_path).expect("load spec");
1181 manager
1182 .load_js_extensions(vec![spec])
1183 .await
1184 .expect("load extension");
1185
1186 (dir, manager)
1187 }
1188
1189 fn basic_context() -> Context<'static> {
1190 Context {
1191 system_prompt: Some("system".to_string().into()),
1192 messages: vec![Message::User(UserMessage {
1193 content: UserContent::Text("hello".to_string()),
1194 timestamp: 0,
1195 })]
1196 .into(),
1197 tools: Vec::new().into(),
1198 }
1199 }
1200
1201 fn basic_options() -> StreamOptions {
1202 StreamOptions {
1203 api_key: Some("sk-test".to_string()),
1204 ..Default::default()
1205 }
1206 }
1207
1208 #[test]
1209 fn extension_stream_simple_provider_emits_assistant_events() {
1210 let runtime = RuntimeBuilder::current_thread()
1211 .build()
1212 .expect("runtime build");
1213
1214 runtime.block_on(async move {
1215 let (_dir, manager) = load_extension(STREAM_SIMPLE_EXTENSION, false).await;
1216 let entries = manager.extension_model_entries();
1217 assert_eq!(entries.len(), 1);
1218 let entry = entries
1219 .iter()
1220 .find(|e| e.model.provider == "stream-provider")
1221 .expect("stream-provider entry");
1222
1223 let provider = create_provider(entry, Some(&manager)).expect("create provider");
1224 assert_eq!(provider.name(), "stream-provider");
1225
1226 let ctx = basic_context();
1227 let opts = basic_options();
1228 let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1229
1230 let mut saw_start = false;
1231 let mut saw_text_delta = false;
1232 while let Some(item) = stream.next().await {
1233 let event = item.expect("stream event");
1234 match event {
1235 StreamEvent::Start { .. } => {
1236 saw_start = true;
1237 }
1238 StreamEvent::TextDelta { delta, .. } => {
1239 assert_eq!(delta, "hi");
1240 saw_text_delta = true;
1241 }
1242 StreamEvent::Done { reason, message } => {
1243 assert_eq!(reason, StopReason::Stop);
1244 let text = match &message.content[0] {
1245 ContentBlock::Text(text) => text,
1246 other => unreachable!("expected text content block, got {other:?}"),
1247 };
1248 assert_eq!(text.text, "hi");
1249 break;
1250 }
1251 _ => {}
1252 }
1253 }
1254
1255 assert!(saw_start, "expected a Start event");
1256 assert!(saw_text_delta, "expected a TextDelta event");
1257 });
1258 }
1259
1260 #[test]
1261 fn extension_stream_simple_provider_drop_cancels_js_stream() {
1262 let runtime = RuntimeBuilder::current_thread()
1263 .build()
1264 .expect("runtime build");
1265
1266 runtime.block_on(async move {
1267 let (dir, manager) = load_extension(STREAM_SIMPLE_CANCEL_EXTENSION, true).await;
1268 let entries = manager.extension_model_entries();
1269 assert_eq!(entries.len(), 1);
1270 let entry = entries
1271 .iter()
1272 .find(|e| e.model.provider == "cancel-provider")
1273 .expect("cancel-provider entry");
1274
1275 let provider = create_provider(entry, Some(&manager)).expect("create provider");
1276 let ctx = basic_context();
1277 let opts = basic_options();
1278 let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1279
1280 let first = stream.next().await.expect("first event");
1281 let _ = first.expect("first event ok");
1282 drop(stream);
1283
1284 let out_path = dir.path().join("cancelled.txt");
1285 for _ in 0..200 {
1286 if out_path.exists() {
1287 let contents = std::fs::read_to_string(&out_path).expect("read cancelled.txt");
1288 assert_eq!(contents, "ok");
1289 return;
1290 }
1291 sleep(wall_now(), Duration::from_millis(5)).await;
1292 }
1293
1294 assert!(
1295 out_path.exists(),
1296 "expected cancelled.txt to be created after stream drop/cancel"
1297 );
1298 });
1299 }
1300
1301 const STREAM_SIMPLE_MULTI_CHUNK: &str = r#"
1306export default function init(pi) {
1307 pi.registerProvider("multi-chunk-provider", {
1308 baseUrl: "https://api.example.test",
1309 apiKey: "EXAMPLE_KEY",
1310 api: "custom-api",
1311 models: [
1312 { id: "multi-model", name: "Multi Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1313 ],
1314 streamSimple: async function* (model, context, options) {
1315 const partial = {
1316 role: "assistant",
1317 content: [{ type: "text", text: "" }],
1318 api: model.api,
1319 provider: model.provider,
1320 model: model.id,
1321 usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1322 stopReason: "stop",
1323 timestamp: 0
1324 };
1325
1326 yield { type: "start", partial };
1327 yield { type: "text_start", contentIndex: 0, partial };
1328
1329 const chunks = ["Hello", ", ", "world", "!"];
1330 for (const chunk of chunks) {
1331 partial.content[0].text += chunk;
1332 yield { type: "text_delta", contentIndex: 0, delta: chunk, partial };
1333 }
1334
1335 yield { type: "text_end", contentIndex: 0, content: partial.content[0].text, partial };
1336 yield { type: "done", reason: "stop", message: partial };
1337 }
1338 });
1339}
1340"#;
1341
1342 const STREAM_SIMPLE_ERROR: &str = r#"
1343export default function init(pi) {
1344 pi.registerProvider("error-provider", {
1345 baseUrl: "https://api.example.test",
1346 apiKey: "EXAMPLE_KEY",
1347 api: "custom-api",
1348 models: [
1349 { id: "error-model", name: "Error Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1350 ],
1351 streamSimple: async function* (model, context, options) {
1352 const partial = {
1353 role: "assistant",
1354 content: [{ type: "text", text: "" }],
1355 api: model.api,
1356 provider: model.provider,
1357 model: model.id,
1358 usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1359 stopReason: "stop",
1360 timestamp: 0
1361 };
1362
1363 yield { type: "start", partial };
1364 throw new Error("simulated JS error during streaming");
1365 }
1366 });
1367}
1368"#;
1369
1370 const STREAM_SIMPLE_UNICODE: &str = r#"
1371export default function init(pi) {
1372 pi.registerProvider("unicode-provider", {
1373 baseUrl: "https://api.example.test",
1374 apiKey: "EXAMPLE_KEY",
1375 api: "custom-api",
1376 models: [
1377 { id: "unicode-model", name: "Unicode Model", contextWindow: 100, maxTokens: 10, input: ["text"] }
1378 ],
1379 streamSimple: async function* (model, context, options) {
1380 const partial = {
1381 role: "assistant",
1382 content: [{ type: "text", text: "" }],
1383 api: model.api,
1384 provider: model.provider,
1385 model: model.id,
1386 usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
1387 stopReason: "stop",
1388 timestamp: 0
1389 };
1390
1391 yield { type: "start", partial };
1392 yield { type: "text_start", contentIndex: 0, partial };
1393 partial.content[0].text = "日本語テスト 🦀";
1394 yield { type: "text_delta", contentIndex: 0, delta: "日本語テスト 🦀", partial };
1395 yield { type: "done", reason: "stop", message: partial };
1396 }
1397 });
1398}
1399"#;
1400
1401 #[test]
1402 fn extension_stream_simple_multiple_chunks_in_order() {
1403 let runtime = RuntimeBuilder::current_thread()
1404 .build()
1405 .expect("runtime build");
1406
1407 runtime.block_on(async move {
1408 let (_dir, manager) = load_extension(STREAM_SIMPLE_MULTI_CHUNK, false).await;
1409 let entries = manager.extension_model_entries();
1410 let entry = entries
1411 .iter()
1412 .find(|e| e.model.provider == "multi-chunk-provider")
1413 .expect("multi-chunk-provider entry");
1414
1415 let provider = create_provider(entry, Some(&manager)).expect("create provider");
1416 let ctx = basic_context();
1417 let opts = basic_options();
1418 let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1419
1420 let mut deltas = Vec::new();
1421 let mut final_text = String::new();
1422 while let Some(item) = stream.next().await {
1423 let event = item.expect("stream event");
1424 match event {
1425 StreamEvent::TextDelta { delta, .. } => {
1426 deltas.push(delta);
1427 }
1428 StreamEvent::Done { message, .. } => {
1429 let text = match &message.content[0] {
1430 ContentBlock::Text(text) => text,
1431 other => unreachable!("expected text content block, got {other:?}"),
1432 };
1433 final_text = text.text.clone();
1434 break;
1435 }
1436 _ => {}
1437 }
1438 }
1439
1440 assert_eq!(deltas, vec!["Hello", ", ", "world", "!"]);
1441 assert_eq!(final_text, "Hello, world!");
1442 });
1443 }
1444
1445 #[test]
1446 fn extension_stream_simple_js_error_propagates() {
1447 let runtime = RuntimeBuilder::current_thread()
1448 .build()
1449 .expect("runtime build");
1450
1451 runtime.block_on(async move {
1452 let (_dir, manager) = load_extension(STREAM_SIMPLE_ERROR, false).await;
1453 let entries = manager.extension_model_entries();
1454 let entry = entries
1455 .iter()
1456 .find(|e| e.model.provider == "error-provider")
1457 .expect("error-provider entry");
1458
1459 let provider = create_provider(entry, Some(&manager)).expect("create provider");
1460 let ctx = basic_context();
1461 let opts = basic_options();
1462 let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1463
1464 let mut saw_start = false;
1465 let mut saw_error = false;
1466 while let Some(item) = stream.next().await {
1467 match item {
1468 Ok(StreamEvent::Start { .. }) => {
1469 saw_start = true;
1470 }
1471 Err(err) => {
1472 let msg = err.to_string();
1474 assert!(
1475 msg.contains("simulated JS error") || msg.contains("error"),
1476 "expected JS error message, got: {msg}"
1477 );
1478 saw_error = true;
1479 break;
1480 }
1481 Ok(StreamEvent::Error { .. }) => {
1482 saw_error = true;
1483 break;
1484 }
1485 _ => {}
1486 }
1487 }
1488
1489 assert!(saw_start, "expected a Start event before error");
1490 assert!(saw_error, "expected JS error to propagate");
1491 });
1492 }
1493
1494 #[test]
1495 fn extension_stream_simple_unicode_content() {
1496 let runtime = RuntimeBuilder::current_thread()
1497 .build()
1498 .expect("runtime build");
1499
1500 runtime.block_on(async move {
1501 let (_dir, manager) = load_extension(STREAM_SIMPLE_UNICODE, false).await;
1502 let entries = manager.extension_model_entries();
1503 let entry = entries
1504 .iter()
1505 .find(|e| e.model.provider == "unicode-provider")
1506 .expect("unicode-provider entry");
1507
1508 let provider = create_provider(entry, Some(&manager)).expect("create provider");
1509 let ctx = basic_context();
1510 let opts = basic_options();
1511 let mut stream = provider.stream(&ctx, &opts).await.expect("stream");
1512
1513 let mut saw_unicode = false;
1514 while let Some(item) = stream.next().await {
1515 let event = item.expect("stream event");
1516 match event {
1517 StreamEvent::TextDelta { delta, .. } => {
1518 assert_eq!(delta, "日本語テスト 🦀");
1519 saw_unicode = true;
1520 }
1521 StreamEvent::Done { .. } => break,
1522 _ => {}
1523 }
1524 }
1525
1526 assert!(saw_unicode, "expected unicode text delta");
1527 });
1528 }
1529
1530 #[test]
1531 fn extension_stream_simple_provider_name_and_model() {
1532 let runtime = RuntimeBuilder::current_thread()
1533 .build()
1534 .expect("runtime build");
1535
1536 runtime.block_on(async move {
1537 let (_dir, manager) = load_extension(STREAM_SIMPLE_EXTENSION, false).await;
1538 let entries = manager.extension_model_entries();
1539 let entry = entries
1540 .iter()
1541 .find(|e| e.model.provider == "stream-provider")
1542 .expect("stream-provider entry");
1543
1544 let provider = create_provider(entry, Some(&manager)).expect("create provider");
1545 assert_eq!(provider.name(), "stream-provider");
1546 assert_eq!(provider.model_id(), "stream-model");
1547 assert_eq!(provider.api(), "custom-api");
1548 });
1549 }
1550
1551 #[test]
1552 fn create_provider_returns_extension_provider_for_stream_simple() {
1553 let runtime = RuntimeBuilder::current_thread()
1554 .build()
1555 .expect("runtime build");
1556
1557 runtime.block_on(async move {
1558 let (_dir, manager) = load_extension(STREAM_SIMPLE_EXTENSION, false).await;
1559 let entries = manager.extension_model_entries();
1560 let entry = entries
1561 .iter()
1562 .find(|e| e.model.provider == "stream-provider")
1563 .expect("stream-provider entry");
1564
1565 let provider = create_provider(entry, Some(&manager));
1567 assert!(provider.is_ok());
1568
1569 let provider_no_ext = create_provider(entry, None);
1571 assert!(provider_no_ext.is_err());
1572 });
1573 }
1574
1575 use crate::models::ModelEntry;
1580 use crate::provider::{InputType, Model, ModelCost};
1581 use std::collections::HashMap;
1582
1583 fn model_entry(provider: &str, api: &str, model_id: &str, base_url: &str) -> ModelEntry {
1584 ModelEntry {
1585 model: Model {
1586 id: model_id.to_string(),
1587 name: model_id.to_string(),
1588 api: api.to_string(),
1589 provider: provider.to_string(),
1590 base_url: base_url.to_string(),
1591 reasoning: false,
1592 input: vec![InputType::Text],
1593 cost: ModelCost {
1594 input: 3.0,
1595 output: 15.0,
1596 cache_read: 0.3,
1597 cache_write: 3.75,
1598 },
1599 context_window: 200_000,
1600 max_tokens: 8192,
1601 headers: HashMap::new(),
1602 },
1603 api_key: Some("sk-test-key".to_string()),
1604 headers: HashMap::new(),
1605 auth_header: true,
1606 compat: None,
1607 oauth_config: None,
1608 }
1609 }
1610
1611 #[test]
1612 fn resolve_provider_route_uses_metadata_for_alias_provider() {
1613 let entry = model_entry(
1614 "kimi",
1615 "openai-completions",
1616 "kimi-k2-instruct",
1617 "https://api.moonshot.ai/v1",
1618 );
1619 let (route, canonical_provider, effective_api) =
1620 resolve_provider_route(&entry).expect("resolve alias route");
1621 assert_eq!(route, ProviderRouteKind::ApiOpenAICompletions);
1622 assert_eq!(canonical_provider, "moonshotai");
1623 assert_eq!(effective_api, "openai-completions");
1624 }
1625
1626 #[test]
1627 fn resolve_provider_route_openai_unknown_api_defaults_to_native_responses() {
1628 let entry = model_entry("openai", "openai", "gpt-4o", "https://api.openai.com/v1");
1629 let (route, canonical_provider, effective_api) =
1630 resolve_provider_route(&entry).expect("resolve openai route");
1631 assert_eq!(route, ProviderRouteKind::NativeOpenAIResponses);
1632 assert_eq!(canonical_provider, "openai");
1633 assert_eq!(effective_api, "openai");
1634 }
1635
1636 #[test]
1637 fn resolve_provider_route_cloudflare_workers_defaults_to_openai_completions() {
1638 let entry = model_entry(
1639 "cloudflare-workers-ai",
1640 "",
1641 "@cf/meta/llama-3.1-8b-instruct",
1642 "https://api.cloudflare.com/client/v4/accounts/test-account/ai/v1",
1643 );
1644 let (route, canonical_provider, effective_api) =
1645 resolve_provider_route(&entry).expect("resolve cloudflare workers route");
1646 assert_eq!(route, ProviderRouteKind::ApiOpenAICompletions);
1647 assert_eq!(canonical_provider, "cloudflare-workers-ai");
1648 assert_eq!(effective_api, "openai-completions");
1649 }
1650
1651 #[test]
1652 fn resolve_provider_route_cloudflare_gateway_defaults_to_openai_completions() {
1653 let entry = model_entry(
1654 "cloudflare-ai-gateway",
1655 "",
1656 "gpt-4o-mini",
1657 "https://gateway.ai.cloudflare.com/v1/account-id/gateway-id/openai",
1658 );
1659 let (route, canonical_provider, effective_api) =
1660 resolve_provider_route(&entry).expect("resolve cloudflare gateway route");
1661 assert_eq!(route, ProviderRouteKind::ApiOpenAICompletions);
1662 assert_eq!(canonical_provider, "cloudflare-ai-gateway");
1663 assert_eq!(effective_api, "openai-completions");
1664 }
1665
1666 #[test]
1667 fn resolve_provider_route_uses_native_azure_route_for_cognitive_alias() {
1668 let entry = model_entry(
1669 "azure-cognitive-services",
1670 "openai-completions",
1671 "gpt-4o-mini",
1672 "https://myresource.cognitiveservices.azure.com",
1673 );
1674 let (route, canonical_provider, effective_api) =
1675 resolve_provider_route(&entry).expect("resolve azure cognitive route");
1676 assert_eq!(route, ProviderRouteKind::NativeAzure);
1677 assert_eq!(canonical_provider, "azure-openai");
1678 assert_eq!(effective_api, "openai-completions");
1679 }
1680
1681 #[test]
1682 fn resolve_provider_route_uses_native_azure_route_for_legacy_provider_alias() {
1683 let entry = model_entry(
1684 "azure-openai-responses",
1685 "azure-openai-responses",
1686 "gpt-4o-mini",
1687 "https://myresource.openai.azure.com",
1688 );
1689 let (route, canonical_provider, effective_api) =
1690 resolve_provider_route(&entry).expect("resolve azure legacy alias route");
1691 assert_eq!(route, ProviderRouteKind::NativeAzure);
1692 assert_eq!(canonical_provider, "azure-openai");
1693 assert_eq!(effective_api, "azure-openai-responses");
1694 }
1695
1696 #[test]
1697 fn resolve_provider_route_accepts_azure_legacy_api_for_custom_provider_id() {
1698 let entry = model_entry(
1699 "my-azure",
1700 "azure-openai-responses",
1701 "gpt-4o-mini",
1702 "https://example.invalid",
1703 );
1704 let (route, canonical_provider, effective_api) =
1705 resolve_provider_route(&entry).expect("resolve azure legacy api fallback");
1706 assert_eq!(route, ProviderRouteKind::NativeAzure);
1707 assert_eq!(canonical_provider, "my-azure");
1708 assert_eq!(effective_api, "azure-openai-responses");
1709 }
1710
1711 #[test]
1712 fn resolve_copilot_token_prefers_inline_model_api_key() {
1713 let mut entry = model_entry("github-copilot", "", "gpt-4o", "");
1714 entry.api_key = Some("inline-copilot-token".to_string());
1715
1716 let token = resolve_copilot_token_with_env(&entry, |_| None)
1717 .expect("inline token should be accepted");
1718 assert_eq!(token, "inline-copilot-token");
1719 }
1720
1721 #[test]
1722 fn resolve_copilot_token_falls_back_to_env() {
1723 let mut entry = model_entry("github-copilot", "", "gpt-4o", "");
1724 entry.api_key = None;
1725
1726 let token = resolve_copilot_token_with_env(&entry, |name| match name {
1727 "GITHUB_COPILOT_API_KEY" => Some("env-copilot-token".to_string()),
1728 _ => None,
1729 })
1730 .expect("env token should be accepted");
1731 assert_eq!(token, "env-copilot-token");
1732 }
1733
1734 #[test]
1735 fn resolve_copilot_token_errors_when_missing_everywhere() {
1736 let mut entry = model_entry("github-copilot", "", "gpt-4o", "");
1737 entry.api_key = None;
1738
1739 let err = resolve_copilot_token_with_env(&entry, |_| None).expect_err("expected error");
1740 assert!(
1741 err.to_string().contains("GitHub Copilot requires"),
1742 "unexpected error: {err}"
1743 );
1744 }
1745
1746 #[test]
1747 fn suggest_similar_providers_finds_prefix_match() {
1748 let suggestions = suggest_similar_providers("deep");
1749 assert!(
1750 suggestions.contains(&"deepinfra".to_string())
1751 || suggestions.contains(&"deepseek".to_string()),
1752 "expected deepinfra or deepseek in suggestions: {suggestions:?}"
1753 );
1754 }
1755
1756 #[test]
1757 fn suggest_similar_providers_finds_substring_match() {
1758 let suggestions = suggest_similar_providers("flow");
1759 assert!(
1760 suggestions.contains(&"siliconflow".to_string()),
1761 "expected siliconflow in suggestions: {suggestions:?}"
1762 );
1763 }
1764
1765 #[test]
1766 fn suggest_similar_providers_returns_empty_for_gibberish() {
1767 let suggestions = suggest_similar_providers("xyzzzabc123");
1768 assert!(
1769 suggestions.is_empty(),
1770 "expected no suggestions for gibberish: {suggestions:?}"
1771 );
1772 }
1773
1774 #[test]
1775 fn suggest_similar_providers_caps_at_three() {
1776 let suggestions = suggest_similar_providers("a");
1777 assert!(
1778 suggestions.len() <= 3,
1779 "expected at most 3 suggestions: {suggestions:?}"
1780 );
1781 }
1782
1783 #[test]
1784 fn edit_distance_basic_cases() {
1785 assert_eq!(edit_distance(b"", b""), 0);
1786 assert_eq!(edit_distance(b"abc", b"abc"), 0);
1787 assert_eq!(edit_distance(b"abc", b"ab"), 1);
1788 assert_eq!(edit_distance(b"abc", b"axc"), 1);
1789 assert_eq!(edit_distance(b"abc", b"abcd"), 1);
1790 assert_eq!(edit_distance(b"kitten", b"sitting"), 3);
1791 assert_eq!(edit_distance(b"", b"hello"), 5);
1792 }
1793
1794 #[test]
1795 fn suggest_similar_providers_finds_typo_with_edit_distance() {
1796 let suggestions = suggest_similar_providers("anthropick");
1798 assert!(
1799 suggestions.contains(&"anthropic".to_string()),
1800 "expected anthropic for typo 'anthropick': {suggestions:?}"
1801 );
1802 }
1803
1804 #[test]
1805 fn suggest_similar_providers_finds_typo_missing_char() {
1806 let suggestions = suggest_similar_providers("opnai");
1808 assert!(
1809 suggestions.contains(&"openai".to_string()),
1810 "expected openai for typo 'opnai': {suggestions:?}"
1811 );
1812 }
1813
1814 #[test]
1815 fn suggest_similar_providers_finds_transposed_chars() {
1816 let suggestions = suggest_similar_providers("gogle");
1818 assert!(
1819 suggestions.contains(&"google".to_string()),
1820 "expected google for typo 'gogle': {suggestions:?}"
1821 );
1822 }
1823
1824 #[test]
1825 fn suggest_similar_providers_no_false_positives_for_short_input() {
1826 let suggestions = suggest_similar_providers("xy");
1828 assert!(
1829 suggestions.is_empty(),
1830 "expected no suggestions for 'xy': {suggestions:?}"
1831 );
1832 }
1833
1834 #[test]
1835 fn resolve_azure_provider_runtime_supports_openai_host() {
1836 let entry = model_entry(
1837 "azure-openai",
1838 "openai-completions",
1839 "gpt-4o",
1840 "https://myresource.openai.azure.com",
1841 );
1842 let runtime =
1843 resolve_azure_provider_runtime_with_env(&entry, |_| None).expect("resolve runtime");
1844 assert_eq!(runtime.resource, "myresource");
1845 assert_eq!(runtime.deployment, "gpt-4o");
1846 assert_eq!(runtime.api_version, "2024-02-15-preview");
1847 assert_eq!(
1848 runtime.endpoint_url,
1849 "https://myresource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-02-15-preview"
1850 );
1851 }
1852
1853 #[test]
1854 fn resolve_azure_provider_runtime_supports_cognitive_services_host() {
1855 let entry = model_entry(
1856 "azure-cognitive-services",
1857 "openai-completions",
1858 "gpt-4o-mini",
1859 "https://myresource.cognitiveservices.azure.com/openai/deployments/custom/chat/completions?api-version=2024-10-21",
1860 );
1861 let runtime =
1862 resolve_azure_provider_runtime_with_env(&entry, |_| None).expect("resolve runtime");
1863 assert_eq!(runtime.resource, "myresource");
1864 assert_eq!(runtime.deployment, "gpt-4o-mini");
1865 assert_eq!(runtime.api_version, "2024-10-21");
1866 assert_eq!(
1867 runtime.endpoint_url,
1868 "https://myresource.cognitiveservices.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-10-21"
1869 );
1870 }
1871
1872 #[test]
1875 fn create_provider_anthropic_by_name() {
1876 let entry = model_entry(
1877 "anthropic",
1878 "anthropic-messages",
1879 "claude-sonnet-4-5",
1880 "https://api.anthropic.com",
1881 );
1882 let provider = create_provider(&entry, None).expect("anthropic provider");
1883 assert_eq!(provider.name(), "anthropic");
1884 assert_eq!(provider.model_id(), "claude-sonnet-4-5");
1885 assert_eq!(provider.api(), "anthropic-messages");
1886 }
1887
1888 #[test]
1889 fn create_provider_openai_completions_by_name() {
1890 let entry = model_entry(
1891 "openai",
1892 "openai-completions",
1893 "gpt-4o",
1894 "https://api.openai.com/v1",
1895 );
1896 let provider = create_provider(&entry, None).expect("openai completions provider");
1897 assert_eq!(provider.name(), "openai");
1898 assert_eq!(provider.model_id(), "gpt-4o");
1899 }
1900
1901 #[test]
1902 fn create_provider_openai_responses_by_name() {
1903 let entry = model_entry(
1904 "openai",
1905 "openai-responses",
1906 "gpt-4o",
1907 "https://api.openai.com/v1",
1908 );
1909 let provider = create_provider(&entry, None).expect("openai responses provider");
1910 assert_eq!(provider.name(), "openai");
1911 assert_eq!(provider.model_id(), "gpt-4o");
1912 }
1913
1914 #[test]
1915 fn create_provider_openai_defaults_to_responses() {
1916 let entry = model_entry("openai", "openai", "gpt-4o", "https://api.openai.com/v1");
1918 let provider = create_provider(&entry, None).expect("openai default responses provider");
1919 assert_eq!(provider.name(), "openai");
1920 }
1921
1922 #[test]
1923 fn create_provider_google_by_name() {
1924 let entry = model_entry(
1925 "google",
1926 "google-generative-ai",
1927 "gemini-2.0-flash",
1928 "https://generativelanguage.googleapis.com",
1929 );
1930 let provider = create_provider(&entry, None).expect("google provider");
1931 assert_eq!(provider.name(), "google");
1932 assert_eq!(provider.model_id(), "gemini-2.0-flash");
1933 }
1934
1935 #[test]
1936 fn create_provider_cohere_by_name() {
1937 let entry = model_entry(
1938 "cohere",
1939 "cohere-chat",
1940 "command-r-plus",
1941 "https://api.cohere.com/v2",
1942 );
1943 let provider = create_provider(&entry, None).expect("cohere provider");
1944 assert_eq!(provider.name(), "cohere");
1945 assert_eq!(provider.model_id(), "command-r-plus");
1946 }
1947
1948 #[test]
1949 fn create_provider_azure_openai_by_name() {
1950 let entry = model_entry(
1951 "azure-openai",
1952 "openai-completions",
1953 "gpt-4o",
1954 "https://myresource.openai.azure.com",
1955 );
1956 let provider = create_provider(&entry, None).expect("azure provider");
1957 assert_eq!(provider.name(), "azure");
1958 assert_eq!(provider.api(), "azure-openai");
1959 assert!(!provider.model_id().is_empty());
1960 }
1961
1962 #[test]
1963 fn create_provider_azure_cognitive_services_alias_by_name() {
1964 let entry = model_entry(
1965 "azure-cognitive-services",
1966 "openai-completions",
1967 "gpt-4o-mini",
1968 "https://myresource.cognitiveservices.azure.com",
1969 );
1970 let provider = create_provider(&entry, None).expect("azure cognitive provider");
1971 assert_eq!(provider.name(), "azure");
1972 assert_eq!(provider.api(), "azure-openai");
1973 assert!(!provider.model_id().is_empty());
1974 }
1975
1976 #[test]
1977 fn create_provider_cloudflare_workers_ai_by_name() {
1978 let entry = model_entry(
1979 "cloudflare-workers-ai",
1980 "",
1981 "@cf/meta/llama-3.1-8b-instruct",
1982 "https://api.cloudflare.com/client/v4/accounts/test-account/ai/v1",
1983 );
1984 let provider = create_provider(&entry, None).expect("cloudflare workers provider");
1985 assert_eq!(provider.name(), "cloudflare-workers-ai");
1986 assert_eq!(provider.api(), "openai-completions");
1987 assert_eq!(provider.model_id(), "@cf/meta/llama-3.1-8b-instruct");
1988 }
1989
1990 #[test]
1991 fn create_provider_cloudflare_ai_gateway_by_name() {
1992 let entry = model_entry(
1993 "cloudflare-ai-gateway",
1994 "",
1995 "gpt-4o-mini",
1996 "https://gateway.ai.cloudflare.com/v1/account-id/gateway-id/openai",
1997 );
1998 let provider = create_provider(&entry, None).expect("cloudflare gateway provider");
1999 assert_eq!(provider.name(), "cloudflare-ai-gateway");
2000 assert_eq!(provider.api(), "openai-completions");
2001 assert_eq!(provider.model_id(), "gpt-4o-mini");
2002 }
2003
2004 #[test]
2007 fn create_provider_falls_back_to_api_anthropic_messages() {
2008 let entry = model_entry(
2009 "custom-anthropic",
2010 "anthropic-messages",
2011 "my-model",
2012 "https://custom.api.com",
2013 );
2014 let provider = create_provider(&entry, None).expect("fallback anthropic provider");
2015 assert_eq!(provider.model_id(), "my-model");
2017 }
2018
2019 #[test]
2020 fn create_provider_falls_back_to_api_openai_completions() {
2021 let entry = model_entry(
2022 "my-openai-compat",
2023 "openai-completions",
2024 "local-model",
2025 "http://localhost:8080/v1",
2026 );
2027 let provider = create_provider(&entry, None).expect("fallback openai completions");
2028 assert_eq!(provider.model_id(), "local-model");
2029 }
2030
2031 #[test]
2032 fn create_provider_falls_back_to_api_openai_responses() {
2033 let entry = model_entry(
2034 "my-openai-compat",
2035 "openai-responses",
2036 "local-model",
2037 "http://localhost:8080/v1",
2038 );
2039 let provider = create_provider(&entry, None).expect("fallback openai responses");
2040 assert_eq!(provider.model_id(), "local-model");
2041 }
2042
2043 #[test]
2044 fn create_provider_falls_back_to_api_cohere_chat() {
2045 let entry = model_entry(
2046 "custom-cohere",
2047 "cohere-chat",
2048 "custom-r",
2049 "https://custom-cohere.api.com/v2",
2050 );
2051 let provider = create_provider(&entry, None).expect("fallback cohere provider");
2052 assert_eq!(provider.model_id(), "custom-r");
2053 }
2054
2055 #[test]
2056 fn create_provider_falls_back_to_api_google() {
2057 let entry = model_entry(
2058 "custom-google",
2059 "google-generative-ai",
2060 "custom-gemini",
2061 "https://custom.google.com",
2062 );
2063 let provider = create_provider(&entry, None).expect("fallback google provider");
2064 assert_eq!(provider.model_id(), "custom-gemini");
2065 }
2066
2067 #[test]
2068 fn resolve_provider_route_copilot_routes_correctly() {
2069 let entry = model_entry("github-copilot", "", "gpt-4o", "");
2070 let (route, canonical, _api) = resolve_provider_route(&entry).expect("copilot route");
2071 assert_eq!(route, ProviderRouteKind::NativeCopilot);
2072 assert_eq!(canonical, "github-copilot");
2073 }
2074
2075 #[test]
2076 fn resolve_provider_route_copilot_alias_routes_correctly() {
2077 let entry = model_entry("copilot", "", "gpt-4o", "");
2078 let (route, canonical, _api) = resolve_provider_route(&entry).expect("copilot alias route");
2079 assert_eq!(route, ProviderRouteKind::NativeCopilot);
2080 assert_eq!(canonical, "github-copilot");
2081 }
2082
2083 #[test]
2084 fn create_provider_unknown_provider_and_api_returns_error() {
2085 let entry = model_entry(
2086 "totally-unknown",
2087 "unknown-api",
2088 "some-model",
2089 "https://example.com",
2090 );
2091 let Err(err) = create_provider(&entry, None) else {
2092 panic!("unknown should fail");
2093 };
2094 let msg = err.to_string();
2095 assert!(
2096 msg.contains("not implemented"),
2097 "expected 'not implemented' message, got: {msg}"
2098 );
2099 }
2100
2101 #[test]
2104 fn normalize_anthropic_base_appends_v1_messages() {
2105 assert_eq!(
2106 normalize_anthropic_base("https://api.anthropic.com"),
2107 "https://api.anthropic.com/v1/messages"
2108 );
2109 }
2110
2111 #[test]
2112 fn normalize_anthropic_base_keeps_existing_v1_messages() {
2113 assert_eq!(
2114 normalize_anthropic_base("https://api.anthropic.com/v1/messages"),
2115 "https://api.anthropic.com/v1/messages"
2116 );
2117 }
2118
2119 #[test]
2120 fn normalize_anthropic_base_strips_trailing_slash() {
2121 assert_eq!(
2122 normalize_anthropic_base("https://api.anthropic.com/"),
2123 "https://api.anthropic.com/v1/messages"
2124 );
2125 }
2126
2127 #[test]
2128 fn normalize_anthropic_base_empty_uses_default() {
2129 assert_eq!(
2130 normalize_anthropic_base(" "),
2131 "https://api.anthropic.com/v1/messages"
2132 );
2133 }
2134
2135 #[test]
2138 fn normalize_openai_base_appends_chat_completions_to_v1() {
2139 assert_eq!(
2140 normalize_openai_base("https://api.openai.com/v1"),
2141 "https://api.openai.com/v1/chat/completions"
2142 );
2143 }
2144
2145 #[test]
2146 fn normalize_openai_base_keeps_existing_chat_completions() {
2147 assert_eq!(
2148 normalize_openai_base("https://api.openai.com/v1/chat/completions"),
2149 "https://api.openai.com/v1/chat/completions"
2150 );
2151 }
2152
2153 #[test]
2154 fn normalize_openai_base_strips_trailing_slash() {
2155 assert_eq!(
2156 normalize_openai_base("https://api.openai.com/v1/"),
2157 "https://api.openai.com/v1/chat/completions"
2158 );
2159 }
2160
2161 #[test]
2162 fn normalize_openai_base_strips_responses_suffix() {
2163 assert_eq!(
2164 normalize_openai_base("https://api.openai.com/v1/responses"),
2165 "https://api.openai.com/v1/chat/completions"
2166 );
2167 }
2168
2169 #[test]
2170 fn normalize_openai_base_bare_url_gets_chat_completions() {
2171 assert_eq!(
2172 normalize_openai_base("https://my-llm-proxy.com"),
2173 "https://my-llm-proxy.com/chat/completions"
2174 );
2175 }
2176
2177 #[test]
2178 fn normalize_openai_base_empty_uses_default() {
2179 assert_eq!(
2180 normalize_openai_base(""),
2181 "https://api.openai.com/v1/chat/completions"
2182 );
2183 }
2184
2185 #[test]
2188 fn normalize_responses_appends_responses_to_v1() {
2189 assert_eq!(
2190 normalize_openai_responses_base("https://api.openai.com/v1"),
2191 "https://api.openai.com/v1/responses"
2192 );
2193 }
2194
2195 #[test]
2196 fn normalize_responses_keeps_existing_responses() {
2197 assert_eq!(
2198 normalize_openai_responses_base("https://api.openai.com/v1/responses"),
2199 "https://api.openai.com/v1/responses"
2200 );
2201 }
2202
2203 #[test]
2204 fn normalize_responses_strips_trailing_slash() {
2205 assert_eq!(
2206 normalize_openai_responses_base("https://api.openai.com/v1/"),
2207 "https://api.openai.com/v1/responses"
2208 );
2209 }
2210
2211 #[test]
2212 fn normalize_responses_strips_chat_completions_suffix() {
2213 assert_eq!(
2214 normalize_openai_responses_base("https://api.openai.com/v1/chat/completions"),
2215 "https://api.openai.com/v1/responses"
2216 );
2217 }
2218
2219 #[test]
2220 fn normalize_responses_bare_url_gets_responses() {
2221 assert_eq!(
2222 normalize_openai_responses_base("https://my-llm-proxy.com"),
2223 "https://my-llm-proxy.com/responses"
2224 );
2225 }
2226
2227 #[test]
2228 fn normalize_responses_base_empty_uses_default() {
2229 assert_eq!(
2230 normalize_openai_responses_base(" "),
2231 "https://api.openai.com/v1/responses"
2232 );
2233 }
2234
2235 #[test]
2238 fn normalize_cohere_appends_chat_to_v2() {
2239 assert_eq!(
2240 normalize_cohere_base("https://api.cohere.com/v2"),
2241 "https://api.cohere.com/v2/chat"
2242 );
2243 }
2244
2245 #[test]
2246 fn normalize_cohere_keeps_existing_chat() {
2247 assert_eq!(
2248 normalize_cohere_base("https://api.cohere.com/v2/chat"),
2249 "https://api.cohere.com/v2/chat"
2250 );
2251 }
2252
2253 #[test]
2254 fn normalize_cohere_strips_trailing_slash() {
2255 assert_eq!(
2256 normalize_cohere_base("https://api.cohere.com/v2/"),
2257 "https://api.cohere.com/v2/chat"
2258 );
2259 }
2260
2261 #[test]
2262 fn normalize_cohere_bare_url_gets_chat() {
2263 assert_eq!(
2264 normalize_cohere_base("https://custom-cohere.example.com"),
2265 "https://custom-cohere.example.com/chat"
2266 );
2267 }
2268
2269 #[test]
2270 fn normalize_cohere_base_empty_uses_default() {
2271 assert_eq!(normalize_cohere_base(""), "https://api.cohere.com/v2/chat");
2272 }
2273
2274 mod proptests {
2275 use super::*;
2276 use proptest::prelude::*;
2277
2278 proptest! {
2279 #[test]
2280 fn normalize_anthropic_base_is_idempotent_and_targets_v1_messages(
2281 base in "[A-Za-z0-9:/._-]{1,96}"
2282 ) {
2283 let normalized = normalize_anthropic_base(&base);
2284 prop_assert!(normalized.ends_with("/v1/messages"));
2285 prop_assert_eq!(normalize_anthropic_base(&normalized), normalized);
2286 }
2287
2288 #[test]
2289 fn normalize_openai_base_is_idempotent_and_targets_chat_completions(
2290 base in "[A-Za-z0-9:/._-]{1,96}"
2291 ) {
2292 let normalized = normalize_openai_base(&base);
2293 prop_assert!(normalized.ends_with("/chat/completions"));
2294 prop_assert_eq!(normalize_openai_base(&normalized), normalized);
2295 }
2296
2297 #[test]
2298 fn normalize_openai_responses_base_is_idempotent_and_targets_responses(
2299 base in "[A-Za-z0-9:/._-]{1,96}"
2300 ) {
2301 let normalized = normalize_openai_responses_base(&base);
2302 prop_assert!(normalized.ends_with("/responses"));
2303 prop_assert_eq!(normalize_openai_responses_base(&normalized), normalized);
2304 }
2305
2306 #[test]
2307 fn normalize_cohere_base_is_idempotent_and_targets_chat(
2308 base in "[A-Za-z0-9:/._-]{1,96}"
2309 ) {
2310 let normalized = normalize_cohere_base(&base);
2311 prop_assert!(normalized.ends_with("/chat"));
2312 prop_assert_eq!(normalize_cohere_base(&normalized), normalized);
2313 }
2314
2315 #[test]
2316 fn normalize_openai_base_rewrites_responses_suffix(
2317 host in "[a-z0-9-]{1,32}",
2318 trailing_slashes in 0usize..4
2319 ) {
2320 let base = format!(
2321 "https://{host}.example/v1/responses{}",
2322 "/".repeat(trailing_slashes)
2323 );
2324 prop_assert_eq!(
2325 normalize_openai_base(&base),
2326 format!("https://{host}.example/v1/chat/completions")
2327 );
2328 }
2329
2330 #[test]
2331 fn normalize_openai_responses_base_rewrites_chat_completions_suffix(
2332 host in "[a-z0-9-]{1,32}",
2333 trailing_slashes in 0usize..4
2334 ) {
2335 let base = format!(
2336 "https://{host}.example/v1/chat/completions{}",
2337 "/".repeat(trailing_slashes)
2338 );
2339 prop_assert_eq!(
2340 normalize_openai_responses_base(&base),
2341 format!("https://{host}.example/v1/responses")
2342 );
2343 }
2344 }
2345 }
2346
2347 use crate::models::CompatConfig;
2350
2351 fn compat_with_custom_headers() -> CompatConfig {
2352 let mut custom = HashMap::new();
2353 custom.insert("X-Custom-Header".to_string(), "test-value".to_string());
2354 custom.insert("X-Provider-Tag".to_string(), "override".to_string());
2355 CompatConfig {
2356 custom_headers: Some(custom),
2357 ..Default::default()
2358 }
2359 }
2360
2361 fn model_entry_with_compat(
2362 provider: &str,
2363 api: &str,
2364 model_id: &str,
2365 base_url: &str,
2366 compat: CompatConfig,
2367 ) -> ModelEntry {
2368 let mut entry = model_entry(provider, api, model_id, base_url);
2369 entry.compat = Some(compat);
2370 entry
2371 }
2372
2373 #[test]
2374 fn create_provider_anthropic_accepts_compat_config() {
2375 let entry = model_entry_with_compat(
2376 "anthropic",
2377 "anthropic-messages",
2378 "claude-sonnet-4-5",
2379 "https://api.anthropic.com",
2380 compat_with_custom_headers(),
2381 );
2382 let provider = create_provider(&entry, None).expect("anthropic with compat");
2383 assert_eq!(provider.name(), "anthropic");
2384 }
2385
2386 #[test]
2387 fn create_provider_openai_completions_accepts_compat_config() {
2388 let entry = model_entry_with_compat(
2389 "openai",
2390 "openai-completions",
2391 "gpt-4o",
2392 "https://api.openai.com/v1",
2393 CompatConfig {
2394 max_tokens_field: Some("max_completion_tokens".to_string()),
2395 system_role_name: Some("developer".to_string()),
2396 supports_tools: Some(false),
2397 ..Default::default()
2398 },
2399 );
2400 let provider = create_provider(&entry, None).expect("openai completions with compat");
2401 assert_eq!(provider.name(), "openai");
2402 }
2403
2404 #[test]
2405 fn create_provider_openai_responses_accepts_compat_config() {
2406 let entry = model_entry_with_compat(
2407 "openai",
2408 "openai-responses",
2409 "gpt-4o",
2410 "https://api.openai.com/v1",
2411 compat_with_custom_headers(),
2412 );
2413 let provider = create_provider(&entry, None).expect("openai responses with compat");
2414 assert_eq!(provider.name(), "openai");
2415 }
2416
2417 #[test]
2418 fn create_provider_cohere_accepts_compat_config() {
2419 let entry = model_entry_with_compat(
2420 "cohere",
2421 "cohere-chat",
2422 "command-r-plus",
2423 "https://api.cohere.com/v2",
2424 compat_with_custom_headers(),
2425 );
2426 let provider = create_provider(&entry, None).expect("cohere with compat");
2427 assert_eq!(provider.name(), "cohere");
2428 }
2429
2430 #[test]
2431 fn create_provider_google_accepts_compat_config() {
2432 let entry = model_entry_with_compat(
2433 "google",
2434 "google-generative-ai",
2435 "gemini-2.0-flash",
2436 "https://generativelanguage.googleapis.com",
2437 compat_with_custom_headers(),
2438 );
2439 let provider = create_provider(&entry, None).expect("google with compat");
2440 assert_eq!(provider.name(), "google");
2441 }
2442
2443 #[test]
2444 fn create_provider_fallback_api_routes_accept_compat_config() {
2445 let entry = model_entry_with_compat(
2447 "custom-anthropic",
2448 "anthropic-messages",
2449 "my-model",
2450 "https://custom.api.com",
2451 compat_with_custom_headers(),
2452 );
2453 let provider = create_provider(&entry, None).expect("fallback anthropic with compat");
2454 assert_eq!(provider.model_id(), "my-model");
2455
2456 let entry = model_entry_with_compat(
2458 "my-groq-clone",
2459 "openai-completions",
2460 "llama-3.1",
2461 "http://localhost:8080/v1",
2462 compat_with_custom_headers(),
2463 );
2464 let provider = create_provider(&entry, None).expect("fallback openai with compat");
2465 assert_eq!(provider.model_id(), "llama-3.1");
2466
2467 let entry = model_entry_with_compat(
2469 "custom-cohere",
2470 "cohere-chat",
2471 "custom-r",
2472 "https://custom-cohere.api.com/v2",
2473 compat_with_custom_headers(),
2474 );
2475 let provider = create_provider(&entry, None).expect("fallback cohere with compat");
2476 assert_eq!(provider.model_id(), "custom-r");
2477
2478 let entry = model_entry_with_compat(
2480 "custom-google",
2481 "google-generative-ai",
2482 "custom-gemini",
2483 "https://custom.google.com",
2484 compat_with_custom_headers(),
2485 );
2486 let provider = create_provider(&entry, None).expect("fallback google with compat");
2487 assert_eq!(provider.model_id(), "custom-gemini");
2488 }
2489
2490 #[test]
2493 fn resolve_provider_route_google_vertex_routes_to_native() {
2494 let entry = model_entry(
2495 "google-vertex",
2496 "google-vertex",
2497 "gemini-2.0-flash",
2498 "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2499 );
2500 let (route, canonical_provider, effective_api) =
2501 resolve_provider_route(&entry).expect("resolve google-vertex route");
2502 assert_eq!(route, ProviderRouteKind::NativeGoogleVertex);
2503 assert_eq!(canonical_provider, "google-vertex");
2504 assert_eq!(effective_api, "google-vertex");
2505 }
2506
2507 #[test]
2508 fn resolve_provider_route_vertexai_alias_routes_to_native() {
2509 let entry = model_entry(
2510 "vertexai",
2511 "google-vertex",
2512 "gemini-2.0-flash",
2513 "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2514 );
2515 let (route, canonical_provider, effective_api) =
2516 resolve_provider_route(&entry).expect("resolve vertexai alias route");
2517 assert_eq!(route, ProviderRouteKind::NativeGoogleVertex);
2518 assert_eq!(canonical_provider, "google-vertex");
2519 assert_eq!(effective_api, "google-vertex");
2520 }
2521
2522 #[test]
2523 fn resolve_provider_route_google_vertex_api_fallback() {
2524 let entry = model_entry(
2526 "custom-vertex",
2527 "google-vertex",
2528 "gemini-2.0-flash",
2529 "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2530 );
2531 let (route, _canonical_provider, effective_api) =
2532 resolve_provider_route(&entry).expect("resolve google-vertex fallback");
2533 assert_eq!(route, ProviderRouteKind::NativeGoogleVertex);
2534 assert_eq!(effective_api, "google-vertex");
2535 }
2536
2537 #[test]
2538 fn create_provider_google_vertex_from_full_url() {
2539 let entry = model_entry(
2540 "google-vertex",
2541 "google-vertex",
2542 "gemini-2.0-flash",
2543 "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2544 );
2545 let provider = create_provider(&entry, None).expect("google-vertex from full URL");
2546 assert_eq!(provider.name(), "google-vertex");
2547 assert_eq!(provider.api(), "google-vertex");
2548 assert_eq!(provider.model_id(), "gemini-2.0-flash");
2549 }
2550
2551 #[test]
2552 fn create_provider_google_vertex_anthropic_publisher() {
2553 let entry = model_entry(
2554 "google-vertex",
2555 "google-vertex",
2556 "claude-sonnet-4-5",
2557 "https://us-east5-aiplatform.googleapis.com/v1/projects/my-project/locations/us-east5/publishers/anthropic/models/claude-sonnet-4-5",
2558 );
2559 let provider =
2560 create_provider(&entry, None).expect("google-vertex with anthropic publisher");
2561 assert_eq!(provider.name(), "google-vertex");
2562 assert_eq!(provider.model_id(), "claude-sonnet-4-5");
2563 }
2564
2565 #[test]
2566 fn create_provider_google_vertex_accepts_compat_config() {
2567 let entry = model_entry_with_compat(
2568 "google-vertex",
2569 "google-vertex",
2570 "gemini-2.0-flash",
2571 "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash",
2572 compat_with_custom_headers(),
2573 );
2574 let provider = create_provider(&entry, None).expect("google-vertex with compat");
2575 assert_eq!(provider.name(), "google-vertex");
2576 }
2577
2578 #[test]
2579 fn create_provider_compat_none_accepted_by_all_routes() {
2580 let routes = [
2582 (
2583 "anthropic",
2584 "anthropic-messages",
2585 "https://api.anthropic.com",
2586 ),
2587 ("openai", "openai-completions", "https://api.openai.com/v1"),
2588 ("openai", "openai-responses", "https://api.openai.com/v1"),
2589 ("cohere", "cohere-chat", "https://api.cohere.com/v2"),
2590 (
2591 "google",
2592 "google-generative-ai",
2593 "https://generativelanguage.googleapis.com",
2594 ),
2595 (
2596 "google-vertex",
2597 "google-vertex",
2598 "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/test-model",
2599 ),
2600 ];
2601 for (provider, api, base_url) in routes {
2602 let entry = model_entry(provider, api, "test-model", base_url);
2603 assert!(
2604 entry.compat.is_none(),
2605 "expected None compat for {provider}"
2606 );
2607 let result = create_provider(&entry, None);
2608 assert!(
2609 result.is_ok(),
2610 "create_provider failed for {provider} with None compat: {:?}",
2611 result.err()
2612 );
2613 }
2614 }
2615}