1use crate::{
2 constants::OPENAI_API_URL,
3 primitives::{
4 ChatCompletion, Completion, CompletionRequest, CompletionResponse, ModelChoice,
5 OpenAICompletionRequest,
6 },
7 utils::parse_chunk,
8 ToolSet,
9};
10use futures_util::StreamExt;
11use std::{cell::OnceCell, io::Write};
12
13pub struct Client {
14 pub(crate) url: String,
15 pub(crate) client: reqwest::Client,
16}
17
18impl Client {
19 pub fn new(api_key: &str) -> Self {
20 let client = make_client(api_key);
21
22 Client {
23 url: OPENAI_API_URL.to_owned(),
24 client,
25 }
26 }
27
28 pub fn with_ollama(url: &str) -> Self {
29 let client = make_client("ollama");
30
31 Client {
32 url: url.to_owned(),
33 client,
34 }
35 }
36}
37
38impl Completion for Client {
39 async fn post<F>(
40 &self,
41 req: CompletionRequest,
42 tools: &ToolSet,
43 callback: OnceCell<F>,
44 ) -> Result<(), ()>
45 where
46 F: Fn(&str) + Send,
47 {
48 let enable_stream = req.stream.unwrap_or(false);
49 let body = OpenAICompletionRequest::new(req);
50 let url = format!("{}/chat/completions", self.url);
51 let response = self
52 .client
53 .post(&url)
54 .json(&body)
55 .send()
56 .await
57 .expect("openai completion msg");
58
59 match enable_stream {
60 true => {
61 let mut stream = response.bytes_stream();
62 while let Some(item) = stream.next().await {
63 let data = &item.expect("msg");
64 let chunk_str = std::str::from_utf8(data).expect("OpenAI expect utf8.");
65 match parse_chunk(chunk_str) {
66 Ok(chunk_response) => {
67 if let Ok(completion_response) =
68 CompletionResponse::try_from(chunk_response)
69 {
70 match completion_response {
71 CompletionResponse {
72 choice: ModelChoice::Message(msg),
73 ..
74 } => {
75 if let Some(callback) = callback.get() {
76 callback(&msg);
77 std::io::stdout()
78 .flush()
79 .expect("Failed to flush stdout");
80 }
81 }
82 CompletionResponse {
83 choice: ModelChoice::ToolCall(toolname, args),
84 ..
85 } => {
86 if let Ok(res) =
87 tools.invoke(&toolname, args.to_string()).await
88 {
89 if let Some(callback) = callback.get() {
90 callback(&res);
91 }
92 }
93 }
94 }
95 }
96 }
97 Err(err) => println!("OpenAI error parsing chunk: {}", err),
98 }
99 }
100 }
101 false => {
102 let chat_completion = response.json::<ChatCompletion>().await;
103 if let Ok(chat_completion) = chat_completion {
104 if let Ok(completion_response) = CompletionResponse::try_from(chat_completion) {
105 match completion_response {
106 CompletionResponse {
107 choice: ModelChoice::Message(msg),
108 ..
109 } => {
110 if let Some(callback) = callback.get() {
111 callback(&msg);
112 }
113 }
114 CompletionResponse {
115 choice: ModelChoice::ToolCall(toolname, args),
116 ..
117 } => {
118 if let Ok(res) = tools.invoke(&toolname, args.to_string()).await {
119 if let Some(callback) = callback.get() {
120 callback(&res);
121 }
122 }
123 }
124 }
125 }
126 }
127 }
128 }
129
130 Ok(())
131 }
132}
133
134fn make_client(api_key: &str) -> reqwest::Client {
135 let mut headers = reqwest::header::HeaderMap::new();
136 headers.insert(
137 "Authorization",
138 format!("Bearer {}", api_key)
139 .parse()
140 .expect("Bearer token should parse"),
141 );
142
143 reqwest::Client::builder()
144 .default_headers(headers)
145 .build()
146 .expect("openai client should build")
147}