llm_sdk/
lib.rs

1mod api;
2mod middleware;
3
4pub use api::*;
5
6use anyhow::{anyhow, Result};
7use bytes::Bytes;
8use derive_builder::Builder;
9use middleware::RetryMiddleware;
10use reqwest::Response;
11use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, RequestBuilder};
12use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
13use reqwest_tracing::TracingMiddleware;
14use schemars::{schema_for, JsonSchema};
15use std::time::Duration;
16use tracing::error;
17
18const TIMEOUT: u64 = 60;
19const MAX_RETRIES: u32 = 3;
20
21#[derive(Debug, Clone, Builder)]
22pub struct LlmSdk {
23    #[builder(setter(into), default = r#""https://api.openai.com/v1".into()"#)]
24    pub(crate) base_url: String,
25    #[builder(setter(into))]
26    pub(crate) token: String,
27    #[allow(dead_code)]
28    #[builder(default = "3")]
29    pub(crate) max_retries: u32,
30    #[builder(setter(skip), default = "self.default_client()")]
31    pub(crate) client: ClientWithMiddleware,
32}
33
34pub trait IntoRequest {
35    fn into_request(self, base_url: &str, client: ClientWithMiddleware) -> RequestBuilder;
36}
37
38/// For tool function. If you have a function that you want ChatGPT to call, you shall put
39/// all params into a struct and derive schemars::JsonSchema for it. Then you can use
40/// `YourStruct::to_schema()` to generate json schema for tools.
41pub trait ToSchema: JsonSchema {
42    fn to_schema() -> serde_json::Value;
43}
44
45impl LlmSdkBuilder {
46    // Private helper method with access to the builder struct.
47    fn default_client(&self) -> ClientWithMiddleware {
48        let retry_policy = ExponentialBackoff::builder()
49            .build_with_max_retries(self.max_retries.unwrap_or(MAX_RETRIES));
50        let m = RetryTransientMiddleware::new_with_policy(retry_policy);
51        ClientBuilder::new(reqwest::Client::new())
52            // Trace HTTP requests. See the tracing crate to make use of these traces.
53            .with(TracingMiddleware::default())
54            // Retry failed requests.
55            .with(RetryMiddleware::from(m))
56            .build()
57    }
58}
59
60impl LlmSdk {
61    pub fn new(token: impl Into<String>) -> Self {
62        LlmSdkBuilder::default().token(token).build().unwrap()
63    }
64
65    pub fn new_with_base_url(token: impl Into<String>, base_url: impl Into<String>) -> Self {
66        LlmSdkBuilder::default()
67            .token(token)
68            .base_url(base_url)
69            .build()
70            .unwrap()
71    }
72
73    pub async fn chat_completion(
74        &self,
75        req: ChatCompletionRequest,
76    ) -> Result<ChatCompletionResponse> {
77        let req = self.prepare_request(req);
78        let res = req.send_and_log().await?;
79        Ok(res.json::<ChatCompletionResponse>().await?)
80    }
81
82    pub async fn create_image(&self, req: CreateImageRequest) -> Result<CreateImageResponse> {
83        let req = self.prepare_request(req);
84        let res = req.send_and_log().await?;
85        Ok(res.json::<CreateImageResponse>().await?)
86    }
87
88    pub async fn speech(&self, req: SpeechRequest) -> Result<Bytes> {
89        let req = self.prepare_request(req);
90        let res = req.send_and_log().await?;
91        Ok(res.bytes().await?)
92    }
93
94    pub async fn whisper(&self, req: WhisperRequest) -> Result<WhisperResponse> {
95        let is_json = req.response_format == WhisperResponseFormat::Json;
96        let req = self.prepare_request(req);
97        let res = req.send_and_log().await?;
98        let ret = if is_json {
99            res.json::<WhisperResponse>().await?
100        } else {
101            let text = res.text().await?;
102            WhisperResponse { text }
103        };
104        Ok(ret)
105    }
106
107    pub async fn embedding(&self, req: EmbeddingRequest) -> Result<EmbeddingResponse> {
108        let req = self.prepare_request(req);
109        let res = req.send_and_log().await?;
110        Ok(res.json().await?)
111    }
112
113    fn prepare_request(&self, req: impl IntoRequest) -> RequestBuilder {
114        let req = req.into_request(&self.base_url, self.client.clone());
115        let req = if self.token.is_empty() {
116            req
117        } else {
118            req.bearer_auth(&self.token)
119        };
120        req.timeout(Duration::from_secs(TIMEOUT))
121    }
122}
123
124trait SendAndLog {
125    async fn send_and_log(self) -> Result<Response>;
126}
127
128impl SendAndLog for RequestBuilder {
129    async fn send_and_log(self) -> Result<Response> {
130        let res = self.send().await?;
131        let status = res.status();
132        if status.is_client_error() || status.is_server_error() {
133            let text = res.text().await?;
134            error!("API failed: {}", text);
135            return Err(anyhow!("API failed: {}", text));
136        }
137        Ok(res)
138    }
139}
140
141impl<T: JsonSchema> ToSchema for T {
142    fn to_schema() -> serde_json::Value {
143        serde_json::to_value(schema_for!(Self)).unwrap()
144    }
145}
146#[cfg(test)]
147#[ctor::ctor]
148fn init() {
149    tracing_subscriber::fmt::init();
150}
151
152#[cfg(test)]
153lazy_static::lazy_static! {
154    static ref SDK: LlmSdk = LlmSdk::new(std::env::var("OPENAI_API_KEY").unwrap());
155}