1use std::collections::HashMap;
21use std::sync::Arc;
22
23use anyllm::{
24 CapabilitySupport, ChatCapability, ChatCapabilityResolver, EmbeddingCapability, Error, Result,
25};
26
27mod chat;
28mod embedding;
29mod error;
30mod options;
31pub mod providers;
32mod streaming;
33mod wire;
34
35#[cfg(feature = "http-tracing")]
36type HttpClient = reqwest_middleware::ClientWithMiddleware;
37#[cfg(not(feature = "http-tracing"))]
38type HttpClient = reqwest::Client;
39
40pub use embedding::{
41 EmbeddingData, EmbeddingRequestOptions, EmbeddingsRequest, EmbeddingsResponse, EmbeddingsUsage,
42 from_embeddings_response, send_embeddings_request, to_embeddings_request,
43};
44pub use error::{
45 map_http_error, map_response_deserialize_error, map_stream_error, map_transport_error,
46};
47pub use options::{OpenAIReasoningEffort, RequestOptions};
48pub use streaming::{SseState, process_sse_data, sse_to_stream};
49pub use wire::{
50 ChatCompletionRequest, ChatCompletionResponse, from_api_response, parse_finish_reason,
51 to_chat_completion_request,
52};
53
54#[derive(Clone)]
64pub struct Provider {
65 pub(crate) inner: Arc<Inner>,
66}
67
68pub(crate) struct Inner {
69 pub(crate) client: HttpClient,
70 pub(crate) transport: TransportConfig,
71 pub(crate) chat_capabilities: HashMap<ChatCapability, CapabilitySupport>,
72 pub(crate) chat_capability_resolver: Option<Arc<dyn ChatCapabilityResolver>>,
73 pub(crate) embedding_capabilities: HashMap<EmbeddingCapability, CapabilitySupport>,
74 pub(crate) provider_name: &'static str,
75}
76
77fn normalize_base_url(url: impl Into<String>) -> String {
78 url.into().trim().trim_end_matches('/').to_string()
79}
80
81fn required_builder_value(name: &'static str, value: Option<String>) -> Result<String> {
82 match value {
83 Some(value) if !value.trim().is_empty() => Ok(value),
84 _ => Err(Error::InvalidRequest(format!("{name} is required"))),
85 }
86}
87
88fn builder_base_url(base_url: Option<String>) -> Result<String> {
89 match base_url {
90 Some(url) if url.trim().is_empty() => {
91 Err(Error::InvalidRequest("base_url cannot be empty".into()))
92 }
93 Some(url) => Ok(normalize_base_url(url)),
94 None => Err(Error::InvalidRequest("base_url is required".into())),
95 }
96}
97
98impl Provider {
99 fn default_http_client() -> HttpClient {
100 let base = reqwest::Client::builder()
101 .build()
102 .expect("default OpenAI-compatible reqwest client config should be valid");
103 #[cfg(feature = "http-tracing")]
104 {
105 reqwest_middleware::ClientBuilder::new(base)
106 .with(reqwest_tracing::TracingMiddleware::<
107 reqwest_tracing::SpanBackendWithUrl,
108 >::new())
109 .build()
110 }
111 #[cfg(not(feature = "http-tracing"))]
112 {
113 base
114 }
115 }
116
117 fn plain_http_client(client: reqwest::Client) -> HttpClient {
118 #[cfg(feature = "http-tracing")]
119 {
120 reqwest_middleware::ClientBuilder::new(client)
121 .with(reqwest_tracing::TracingMiddleware::<
122 reqwest_tracing::SpanBackendWithUrl,
123 >::new())
124 .build()
125 }
126 #[cfg(not(feature = "http-tracing"))]
127 {
128 client
129 }
130 }
131
132 pub fn builder() -> ProviderBuilder {
134 ProviderBuilder {
135 base_url: None,
136 chat_completions_path: None,
137 embeddings_path: None,
138 auth_header_name: None,
139 auth_header_value: None,
140 organization_header: None,
141 project_header: None,
142 request_id_header_name: None,
143 retry_after_header_name: None,
144 chat_capabilities: HashMap::new(),
145 embedding_capabilities: HashMap::new(),
146 provider_name: None,
147 client: None,
148 }
149 }
150
151 pub(crate) fn transport_config(&self) -> &TransportConfig {
152 &self.inner.transport
153 }
154
155 #[must_use]
157 pub fn with_chat_capabilities(self, resolver: impl ChatCapabilityResolver) -> Self {
158 Self {
159 inner: Arc::new(Inner {
160 client: self.inner.client.clone(),
161 transport: self.inner.transport.clone(),
162 chat_capabilities: self.inner.chat_capabilities.clone(),
163 chat_capability_resolver: Some(Arc::new(resolver)),
164 embedding_capabilities: self.inner.embedding_capabilities.clone(),
165 provider_name: self.inner.provider_name,
166 }),
167 }
168 }
169
170 pub(crate) fn builtin_chat_capability(
171 &self,
172 _model: &str,
173 capability: ChatCapability,
174 ) -> CapabilitySupport {
175 self.inner
176 .chat_capabilities
177 .get(&capability)
178 .copied()
179 .unwrap_or(CapabilitySupport::Unknown)
180 }
181}
182
183pub struct ProviderBuilder {
185 base_url: Option<String>,
186 chat_completions_path: Option<String>,
187 embeddings_path: Option<String>,
188 auth_header_name: Option<String>,
189 auth_header_value: Option<String>,
190 organization_header: Option<(String, String)>,
191 project_header: Option<(String, String)>,
192 request_id_header_name: Option<String>,
193 retry_after_header_name: Option<String>,
194 chat_capabilities: HashMap<ChatCapability, CapabilitySupport>,
195 embedding_capabilities: HashMap<EmbeddingCapability, CapabilitySupport>,
196 provider_name: Option<&'static str>,
197 client: Option<HttpClient>,
198}
199
200impl ProviderBuilder {
201 pub fn base_url(mut self, url: impl Into<String>) -> Self {
203 self.base_url = Some(url.into());
204 self
205 }
206
207 pub fn chat_completions_path(mut self, path: impl Into<String>) -> Self {
210 self.chat_completions_path = Some(path.into());
211 self
212 }
213
214 pub fn embeddings_path(mut self, path: impl Into<String>) -> Self {
217 self.embeddings_path = Some(path.into());
218 self
219 }
220
221 pub fn bearer_token(mut self, token: impl Into<String>) -> Self {
224 self.auth_header_name = Some("authorization".into());
225 self.auth_header_value = Some(format!("Bearer {}", token.into()));
226 self
227 }
228
229 pub fn auth_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
231 self.auth_header_name = Some(name.into());
232 self.auth_header_value = Some(value.into());
233 self
234 }
235
236 pub fn organization_header(
238 mut self,
239 name: impl Into<String>,
240 value: impl Into<String>,
241 ) -> Self {
242 self.organization_header = Some((name.into(), value.into()));
243 self
244 }
245
246 pub fn project_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
248 self.project_header = Some((name.into(), value.into()));
249 self
250 }
251
252 pub fn request_id_header_name(mut self, name: impl Into<String>) -> Self {
255 self.request_id_header_name = Some(name.into());
256 self
257 }
258
259 pub fn retry_after_header_name(mut self, name: impl Into<String>) -> Self {
262 self.retry_after_header_name = Some(name.into());
263 self
264 }
265
266 pub fn chat_capability(
268 mut self,
269 capability: ChatCapability,
270 support: CapabilitySupport,
271 ) -> Self {
272 self.chat_capabilities.insert(capability, support);
273 self
274 }
275
276 pub fn chat_capabilities<I>(mut self, capabilities: I) -> Self
278 where
279 I: IntoIterator<Item = (ChatCapability, CapabilitySupport)>,
280 {
281 self.chat_capabilities.extend(capabilities);
282 self
283 }
284
285 pub fn embedding_capability(
287 mut self,
288 capability: EmbeddingCapability,
289 support: CapabilitySupport,
290 ) -> Self {
291 self.embedding_capabilities.insert(capability, support);
292 self
293 }
294
295 pub fn embedding_capabilities<I>(mut self, capabilities: I) -> Self
297 where
298 I: IntoIterator<Item = (EmbeddingCapability, CapabilitySupport)>,
299 {
300 self.embedding_capabilities.extend(capabilities);
301 self
302 }
303
304 pub fn provider_name(mut self, name: &'static str) -> Self {
306 self.provider_name = Some(name);
307 self
308 }
309
310 pub fn client(mut self, client: reqwest::Client) -> Self {
318 self.client = Some(Provider::plain_http_client(client));
319 self
320 }
321
322 #[cfg(feature = "http-tracing")]
329 pub fn client_with_middleware(
330 mut self,
331 client: reqwest_middleware::ClientWithMiddleware,
332 ) -> Self {
333 self.client = Some(client);
334 self
335 }
336
337 pub fn build(self) -> Result<Provider> {
339 let auth_header_value = self.auth_header_value.ok_or_else(|| {
340 Error::InvalidRequest("auth is required: use bearer_token() or auth_header()".into())
341 })?;
342
343 let transport = TransportConfig {
344 base_url: builder_base_url(self.base_url)?,
345 chat_completions_path: self
346 .chat_completions_path
347 .unwrap_or_else(|| "/chat/completions".into()),
348 embeddings_path: self.embeddings_path.unwrap_or_else(|| "/embeddings".into()),
349 auth_header_name: self
350 .auth_header_name
351 .unwrap_or_else(|| "authorization".into()),
352 auth_header_value: required_builder_value(
353 "auth_header_value",
354 Some(auth_header_value),
355 )?,
356 organization_header: self.organization_header,
357 project_header: self.project_header,
358 request_id_header_name: self
359 .request_id_header_name
360 .unwrap_or_else(|| "x-request-id".into()),
361 retry_after_header_name: self
362 .retry_after_header_name
363 .unwrap_or_else(|| "retry-after".into()),
364 };
365
366 Ok(Provider {
367 inner: Arc::new(Inner {
368 client: self.client.unwrap_or_else(Provider::default_http_client),
369 transport,
370 chat_capabilities: self.chat_capabilities,
371 chat_capability_resolver: None,
372 embedding_capabilities: self.embedding_capabilities,
373 provider_name: self.provider_name.unwrap_or("unknown"),
378 }),
379 })
380 }
381}
382
383#[derive(Debug, Clone, PartialEq, Eq)]
384pub struct TransportConfig {
385 pub base_url: String,
386 pub chat_completions_path: String,
387 pub embeddings_path: String,
388 pub auth_header_name: String,
389 pub auth_header_value: String,
390 pub organization_header: Option<(String, String)>,
391 pub project_header: Option<(String, String)>,
392 pub request_id_header_name: String,
393 pub retry_after_header_name: String,
394}
395
396impl TransportConfig {
397 pub fn chat_completions_url(&self) -> String {
398 format!("{}{}", self.base_url, self.chat_completions_path)
399 }
400
401 pub fn embeddings_url(&self) -> String {
402 format!("{}{}", self.base_url, self.embeddings_path)
403 }
404}
405
406pub fn extract_request_id(
407 headers: &reqwest::header::HeaderMap,
408 header_name: &str,
409) -> Option<String> {
410 headers
411 .get(header_name)
412 .and_then(|value| value.to_str().ok())
413 .map(String::from)
414}
415
416pub fn extract_retry_after(
417 headers: &reqwest::header::HeaderMap,
418 header_name: &str,
419) -> Option<std::time::Duration> {
420 headers
421 .get(header_name)
422 .and_then(|value| value.to_str().ok())
423 .and_then(parse_retry_after_value)
424}
425
426fn parse_retry_after_value(value: &str) -> Option<std::time::Duration> {
427 let seconds = value.parse::<f64>().ok()?;
428 if !seconds.is_finite() || seconds.is_sign_negative() {
429 return None;
430 }
431
432 Some(std::time::Duration::from_secs_f64(seconds))
433}
434
435pub async fn send_chat_completion_request<E, Fut, F, M>(
436 api_request: &ChatCompletionRequest,
437 send: F,
438 map_transport_error: M,
439) -> Result<reqwest::Response>
440where
441 F: FnOnce(String) -> Fut,
442 Fut: std::future::Future<Output = std::result::Result<reqwest::Response, E>>,
443 M: Fn(E) -> Error,
444{
445 let body = serde_json::to_string(api_request).map_err(Error::from)?;
446
447 send(body).await.map_err(map_transport_error)
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453 use anyllm::ChatProvider;
454 use serde_json::json;
455
456 #[test]
457 fn transport_config_builds_chat_completions_url() {
458 let config = TransportConfig {
459 base_url: "https://example.com/v1".into(),
460 chat_completions_path: "/chat/completions".into(),
461 embeddings_path: "/embeddings".into(),
462 auth_header_name: "authorization".into(),
463 auth_header_value: "Bearer sk-test".into(),
464 organization_header: None,
465 project_header: None,
466 request_id_header_name: "x-request-id".into(),
467 retry_after_header_name: "retry-after".into(),
468 };
469
470 assert_eq!(
471 config.chat_completions_url(),
472 "https://example.com/v1/chat/completions"
473 );
474 }
475
476 #[test]
477 fn transport_config_builds_embeddings_url() {
478 let config = TransportConfig {
479 base_url: "https://example.com/v1".into(),
480 chat_completions_path: "/chat/completions".into(),
481 embeddings_path: "/embeddings".into(),
482 auth_header_name: "authorization".into(),
483 auth_header_value: "Bearer sk-test".into(),
484 organization_header: None,
485 project_header: None,
486 request_id_header_name: "x-request-id".into(),
487 retry_after_header_name: "retry-after".into(),
488 };
489 assert_eq!(config.embeddings_url(), "https://example.com/v1/embeddings");
490 }
491
492 #[test]
493 fn extracts_request_id_from_configured_header() {
494 let mut headers = reqwest::header::HeaderMap::new();
495 headers.insert("x-custom-request-id", "req_123".parse().unwrap());
496
497 assert_eq!(
498 extract_request_id(&headers, "x-custom-request-id").as_deref(),
499 Some("req_123")
500 );
501 }
502
503 #[test]
504 fn extracts_retry_after_from_configured_header() {
505 let mut headers = reqwest::header::HeaderMap::new();
506 headers.insert("x-retry-after", "2.5".parse().unwrap());
507
508 assert_eq!(
509 extract_retry_after(&headers, "x-retry-after"),
510 Some(std::time::Duration::from_secs_f64(2.5))
511 );
512 }
513
514 #[test]
515 fn ignores_negative_retry_after_from_configured_header() {
516 let mut headers = reqwest::header::HeaderMap::new();
517 headers.insert("x-retry-after", "-1".parse().unwrap());
518
519 assert_eq!(extract_retry_after(&headers, "x-retry-after"), None);
520 }
521
522 #[test]
523 fn ignores_non_finite_retry_after_from_configured_header() {
524 let mut headers = reqwest::header::HeaderMap::new();
525 headers.insert("x-retry-after", "NaN".parse().unwrap());
526
527 assert_eq!(extract_retry_after(&headers, "x-retry-after"), None);
528 }
529
530 #[test]
531 fn response_conversion_supports_metadata_hook() {
532 let response: ChatCompletionResponse = serde_json::from_value(json!({
533 "id": "chatcmpl-1",
534 "choices": [{
535 "index": 0,
536 "message": {
537 "role": "assistant",
538 "content": "hello"
539 },
540 "finish_reason": "stop"
541 }],
542 "model": "gpt-4o",
543 "system_fingerprint": "fp_test"
544 }))
545 .unwrap();
546
547 #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize)]
548 struct DemoMetadata {
549 fingerprint: String,
550 }
551
552 impl anyllm::ResponseMetadataType for DemoMetadata {
553 const KEY: &'static str = "demo";
554 }
555
556 let converted = from_api_response(response, |response, metadata| {
557 if let Some(fp) = &response.system_fingerprint {
558 metadata.insert(DemoMetadata {
559 fingerprint: fp.clone(),
560 });
561 }
562 })
563 .unwrap();
564
565 assert_eq!(converted.text().as_deref(), Some("hello"));
566 assert_eq!(
567 serde_json::to_value(&converted.metadata).unwrap(),
568 json!({
569 "demo": {"fingerprint": "fp_test"}
570 })
571 );
572 }
573
574 #[test]
575 fn builder_requires_base_url() {
576 let result = Provider::builder().bearer_token("token").build();
577 assert!(result.is_err());
578 }
579
580 #[test]
581 fn builder_rejects_empty_base_url() {
582 let result = Provider::builder()
583 .base_url(" ")
584 .bearer_token("token")
585 .build();
586 assert!(
587 matches!(result, Err(Error::InvalidRequest(message)) if message == "base_url cannot be empty")
588 );
589 }
590
591 #[test]
592 fn builder_normalizes_base_url() {
593 let provider = Provider::builder()
594 .base_url(" https://api.example.com/v1/ ")
595 .bearer_token("token")
596 .build()
597 .unwrap();
598 assert_eq!(
599 provider.transport_config().base_url,
600 "https://api.example.com/v1"
601 );
602 }
603
604 #[test]
605 fn builder_requires_auth() {
606 let result = Provider::builder()
607 .base_url("https://example.com/v1")
608 .build();
609 assert!(result.is_err());
610 }
611
612 #[test]
613 fn custom_auth_header() {
614 let provider = Provider::builder()
615 .base_url("https://api.example.com/v1")
616 .auth_header("x-api-key", "my-secret")
617 .build()
618 .unwrap();
619
620 let config = provider.transport_config();
621 assert_eq!(config.auth_header_name, "x-api-key");
622 assert_eq!(config.auth_header_value, "my-secret");
623 }
624
625 #[test]
626 fn chat_capability_resolver_takes_precedence_over_configured_capabilities() {
627 let provider = Provider::builder()
628 .base_url("https://api.example.com/v1")
629 .bearer_token("token")
630 .chat_capability(
631 ChatCapability::StructuredOutput,
632 CapabilitySupport::Supported,
633 )
634 .build()
635 .unwrap()
636 .with_chat_capabilities(|model: &str, capability| {
637 if model == "legacy" && capability == ChatCapability::StructuredOutput {
638 Some(CapabilitySupport::Unknown)
639 } else {
640 None
641 }
642 });
643
644 assert_eq!(
645 provider.chat_capability("legacy", ChatCapability::StructuredOutput),
646 CapabilitySupport::Unknown
647 );
648 assert_eq!(
649 provider.chat_capability("modern", ChatCapability::StructuredOutput),
650 CapabilitySupport::Supported
651 );
652 }
653
654 #[tokio::test]
655 async fn embed_posts_expected_request_and_parses_response() {
656 use anyllm::{EmbeddingCapability, EmbeddingProvider, EmbeddingRequest};
657 use anyllm_conformance::{MockHttpResponse, TestHttpServer};
658
659 let server = TestHttpServer::spawn([MockHttpResponse::json(
660 200,
661 &serde_json::json!({
662 "data": [
663 {"embedding": [0.1, 0.2], "index": 0},
664 {"embedding": [0.3, 0.4], "index": 1}
665 ],
666 "model": "text-embedding-3-small",
667 "usage": {"prompt_tokens": 4, "total_tokens": 4}
668 }),
669 )])
670 .await;
671
672 let provider = Provider::builder()
673 .base_url(format!("{}/v1", server.url()))
674 .bearer_token("sk-test")
675 .provider_name("test-compat")
676 .embedding_capability(
677 EmbeddingCapability::BatchInput,
678 CapabilitySupport::Supported,
679 )
680 .build()
681 .unwrap();
682
683 let request = EmbeddingRequest::new("text-embedding-3-small")
684 .inputs(["a", "b"])
685 .dimensions(32);
686 let response = provider.embed(&request).await.unwrap();
687
688 assert_eq!(response.embeddings, vec![vec![0.1, 0.2], vec![0.3, 0.4]]);
689 assert_eq!(response.model.as_deref(), Some("text-embedding-3-small"));
690 assert_eq!(
691 response.usage.as_ref().and_then(|u| u.input_tokens),
692 Some(4)
693 );
694
695 let recorded = server.recorded_requests().await;
696 assert_eq!(recorded.len(), 1);
697 let body = recorded[0].body_json();
698 assert_eq!(body["model"], "text-embedding-3-small");
699 assert_eq!(body["input"], serde_json::json!(["a", "b"]));
700 assert_eq!(body["dimensions"], 32);
701 assert_eq!(recorded[0].path, "/v1/embeddings");
702 assert_eq!(recorded[0].header("authorization"), Some("Bearer sk-test"));
703 }
704
705 #[test]
706 fn embedding_capability_reads_builder_config() {
707 use anyllm::{EmbeddingCapability, EmbeddingProvider};
708 let provider = Provider::builder()
709 .base_url("https://example.com/v1")
710 .bearer_token("token")
711 .embedding_capabilities([
712 (
713 EmbeddingCapability::BatchInput,
714 CapabilitySupport::Supported,
715 ),
716 (
717 EmbeddingCapability::OutputDimensions,
718 CapabilitySupport::Supported,
719 ),
720 ])
721 .build()
722 .unwrap();
723
724 assert_eq!(
725 provider.embedding_capability("m", EmbeddingCapability::BatchInput),
726 CapabilitySupport::Supported
727 );
728 assert_eq!(
729 provider.embedding_capability("m", EmbeddingCapability::OutputDimensions),
730 CapabilitySupport::Supported
731 );
732 }
733}