1use crate::err::ApiError;
2use crate::http_client;
3use crate::output::Output;
4use serde::{Deserialize, Serialize};
5use serde_json::from_str;
6use serde_json::Result as SerdeResult;
7
8#[derive(Debug, Clone, Deserialize, Serialize)]
9pub struct GptChat {
10 messages: Vec<Message>,
11}
12
13pub trait History {
14 fn new() -> GptChat;
15 fn add(&mut self, message: Message);
16 fn pop(&mut self);
17 fn get_all(&self) -> Vec<Message>;
18 fn flush(&mut self);
19}
20
21impl History for GptChat {
22 fn new() -> GptChat {
23 GptChat { messages: vec![] }
24 }
25
26 fn add(&mut self, message: Message) {
27 self.messages.push(message);
28 }
29
30 fn pop(&mut self) {
31 self.messages.pop();
32 }
33
34 fn get_all(&self) -> Vec<Message> {
35 self.messages.clone()
36 }
37
38 fn flush(&mut self) {
39 self.messages = vec![];
40 }
41}
42
43#[derive(Debug, Deserialize, Serialize)]
46pub struct ChatCreateCompletionParams {
47 pub model: Option<String>,
48 pub messages: Option<Vec<Message>>,
49 pub temperature: Option<f64>,
50 pub max_tokens: Option<i32>,
51 }
60
61#[derive(Debug, Deserialize, Serialize)]
62pub struct ChatCreateCompletionResponse {
63 id: String,
64 object: Option<String>,
65 created_at: Option<i64>,
66 choices: Option<Vec<Choice>>,
67 usage: Option<Usage>,
68}
69
70#[derive(Debug, Deserialize, Serialize)]
71struct Choice {
72 message: Option<Message>,
73 index: Option<i32>,
74 logprobs: Option<i32>,
75 finish_reason: Option<String>,
76}
77
78#[derive(Debug, Clone, Deserialize, Serialize)]
79pub struct Message {
80 pub role: Option<String>,
81 pub content: Option<String>,
82}
83
84#[derive(Debug, Deserialize, Serialize)]
85struct Usage {
86 prompt_tokens: Option<i32>,
87 completion_tokens: Option<i32>,
88 total_tokens: Option<i32>,
89}
90
91#[derive(Debug, Deserialize)]
92pub struct Error {
93 pub message: String,
94 pub r#type: String,
95 pub param: Option<String>,
96 pub code: Option<String>,
97}
98
99#[derive(Debug, Deserialize)]
100pub struct ErrorResponse {
101 pub error: Error,
102}
103
104#[derive(Debug, Deserialize)]
105#[serde(untagged)]
106pub enum Response {
107 ChatCreateCompletion(ChatCreateCompletionResponse),
108 Error(ErrorResponse),
109}
110
111
112impl Output for ChatCreateCompletionResponse {
113 fn get_output(&self) -> String {
114 let mut output = String::from("");
115 for choice in self.choices.iter() {
116 for message in choice.iter() {
117 let lines = &message.message;
118 let some_lines = lines;
119 match some_lines {
120 Some(some_lines) => {
121 for line in &some_lines.content {
122 if line.trim().is_empty() {
123 continue; }
125 output.push_str(&line.to_string());
126 output.push_str("\n");
127 }
128 }
129 None => {}
130 }
131 }
132 }
133 output
134 }
135}
136
137impl Output for ErrorResponse {
138 fn get_output(&self) -> String {
139 let output = format!("{:?}", self);
140 String::from(output)
141 }
142}
143
144pub trait MessageHistory {
145 fn save_messages(&self, history: &mut GptChat);
146}
147
148impl MessageHistory for ChatCreateCompletionResponse {
149 fn save_messages(&self, history: &mut GptChat) {
150 for choice in self.choices.iter() {
151 for message in choice.iter() {
152 let lines = &message.message;
153 let some_lines = lines;
154 match some_lines {
155 Some(some_lines) => history.add(Message {
156 role: some_lines.role.clone(),
157 content: some_lines.content.clone(),
158 }),
159 None => {}
160 }
161 }
162 }
163 }
164}
165
166fn parse_chat_response(response: String) -> SerdeResult<Response> {
167 match from_str(&response) {
168 Ok(c) => return Ok(c),
170 Err(e) => {
171 return Err(e);
172 }
173 };
174}
175
176pub async fn process_chat_prompt(
177 request_defaults: ChatCreateCompletionParams,
178) -> Result<ChatCreateCompletionResponse, ApiError> {
179 let result = http_client::send_chat_request(request_defaults).await;
181 match result {
182 Ok(response) => match parse_chat_response(response) {
183 Ok(completion) => {
184 match completion {
185 Response::ChatCreateCompletion(r) => {
186 return Ok(r);
187 }
188 Response::Error(e) => {
189 return Err(ApiError::new(&e.get_output()));
190 }
191
192 }
193 }
194 Err(e) => {
195 return Err(ApiError::new(&e.to_string()));
196 }
197 },
198 Err(e) => {
199 return Err(ApiError::new(&e.to_string()));
200 }
201 }
202}