model_gateway_rs/sdk/
openai.rs

1use crate::error::Result;
2use serde::{Deserialize, Serialize};
3use service_utils_rs::utils::{ByteStream, Request};
4
5/// Role in chat messages.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum Role {
9    System,
10    User,
11    Assistant,
12}
13
14/// Single chat message.
15#[derive(Debug, Clone, Serialize)]
16pub struct ChatMessage {
17    pub role: Role,
18    pub content: String,
19}
20
21impl ChatMessage {
22    pub fn user(content: &str) -> Self {
23        Self {
24            role: Role::User,
25            content: content.to_string(),
26        }
27    }
28}
29
30/// Request body for chat completion.
31
32#[derive(Debug, Deserialize)]
33pub struct ChatChoice {
34    pub index: u32,
35    pub message: ChatMessageResponse,
36    pub finish_reason: Option<String>,
37}
38
39#[derive(Debug, Deserialize)]
40pub struct ChatMessageResponse {
41    pub role: Role,
42    pub content: String,
43}
44
45#[derive(Debug, Deserialize)]
46pub struct ChatUsage {
47    pub prompt_tokens: u32,
48    pub completion_tokens: u32,
49    pub total_tokens: u32,
50}
51
52#[derive(Debug, Clone, Serialize)]
53pub struct ChatRequest {
54    pub model: String,
55    pub messages: Vec<ChatMessage>,
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub stream: Option<bool>,
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub temperature: Option<f32>,
60}
61
62#[derive(Debug, Deserialize)]
63pub struct ChatResponse {
64    pub id: String,
65    pub object: String,
66    pub created: u64,
67    pub model: String,
68    pub choices: Vec<ChatChoice>,
69    pub usage: Option<ChatUsage>,
70}
71
72impl ChatResponse {
73    /// Get the first choice's message content.
74    pub fn first_message(&self) -> Option<String> {
75        self.choices
76            .first()
77            .map(|choice| choice.message.content.clone())
78    }
79}
80
81/// ChatCompletion client using your wrapped Request.
82pub struct OpenAIClient {
83    request: Request,
84    model: String,
85}
86
87impl OpenAIClient {
88    pub fn new(api_key: &str, base_url: &str, model: &str) -> Result<Self> {
89        let mut request = Request::new();
90        request.set_base_url(base_url)?;
91        request.set_default_headers(vec![
92            ("Content-Type", "application/json".to_string()),
93            ("Authorization", format!("Bearer {}", api_key)),
94        ])?;
95        Ok(Self {
96            request,
97            model: model.to_string(),
98        })
99    }
100
101    /// Send a chat request and get full response.
102    pub async fn chat_once(&self, messages: Vec<ChatMessage>) -> Result<ChatResponse> {
103        let body = ChatRequest {
104            model: self.model.clone(),
105            messages,
106            stream: None,
107            temperature: None,
108        };
109        let payload = serde_json::to_value(body)?;
110        let response = self
111            .request
112            .post("chat/completions", &payload, None)
113            .await?;
114        let json: ChatResponse = response.json().await?;
115        Ok(json)
116    }
117
118    /// Send a chat request and get response stream (SSE).
119    pub async fn chat_stream(&self, messages: Vec<ChatMessage>) -> Result<ByteStream> {
120        let body = ChatRequest {
121            model: self.model.clone(),
122            messages,
123            stream: Some(true),
124            temperature: None,
125        };
126        let payload = serde_json::to_value(body)?;
127        let r = self
128            .request
129            .post_stream("chat/completions", &payload, None)
130            .await?;
131        Ok(r)
132    }
133}