1pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11pub mod verify;
12
13use bytes::Bytes;
14pub use completion::CompletionClient;
15pub use embeddings::EmbeddingsClient;
16use http::{HeaderMap, HeaderName, HeaderValue};
17use serde::{Deserialize, Serialize};
18use std::{fmt::Debug, marker::PhantomData, sync::Arc};
19use thiserror::Error;
20pub use verify::{VerifyClient, VerifyError};
21
22#[cfg(feature = "image")]
23use crate::image_generation::ImageGenerationModel;
24#[cfg(feature = "image")]
25use image_generation::ImageGenerationClient;
26
27#[cfg(feature = "audio")]
28use crate::audio_generation::*;
29#[cfg(feature = "audio")]
30use audio_generation::*;
31
32use crate::{
33 completion::CompletionModel,
34 embeddings::EmbeddingModel,
35 http_client::{
36 self, Builder, HttpClientExt, LazyBody, MultipartForm, Request, Response, make_auth_header,
37 },
38 prelude::TranscriptionClient,
39 transcription::TranscriptionModel,
40 wasm_compat::{WasmCompatSend, WasmCompatSync},
41};
42
43#[derive(Debug, Error)]
44#[non_exhaustive]
45pub enum ClientBuilderError {
46 #[error("reqwest error: {0}")]
47 HttpError(
48 #[from]
49 #[source]
50 reqwest::Error,
51 ),
52 #[error("invalid property: {0}")]
53 InvalidProperty(&'static str),
54}
55
56pub trait ProviderClient {
59 type Input;
60
61 fn from_env() -> Self;
64
65 fn from_val(input: Self::Input) -> Self;
66}
67
68use crate::completion::{GetTokenUsage, Usage};
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct FinalCompletionResponse {
73 pub usage: Option<Usage>,
74}
75
76impl GetTokenUsage for FinalCompletionResponse {
77 fn token_usage(&self) -> Option<Usage> {
78 self.usage
79 }
80}
81
82pub trait ApiKey: Sized {
85 fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
86 None
87 }
88}
89
90pub struct BearerAuth(String);
92
93impl ApiKey for BearerAuth {
94 fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
95 Some(make_auth_header(self.0))
96 }
97}
98
99impl<S> From<S> for BearerAuth
100where
101 S: Into<String>,
102{
103 fn from(value: S) -> Self {
104 Self(value.into())
105 }
106}
107
108#[derive(Debug, Default, Clone, Copy)]
111pub struct Nothing;
112
113impl ApiKey for Nothing {}
114
115impl TryFrom<String> for Nothing {
116 type Error = &'static str;
117
118 fn try_from(_: String) -> Result<Self, Self::Error> {
119 Err(
120 "Tried to create a Nothing from a string - this should not happen, please file an issue",
121 )
122 }
123}
124
125#[derive(Clone)]
126pub struct Client<Ext = Nothing, H = reqwest::Client> {
127 base_url: Arc<str>,
128 headers: Arc<HeaderMap>,
129 http_client: H,
130 ext: Ext,
131}
132
133pub trait DebugExt: Debug {
134 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
135 std::iter::empty()
136 }
137}
138
139impl<Ext, H> std::fmt::Debug for Client<Ext, H>
140where
141 Ext: DebugExt,
142 H: std::fmt::Debug,
143{
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 let mut d = &mut f.debug_struct("Client");
146
147 d = d
148 .field("base_url", &self.base_url)
149 .field(
150 "headers",
151 &self
152 .headers
153 .iter()
154 .filter_map(|(k, v)| {
155 if k == http::header::AUTHORIZATION || k.as_str().contains("api-key") {
156 None
157 } else {
158 Some((k, v))
159 }
160 })
161 .collect::<Vec<(&HeaderName, &HeaderValue)>>(),
162 )
163 .field("http_client", &self.http_client);
164
165 self.ext
166 .fields()
167 .fold(d, |d, (name, field)| d.field(name, field))
168 .finish()
169 }
170}
171
172pub enum Transport {
173 Http,
174 Sse,
175 NdJson,
176}
177
178pub trait Provider: Sized {
182 const VERIFY_PATH: &'static str;
183
184 type Builder: ProviderBuilder;
185
186 fn build<H>(
187 builder: &ClientBuilder<Self::Builder, <Self::Builder as ProviderBuilder>::ApiKey, H>,
188 ) -> http_client::Result<Self>;
189
190 fn build_uri(&self, base_url: &str, path: &str, _transport: Transport) -> String {
191 base_url.to_string() + "/" + path.trim_start_matches('/')
192 }
193
194 fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
195 Ok(req)
196 }
197}
198
199pub struct Capable<M>(PhantomData<M>);
201
202pub trait Capability {
203 const CAPABLE: bool;
204}
205
206impl<M> Capability for Capable<M> {
207 const CAPABLE: bool = true;
208}
209
210impl Capability for Nothing {
211 const CAPABLE: bool = false;
212}
213
214pub trait Capabilities<H = reqwest::Client> {
216 type Completion: Capability;
217 type Embeddings: Capability;
218 type Transcription: Capability;
219 #[cfg(feature = "image")]
220 type ImageGeneration: Capability;
221 #[cfg(feature = "audio")]
222 type AudioGeneration: Capability;
223}
224
225pub trait ProviderBuilder: Sized {
230 type Output: Provider;
231 type ApiKey;
232
233 const BASE_URL: &'static str;
234
235 fn finish<H>(
238 &self,
239 builder: ClientBuilder<Self, Self::ApiKey, H>,
240 ) -> http_client::Result<ClientBuilder<Self, Self::ApiKey, H>> {
241 Ok(builder)
242 }
243}
244
245impl<Ext, ExtBuilder, Key, H> Client<Ext, H>
246where
247 ExtBuilder: Clone + Default + ProviderBuilder<Output = Ext, ApiKey = Key>,
248 Ext: Provider<Builder = ExtBuilder>,
249 H: Default + HttpClientExt,
250 Key: ApiKey,
251{
252 pub fn new(api_key: impl Into<Key>) -> http_client::Result<Self> {
253 Self::builder().api_key(api_key).build()
254 }
255}
256
257impl<Ext, H> Client<Ext, H> {
258 pub fn base_url(&self) -> &str {
259 &self.base_url
260 }
261
262 pub fn headers(&self) -> &HeaderMap {
263 &self.headers
264 }
265
266 pub fn ext(&self) -> &Ext {
267 &self.ext
268 }
269
270 pub fn with_ext<NewExt>(self, new_ext: NewExt) -> Client<NewExt, H> {
271 Client {
272 base_url: self.base_url,
273 headers: self.headers,
274 http_client: self.http_client,
275 ext: new_ext,
276 }
277 }
278}
279
280impl<Ext, H> HttpClientExt for Client<Ext, H>
281where
282 H: HttpClientExt + 'static,
283 Ext: WasmCompatSend + WasmCompatSync + 'static,
284{
285 fn send<T, U>(
286 &self,
287 mut req: Request<T>,
288 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
289 where
290 T: Into<Bytes> + WasmCompatSend,
291 U: From<Bytes>,
292 U: WasmCompatSend + 'static,
293 {
294 req.headers_mut().insert(
295 http::header::CONTENT_TYPE,
296 http::HeaderValue::from_static("application/json"),
297 );
298
299 self.http_client.send(req)
300 }
301
302 fn send_multipart<U>(
303 &self,
304 req: Request<MultipartForm>,
305 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
306 where
307 U: From<Bytes>,
308 U: WasmCompatSend + 'static,
309 {
310 self.http_client.send_multipart(req)
311 }
312
313 fn send_streaming<T>(
314 &self,
315 mut req: Request<T>,
316 ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
317 where
318 T: Into<Bytes>,
319 {
320 req.headers_mut().insert(
321 http::header::CONTENT_TYPE,
322 http::HeaderValue::from_static("application/json"),
323 );
324
325 self.http_client.send_streaming(req)
326 }
327}
328
329impl<Ext, Builder, H> Client<Ext, H>
330where
331 H: Default + HttpClientExt,
332 Ext: Provider<Builder = Builder>,
333 Builder: Default + ProviderBuilder,
334{
335 pub fn builder() -> ClientBuilder<Builder, NeedsApiKey, H> {
336 ClientBuilder::default()
337 }
338}
339
340impl<Ext, H> Client<Ext, H>
341where
342 Ext: Provider,
343{
344 pub fn post<S>(&self, path: S) -> http_client::Result<Builder>
345 where
346 S: AsRef<str>,
347 {
348 let uri = self
349 .ext
350 .build_uri(&self.base_url, path.as_ref(), Transport::Http);
351
352 let mut req = Request::post(uri);
353
354 if let Some(hs) = req.headers_mut() {
355 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
356 }
357
358 self.ext.with_custom(req)
359 }
360
361 pub fn post_sse<S>(&self, path: S) -> http_client::Result<Builder>
362 where
363 S: AsRef<str>,
364 {
365 let uri = self
366 .ext
367 .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
368
369 let mut req = Request::post(uri);
370
371 if let Some(hs) = req.headers_mut() {
372 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
373 }
374
375 self.ext.with_custom(req)
376 }
377
378 pub fn get<S>(&self, path: S) -> http_client::Result<Builder>
379 where
380 S: AsRef<str>,
381 {
382 let uri = self
383 .ext
384 .build_uri(&self.base_url, path.as_ref(), Transport::Http);
385
386 let mut req = Request::get(uri);
387
388 if let Some(hs) = req.headers_mut() {
389 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
390 }
391
392 self.ext.with_custom(req)
393 }
394}
395
396impl<Ext, H> VerifyClient for Client<Ext, H>
397where
398 H: HttpClientExt,
399 Ext: DebugExt + Provider + WasmCompatSync,
400{
401 async fn verify(&self) -> Result<(), VerifyError> {
402 use http::StatusCode;
403
404 let req = self
405 .get(Ext::VERIFY_PATH)?
406 .body(http_client::NoBody)
407 .map_err(http_client::Error::from)?;
408
409 let response = self.http_client.send(req).await?;
410
411 match response.status() {
412 StatusCode::OK => Ok(()),
413 StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
414 Err(VerifyError::InvalidAuthentication)
415 }
416 StatusCode::INTERNAL_SERVER_ERROR => {
417 let text = http_client::text(response).await?;
418 Err(VerifyError::ProviderError(text))
419 }
420 status if status.as_u16() == 529 => {
421 let text = http_client::text(response).await?;
422 Err(VerifyError::ProviderError(text))
423 }
424 _ => {
425 let status = response.status();
426
427 if status.is_success() {
428 Ok(())
429 } else {
430 let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
431 Err(VerifyError::HttpError(http_client::Error::Instance(
432 format!("Failed with '{status}': {text}").into(),
433 )))
434 }
435 }
436 }
437 }
438}
439
440#[derive(Debug, Default, Clone, Copy)]
441pub struct NeedsApiKey;
442
443#[derive(Clone)]
445pub struct ClientBuilder<Ext, ApiKey = NeedsApiKey, H = reqwest::Client> {
446 base_url: String,
447 api_key: ApiKey,
448 headers: HeaderMap,
449 http_client: Option<H>,
450 ext: Ext,
451}
452
453impl<ExtBuilder, H> Default for ClientBuilder<ExtBuilder, NeedsApiKey, H>
454where
455 H: Default,
456 ExtBuilder: ProviderBuilder + Default,
457{
458 fn default() -> Self {
459 Self {
460 api_key: NeedsApiKey,
461 headers: Default::default(),
462 base_url: ExtBuilder::BASE_URL.into(),
463 http_client: None,
464 ext: Default::default(),
465 }
466 }
467}
468
469impl<Ext, H> ClientBuilder<Ext, NeedsApiKey, H> {
470 pub fn api_key<ApiKey>(self, api_key: impl Into<ApiKey>) -> ClientBuilder<Ext, ApiKey, H> {
473 ClientBuilder {
474 api_key: api_key.into(),
475 base_url: self.base_url,
476 headers: self.headers,
477 http_client: self.http_client,
478 ext: self.ext,
479 }
480 }
481}
482
483impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H>
484where
485 Ext: Clone,
486{
487 pub(crate) fn over_ext<F, NewExt>(self, f: F) -> ClientBuilder<NewExt, ApiKey, H>
489 where
490 F: FnOnce(Ext) -> NewExt,
491 {
492 let ClientBuilder {
493 base_url,
494 api_key,
495 headers,
496 http_client,
497 ext,
498 } = self;
499
500 let new_ext = f(ext.clone());
501
502 ClientBuilder {
503 base_url,
504 api_key,
505 headers,
506 http_client,
507 ext: new_ext,
508 }
509 }
510
511 pub fn base_url<S>(self, base_url: S) -> Self
513 where
514 S: AsRef<str>,
515 {
516 Self {
517 base_url: base_url.as_ref().to_string(),
518 ..self
519 }
520 }
521
522 pub fn http_client<U>(self, http_client: U) -> ClientBuilder<Ext, ApiKey, U> {
524 ClientBuilder {
525 http_client: Some(http_client),
526 base_url: self.base_url,
527 api_key: self.api_key,
528 headers: self.headers,
529 ext: self.ext,
530 }
531 }
532
533 pub fn http_headers(self, headers: HeaderMap) -> Self {
535 Self { headers, ..self }
536 }
537
538 pub(crate) fn headers_mut(&mut self) -> &mut HeaderMap {
539 &mut self.headers
540 }
541
542 pub(crate) fn ext_mut(&mut self) -> &mut Ext {
543 &mut self.ext
544 }
545}
546
547impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H> {
548 pub(crate) fn get_api_key(&self) -> &ApiKey {
549 &self.api_key
550 }
551}
552
553impl<Ext, Key, H> ClientBuilder<Ext, Key, H> {
554 pub fn ext(&self) -> &Ext {
555 &self.ext
556 }
557}
558
559impl<Ext, ExtBuilder, Key, H> ClientBuilder<ExtBuilder, Key, H>
560where
561 ExtBuilder: Clone + ProviderBuilder<Output = Ext, ApiKey = Key> + Default,
562 Ext: Provider<Builder = ExtBuilder>,
563 Key: ApiKey,
564 H: Default,
565{
566 pub fn build(mut self) -> http_client::Result<Client<ExtBuilder::Output, H>> {
567 let ext = self.ext.clone();
568
569 self = ext.finish(self)?;
570 let ext = Ext::build(&self)?;
571
572 let ClientBuilder {
573 http_client,
574 base_url,
575 mut headers,
576 api_key,
577 ..
578 } = self;
579
580 if let Some((k, v)) = api_key.into_header().transpose()? {
581 headers.insert(k, v);
582 }
583
584 let http_client = http_client.unwrap_or_default();
585
586 Ok(Client {
587 http_client,
588 base_url: Arc::from(base_url.as_str()),
589 headers: Arc::new(headers),
590 ext,
591 })
592 }
593}
594
595impl<M, Ext, H> CompletionClient for Client<Ext, H>
596where
597 Ext: Capabilities<H, Completion = Capable<M>>,
598 M: CompletionModel<Client = Self>,
599{
600 type CompletionModel = M;
601
602 fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
603 M::make(self, model)
604 }
605}
606
607impl<M, Ext, H> EmbeddingsClient for Client<Ext, H>
608where
609 Ext: Capabilities<H, Embeddings = Capable<M>>,
610 M: EmbeddingModel<Client = Self>,
611{
612 type EmbeddingModel = M;
613
614 fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
615 M::make(self, model, None)
616 }
617
618 fn embedding_model_with_ndims(
619 &self,
620 model: impl Into<String>,
621 ndims: usize,
622 ) -> Self::EmbeddingModel {
623 M::make(self, model, Some(ndims))
624 }
625}
626
627impl<M, Ext, H> TranscriptionClient for Client<Ext, H>
628where
629 Ext: Capabilities<H, Transcription = Capable<M>>,
630 M: TranscriptionModel<Client = Self> + WasmCompatSend,
631{
632 type TranscriptionModel = M;
633
634 fn transcription_model(&self, model: impl Into<String>) -> Self::TranscriptionModel {
635 M::make(self, model)
636 }
637}
638
639#[cfg(feature = "image")]
640impl<M, Ext, H> ImageGenerationClient for Client<Ext, H>
641where
642 Ext: Capabilities<H, ImageGeneration = Capable<M>>,
643 M: ImageGenerationModel<Client = Self>,
644{
645 type ImageGenerationModel = M;
646
647 fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
648 M::make(self, model)
649 }
650}
651
652#[cfg(feature = "audio")]
653impl<M, Ext, H> AudioGenerationClient for Client<Ext, H>
654where
655 Ext: Capabilities<H, AudioGeneration = Capable<M>>,
656 M: AudioGenerationModel<Client = Self>,
657{
658 type AudioGenerationModel = M;
659
660 fn audio_generation_model(&self, model: impl Into<String>) -> Self::AudioGenerationModel {
661 M::make(self, model)
662 }
663}