1pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11pub mod verify;
12
13use bytes::Bytes;
14pub use completion::CompletionClient;
15pub use embeddings::EmbeddingsClient;
16use http::{HeaderMap, HeaderName, HeaderValue};
17use serde::{Deserialize, Serialize};
18use std::{fmt::Debug, marker::PhantomData, sync::Arc};
19use thiserror::Error;
20pub use verify::{VerifyClient, VerifyError};
21
22#[cfg(feature = "image")]
23use crate::image_generation::ImageGenerationModel;
24#[cfg(feature = "image")]
25use image_generation::ImageGenerationClient;
26
27#[cfg(feature = "audio")]
28use crate::audio_generation::*;
29#[cfg(feature = "audio")]
30use audio_generation::*;
31
32use crate::{
33 completion::CompletionModel,
34 embeddings::EmbeddingModel,
35 http_client::{
36 self, Builder, HttpClientExt, LazyBody, MultipartForm, Request, Response, make_auth_header,
37 },
38 prelude::TranscriptionClient,
39 transcription::TranscriptionModel,
40 wasm_compat::{WasmCompatSend, WasmCompatSync},
41};
42
43#[derive(Debug, Error)]
44#[non_exhaustive]
45pub enum ClientBuilderError {
46 #[error("reqwest error: {0}")]
47 HttpError(
48 #[from]
49 #[source]
50 reqwest::Error,
51 ),
52 #[error("invalid property: {0}")]
53 InvalidProperty(&'static str),
54}
55
56pub trait ProviderClient {
59 type Input;
60
61 fn from_env() -> Self;
64
65 fn from_val(input: Self::Input) -> Self;
66}
67
68use crate::completion::{GetTokenUsage, Usage};
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct FinalCompletionResponse {
73 pub usage: Option<Usage>,
74}
75
76impl GetTokenUsage for FinalCompletionResponse {
77 fn token_usage(&self) -> Option<Usage> {
78 self.usage
79 }
80}
81
82pub trait ApiKey: Sized {
85 fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
86 None
87 }
88}
89
90pub struct BearerAuth(String);
92
93impl ApiKey for BearerAuth {
94 fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
95 Some(make_auth_header(self.0))
96 }
97}
98
99impl<S> From<S> for BearerAuth
100where
101 S: Into<String>,
102{
103 fn from(value: S) -> Self {
104 Self(value.into())
105 }
106}
107
108#[derive(Debug, Default, Clone, Copy)]
111pub struct Nothing;
112
113impl ApiKey for Nothing {}
114
115impl TryFrom<String> for Nothing {
116 type Error = &'static str;
117
118 fn try_from(_: String) -> Result<Self, Self::Error> {
119 Err(
120 "Tried to create a Nothing from a string - this should not happen, please file an issue",
121 )
122 }
123}
124
125#[derive(Clone)]
126pub struct Client<Ext = Nothing, H = reqwest::Client> {
127 base_url: Arc<str>,
128 headers: Arc<HeaderMap>,
129 http_client: H,
130 ext: Ext,
131}
132
133pub trait DebugExt: Debug {
134 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
135 std::iter::empty()
136 }
137}
138
139impl<Ext, H> std::fmt::Debug for Client<Ext, H>
140where
141 Ext: DebugExt,
142 H: std::fmt::Debug,
143{
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 let mut d = &mut f.debug_struct("Client");
146
147 d = d
148 .field("base_url", &self.base_url)
149 .field(
150 "headers",
151 &self
152 .headers
153 .iter()
154 .filter_map(|(k, v)| {
155 if k == http::header::AUTHORIZATION || k.as_str().contains("api-key") {
156 None
157 } else {
158 Some((k, v))
159 }
160 })
161 .collect::<Vec<(&HeaderName, &HeaderValue)>>(),
162 )
163 .field("http_client", &self.http_client);
164
165 self.ext
166 .fields()
167 .fold(d, |d, (name, field)| d.field(name, field))
168 .finish()
169 }
170}
171
172pub enum Transport {
173 Http,
174 Sse,
175 NdJson,
176}
177
178pub trait Provider: Sized {
182 const VERIFY_PATH: &'static str;
183
184 type Builder: ProviderBuilder;
185
186 fn build<H>(
187 builder: &ClientBuilder<Self::Builder, <Self::Builder as ProviderBuilder>::ApiKey, H>,
188 ) -> http_client::Result<Self>;
189
190 fn build_uri(&self, base_url: &str, path: &str, _transport: Transport) -> String {
191 base_url.to_string() + "/" + path.trim_start_matches('/')
192 }
193
194 fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
195 Ok(req)
196 }
197}
198
199pub struct Capable<M>(PhantomData<M>);
201
202pub trait Capability {
203 const CAPABLE: bool;
204}
205
206impl<M> Capability for Capable<M> {
207 const CAPABLE: bool = true;
208}
209
210impl Capability for Nothing {
211 const CAPABLE: bool = false;
212}
213
214pub trait Capabilities<H = reqwest::Client> {
216 type Completion: Capability;
217 type Embeddings: Capability;
218 type Transcription: Capability;
219 #[cfg(feature = "image")]
220 type ImageGeneration: Capability;
221 #[cfg(feature = "audio")]
222 type AudioGeneration: Capability;
223}
224
225pub trait ProviderBuilder: Sized {
230 type Output: Provider;
231 type ApiKey;
232
233 const BASE_URL: &'static str;
234
235 fn finish<H>(
238 &self,
239 builder: ClientBuilder<Self, Self::ApiKey, H>,
240 ) -> http_client::Result<ClientBuilder<Self, Self::ApiKey, H>> {
241 Ok(builder)
242 }
243}
244
245impl<Ext, ExtBuilder, Key, H> Client<Ext, H>
246where
247 ExtBuilder: Clone + Default + ProviderBuilder<Output = Ext, ApiKey = Key>,
248 Ext: Provider<Builder = ExtBuilder>,
249 H: Default + HttpClientExt,
250 Key: ApiKey,
251{
252 pub fn new(api_key: impl Into<Key>) -> http_client::Result<Self> {
253 Self::builder().api_key(api_key).build()
254 }
255}
256
257impl<Ext, H> Client<Ext, H> {
258 pub fn base_url(&self) -> &str {
259 &self.base_url
260 }
261
262 pub fn headers(&self) -> &HeaderMap {
263 &self.headers
264 }
265
266 pub fn ext(&self) -> &Ext {
267 &self.ext
268 }
269
270 pub fn with_ext<NewExt>(self, new_ext: NewExt) -> Client<NewExt, H> {
271 Client {
272 base_url: self.base_url,
273 headers: self.headers,
274 http_client: self.http_client,
275 ext: new_ext,
276 }
277 }
278}
279
280impl<Ext, H> HttpClientExt for Client<Ext, H>
281where
282 H: HttpClientExt + 'static,
283 Ext: WasmCompatSend + WasmCompatSync + 'static,
284{
285 fn send<T, U>(
286 &self,
287 mut req: Request<T>,
288 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
289 where
290 T: Into<Bytes> + WasmCompatSend,
291 U: From<Bytes>,
292 U: WasmCompatSend + 'static,
293 {
294 req.headers_mut().insert(
295 http::header::CONTENT_TYPE,
296 http::HeaderValue::from_static("application/json"),
297 );
298
299 self.http_client.send(req)
300 }
301
302 fn send_multipart<U>(
303 &self,
304 req: Request<MultipartForm>,
305 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
306 where
307 U: From<Bytes>,
308 U: WasmCompatSend + 'static,
309 {
310 self.http_client.send_multipart(req)
311 }
312
313 fn send_streaming<T>(
314 &self,
315 req: Request<T>,
316 ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
317 where
318 T: Into<Bytes>,
319 {
320 self.http_client.send_streaming(req)
321 }
322}
323
324impl<Ext, Builder, H> Client<Ext, H>
325where
326 H: Default + HttpClientExt,
327 Ext: Provider<Builder = Builder>,
328 Builder: Default + ProviderBuilder,
329{
330 pub fn builder() -> ClientBuilder<Builder, NeedsApiKey, H> {
331 ClientBuilder::default()
332 }
333}
334
335impl<Ext, H> Client<Ext, H>
336where
337 Ext: Provider,
338{
339 pub fn post<S>(&self, path: S) -> http_client::Result<Builder>
340 where
341 S: AsRef<str>,
342 {
343 let uri = self
344 .ext
345 .build_uri(&self.base_url, path.as_ref(), Transport::Http);
346
347 let mut req = Request::post(uri);
348
349 if let Some(hs) = req.headers_mut() {
350 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
351 }
352
353 self.ext.with_custom(req)
354 }
355
356 pub fn post_sse<S>(&self, path: S) -> http_client::Result<Builder>
357 where
358 S: AsRef<str>,
359 {
360 let uri = self
361 .ext
362 .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
363
364 let mut req = Request::post(uri);
365
366 if let Some(hs) = req.headers_mut() {
367 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
368 }
369
370 self.ext.with_custom(req)
371 }
372
373 pub fn get<S>(&self, path: S) -> http_client::Result<Builder>
374 where
375 S: AsRef<str>,
376 {
377 let uri = self
378 .ext
379 .build_uri(&self.base_url, path.as_ref(), Transport::Http);
380
381 self.ext.with_custom(Request::get(uri))
382 }
383}
384
385impl<Ext, H> VerifyClient for Client<Ext, H>
386where
387 H: HttpClientExt,
388 Ext: DebugExt + Provider + WasmCompatSync,
389{
390 async fn verify(&self) -> Result<(), VerifyError> {
391 use http::StatusCode;
392
393 let req = self
394 .get(Ext::VERIFY_PATH)?
395 .body(http_client::NoBody)
396 .map_err(http_client::Error::from)?;
397
398 let response = self.http_client.send(req).await?;
399
400 match response.status() {
401 StatusCode::OK => Ok(()),
402 StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
403 Err(VerifyError::InvalidAuthentication)
404 }
405 StatusCode::INTERNAL_SERVER_ERROR => {
406 let text = http_client::text(response).await?;
407 Err(VerifyError::ProviderError(text))
408 }
409 status if status.as_u16() == 529 => {
410 let text = http_client::text(response).await?;
411 Err(VerifyError::ProviderError(text))
412 }
413 _ => {
414 let status = response.status();
415
416 if status.is_success() {
417 Ok(())
418 } else {
419 let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
420 Err(VerifyError::HttpError(http_client::Error::Instance(
421 format!("Failed with '{status}': {text}").into(),
422 )))
423 }
424 }
425 }
426 }
427}
428
429#[derive(Debug, Default, Clone, Copy)]
430pub struct NeedsApiKey;
431
432#[derive(Clone)]
434pub struct ClientBuilder<Ext, ApiKey = NeedsApiKey, H = reqwest::Client> {
435 base_url: String,
436 api_key: ApiKey,
437 headers: HeaderMap,
438 http_client: Option<H>,
439 ext: Ext,
440}
441
442impl<ExtBuilder, H> Default for ClientBuilder<ExtBuilder, NeedsApiKey, H>
443where
444 H: Default,
445 ExtBuilder: ProviderBuilder + Default,
446{
447 fn default() -> Self {
448 Self {
449 api_key: NeedsApiKey,
450 headers: Default::default(),
451 base_url: ExtBuilder::BASE_URL.into(),
452 http_client: None,
453 ext: Default::default(),
454 }
455 }
456}
457
458impl<Ext, H> ClientBuilder<Ext, NeedsApiKey, H> {
459 pub fn api_key<ApiKey>(self, api_key: impl Into<ApiKey>) -> ClientBuilder<Ext, ApiKey, H> {
462 ClientBuilder {
463 api_key: api_key.into(),
464 base_url: self.base_url,
465 headers: self.headers,
466 http_client: self.http_client,
467 ext: self.ext,
468 }
469 }
470}
471
472impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H>
473where
474 Ext: Clone,
475{
476 pub(crate) fn over_ext<F, NewExt>(self, f: F) -> ClientBuilder<NewExt, ApiKey, H>
478 where
479 F: FnOnce(Ext) -> NewExt,
480 {
481 let ClientBuilder {
482 base_url,
483 api_key,
484 headers,
485 http_client,
486 ext,
487 } = self;
488
489 let new_ext = f(ext.clone());
490
491 ClientBuilder {
492 base_url,
493 api_key,
494 headers,
495 http_client,
496 ext: new_ext,
497 }
498 }
499
500 pub fn base_url<S>(self, base_url: S) -> Self
502 where
503 S: AsRef<str>,
504 {
505 Self {
506 base_url: base_url.as_ref().to_string(),
507 ..self
508 }
509 }
510
511 pub fn http_client<U>(self, http_client: U) -> ClientBuilder<Ext, ApiKey, U> {
513 ClientBuilder {
514 http_client: Some(http_client),
515 base_url: self.base_url,
516 api_key: self.api_key,
517 headers: self.headers,
518 ext: self.ext,
519 }
520 }
521
522 pub(crate) fn headers_mut(&mut self) -> &mut HeaderMap {
523 &mut self.headers
524 }
525
526 pub(crate) fn ext_mut(&mut self) -> &mut Ext {
527 &mut self.ext
528 }
529}
530
531impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H> {
532 pub(crate) fn get_api_key(&self) -> &ApiKey {
533 &self.api_key
534 }
535}
536
537impl<Ext, Key, H> ClientBuilder<Ext, Key, H> {
538 pub fn ext(&self) -> &Ext {
539 &self.ext
540 }
541}
542
543impl<Ext, ExtBuilder, Key, H> ClientBuilder<ExtBuilder, Key, H>
544where
545 ExtBuilder: Clone + ProviderBuilder<Output = Ext, ApiKey = Key> + Default,
546 Ext: Provider<Builder = ExtBuilder>,
547 Key: ApiKey,
548 H: Default,
549{
550 pub fn build(mut self) -> http_client::Result<Client<ExtBuilder::Output, H>> {
551 let ext = self.ext.clone();
552
553 self = ext.finish(self)?;
554 let ext = Ext::build(&self)?;
555
556 let ClientBuilder {
557 http_client,
558 base_url,
559 mut headers,
560 api_key,
561 ..
562 } = self;
563
564 if let Some((k, v)) = api_key.into_header().transpose()? {
565 headers.insert(k, v);
566 }
567
568 let http_client = http_client.unwrap_or_default();
569
570 Ok(Client {
571 http_client,
572 base_url: Arc::from(base_url.as_str()),
573 headers: Arc::new(headers),
574 ext,
575 })
576 }
577}
578
579impl<M, Ext, H> CompletionClient for Client<Ext, H>
580where
581 Ext: Capabilities<H, Completion = Capable<M>>,
582 M: CompletionModel<Client = Self>,
583{
584 type CompletionModel = M;
585
586 fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
587 M::make(self, model)
588 }
589}
590
591impl<M, Ext, H> EmbeddingsClient for Client<Ext, H>
592where
593 Ext: Capabilities<H, Embeddings = Capable<M>>,
594 M: EmbeddingModel<Client = Self>,
595{
596 type EmbeddingModel = M;
597
598 fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
599 M::make(self, model, None)
600 }
601
602 fn embedding_model_with_ndims(
603 &self,
604 model: impl Into<String>,
605 ndims: usize,
606 ) -> Self::EmbeddingModel {
607 M::make(self, model, Some(ndims))
608 }
609}
610
611impl<M, Ext, H> TranscriptionClient for Client<Ext, H>
612where
613 Ext: Capabilities<H, Transcription = Capable<M>>,
614 M: TranscriptionModel<Client = Self> + WasmCompatSend,
615{
616 type TranscriptionModel = M;
617
618 fn transcription_model(&self, model: impl Into<String>) -> Self::TranscriptionModel {
619 M::make(self, model)
620 }
621}
622
623#[cfg(feature = "image")]
624impl<M, Ext, H> ImageGenerationClient for Client<Ext, H>
625where
626 Ext: Capabilities<H, ImageGeneration = Capable<M>>,
627 M: ImageGenerationModel<Client = Self>,
628{
629 type ImageGenerationModel = M;
630
631 fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
632 M::make(self, model)
633 }
634}
635
636#[cfg(feature = "audio")]
637impl<M, Ext, H> AudioGenerationClient for Client<Ext, H>
638where
639 Ext: Capabilities<H, AudioGeneration = Capable<M>>,
640 M: AudioGenerationModel<Client = Self>,
641{
642 type AudioGenerationModel = M;
643
644 fn audio_generation_model(&self, model: impl Into<String>) -> Self::AudioGenerationModel {
645 M::make(self, model)
646 }
647}