openai_gpt_client/
chat.rs1use std::collections::{HashMap, VecDeque};
2
3use log::{debug, info};
4use serde::{Deserialize, Serialize};
5
6use crate::{client::Stop, model_variants::ModelId};
7
8#[serde_with::skip_serializing_none]
9#[derive(Debug, Default, Serialize)]
10pub struct ChatRequest {
11 pub model: ModelId,
12 pub messages: Vec<ChatMessage>,
13 pub temperature: Option<f64>,
14 pub top_p: Option<f64>,
15 pub n: Option<i32>,
16 pub stream: Option<bool>,
17 pub stop: Option<Stop>,
18 pub max_tokens: Option<i32>,
19 pub presence_penalty: Option<f64>,
20 pub frequency_penalty: Option<f64>,
21 pub logit_bias: Option<HashMap<String, f64>>,
22 pub user: Option<String>,
23}
24
25#[derive(Debug, Deserialize)]
26pub struct ChatResponse {
27 pub id: Option<String>,
28 pub object: Option<String>,
29 pub created: Option<i32>,
30 pub choices: Vec<ChatCompletionChoice>,
31 pub usage: Option<ChatCompletionUsage>,
32}
33#[derive(Debug, Deserialize)]
34pub struct ChatCompletionUsage {
35 pub prompt_tokens: i32,
36 pub completion_tokens: i32,
37 pub total_tokens: i32,
38}
39
40#[derive(Debug, Deserialize)]
41pub struct ChatCompletionChoice {
42 pub index: i32,
43 pub message: ChatMessage,
44 pub finish_reason: Option<String>,
45}
46#[derive(Debug, Deserialize, Serialize, Copy, Clone)]
47#[serde(rename_all = "lowercase")]
48pub enum Role {
49 System,
50 User,
51 Assistant,
52}
53
54#[derive(Debug, Deserialize, Serialize, Clone)]
55pub struct ChatMessage {
56 pub role: Role,
57 pub content: String,
58}
59
60#[derive(Debug)]
61pub struct ChatHistory {
62 queue: VecDeque<ChatMessage>,
63 prompt_message: Option<ChatMessage>,
64 summary_message: Option<ChatMessage>,
65}
66
67impl ChatHistory {
68 pub fn new(prompt_message: Option<String>) -> Self {
69 ChatHistory {
70 queue: Default::default(),
71 prompt_message: prompt_message.map(|prompt| ChatMessage {
72 role: Role::System,
73 content: prompt,
74 }),
75 summary_message: Default::default(),
76 }
77 }
78
79 pub fn add_initial_messages(&mut self, messages: Vec<ChatMessage>) {
80 self.queue.extend(messages)
81 }
82
83 pub fn summary_needed(
86 &mut self,
87 summary_prompt: ChatMessage,
88 max_length: usize,
89 summarize_length: usize,
90 ) -> Option<Vec<ChatMessage>> {
91 let total_length: usize = self.queue.iter().map(|message| message.content.len()).sum();
92 if total_length < max_length - summarize_length {
93 return None;
94 }
95 let mut selected_messages = vec![];
96 if let Some(prompt_message) = &self.prompt_message {
97 selected_messages.push(prompt_message.clone());
98 }
99 selected_messages.extend(self.summary_message.iter().cloned());
100 let mut current_length = 0;
101 while current_length < summarize_length {
102 if let Some(message) = self.queue.pop_front() {
103 current_length += message.content.len();
104 selected_messages.push(message);
105 } else {
106 break;
107 }
108 }
109 selected_messages.push(summary_prompt);
110 Some(selected_messages)
111 }
112
113 pub fn add_message(&mut self, message: ChatMessage) {
114 self.queue.push_back(message);
115 }
116
117 pub fn clear_history(&mut self) {
118 self.queue.clear();
119 }
120
121 pub fn get_history(&self) -> Vec<ChatMessage> {
122 let mut messages = vec![];
123 if let Some(prompt_message) = &self.prompt_message {
124 messages.push(prompt_message.clone());
125 }
126 messages.extend(self.summary_message.iter().cloned());
127 messages.extend(self.queue.iter().cloned());
128 debug!(
129 "Length of history: {}",
130 messages
131 .iter()
132 .map(|message| message.content.len())
133 .sum::<usize>()
134 );
135 messages
136 }
137
138 pub fn set_summary(&mut self, summary_message: String) {
139 info!("Set summary message:\n {}", summary_message);
140 self.summary_message = Some(ChatMessage {
141 role: Role::System,
142 content: summary_message,
143 });
144 }
145}