mistralai_client/v1/
client.rs

1use futures::stream::StreamExt;
2use futures::Stream;
3use log::debug;
4use reqwest::Error as ReqwestError;
5use std::{
6    any::Any,
7    collections::HashMap,
8    sync::{Arc, Mutex},
9};
10
11use crate::v1::{chat, chat_stream, constants, embedding, error, model_list, tool, utils};
12
13#[derive(Debug)]
14pub struct Client {
15    pub api_key: String,
16    pub endpoint: String,
17    pub max_retries: u32,
18    pub timeout: u32,
19
20    functions: Arc<Mutex<HashMap<String, Box<dyn tool::Function>>>>,
21    last_function_call_result: Arc<Mutex<Option<Box<dyn Any + Send>>>>,
22}
23
24impl Client {
25    /// Constructs a new `Client`.
26    ///
27    /// # Arguments
28    ///
29    /// * `api_key`     - An optional API key.
30    ///                   If not provided, the method will try to use the `MISTRAL_API_KEY` environment variable.
31    /// * `endpoint`    - An optional custom API endpoint. Defaults to the official API endpoint if not provided.
32    /// * `max_retries` - Optional maximum number of retries for failed requests. Defaults to `5`.
33    /// * `timeout`     - Optional timeout in seconds for requests. Defaults to `120`.
34    ///
35    /// # Examples
36    ///
37    /// ```
38    /// use mistralai_client::v1::client::Client;
39    ///
40    /// let client = Client::new(Some("your_api_key_here".to_string()), None, Some(3), Some(60));
41    /// assert!(client.is_ok());
42    /// ```
43    ///
44    /// # Errors
45    ///
46    /// This method fails whenever neither the `api_key` is provided
47    /// nor the `MISTRAL_API_KEY` environment variable is set.
48    pub fn new(
49        api_key: Option<String>,
50        endpoint: Option<String>,
51        max_retries: Option<u32>,
52        timeout: Option<u32>,
53    ) -> Result<Self, error::ClientError> {
54        let api_key = match api_key {
55            Some(api_key_from_param) => api_key_from_param,
56            None => {
57                std::env::var("MISTRAL_API_KEY").map_err(|_| error::ClientError::MissingApiKey)?
58            }
59        };
60        let endpoint = endpoint.unwrap_or(constants::API_URL_BASE.to_string());
61        let max_retries = max_retries.unwrap_or(5);
62        let timeout = timeout.unwrap_or(120);
63
64        let functions: Arc<_> = Arc::new(Mutex::new(HashMap::new()));
65        let last_function_call_result = Arc::new(Mutex::new(None));
66
67        Ok(Self {
68            api_key,
69            endpoint,
70            max_retries,
71            timeout,
72
73            functions,
74            last_function_call_result,
75        })
76    }
77
78    /// Synchronously sends a chat completion request and returns the response.
79    ///
80    /// # Arguments
81    ///
82    /// * `model` - The [Model] to use for the chat completion.
83    /// * `messages` - A vector of [ChatMessage] to send as part of the chat.
84    /// * `options` - Optional [ChatParams] to customize the request.
85    ///
86    /// # Returns
87    ///
88    /// Returns a [Result] containing the `ChatResponse` if the request is successful,
89    /// or an [ApiError] if there is an error.
90    ///
91    /// # Examples
92    ///
93    /// ```
94    /// use mistralai_client::v1::{
95    ///     chat::{ChatMessage, ChatMessageRole},
96    ///     client::Client,
97    ///     constants::Model,
98    /// };
99    ///
100    /// let client = Client::new(None, None, None, None).unwrap();
101    /// let messages = vec![ChatMessage {
102    ///     role: ChatMessageRole::User,
103    ///     content: "Hello, world!".to_string(),
104    ///     tool_calls: None,
105    /// }];
106    /// let response = client.chat(Model::OpenMistral7b, messages, None).unwrap();
107    /// println!("{:?}: {}", response.choices[0].message.role, response.choices[0].message.content);
108    /// ```
109    pub fn chat(
110        &self,
111        model: constants::Model,
112        messages: Vec<chat::ChatMessage>,
113        options: Option<chat::ChatParams>,
114    ) -> Result<chat::ChatResponse, error::ApiError> {
115        let request = chat::ChatRequest::new(model, messages, false, options);
116
117        let response = self.post_sync("/chat/completions", &request)?;
118        let result = response.json::<chat::ChatResponse>();
119        match result {
120            Ok(data) => {
121                utils::debug_pretty_json_from_struct("Response Data", &data);
122
123                self.call_function_if_any(data.clone());
124
125                Ok(data)
126            }
127            Err(error) => Err(self.to_api_error(error)),
128        }
129    }
130
131    /// Asynchronously sends a chat completion request and returns the response.
132    ///
133    /// # Arguments
134    ///
135    /// * `model` - The [Model] to use for the chat completion.
136    /// * `messages` - A vector of [ChatMessage] to send as part of the chat.
137    /// * `options` - Optional [ChatParams] to customize the request.
138    ///
139    /// # Returns
140    ///
141    /// Returns a [Result] containing a `Stream` of `ChatStreamChunk` if the request is successful,
142    /// or an [ApiError] if there is an error.
143    ///
144    /// # Examples
145    ///
146    /// ```
147    /// use mistralai_client::v1::{
148    ///     chat::{ChatMessage, ChatMessageRole},
149    ///     client::Client,
150    ///     constants::Model,
151    /// };
152    ///
153    /// #[tokio::main]
154    /// async fn main() {
155    ///     let client = Client::new(None, None, None, None).unwrap();
156    ///     let messages = vec![ChatMessage {
157    ///         role: ChatMessageRole::User,
158    ///         content: "Hello, world!".to_string(),
159    ///         tool_calls: None,
160    ///     }];
161    ///     let response = client.chat_async(Model::OpenMistral7b, messages, None).await.unwrap();
162    ///     println!("{:?}: {}", response.choices[0].message.role, response.choices[0].message.content);
163    /// }
164    /// ```
165    pub async fn chat_async(
166        &self,
167        model: constants::Model,
168        messages: Vec<chat::ChatMessage>,
169        options: Option<chat::ChatParams>,
170    ) -> Result<chat::ChatResponse, error::ApiError> {
171        let request = chat::ChatRequest::new(model, messages, false, options);
172
173        let response = self.post_async("/chat/completions", &request).await?;
174        let result = response.json::<chat::ChatResponse>().await;
175        match result {
176            Ok(data) => {
177                utils::debug_pretty_json_from_struct("Response Data", &data);
178
179                self.call_function_if_any_async(data.clone()).await;
180
181                Ok(data)
182            }
183            Err(error) => Err(self.to_api_error(error)),
184        }
185    }
186
187    /// Asynchronously sends a chat completion request and returns a stream of message chunks.
188    ///
189    /// # Arguments
190    ///
191    /// * `model` - The [Model] to use for the chat completion.
192    /// * `messages` - A vector of [ChatMessage] to send as part of the chat.
193    /// * `options` - Optional [ChatParams] to customize the request.
194    ///
195    /// # Returns
196    ///
197    /// Returns a [Result] containing a `Stream` of `ChatStreamChunk` if the request is successful,
198    /// or an [ApiError] if there is an error.
199    ///
200    /// # Examples
201    ///
202    /// ```
203    /// use futures::stream::StreamExt;
204    /// use mistralai_client::v1::{
205    ///     chat::{ChatMessage, ChatMessageRole},
206    ///     client::Client,
207    ///     constants::Model,
208    /// };
209    /// use std::io::{self, Write};
210    ///
211    /// #[tokio::main]
212    /// async fn main() {
213    ///     let client = Client::new(None, None, None, None).unwrap();
214    ///     let messages = vec![ChatMessage {
215    ///         role: ChatMessageRole::User,
216    ///         content: "Hello, world!".to_string(),
217    ///         tool_calls: None,
218    ///     }];
219    ///
220    ///     let stream_result = client
221    ///         .chat_stream(Model::OpenMistral7b,messages, None)
222    ///         .await
223    ///         .unwrap();
224    ///     stream_result
225    ///         .for_each(|chunk_result| async {
226    ///             match chunk_result {
227    ///                 Ok(chunks) => chunks.iter().for_each(|chunk| {
228    ///                     print!("{}", chunk.choices[0].delta.content);
229    ///                     io::stdout().flush().unwrap();
230    ///                     // => "Once upon a time, [...]"
231    ///                 }),
232    ///                 Err(error) => {
233    ///                     eprintln!("Error processing chunk: {:?}", error)
234    ///                 }
235    ///             }
236    ///         })
237    ///         .await;
238    ///     print!("\n") // To persist the last chunk output.
239    /// }
240    pub async fn chat_stream(
241        &self,
242        model: constants::Model,
243        messages: Vec<chat::ChatMessage>,
244        options: Option<chat::ChatParams>,
245    ) -> Result<
246        impl Stream<Item = Result<Vec<chat_stream::ChatStreamChunk>, error::ApiError>>,
247        error::ApiError,
248    > {
249        let request = chat::ChatRequest::new(model, messages, true, options);
250        let response = self
251            .post_stream("/chat/completions", &request)
252            .await
253            .map_err(|e| error::ApiError {
254                message: e.to_string(),
255            })?;
256        if !response.status().is_success() {
257            let status = response.status();
258            let text = response.text().await.unwrap_or_default();
259            return Err(error::ApiError {
260                message: format!("{}: {}", status, text),
261            });
262        }
263
264        let deserialized_stream = response.bytes_stream().then(|bytes_result| async move {
265            match bytes_result {
266                Ok(bytes) => match String::from_utf8(bytes.to_vec()) {
267                    Ok(message) => {
268                        let chunks = message
269                            .lines()
270                            .filter_map(
271                                |line| match chat_stream::get_chunk_from_stream_message_line(line) {
272                                    Ok(Some(chunks)) => Some(chunks),
273                                    Ok(None) => None,
274                                    Err(_error) => None,
275                                },
276                            )
277                            .flatten()
278                            .collect();
279
280                        Ok(chunks)
281                    }
282                    Err(e) => Err(error::ApiError {
283                        message: e.to_string(),
284                    }),
285                },
286                Err(e) => Err(error::ApiError {
287                    message: e.to_string(),
288                }),
289            }
290        });
291
292        Ok(deserialized_stream)
293    }
294
295    pub fn embeddings(
296        &self,
297        model: constants::EmbedModel,
298        input: Vec<String>,
299        options: Option<embedding::EmbeddingRequestOptions>,
300    ) -> Result<embedding::EmbeddingResponse, error::ApiError> {
301        let request = embedding::EmbeddingRequest::new(model, input, options);
302
303        let response = self.post_sync("/embeddings", &request)?;
304        let result = response.json::<embedding::EmbeddingResponse>();
305        match result {
306            Ok(data) => {
307                utils::debug_pretty_json_from_struct("Response Data", &data);
308
309                Ok(data)
310            }
311            Err(error) => Err(self.to_api_error(error)),
312        }
313    }
314
315    pub async fn embeddings_async(
316        &self,
317        model: constants::EmbedModel,
318        input: Vec<String>,
319        options: Option<embedding::EmbeddingRequestOptions>,
320    ) -> Result<embedding::EmbeddingResponse, error::ApiError> {
321        let request = embedding::EmbeddingRequest::new(model, input, options);
322
323        let response = self.post_async("/embeddings", &request).await?;
324        let result = response.json::<embedding::EmbeddingResponse>().await;
325        match result {
326            Ok(data) => {
327                utils::debug_pretty_json_from_struct("Response Data", &data);
328
329                Ok(data)
330            }
331            Err(error) => Err(self.to_api_error(error)),
332        }
333    }
334
335    pub fn get_last_function_call_result(&self) -> Option<Box<dyn Any + Send>> {
336        let mut result_lock = self.last_function_call_result.lock().unwrap();
337
338        result_lock.take()
339    }
340
341    pub fn list_models(&self) -> Result<model_list::ModelListResponse, error::ApiError> {
342        let response = self.get_sync("/models")?;
343        let result = response.json::<model_list::ModelListResponse>();
344        match result {
345            Ok(data) => {
346                utils::debug_pretty_json_from_struct("Response Data", &data);
347
348                Ok(data)
349            }
350            Err(error) => Err(self.to_api_error(error)),
351        }
352    }
353
354    pub async fn list_models_async(
355        &self,
356    ) -> Result<model_list::ModelListResponse, error::ApiError> {
357        let response = self.get_async("/models").await?;
358        let result = response.json::<model_list::ModelListResponse>().await;
359        match result {
360            Ok(data) => {
361                utils::debug_pretty_json_from_struct("Response Data", &data);
362
363                Ok(data)
364            }
365            Err(error) => Err(self.to_api_error(error)),
366        }
367    }
368
369    pub fn register_function(&mut self, name: String, function: Box<dyn tool::Function>) {
370        let mut functions = self.functions.lock().unwrap();
371
372        functions.insert(name, function);
373    }
374
375    fn build_request_sync(
376        &self,
377        request: reqwest::blocking::RequestBuilder,
378    ) -> reqwest::blocking::RequestBuilder {
379        let user_agent = format!(
380            "ivangabriele/mistralai-client-rs/{}",
381            env!("CARGO_PKG_VERSION")
382        );
383
384        let request_builder = request
385            .bearer_auth(&self.api_key)
386            .header("Accept", "application/json")
387            .header("User-Agent", user_agent);
388
389        request_builder
390    }
391
392    fn build_request_async(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
393        let user_agent = format!(
394            "ivangabriele/mistralai-client-rs/{}",
395            env!("CARGO_PKG_VERSION")
396        );
397
398        let request_builder = request
399            .bearer_auth(&self.api_key)
400            .header("Accept", "application/json")
401            .header("User-Agent", user_agent);
402
403        request_builder
404    }
405
406    fn build_request_stream(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
407        let user_agent = format!(
408            "ivangabriele/mistralai-client-rs/{}",
409            env!("CARGO_PKG_VERSION")
410        );
411
412        let request_builder = request
413            .bearer_auth(&self.api_key)
414            .header("Accept", "text/event-stream")
415            .header("User-Agent", user_agent);
416
417        request_builder
418    }
419
420    fn call_function_if_any(&self, response: chat::ChatResponse) -> () {
421        let next_result = match response.choices.get(0) {
422            Some(first_choice) => match first_choice.message.tool_calls.to_owned() {
423                Some(tool_calls) => match tool_calls.get(0) {
424                    Some(first_tool_call) => {
425                        let functions = self.functions.lock().unwrap();
426                        match functions.get(&first_tool_call.function.name) {
427                            Some(function) => {
428                                let runtime = tokio::runtime::Runtime::new().unwrap();
429                                let result = runtime.block_on(async {
430                                    function
431                                        .execute(first_tool_call.function.arguments.to_owned())
432                                        .await
433                                });
434
435                                Some(result)
436                            }
437                            None => None,
438                        }
439                    }
440                    None => None,
441                },
442                None => None,
443            },
444            None => None,
445        };
446
447        let mut last_result_lock = self.last_function_call_result.lock().unwrap();
448        *last_result_lock = next_result;
449    }
450
451    async fn call_function_if_any_async(&self, response: chat::ChatResponse) -> () {
452        let next_result = match response.choices.get(0) {
453            Some(first_choice) => match first_choice.message.tool_calls.to_owned() {
454                Some(tool_calls) => match tool_calls.get(0) {
455                    Some(first_tool_call) => {
456                        let functions = self.functions.lock().unwrap();
457                        match functions.get(&first_tool_call.function.name) {
458                            Some(function) => {
459                                let result = function
460                                    .execute(first_tool_call.function.arguments.to_owned())
461                                    .await;
462
463                                Some(result)
464                            }
465                            None => None,
466                        }
467                    }
468                    None => None,
469                },
470                None => None,
471            },
472            None => None,
473        };
474
475        let mut last_result_lock = self.last_function_call_result.lock().unwrap();
476        *last_result_lock = next_result;
477    }
478
479    fn get_sync(&self, path: &str) -> Result<reqwest::blocking::Response, error::ApiError> {
480        let reqwest_client = reqwest::blocking::Client::new();
481        let url = format!("{}{}", self.endpoint, path);
482        debug!("Request URL: {}", url);
483
484        let request = self.build_request_sync(reqwest_client.get(url));
485
486        let result = request.send();
487        match result {
488            Ok(response) => {
489                if response.status().is_success() {
490                    Ok(response)
491                } else {
492                    let response_status = response.status();
493                    let response_body = response.text().unwrap_or_default();
494                    debug!("Response Status: {}", &response_status);
495                    utils::debug_pretty_json_from_string("Response Data", &response_body);
496
497                    Err(error::ApiError {
498                        message: format!("{}: {}", response_status, response_body),
499                    })
500                }
501            }
502            Err(error) => Err(error::ApiError {
503                message: error.to_string(),
504            }),
505        }
506    }
507
508    async fn get_async(&self, path: &str) -> Result<reqwest::Response, error::ApiError> {
509        let reqwest_client = reqwest::Client::new();
510        let url = format!("{}{}", self.endpoint, path);
511        debug!("Request URL: {}", url);
512
513        let request_builder = reqwest_client.get(url);
514        let request = self.build_request_async(request_builder);
515
516        let result = request.send().await;
517        match result {
518            Ok(response) => {
519                if response.status().is_success() {
520                    Ok(response)
521                } else {
522                    let response_status = response.status();
523                    let response_body = response.text().await.unwrap_or_default();
524                    debug!("Response Status: {}", &response_status);
525                    utils::debug_pretty_json_from_string("Response Data", &response_body);
526
527                    Err(error::ApiError {
528                        message: format!("{}: {}", response_status, response_body),
529                    })
530                }
531            }
532            Err(error) => Err(error::ApiError {
533                message: error.to_string(),
534            }),
535        }
536    }
537
538    fn post_sync<T: std::fmt::Debug + serde::ser::Serialize>(
539        &self,
540        path: &str,
541        params: &T,
542    ) -> Result<reqwest::blocking::Response, error::ApiError> {
543        let reqwest_client = reqwest::blocking::Client::new();
544        let url = format!("{}{}", self.endpoint, path);
545        debug!("Request URL: {}", url);
546        utils::debug_pretty_json_from_struct("Request Body", params);
547
548        let request_builder = reqwest_client.post(url).json(params);
549        let request = self.build_request_sync(request_builder);
550
551        let result = request.send();
552        match result {
553            Ok(response) => {
554                if response.status().is_success() {
555                    Ok(response)
556                } else {
557                    let response_status = response.status();
558                    let response_body = response.text().unwrap_or_default();
559                    debug!("Response Status: {}", &response_status);
560                    utils::debug_pretty_json_from_string("Response Data", &response_body);
561
562                    Err(error::ApiError {
563                        message: format!("{}: {}", response_body, response_status),
564                    })
565                }
566            }
567            Err(error) => Err(error::ApiError {
568                message: error.to_string(),
569            }),
570        }
571    }
572
573    async fn post_async<T: serde::ser::Serialize + std::fmt::Debug>(
574        &self,
575        path: &str,
576        params: &T,
577    ) -> Result<reqwest::Response, error::ApiError> {
578        let reqwest_client = reqwest::Client::new();
579        let url = format!("{}{}", self.endpoint, path);
580        debug!("Request URL: {}", url);
581        utils::debug_pretty_json_from_struct("Request Body", params);
582
583        let request_builder = reqwest_client.post(url).json(params);
584        let request = self.build_request_async(request_builder);
585
586        let result = request.send().await;
587        match result {
588            Ok(response) => {
589                if response.status().is_success() {
590                    Ok(response)
591                } else {
592                    let response_status = response.status();
593                    let response_body = response.text().await.unwrap_or_default();
594                    debug!("Response Status: {}", &response_status);
595                    utils::debug_pretty_json_from_string("Response Data", &response_body);
596
597                    Err(error::ApiError {
598                        message: format!("{}: {}", response_status, response_body),
599                    })
600                }
601            }
602            Err(error) => Err(error::ApiError {
603                message: error.to_string(),
604            }),
605        }
606    }
607
608    async fn post_stream<T: serde::ser::Serialize + std::fmt::Debug>(
609        &self,
610        path: &str,
611        params: &T,
612    ) -> Result<reqwest::Response, error::ApiError> {
613        let reqwest_client = reqwest::Client::new();
614        let url = format!("{}{}", self.endpoint, path);
615        debug!("Request URL: {}", url);
616        utils::debug_pretty_json_from_struct("Request Body", params);
617
618        let request_builder = reqwest_client.post(url).json(params);
619        let request = self.build_request_stream(request_builder);
620
621        let result = request.send().await;
622        match result {
623            Ok(response) => {
624                if response.status().is_success() {
625                    Ok(response)
626                } else {
627                    let response_status = response.status();
628                    let response_body = response.text().await.unwrap_or_default();
629                    debug!("Response Status: {}", &response_status);
630                    utils::debug_pretty_json_from_string("Response Data", &response_body);
631
632                    Err(error::ApiError {
633                        message: format!("{}: {}", response_status, response_body),
634                    })
635                }
636            }
637            Err(error) => Err(error::ApiError {
638                message: error.to_string(),
639            }),
640        }
641    }
642
643    fn to_api_error(&self, err: ReqwestError) -> error::ApiError {
644        error::ApiError {
645            message: err.to_string(),
646        }
647    }
648}