openai_utils/
chat_completion_delta.rs

1#![allow(dead_code)]
2
3use std::collections::HashMap;
4
5use crate::{AiAgent, calculate_message_tokens, ChatDelta, Choice, FunctionCall, Message, Usage};
6use futures_util::StreamExt;
7use log::trace;
8use reqwest_eventsource::Event;
9
10use crate::chat_completion_request::serialize;
11use crate::error::{InternalError, UtilsResult};
12use crate::{Chat, ChoiceDelta};
13use reqwest_eventsource::EventSource;
14use serde_derive::{Deserialize, Serialize};
15use tokio::sync::mpsc::{Receiver, Sender};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ChatCompletionDelta {
19    pub id: String,
20    pub object: String,
21    pub created: u64,
22    pub model: String,
23    pub choices: Vec<ChoiceDelta>,
24}
25
26pub struct DeltaReceiver<'a> {
27    pub receiver: Receiver<UtilsResult<ChatDelta>>,
28    pub builder: &'a AiAgent,
29    pub deltas: Vec<ChatCompletionDelta>,
30    usage: usize,
31}
32
33impl<'a> DeltaReceiver<'a> {
34    pub fn from(receiver: Receiver<UtilsResult<ChatDelta>>, builder: &'a AiAgent, usage: usize) -> Self {
35        Self {
36            receiver,
37            builder,
38            deltas: Vec::new(),
39            usage
40        }
41    }
42
43    pub async fn receive(
44        &mut self,
45        choice_index: i64,
46    ) -> anyhow::Result<Option<ChatCompletionDelta>> {
47        loop {
48            if let Some(delta) = self.receiver.recv().await {
49                let delta = delta?;
50                self.deltas.push(delta.clone());
51                for choice in &delta.choices {
52                    if choice.index == choice_index {
53                        continue;
54                    }
55                    return Ok(Some(delta));
56                }
57            } else {
58                return Ok(None);
59            }
60        }
61    }
62
63    pub async fn receive_content(&mut self, choice_index: i64) -> anyhow::Result<Option<String>> {
64        loop {
65            if let Some(delta) = self.receiver.recv().await {
66                let delta = delta?;
67                self.deltas.push(delta.clone());
68                for choice in &delta.choices {
69                    if choice.index != choice_index {
70                        continue;
71                    }
72                    if let Some(content) = &choice.delta.content {
73                        return Ok(Some(content.clone()));
74                    }
75                }
76            } else {
77                return Ok(None);
78            }
79        }
80    }
81
82    pub async fn receive_all(&mut self) -> anyhow::Result<Option<ChatCompletionDelta>> {
83        if let Some(delta) = self.receiver.recv().await {
84            let delta = delta?;
85            self.deltas.push(delta.clone());
86            Ok(Some(delta))
87        } else {
88            Ok(None)
89        }
90    }
91
92    pub async fn construct_chat(&mut self) -> anyhow::Result<Chat> {
93        // make sure you get the full response first
94        while let Some(delta) = self.receive_all().await? {
95            if delta.choices[0].finish_reason.is_some() {
96                break;
97            }
98        }
99
100        if self.deltas.len() == 0 {
101            Err(InternalError::NoDeltasReceived)?
102        }
103
104        let choice_list: Vec<ChoiceDelta> = self
105            .deltas
106            .iter()
107            .flat_map(|delta| delta.choices.clone())
108            .collect();
109
110        let mut choices_map: HashMap<i64, Vec<ChoiceDelta>> = Default::default();
111        choice_list.into_iter().for_each(|choice| {
112            choices_map.entry(choice.index).or_default().push(choice);
113        });
114
115        let choices: Vec<Choice> = choices_map
116            .iter()
117            .map(|(i, choices)| {
118                let index = *i;
119                let mut finish_reason: String = Default::default();
120                // message part
121                let mut role: Option<String> = None;
122                let mut content: Option<String> = None;
123                let mut function_call = false;
124                let mut function_call_name: Option<String> = None;
125                let mut arguments: Option<String> = None;
126
127                choices.iter().for_each(|choice| {
128                    if let Some(reason) = &choice.finish_reason {
129                        finish_reason = reason.clone();
130                    }
131
132                    if let Some(role_) = &choice.delta.role {
133                        role = Some(role_.clone());
134                    }
135
136                    if let Some(c) = &choice.delta.content {
137                        if let Some(content_) = &mut content {
138                            content_.push_str(c);
139                        } else {
140                            content = Some(c.clone());
141                        }
142                    }
143
144                    if let Some(call) = &choice.delta.function_call {
145                        function_call = true;
146                        if let Some(name) = &call.name {
147                            function_call_name = Some(name.clone());
148                        }
149
150                        if let Some(args) = &call.arguments {
151                            if let Some(args_) = &mut arguments {
152                                args_.push_str(args);
153                            } else {
154                                arguments = Some(args.clone());
155                            }
156                        }
157                    }
158                });
159
160                Choice {
161                    index,
162                    message: Message {
163                        // role should always be there, panic otherwise make this return an error later
164                        role: role.unwrap(),
165                        content,
166                        name: None,
167                        function_call: match function_call {
168                            true => Some(FunctionCall {
169                                name: function_call_name.unwrap(),
170                                arguments: arguments.unwrap(),
171                            }),
172                            false => None,
173                        },
174                    },
175                    finish_reason,
176                }
177            })
178            .collect();
179
180        let usage = Usage {
181            prompt_tokens: self.usage as u64,
182            completion_tokens: choices.iter().fold(0, |acc, c| acc + calculate_message_tokens(&c.message)) as u64,
183            total_tokens: choices.iter().fold(0, |acc, c| acc + calculate_message_tokens(&c.message)) as u64 + self.usage as u64,
184        };
185
186        let res = Ok(Chat {
187            id: self.deltas[0].id.clone(),
188            object: self.deltas[0].object.clone(),
189            created: self.deltas[0].created,
190            model: self.deltas[0].model.clone(),
191            //will be computed
192            choices,
193            // approximation
194            usage,
195        });
196
197        trace!("response: {res:#?}");
198
199        res
200    }
201}
202
203pub async fn forward_stream(
204    mut es: EventSource,
205    tx: Sender<UtilsResult<ChatDelta>>,
206) -> anyhow::Result<()> {
207    // Process each event from the EventSource
208    while let Some(event) = es.next().await {
209        // Handle errors in the event
210        let event = event?;
211
212        // Process Message events
213        if let Event::Message(message) = event {
214            // Break the loop if the message data is "[DONE]"
215            if message.data == "[DONE]" {
216                break;
217            }
218
219            // Serialize the message data and send it
220            let chat = serialize(&message.data);
221            tx.send(chat).await?;
222        }
223    }
224
225    Ok(())
226}