Skip to main content

openai_dive/v1/endpoints/
chat.rs

1use crate::v1::error::APIError;
2#[cfg(feature = "stream")]
3use crate::v1::resources::chat::ChatCompletionChunkResponse;
4#[cfg(feature = "stream")]
5use crate::v1::resources::chat::DeltaChatMessage;
6use crate::v1::resources::chat::{ChatCompletionParameters, ChatCompletionResponse};
7use crate::v1::resources::shared::ResponseWrapper;
8use crate::v1::{api::Client, helpers::format_response};
9#[cfg(feature = "stream")]
10use futures::Stream;
11#[cfg(feature = "stream")]
12use std::pin::Pin;
13#[cfg(feature = "stream")]
14use std::task::{Context, Poll};
15
16pub struct Chat<'a> {
17    pub client: &'a Client,
18}
19
20impl Client {
21    /// Given a list of messages comprising a conversation, the model will return a response.
22    pub fn chat(&self) -> Chat<'_> {
23        Chat { client: self }
24    }
25}
26
27impl Chat<'_> {
28    /// Creates a model response for the given chat conversation.
29    pub async fn create(
30        &self,
31        parameters: ChatCompletionParameters,
32    ) -> Result<ChatCompletionResponse, APIError> {
33        let wrapped_response = self.create_wrapped(parameters).await?;
34
35        Ok(wrapped_response.data)
36    }
37
38    /// Creates a model response for the given chat conversation.
39    pub async fn create_wrapped(
40        &self,
41        parameters: ChatCompletionParameters,
42    ) -> Result<ResponseWrapper<ChatCompletionResponse>, APIError> {
43        let response = self
44            .client
45            .post(
46                "/chat/completions",
47                &ChatCompletionParameters {
48                    query_params: None,
49                    ..parameters
50                },
51                parameters.query_params.as_ref(),
52            )
53            .await?;
54
55        let data: ChatCompletionResponse = format_response(response.data)?;
56
57        Ok(ResponseWrapper {
58            data,
59            headers: response.headers,
60        })
61    }
62
63    #[cfg(feature = "stream")]
64    /// Creates a model response for the given chat conversation.
65    pub async fn create_stream(
66        &self,
67        parameters: ChatCompletionParameters,
68    ) -> Result<
69        Pin<Box<dyn Stream<Item = Result<ChatCompletionChunkResponse, APIError>> + Send>>,
70        APIError,
71    > {
72        let mut stream_parameters = ChatCompletionParameters {
73            query_params: None,
74            ..parameters
75        };
76        stream_parameters.stream = Some(true);
77
78        Ok(self
79            .client
80            .post_stream(
81                "/chat/completions",
82                &stream_parameters,
83                stream_parameters.query_params.as_ref(),
84            )
85            .await)
86    }
87}
88
89#[cfg(feature = "stream")]
90enum CurrentRole {
91    User,
92    System,
93    Assistant,
94}
95
96#[cfg(feature = "stream")]
97pub struct RoleTrackingStream<S> {
98    stream: S,
99    current_role: Option<CurrentRole>,
100}
101
102#[cfg(feature = "stream")]
103impl<S> RoleTrackingStream<S> {
104    pub fn new(stream: S) -> Self {
105        Self {
106            stream,
107            current_role: None,
108        }
109    }
110}
111
112#[cfg(feature = "stream")]
113impl<S> Stream for RoleTrackingStream<S>
114where
115    S: Stream<Item = Result<ChatCompletionChunkResponse, APIError>> + Unpin,
116{
117    type Item = Result<ChatCompletionChunkResponse, APIError>;
118
119    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
120        let this = self.get_mut();
121
122        match Pin::new(&mut this.stream).poll_next(cx) {
123            Poll::Ready(Some(Ok(mut chat_response))) => {
124                chat_response.choices.iter_mut().for_each(|choice| {
125                    match &choice.delta {
126                        DeltaChatMessage::User { .. } => {
127                            this.current_role = Some(CurrentRole::User)
128                        }
129                        DeltaChatMessage::System { .. } => {
130                            this.current_role = Some(CurrentRole::System)
131                        }
132                        DeltaChatMessage::Assistant { .. } => {
133                            this.current_role = Some(CurrentRole::Assistant)
134                        }
135                        _ => {}
136                    }
137
138                    if let DeltaChatMessage::Untagged {
139                        content,
140                        reasoning,
141                        reasoning_content,
142                        refusal,
143                        name: _,
144                        tool_calls,
145                        tool_call_id: _,
146                    } = &mut choice.delta
147                    {
148                        match this.current_role {
149                            Some(CurrentRole::User) => {
150                                choice.delta = DeltaChatMessage::User {
151                                    name: Some("user".to_string()),
152                                    content: content.clone().unwrap(),
153                                }
154                            }
155                            Some(CurrentRole::System) => {
156                                choice.delta = DeltaChatMessage::System {
157                                    name: Some("system".to_string()),
158                                    content: content.clone().unwrap(),
159                                }
160                            }
161                            Some(CurrentRole::Assistant) => {
162                                choice.delta = DeltaChatMessage::Assistant {
163                                    name: Some("assistant".to_string()),
164                                    content: content.clone(),
165                                    reasoning: reasoning.clone(),
166                                    reasoning_content: reasoning_content.clone(),
167                                    refusal: refusal.clone(),
168                                    tool_calls: tool_calls.clone(),
169                                }
170                            }
171                            _ => {}
172                        }
173                    }
174                });
175
176                Poll::Ready(Some(Ok(chat_response)))
177            }
178            other => other,
179        }
180    }
181}