az_openai_rs/
completions.rs1use crate::query_builder::{
2 ChatMessages, ChatQuery, ChatResponse, CompletionQuery, CompletionResponse,
3};
4use reqwest::StatusCode;
5
6pub struct Where<'a> {
7 query: Vec<(&'a str, String)>,
8}
9
10pub struct WhereChat<'a> {
11 query: Vec<(&'a str, String)>,
12 messages: Vec<(&'a str, String)>,
13}
14
15pub fn chat<'a>() -> WhereChat<'a> {
16 WhereChat {
17 query: Vec::new(),
18 messages: Vec::new(),
19 }
20}
21
22pub fn completions<'a>() -> Where<'a> {
23 Where { query: Vec::new() }
24}
25
26impl<'a> Where<'a> {
27 pub fn prompt(mut self, input: &'a str) -> Self {
28 self.query.push(("prompt", String::from(input)));
29 self
30 }
31
32 pub async fn send(self) -> Result<CompletionResponse, StatusCode> {
33 let mut query: CompletionQuery = CompletionQuery {
34 prompt: String::new(),
35 max_tokens: 16,
36 temperature: 1.0,
37 top_p: 1.0,
38 frequency_penalty: 0.0,
39 presence_penalty: 0.0,
40 };
41
42 for (k, v) in self.query.into_iter() {
43 if k == "prompt" {
44 query.prompt = v;
45 }
46 }
47
48 let result: Result<CompletionResponse, StatusCode> =
49 crate::query_builder::build_completion(query).await;
50
51 match result {
52 Ok(t) => Ok(t),
53 Err(e) => Err(e),
54 }
55 }
56}
57
58impl<'a> WhereChat<'a> {
59 pub fn messages(mut self, messages: Vec<(&'a str, &'a str)>) -> Self {
60 for (role, content) in messages.into_iter() {
61 self.messages.push((role, String::from(content)));
62 }
63 self
64 }
65
66 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
67 self.query.push(("max_tokens", max_tokens.to_string()));
68 self
69 }
70
71 pub async fn send(self) -> Result<ChatResponse, StatusCode> {
72 let mut query: ChatQuery = ChatQuery {
73 messages: Vec::new(),
74 max_tokens: 16,
75 temperature: 1.0,
76 n: 1,
77 stream: false,
78 top_p: 1.0,
79 frequency_penalty: 0.0,
80 presence_penalty: 0.0,
81 };
82
83 for (k, v) in self.query.into_iter() {
84 if k == "max_tokens" {
85 query.max_tokens = v.parse::<u32>().unwrap();
86 }
87 }
88
89 for (role, content) in self.messages.into_iter() {
90 let new_msg: ChatMessages = ChatMessages {
91 role: String::from(role),
92 content: String::from(content),
93 };
94 query.messages.push(new_msg);
95 }
96
97 let result: Result<ChatResponse, StatusCode> =
98 crate::query_builder::build_chat(query).await;
99
100 match result {
101 Ok(t) => Ok(t),
102 Err(e) => Err(e),
103 }
104 }
105}