1pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod model_listing;
11pub mod transcription;
12pub mod verify;
13
14use bytes::Bytes;
15pub use completion::CompletionClient;
16pub use embeddings::EmbeddingsClient;
17use http::{HeaderMap, HeaderName, HeaderValue};
18pub use model_listing::{ModelLister, ModelListingClient};
19use serde::{Deserialize, Serialize};
20use std::{fmt::Debug, marker::PhantomData, sync::Arc};
21use thiserror::Error;
22pub use verify::{VerifyClient, VerifyError};
23
24#[cfg(feature = "image")]
25use crate::image_generation::ImageGenerationModel;
26#[cfg(feature = "image")]
27use image_generation::ImageGenerationClient;
28
29#[cfg(feature = "audio")]
30use crate::audio_generation::*;
31#[cfg(feature = "audio")]
32use audio_generation::*;
33
34use crate::{
35 completion::CompletionModel,
36 embeddings::EmbeddingModel,
37 http_client::{
38 self, Builder, HttpClientExt, LazyBody, MultipartForm, Request, Response, make_auth_header,
39 },
40 prelude::TranscriptionClient,
41 transcription::TranscriptionModel,
42 wasm_compat::{WasmCompatSend, WasmCompatSync},
43};
44
45#[derive(Debug, Error)]
46#[non_exhaustive]
47pub enum ClientBuilderError {
48 #[error("reqwest error: {0}")]
49 HttpError(
50 #[from]
51 #[source]
52 reqwest::Error,
53 ),
54 #[error("invalid property: {0}")]
55 InvalidProperty(&'static str),
56}
57
58pub trait ProviderClient {
61 type Input;
62
63 fn from_env() -> Self;
66
67 fn from_val(input: Self::Input) -> Self;
68}
69
70use crate::completion::{GetTokenUsage, Usage};
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct FinalCompletionResponse {
75 pub usage: Option<Usage>,
76}
77
78impl GetTokenUsage for FinalCompletionResponse {
79 fn token_usage(&self) -> Option<Usage> {
80 self.usage
81 }
82}
83
84pub trait ApiKey: Sized {
87 fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
88 None
89 }
90}
91
92pub struct BearerAuth(String);
94
95impl ApiKey for BearerAuth {
96 fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
97 Some(make_auth_header(self.0))
98 }
99}
100
101impl<S> From<S> for BearerAuth
102where
103 S: Into<String>,
104{
105 fn from(value: S) -> Self {
106 Self(value.into())
107 }
108}
109
110#[derive(Debug, Default, Clone, Copy)]
113pub struct Nothing;
114
115impl ApiKey for Nothing {}
116
117impl TryFrom<String> for Nothing {
118 type Error = &'static str;
119
120 fn try_from(_: String) -> Result<Self, Self::Error> {
121 Err(
122 "Tried to create a Nothing from a string - this should not happen, please file an issue",
123 )
124 }
125}
126
127#[derive(Clone)]
128pub struct Client<Ext = Nothing, H = reqwest::Client> {
129 base_url: Arc<str>,
130 headers: Arc<HeaderMap>,
131 http_client: H,
132 ext: Ext,
133}
134
135pub trait DebugExt: Debug {
136 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
137 std::iter::empty()
138 }
139}
140
141impl<Ext, H> std::fmt::Debug for Client<Ext, H>
142where
143 Ext: DebugExt,
144 H: std::fmt::Debug,
145{
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 let mut d = &mut f.debug_struct("Client");
148
149 d = d
150 .field("base_url", &self.base_url)
151 .field(
152 "headers",
153 &self
154 .headers
155 .iter()
156 .filter_map(|(k, v)| {
157 if k == http::header::AUTHORIZATION || k.as_str().contains("api-key") {
158 None
159 } else {
160 Some((k, v))
161 }
162 })
163 .collect::<Vec<(&HeaderName, &HeaderValue)>>(),
164 )
165 .field("http_client", &self.http_client);
166
167 self.ext
168 .fields()
169 .fold(d, |d, (name, field)| d.field(name, field))
170 .finish()
171 }
172}
173
174pub enum Transport {
175 Http,
176 Sse,
177 NdJson,
178}
179
180pub trait Provider: Sized {
184 const VERIFY_PATH: &'static str;
185
186 type Builder: ProviderBuilder;
187
188 fn build<H>(
189 builder: &ClientBuilder<Self::Builder, <Self::Builder as ProviderBuilder>::ApiKey, H>,
190 ) -> http_client::Result<Self>;
191
192 fn build_uri(&self, base_url: &str, path: &str, _transport: Transport) -> String {
193 let base_url = if base_url.is_empty() {
195 base_url.to_string()
196 } else {
197 base_url.to_string() + "/"
198 };
199
200 base_url.to_string() + path.trim_start_matches('/')
201 }
202
203 fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
204 Ok(req)
205 }
206}
207
208pub struct Capable<M>(PhantomData<M>);
210
211pub trait Capability {
212 const CAPABLE: bool;
213}
214
215impl<M> Capability for Capable<M> {
216 const CAPABLE: bool = true;
217}
218
219impl Capability for Nothing {
220 const CAPABLE: bool = false;
221}
222
223pub trait Capabilities<H = reqwest::Client> {
225 type Completion: Capability;
226 type Embeddings: Capability;
227 type Transcription: Capability;
228 type ModelListing: Capability;
229 #[cfg(feature = "image")]
230 type ImageGeneration: Capability;
231 #[cfg(feature = "audio")]
232 type AudioGeneration: Capability;
233}
234
235pub trait ProviderBuilder: Sized {
240 type Output: Provider;
241 type ApiKey;
242
243 const BASE_URL: &'static str;
244
245 fn finish<H>(
248 &self,
249 builder: ClientBuilder<Self, Self::ApiKey, H>,
250 ) -> http_client::Result<ClientBuilder<Self, Self::ApiKey, H>> {
251 Ok(builder)
252 }
253}
254
255impl<Ext, ExtBuilder, Key, H> Client<Ext, H>
256where
257 ExtBuilder: Clone + Default + ProviderBuilder<Output = Ext, ApiKey = Key>,
258 Ext: Provider<Builder = ExtBuilder>,
259 H: Default + HttpClientExt,
260 Key: ApiKey,
261{
262 pub fn new(api_key: impl Into<Key>) -> http_client::Result<Self> {
263 Self::builder().api_key(api_key).build()
264 }
265}
266
267impl<Ext, H> Client<Ext, H> {
268 pub fn base_url(&self) -> &str {
269 &self.base_url
270 }
271
272 pub fn headers(&self) -> &HeaderMap {
273 &self.headers
274 }
275
276 pub fn ext(&self) -> &Ext {
277 &self.ext
278 }
279
280 pub fn with_ext<NewExt>(self, new_ext: NewExt) -> Client<NewExt, H> {
281 Client {
282 base_url: self.base_url,
283 headers: self.headers,
284 http_client: self.http_client,
285 ext: new_ext,
286 }
287 }
288}
289
290impl<Ext, H> HttpClientExt for Client<Ext, H>
291where
292 H: HttpClientExt + 'static,
293 Ext: WasmCompatSend + WasmCompatSync + 'static,
294{
295 fn send<T, U>(
296 &self,
297 mut req: Request<T>,
298 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
299 where
300 T: Into<Bytes> + WasmCompatSend,
301 U: From<Bytes>,
302 U: WasmCompatSend + 'static,
303 {
304 req.headers_mut().insert(
305 http::header::CONTENT_TYPE,
306 http::HeaderValue::from_static("application/json"),
307 );
308
309 self.http_client.send(req)
310 }
311
312 fn send_multipart<U>(
313 &self,
314 req: Request<MultipartForm>,
315 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
316 where
317 U: From<Bytes>,
318 U: WasmCompatSend + 'static,
319 {
320 self.http_client.send_multipart(req)
321 }
322
323 fn send_streaming<T>(
324 &self,
325 mut req: Request<T>,
326 ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
327 where
328 T: Into<Bytes>,
329 {
330 req.headers_mut().insert(
331 http::header::CONTENT_TYPE,
332 http::HeaderValue::from_static("application/json"),
333 );
334
335 self.http_client.send_streaming(req)
336 }
337}
338
339impl<Ext, Builder, H> Client<Ext, H>
340where
341 H: Default + HttpClientExt,
342 Ext: Provider<Builder = Builder>,
343 Builder: Default + ProviderBuilder,
344{
345 pub fn builder() -> ClientBuilder<Builder, NeedsApiKey, H> {
346 ClientBuilder::default()
347 }
348}
349
350impl<Ext, H> Client<Ext, H>
351where
352 Ext: Provider,
353{
354 pub fn post<S>(&self, path: S) -> http_client::Result<Builder>
355 where
356 S: AsRef<str>,
357 {
358 let uri = self
359 .ext
360 .build_uri(&self.base_url, path.as_ref(), Transport::Http);
361
362 let mut req = Request::post(uri);
363
364 if let Some(hs) = req.headers_mut() {
365 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
366 }
367
368 self.ext.with_custom(req)
369 }
370
371 pub fn post_sse<S>(&self, path: S) -> http_client::Result<Builder>
372 where
373 S: AsRef<str>,
374 {
375 let uri = self
376 .ext
377 .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
378
379 let mut req = Request::post(uri);
380
381 if let Some(hs) = req.headers_mut() {
382 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
383 }
384
385 self.ext.with_custom(req)
386 }
387
388 pub fn get<S>(&self, path: S) -> http_client::Result<Builder>
389 where
390 S: AsRef<str>,
391 {
392 let uri = self
393 .ext
394 .build_uri(&self.base_url, path.as_ref(), Transport::Http);
395
396 let mut req = Request::get(uri);
397
398 if let Some(hs) = req.headers_mut() {
399 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
400 }
401
402 self.ext.with_custom(req)
403 }
404}
405
406impl<Ext, H> VerifyClient for Client<Ext, H>
407where
408 H: HttpClientExt,
409 Ext: DebugExt + Provider + WasmCompatSync,
410{
411 async fn verify(&self) -> Result<(), VerifyError> {
412 use http::StatusCode;
413
414 let req = self
415 .get(Ext::VERIFY_PATH)?
416 .body(http_client::NoBody)
417 .map_err(http_client::Error::from)?;
418
419 let response = self.http_client.send(req).await?;
420
421 match response.status() {
422 StatusCode::OK => Ok(()),
423 StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
424 Err(VerifyError::InvalidAuthentication)
425 }
426 StatusCode::INTERNAL_SERVER_ERROR => {
427 let text = http_client::text(response).await?;
428 Err(VerifyError::ProviderError(text))
429 }
430 status if status.as_u16() == 529 => {
431 let text = http_client::text(response).await?;
432 Err(VerifyError::ProviderError(text))
433 }
434 _ => {
435 let status = response.status();
436
437 if status.is_success() {
438 Ok(())
439 } else {
440 let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
441 Err(VerifyError::HttpError(http_client::Error::Instance(
442 format!("Failed with '{status}': {text}").into(),
443 )))
444 }
445 }
446 }
447 }
448}
449
450#[derive(Debug, Default, Clone, Copy)]
451pub struct NeedsApiKey;
452
453#[derive(Clone)]
455pub struct ClientBuilder<Ext, ApiKey = NeedsApiKey, H = reqwest::Client> {
456 base_url: String,
457 api_key: ApiKey,
458 headers: HeaderMap,
459 http_client: Option<H>,
460 ext: Ext,
461}
462
463impl<ExtBuilder, H> Default for ClientBuilder<ExtBuilder, NeedsApiKey, H>
464where
465 H: Default,
466 ExtBuilder: ProviderBuilder + Default,
467{
468 fn default() -> Self {
469 Self {
470 api_key: NeedsApiKey,
471 headers: Default::default(),
472 base_url: ExtBuilder::BASE_URL.into(),
473 http_client: None,
474 ext: Default::default(),
475 }
476 }
477}
478
479impl<Ext, H> ClientBuilder<Ext, NeedsApiKey, H> {
480 pub fn api_key<ApiKey>(self, api_key: impl Into<ApiKey>) -> ClientBuilder<Ext, ApiKey, H> {
483 ClientBuilder {
484 api_key: api_key.into(),
485 base_url: self.base_url,
486 headers: self.headers,
487 http_client: self.http_client,
488 ext: self.ext,
489 }
490 }
491}
492
493impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H>
494where
495 Ext: Clone,
496{
497 pub(crate) fn over_ext<F, NewExt>(self, f: F) -> ClientBuilder<NewExt, ApiKey, H>
499 where
500 F: FnOnce(Ext) -> NewExt,
501 {
502 let ClientBuilder {
503 base_url,
504 api_key,
505 headers,
506 http_client,
507 ext,
508 } = self;
509
510 let new_ext = f(ext.clone());
511
512 ClientBuilder {
513 base_url,
514 api_key,
515 headers,
516 http_client,
517 ext: new_ext,
518 }
519 }
520
521 pub fn base_url<S>(self, base_url: S) -> Self
523 where
524 S: AsRef<str>,
525 {
526 Self {
527 base_url: base_url.as_ref().to_string(),
528 ..self
529 }
530 }
531
532 pub fn http_client<U>(self, http_client: U) -> ClientBuilder<Ext, ApiKey, U> {
534 ClientBuilder {
535 http_client: Some(http_client),
536 base_url: self.base_url,
537 api_key: self.api_key,
538 headers: self.headers,
539 ext: self.ext,
540 }
541 }
542
543 pub fn http_headers(self, headers: HeaderMap) -> Self {
545 Self { headers, ..self }
546 }
547
548 pub(crate) fn headers_mut(&mut self) -> &mut HeaderMap {
549 &mut self.headers
550 }
551
552 pub(crate) fn ext_mut(&mut self) -> &mut Ext {
553 &mut self.ext
554 }
555}
556
557impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H> {
558 pub(crate) fn get_api_key(&self) -> &ApiKey {
559 &self.api_key
560 }
561}
562
563impl<Ext, Key, H> ClientBuilder<Ext, Key, H> {
564 pub fn ext(&self) -> &Ext {
565 &self.ext
566 }
567}
568
569impl<Ext, ExtBuilder, Key, H> ClientBuilder<ExtBuilder, Key, H>
570where
571 ExtBuilder: Clone + ProviderBuilder<Output = Ext, ApiKey = Key> + Default,
572 Ext: Provider<Builder = ExtBuilder>,
573 Key: ApiKey,
574 H: Default,
575{
576 pub fn build(mut self) -> http_client::Result<Client<ExtBuilder::Output, H>> {
577 let ext = self.ext.clone();
578
579 self = ext.finish(self)?;
580 let ext = Ext::build(&self)?;
581
582 let ClientBuilder {
583 http_client,
584 base_url,
585 mut headers,
586 api_key,
587 ..
588 } = self;
589
590 if let Some((k, v)) = api_key.into_header().transpose()? {
591 headers.insert(k, v);
592 }
593
594 let http_client = http_client.unwrap_or_default();
595
596 Ok(Client {
597 http_client,
598 base_url: Arc::from(base_url.as_str()),
599 headers: Arc::new(headers),
600 ext,
601 })
602 }
603}
604
605impl<M, Ext, H> CompletionClient for Client<Ext, H>
606where
607 Ext: Capabilities<H, Completion = Capable<M>>,
608 M: CompletionModel<Client = Self>,
609{
610 type CompletionModel = M;
611
612 fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
613 M::make(self, model)
614 }
615}
616
617impl<M, Ext, H> EmbeddingsClient for Client<Ext, H>
618where
619 Ext: Capabilities<H, Embeddings = Capable<M>>,
620 M: EmbeddingModel<Client = Self>,
621{
622 type EmbeddingModel = M;
623
624 fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
625 M::make(self, model, None)
626 }
627
628 fn embedding_model_with_ndims(
629 &self,
630 model: impl Into<String>,
631 ndims: usize,
632 ) -> Self::EmbeddingModel {
633 M::make(self, model, Some(ndims))
634 }
635}
636
637impl<M, Ext, H> TranscriptionClient for Client<Ext, H>
638where
639 Ext: Capabilities<H, Transcription = Capable<M>>,
640 M: TranscriptionModel<Client = Self> + WasmCompatSend,
641{
642 type TranscriptionModel = M;
643
644 fn transcription_model(&self, model: impl Into<String>) -> Self::TranscriptionModel {
645 M::make(self, model)
646 }
647}
648
649#[cfg(feature = "image")]
650impl<M, Ext, H> ImageGenerationClient for Client<Ext, H>
651where
652 Ext: Capabilities<H, ImageGeneration = Capable<M>>,
653 M: ImageGenerationModel<Client = Self>,
654{
655 type ImageGenerationModel = M;
656
657 fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
658 M::make(self, model)
659 }
660}
661
662#[cfg(feature = "audio")]
663impl<M, Ext, H> AudioGenerationClient for Client<Ext, H>
664where
665 Ext: Capabilities<H, AudioGeneration = Capable<M>>,
666 M: AudioGenerationModel<Client = Self>,
667{
668 type AudioGenerationModel = M;
669
670 fn audio_generation_model(&self, model: impl Into<String>) -> Self::AudioGenerationModel {
671 M::make(self, model)
672 }
673}
674
675impl<M, Ext, H> ModelListingClient for Client<Ext, H>
676where
677 Ext: Capabilities<H, ModelListing = Capable<M>> + Clone,
678 M: ModelLister<H, Client = Self> + Send + Sync + Clone + 'static,
679 H: Send + Sync + Clone,
680{
681 fn list_models(
682 &self,
683 ) -> impl std::future::Future<
684 Output = Result<crate::model::ModelList, crate::model::ModelListingError>,
685 > + WasmCompatSend {
686 let lister = M::new(self.clone());
687 async move { lister.list_all().await }
688 }
689}