model_gateway_rs/sdk/
openai.rs1use async_trait::async_trait;
2use toolcraft_request::{ByteStream, Request};
3
4use crate::{
5 error::Result,
6 model::{
7 llm::{LlmInput, LlmOutput},
8 openai::{OpenAiChatRequest, OpenAiChatResponse},
9 },
10 sdk::ModelSDK,
11};
12
13pub struct OpenAiSdk {
15 request: Request,
16 model: String,
17}
18
19impl OpenAiSdk {
20 pub fn new(api_key: &str, base_url: &str, model: &str) -> Result<Self> {
21 let mut request = Request::new()?;
22 request.set_base_url(base_url)?;
23 request.set_default_headers(vec![
24 ("Content-Type", "application/json".to_string()),
25 ("Authorization", format!("Bearer {api_key}")),
26 ])?;
27 Ok(Self {
28 request,
29 model: model.to_string(),
30 })
31 }
32}
33
34#[async_trait]
35impl ModelSDK for OpenAiSdk {
36 type Input = LlmInput;
37 type Output = LlmOutput;
38
39 async fn chat_once(&self, input: Self::Input) -> Result<Self::Output> {
41 let body = OpenAiChatRequest {
42 model: self.model.clone(),
43 messages: input.messages,
44 stream: None,
45 temperature: None,
46 };
47 let payload = serde_json::to_value(body)?;
48 let response = self
49 .request
50 .post("chat/completions", &payload, None)
51 .await?;
52 let json: OpenAiChatResponse = response.json().await?;
53 Ok(json.into())
54 }
55
56 async fn chat_stream(&self, input: Self::Input) -> Result<ByteStream> {
58 let body = OpenAiChatRequest {
59 model: self.model.clone(),
60 messages: input.messages,
61 stream: Some(true),
62 temperature: None,
63 };
64 let payload = serde_json::to_value(body)?;
65 let r = self
66 .request
67 .post_stream("chat/completions", &payload, None)
68 .await?;
69 Ok(r)
70 }
71}