openai_gpt_client/
chat.rs

1use std::collections::{HashMap, VecDeque};
2
3use log::{debug, info};
4use serde::{Deserialize, Serialize};
5
6use crate::{client::Stop, model_variants::ModelId};
7
8#[serde_with::skip_serializing_none]
9#[derive(Debug, Default, Serialize)]
10pub struct ChatRequest {
11    pub model: ModelId,
12    pub messages: Vec<ChatMessage>,
13    pub temperature: Option<f64>,
14    pub top_p: Option<f64>,
15    pub n: Option<i32>,
16    pub stream: Option<bool>,
17    pub stop: Option<Stop>,
18    pub max_tokens: Option<i32>,
19    pub presence_penalty: Option<f64>,
20    pub frequency_penalty: Option<f64>,
21    pub logit_bias: Option<HashMap<String, f64>>,
22    pub user: Option<String>,
23}
24
25#[derive(Debug, Deserialize)]
26pub struct ChatResponse {
27    pub id: Option<String>,
28    pub object: Option<String>,
29    pub created: Option<i32>,
30    pub choices: Vec<ChatCompletionChoice>,
31    pub usage: Option<ChatCompletionUsage>,
32}
33#[derive(Debug, Deserialize)]
34pub struct ChatCompletionUsage {
35    pub prompt_tokens: i32,
36    pub completion_tokens: i32,
37    pub total_tokens: i32,
38}
39
40#[derive(Debug, Deserialize)]
41pub struct ChatCompletionChoice {
42    pub index: i32,
43    pub message: ChatMessage,
44    pub finish_reason: Option<String>,
45}
46#[derive(Debug, Deserialize, Serialize, Copy, Clone)]
47#[serde(rename_all = "lowercase")]
48pub enum Role {
49    System,
50    User,
51    Assistant,
52}
53
54#[derive(Debug, Deserialize, Serialize, Clone)]
55pub struct ChatMessage {
56    pub role: Role,
57    pub content: String,
58}
59
60#[derive(Debug)]
61pub struct ChatHistory {
62    queue: VecDeque<ChatMessage>,
63    prompt_message: Option<ChatMessage>,
64    summary_message: Option<ChatMessage>,
65}
66
67impl ChatHistory {
68    pub fn new(prompt_message: Option<String>) -> Self {
69        ChatHistory {
70            queue: Default::default(),
71            prompt_message: prompt_message.map(|prompt| ChatMessage {
72                role: Role::System,
73                content: prompt,
74            }),
75            summary_message: Default::default(),
76        }
77    }
78
79    pub fn add_initial_messages(&mut self, messages: Vec<ChatMessage>) {
80        self.queue.extend(messages)
81    }
82
83    /// Remove messages to summarize the conversation if it is too long
84    /// This is a naive implementation that just removes the oldest messages
85    pub fn summary_needed(
86        &mut self,
87        summary_prompt: ChatMessage,
88        max_length: usize,
89        summarize_length: usize,
90    ) -> Option<Vec<ChatMessage>> {
91        let total_length: usize = self.queue.iter().map(|message| message.content.len()).sum();
92        if total_length < max_length - summarize_length {
93            return None;
94        }
95        let mut selected_messages = vec![];
96        if let Some(prompt_message) = &self.prompt_message {
97            selected_messages.push(prompt_message.clone());
98        }
99        selected_messages.extend(self.summary_message.iter().cloned());
100        let mut current_length = 0;
101        while current_length < summarize_length {
102            if let Some(message) = self.queue.pop_front() {
103                current_length += message.content.len();
104                selected_messages.push(message);
105            } else {
106                break;
107            }
108        }
109        selected_messages.push(summary_prompt);
110        Some(selected_messages)
111    }
112
113    pub fn add_message(&mut self, message: ChatMessage) {
114        self.queue.push_back(message);
115    }
116
117    pub fn clear_history(&mut self) {
118        self.queue.clear();
119    }
120
121    pub fn get_history(&self) -> Vec<ChatMessage> {
122        let mut messages = vec![];
123        if let Some(prompt_message) = &self.prompt_message {
124            messages.push(prompt_message.clone());
125        }
126        messages.extend(self.summary_message.iter().cloned());
127        messages.extend(self.queue.iter().cloned());
128        debug!(
129            "Length of history: {}",
130            messages
131                .iter()
132                .map(|message| message.content.len())
133                .sum::<usize>()
134        );
135        messages
136    }
137
138    pub fn set_summary(&mut self, summary_message: String) {
139        info!("Set summary message:\n {}", summary_message);
140        self.summary_message = Some(ChatMessage {
141            role: Role::System,
142            content: summary_message,
143        });
144    }
145}