mistral_api/
completion.rs1use crate::client::Endpoint;
2use reqwest::{Client, Request};
3use serde::{Deserialize, Serialize};
4use url::Url;
5
6#[derive(Debug, Deserialize, Serialize)]
7#[serde(rename_all = "snake_case")]
8pub enum MessageRole {
9 System,
10 User,
11 Assistant,
12}
13
14#[derive(Debug, Deserialize, Serialize)]
15pub struct Message {
16 pub role: MessageRole,
17 pub content: String,
18}
19
20#[derive(Debug, Deserialize)]
21#[serde(rename_all = "snake_case")]
22pub enum FinishReason {
23 Stop,
24 Length,
25 ModelLength,
26}
27
28#[derive(Debug, Deserialize)]
29pub struct Choice {
30 pub index: i32,
31 pub message: Message,
32 pub finish_reason: FinishReason,
33}
34
35#[derive(Debug, Deserialize)]
36pub struct Usage {
37 pub prompt_tokens: i32,
38 pub completion_tokens: i32,
39 pub total_tokens: i32,
40}
41
42#[derive(Debug, Deserialize)]
43pub struct ChatCompletionResponse {
44 pub id: String,
45 pub object: String,
46 pub created: i64,
47 pub model: String,
48 pub choices: Vec<Choice>,
49 pub usage: Usage,
50}
51
52#[derive(Debug, Deserialize, Serialize)]
53pub struct ChatCompletion {
54 model: String,
56 messages: Vec<Message>,
58 temperature: f32,
62 top_p: f32,
66 #[serde(skip_serializing_if = "Option::is_none")]
70 max_tokens: Option<i32>,
71 stream: bool,
73 safe_prompt: bool,
75 random_seed: Option<i32>,
77}
78
79impl ChatCompletion {
80 pub fn builder() -> ChatCompletionBuilder {
81 ChatCompletionBuilder::default()
82 }
83
84 pub fn new(model: &str) -> ChatCompletion {
85 ChatCompletionBuilder::default().build(model)
86 }
87
88 pub fn messages_mut(&mut self) -> &mut Vec<Message> {
89 &mut self.messages
90 }
91
92 pub fn append_message(&mut self, message: Message) {
93 self.messages.push(message);
94 }
95}
96
97impl Endpoint for ChatCompletion {
98 type Response = ChatCompletionResponse;
99
100 fn request(&self, client: &Client) -> Request {
101 let url = Url::parse("https://api.mistral.ai/v1/chat/completions").unwrap();
102 client
103 .post(url)
104 .header("Content-Type", "application/json")
105 .header(
106 "Authorization",
107 format!("Bearer {}", std::env::var("MISTRAL_API_KEY").unwrap()),
108 )
109 .json(self)
110 .build()
111 .unwrap()
112 }
113}
114
115pub struct ChatCompletionBuilder {
116 temperature: f32,
117 top_p: f32,
118 max_tokens: Option<i32>,
119 stream: bool,
120 safe_prompt: bool,
121 random_seed: Option<i32>,
122}
123
124impl ChatCompletionBuilder {
125 pub fn temperature(mut self, temperature: f32) -> Self {
126 self.temperature = temperature;
127 self
128 }
129
130 pub fn top_p(mut self, top_p: f32) -> Self {
131 self.top_p = top_p;
132 self
133 }
134
135 pub fn max_tokens(mut self, max_tokens: i32) -> Self {
136 self.max_tokens = Some(max_tokens);
137 self
138 }
139
140 pub fn stream(mut self, stream: bool) -> Self {
141 self.stream = stream;
142 self
143 }
144
145 pub fn safe_prompt(mut self, safe_prompt: bool) -> Self {
146 self.safe_prompt = safe_prompt;
147 self
148 }
149
150 pub fn random_seed(mut self, random_seed: i32) -> Self {
151 self.random_seed = Some(random_seed);
152 self
153 }
154
155 pub fn build(&self, model: &str) -> ChatCompletion {
156 ChatCompletion {
157 model: model.to_string(),
158 messages: vec![],
159 temperature: self.temperature,
160 top_p: self.top_p,
161 max_tokens: self.max_tokens,
162 stream: self.stream,
163 safe_prompt: self.safe_prompt,
164 random_seed: self.random_seed,
165 }
166 }
167}
168
169impl Default for ChatCompletionBuilder {
170 fn default() -> Self {
171 Self {
172 temperature: 0.7,
173 top_p: 1.0,
174 max_tokens: None,
175 stream: false,
176 safe_prompt: false,
177 random_seed: None,
178 }
179 }
180}