#![allow(dead_code)]
use std::collections::HashMap;
use crate::{ChatDelta, AiAgent, Choice, FunctionCall, Message};
use futures_util::StreamExt;
use reqwest_eventsource::Event;
use crate::chat_completion_request::serialize;
use crate::{Chat, ChoiceDelta};
use reqwest_eventsource::EventSource;
use serde_derive::{Deserialize, Serialize};
use tokio::sync::mpsc::{Receiver, Sender};
use crate::error::ApiResult;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionDelta {
pub id: String,
pub object: String,
pub created: i64,
pub model: String,
pub choices: Vec<ChoiceDelta>,
}
pub struct DeltaReceiver<'a> {
pub receiver: Receiver<ApiResult<ChatDelta>>,
pub builder: &'a AiAgent,
pub deltas: Vec<ChatCompletionDelta>,
}
impl<'a> DeltaReceiver<'a> {
pub fn from(
receiver: Receiver<ApiResult<ChatDelta>>,
builder: &'a AiAgent,
) -> Self {
Self {
receiver,
builder,
deltas: Vec::new(),
}
}
pub async fn receive(&mut self, choice_index: i64) -> anyhow::Result<Option<ChatCompletionDelta>> {
loop {
if let Some(delta) = self.receiver.recv().await {
let delta = delta?;
self.deltas.push(delta.clone());
for choice in &delta.choices {
if choice.index == choice_index {
continue;
}
return Ok(Some(delta));
}
} else {
return Ok(None);
}
}
}
pub async fn receive_content(&mut self, choice_index: i64) -> anyhow::Result<Option<String>> {
loop {
if let Some(delta) = self.receiver.recv().await {
let delta = delta?;
self.deltas.push(delta.clone());
for choice in &delta.choices {
if choice.index != choice_index {
continue;
}
if let Some(content) = &choice.delta.content {
return Ok(Some(content.clone()));
}
}
} else {
return Ok(None);
}
}
}
pub async fn receive_all(&mut self) -> anyhow::Result<Option<ChatCompletionDelta>> {
if let Some(delta) = self.receiver.recv().await {
let delta = delta?;
self.deltas.push(delta.clone());
Ok(Some(delta))
} else {
Ok(None)
}
}
pub async fn construct_chat(&mut self) -> anyhow::Result<Chat> {
while let Some(delta) = self.receive_all().await? {
if delta.choices[0].finish_reason.is_some() {
break;
}
}
let choice_list: Vec<ChoiceDelta> = self
.deltas
.iter()
.flat_map(|delta| delta.choices.clone())
.collect();
let mut choices_map: HashMap<i64, Vec<ChoiceDelta>> = Default::default();
choice_list.into_iter().for_each(|choice| {
choices_map.entry(choice.index).or_default().push(choice);
});
let choices: Vec<Choice> = choices_map
.iter()
.map(|(i, choices)| {
let index = *i;
let mut finish_reason: String = Default::default();
let mut role: Option<String> = None;
let mut content: Option<String> = None;
let mut function_call = false;
let mut function_call_name: Option<String> = None;
let mut arguments: Option<String> = None;
choices.iter().for_each(|choice| {
if let Some(reason) = &choice.finish_reason {
finish_reason = reason.clone();
}
if let Some(role_) = &choice.delta.role {
role = Some(role_.clone());
}
if let Some(c) = &choice.delta.content {
if let Some(content_) = &mut content {
content_.push_str(c);
} else {
content = Some(c.clone());
}
}
if let Some(call) = &choice.delta.function_call {
function_call = true;
if let Some(name) = &call.name {
function_call_name = Some(name.clone());
}
if let Some(args) = &call.arguments {
if let Some(args_) = &mut arguments {
args_.push_str(args);
} else {
arguments = Some(args.clone());
}
}
}
});
Choice {
index,
message: Message {
role: role.unwrap(),
content,
name: None,
function_call: match function_call {
true => Some(FunctionCall {
name: function_call_name.unwrap(),
arguments: arguments.unwrap(),
}),
false => None,
},
},
finish_reason,
}
})
.collect();
Ok(Chat {
id: self.deltas[0].id.clone(),
object: self.deltas[0].object.clone(),
created: self.deltas[0].created,
model: self.deltas[0].model.clone(),
choices,
usage: crate::Usage {
prompt_tokens: 0,
completion_tokens: self.deltas.len() as i64,
total_tokens: self.deltas.len() as i64,
},
})
}
}
pub async fn forward_stream(
mut es: EventSource,
tx: Sender<ApiResult<ChatDelta>>,
) -> anyhow::Result<()> {
while let Some(event) = es.next().await {
let event = match event {
Ok(event) => event,
Err(_err) => {
panic!("{_err:#?}")
}
};
if let Event::Message(message) = event {
if message.data == "[DONE]" {
break;
}
let chat = serialize(&message.data);
tx.send(chat).await?;
}
}
Ok(())
}