mistral_api/
completion.rs

1use crate::client::Endpoint;
2use reqwest::{Client, Request};
3use serde::{Deserialize, Serialize};
4use url::Url;
5
6#[derive(Debug, Deserialize, Serialize)]
7#[serde(rename_all = "snake_case")]
8pub enum MessageRole {
9    System,
10    User,
11    Assistant,
12}
13
14#[derive(Debug, Deserialize, Serialize)]
15pub struct Message {
16    pub role: MessageRole,
17    pub content: String,
18}
19
20#[derive(Debug, Deserialize)]
21#[serde(rename_all = "snake_case")]
22pub enum FinishReason {
23    Stop,
24    Length,
25    ModelLength,
26}
27
28#[derive(Debug, Deserialize)]
29pub struct Choice {
30    pub index: i32,
31    pub message: Message,
32    pub finish_reason: FinishReason,
33}
34
35#[derive(Debug, Deserialize)]
36pub struct Usage {
37    pub prompt_tokens: i32,
38    pub completion_tokens: i32,
39    pub total_tokens: i32,
40}
41
42#[derive(Debug, Deserialize)]
43pub struct ChatCompletionResponse {
44    pub id: String,
45    pub object: String,
46    pub created: i64,
47    pub model: String,
48    pub choices: Vec<Choice>,
49    pub usage: Usage,
50}
51
52#[derive(Debug, Deserialize, Serialize)]
53pub struct ChatCompletion {
54    /// ID of the model to use. You can use the List Available Models API to see all of your available models, or see our Model overview for model descriptions.
55    model: String,
56    /// The prompt(s) to generate completions for, encoded as a list of dict with role and content. The first prompt role should be user or system.
57    messages: Vec<Message>,
58    /// What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
59    ///
60    /// We generally recommend altering this or top_p but not both.
61    temperature: f32,
62    /// Nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
63    ///
64    /// We generally recommend altering this or temperature but not both.
65    top_p: f32,
66    /// The maximum number of tokens to generate in the completion.
67    ///
68    /// The token count of your prompt plus max_tokens cannot exceed the model's context length.
69    #[serde(skip_serializing_if = "Option::is_none")]
70    max_tokens: Option<i32>,
71    /// Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message. Otherwise, the server will hold the request open until the timeout or until completion, with the response containing the full result as JSON.
72    stream: bool,
73    /// Whether to inject a safety prompt before all conversations.
74    safe_prompt: bool,
75    /// The seed to use for random sampling. If set, different calls will generate deterministic results.
76    random_seed: Option<i32>,
77}
78
79impl ChatCompletion {
80    pub fn builder() -> ChatCompletionBuilder {
81        ChatCompletionBuilder::default()
82    }
83
84    pub fn new(model: &str) -> ChatCompletion {
85        ChatCompletionBuilder::default().build(model)
86    }
87
88    pub fn messages_mut(&mut self) -> &mut Vec<Message> {
89        &mut self.messages
90    }
91
92    pub fn append_message(&mut self, message: Message) {
93        self.messages.push(message);
94    }
95}
96
97impl Endpoint for ChatCompletion {
98    type Response = ChatCompletionResponse;
99
100    fn request(&self, client: &Client) -> Request {
101        let url = Url::parse("https://api.mistral.ai/v1/chat/completions").unwrap();
102        client
103            .post(url)
104            .header("Content-Type", "application/json")
105            .header(
106                "Authorization",
107                format!("Bearer {}", std::env::var("MISTRAL_API_KEY").unwrap()),
108            )
109            .json(self)
110            .build()
111            .unwrap()
112    }
113}
114
115pub struct ChatCompletionBuilder {
116    temperature: f32,
117    top_p: f32,
118    max_tokens: Option<i32>,
119    stream: bool,
120    safe_prompt: bool,
121    random_seed: Option<i32>,
122}
123
124impl ChatCompletionBuilder {
125    pub fn temperature(mut self, temperature: f32) -> Self {
126        self.temperature = temperature;
127        self
128    }
129
130    pub fn top_p(mut self, top_p: f32) -> Self {
131        self.top_p = top_p;
132        self
133    }
134
135    pub fn max_tokens(mut self, max_tokens: i32) -> Self {
136        self.max_tokens = Some(max_tokens);
137        self
138    }
139
140    pub fn stream(mut self, stream: bool) -> Self {
141        self.stream = stream;
142        self
143    }
144
145    pub fn safe_prompt(mut self, safe_prompt: bool) -> Self {
146        self.safe_prompt = safe_prompt;
147        self
148    }
149
150    pub fn random_seed(mut self, random_seed: i32) -> Self {
151        self.random_seed = Some(random_seed);
152        self
153    }
154
155    pub fn build(&self, model: &str) -> ChatCompletion {
156        ChatCompletion {
157            model: model.to_string(),
158            messages: vec![],
159            temperature: self.temperature,
160            top_p: self.top_p,
161            max_tokens: self.max_tokens,
162            stream: self.stream,
163            safe_prompt: self.safe_prompt,
164            random_seed: self.random_seed,
165        }
166    }
167}
168
169impl Default for ChatCompletionBuilder {
170    fn default() -> Self {
171        Self {
172            temperature: 0.7,
173            top_p: 1.0,
174            max_tokens: None,
175            stream: false,
176            safe_prompt: false,
177            random_seed: None,
178        }
179    }
180}