openai_dive/v1/endpoints/
chat.rs1use crate::v1::error::APIError;
2#[cfg(feature = "stream")]
3use crate::v1::resources::chat::ChatCompletionChunkResponse;
4#[cfg(feature = "stream")]
5use crate::v1::resources::chat::DeltaChatMessage;
6use crate::v1::resources::chat::{ChatCompletionParameters, ChatCompletionResponse};
7use crate::v1::resources::shared::ResponseWrapper;
8use crate::v1::{api::Client, helpers::format_response};
9#[cfg(feature = "stream")]
10use futures::Stream;
11#[cfg(feature = "stream")]
12use std::pin::Pin;
13#[cfg(feature = "stream")]
14use std::task::{Context, Poll};
15
16pub struct Chat<'a> {
17 pub client: &'a Client,
18}
19
20impl Client {
21 pub fn chat(&self) -> Chat<'_> {
23 Chat { client: self }
24 }
25}
26
27impl Chat<'_> {
28 pub async fn create(
30 &self,
31 parameters: ChatCompletionParameters,
32 ) -> Result<ChatCompletionResponse, APIError> {
33 let wrapped_response = self.create_wrapped(parameters).await?;
34
35 Ok(wrapped_response.data)
36 }
37
38 pub async fn create_wrapped(
40 &self,
41 parameters: ChatCompletionParameters,
42 ) -> Result<ResponseWrapper<ChatCompletionResponse>, APIError> {
43 let response = self
44 .client
45 .post(
46 "/chat/completions",
47 &ChatCompletionParameters {
48 query_params: None,
49 ..parameters
50 },
51 parameters.query_params.as_ref(),
52 )
53 .await?;
54
55 let data: ChatCompletionResponse = format_response(response.data)?;
56
57 Ok(ResponseWrapper {
58 data,
59 headers: response.headers,
60 })
61 }
62
63 #[cfg(feature = "stream")]
64 pub async fn create_stream(
66 &self,
67 parameters: ChatCompletionParameters,
68 ) -> Result<
69 Pin<Box<dyn Stream<Item = Result<ChatCompletionChunkResponse, APIError>> + Send>>,
70 APIError,
71 > {
72 let mut stream_parameters = ChatCompletionParameters {
73 query_params: None,
74 ..parameters
75 };
76 stream_parameters.stream = Some(true);
77
78 Ok(self
79 .client
80 .post_stream(
81 "/chat/completions",
82 &stream_parameters,
83 stream_parameters.query_params.as_ref(),
84 )
85 .await)
86 }
87}
88
89#[cfg(feature = "stream")]
90enum CurrentRole {
91 User,
92 System,
93 Assistant,
94}
95
96#[cfg(feature = "stream")]
97pub struct RoleTrackingStream<S> {
98 stream: S,
99 current_role: Option<CurrentRole>,
100}
101
102#[cfg(feature = "stream")]
103impl<S> RoleTrackingStream<S> {
104 pub fn new(stream: S) -> Self {
105 Self {
106 stream,
107 current_role: None,
108 }
109 }
110}
111
112#[cfg(feature = "stream")]
113impl<S> Stream for RoleTrackingStream<S>
114where
115 S: Stream<Item = Result<ChatCompletionChunkResponse, APIError>> + Unpin,
116{
117 type Item = Result<ChatCompletionChunkResponse, APIError>;
118
119 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
120 let this = self.get_mut();
121
122 match Pin::new(&mut this.stream).poll_next(cx) {
123 Poll::Ready(Some(Ok(mut chat_response))) => {
124 chat_response.choices.iter_mut().for_each(|choice| {
125 match &choice.delta {
126 DeltaChatMessage::User { .. } => {
127 this.current_role = Some(CurrentRole::User)
128 }
129 DeltaChatMessage::System { .. } => {
130 this.current_role = Some(CurrentRole::System)
131 }
132 DeltaChatMessage::Assistant { .. } => {
133 this.current_role = Some(CurrentRole::Assistant)
134 }
135 _ => {}
136 }
137
138 if let DeltaChatMessage::Untagged {
139 content,
140 reasoning,
141 reasoning_content,
142 refusal,
143 name: _,
144 tool_calls,
145 tool_call_id: _,
146 } = &mut choice.delta
147 {
148 match this.current_role {
149 Some(CurrentRole::User) => {
150 choice.delta = DeltaChatMessage::User {
151 name: Some("user".to_string()),
152 content: content.clone().unwrap(),
153 }
154 }
155 Some(CurrentRole::System) => {
156 choice.delta = DeltaChatMessage::System {
157 name: Some("system".to_string()),
158 content: content.clone().unwrap(),
159 }
160 }
161 Some(CurrentRole::Assistant) => {
162 choice.delta = DeltaChatMessage::Assistant {
163 name: Some("assistant".to_string()),
164 content: content.clone(),
165 reasoning: reasoning.clone(),
166 reasoning_content: reasoning_content.clone(),
167 refusal: refusal.clone(),
168 tool_calls: tool_calls.clone(),
169 }
170 }
171 _ => {}
172 }
173 }
174 });
175
176 Poll::Ready(Some(Ok(chat_response)))
177 }
178 other => other,
179 }
180 }
181}