1use optional_default::OptionalDefault;
2use reqwest;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fmt::format;
6use std::future::Future;
7use std::pin::Pin;
8use std::{boxed::Box, rc::Rc};
9
10use anyhow::{anyhow, Result};
11use serde_json::{json, Value};
12pub static CHAT_COMPLETION_API_URL: &str = "https://api.deepseek.com/chat/completions";
13pub static DEEPSEEK_MODEL_CHAT: &str = "deepseek-chat";
14pub static DEEPSEEK_MODEL_CODER: &str = "deepseek-coder";
15
16#[derive(Serialize, Deserialize, Clone, Debug)]
17pub struct Message {
18 pub role: String,
19 pub content: String,
20}
21#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
22pub struct FunctionalCallObject {
23 r#type: String,
24 function: FunctionalCallObjectSingle,
25}
26#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
27pub struct FunctionalCallObjectSingle {
28 name: String,
29 description: String,
30 parameters: Value,
31}
32impl FunctionalCallObject {
33 fn new(fname: &str, fdesc: &str, parameters: Value) -> Self {
34 Self {
35 r#type: "function".to_owned(),
36 function: FunctionalCallObjectSingle {
37 name: fname.to_owned(),
38 description: fdesc.to_owned(),
39 parameters,
40 },
41 }
42 }
43}
44#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
45pub struct LogprobsObject {
46 token: String,
47 logprob: f32,
48 #[optional(default = None)]
49 bytes: Option<Vec<i32>>,
50}
51#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
52pub struct TopLogprobsObject {
53 token: String,
54 logprob: f32,
55 #[optional(default = None)]
56 bytes: Option<Vec<i32>>,
57 top_logprobs: Vec<LogprobsObject>,
58}
59#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
60pub struct LogprobsContent {
61 #[optional(default = None)]
62 content: Option<Vec<TopLogprobsObject>>,
63}
64#[derive(Serialize, Deserialize, OptionalDefault, Clone, Debug)]
65pub struct ChoiceObjectMessage {
66 #[optional(default = "assistant".to_owned())]
67 pub role: String,
68 pub content: Option<String>,
69}
70#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
71pub struct ChoiceObject {
72 #[optional(default = None)]
73 finish_reason: Option<String>,
74 index: i32,
75 #[optional(default = None)]
76 logprobs: Option<LogprobsContent>,
77 message: ChoiceObjectMessage,
78}
79#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
80struct Usage {
81 completion_tokens: i32,
82 prompt_tokens: i32,
83 total_tokens: i32,
84}
85#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
86struct ResponseFormatType {
87 r#type: String,
88}
89#[derive(Serialize, OptionalDefault, Debug)]
90pub struct RequestChat {
91 messages: Vec<Message>,
92 model: String,
93 #[optional(default = 0.0)]
94 frequency_penalty: f32,
95 #[optional(default = 2048)]
96 max_tokens: usize,
97 #[optional(default = 0.0)]
98 presence_penalty: f32,
99 #[optional(default = ResponseFormatType{r#type: "text".to_owned()} )]
100 response_format: ResponseFormatType,
101 #[optional(default = None)]
102 stop: Option<String>,
103 #[optional(default = false)]
104 stream: bool,
105 #[optional(default = None)]
106 stream_options: Option<String>,
107 #[optional(default = 1.0)]
108 temperature: f32,
109 #[optional(default = 1.0)]
110 top_p: f32,
111 #[optional(default = None)]
112 tools: Option<Vec<FunctionalCallObject>>,
113 #[optional(default = "none".to_owned())]
114 tool_choice: String,
115 #[optional(default = false)]
116 logprobs: bool,
117 #[optional(default = None)]
118 top_logprobs: Option<i32>,
119}
120
121#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
122pub struct ChatResponses {
123 id: String,
124 object: String,
125 created: i64,
126 model: String,
127 system_fingerprint: String,
128 choices: Vec<ChoiceObject>,
129 #[optional(default = None)]
130 usage: Option<Usage>,
131}
132
133#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
134pub struct ChatResponsesStream {
135 id: String,
136 object: String,
137 created: i64,
138 model: String,
139 system_fingerprint: String,
140 choices: Vec<ChoiceObjectChunk>,
141 #[optional(default = None)]
142 usage: Option<Usage>,
143}
144#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
145pub struct ChoiceObjectChunk {
146 #[optional(default = None)]
147 finish_reason: Option<String>,
148 index: i32,
149 #[optional(default = None)]
150 logprobs: Option<LogprobsContent>,
151 delta: Value,
152}
153pub fn chat_DeepSeek_LLM_stream(
154 mut params: RequestChat,
155 api_key: &str,
156) -> Box<
157 dyn FnMut(
158 Vec<Message>,
159 ) -> Pin<Box<dyn Future<Output = Result<reqwest::Response, reqwest::Error>>>>,
160> {
161 let api_key_rc = Rc::new(api_key.to_owned());
162 let c = move |messages: Vec<Message>| -> Pin<Box<dyn Future<Output = Result<reqwest::Response,reqwest::Error> >>> {
163 params.messages = messages;
164 let params_json = serde_json::to_string(¶ms).unwrap();
165 let client = reqwest::Client::new();
166 let api_key = api_key_rc.clone();
167 let req = client.post(CHAT_COMPLETION_API_URL)
168 .header("Content-Type", "application/json")
169 .header("Authorization", format!("Bearer {}", api_key.to_string()))
170 .body(params_json)
171 .send();
172 Box::pin(req)
173 };
174 Box::new(c)
175}
176pub fn chat_deepSeek_LLM_synchornous(
177 mut params: RequestChat,
178 api_key: &str,
179) -> Box<dyn FnMut(Vec<Message>) -> Result<ChatResponses>> {
180 let api_key_rc = Rc::new(api_key.to_owned());
181 let c = move |messages: Vec<Message>| -> Result<ChatResponses> {
182 params.messages = messages;
183 let params_json = serde_json::to_string(¶ms).unwrap();
184 let api_key = api_key_rc.clone();
185 let client = reqwest::blocking::Client::new();
186 let req = client
187 .post(CHAT_COMPLETION_API_URL)
188 .header("Content-Type", "application/json")
189 .header("Authorization", format!("Bearer {}", api_key.to_string()))
190 .body(params_json)
191 .send();
192
193 if let Ok(req) = req {
194 let s = req.text().unwrap();
195 let data = serde_json::from_str(&s);
196 if let Ok(data) = data {
197 return Ok(data);
198 }
199 return Err(anyhow!("Parse error {:?}", data));
200 }
201 Err(anyhow!("Can't connect to API"))
202 };
203 Box::new(c)
204}
205
206pub fn chat_DeepSeek_LLM(
207 mut params: RequestChat,
208 api_key: &str,
209) -> Box<dyn FnMut(Vec<Message>) -> Pin<Box<dyn Future<Output = Result<ChatResponses>>>>> {
210 let api_key_rc = Rc::new(api_key.to_owned());
211 let f = move |messages: Vec<Message>| -> Pin<Box<dyn Future<Output = Result<ChatResponses>>>> {
212 params.messages = messages;
213 let is_stream = params.stream;
214 let params_json = serde_json::to_string(¶ms).unwrap();
215
216 let api_key = api_key_rc.clone();
217 let c = async move {
218 let client = reqwest::Client::new();
219 let req = curl_post_request(
220 &client,
221 CHAT_COMPLETION_API_URL,
222 params_json,
223 api_key.to_string().as_str(),
224 );
225 if let Ok(req) = req {
226 let res = client.execute(req);
227
228 if let Ok(r) = res.await {
229 let s = r.text().await;
230
231 if let Ok(s) = s {
232 if is_stream {
233 let data = string_to_ChatResponses(&s);
234 Ok(data)
235 } else {
236 let data = serde_json::from_str(&s);
237 if data.is_ok() {
238 let d: ChatResponses = data.unwrap();
239 Ok(d)
240 } else {
241 Err(anyhow!("Parse error {:?}", data))
242 }
243 }
244 } else {
245 Err(anyhow!("Result response {:?}", s))
246 }
247 } else {
248 Err(anyhow!("Can't connect to API"))
249 }
250 } else {
251 Err(anyhow!("Request {:?}", req))
252 }
253 };
254 Box::pin(c)
255 };
256 Box::new(f)
257}
258pub fn get_response_text(d: &ChatResponses, ind: usize) -> Option<String> {
259 let response_index = d.choices.get(ind);
260 if let Some(response_index) = response_index {
261 response_index.message.content.clone()
262 } else {
263 None
264 }
265}
266pub fn string_to_ChatResponses(s: &str) -> ChatResponses {
267 let st = s.split("\n\n");
268 let fold_init: ChatResponses = ChatResponses!( id: "".to_owned(),
269 object: "".to_owned(),
270 created: 0,
271 model: "".to_owned(),
272 system_fingerprint: "".to_owned(),
273 choices: vec![]);
274
275 let data: ChatResponses = st.filter_map(|item|{
276 let sj = item.strip_prefix("data: ").unwrap_or("");
277 let dt = serde_json::from_str::<ChatResponsesStream>(sj).ok();
278 dt
279 }).fold(fold_init,|mut acc,item|{
280 if acc.choices.is_empty(){
281 acc.id = item.id;
282 acc.object = item.object;
283 acc.created = item.created;
284 acc.model = item.model;
285 acc.system_fingerprint = item.system_fingerprint;
286 let choice = item.choices.get(0).unwrap().delta.as_object().unwrap().get("content").unwrap().as_str().unwrap_or("").to_owned();
287 acc.choices = vec![ChoiceObject!(finish_reason: None,index: item.choices.get(0).unwrap().index,logprobs: None,message: ChoiceObjectMessage!(content: Some(choice) ))];
288 }else{
289 let choice = item.choices.get(0).unwrap().delta.as_object().unwrap().get("content").unwrap().as_str().unwrap_or("").to_owned();
290 let acc_choices = acc.choices[0].message.content.clone().unwrap();
291 acc.choices[0].message.content = Some(acc_choices+&choice);
292 }
293 acc
294 });
295
296 data
297}
298fn curl_post_request(
299 client: &reqwest::Client,
300 url: &str,
301 params: String,
302 api_key: &str,
303) -> Result<reqwest::Request, reqwest::Error> {
304 let req = client
305 .post(url)
306 .header("Content-Type", "application/json")
307 .header("Authorization", format!("Bearer {}", api_key))
308 .body(params)
309 .build();
310 req
311}
312
313pub fn chat_completion(
314 api_key: &str,
315) -> Box<dyn FnMut(Vec<Message>) -> Pin<Box<dyn Future<Output = Result<ChatResponses>>>>> {
316 let params = RequestChat! {
317 model: DEEPSEEK_MODEL_CHAT.to_owned(),
318 messages: vec![]
319 };
320 chat_DeepSeek_LLM(params, api_key)
321}
322pub fn code_completion(
323 api_key: &str,
324) -> Box<dyn FnMut(Vec<Message>) -> Pin<Box<dyn Future<Output = Result<ChatResponses>>>>> {
325 let params = RequestChat! {
326 model: DEEPSEEK_MODEL_CODER.to_owned(),
327 stream:true,
328 messages: vec![]
329 };
330 chat_DeepSeek_LLM(params, api_key)
331}
332pub fn llm_function_call(
333 api_key: &str,
334 tools: Vec<FunctionalCallObject>,
335) -> Box<dyn FnMut(Vec<Message>) -> Pin<Box<dyn Future<Output = Result<ChatResponses>>>>> {
336 let params = RequestChat! {
337 model: DEEPSEEK_MODEL_CODER.to_owned(),
338 messages: vec![],
339 tools:Some(tools)
340 };
341 chat_DeepSeek_LLM(params, api_key)
342}
343pub fn chat_completion_stream(
344 api_key: &str,
345) -> Box<
346 dyn FnMut(
347 Vec<Message>,
348 ) -> Pin<Box<dyn Future<Output = Result<reqwest::Response, reqwest::Error>>>>,
349> {
350 let params = RequestChat! {
351 model: DEEPSEEK_MODEL_CHAT.to_owned(),
352 stream:true,
353 messages: vec![]
354 };
355 chat_DeepSeek_LLM_stream(params, api_key)
356}
357pub fn code_completion_stream(
358 api_key: &str,
359) -> Box<
360 dyn FnMut(
361 Vec<Message>,
362 ) -> Pin<Box<dyn Future<Output = Result<reqwest::Response, reqwest::Error>>>>,
363> {
364 let params = RequestChat! {
365 model: DEEPSEEK_MODEL_CODER.to_owned(),
366 stream:true,
367 messages: vec![]
368 };
369 chat_DeepSeek_LLM_stream(params, api_key)
370}
371pub fn llm_function_call_stream(
372 api_key: &str,
373 tools: Vec<FunctionalCallObject>,
374) -> Box<
375 dyn FnMut(
376 Vec<Message>,
377 ) -> Pin<Box<dyn Future<Output = Result<reqwest::Response, reqwest::Error>>>>,
378> {
379 let params = RequestChat! {
380 model: DEEPSEEK_MODEL_CODER.to_owned(),
381 stream:true,
382 messages: vec![],
383 tools:Some(tools)
384 };
385 chat_DeepSeek_LLM_stream(params, api_key)
386}
387pub fn chat_completion_sync(
388 api_key: &str,
389) -> Box<dyn FnMut(Vec<Message>) -> Result<ChatResponses>> {
390 let params = RequestChat! {
391 model: DEEPSEEK_MODEL_CHAT.to_owned(),
392 messages: vec![]
393 };
394 chat_deepSeek_LLM_synchornous(params, api_key)
395}
396pub fn code_completion_sync(
397 api_key: &str,
398) -> Box<dyn FnMut(Vec<Message>) -> Result<ChatResponses>> {
399 let params = RequestChat! {
400 model: DEEPSEEK_MODEL_CODER.to_owned(),
401 messages: vec![]
402 };
403 chat_deepSeek_LLM_synchornous(params, api_key)
404}
405pub fn llm_function_call_sync(
406 api_key: &str,
407 tools: Vec<FunctionalCallObject>,
408) -> Box<dyn FnMut(Vec<Message>) -> Result<ChatResponses>> {
409 let params = RequestChat! {
410 model: DEEPSEEK_MODEL_CODER.to_owned(),
411 messages: vec![],
412 tools:Some(tools)
413 };
414 chat_deepSeek_LLM_synchornous(params, api_key)
415}
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use futures_util::StreamExt;
420 use tokio::runtime::Runtime;
421 pub static DEEPSEEK_API_KEY: &str = "sk-.......................";
423 #[test]
424 fn synchornous_completion_test() {
425
426 let messages = vec![
427 Message {
428 role: "system".to_owned(),
429 content: "You are a helpful assistant".to_owned(),
430 },
431 Message {
432 role: "user".to_owned(),
433 content: "Write Hello world in rust".to_owned(),
434 },
435 ];
436 let mut llm = chat_completion_sync(DEEPSEEK_API_KEY);
437 let res = llm(messages);
438 let res_text = get_response_text(&res.unwrap(), 0);
439 dbg!(res_text);
440 }
441 #[test]
442 fn stream_completion_test() {
443 let messages = vec![
444 Message {
445 role: "system".to_owned(),
446 content: "You are a helpful assistant".to_owned(),
447 },
448 Message {
449 role: "user".to_owned(),
450 content: "Write Hello world in rust".to_owned(),
451 },
452 ];
453
454 let mut llm = chat_completion_stream(DEEPSEEK_API_KEY);
455 let rt = Runtime::new().unwrap();
456
457 let dt = llm(messages);
458 let _ = rt.block_on(async {
459 let res = dt.await.unwrap();
460 let mut stream = res.bytes_stream();
461 while let Some(item) = stream.next().await {
462 let item = item.unwrap();
463 let s = match std::str::from_utf8(&item) {
464 Ok(v) => v,
465 Err(e) => panic!("Invalid UTF-8 sequence: {}", e),
466 };
467 let data = string_to_ChatResponses(s);
468 let res = get_response_text(&data, 0).unwrap_or("".to_owned());
469 println!("{}", res);
470 }
471 });
472 }
473 #[test]
474 fn chat_completion_test() {
475 let rt = Runtime::new().unwrap();
476 let mut codeLLM = code_completion(DEEPSEEK_API_KEY);
477 let messages = vec![
478 Message {
479 role: "system".to_owned(),
480 content: "You are a helpful assistant".to_owned(),
481 },
482 Message {
483 role: "user".to_owned(),
484 content: "Write Hello world in rust".to_owned(),
485 },
486 ];
487 let res = codeLLM(messages);
488 let r = rt.block_on(async { get_response_text(&res.await.unwrap(), 0) });
489 dbg!(&r);
490 assert!(r.is_some());
491 }
492 #[test]
493 fn function_call_test() {
494 let rt = Runtime::new().unwrap();
495
496 let tparam1 = json!({
497 "type": "object",
498 "required": ["location"],
499 "properties": {
500 "location": {
501 "type": "string",
502 "description": "The city and state, e.g. San Francisco, CA"
503 }
504 }
505 });
506
507 let t1 = FunctionalCallObject::new(
508 "get_weather",
509 "Get weather of an location, the user shoud supply a location first",
510 tparam1,
511 );
512
513 let tools = vec![t1];
514 let mut codeLLM = llm_function_call(DEEPSEEK_API_KEY, tools);
515
516 let messages = vec![
517 Message {
518 role: "system".to_owned(),
519 content: "You are a helpful assistant,your should reply in json format".to_owned(),
520 },
521 Message {
522 role: "user".to_owned(),
523 content: "How's the weather in Hangzhou?".to_owned(),
524 },
525 ];
526 let res = codeLLM(messages);
527 let r = rt.block_on(async {
528 let d = res.await.unwrap();
529 dbg!(&d);
530 get_response_text(&d, 0)
531 });
532 dbg!(&r);
533 assert!(r.is_some());
534 }
535}