1pub mod config;
2pub mod model;
3
4use super::client::Client;
5use super::entry_point::{EntryPoint, Function, Version};
6use anyhow::Result;
7use config::ChatConfig;
8use model::{CompletionRequest, CompletionResponse, Message};
9
10#[derive(Debug, Default, Clone, PartialEq)]
11pub struct ChatBuilder {
12 version: Option<Version>,
13 config: Option<ChatConfig>,
14 request: Option<Vec<Message>>,
15}
16
17impl ChatBuilder {
18 pub fn version(mut self, value: Version) -> Self {
19 self.version = Some(value);
20 self
21 }
22 pub fn config(mut self, value: ChatConfig) -> Self {
23 self.config = Some(value);
24 self
25 }
26 pub fn request(mut self, value: Vec<Message>) -> Self {
27 self.request = Some(value);
28 self
29 }
30 pub fn build(&self) -> Result<Chat> {
31 let version = self.version.unwrap_or(Default::default());
32 let config = match self.config.clone() {
33 Some(val) => val,
34 _ => return Err(anyhow::anyhow!("Config must be set.")),
35 };
36 let entry_point = EntryPoint::default()
37 .set_version(version)
38 .set_function(Function::CreateChatCompletion);
39 let request = match &self.request {
40 Some(val) => val.clone(),
41 _ => return Err(anyhow::anyhow!("Request must be set.")),
42 };
43 Ok(Chat {
44 entry_point,
45 config,
46 request,
47 })
48 }
49}
50
51#[derive(Debug, Clone, PartialEq)]
52pub struct Chat {
53 entry_point: EntryPoint,
54 config: ChatConfig,
55 request: Vec<Message>,
56}
57
58impl Chat {
59 pub fn new(config: ChatConfig, request: Vec<Message>) -> Self {
60 Self {
61 entry_point: Default::default(),
62 config,
63 request,
64 }
65 }
66 pub fn builder() -> ChatBuilder {
67 Default::default()
68 }
69 pub async fn execute(&self, client: &Client) -> Result<CompletionResponse> {
70 let request = CompletionRequest {
71 model: self.config.model.as_ref().to_string(),
72 temperature: self.config.temperature,
73 messages: self.request.clone(),
74 };
75 let res = client.post(&self.entry_point.path(), request).await?;
76 let res: CompletionResponse = serde_json::from_str(res.text().await?.as_str())?;
77 Ok(res)
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::super::client::Config;
84 use super::*;
85 #[test]
86 fn builder() -> Result<()> {
87 let builder = ChatBuilder::default();
88 assert!(builder.clone().build().is_err());
89 let builder = builder.config(ChatConfig::default()).request(vec![]);
90 assert!(builder.build().is_ok());
91 Ok(())
92 }
93
94 #[tokio::test]
95 async fn post() -> Result<()> {
96 let client = Client::new(Config::from_env()?)?;
97 let res = Chat::builder()
98 .config(Default::default())
99 .request(vec!["hi".into()])
100 .build()?
101 .execute(&client)
102 .await;
103 assert!(res.is_ok());
104 Ok(())
105 }
106}