az_openai_rs/
completions.rs

1use crate::query_builder::{
2    ChatMessages, ChatQuery, ChatResponse, CompletionQuery, CompletionResponse,
3};
4use reqwest::StatusCode;
5
6pub struct Where<'a> {
7    query: Vec<(&'a str, String)>,
8}
9
10pub struct WhereChat<'a> {
11    query: Vec<(&'a str, String)>,
12    messages: Vec<(&'a str, String)>,
13}
14
15pub fn chat<'a>() -> WhereChat<'a> {
16    WhereChat {
17        query: Vec::new(),
18        messages: Vec::new(),
19    }
20}
21
22pub fn completions<'a>() -> Where<'a> {
23    Where { query: Vec::new() }
24}
25
26impl<'a> Where<'a> {
27    pub fn prompt(mut self, input: &'a str) -> Self {
28        self.query.push(("prompt", String::from(input)));
29        self
30    }
31
32    pub async fn send(self) -> Result<CompletionResponse, StatusCode> {
33        let mut query: CompletionQuery = CompletionQuery {
34            prompt: String::new(),
35            max_tokens: 16,
36            temperature: 1.0,
37            top_p: 1.0,
38            frequency_penalty: 0.0,
39            presence_penalty: 0.0,
40        };
41
42        for (k, v) in self.query.into_iter() {
43            if k == "prompt" {
44                query.prompt = v;
45            }
46        }
47
48        let result: Result<CompletionResponse, StatusCode> =
49            crate::query_builder::build_completion(query).await;
50
51        match result {
52            Ok(t) => Ok(t),
53            Err(e) => Err(e),
54        }
55    }
56}
57
58impl<'a> WhereChat<'a> {
59    pub fn messages(mut self, messages: Vec<(&'a str, &'a str)>) -> Self {
60        for (role, content) in messages.into_iter() {
61            self.messages.push((role, String::from(content)));
62        }
63        self
64    }
65
66    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
67        self.query.push(("max_tokens", max_tokens.to_string()));
68        self
69    }
70
71    pub async fn send(self) -> Result<ChatResponse, StatusCode> {
72        let mut query: ChatQuery = ChatQuery {
73            messages: Vec::new(),
74            max_tokens: 16,
75            temperature: 1.0,
76            n: 1,
77            stream: false,
78            top_p: 1.0,
79            frequency_penalty: 0.0,
80            presence_penalty: 0.0,
81        };
82
83        for (k, v) in self.query.into_iter() {
84            if k == "max_tokens" {
85                query.max_tokens = v.parse::<u32>().unwrap();
86            }
87        }
88
89        for (role, content) in self.messages.into_iter() {
90            let new_msg: ChatMessages = ChatMessages {
91                role: String::from(role),
92                content: String::from(content),
93            };
94            query.messages.push(new_msg);
95        }
96
97        let result: Result<ChatResponse, StatusCode> =
98            crate::query_builder::build_chat(query).await;
99
100        match result {
101            Ok(t) => Ok(t),
102            Err(e) => Err(e),
103        }
104    }
105}