model_gateway_rs/sdk/
openai.rs1use crate::error::Result;
2use serde::{Deserialize, Serialize};
3use service_utils_rs::utils::{ByteStream, Request};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum Role {
9 System,
10 User,
11 Assistant,
12}
13
14#[derive(Debug, Clone, Serialize)]
16pub struct ChatMessage {
17 pub role: Role,
18 pub content: String,
19}
20
21impl ChatMessage {
22 pub fn user(content: &str) -> Self {
23 Self {
24 role: Role::User,
25 content: content.to_string(),
26 }
27 }
28}
29
30#[derive(Debug, Deserialize)]
33pub struct ChatChoice {
34 pub index: u32,
35 pub message: ChatMessageResponse,
36 pub finish_reason: Option<String>,
37}
38
39#[derive(Debug, Deserialize)]
40pub struct ChatMessageResponse {
41 pub role: Role,
42 pub content: String,
43}
44
45#[derive(Debug, Deserialize)]
46pub struct ChatUsage {
47 pub prompt_tokens: u32,
48 pub completion_tokens: u32,
49 pub total_tokens: u32,
50}
51
52#[derive(Debug, Clone, Serialize)]
53pub struct ChatRequest {
54 pub model: String,
55 pub messages: Vec<ChatMessage>,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 pub stream: Option<bool>,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub temperature: Option<f32>,
60}
61
62#[derive(Debug, Deserialize)]
63pub struct ChatResponse {
64 pub id: String,
65 pub object: String,
66 pub created: u64,
67 pub model: String,
68 pub choices: Vec<ChatChoice>,
69 pub usage: Option<ChatUsage>,
70}
71
72impl ChatResponse {
73 pub fn first_message(&self) -> Option<String> {
75 self.choices
76 .first()
77 .map(|choice| choice.message.content.clone())
78 }
79}
80
81pub struct OpenAIClient {
83 request: Request,
84 model: String,
85}
86
87impl OpenAIClient {
88 pub fn new(api_key: &str, base_url: &str, model: &str) -> Result<Self> {
89 let mut request = Request::new();
90 request.set_base_url(base_url)?;
91 request.set_default_headers(vec![
92 ("Content-Type", "application/json".to_string()),
93 ("Authorization", format!("Bearer {}", api_key)),
94 ])?;
95 Ok(Self {
96 request,
97 model: model.to_string(),
98 })
99 }
100
101 pub async fn chat_once(&self, messages: Vec<ChatMessage>) -> Result<ChatResponse> {
103 let body = ChatRequest {
104 model: self.model.clone(),
105 messages,
106 stream: None,
107 temperature: None,
108 };
109 let payload = serde_json::to_value(body)?;
110 let response = self
111 .request
112 .post("chat/completions", &payload, None)
113 .await?;
114 let json: ChatResponse = response.json().await?;
115 Ok(json)
116 }
117
118 pub async fn chat_stream(&self, messages: Vec<ChatMessage>) -> Result<ByteStream> {
120 let body = ChatRequest {
121 model: self.model.clone(),
122 messages,
123 stream: Some(true),
124 temperature: None,
125 };
126 let payload = serde_json::to_value(body)?;
127 let r = self
128 .request
129 .post_stream("chat/completions", &payload, None)
130 .await?;
131 Ok(r)
132 }
133}