gptshell/
chat.rs

1use crate::err::ApiError;
2use crate::http_client;
3use crate::output::Output;
4use serde::{Deserialize, Serialize};
5use serde_json::from_str;
6use serde_json::Result as SerdeResult;
7
8#[derive(Debug, Clone, Deserialize, Serialize)]
9pub struct GptChat {
10    messages: Vec<Message>,
11}
12
13pub trait History {
14    fn new() -> GptChat;
15    fn add(&mut self, message: Message);
16    fn pop(&mut self);
17    fn get_all(&self) -> Vec<Message>;
18    fn flush(&mut self);
19}
20
21impl History for GptChat {
22    fn new() -> GptChat {
23        GptChat { messages: vec![] }
24    }
25
26    fn add(&mut self, message: Message) {
27        self.messages.push(message);
28    }
29
30    fn pop(&mut self) {
31        self.messages.pop();
32    }
33
34    fn get_all(&self) -> Vec<Message> {
35        self.messages.clone()
36    }
37
38    fn flush(&mut self) {
39        self.messages = vec![];
40    }
41}
42
43// Based off create chat completion
44// See API reference here https://platform.openai.com/docs/api-reference/chat/create
45#[derive(Debug, Deserialize, Serialize)]
46pub struct ChatCreateCompletionParams {
47    pub model: Option<String>,
48    pub messages: Option<Vec<Message>>,
49    pub temperature: Option<f64>,
50    pub max_tokens: Option<i32>,
51    //TODO: readd
52    // stop: Option<Vec<String>>,
53    // stream: Option<bool>,
54    // n: Option<i32>,
55    // top_n: Option<f64>,
56    // presence_penalty: Option<f64>,
57    // frequency_penalty: Option<f64>,
58    // user: Option<String>,
59}
60
61#[derive(Debug, Deserialize, Serialize)]
62pub struct ChatCreateCompletionResponse {
63    id: String,
64    object: Option<String>,
65    created_at: Option<i64>,
66    choices: Option<Vec<Choice>>,
67    usage: Option<Usage>,
68}
69
70#[derive(Debug, Deserialize, Serialize)]
71struct Choice {
72    message: Option<Message>,
73    index: Option<i32>,
74    logprobs: Option<i32>,
75    finish_reason: Option<String>,
76}
77
78#[derive(Debug, Clone, Deserialize, Serialize)]
79pub struct Message {
80    pub role: Option<String>,
81    pub content: Option<String>,
82}
83
84#[derive(Debug, Deserialize, Serialize)]
85struct Usage {
86    prompt_tokens: Option<i32>,
87    completion_tokens: Option<i32>,
88    total_tokens: Option<i32>,
89}
90
91#[derive(Debug, Deserialize)]
92pub struct Error {
93    pub message: String,
94    pub r#type: String,
95    pub param: Option<String>,
96    pub code: Option<String>,
97}
98
99#[derive(Debug, Deserialize)]
100pub struct ErrorResponse {
101    pub error: Error,
102}
103
104#[derive(Debug, Deserialize)]
105#[serde(untagged)]
106pub enum Response {
107    ChatCreateCompletion(ChatCreateCompletionResponse),
108    Error(ErrorResponse),
109}
110
111
112impl Output for ChatCreateCompletionResponse {
113    fn get_output(&self) -> String {
114        let mut output = String::from("");
115        for choice in self.choices.iter() {
116            for message in choice.iter() {
117                let lines = &message.message;
118                let some_lines = lines;
119                match some_lines {
120                    Some(some_lines) => {
121                        for line in &some_lines.content {
122                            if line.trim().is_empty() {
123                                continue; // ignore empty or whitespace-only lines
124                            }
125                            output.push_str(&line.to_string());
126                            output.push_str("\n");
127                        }
128                    }
129                    None => {}
130                }
131            }
132        }
133        output
134    }
135}
136
137impl Output for ErrorResponse {
138    fn get_output(&self) -> String {
139        let output = format!("{:?}", self);
140        String::from(output)
141    }
142}
143
144pub trait MessageHistory {
145    fn save_messages(&self, history: &mut GptChat);
146}
147
148impl MessageHistory for ChatCreateCompletionResponse {
149    fn save_messages(&self, history: &mut GptChat) {
150        for choice in self.choices.iter() {
151            for message in choice.iter() {
152                let lines = &message.message;
153                let some_lines = lines;
154                match some_lines {
155                    Some(some_lines) => history.add(Message {
156                        role: some_lines.role.clone(),
157                        content: some_lines.content.clone(),
158                    }),
159                    None => {}
160                }
161            }
162        }
163    }
164}
165
166fn parse_chat_response(response: String) -> SerdeResult<Response> {
167    match from_str(&response) {
168        //TODO: make sure decoding is happining here properly
169        Ok(c) => return Ok(c),
170        Err(e) => {
171            return Err(e);
172        }
173    };
174}
175
176pub async fn process_chat_prompt(
177    request_defaults: ChatCreateCompletionParams,
178) -> Result<ChatCreateCompletionResponse, ApiError> {
179    //TODO: Readd Language
180    let result = http_client::send_chat_request(request_defaults).await;
181    match result {
182        Ok(response) => match parse_chat_response(response) {
183            Ok(completion) => {
184                match completion {
185                    Response::ChatCreateCompletion(r)  => {
186                        return Ok(r);
187                    }
188                    Response::Error(e) => {
189                        return Err(ApiError::new(&e.get_output()));
190                    }
191
192                }
193            }
194            Err(e) => {
195                return Err(ApiError::new(&e.to_string()));
196            }
197        },
198        Err(e) => {
199            return Err(ApiError::new(&e.to_string()));
200        }
201    }
202}