1pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11pub mod verify;
12
13use bytes::Bytes;
14pub use completion::CompletionClient;
15pub use embeddings::EmbeddingsClient;
16use http::{HeaderMap, HeaderName, HeaderValue};
17use serde::{Deserialize, Serialize};
18use std::{fmt::Debug, marker::PhantomData, sync::Arc};
19use thiserror::Error;
20pub use verify::{VerifyClient, VerifyError};
21
22#[cfg(feature = "image")]
23use crate::image_generation::ImageGenerationModel;
24#[cfg(feature = "image")]
25use image_generation::ImageGenerationClient;
26
27#[cfg(feature = "audio")]
28use crate::audio_generation::*;
29#[cfg(feature = "audio")]
30use audio_generation::*;
31
32use crate::{
33 completion::CompletionModel,
34 embeddings::EmbeddingModel,
35 http_client::{self, Builder, HttpClientExt, LazyBody, Request, Response, make_auth_header},
36 prelude::TranscriptionClient,
37 transcription::TranscriptionModel,
38 wasm_compat::{WasmCompatSend, WasmCompatSync},
39};
40
41#[derive(Debug, Error)]
42#[non_exhaustive]
43pub enum ClientBuilderError {
44 #[error("reqwest error: {0}")]
45 HttpError(
46 #[from]
47 #[source]
48 reqwest::Error,
49 ),
50 #[error("invalid property: {0}")]
51 InvalidProperty(&'static str),
52}
53
54pub trait ProviderClient {
57 type Input;
58
59 fn from_env() -> Self;
62
63 fn from_val(input: Self::Input) -> Self;
64}
65
66use crate::completion::{GetTokenUsage, Usage};
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct FinalCompletionResponse {
71 pub usage: Option<Usage>,
72}
73
74impl GetTokenUsage for FinalCompletionResponse {
75 fn token_usage(&self) -> Option<Usage> {
76 self.usage
77 }
78}
79
80pub trait ApiKey: Sized {
83 fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
84 None
85 }
86}
87
88pub struct BearerAuth(String);
90
91impl ApiKey for BearerAuth {
92 fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
93 Some(make_auth_header(self.0))
94 }
95}
96
97impl<S> From<S> for BearerAuth
98where
99 S: Into<String>,
100{
101 fn from(value: S) -> Self {
102 Self(value.into())
103 }
104}
105
106#[derive(Debug, Default, Clone, Copy)]
109pub struct Nothing;
110
111impl ApiKey for Nothing {}
112
113impl TryFrom<String> for Nothing {
114 type Error = &'static str;
115
116 fn try_from(_: String) -> Result<Self, Self::Error> {
117 Err(
118 "Tried to create a Nothing from a string - this should not happen, please file an issue",
119 )
120 }
121}
122
123#[derive(Clone)]
124pub struct Client<Ext = Nothing, H = reqwest::Client> {
125 base_url: Arc<str>,
126 headers: Arc<HeaderMap>,
127 http_client: H,
128 ext: Ext,
129}
130
131pub trait DebugExt: Debug {
132 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
133 std::iter::empty()
134 }
135}
136
137impl<Ext, H> std::fmt::Debug for Client<Ext, H>
138where
139 Ext: DebugExt,
140 H: std::fmt::Debug,
141{
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 let mut d = &mut f.debug_struct("Client");
144
145 d = d
146 .field("base_url", &self.base_url)
147 .field(
148 "headers",
149 &self.headers.iter().filter_map(|(k, v)| {
150 (!(k == http::header::AUTHORIZATION || k.as_str().contains("api-key")))
151 .then_some((k, v))
152 }),
153 )
154 .field("http_client", &self.http_client);
155
156 self.ext
157 .fields()
158 .fold(d, |d, (name, field)| d.field(name, field))
159 .finish()
160 }
161}
162
163pub enum Transport {
164 Http,
165 Sse,
166 NdJson,
167}
168
169pub trait Provider: Sized {
173 const VERIFY_PATH: &'static str;
174
175 type Builder: ProviderBuilder;
176
177 fn build<H>(
178 builder: &ClientBuilder<Self::Builder, <Self::Builder as ProviderBuilder>::ApiKey, H>,
179 ) -> http_client::Result<Self>;
180
181 fn build_uri(&self, base_url: &str, path: &str, _transport: Transport) -> String {
182 base_url.to_string() + "/" + path.trim_start_matches('/')
183 }
184
185 fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
186 Ok(req)
187 }
188}
189
190pub struct Capable<M>(PhantomData<M>);
192
193pub trait Capability {
194 const CAPABLE: bool;
195}
196
197impl<M> Capability for Capable<M> {
198 const CAPABLE: bool = true;
199}
200
201impl Capability for Nothing {
202 const CAPABLE: bool = false;
203}
204
205pub trait Capabilities<H = reqwest::Client> {
207 type Completion: Capability;
208 type Embeddings: Capability;
209 type Transcription: Capability;
210 #[cfg(feature = "image")]
211 type ImageGeneration: Capability;
212 #[cfg(feature = "audio")]
213 type AudioGeneration: Capability;
214}
215
216pub trait ProviderBuilder: Sized {
221 type Output: Provider;
222 type ApiKey;
223
224 const BASE_URL: &'static str;
225
226 fn finish<H>(
229 &self,
230 builder: ClientBuilder<Self, Self::ApiKey, H>,
231 ) -> http_client::Result<ClientBuilder<Self, Self::ApiKey, H>> {
232 Ok(builder)
233 }
234}
235
236impl<Ext, ExtBuilder, Key, H> Client<Ext, H>
237where
238 ExtBuilder: Clone + Default + ProviderBuilder<Output = Ext, ApiKey = Key>,
239 Ext: Provider<Builder = ExtBuilder>,
240 H: Default + HttpClientExt,
241 Key: ApiKey,
242{
243 pub fn new(api_key: impl Into<Key>) -> http_client::Result<Self> {
244 Self::builder().api_key(api_key).build()
245 }
246}
247
248impl<Ext, H> Client<Ext, H> {
249 pub fn base_url(&self) -> &str {
250 &self.base_url
251 }
252
253 pub fn headers(&self) -> &HeaderMap {
254 &self.headers
255 }
256
257 pub fn ext(&self) -> &Ext {
258 &self.ext
259 }
260
261 pub fn with_ext<NewExt>(self, new_ext: NewExt) -> Client<NewExt, H> {
262 Client {
263 base_url: self.base_url,
264 headers: self.headers,
265 http_client: self.http_client,
266 ext: new_ext,
267 }
268 }
269}
270
271impl<Ext, H> HttpClientExt for Client<Ext, H>
272where
273 H: HttpClientExt + 'static,
274 Ext: WasmCompatSend + WasmCompatSync + 'static,
275{
276 fn send<T, U>(
277 &self,
278 mut req: Request<T>,
279 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
280 where
281 T: Into<Bytes> + WasmCompatSend,
282 U: From<Bytes>,
283 U: WasmCompatSend + 'static,
284 {
285 req.headers_mut().insert(
286 http::header::CONTENT_TYPE,
287 http::HeaderValue::from_static("application/json"),
288 );
289
290 self.http_client.send(req)
291 }
292
293 fn send_multipart<U>(
294 &self,
295 req: Request<reqwest::multipart::Form>,
296 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
297 where
298 U: From<Bytes>,
299 U: WasmCompatSend + 'static,
300 {
301 self.http_client.send_multipart(req)
302 }
303
304 fn send_streaming<T>(
305 &self,
306 req: Request<T>,
307 ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
308 where
309 T: Into<Bytes>,
310 {
311 self.http_client.send_streaming(req)
312 }
313}
314
315impl<Ext, Builder, H> Client<Ext, H>
316where
317 H: Default + HttpClientExt,
318 Ext: Provider<Builder = Builder>,
319 Builder: Default + ProviderBuilder,
320{
321 pub fn builder() -> ClientBuilder<Builder, NeedsApiKey, H> {
322 ClientBuilder::default()
323 }
324}
325
326impl<Ext, H> Client<Ext, H>
327where
328 Ext: Provider,
329{
330 pub fn post<S>(&self, path: S) -> http_client::Result<Builder>
331 where
332 S: AsRef<str>,
333 {
334 let uri = self
335 .ext
336 .build_uri(&self.base_url, path.as_ref(), Transport::Http);
337
338 let mut req = Request::post(uri);
339
340 if let Some(hs) = req.headers_mut() {
341 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
342 }
343
344 self.ext.with_custom(req)
345 }
346
347 pub fn post_sse<S>(&self, path: S) -> http_client::Result<Builder>
348 where
349 S: AsRef<str>,
350 {
351 let uri = self
352 .ext
353 .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
354
355 let mut req = Request::post(uri);
356
357 if let Some(hs) = req.headers_mut() {
358 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
359 }
360
361 self.ext.with_custom(req)
362 }
363
364 pub fn get<S>(&self, path: S) -> http_client::Result<Builder>
365 where
366 S: AsRef<str>,
367 {
368 let uri = self
369 .ext
370 .build_uri(&self.base_url, path.as_ref(), Transport::Http);
371
372 self.ext.with_custom(Request::get(uri))
373 }
374}
375
376impl<Ext, H> VerifyClient for Client<Ext, H>
377where
378 H: HttpClientExt,
379 Ext: DebugExt + Provider + WasmCompatSync,
380{
381 async fn verify(&self) -> Result<(), VerifyError> {
382 use http::StatusCode;
383
384 let req = self
385 .get(Ext::VERIFY_PATH)?
386 .body(http_client::NoBody)
387 .map_err(http_client::Error::from)?;
388
389 let response = self.http_client.send(req).await?;
390
391 match response.status() {
392 StatusCode::OK => Ok(()),
393 StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
394 Err(VerifyError::InvalidAuthentication)
395 }
396 StatusCode::INTERNAL_SERVER_ERROR => {
397 let text = http_client::text(response).await?;
398 Err(VerifyError::ProviderError(text))
399 }
400 status if status.as_u16() == 529 => {
401 let text = http_client::text(response).await?;
402 Err(VerifyError::ProviderError(text))
403 }
404 _ => {
405 let status = response.status();
406
407 if status.is_success() {
408 Ok(())
409 } else {
410 let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
411 Err(VerifyError::HttpError(http_client::Error::Instance(
412 format!("Failed with '{status}': {text}").into(),
413 )))
414 }
415 }
416 }
417 }
418}
419
420#[derive(Debug, Default, Clone, Copy)]
421pub struct NeedsApiKey;
422
423#[derive(Clone)]
425pub struct ClientBuilder<Ext, ApiKey = NeedsApiKey, H = reqwest::Client> {
426 base_url: String,
427 api_key: ApiKey,
428 headers: HeaderMap,
429 http_client: Option<H>,
430 ext: Ext,
431}
432
433impl<ExtBuilder, H> Default for ClientBuilder<ExtBuilder, NeedsApiKey, H>
434where
435 H: Default,
436 ExtBuilder: ProviderBuilder + Default,
437{
438 fn default() -> Self {
439 Self {
440 api_key: NeedsApiKey,
441 headers: Default::default(),
442 base_url: ExtBuilder::BASE_URL.into(),
443 http_client: None,
444 ext: Default::default(),
445 }
446 }
447}
448
449impl<Ext, H> ClientBuilder<Ext, NeedsApiKey, H> {
450 pub fn api_key<ApiKey>(self, api_key: impl Into<ApiKey>) -> ClientBuilder<Ext, ApiKey, H> {
453 ClientBuilder {
454 api_key: api_key.into(),
455 base_url: self.base_url,
456 headers: self.headers,
457 http_client: self.http_client,
458 ext: self.ext,
459 }
460 }
461}
462
463impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H>
464where
465 Ext: Clone,
466{
467 pub(crate) fn over_ext<F, NewExt>(self, f: F) -> ClientBuilder<NewExt, ApiKey, H>
469 where
470 F: FnOnce(Ext) -> NewExt,
471 {
472 let ClientBuilder {
473 base_url,
474 api_key,
475 headers,
476 http_client,
477 ext,
478 } = self;
479
480 let new_ext = f(ext.clone());
481
482 ClientBuilder {
483 base_url,
484 api_key,
485 headers,
486 http_client,
487 ext: new_ext,
488 }
489 }
490
491 pub fn base_url<S>(self, base_url: S) -> Self
493 where
494 S: AsRef<str>,
495 {
496 Self {
497 base_url: base_url.as_ref().to_string(),
498 ..self
499 }
500 }
501
502 pub fn http_client<U>(self, http_client: U) -> ClientBuilder<Ext, ApiKey, U> {
504 ClientBuilder {
505 http_client: Some(http_client),
506 base_url: self.base_url,
507 api_key: self.api_key,
508 headers: self.headers,
509 ext: self.ext,
510 }
511 }
512
513 pub(crate) fn headers_mut(&mut self) -> &mut HeaderMap {
514 &mut self.headers
515 }
516
517 pub(crate) fn ext_mut(&mut self) -> &mut Ext {
518 &mut self.ext
519 }
520}
521
522impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H> {
523 pub(crate) fn get_api_key(&self) -> &ApiKey {
524 &self.api_key
525 }
526}
527
528impl<Ext, Key, H> ClientBuilder<Ext, Key, H> {
529 pub fn ext(&self) -> &Ext {
530 &self.ext
531 }
532}
533
534impl<Ext, ExtBuilder, Key, H> ClientBuilder<ExtBuilder, Key, H>
535where
536 ExtBuilder: Clone + ProviderBuilder<Output = Ext, ApiKey = Key> + Default,
537 Ext: Provider<Builder = ExtBuilder>,
538 Key: ApiKey,
539 H: Default,
540{
541 pub fn build(mut self) -> http_client::Result<Client<ExtBuilder::Output, H>> {
542 let ext = self.ext.clone();
543
544 self = ext.finish(self)?;
545 let ext = Ext::build(&self)?;
546
547 let ClientBuilder {
548 http_client,
549 base_url,
550 mut headers,
551 api_key,
552 ..
553 } = self;
554
555 if let Some((k, v)) = api_key.into_header().transpose()? {
556 headers.insert(k, v);
557 }
558
559 let http_client = http_client.unwrap_or_default();
560
561 Ok(Client {
562 http_client,
563 base_url: Arc::from(base_url.as_str()),
564 headers: Arc::new(headers),
565 ext,
566 })
567 }
568}
569
570impl<M, Ext, H> CompletionClient for Client<Ext, H>
571where
572 Ext: Capabilities<H, Completion = Capable<M>>,
573 M: CompletionModel<Client = Self>,
574{
575 type CompletionModel = M;
576
577 fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
578 M::make(self, model)
579 }
580}
581
582impl<M, Ext, H> EmbeddingsClient for Client<Ext, H>
583where
584 Ext: Capabilities<H, Embeddings = Capable<M>>,
585 M: EmbeddingModel<Client = Self>,
586{
587 type EmbeddingModel = M;
588
589 fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
590 M::make(self, model, None)
591 }
592
593 fn embedding_model_with_ndims(
594 &self,
595 model: impl Into<String>,
596 ndims: usize,
597 ) -> Self::EmbeddingModel {
598 M::make(self, model, Some(ndims))
599 }
600}
601
602impl<M, Ext, H> TranscriptionClient for Client<Ext, H>
603where
604 Ext: Capabilities<H, Transcription = Capable<M>>,
605 M: TranscriptionModel<Client = Self> + WasmCompatSend,
606{
607 type TranscriptionModel = M;
608
609 fn transcription_model(&self, model: impl Into<String>) -> Self::TranscriptionModel {
610 M::make(self, model)
611 }
612}
613
614#[cfg(feature = "image")]
615impl<M, Ext, H> ImageGenerationClient for Client<Ext, H>
616where
617 Ext: Capabilities<H, ImageGeneration = Capable<M>>,
618 M: ImageGenerationModel<Client = Self>,
619{
620 type ImageGenerationModel = M;
621
622 fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
623 M::make(self, model)
624 }
625}
626
627#[cfg(feature = "audio")]
628impl<M, Ext, H> AudioGenerationClient for Client<Ext, H>
629where
630 Ext: Capabilities<H, AudioGeneration = Capable<M>>,
631 M: AudioGenerationModel<Client = Self>,
632{
633 type AudioGenerationModel = M;
634
635 fn audio_generation_model(&self, model: impl Into<String>) -> Self::AudioGenerationModel {
636 M::make(self, model)
637 }
638}