1pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11pub mod verify;
12
13use bytes::Bytes;
14pub use completion::CompletionClient;
15pub use embeddings::EmbeddingsClient;
16use http::{HeaderMap, HeaderName, HeaderValue};
17use serde::{Deserialize, Serialize};
18use std::{fmt::Debug, marker::PhantomData};
19use thiserror::Error;
20pub use verify::{VerifyClient, VerifyError};
21
22#[cfg(feature = "image")]
23use crate::image_generation::ImageGenerationModel;
24#[cfg(feature = "image")]
25use image_generation::ImageGenerationClient;
26
27#[cfg(feature = "audio")]
28use crate::audio_generation::*;
29#[cfg(feature = "audio")]
30use audio_generation::*;
31
32use crate::{
33 completion::CompletionModel,
34 embeddings::EmbeddingModel,
35 http_client::{self, Builder, HttpClientExt, LazyBody, Request, Response, make_auth_header},
36 prelude::TranscriptionClient,
37 transcription::TranscriptionModel,
38 wasm_compat::{WasmCompatSend, WasmCompatSync},
39};
40
41#[derive(Debug, Error)]
42#[non_exhaustive]
43pub enum ClientBuilderError {
44 #[error("reqwest error: {0}")]
45 HttpError(
46 #[from]
47 #[source]
48 reqwest::Error,
49 ),
50 #[error("invalid property: {0}")]
51 InvalidProperty(&'static str),
52}
53
54pub trait ProviderClient {
57 type Input;
58
59 fn from_env() -> Self;
62
63 fn from_val(input: Self::Input) -> Self;
64}
65
66use crate::completion::{GetTokenUsage, Usage};
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct FinalCompletionResponse {
71 pub usage: Option<Usage>,
72}
73
74impl GetTokenUsage for FinalCompletionResponse {
75 fn token_usage(&self) -> Option<Usage> {
76 self.usage
77 }
78}
79
80pub trait ApiKey: Sized {
83 fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
84 None
85 }
86}
87
88pub struct BearerAuth(String);
90
91impl ApiKey for BearerAuth {
92 fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
93 Some(make_auth_header(self.0))
94 }
95}
96
97impl<S> From<S> for BearerAuth
98where
99 S: Into<String>,
100{
101 fn from(value: S) -> Self {
102 Self(value.into())
103 }
104}
105
106#[derive(Debug, Default, Clone, Copy)]
109pub struct Nothing;
110
111impl ApiKey for Nothing {}
112
113impl TryFrom<String> for Nothing {
114 type Error = &'static str;
115
116 fn try_from(_: String) -> Result<Self, Self::Error> {
117 Err(
118 "Tried to create a Nothing from a string - this should not happen, please file an issue",
119 )
120 }
121}
122
123#[derive(Clone)]
125pub struct Client<Ext = Nothing, H = reqwest::Client> {
126 base_url: String,
127 headers: HeaderMap,
128 http_client: H,
129 ext: Ext,
130}
131
132pub trait DebugExt: Debug {
133 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
134 std::iter::empty()
135 }
136}
137
138impl<Ext, H> std::fmt::Debug for Client<Ext, H>
139where
140 Ext: DebugExt,
141 H: std::fmt::Debug,
142{
143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144 let mut d = &mut f.debug_struct("Client");
145
146 d = d
147 .field("base_url", &self.base_url)
148 .field(
149 "headers",
150 &self.headers.iter().filter_map(|(k, v)| {
151 (!(k == http::header::AUTHORIZATION || k.as_str().contains("api-key")))
152 .then_some((k, v))
153 }),
154 )
155 .field("http_client", &self.http_client);
156
157 self.ext
158 .fields()
159 .fold(d, |d, (name, field)| d.field(name, field))
160 .finish()
161 }
162}
163
164pub enum Transport {
165 Http,
166 Sse,
167 NdJson,
168}
169
170pub trait Provider: Sized {
174 const VERIFY_PATH: &'static str;
175
176 type Builder: ProviderBuilder;
177
178 fn build<H>(
179 builder: &ClientBuilder<Self::Builder, <Self::Builder as ProviderBuilder>::ApiKey, H>,
180 ) -> http_client::Result<Self>;
181
182 fn build_uri(&self, base_url: &str, path: &str, _transport: Transport) -> String {
183 base_url.to_string() + "/" + path.trim_start_matches('/')
184 }
185
186 fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
187 Ok(req)
188 }
189}
190
191pub struct Capable<M>(PhantomData<M>);
193
194pub trait Capability {
195 const CAPABLE: bool;
196}
197
198impl<M> Capability for Capable<M> {
199 const CAPABLE: bool = true;
200}
201
202impl Capability for Nothing {
203 const CAPABLE: bool = false;
204}
205
206pub trait Capabilities<H = reqwest::Client> {
208 type Completion: Capability;
209 type Embeddings: Capability;
210 type Transcription: Capability;
211 #[cfg(feature = "image")]
212 type ImageGeneration: Capability;
213 #[cfg(feature = "audio")]
214 type AudioGeneration: Capability;
215}
216
217pub trait ProviderBuilder: Sized {
222 type Output: Provider;
223 type ApiKey;
224
225 const BASE_URL: &'static str;
226
227 fn finish<H>(
230 &self,
231 builder: ClientBuilder<Self, Self::ApiKey, H>,
232 ) -> http_client::Result<ClientBuilder<Self, Self::ApiKey, H>> {
233 Ok(builder)
234 }
235}
236
237impl<Ext, ExtBuilder, Key, H> Client<Ext, H>
238where
239 ExtBuilder: Clone + Default + ProviderBuilder<Output = Ext, ApiKey = Key>,
240 Ext: Provider<Builder = ExtBuilder>,
241 H: Default + HttpClientExt,
242 Key: ApiKey,
243{
244 pub fn new(api_key: impl Into<Key>) -> http_client::Result<Self> {
245 Self::builder().api_key(api_key).build()
246 }
247}
248
249impl<Ext, H> Client<Ext, H> {
250 pub fn base_url(&self) -> &str {
251 &self.base_url
252 }
253
254 pub fn headers(&self) -> &HeaderMap {
255 &self.headers
256 }
257
258 pub fn http_client(&self) -> &H {
259 &self.http_client
260 }
261
262 pub fn ext(&self) -> &Ext {
263 &self.ext
264 }
265
266 pub fn from_parts(base_url: String, headers: HeaderMap, http_client: H, ext: Ext) -> Self {
269 Self {
270 base_url,
271 headers,
272 http_client,
273 ext,
274 }
275 }
276}
277
278impl<Ext, Builder, H> Client<Ext, H>
279where
280 H: Default + HttpClientExt,
281 Ext: Provider<Builder = Builder>,
282 Builder: Default + ProviderBuilder,
283{
284 pub fn builder() -> ClientBuilder<Builder, NeedsApiKey, H> {
285 ClientBuilder::default()
286 }
287}
288
289impl<Ext, H> Client<Ext, H>
290where
291 Ext: Provider,
292{
293 pub fn post<S>(&self, path: S) -> http_client::Result<Builder>
294 where
295 S: AsRef<str>,
296 {
297 let uri = self
298 .ext
299 .build_uri(&self.base_url, path.as_ref(), Transport::Http);
300
301 let mut req = Request::post(uri);
302
303 if let Some(hs) = req.headers_mut() {
304 hs.extend(self.headers.clone());
305 }
306
307 self.ext.with_custom(req)
308 }
309
310 pub fn post_sse<S>(&self, path: S) -> http_client::Result<Builder>
311 where
312 S: AsRef<str>,
313 {
314 let uri = self
315 .ext
316 .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
317
318 let mut req = Request::post(uri);
319
320 if let Some(hs) = req.headers_mut() {
321 hs.extend(self.headers.clone());
322 }
323
324 self.ext.with_custom(req)
325 }
326
327 pub fn get<S>(&self, path: S) -> http_client::Result<Builder>
328 where
329 S: AsRef<str>,
330 {
331 let uri = self
332 .ext
333 .build_uri(&self.base_url, path.as_ref(), Transport::Http);
334
335 self.ext.with_custom(Request::get(uri))
336 }
337}
338
339impl<Ext, H> Client<Ext, H>
340where
341 H: HttpClientExt,
342{
343 pub async fn send<T, U>(
344 &self,
345 mut req: Request<T>,
346 ) -> http_client::Result<Response<LazyBody<U>>>
347 where
348 T: std::fmt::Debug + Into<Bytes> + WasmCompatSend,
349 U: std::fmt::Debug + From<Bytes> + WasmCompatSend + 'static,
350 {
351 req.headers_mut().insert(
352 http::header::CONTENT_TYPE,
353 http::HeaderValue::from_static("application/json"),
354 );
355
356 dbg!(&req);
357
358 self.http_client.send(req).await
359 }
360
361 pub async fn send_streaming<U, R>(
362 &self,
363 req: Request<U>,
364 ) -> Result<http_client::StreamingResponse, http_client::Error>
365 where
366 U: std::fmt::Debug + Into<Bytes>,
367 {
368 self.http_client.send_streaming(req).await
369 }
370}
371
372impl<Ext, H> VerifyClient for Client<Ext, H>
373where
374 H: HttpClientExt,
375 Ext: DebugExt + Provider + WasmCompatSync,
376{
377 async fn verify(&self) -> Result<(), VerifyError> {
378 use http::StatusCode;
379
380 let req = self
381 .get(Ext::VERIFY_PATH)?
382 .body(http_client::NoBody)
383 .map_err(http_client::Error::from)?;
384
385 let response = self.http_client.send(req).await?;
386
387 match response.status() {
388 StatusCode::OK => Ok(()),
389 StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
390 Err(VerifyError::InvalidAuthentication)
391 }
392 StatusCode::INTERNAL_SERVER_ERROR => {
393 let text = http_client::text(response).await?;
394 Err(VerifyError::ProviderError(text))
395 }
396 status if status.as_u16() == 529 => {
397 let text = http_client::text(response).await?;
398 Err(VerifyError::ProviderError(text))
399 }
400 _ => {
401 let status = response.status();
402
403 if status.is_success() {
404 Ok(())
405 } else {
406 let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
407 Err(VerifyError::HttpError(http_client::Error::Instance(
408 format!("Failed with '{status}': {text}").into(),
409 )))
410 }
411 }
412 }
413 }
414}
415
416#[derive(Debug, Default, Clone, Copy)]
417pub struct NeedsApiKey;
418
419#[derive(Clone)]
421pub struct ClientBuilder<Ext, ApiKey = NeedsApiKey, H = reqwest::Client> {
422 base_url: String,
423 api_key: ApiKey,
424 headers: HeaderMap,
425 http_client: Option<H>,
426 ext: Ext,
427}
428
429impl<ExtBuilder, H> Default for ClientBuilder<ExtBuilder, NeedsApiKey, H>
430where
431 H: Default,
432 ExtBuilder: ProviderBuilder + Default,
433{
434 fn default() -> Self {
435 Self {
436 api_key: NeedsApiKey,
437 headers: Default::default(),
438 base_url: ExtBuilder::BASE_URL.into(),
439 http_client: None,
440 ext: Default::default(),
441 }
442 }
443}
444
445impl<Ext, H> ClientBuilder<Ext, NeedsApiKey, H> {
446 pub fn api_key<ApiKey>(self, api_key: impl Into<ApiKey>) -> ClientBuilder<Ext, ApiKey, H> {
449 ClientBuilder {
450 api_key: api_key.into(),
451 base_url: self.base_url,
452 headers: self.headers,
453 http_client: self.http_client,
454 ext: self.ext,
455 }
456 }
457}
458
459impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H>
460where
461 Ext: Clone,
462{
463 pub(crate) fn over_ext<F, NewExt>(self, f: F) -> ClientBuilder<NewExt, ApiKey, H>
465 where
466 F: FnOnce(Ext) -> NewExt,
467 {
468 let ClientBuilder {
469 base_url,
470 api_key,
471 headers,
472 http_client,
473 ext,
474 } = self;
475
476 let new_ext = f(ext.clone());
477
478 ClientBuilder {
479 base_url,
480 api_key,
481 headers,
482 http_client,
483 ext: new_ext,
484 }
485 }
486
487 pub fn base_url<S>(self, base_url: S) -> Self
489 where
490 S: AsRef<str>,
491 {
492 Self {
493 base_url: base_url.as_ref().to_string(),
494 ..self
495 }
496 }
497
498 pub fn http_client<U>(self, http_client: U) -> ClientBuilder<Ext, ApiKey, U> {
500 ClientBuilder {
501 http_client: Some(http_client),
502 base_url: self.base_url,
503 api_key: self.api_key,
504 headers: self.headers,
505 ext: self.ext,
506 }
507 }
508
509 pub(crate) fn headers_mut(&mut self) -> &mut HeaderMap {
510 &mut self.headers
511 }
512
513 pub(crate) fn ext_mut(&mut self) -> &mut Ext {
514 &mut self.ext
515 }
516}
517
518impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H> {
519 pub(crate) fn get_api_key(&self) -> &ApiKey {
520 &self.api_key
521 }
522}
523
524impl<Ext, Key, H> ClientBuilder<Ext, Key, H> {
525 pub fn ext(&self) -> &Ext {
526 &self.ext
527 }
528}
529
530impl<Ext, ExtBuilder, Key, H> ClientBuilder<ExtBuilder, Key, H>
531where
532 ExtBuilder: Clone + ProviderBuilder<Output = Ext, ApiKey = Key> + Default,
533 Ext: Provider<Builder = ExtBuilder>,
534 Key: ApiKey,
535 H: Default,
536{
537 pub fn build(mut self) -> http_client::Result<Client<ExtBuilder::Output, H>> {
538 let ext = self.ext.clone();
539
540 self = ext.finish(self)?;
541 let ext = Ext::build(&self)?;
542
543 let ClientBuilder {
544 http_client,
545 base_url,
546 mut headers,
547 api_key,
548 ..
549 } = self;
550
551 if let Some((k, v)) = api_key.into_header().transpose()? {
552 headers.insert(k, v);
553 }
554
555 let http_client = http_client.unwrap_or_default();
556
557 Ok(Client {
558 http_client,
559 base_url,
560 headers,
561 ext,
562 })
563 }
564}
565
566impl<M, Ext, H> CompletionClient for Client<Ext, H>
567where
568 Ext: Capabilities<H, Completion = Capable<M>>,
569 M: CompletionModel<Client = Self>,
570{
571 type CompletionModel = M;
572
573 fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
574 M::make(self, model)
575 }
576}
577
578impl<M, Ext, H> EmbeddingsClient for Client<Ext, H>
579where
580 Ext: Capabilities<H, Embeddings = Capable<M>>,
581 M: EmbeddingModel<Client = Self>,
582{
583 type EmbeddingModel = M;
584
585 fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
586 M::make(self, model, None)
587 }
588
589 fn embedding_model_with_ndims(
590 &self,
591 model: impl Into<String>,
592 ndims: usize,
593 ) -> Self::EmbeddingModel {
594 M::make(self, model, Some(ndims))
595 }
596}
597
598impl<M, Ext, H> TranscriptionClient for Client<Ext, H>
599where
600 Ext: Capabilities<H, Transcription = Capable<M>>,
601 M: TranscriptionModel<Client = Self> + WasmCompatSend,
602{
603 type TranscriptionModel = M;
604
605 fn transcription_model(&self, model: impl Into<String>) -> Self::TranscriptionModel {
606 M::make(self, model)
607 }
608}
609
610#[cfg(feature = "image")]
611impl<M, Ext, H> ImageGenerationClient for Client<Ext, H>
612where
613 Ext: Capabilities<H, ImageGeneration = Capable<M>>,
614 M: ImageGenerationModel<Client = Self>,
615{
616 type ImageGenerationModel = M;
617
618 fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
619 M::make(self, model)
620 }
621}
622
623#[cfg(feature = "audio")]
624impl<M, Ext, H> AudioGenerationClient for Client<Ext, H>
625where
626 Ext: Capabilities<H, AudioGeneration = Capable<M>>,
627 M: AudioGenerationModel<Client = Self>,
628{
629 type AudioGenerationModel = M;
630
631 fn audio_generation_model(&self, model: impl Into<String>) -> Self::AudioGenerationModel {
632 M::make(self, model)
633 }
634}