gpt_rs/chat/
mod.rs

1pub mod config;
2pub mod model;
3
4use super::client::Client;
5use super::entry_point::{EntryPoint, Function, Version};
6use anyhow::Result;
7use config::ChatConfig;
8use model::{CompletionRequest, CompletionResponse, Message};
9
10#[derive(Debug, Default, Clone, PartialEq)]
11pub struct ChatBuilder {
12    version: Option<Version>,
13    config: Option<ChatConfig>,
14    request: Option<Vec<Message>>,
15}
16
17impl ChatBuilder {
18    pub fn version(mut self, value: Version) -> Self {
19        self.version = Some(value);
20        self
21    }
22    pub fn config(mut self, value: ChatConfig) -> Self {
23        self.config = Some(value);
24        self
25    }
26    pub fn request(mut self, value: Vec<Message>) -> Self {
27        self.request = Some(value);
28        self
29    }
30    pub fn build(&self) -> Result<Chat> {
31        let version = self.version.unwrap_or(Default::default());
32        let config = match self.config.clone() {
33            Some(val) => val,
34            _ => return Err(anyhow::anyhow!("Config must be set.")),
35        };
36        let entry_point = EntryPoint::default()
37            .set_version(version)
38            .set_function(Function::CreateChatCompletion);
39        let request = match &self.request {
40            Some(val) => val.clone(),
41            _ => return Err(anyhow::anyhow!("Request must be set.")),
42        };
43        Ok(Chat {
44            entry_point,
45            config,
46            request,
47        })
48    }
49}
50
51#[derive(Debug, Clone, PartialEq)]
52pub struct Chat {
53    entry_point: EntryPoint,
54    config: ChatConfig,
55    request: Vec<Message>,
56}
57
58impl Chat {
59    pub fn new(config: ChatConfig, request: Vec<Message>) -> Self {
60        Self {
61            entry_point: Default::default(),
62            config,
63            request,
64        }
65    }
66    pub fn builder() -> ChatBuilder {
67        Default::default()
68    }
69    pub async fn execute(&self, client: &Client) -> Result<CompletionResponse> {
70        let request = CompletionRequest {
71            model: self.config.model.as_ref().to_string(),
72            temperature: self.config.temperature,
73            messages: self.request.clone(),
74        };
75        let res = client.post(&self.entry_point.path(), request).await?;
76        let res: CompletionResponse = serde_json::from_str(res.text().await?.as_str())?;
77        Ok(res)
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::super::client::Config;
84    use super::*;
85    #[test]
86    fn builder() -> Result<()> {
87        let builder = ChatBuilder::default();
88        assert!(builder.clone().build().is_err());
89        let builder = builder.config(ChatConfig::default()).request(vec![]);
90        assert!(builder.build().is_ok());
91        Ok(())
92    }
93
94    #[tokio::test]
95    async fn post() -> Result<()> {
96        let client = Client::new(Config::from_env()?)?;
97        let res = Chat::builder()
98            .config(Default::default())
99            .request(vec!["hi".into()])
100            .build()?
101            .execute(&client)
102            .await;
103        assert!(res.is_ok());
104        Ok(())
105    }
106}