1pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11pub mod verify;
12
13use bytes::Bytes;
14pub use completion::CompletionClient;
15pub use embeddings::EmbeddingsClient;
16use http::{HeaderMap, HeaderName, HeaderValue};
17use serde::{Deserialize, Serialize};
18use std::{fmt::Debug, marker::PhantomData, sync::Arc};
19use thiserror::Error;
20pub use verify::{VerifyClient, VerifyError};
21
22#[cfg(feature = "image")]
23use crate::image_generation::ImageGenerationModel;
24#[cfg(feature = "image")]
25use image_generation::ImageGenerationClient;
26
27#[cfg(feature = "audio")]
28use crate::audio_generation::*;
29#[cfg(feature = "audio")]
30use audio_generation::*;
31
32use crate::{
33 completion::CompletionModel,
34 embeddings::EmbeddingModel,
35 http_client::{
36 self, Builder, HttpClientExt, LazyBody, MultipartForm, Request, Response, make_auth_header,
37 },
38 prelude::TranscriptionClient,
39 transcription::TranscriptionModel,
40 wasm_compat::{WasmCompatSend, WasmCompatSync},
41};
42
43#[derive(Debug, Error)]
44#[non_exhaustive]
45pub enum ClientBuilderError {
46 #[error("reqwest error: {0}")]
47 HttpError(
48 #[from]
49 #[source]
50 reqwest::Error,
51 ),
52 #[error("invalid property: {0}")]
53 InvalidProperty(&'static str),
54}
55
56pub trait ProviderClient {
59 type Input;
60
61 fn from_env() -> Self;
64
65 fn from_val(input: Self::Input) -> Self;
66}
67
68use crate::completion::{GetTokenUsage, Usage};
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct FinalCompletionResponse {
73 pub usage: Option<Usage>,
74}
75
76impl GetTokenUsage for FinalCompletionResponse {
77 fn token_usage(&self) -> Option<Usage> {
78 self.usage
79 }
80}
81
82pub trait ApiKey: Sized {
85 fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
86 None
87 }
88}
89
90pub struct BearerAuth(String);
92
93impl ApiKey for BearerAuth {
94 fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
95 Some(make_auth_header(self.0))
96 }
97}
98
99impl<S> From<S> for BearerAuth
100where
101 S: Into<String>,
102{
103 fn from(value: S) -> Self {
104 Self(value.into())
105 }
106}
107
108#[derive(Debug, Default, Clone, Copy)]
111pub struct Nothing;
112
113impl ApiKey for Nothing {}
114
115impl TryFrom<String> for Nothing {
116 type Error = &'static str;
117
118 fn try_from(_: String) -> Result<Self, Self::Error> {
119 Err(
120 "Tried to create a Nothing from a string - this should not happen, please file an issue",
121 )
122 }
123}
124
125#[derive(Clone)]
126pub struct Client<Ext = Nothing, H = reqwest::Client> {
127 base_url: Arc<str>,
128 headers: Arc<HeaderMap>,
129 http_client: H,
130 ext: Ext,
131}
132
133pub trait DebugExt: Debug {
134 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
135 std::iter::empty()
136 }
137}
138
139impl<Ext, H> std::fmt::Debug for Client<Ext, H>
140where
141 Ext: DebugExt,
142 H: std::fmt::Debug,
143{
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 let mut d = &mut f.debug_struct("Client");
146
147 d = d
148 .field("base_url", &self.base_url)
149 .field(
150 "headers",
151 &self
152 .headers
153 .iter()
154 .filter_map(|(k, v)| {
155 if k == http::header::AUTHORIZATION || k.as_str().contains("api-key") {
156 None
157 } else {
158 Some((k, v))
159 }
160 })
161 .collect::<Vec<(&HeaderName, &HeaderValue)>>(),
162 )
163 .field("http_client", &self.http_client);
164
165 self.ext
166 .fields()
167 .fold(d, |d, (name, field)| d.field(name, field))
168 .finish()
169 }
170}
171
172pub enum Transport {
173 Http,
174 Sse,
175 NdJson,
176}
177
178pub trait Provider: Sized {
182 const VERIFY_PATH: &'static str;
183
184 type Builder: ProviderBuilder;
185
186 fn build<H>(
187 builder: &ClientBuilder<Self::Builder, <Self::Builder as ProviderBuilder>::ApiKey, H>,
188 ) -> http_client::Result<Self>;
189
190 fn build_uri(&self, base_url: &str, path: &str, _transport: Transport) -> String {
191 let base_url = if base_url.is_empty() {
193 base_url.to_string()
194 } else {
195 base_url.to_string() + "/"
196 };
197
198 base_url.to_string() + path.trim_start_matches('/')
199 }
200
201 fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
202 Ok(req)
203 }
204}
205
206pub struct Capable<M>(PhantomData<M>);
208
209pub trait Capability {
210 const CAPABLE: bool;
211}
212
213impl<M> Capability for Capable<M> {
214 const CAPABLE: bool = true;
215}
216
217impl Capability for Nothing {
218 const CAPABLE: bool = false;
219}
220
221pub trait Capabilities<H = reqwest::Client> {
223 type Completion: Capability;
224 type Embeddings: Capability;
225 type Transcription: Capability;
226 #[cfg(feature = "image")]
227 type ImageGeneration: Capability;
228 #[cfg(feature = "audio")]
229 type AudioGeneration: Capability;
230}
231
232pub trait ProviderBuilder: Sized {
237 type Output: Provider;
238 type ApiKey;
239
240 const BASE_URL: &'static str;
241
242 fn finish<H>(
245 &self,
246 builder: ClientBuilder<Self, Self::ApiKey, H>,
247 ) -> http_client::Result<ClientBuilder<Self, Self::ApiKey, H>> {
248 Ok(builder)
249 }
250}
251
252impl<Ext, ExtBuilder, Key, H> Client<Ext, H>
253where
254 ExtBuilder: Clone + Default + ProviderBuilder<Output = Ext, ApiKey = Key>,
255 Ext: Provider<Builder = ExtBuilder>,
256 H: Default + HttpClientExt,
257 Key: ApiKey,
258{
259 pub fn new(api_key: impl Into<Key>) -> http_client::Result<Self> {
260 Self::builder().api_key(api_key).build()
261 }
262}
263
264impl<Ext, H> Client<Ext, H> {
265 pub fn base_url(&self) -> &str {
266 &self.base_url
267 }
268
269 pub fn headers(&self) -> &HeaderMap {
270 &self.headers
271 }
272
273 pub fn ext(&self) -> &Ext {
274 &self.ext
275 }
276
277 pub fn with_ext<NewExt>(self, new_ext: NewExt) -> Client<NewExt, H> {
278 Client {
279 base_url: self.base_url,
280 headers: self.headers,
281 http_client: self.http_client,
282 ext: new_ext,
283 }
284 }
285}
286
287impl<Ext, H> HttpClientExt for Client<Ext, H>
288where
289 H: HttpClientExt + 'static,
290 Ext: WasmCompatSend + WasmCompatSync + 'static,
291{
292 fn send<T, U>(
293 &self,
294 mut req: Request<T>,
295 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
296 where
297 T: Into<Bytes> + WasmCompatSend,
298 U: From<Bytes>,
299 U: WasmCompatSend + 'static,
300 {
301 req.headers_mut().insert(
302 http::header::CONTENT_TYPE,
303 http::HeaderValue::from_static("application/json"),
304 );
305
306 self.http_client.send(req)
307 }
308
309 fn send_multipart<U>(
310 &self,
311 req: Request<MultipartForm>,
312 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
313 where
314 U: From<Bytes>,
315 U: WasmCompatSend + 'static,
316 {
317 self.http_client.send_multipart(req)
318 }
319
320 fn send_streaming<T>(
321 &self,
322 mut req: Request<T>,
323 ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
324 where
325 T: Into<Bytes>,
326 {
327 req.headers_mut().insert(
328 http::header::CONTENT_TYPE,
329 http::HeaderValue::from_static("application/json"),
330 );
331
332 self.http_client.send_streaming(req)
333 }
334}
335
336impl<Ext, Builder, H> Client<Ext, H>
337where
338 H: Default + HttpClientExt,
339 Ext: Provider<Builder = Builder>,
340 Builder: Default + ProviderBuilder,
341{
342 pub fn builder() -> ClientBuilder<Builder, NeedsApiKey, H> {
343 ClientBuilder::default()
344 }
345}
346
347impl<Ext, H> Client<Ext, H>
348where
349 Ext: Provider,
350{
351 pub fn post<S>(&self, path: S) -> http_client::Result<Builder>
352 where
353 S: AsRef<str>,
354 {
355 let uri = self
356 .ext
357 .build_uri(&self.base_url, path.as_ref(), Transport::Http);
358
359 let mut req = Request::post(uri);
360
361 if let Some(hs) = req.headers_mut() {
362 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
363 }
364
365 self.ext.with_custom(req)
366 }
367
368 pub fn post_sse<S>(&self, path: S) -> http_client::Result<Builder>
369 where
370 S: AsRef<str>,
371 {
372 let uri = self
373 .ext
374 .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
375
376 let mut req = Request::post(uri);
377
378 if let Some(hs) = req.headers_mut() {
379 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
380 }
381
382 self.ext.with_custom(req)
383 }
384
385 pub fn get<S>(&self, path: S) -> http_client::Result<Builder>
386 where
387 S: AsRef<str>,
388 {
389 let uri = self
390 .ext
391 .build_uri(&self.base_url, path.as_ref(), Transport::Http);
392
393 let mut req = Request::get(uri);
394
395 if let Some(hs) = req.headers_mut() {
396 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
397 }
398
399 self.ext.with_custom(req)
400 }
401}
402
403impl<Ext, H> VerifyClient for Client<Ext, H>
404where
405 H: HttpClientExt,
406 Ext: DebugExt + Provider + WasmCompatSync,
407{
408 async fn verify(&self) -> Result<(), VerifyError> {
409 use http::StatusCode;
410
411 let req = self
412 .get(Ext::VERIFY_PATH)?
413 .body(http_client::NoBody)
414 .map_err(http_client::Error::from)?;
415
416 let response = self.http_client.send(req).await?;
417
418 match response.status() {
419 StatusCode::OK => Ok(()),
420 StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
421 Err(VerifyError::InvalidAuthentication)
422 }
423 StatusCode::INTERNAL_SERVER_ERROR => {
424 let text = http_client::text(response).await?;
425 Err(VerifyError::ProviderError(text))
426 }
427 status if status.as_u16() == 529 => {
428 let text = http_client::text(response).await?;
429 Err(VerifyError::ProviderError(text))
430 }
431 _ => {
432 let status = response.status();
433
434 if status.is_success() {
435 Ok(())
436 } else {
437 let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
438 Err(VerifyError::HttpError(http_client::Error::Instance(
439 format!("Failed with '{status}': {text}").into(),
440 )))
441 }
442 }
443 }
444 }
445}
446
447#[derive(Debug, Default, Clone, Copy)]
448pub struct NeedsApiKey;
449
450#[derive(Clone)]
452pub struct ClientBuilder<Ext, ApiKey = NeedsApiKey, H = reqwest::Client> {
453 base_url: String,
454 api_key: ApiKey,
455 headers: HeaderMap,
456 http_client: Option<H>,
457 ext: Ext,
458}
459
460impl<ExtBuilder, H> Default for ClientBuilder<ExtBuilder, NeedsApiKey, H>
461where
462 H: Default,
463 ExtBuilder: ProviderBuilder + Default,
464{
465 fn default() -> Self {
466 Self {
467 api_key: NeedsApiKey,
468 headers: Default::default(),
469 base_url: ExtBuilder::BASE_URL.into(),
470 http_client: None,
471 ext: Default::default(),
472 }
473 }
474}
475
476impl<Ext, H> ClientBuilder<Ext, NeedsApiKey, H> {
477 pub fn api_key<ApiKey>(self, api_key: impl Into<ApiKey>) -> ClientBuilder<Ext, ApiKey, H> {
480 ClientBuilder {
481 api_key: api_key.into(),
482 base_url: self.base_url,
483 headers: self.headers,
484 http_client: self.http_client,
485 ext: self.ext,
486 }
487 }
488}
489
490impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H>
491where
492 Ext: Clone,
493{
494 pub(crate) fn over_ext<F, NewExt>(self, f: F) -> ClientBuilder<NewExt, ApiKey, H>
496 where
497 F: FnOnce(Ext) -> NewExt,
498 {
499 let ClientBuilder {
500 base_url,
501 api_key,
502 headers,
503 http_client,
504 ext,
505 } = self;
506
507 let new_ext = f(ext.clone());
508
509 ClientBuilder {
510 base_url,
511 api_key,
512 headers,
513 http_client,
514 ext: new_ext,
515 }
516 }
517
518 pub fn base_url<S>(self, base_url: S) -> Self
520 where
521 S: AsRef<str>,
522 {
523 Self {
524 base_url: base_url.as_ref().to_string(),
525 ..self
526 }
527 }
528
529 pub fn http_client<U>(self, http_client: U) -> ClientBuilder<Ext, ApiKey, U> {
531 ClientBuilder {
532 http_client: Some(http_client),
533 base_url: self.base_url,
534 api_key: self.api_key,
535 headers: self.headers,
536 ext: self.ext,
537 }
538 }
539
540 pub fn http_headers(self, headers: HeaderMap) -> Self {
542 Self { headers, ..self }
543 }
544
545 pub(crate) fn headers_mut(&mut self) -> &mut HeaderMap {
546 &mut self.headers
547 }
548
549 pub(crate) fn ext_mut(&mut self) -> &mut Ext {
550 &mut self.ext
551 }
552}
553
554impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H> {
555 pub(crate) fn get_api_key(&self) -> &ApiKey {
556 &self.api_key
557 }
558}
559
560impl<Ext, Key, H> ClientBuilder<Ext, Key, H> {
561 pub fn ext(&self) -> &Ext {
562 &self.ext
563 }
564}
565
566impl<Ext, ExtBuilder, Key, H> ClientBuilder<ExtBuilder, Key, H>
567where
568 ExtBuilder: Clone + ProviderBuilder<Output = Ext, ApiKey = Key> + Default,
569 Ext: Provider<Builder = ExtBuilder>,
570 Key: ApiKey,
571 H: Default,
572{
573 pub fn build(mut self) -> http_client::Result<Client<ExtBuilder::Output, H>> {
574 let ext = self.ext.clone();
575
576 self = ext.finish(self)?;
577 let ext = Ext::build(&self)?;
578
579 let ClientBuilder {
580 http_client,
581 base_url,
582 mut headers,
583 api_key,
584 ..
585 } = self;
586
587 if let Some((k, v)) = api_key.into_header().transpose()? {
588 headers.insert(k, v);
589 }
590
591 let http_client = http_client.unwrap_or_default();
592
593 Ok(Client {
594 http_client,
595 base_url: Arc::from(base_url.as_str()),
596 headers: Arc::new(headers),
597 ext,
598 })
599 }
600}
601
602impl<M, Ext, H> CompletionClient for Client<Ext, H>
603where
604 Ext: Capabilities<H, Completion = Capable<M>>,
605 M: CompletionModel<Client = Self>,
606{
607 type CompletionModel = M;
608
609 fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
610 M::make(self, model)
611 }
612}
613
614impl<M, Ext, H> EmbeddingsClient for Client<Ext, H>
615where
616 Ext: Capabilities<H, Embeddings = Capable<M>>,
617 M: EmbeddingModel<Client = Self>,
618{
619 type EmbeddingModel = M;
620
621 fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
622 M::make(self, model, None)
623 }
624
625 fn embedding_model_with_ndims(
626 &self,
627 model: impl Into<String>,
628 ndims: usize,
629 ) -> Self::EmbeddingModel {
630 M::make(self, model, Some(ndims))
631 }
632}
633
634impl<M, Ext, H> TranscriptionClient for Client<Ext, H>
635where
636 Ext: Capabilities<H, Transcription = Capable<M>>,
637 M: TranscriptionModel<Client = Self> + WasmCompatSend,
638{
639 type TranscriptionModel = M;
640
641 fn transcription_model(&self, model: impl Into<String>) -> Self::TranscriptionModel {
642 M::make(self, model)
643 }
644}
645
646#[cfg(feature = "image")]
647impl<M, Ext, H> ImageGenerationClient for Client<Ext, H>
648where
649 Ext: Capabilities<H, ImageGeneration = Capable<M>>,
650 M: ImageGenerationModel<Client = Self>,
651{
652 type ImageGenerationModel = M;
653
654 fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
655 M::make(self, model)
656 }
657}
658
659#[cfg(feature = "audio")]
660impl<M, Ext, H> AudioGenerationClient for Client<Ext, H>
661where
662 Ext: Capabilities<H, AudioGeneration = Capable<M>>,
663 M: AudioGenerationModel<Client = Self>,
664{
665 type AudioGenerationModel = M;
666
667 fn audio_generation_model(&self, model: impl Into<String>) -> Self::AudioGenerationModel {
668 M::make(self, model)
669 }
670}