1mod api;
2mod middleware;
3
4pub use api::*;
5
6use anyhow::{anyhow, Result};
7use bytes::Bytes;
8use derive_builder::Builder;
9use middleware::RetryMiddleware;
10use reqwest::Response;
11use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, RequestBuilder};
12use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
13use reqwest_tracing::TracingMiddleware;
14use schemars::{schema_for, JsonSchema};
15use std::time::Duration;
16use tracing::error;
17
18const TIMEOUT: u64 = 60;
19const MAX_RETRIES: u32 = 3;
20
21#[derive(Debug, Clone, Builder)]
22pub struct LlmSdk {
23 #[builder(setter(into), default = r#""https://api.openai.com/v1".into()"#)]
24 pub(crate) base_url: String,
25 #[builder(setter(into))]
26 pub(crate) token: String,
27 #[allow(dead_code)]
28 #[builder(default = "3")]
29 pub(crate) max_retries: u32,
30 #[builder(setter(skip), default = "self.default_client()")]
31 pub(crate) client: ClientWithMiddleware,
32}
33
34pub trait IntoRequest {
35 fn into_request(self, base_url: &str, client: ClientWithMiddleware) -> RequestBuilder;
36}
37
38pub trait ToSchema: JsonSchema {
42 fn to_schema() -> serde_json::Value;
43}
44
45impl LlmSdkBuilder {
46 fn default_client(&self) -> ClientWithMiddleware {
48 let retry_policy = ExponentialBackoff::builder()
49 .build_with_max_retries(self.max_retries.unwrap_or(MAX_RETRIES));
50 let m = RetryTransientMiddleware::new_with_policy(retry_policy);
51 ClientBuilder::new(reqwest::Client::new())
52 .with(TracingMiddleware::default())
54 .with(RetryMiddleware::from(m))
56 .build()
57 }
58}
59
60impl LlmSdk {
61 pub fn new(token: impl Into<String>) -> Self {
62 LlmSdkBuilder::default().token(token).build().unwrap()
63 }
64
65 pub fn new_with_base_url(token: impl Into<String>, base_url: impl Into<String>) -> Self {
66 LlmSdkBuilder::default()
67 .token(token)
68 .base_url(base_url)
69 .build()
70 .unwrap()
71 }
72
73 pub async fn chat_completion(
74 &self,
75 req: ChatCompletionRequest,
76 ) -> Result<ChatCompletionResponse> {
77 let req = self.prepare_request(req);
78 let res = req.send_and_log().await?;
79 Ok(res.json::<ChatCompletionResponse>().await?)
80 }
81
82 pub async fn create_image(&self, req: CreateImageRequest) -> Result<CreateImageResponse> {
83 let req = self.prepare_request(req);
84 let res = req.send_and_log().await?;
85 Ok(res.json::<CreateImageResponse>().await?)
86 }
87
88 pub async fn speech(&self, req: SpeechRequest) -> Result<Bytes> {
89 let req = self.prepare_request(req);
90 let res = req.send_and_log().await?;
91 Ok(res.bytes().await?)
92 }
93
94 pub async fn whisper(&self, req: WhisperRequest) -> Result<WhisperResponse> {
95 let is_json = req.response_format == WhisperResponseFormat::Json;
96 let req = self.prepare_request(req);
97 let res = req.send_and_log().await?;
98 let ret = if is_json {
99 res.json::<WhisperResponse>().await?
100 } else {
101 let text = res.text().await?;
102 WhisperResponse { text }
103 };
104 Ok(ret)
105 }
106
107 pub async fn embedding(&self, req: EmbeddingRequest) -> Result<EmbeddingResponse> {
108 let req = self.prepare_request(req);
109 let res = req.send_and_log().await?;
110 Ok(res.json().await?)
111 }
112
113 fn prepare_request(&self, req: impl IntoRequest) -> RequestBuilder {
114 let req = req.into_request(&self.base_url, self.client.clone());
115 let req = if self.token.is_empty() {
116 req
117 } else {
118 req.bearer_auth(&self.token)
119 };
120 req.timeout(Duration::from_secs(TIMEOUT))
121 }
122}
123
124trait SendAndLog {
125 async fn send_and_log(self) -> Result<Response>;
126}
127
128impl SendAndLog for RequestBuilder {
129 async fn send_and_log(self) -> Result<Response> {
130 let res = self.send().await?;
131 let status = res.status();
132 if status.is_client_error() || status.is_server_error() {
133 let text = res.text().await?;
134 error!("API failed: {}", text);
135 return Err(anyhow!("API failed: {}", text));
136 }
137 Ok(res)
138 }
139}
140
141impl<T: JsonSchema> ToSchema for T {
142 fn to_schema() -> serde_json::Value {
143 serde_json::to_value(schema_for!(Self)).unwrap()
144 }
145}
146#[cfg(test)]
147#[ctor::ctor]
148fn init() {
149 tracing_subscriber::fmt::init();
150}
151
152#[cfg(test)]
153lazy_static::lazy_static! {
154 static ref SDK: LlmSdk = LlmSdk::new(std::env::var("OPENAI_API_KEY").unwrap());
155}