openai_utils/
chat_completion_delta.rs1#![allow(dead_code)]
2
3use std::collections::HashMap;
4
5use crate::{AiAgent, calculate_message_tokens, ChatDelta, Choice, FunctionCall, Message, Usage};
6use futures_util::StreamExt;
7use log::trace;
8use reqwest_eventsource::Event;
9
10use crate::chat_completion_request::serialize;
11use crate::error::{InternalError, UtilsResult};
12use crate::{Chat, ChoiceDelta};
13use reqwest_eventsource::EventSource;
14use serde_derive::{Deserialize, Serialize};
15use tokio::sync::mpsc::{Receiver, Sender};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ChatCompletionDelta {
19 pub id: String,
20 pub object: String,
21 pub created: u64,
22 pub model: String,
23 pub choices: Vec<ChoiceDelta>,
24}
25
26pub struct DeltaReceiver<'a> {
27 pub receiver: Receiver<UtilsResult<ChatDelta>>,
28 pub builder: &'a AiAgent,
29 pub deltas: Vec<ChatCompletionDelta>,
30 usage: usize,
31}
32
33impl<'a> DeltaReceiver<'a> {
34 pub fn from(receiver: Receiver<UtilsResult<ChatDelta>>, builder: &'a AiAgent, usage: usize) -> Self {
35 Self {
36 receiver,
37 builder,
38 deltas: Vec::new(),
39 usage
40 }
41 }
42
43 pub async fn receive(
44 &mut self,
45 choice_index: i64,
46 ) -> anyhow::Result<Option<ChatCompletionDelta>> {
47 loop {
48 if let Some(delta) = self.receiver.recv().await {
49 let delta = delta?;
50 self.deltas.push(delta.clone());
51 for choice in &delta.choices {
52 if choice.index == choice_index {
53 continue;
54 }
55 return Ok(Some(delta));
56 }
57 } else {
58 return Ok(None);
59 }
60 }
61 }
62
63 pub async fn receive_content(&mut self, choice_index: i64) -> anyhow::Result<Option<String>> {
64 loop {
65 if let Some(delta) = self.receiver.recv().await {
66 let delta = delta?;
67 self.deltas.push(delta.clone());
68 for choice in &delta.choices {
69 if choice.index != choice_index {
70 continue;
71 }
72 if let Some(content) = &choice.delta.content {
73 return Ok(Some(content.clone()));
74 }
75 }
76 } else {
77 return Ok(None);
78 }
79 }
80 }
81
82 pub async fn receive_all(&mut self) -> anyhow::Result<Option<ChatCompletionDelta>> {
83 if let Some(delta) = self.receiver.recv().await {
84 let delta = delta?;
85 self.deltas.push(delta.clone());
86 Ok(Some(delta))
87 } else {
88 Ok(None)
89 }
90 }
91
92 pub async fn construct_chat(&mut self) -> anyhow::Result<Chat> {
93 while let Some(delta) = self.receive_all().await? {
95 if delta.choices[0].finish_reason.is_some() {
96 break;
97 }
98 }
99
100 if self.deltas.len() == 0 {
101 Err(InternalError::NoDeltasReceived)?
102 }
103
104 let choice_list: Vec<ChoiceDelta> = self
105 .deltas
106 .iter()
107 .flat_map(|delta| delta.choices.clone())
108 .collect();
109
110 let mut choices_map: HashMap<i64, Vec<ChoiceDelta>> = Default::default();
111 choice_list.into_iter().for_each(|choice| {
112 choices_map.entry(choice.index).or_default().push(choice);
113 });
114
115 let choices: Vec<Choice> = choices_map
116 .iter()
117 .map(|(i, choices)| {
118 let index = *i;
119 let mut finish_reason: String = Default::default();
120 let mut role: Option<String> = None;
122 let mut content: Option<String> = None;
123 let mut function_call = false;
124 let mut function_call_name: Option<String> = None;
125 let mut arguments: Option<String> = None;
126
127 choices.iter().for_each(|choice| {
128 if let Some(reason) = &choice.finish_reason {
129 finish_reason = reason.clone();
130 }
131
132 if let Some(role_) = &choice.delta.role {
133 role = Some(role_.clone());
134 }
135
136 if let Some(c) = &choice.delta.content {
137 if let Some(content_) = &mut content {
138 content_.push_str(c);
139 } else {
140 content = Some(c.clone());
141 }
142 }
143
144 if let Some(call) = &choice.delta.function_call {
145 function_call = true;
146 if let Some(name) = &call.name {
147 function_call_name = Some(name.clone());
148 }
149
150 if let Some(args) = &call.arguments {
151 if let Some(args_) = &mut arguments {
152 args_.push_str(args);
153 } else {
154 arguments = Some(args.clone());
155 }
156 }
157 }
158 });
159
160 Choice {
161 index,
162 message: Message {
163 role: role.unwrap(),
165 content,
166 name: None,
167 function_call: match function_call {
168 true => Some(FunctionCall {
169 name: function_call_name.unwrap(),
170 arguments: arguments.unwrap(),
171 }),
172 false => None,
173 },
174 },
175 finish_reason,
176 }
177 })
178 .collect();
179
180 let usage = Usage {
181 prompt_tokens: self.usage as u64,
182 completion_tokens: choices.iter().fold(0, |acc, c| acc + calculate_message_tokens(&c.message)) as u64,
183 total_tokens: choices.iter().fold(0, |acc, c| acc + calculate_message_tokens(&c.message)) as u64 + self.usage as u64,
184 };
185
186 let res = Ok(Chat {
187 id: self.deltas[0].id.clone(),
188 object: self.deltas[0].object.clone(),
189 created: self.deltas[0].created,
190 model: self.deltas[0].model.clone(),
191 choices,
193 usage,
195 });
196
197 trace!("response: {res:#?}");
198
199 res
200 }
201}
202
203pub async fn forward_stream(
204 mut es: EventSource,
205 tx: Sender<UtilsResult<ChatDelta>>,
206) -> anyhow::Result<()> {
207 while let Some(event) = es.next().await {
209 let event = event?;
211
212 if let Event::Message(message) = event {
214 if message.data == "[DONE]" {
216 break;
217 }
218
219 let chat = serialize(&message.data);
221 tx.send(chat).await?;
222 }
223 }
224
225 Ok(())
226}