deepseek_api_client/
lib.rs

1use optional_default::OptionalDefault;
2use reqwest;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::fmt::format;
6use std::future::Future;
7use std::pin::Pin;
8use std::{boxed::Box, rc::Rc};
9
10use anyhow::{anyhow, Result};
11use serde_json::{json, Value};
12pub static CHAT_COMPLETION_API_URL: &str = "https://api.deepseek.com/chat/completions";
13pub static DEEPSEEK_MODEL_CHAT: &str = "deepseek-chat";
14pub static DEEPSEEK_MODEL_CODER: &str = "deepseek-coder";
15
16#[derive(Serialize, Deserialize, Clone, Debug)]
17pub struct Message {
18    pub role: String,
19    pub content: String,
20}
21#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
22pub struct FunctionalCallObject {
23    r#type: String,
24    function: FunctionalCallObjectSingle,
25}
26#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
27pub struct FunctionalCallObjectSingle {
28    name: String,
29    description: String,
30    parameters: Value,
31}
32impl FunctionalCallObject {
33    fn new(fname: &str, fdesc: &str, parameters: Value) -> Self {
34        Self {
35            r#type: "function".to_owned(),
36            function: FunctionalCallObjectSingle {
37                name: fname.to_owned(),
38                description: fdesc.to_owned(),
39                parameters,
40            },
41        }
42    }
43}
44#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
45pub struct LogprobsObject {
46    token: String,
47    logprob: f32,
48    #[optional(default = None)]
49    bytes: Option<Vec<i32>>,
50}
51#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
52pub struct TopLogprobsObject {
53    token: String,
54    logprob: f32,
55    #[optional(default = None)]
56    bytes: Option<Vec<i32>>,
57    top_logprobs: Vec<LogprobsObject>,
58}
59#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
60pub struct LogprobsContent {
61    #[optional(default = None)]
62    content: Option<Vec<TopLogprobsObject>>,
63}
64#[derive(Serialize, Deserialize, OptionalDefault, Clone, Debug)]
65pub struct ChoiceObjectMessage {
66    #[optional(default = "assistant".to_owned())]
67    pub role: String,
68    pub content: Option<String>,
69}
70#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
71pub struct ChoiceObject {
72    #[optional(default = None)]
73    finish_reason: Option<String>,
74    index: i32,
75    #[optional(default = None)]
76    logprobs: Option<LogprobsContent>,
77    message: ChoiceObjectMessage,
78}
79#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
80struct Usage {
81    completion_tokens: i32,
82    prompt_tokens: i32,
83    total_tokens: i32,
84}
85#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
86struct ResponseFormatType {
87    r#type: String,
88}
89#[derive(Serialize, OptionalDefault, Debug)]
90pub struct RequestChat {
91    messages: Vec<Message>,
92    model: String,
93    #[optional(default = 0.0)]
94    frequency_penalty: f32,
95    #[optional(default = 2048)]
96    max_tokens: usize,
97    #[optional(default = 0.0)]
98    presence_penalty: f32,
99    #[optional(default = ResponseFormatType{r#type: "text".to_owned()} )]
100    response_format: ResponseFormatType,
101    #[optional(default = None)]
102    stop: Option<String>,
103    #[optional(default = false)]
104    stream: bool,
105    #[optional(default = None)]
106    stream_options: Option<String>,
107    #[optional(default = 1.0)]
108    temperature: f32,
109    #[optional(default = 1.0)]
110    top_p: f32,
111    #[optional(default = None)]
112    tools: Option<Vec<FunctionalCallObject>>,
113    #[optional(default = "none".to_owned())]
114    tool_choice: String,
115    #[optional(default = false)]
116    logprobs: bool,
117    #[optional(default = None)]
118    top_logprobs: Option<i32>,
119}
120
121#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
122pub struct ChatResponses {
123    id: String,
124    object: String,
125    created: i64,
126    model: String,
127    system_fingerprint: String,
128    choices: Vec<ChoiceObject>,
129    #[optional(default = None)]
130    usage: Option<Usage>,
131}
132
133#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
134pub struct ChatResponsesStream {
135    id: String,
136    object: String,
137    created: i64,
138    model: String,
139    system_fingerprint: String,
140    choices: Vec<ChoiceObjectChunk>,
141    #[optional(default = None)]
142    usage: Option<Usage>,
143}
144#[derive(Serialize, Deserialize, OptionalDefault, Debug)]
145pub struct ChoiceObjectChunk {
146    #[optional(default = None)]
147    finish_reason: Option<String>,
148    index: i32,
149    #[optional(default = None)]
150    logprobs: Option<LogprobsContent>,
151    delta: Value,
152}
153pub fn chat_DeepSeek_LLM_stream(
154    mut params: RequestChat,
155    api_key: &str,
156) -> Box<
157    dyn FnMut(
158        Vec<Message>,
159    ) -> Pin<Box<dyn Future<Output = Result<reqwest::Response, reqwest::Error>>>>,
160> {
161    let api_key_rc = Rc::new(api_key.to_owned());
162    let c =  move  |messages: Vec<Message>| -> Pin<Box<dyn Future<Output = Result<reqwest::Response,reqwest::Error> >>> {
163        params.messages = messages;
164        let params_json =  serde_json::to_string(&params).unwrap();
165        let client = reqwest::Client::new();
166        let api_key = api_key_rc.clone();
167        let req = client.post(CHAT_COMPLETION_API_URL)
168        .header("Content-Type", "application/json")
169        .header("Authorization", format!("Bearer {}", api_key.to_string()))
170        .body(params_json)
171        .send();
172        Box::pin(req) 
173    };
174    Box::new(c)
175}
176pub fn chat_deepSeek_LLM_synchornous(
177    mut params: RequestChat,
178    api_key: &str,
179) -> Box<dyn FnMut(Vec<Message>) -> Result<ChatResponses>> {
180    let api_key_rc = Rc::new(api_key.to_owned());
181    let c = move |messages: Vec<Message>| -> Result<ChatResponses> {
182        params.messages = messages;
183        let params_json = serde_json::to_string(&params).unwrap();
184        let api_key = api_key_rc.clone();
185        let client = reqwest::blocking::Client::new();
186        let req = client
187            .post(CHAT_COMPLETION_API_URL)
188            .header("Content-Type", "application/json")
189            .header("Authorization", format!("Bearer {}", api_key.to_string()))
190            .body(params_json)
191            .send();
192
193        if let Ok(req) = req {
194            let s = req.text().unwrap();
195            let data = serde_json::from_str(&s);
196            if let Ok(data) = data {
197                return Ok(data);
198            }
199            return Err(anyhow!("Parse error {:?}", data));
200        }
201        Err(anyhow!("Can't connect to API"))
202    };
203    Box::new(c)
204}
205
206pub fn chat_DeepSeek_LLM(
207    mut params: RequestChat,
208    api_key: &str,
209) -> Box<dyn FnMut(Vec<Message>) -> Pin<Box<dyn Future<Output = Result<ChatResponses>>>>> {
210    let api_key_rc = Rc::new(api_key.to_owned());
211    let f = move |messages: Vec<Message>| -> Pin<Box<dyn Future<Output = Result<ChatResponses>>>> {
212        params.messages = messages;
213        let is_stream = params.stream;
214        let params_json = serde_json::to_string(&params).unwrap();
215
216        let api_key = api_key_rc.clone();
217        let c = async move {
218            let client = reqwest::Client::new();
219            let req = curl_post_request(
220                &client,
221                CHAT_COMPLETION_API_URL,
222                params_json,
223                api_key.to_string().as_str(),
224            );
225            if let Ok(req) = req {
226                let res = client.execute(req);
227
228                if let Ok(r) = res.await {
229                    let s = r.text().await;
230
231                    if let Ok(s) = s {
232                        if is_stream {
233                            let data = string_to_ChatResponses(&s);
234                            Ok(data)
235                        } else {
236                            let data = serde_json::from_str(&s);
237                            if data.is_ok() {
238                                let d: ChatResponses = data.unwrap();
239                                Ok(d)
240                            } else {
241                                Err(anyhow!("Parse error {:?}", data))
242                            }
243                        }
244                    } else {
245                        Err(anyhow!("Result response {:?}", s))
246                    }
247                } else {
248                    Err(anyhow!("Can't connect to API"))
249                }
250            } else {
251                Err(anyhow!("Request {:?}", req))
252            }
253        };
254        Box::pin(c)
255    };
256    Box::new(f)
257}
258pub fn get_response_text(d: &ChatResponses, ind: usize) -> Option<String> {
259    let response_index = d.choices.get(ind);
260    if let Some(response_index) = response_index {
261        response_index.message.content.clone()
262    } else {
263        None
264    }
265}
266pub fn string_to_ChatResponses(s: &str) -> ChatResponses {
267    let st = s.split("\n\n");
268    let fold_init: ChatResponses = ChatResponses!( id: "".to_owned(),
269                                object: "".to_owned(),
270                                created: 0,
271                                model: "".to_owned(),
272                                system_fingerprint: "".to_owned(),
273                                choices: vec![]);
274
275    let data: ChatResponses = st.filter_map(|item|{
276                                let sj = item.strip_prefix("data: ").unwrap_or(""); 
277                                let dt = serde_json::from_str::<ChatResponsesStream>(sj).ok();
278                                dt
279                            }).fold(fold_init,|mut acc,item|{
280                                if acc.choices.is_empty(){
281                                    acc.id = item.id;
282                                    acc.object = item.object;
283                                    acc.created = item.created;
284                                    acc.model = item.model;
285                                    acc.system_fingerprint = item.system_fingerprint;
286                                    let choice = item.choices.get(0).unwrap().delta.as_object().unwrap().get("content").unwrap().as_str().unwrap_or("").to_owned();
287                                    acc.choices = vec![ChoiceObject!(finish_reason: None,index: item.choices.get(0).unwrap().index,logprobs: None,message: ChoiceObjectMessage!(content: Some(choice) ))];
288                                }else{
289                                    let choice = item.choices.get(0).unwrap().delta.as_object().unwrap().get("content").unwrap().as_str().unwrap_or("").to_owned();
290                                    let acc_choices = acc.choices[0].message.content.clone().unwrap();
291                                    acc.choices[0].message.content = Some(acc_choices+&choice);
292                                } 
293                                acc
294                            });
295
296    data
297}
298fn curl_post_request(
299    client: &reqwest::Client,
300    url: &str,
301    params: String,
302    api_key: &str,
303) -> Result<reqwest::Request, reqwest::Error> {
304    let req = client
305        .post(url)
306        .header("Content-Type", "application/json")
307        .header("Authorization", format!("Bearer {}", api_key))
308        .body(params)
309        .build();
310    req
311}
312
313pub fn chat_completion(
314    api_key: &str,
315) -> Box<dyn FnMut(Vec<Message>) -> Pin<Box<dyn Future<Output = Result<ChatResponses>>>>> {
316    let params = RequestChat! {
317        model: DEEPSEEK_MODEL_CHAT.to_owned(),
318        messages: vec![]
319    };
320    chat_DeepSeek_LLM(params, api_key)
321}
322pub fn code_completion(
323    api_key: &str,
324) -> Box<dyn FnMut(Vec<Message>) -> Pin<Box<dyn Future<Output = Result<ChatResponses>>>>> {
325    let params = RequestChat! {
326        model: DEEPSEEK_MODEL_CODER.to_owned(),
327        stream:true,
328        messages: vec![]
329    };
330    chat_DeepSeek_LLM(params, api_key)
331}
332pub fn llm_function_call(
333    api_key: &str,
334    tools: Vec<FunctionalCallObject>,
335) -> Box<dyn FnMut(Vec<Message>) -> Pin<Box<dyn Future<Output = Result<ChatResponses>>>>> {
336    let params = RequestChat! {
337        model: DEEPSEEK_MODEL_CODER.to_owned(),
338        messages: vec![],
339        tools:Some(tools)
340    };
341    chat_DeepSeek_LLM(params, api_key)
342}
343pub fn chat_completion_stream(
344    api_key: &str,
345) -> Box<
346    dyn FnMut(
347        Vec<Message>,
348    ) -> Pin<Box<dyn Future<Output = Result<reqwest::Response, reqwest::Error>>>>,
349> {
350    let params = RequestChat! {
351        model: DEEPSEEK_MODEL_CHAT.to_owned(),
352        stream:true,
353        messages: vec![]
354    };
355    chat_DeepSeek_LLM_stream(params, api_key)
356}
357pub fn code_completion_stream(
358    api_key: &str,
359) -> Box<
360    dyn FnMut(
361        Vec<Message>,
362    ) -> Pin<Box<dyn Future<Output = Result<reqwest::Response, reqwest::Error>>>>,
363> {
364    let params = RequestChat! {
365        model: DEEPSEEK_MODEL_CODER.to_owned(),
366        stream:true,
367        messages: vec![]
368    };
369    chat_DeepSeek_LLM_stream(params, api_key)
370}
371pub fn llm_function_call_stream(
372    api_key: &str,
373    tools: Vec<FunctionalCallObject>,
374) -> Box<
375    dyn FnMut(
376        Vec<Message>,
377    ) -> Pin<Box<dyn Future<Output = Result<reqwest::Response, reqwest::Error>>>>,
378> {
379    let params = RequestChat! {
380        model: DEEPSEEK_MODEL_CODER.to_owned(),
381        stream:true,
382        messages: vec![],
383        tools:Some(tools)
384    };
385    chat_DeepSeek_LLM_stream(params, api_key)
386}
387pub fn chat_completion_sync(
388    api_key: &str,
389) -> Box<dyn FnMut(Vec<Message>) -> Result<ChatResponses>> {
390    let params = RequestChat! {
391        model: DEEPSEEK_MODEL_CHAT.to_owned(),
392        messages: vec![]
393    };
394    chat_deepSeek_LLM_synchornous(params, api_key)
395}
396pub fn code_completion_sync(
397    api_key: &str,
398) -> Box<dyn FnMut(Vec<Message>) -> Result<ChatResponses>> {
399    let params = RequestChat! {
400        model: DEEPSEEK_MODEL_CODER.to_owned(),
401        messages: vec![]
402    };
403    chat_deepSeek_LLM_synchornous(params, api_key)
404}
405pub fn llm_function_call_sync(
406    api_key: &str,
407    tools: Vec<FunctionalCallObject>,
408) -> Box<dyn FnMut(Vec<Message>) -> Result<ChatResponses>> {
409    let params = RequestChat! {
410        model: DEEPSEEK_MODEL_CODER.to_owned(),
411        messages: vec![],
412        tools:Some(tools)
413    };
414    chat_deepSeek_LLM_synchornous(params, api_key)
415}
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use futures_util::StreamExt;
420    use tokio::runtime::Runtime;
421    // replace by your API key
422    pub static DEEPSEEK_API_KEY: &str = "sk-.......................";
423    #[test]
424    fn synchornous_completion_test() {
425        
426        let messages = vec![
427            Message {
428                role: "system".to_owned(),
429                content: "You are a helpful assistant".to_owned(),
430            },
431            Message {
432                role: "user".to_owned(),
433                content: "Write Hello world in rust".to_owned(),
434            },
435        ];
436        let mut llm = chat_completion_sync(DEEPSEEK_API_KEY);
437        let res = llm(messages);
438        let res_text = get_response_text(&res.unwrap(), 0);
439        dbg!(res_text);
440    }
441    #[test]
442    fn stream_completion_test() {
443        let messages = vec![
444            Message {
445                role: "system".to_owned(),
446                content: "You are a helpful assistant".to_owned(),
447            },
448            Message {
449                role: "user".to_owned(),
450                content: "Write Hello world in rust".to_owned(),
451            },
452        ];
453
454        let mut llm = chat_completion_stream(DEEPSEEK_API_KEY);
455        let rt = Runtime::new().unwrap();
456
457        let dt = llm(messages);
458        let _ = rt.block_on(async {
459            let res = dt.await.unwrap();
460            let mut stream = res.bytes_stream();
461            while let Some(item) = stream.next().await {
462                let item = item.unwrap();
463                let s = match std::str::from_utf8(&item) {
464                    Ok(v) => v,
465                    Err(e) => panic!("Invalid UTF-8 sequence: {}", e),
466                };
467                let data = string_to_ChatResponses(s);
468                let res = get_response_text(&data, 0).unwrap_or("".to_owned());
469                println!("{}", res);
470            }
471        });
472    }
473    #[test]
474    fn chat_completion_test() {
475        let rt = Runtime::new().unwrap();
476        let mut codeLLM = code_completion(DEEPSEEK_API_KEY);
477        let messages = vec![
478            Message {
479                role: "system".to_owned(),
480                content: "You are a helpful assistant".to_owned(),
481            },
482            Message {
483                role: "user".to_owned(),
484                content: "Write Hello world in rust".to_owned(),
485            },
486        ];
487        let res = codeLLM(messages);
488        let r = rt.block_on(async { get_response_text(&res.await.unwrap(), 0) });
489        dbg!(&r);
490        assert!(r.is_some());
491    }
492    #[test]
493    fn function_call_test() {
494        let rt = Runtime::new().unwrap();
495
496        let tparam1 = json!({
497            "type": "object",
498            "required": ["location"],
499            "properties": {
500                "location": {
501                    "type": "string",
502                    "description": "The city and state, e.g. San Francisco, CA"
503                }
504            }
505        });
506
507        let t1 = FunctionalCallObject::new(
508            "get_weather",
509            "Get weather of an location, the user shoud supply a location first",
510            tparam1,
511        );
512
513        let tools = vec![t1];
514        let mut codeLLM = llm_function_call(DEEPSEEK_API_KEY, tools);
515
516        let messages = vec![
517            Message {
518                role: "system".to_owned(),
519                content: "You are a helpful assistant,your should reply in json format".to_owned(),
520            },
521            Message {
522                role: "user".to_owned(),
523                content: "How's the weather in Hangzhou?".to_owned(),
524            },
525        ];
526        let res = codeLLM(messages);
527        let r = rt.block_on(async {
528            let d = res.await.unwrap();
529            dbg!(&d);
530            get_response_text(&d, 0)
531        });
532        dbg!(&r);
533        assert!(r.is_some());
534    }
535}