model_gateway_rs/sdk/
openai.rs1use async_trait::async_trait;
2use service_utils_rs::utils::{ByteStream, Request};
3
4use crate::{
5 error::Result,
6 model::llm::{ChatMessages, ChatRequest, ChatResponse},
7 sdk::ModelSDK,
8};
9
10pub struct OpenAIClient {
12 request: Request,
13 model: String,
14}
15
16impl OpenAIClient {
17 pub fn new(api_key: &str, base_url: &str, model: &str) -> Result<Self> {
18 let mut request = Request::with_timeout(60)?;
19 request.set_base_url(base_url)?;
20 request.set_default_headers(vec![
21 ("Content-Type", "application/json".to_string()),
22 ("Authorization", format!("Bearer {}", api_key)),
23 ])?;
24 Ok(Self {
25 request,
26 model: model.to_string(),
27 })
28 }
29}
30
31#[async_trait]
32impl ModelSDK for OpenAIClient {
33 type Input = ChatMessages;
34 type Output = ChatResponse;
35
36 async fn chat_once(&self, messages: Self::Input) -> Result<Self::Output> {
38 let body = ChatRequest {
39 model: self.model.clone(),
40 messages: messages.0,
41 stream: None,
42 temperature: None,
43 };
44 let payload = serde_json::to_value(body)?;
45 let response = self
46 .request
47 .post("chat/completions", &payload, None)
48 .await?;
49 let json: ChatResponse = response.json().await?;
50 Ok(json)
51 }
52
53 async fn chat_stream(&self, messages: Self::Input) -> Result<ByteStream> {
55 let body = ChatRequest {
56 model: self.model.clone(),
57 messages: messages.0,
58 stream: Some(true),
59 temperature: None,
60 };
61 let payload = serde_json::to_value(body)?;
62 let r = self
63 .request
64 .post_stream("chat/completions", &payload, None)
65 .await?;
66 Ok(r)
67 }
68}