topkio/
openai.rs

1use crate::{
2    constants::OPENAI_API_URL,
3    primitives::{
4        ChatCompletion, Completion, CompletionRequest, CompletionResponse, ModelChoice,
5        OpenAICompletionRequest,
6    },
7    utils::parse_chunk,
8    ToolSet,
9};
10use futures_util::StreamExt;
11use std::{cell::OnceCell, io::Write};
12
13pub struct Client {
14    pub(crate) url: String,
15    pub(crate) client: reqwest::Client,
16}
17
18impl Client {
19    pub fn new(api_key: &str) -> Self {
20        let client = make_client(api_key);
21
22        Client {
23            url: OPENAI_API_URL.to_owned(),
24            client,
25        }
26    }
27
28    pub fn with_ollama(url: &str) -> Self {
29        let client = make_client("ollama");
30
31        Client {
32            url: url.to_owned(),
33            client,
34        }
35    }
36}
37
38impl Completion for Client {
39    async fn post<F>(
40        &self,
41        req: CompletionRequest,
42        tools: &ToolSet,
43        callback: OnceCell<F>,
44    ) -> Result<(), ()>
45    where
46        F: Fn(&str) + Send,
47    {
48        let enable_stream = req.stream.unwrap_or(false);
49        let body = OpenAICompletionRequest::new(req);
50        let url = format!("{}/chat/completions", self.url);
51        let response = self
52            .client
53            .post(&url)
54            .json(&body)
55            .send()
56            .await
57            .expect("openai completion msg");
58
59        match enable_stream {
60            true => {
61                let mut stream = response.bytes_stream();
62                while let Some(item) = stream.next().await {
63                    let data = &item.expect("msg");
64                    let chunk_str = std::str::from_utf8(data).expect("OpenAI expect utf8.");
65                    match parse_chunk(chunk_str) {
66                        Ok(chunk_response) => {
67                            if let Ok(completion_response) =
68                                CompletionResponse::try_from(chunk_response)
69                            {
70                                match completion_response {
71                                    CompletionResponse {
72                                        choice: ModelChoice::Message(msg),
73                                        ..
74                                    } => {
75                                        if let Some(callback) = callback.get() {
76                                            callback(&msg);
77                                            std::io::stdout()
78                                                .flush()
79                                                .expect("Failed to flush stdout");
80                                        }
81                                    }
82                                    CompletionResponse {
83                                        choice: ModelChoice::ToolCall(toolname, args),
84                                        ..
85                                    } => {
86                                        if let Ok(res) =
87                                            tools.invoke(&toolname, args.to_string()).await
88                                        {
89                                            if let Some(callback) = callback.get() {
90                                                callback(&res);
91                                            }
92                                        }
93                                    }
94                                }
95                            }
96                        }
97                        Err(err) => println!("OpenAI error parsing chunk: {}", err),
98                    }
99                }
100            }
101            false => {
102                let chat_completion = response.json::<ChatCompletion>().await;
103                if let Ok(chat_completion) = chat_completion {
104                    if let Ok(completion_response) = CompletionResponse::try_from(chat_completion) {
105                        match completion_response {
106                            CompletionResponse {
107                                choice: ModelChoice::Message(msg),
108                                ..
109                            } => {
110                                if let Some(callback) = callback.get() {
111                                    callback(&msg);
112                                }
113                            }
114                            CompletionResponse {
115                                choice: ModelChoice::ToolCall(toolname, args),
116                                ..
117                            } => {
118                                if let Ok(res) = tools.invoke(&toolname, args.to_string()).await {
119                                    if let Some(callback) = callback.get() {
120                                        callback(&res);
121                                    }
122                                }
123                            }
124                        }
125                    }
126                }
127            }
128        }
129
130        Ok(())
131    }
132}
133
134fn make_client(api_key: &str) -> reqwest::Client {
135    let mut headers = reqwest::header::HeaderMap::new();
136    headers.insert(
137        "Authorization",
138        format!("Bearer {}", api_key)
139            .parse()
140            .expect("Bearer token should parse"),
141    );
142
143    reqwest::Client::builder()
144        .default_headers(headers)
145        .build()
146        .expect("openai client should build")
147}