1use futures::stream::StreamExt;
2use futures::Stream;
3use log::debug;
4use reqwest::Error as ReqwestError;
5use std::{
6 any::Any,
7 collections::HashMap,
8 sync::{Arc, Mutex},
9};
10
11use crate::v1::{chat, chat_stream, constants, embedding, error, model_list, tool, utils};
12
13#[derive(Debug)]
14pub struct Client {
15 pub api_key: String,
16 pub endpoint: String,
17 pub max_retries: u32,
18 pub timeout: u32,
19
20 functions: Arc<Mutex<HashMap<String, Box<dyn tool::Function>>>>,
21 last_function_call_result: Arc<Mutex<Option<Box<dyn Any + Send>>>>,
22}
23
24impl Client {
25 pub fn new(
49 api_key: Option<String>,
50 endpoint: Option<String>,
51 max_retries: Option<u32>,
52 timeout: Option<u32>,
53 ) -> Result<Self, error::ClientError> {
54 let api_key = match api_key {
55 Some(api_key_from_param) => api_key_from_param,
56 None => {
57 std::env::var("MISTRAL_API_KEY").map_err(|_| error::ClientError::MissingApiKey)?
58 }
59 };
60 let endpoint = endpoint.unwrap_or(constants::API_URL_BASE.to_string());
61 let max_retries = max_retries.unwrap_or(5);
62 let timeout = timeout.unwrap_or(120);
63
64 let functions: Arc<_> = Arc::new(Mutex::new(HashMap::new()));
65 let last_function_call_result = Arc::new(Mutex::new(None));
66
67 Ok(Self {
68 api_key,
69 endpoint,
70 max_retries,
71 timeout,
72
73 functions,
74 last_function_call_result,
75 })
76 }
77
78 pub fn chat(
110 &self,
111 model: constants::Model,
112 messages: Vec<chat::ChatMessage>,
113 options: Option<chat::ChatParams>,
114 ) -> Result<chat::ChatResponse, error::ApiError> {
115 let request = chat::ChatRequest::new(model, messages, false, options);
116
117 let response = self.post_sync("/chat/completions", &request)?;
118 let result = response.json::<chat::ChatResponse>();
119 match result {
120 Ok(data) => {
121 utils::debug_pretty_json_from_struct("Response Data", &data);
122
123 self.call_function_if_any(data.clone());
124
125 Ok(data)
126 }
127 Err(error) => Err(self.to_api_error(error)),
128 }
129 }
130
131 pub async fn chat_async(
166 &self,
167 model: constants::Model,
168 messages: Vec<chat::ChatMessage>,
169 options: Option<chat::ChatParams>,
170 ) -> Result<chat::ChatResponse, error::ApiError> {
171 let request = chat::ChatRequest::new(model, messages, false, options);
172
173 let response = self.post_async("/chat/completions", &request).await?;
174 let result = response.json::<chat::ChatResponse>().await;
175 match result {
176 Ok(data) => {
177 utils::debug_pretty_json_from_struct("Response Data", &data);
178
179 self.call_function_if_any_async(data.clone()).await;
180
181 Ok(data)
182 }
183 Err(error) => Err(self.to_api_error(error)),
184 }
185 }
186
187 pub async fn chat_stream(
241 &self,
242 model: constants::Model,
243 messages: Vec<chat::ChatMessage>,
244 options: Option<chat::ChatParams>,
245 ) -> Result<
246 impl Stream<Item = Result<Vec<chat_stream::ChatStreamChunk>, error::ApiError>>,
247 error::ApiError,
248 > {
249 let request = chat::ChatRequest::new(model, messages, true, options);
250 let response = self
251 .post_stream("/chat/completions", &request)
252 .await
253 .map_err(|e| error::ApiError {
254 message: e.to_string(),
255 })?;
256 if !response.status().is_success() {
257 let status = response.status();
258 let text = response.text().await.unwrap_or_default();
259 return Err(error::ApiError {
260 message: format!("{}: {}", status, text),
261 });
262 }
263
264 let deserialized_stream = response.bytes_stream().then(|bytes_result| async move {
265 match bytes_result {
266 Ok(bytes) => match String::from_utf8(bytes.to_vec()) {
267 Ok(message) => {
268 let chunks = message
269 .lines()
270 .filter_map(
271 |line| match chat_stream::get_chunk_from_stream_message_line(line) {
272 Ok(Some(chunks)) => Some(chunks),
273 Ok(None) => None,
274 Err(_error) => None,
275 },
276 )
277 .flatten()
278 .collect();
279
280 Ok(chunks)
281 }
282 Err(e) => Err(error::ApiError {
283 message: e.to_string(),
284 }),
285 },
286 Err(e) => Err(error::ApiError {
287 message: e.to_string(),
288 }),
289 }
290 });
291
292 Ok(deserialized_stream)
293 }
294
295 pub fn embeddings(
296 &self,
297 model: constants::EmbedModel,
298 input: Vec<String>,
299 options: Option<embedding::EmbeddingRequestOptions>,
300 ) -> Result<embedding::EmbeddingResponse, error::ApiError> {
301 let request = embedding::EmbeddingRequest::new(model, input, options);
302
303 let response = self.post_sync("/embeddings", &request)?;
304 let result = response.json::<embedding::EmbeddingResponse>();
305 match result {
306 Ok(data) => {
307 utils::debug_pretty_json_from_struct("Response Data", &data);
308
309 Ok(data)
310 }
311 Err(error) => Err(self.to_api_error(error)),
312 }
313 }
314
315 pub async fn embeddings_async(
316 &self,
317 model: constants::EmbedModel,
318 input: Vec<String>,
319 options: Option<embedding::EmbeddingRequestOptions>,
320 ) -> Result<embedding::EmbeddingResponse, error::ApiError> {
321 let request = embedding::EmbeddingRequest::new(model, input, options);
322
323 let response = self.post_async("/embeddings", &request).await?;
324 let result = response.json::<embedding::EmbeddingResponse>().await;
325 match result {
326 Ok(data) => {
327 utils::debug_pretty_json_from_struct("Response Data", &data);
328
329 Ok(data)
330 }
331 Err(error) => Err(self.to_api_error(error)),
332 }
333 }
334
335 pub fn get_last_function_call_result(&self) -> Option<Box<dyn Any + Send>> {
336 let mut result_lock = self.last_function_call_result.lock().unwrap();
337
338 result_lock.take()
339 }
340
341 pub fn list_models(&self) -> Result<model_list::ModelListResponse, error::ApiError> {
342 let response = self.get_sync("/models")?;
343 let result = response.json::<model_list::ModelListResponse>();
344 match result {
345 Ok(data) => {
346 utils::debug_pretty_json_from_struct("Response Data", &data);
347
348 Ok(data)
349 }
350 Err(error) => Err(self.to_api_error(error)),
351 }
352 }
353
354 pub async fn list_models_async(
355 &self,
356 ) -> Result<model_list::ModelListResponse, error::ApiError> {
357 let response = self.get_async("/models").await?;
358 let result = response.json::<model_list::ModelListResponse>().await;
359 match result {
360 Ok(data) => {
361 utils::debug_pretty_json_from_struct("Response Data", &data);
362
363 Ok(data)
364 }
365 Err(error) => Err(self.to_api_error(error)),
366 }
367 }
368
369 pub fn register_function(&mut self, name: String, function: Box<dyn tool::Function>) {
370 let mut functions = self.functions.lock().unwrap();
371
372 functions.insert(name, function);
373 }
374
375 fn build_request_sync(
376 &self,
377 request: reqwest::blocking::RequestBuilder,
378 ) -> reqwest::blocking::RequestBuilder {
379 let user_agent = format!(
380 "ivangabriele/mistralai-client-rs/{}",
381 env!("CARGO_PKG_VERSION")
382 );
383
384 let request_builder = request
385 .bearer_auth(&self.api_key)
386 .header("Accept", "application/json")
387 .header("User-Agent", user_agent);
388
389 request_builder
390 }
391
392 fn build_request_async(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
393 let user_agent = format!(
394 "ivangabriele/mistralai-client-rs/{}",
395 env!("CARGO_PKG_VERSION")
396 );
397
398 let request_builder = request
399 .bearer_auth(&self.api_key)
400 .header("Accept", "application/json")
401 .header("User-Agent", user_agent);
402
403 request_builder
404 }
405
406 fn build_request_stream(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
407 let user_agent = format!(
408 "ivangabriele/mistralai-client-rs/{}",
409 env!("CARGO_PKG_VERSION")
410 );
411
412 let request_builder = request
413 .bearer_auth(&self.api_key)
414 .header("Accept", "text/event-stream")
415 .header("User-Agent", user_agent);
416
417 request_builder
418 }
419
420 fn call_function_if_any(&self, response: chat::ChatResponse) -> () {
421 let next_result = match response.choices.get(0) {
422 Some(first_choice) => match first_choice.message.tool_calls.to_owned() {
423 Some(tool_calls) => match tool_calls.get(0) {
424 Some(first_tool_call) => {
425 let functions = self.functions.lock().unwrap();
426 match functions.get(&first_tool_call.function.name) {
427 Some(function) => {
428 let runtime = tokio::runtime::Runtime::new().unwrap();
429 let result = runtime.block_on(async {
430 function
431 .execute(first_tool_call.function.arguments.to_owned())
432 .await
433 });
434
435 Some(result)
436 }
437 None => None,
438 }
439 }
440 None => None,
441 },
442 None => None,
443 },
444 None => None,
445 };
446
447 let mut last_result_lock = self.last_function_call_result.lock().unwrap();
448 *last_result_lock = next_result;
449 }
450
451 async fn call_function_if_any_async(&self, response: chat::ChatResponse) -> () {
452 let next_result = match response.choices.get(0) {
453 Some(first_choice) => match first_choice.message.tool_calls.to_owned() {
454 Some(tool_calls) => match tool_calls.get(0) {
455 Some(first_tool_call) => {
456 let functions = self.functions.lock().unwrap();
457 match functions.get(&first_tool_call.function.name) {
458 Some(function) => {
459 let result = function
460 .execute(first_tool_call.function.arguments.to_owned())
461 .await;
462
463 Some(result)
464 }
465 None => None,
466 }
467 }
468 None => None,
469 },
470 None => None,
471 },
472 None => None,
473 };
474
475 let mut last_result_lock = self.last_function_call_result.lock().unwrap();
476 *last_result_lock = next_result;
477 }
478
479 fn get_sync(&self, path: &str) -> Result<reqwest::blocking::Response, error::ApiError> {
480 let reqwest_client = reqwest::blocking::Client::new();
481 let url = format!("{}{}", self.endpoint, path);
482 debug!("Request URL: {}", url);
483
484 let request = self.build_request_sync(reqwest_client.get(url));
485
486 let result = request.send();
487 match result {
488 Ok(response) => {
489 if response.status().is_success() {
490 Ok(response)
491 } else {
492 let response_status = response.status();
493 let response_body = response.text().unwrap_or_default();
494 debug!("Response Status: {}", &response_status);
495 utils::debug_pretty_json_from_string("Response Data", &response_body);
496
497 Err(error::ApiError {
498 message: format!("{}: {}", response_status, response_body),
499 })
500 }
501 }
502 Err(error) => Err(error::ApiError {
503 message: error.to_string(),
504 }),
505 }
506 }
507
508 async fn get_async(&self, path: &str) -> Result<reqwest::Response, error::ApiError> {
509 let reqwest_client = reqwest::Client::new();
510 let url = format!("{}{}", self.endpoint, path);
511 debug!("Request URL: {}", url);
512
513 let request_builder = reqwest_client.get(url);
514 let request = self.build_request_async(request_builder);
515
516 let result = request.send().await;
517 match result {
518 Ok(response) => {
519 if response.status().is_success() {
520 Ok(response)
521 } else {
522 let response_status = response.status();
523 let response_body = response.text().await.unwrap_or_default();
524 debug!("Response Status: {}", &response_status);
525 utils::debug_pretty_json_from_string("Response Data", &response_body);
526
527 Err(error::ApiError {
528 message: format!("{}: {}", response_status, response_body),
529 })
530 }
531 }
532 Err(error) => Err(error::ApiError {
533 message: error.to_string(),
534 }),
535 }
536 }
537
538 fn post_sync<T: std::fmt::Debug + serde::ser::Serialize>(
539 &self,
540 path: &str,
541 params: &T,
542 ) -> Result<reqwest::blocking::Response, error::ApiError> {
543 let reqwest_client = reqwest::blocking::Client::new();
544 let url = format!("{}{}", self.endpoint, path);
545 debug!("Request URL: {}", url);
546 utils::debug_pretty_json_from_struct("Request Body", params);
547
548 let request_builder = reqwest_client.post(url).json(params);
549 let request = self.build_request_sync(request_builder);
550
551 let result = request.send();
552 match result {
553 Ok(response) => {
554 if response.status().is_success() {
555 Ok(response)
556 } else {
557 let response_status = response.status();
558 let response_body = response.text().unwrap_or_default();
559 debug!("Response Status: {}", &response_status);
560 utils::debug_pretty_json_from_string("Response Data", &response_body);
561
562 Err(error::ApiError {
563 message: format!("{}: {}", response_body, response_status),
564 })
565 }
566 }
567 Err(error) => Err(error::ApiError {
568 message: error.to_string(),
569 }),
570 }
571 }
572
573 async fn post_async<T: serde::ser::Serialize + std::fmt::Debug>(
574 &self,
575 path: &str,
576 params: &T,
577 ) -> Result<reqwest::Response, error::ApiError> {
578 let reqwest_client = reqwest::Client::new();
579 let url = format!("{}{}", self.endpoint, path);
580 debug!("Request URL: {}", url);
581 utils::debug_pretty_json_from_struct("Request Body", params);
582
583 let request_builder = reqwest_client.post(url).json(params);
584 let request = self.build_request_async(request_builder);
585
586 let result = request.send().await;
587 match result {
588 Ok(response) => {
589 if response.status().is_success() {
590 Ok(response)
591 } else {
592 let response_status = response.status();
593 let response_body = response.text().await.unwrap_or_default();
594 debug!("Response Status: {}", &response_status);
595 utils::debug_pretty_json_from_string("Response Data", &response_body);
596
597 Err(error::ApiError {
598 message: format!("{}: {}", response_status, response_body),
599 })
600 }
601 }
602 Err(error) => Err(error::ApiError {
603 message: error.to_string(),
604 }),
605 }
606 }
607
608 async fn post_stream<T: serde::ser::Serialize + std::fmt::Debug>(
609 &self,
610 path: &str,
611 params: &T,
612 ) -> Result<reqwest::Response, error::ApiError> {
613 let reqwest_client = reqwest::Client::new();
614 let url = format!("{}{}", self.endpoint, path);
615 debug!("Request URL: {}", url);
616 utils::debug_pretty_json_from_struct("Request Body", params);
617
618 let request_builder = reqwest_client.post(url).json(params);
619 let request = self.build_request_stream(request_builder);
620
621 let result = request.send().await;
622 match result {
623 Ok(response) => {
624 if response.status().is_success() {
625 Ok(response)
626 } else {
627 let response_status = response.status();
628 let response_body = response.text().await.unwrap_or_default();
629 debug!("Response Status: {}", &response_status);
630 utils::debug_pretty_json_from_string("Response Data", &response_body);
631
632 Err(error::ApiError {
633 message: format!("{}: {}", response_status, response_body),
634 })
635 }
636 }
637 Err(error) => Err(error::ApiError {
638 message: error.to_string(),
639 }),
640 }
641 }
642
643 fn to_api_error(&self, err: ReqwestError) -> error::ApiError {
644 error::ApiError {
645 message: err.to_string(),
646 }
647 }
648}